Skip to content

Commit

Permalink
add support of nanollava model
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 5, 2024
1 parent 54a9727 commit f5e8c91
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 13 deletions.
162 changes: 161 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from packaging import version
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
Expand Down Expand Up @@ -75,6 +75,7 @@
JaisModelPatcher,
LlamaModelPatcher,
LlavaImageEmbeddingModelPatcher,
LlavaQwen2ImageEmbeddingsModelPatcher,
MiniCPMVImageEmbeddingsModelPatcher,
MiniCPMVResamplerModelPatcher,
MistralModelPatcher,
Expand Down Expand Up @@ -1579,6 +1580,165 @@ def patch_model_for_export(
return InternVLChatImageEmbeddingModelPatcher(self, model, model_kwargs)


@register_in_tasks_manager(
"llava-qwen2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers"
)
class LlavaQwen2OpenVINOConfig(OnnxConfig):
SUPPORTS_PAST = True
MIN_TRANSFORMERS_VERSION = version.parse("4.40.0")
SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaConfigBehavior]
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: LlavaConfigBehavior = LlavaConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
use_past: bool = False,
):
self._behavior = behavior
self._orig_config = config
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
config = AutoConfig.from_pretrained(config.mm_vision_tower, trust_remote_code=True)
if hasattr(config, "vision_config"):
config = config.vision_config
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return {}
return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return {}
return {"last_hidden_state": {0: "batch_size"}}

def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
behavior = LlavaConfigBehavior(behavior)

if behavior == LlavaConfigBehavior.LANGUAGE:
model.forward = super(type(model), model).forward
return model

if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return model

if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
text_embedding = model.model.embed_tokens
text_embedding.config = model.model.config
return text_embedding

def with_behavior(
self,
behavior: Union[str, LlavaConfigBehavior],
):
"""
Creates a config for different behaviour.
Args:
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
"""
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
behavior = LlavaConfigBehavior(behavior)

if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
model_type = self._orig_config.model_type.replace("llava-", "")
model_type = model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise ValueError(
f"Unsupported language model type provided `{model_type}`. Please define custom export config"
)

