From f5e8c9158e02e50fd1072af3158965f329ffd151 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 24 Oct 2024 15:57:43 +0400 Subject: [PATCH] add support of nanollava model --- optimum/exporters/openvino/model_configs.py | 162 +++++++++++++++++- optimum/exporters/openvino/model_patcher.py | 18 ++ optimum/exporters/openvino/utils.py | 2 +- .../openvino/modeling_visual_language.py | 25 ++- tests/openvino/test_modeling.py | 5 +- tests/openvino/utils_tests.py | 1 + 6 files changed, 200 insertions(+), 13 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 108deed57..9dbcacb7f 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from packaging import version -from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel from transformers.utils import is_tf_available from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig @@ -75,6 +75,7 @@ JaisModelPatcher, LlamaModelPatcher, LlavaImageEmbeddingModelPatcher, + LlavaQwen2ImageEmbeddingsModelPatcher, MiniCPMVImageEmbeddingsModelPatcher, MiniCPMVResamplerModelPatcher, MistralModelPatcher, @@ -1579,6 +1580,165 @@ def patch_model_for_export( return InternVLChatImageEmbeddingModelPatcher(self, model, model_kwargs) +@register_in_tasks_manager( + "llava-qwen2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers" +) +class LlavaQwen2OpenVINOConfig(OnnxConfig): + SUPPORTS_PAST = True + MIN_TRANSFORMERS_VERSION = version.parse("4.40.0") + SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaConfigBehavior] + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: LlavaConfigBehavior = LlavaConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + use_past: bool = False, + ): + self._behavior = behavior + self._orig_config = config + if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + config = AutoConfig.from_pretrained(config.mm_vision_tower, trust_remote_code=True) + if hasattr(config, "vision_config"): + config = config.vision_config + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + return {} + return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}} + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + return {} + return {"last_hidden_state": {0: "batch_size"}} + + def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): + behavior = LlavaConfigBehavior(behavior) + + if behavior == LlavaConfigBehavior.LANGUAGE: + model.forward = super(type(model), model).forward + return model + + if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + return model + + if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS: + text_embedding = model.model.embed_tokens + text_embedding.config = model.model.config + return text_embedding + + def with_behavior( + self, + behavior: Union[str, LlavaConfigBehavior], + ): + """ + Creates a config for different behaviour. + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): + behavior = LlavaConfigBehavior(behavior) + + if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS: + model_type = self._orig_config.model_type.replace("llava-", "") + model_type = model_type.replace("_", "-") + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + raise ValueError( + f"Unsupported language model type provided `{model_type}`. Please define custom export config" + ) + + if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: + raise ValueError( + f"Export config for text generation for `{model_type}` is not available. Please define custom export config" + ) + internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ + "text-generation-with-past" + ] + internal_export_config = internal_export_config_class( + self._orig_config, + use_past=True, + use_past_in_inputs=True, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + ) + InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS + export_config = InputEmbedOpenvVINOConfig( + self._orig_config, + task="feature-extraction", + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + ) + return export_config + + if behavior == LlavaConfigBehavior.LANGUAGE: + model_type = self._orig_config.model_type.replace("llava-", "") + model_type = model_type.replace("_", "-") + + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + raise ValueError( + f"Unsupported language model type provided `{model_type}`. Please define custom export config" + ) + + if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: + raise ValueError( + f"Export config for text generation for `{model_type}` is not available. Please define custom export config" + ) + internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ + "text-generation-with-past" + ] + internal_export_config = internal_export_config_class( + self._orig_config, + use_past=True, + use_past_in_inputs=True, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + ) + export_config = LMInputEmbedsConfigHelper(internal_export_config) + export_config._normalized_config = internal_export_config._normalized_config + return export_config + + if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ): + model_kwargs = model_kwargs or {} + if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS: + return super().patch_model_for_export(model, model_kwargs) + return LlavaQwen2ImageEmbeddingsModelPatcher(self, model, model_kwargs) + + def rename_ambiguous_inputs(self, inputs): + if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: + model_inputs = {} + model_inputs["images"] = inputs["pixel_values"] + return model_inputs + return super().rename_ambiguous_inputs(inputs) + + class PooledProjectionsDummyInputGenerator(DummyInputGenerator): SUPPORTED_INPUT_NAMES = ["pooled_projections"] diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index b1aa7eaa9..8507d94fe 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -2973,3 +2973,21 @@ def __exit__(self, exc_type, exc_value, traceback): if is_torch_version(">=", "2.0.0"): for layer in self._model.encoder.layers: layer.self_attn.forward = layer.self_attn._orig_forward + + +class LlavaQwen2ImageEmbeddingsModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + model.__orig_forward = model.forward + model.forward = model.encode_images + super().__init__(config, model, model_kwargs) + if not self._model.get_vision_tower().is_loaded: + self._model.get_vision_tower().load_model() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 35e0c3017..9286a37f7 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -208,4 +208,4 @@ def get_submodels(model): return custom_export, fn_get_submodels -MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "internvl-chat", "minicpmv"] +MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv"] diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index b071602d9..74d7c88d6 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -14,7 +14,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling from ...exporters.openvino import main_export -from ...exporters.openvino.stateful import ensure_stateful_is_available +from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name from .configuration import OVConfig, OVWeightQuantizationConfig from .modeling_base import OVBaseModel, OVModelPart from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM @@ -122,8 +122,8 @@ def prepare_inputs( else: position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if past_len: + position_ids = position_ids[:, -inputs_embeds.shape[1] :] inputs["position_ids"] = position_ids @@ -177,9 +177,11 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None: self.hidden_states_output_names = [ key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name() ] + self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} + self._main_input = "images" if model_has_input_output_name(self.model, "images") else "pixel_values" def forward(self, pixel_values, **kwargs): - inputs = {"pixel_values": pixel_values} + inputs = {self._main_input: pixel_values} if len(self.input_names) > 1: for name in self.input_names: if name in kwargs: @@ -568,7 +570,7 @@ def half(self): def forward( self, input_ids, - pixel_values, + pixel_values=None, past_key_values=None, inputs_embeds=None, image_sizes=None, @@ -576,8 +578,11 @@ def forward( position_ids=None, image_bound=None, tgt_sizes=None, + images=None, **kwargs, ): + if pixel_values is None and images is not None: + pixel_values = images inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings( input_ids, pixel_values, @@ -629,6 +634,7 @@ def get_multimodal_embeddings( ) return inputs_embeds, attention_mask, position_ids + # Adopted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llava/modeling_llava.py#L521 def prepare_inputs_for_generation( self, input_ids, @@ -646,14 +652,15 @@ def prepare_inputs_for_generation( # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + if attention_mask is not None and past_length + 1 > input_ids.shape[1]: + input_discount = max(attention_mask.shape[1] - past_length, 1) + input_ids = input_ids[:, -input_discount:] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length.llava elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif getattr(self.config, "image_token_index", None) in input_ids: + elif getattr(self.config, "image_token_index", -1) in input_ids: input_ids = input_ids[:, input_ids.shape[1] - 1 :] position_ids = kwargs.get("position_ids", None) @@ -679,6 +686,7 @@ def prepare_inputs_for_generation( "image_sizes": image_sizes, "image_bound": kwargs.get("image_bound"), "tgt_sizes": kwargs.get("tgt_sizes"), + "images": kwargs.get("images"), } ) return model_inputs @@ -1546,4 +1554,5 @@ def get_multimodal_embeddings( "llava_next": _OVLlavaNextForCausalLM, "internvl_chat": _OvInternVLForCausalLM, "minicpmv": _OVMiniCPMVForCausalLM, + "llava-qwen2": _OVNanoLlavaForCausalLM, } diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 0dcfaac71..6c68438c7 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1879,12 +1879,11 @@ def test_compare_with_and_without_past_key_values(self): class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ["llava"] - REMOTE_CODE_MODELS = ["minicpmv"] - if is_transformers_version(">=", "4.40.0"): - SUPPORTED_ARCHITECTURES += ["llava_next"] + SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"] if is_transformers_version(">=", "4.45.0"): SUPPORTED_ARCHITECTURES += ["minicpmv"] + REMOTE_CODE_MODELS = ["minicpmv", "nanollava"] TASK = "image-text-to-text" IMAGE = Image.open( diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index ec0ca3981..f062ded11 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -95,6 +95,7 @@ "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mpnet": "hf-internal-testing/tiny-random-MPNetModel", "mt5": "stas/mt5-tiny-random", + "nanollava": "katuni4ka/tiny-random-nanollava", "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel", "olmo": "katuni4ka/tiny-random-olmo-hf", "orion": "katuni4ka/tiny-random-orion",