if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
raise ValueError(
f"Export config for text generation for `{model_type}` is not available. Please define custom export config"
)
internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][
"text-generation-with-past"
]
internal_export_config = internal_export_config_class(
self._orig_config,
use_past=True,
use_past_in_inputs=True,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS
export_config = InputEmbedOpenvVINOConfig(
self._orig_config,
task="feature-extraction",
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
return export_config

if behavior == LlavaConfigBehavior.LANGUAGE:
model_type = self._orig_config.model_type.replace("llava-", "")
model_type = model_type.replace("_", "-")

if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise ValueError(
f"Unsupported language model type provided `{model_type}`. Please define custom export config"
)

if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
raise ValueError(
f"Export config for text generation for `{model_type}` is not available. Please define custom export config"
)
internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][
"text-generation-with-past"
]
internal_export_config = internal_export_config_class(
self._orig_config,
use_past=True,
use_past_in_inputs=True,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
export_config = LMInputEmbedsConfigHelper(internal_export_config)
export_config._normalized_config = internal_export_config._normalized_config
return export_config

if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
model_kwargs = model_kwargs or {}
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
return super().patch_model_for_export(model, model_kwargs)
return LlavaQwen2ImageEmbeddingsModelPatcher(self, model, model_kwargs)

def rename_ambiguous_inputs(self, inputs):
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
model_inputs = {}
model_inputs["images"] = inputs["pixel_values"]
return model_inputs
return super().rename_ambiguous_inputs(inputs)


class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ["pooled_projections"]

Expand Down
18 changes: 18 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2973,3 +2973,21 @@ def __exit__(self, exc_type, exc_value, traceback):
if is_torch_version(">=", "2.0.0"):
for layer in self._model.encoder.layers:
layer.self_attn.forward = layer.self_attn._orig_forward


class LlavaQwen2ImageEmbeddingsModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = model.encode_images
super().__init__(config, model, model_kwargs)
if not self._model.get_vision_tower().is_loaded:
self._model.get_vision_tower().load_model()

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,4 @@ def get_submodels(model):
return custom_export, fn_get_submodels


MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "internvl-chat", "minicpmv"]
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv"]
25 changes: 17 additions & 8 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers.modeling_outputs import BaseModelOutputWithPooling

from ...exporters.openvino import main_export
from ...exporters.openvino.stateful import ensure_stateful_is_available
from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel, OVModelPart
from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM
Expand Down Expand Up @@ -122,8 +122,8 @@ def prepare_inputs(
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if past_len:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]

inputs["position_ids"] = position_ids

Expand Down Expand Up @@ -177,9 +177,11 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
self.hidden_states_output_names = [
key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name()
]
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self._main_input = "images" if model_has_input_output_name(self.model, "images") else "pixel_values"

def forward(self, pixel_values, **kwargs):
inputs = {"pixel_values": pixel_values}
inputs = {self._main_input: pixel_values}
if len(self.input_names) > 1:
for name in self.input_names:
if name in kwargs:
Expand Down Expand Up @@ -568,16 +570,19 @@ def half(self):
def forward(
self,
input_ids,
pixel_values,
pixel_values=None,
past_key_values=None,
inputs_embeds=None,
image_sizes=None,
attention_mask=None,
position_ids=None,
image_bound=None,
tgt_sizes=None,
images=None,
**kwargs,
):
if pixel_values is None and images is not None:
pixel_values = images
inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings(
input_ids,
pixel_values,
Expand Down Expand Up @@ -629,6 +634,7 @@ def get_multimodal_embeddings(
)
return inputs_embeds, attention_mask, position_ids

# Adopted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llava/modeling_llava.py#L521
def prepare_inputs_for_generation(
self,
input_ids,
Expand All @@ -646,14 +652,15 @@ def prepare_inputs_for_generation(
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
if attention_mask is not None and past_length + 1 > input_ids.shape[1]:
input_discount = max(attention_mask.shape[1] - past_length, 1)
input_ids = input_ids[:, -input_discount:]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.llava
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif getattr(self.config, "image_token_index", None) in input_ids:
elif getattr(self.config, "image_token_index", -1) in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]

position_ids = kwargs.get("position_ids", None)
Expand All @@ -679,6 +686,7 @@ def prepare_inputs_for_generation(
"image_sizes": image_sizes,
"image_bound": kwargs.get("image_bound"),
"tgt_sizes": kwargs.get("tgt_sizes"),
"images": kwargs.get("images"),
}
)
return model_inputs
Expand Down Expand Up @@ -1546,4 +1554,5 @@ def get_multimodal_embeddings(
"llava_next": _OVLlavaNextForCausalLM,
"internvl_chat": _OvInternVLForCausalLM,
"minicpmv": _OVMiniCPMVForCausalLM,
"llava-qwen2": _OVNanoLlavaForCausalLM,
}
5 changes: 2 additions & 3 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,12 +1879,11 @@ def test_compare_with_and_without_past_key_values(self):
class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ["llava"]

REMOTE_CODE_MODELS = ["minicpmv"]

if is_transformers_version(">=", "4.40.0"):
SUPPORTED_ARCHITECTURES += ["llava_next"]
SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
if is_transformers_version(">=", "4.45.0"):
SUPPORTED_ARCHITECTURES += ["minicpmv"]
REMOTE_CODE_MODELS = ["minicpmv", "nanollava"]
TASK = "image-text-to-text"

IMAGE = Image.open(
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mt5": "stas/mt5-tiny-random",
"nanollava": "katuni4ka/tiny-random-nanollava",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"olmo": "katuni4ka/tiny-random-olmo-hf",
"orion": "katuni4ka/tiny-random-orion",
Expand Down

0 comments on commit f5e8c91

Please sign in to comment.