diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index ee61563c9..dba4628d7 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -45,6 +45,8 @@ from .utils import _MAX_UNCOMPRESSED_SIZE, MULTI_MODAL_TEXT_GENERATION_MODELS, clear_class_registry +FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"} + if TYPE_CHECKING: from optimum.intel.openvino.configuration import OVConfig @@ -264,6 +266,10 @@ def main_export( if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: loading_kwargs["attn_implementation"] = "eager" + + # some models force flash_attn attention by default that does not support load model on cpu + if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES: + loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type] # there are some difference between remote and in library representation of past key values for some models, # for avoiding confusion we disable remote code for them if ( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index fdee8a3ef..a84ecfabd 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -712,7 +712,18 @@ def export_from_model( ) model_name_or_path = model.config._name_or_path - maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code) + if preprocessors is not None: + # phi3-vision processor does not have chat_template attribute that breaks Processor saving on disk + if is_transformers_version(">=", "4.45") and model_type == "phi3-v" and len(preprocessors) > 1: + if not hasattr(preprocessors[1], "chat_template"): + preprocessors[1].chat_template = getattr(preprocessors[0], "chat_template", None) + for processor in preprocessors: + try: + processor.save_pretrained(output) + except Exception as ex: + logger.error(f"Saving {type(processor)} failed with {ex}") + else: + maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code) files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()] @@ -891,6 +902,10 @@ def _get_multi_modal_submodels_and_export_configs( if model_type == "internvl-chat" and preprocessors is not None: model.config.img_context_token_id = preprocessors[0].convert_tokens_to_ids("") + if model_type == "phi3-v": + model.config.glb_GN = model.model.vision_embed_tokens.glb_GN.tolist() + model.config.sub_GN = model.model.vision_embed_tokens.sub_GN.tolist() + if hasattr(model, "image_newline"): model.config.image_newline = model.image_newline.tolist() main_config_cls = TasksManager.get_exporter_config_constructor( diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 876672db4..b8310882b 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -86,6 +86,7 @@ MPTModelPatcher, PersimmonModelPatcher, Phi3ModelPatcher, + Phi3VisionImageEmbeddingsPatcher, QwenModelPatcher, RotaryEmbPatcher, UpdateCausalMaskModelPatcher, @@ -1292,6 +1293,48 @@ def patch_model_for_export( return InputEmbeddingPatcher(self, model, model_kwargs) +def get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype): + model_type = model_type.replace("_", "-") + + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + raise ValueError( + f"Unsupported language model type provided `{model_type}`. Please define custom export config" + ) + + if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: + raise ValueError( + f"Export config for text generation for `{model_type}` is not available. Please define custom export config" + ) + export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]["text-generation-with-past"] + export_config = export_config_class( + model_config, + use_past=True, + use_past_in_inputs=True, + int_dtype=int_dtype, + float_dtype=float_dtype, + ) + return export_config + + +def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dtype): + internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype) + InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS + export_config = InputEmbedOpenvVINOConfig( + model_config, + task="feature-extraction", + int_dtype=int_dtype, + float_dtype=float_dtype, + ) + return export_config + + +def get_vlm_text_generation_config(model_type, model_config, int_dtype, float_dtype): + internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype) + export_config = LMInputEmbedsConfigHelper(internal_export_config) + export_config._normalized_config = internal_export_config._normalized_config + return export_config + + class LlavaConfigBehavior(str, enum.Enum): LANGUAGE = "language" VISION_EMBEDDINGS = "vision_embeddings" @@ -1355,61 +1398,15 @@ def with_behavior( if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS: model_type = self._orig_config.text_config.model_type - model_type = model_type.replace("_", "-") - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: - raise ValueError( - f"Unsupported language model type provided `{model_type}`. Please define custom export config" - ) - - if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: - raise ValueError( - f"Export config for text generation for `{model_type}` is not available. Please define custom export config" - ) - internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ - "text-generation-with-past" - ] - internal_export_config = internal_export_config_class( - self._orig_config.text_config, - use_past=True, - use_past_in_inputs=True, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, + return get_vlm_text_embeddings_config( + model_type, self._orig_config.text_config, self.int_dtype, self.float_dtype ) - InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS - export_config = InputEmbedOpenvVINOConfig( - self._orig_config.text_config, - task="feature-extraction", - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - ) - return export_config if behavior == LlavaConfigBehavior.LANGUAGE: model_type = self._orig_config.text_config.model_type - model_type = model_type.replace("_", "-") - - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: - raise ValueError( - f"Unsupported language model type provided `{model_type}`. Please define custom export config" - ) - - if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: - raise ValueError( - f"Export config for text generation for `{model_type}` is not available. Please define custom export config" - ) - internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ - "text-generation-with-past" - ] - internal_export_config = internal_export_config_class( - self._orig_config.text_config, - use_past=True, - use_past_in_inputs=True, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, + return get_vlm_text_generation_config( + model_type, self._orig_config.text_config, self.int_dtype, self.float_dtype ) - export_config = LMInputEmbedsConfigHelper(internal_export_config) - export_config._normalized_config = internal_export_config._normalized_config - return export_config if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: return self.__class__( @@ -1517,61 +1514,15 @@ def with_behavior( if behavior == InternVLChatConfigBehavior.TEXT_EMBEDDINGS: model_type = self._orig_config.llm_config.model_type - model_type = model_type.replace("_", "-") - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: - raise ValueError( - f"Unsupported language model type provided `{model_type}`. Please define custom export config" - ) - - if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: - raise ValueError( - f"Export config for text generation for `{model_type}` is not available. Please define custom export config" - ) - internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ - "text-generation-with-past" - ] - internal_export_config = internal_export_config_class( - self._orig_config.llm_config, - use_past=True, - use_past_in_inputs=True, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - ) - InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS - export_config = InputEmbedOpenvVINOConfig( - self._orig_config.llm_config, - task="feature-extraction", - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, + return get_vlm_text_embeddings_config( + model_type, self._orig_config.llm_config, self.int_dtype, self.float_dtype ) - return export_config if behavior == InternVLChatConfigBehavior.LANGUAGE: model_type = self._orig_config.llm_config.model_type - model_type = model_type.replace("_", "-") - - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: - raise ValueError( - f"Unsupported language model type provided `{model_type}`. Please define custom export config" - ) - - if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: - raise ValueError( - f"Export config for text generation for `{model_type}` is not available. Please define custom export config" - ) - internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ - "text-generation-with-past" - ] - internal_export_config = internal_export_config_class( - self._orig_config.llm_config, - use_past=True, - use_past_in_inputs=True, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, + return get_vlm_text_generation_config( + model_type, self._orig_config.llm_config, self.int_dtype, self.float_dtype ) - export_config = LMInputEmbedsConfigHelper(internal_export_config) - export_config._normalized_config = internal_export_config._normalized_config - return export_config if behavior == InternVLChatConfigBehavior.VISION_EMBEDDINGS: return self.__class__( @@ -1583,7 +1534,8 @@ def with_behavior( preprocessors=self._preprocessors, ) - def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]): + @staticmethod + def get_model_for_behavior(model, behavior: Union[str, LlavaConfigBehavior]): if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): behavior = InternVLChatConfigBehavior(behavior) @@ -1653,7 +1605,8 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return {} return {"last_hidden_state": {0: "batch_size"}} - def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]): + @staticmethod + def get_model_for_behavior(model, behavior: Union[str, LlavaConfigBehavior]): if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): behavior = LlavaConfigBehavior(behavior) @@ -1684,61 +1637,11 @@ def with_behavior( if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS: model_type = self._orig_config.model_type.replace("llava-", "") - model_type = model_type.replace("_", "-") - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: - raise ValueError( - f"Unsupported language model type provided `{model_type}`. Please define custom export config" - ) - - if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: - raise ValueError( - f"Export config for text generation for `{model_type}` is not available. Please define custom export config" - ) - internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ - "text-generation-with-past" - ] - internal_export_config = internal_export_config_class( - self._orig_config, - use_past=True, - use_past_in_inputs=True, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - ) - InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS - export_config = InputEmbedOpenvVINOConfig( - self._orig_config, - task="feature-extraction", - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - ) - return export_config + return get_vlm_text_embeddings_config(model_type, self._orig_config, self.int_dtype, self.float_dtype) if behavior == LlavaConfigBehavior.LANGUAGE: model_type = self._orig_config.model_type.replace("llava-", "") - model_type = model_type.replace("_", "-") - - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: - raise ValueError( - f"Unsupported language model type provided `{model_type}`. Please define custom export config" - ) - - if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: - raise ValueError( - f"Export config for text generation for `{model_type}` is not available. Please define custom export config" - ) - internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ - "text-generation-with-past" - ] - internal_export_config = internal_export_config_class( - self._orig_config, - use_past=True, - use_past_in_inputs=True, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - ) - export_config = LMInputEmbedsConfigHelper(internal_export_config) - export_config._normalized_config = internal_export_config._normalized_config - return export_config + return get_vlm_text_generation_config(model_type, self._orig_config, self.int_dtype, self.float_dtype) if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: return self.__class__( @@ -2090,62 +1993,10 @@ def with_behavior( behavior = MiniCPMVConfigBehavior(behavior) if behavior == MiniCPMVConfigBehavior.TEXT_EMBEDDINGS: - model_type = "qwen2" - model_type = model_type.replace("_", "-") - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: - raise ValueError( - f"Unsupported language model type provided `{model_type}`. Please define custom export config" - ) - - if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: - raise ValueError( - f"Export config for text generation for `{model_type}` is not available. Please define custom export config" - ) - internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ - "text-generation-with-past" - ] - internal_export_config = internal_export_config_class( - self._orig_config, - use_past=True, - use_past_in_inputs=True, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - ) - InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS - export_config = InputEmbedOpenvVINOConfig( - self._orig_config, - task="feature-extraction", - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - ) - return export_config + return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype) if behavior == MiniCPMVConfigBehavior.LANGUAGE: - model_type = "qwen2" - model_type = model_type.replace("_", "-") - - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: - raise ValueError( - f"Unsupported language model type provided `{model_type}`. Please define custom export config" - ) - - if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: - raise ValueError( - f"Export config for text generation for `{model_type}` is not available. Please define custom export config" - ) - internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ - "text-generation-with-past" - ] - internal_export_config = internal_export_config_class( - self._orig_config, - use_past=True, - use_past_in_inputs=True, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - ) - export_config = LMInputEmbedsConfigHelper(internal_export_config) - export_config._normalized_config = internal_export_config._normalized_config - return export_config + return get_vlm_text_generation_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype) if behavior == MiniCPMVConfigBehavior.VISION_EMBEDDINGS: return self.__class__( @@ -2167,7 +2018,8 @@ def with_behavior( preprocessors=self._preprocessors, ) - def get_model_for_behavior(self, model, behavior: Union[str, MiniCPMVConfigBehavior]): + @staticmethod + def get_model_for_behavior(model, behavior: Union[str, MiniCPMVConfigBehavior]): if isinstance(behavior, str) and not isinstance(behavior, MiniCPMVConfigBehavior): behavior = MiniCPMVConfigBehavior(behavior) @@ -2196,3 +2048,159 @@ def patch_model_for_export( return MiniCPMVResamplerModelPatcher(self, model, model_kwargs) return super().patch_model_for_export(model, model_kwargs) + + +class Phi3VisionConfigBehavior(str, enum.Enum): + LANGUAGE = "language" + VISION_PROJECTION = "vision_projection" + VISION_EMBEDDINGS = "vision_embeddings" + TEXT_EMBEDDINGS = "text_embeddings" + + +class DummyPhi3VisionProjectionInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ("input",) + + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = 336, + height: int = 336, + **kwargs, + ): + self.batch_size = batch_size + self._embed_layer_realization = normalized_config.config.embd_layer["embedding_cls"] + self.image_dim_out = normalized_config.config.img_processor["image_dim_out"] + self.height = height + self.width = width + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + h = self.height // 336 + w = self.width // 336 + feat_size = (h * w + 1) * 144 + 1 + (h + 1) * 12 + if self._embed_layer_realization == "linear": + shape = [self.batch_size, feat_size, self.image_dim_out] + else: + shape = [self.batch_size, feat_size, self.image_dim_out * 4] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + + +@register_in_tasks_manager("phi3-v", *["image-text-to-text"], library_name="transformers") +class Phi3VisionOpenVINOConfig(OnnxConfig): + SUPPORTED_BEHAVIORS = [model_type.value for model_type in Phi3VisionConfigBehavior] + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) + MIN_TRANSFORMERS_VERSION = version.parse("4.40.0") + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: Phi3VisionConfigBehavior = Phi3VisionConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + self._behavior = behavior + self._orig_config = config + if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "img_processor"): + self._config = AutoConfig.from_pretrained( + config.img_processor["model_name"], trust_remote_code=True + ).vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) + if self._behavior == Phi3VisionConfigBehavior.VISION_PROJECTION and hasattr(config, "img_processor"): + self._config = config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyPhi3VisionProjectionInputGenerator,) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: + return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}} + if self._behavior == Phi3VisionConfigBehavior.VISION_PROJECTION: + return {"input": {0: "batch_size", 1: "img_feat_size"}} + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior in [Phi3VisionConfigBehavior.VISION_EMBEDDINGS, Phi3VisionConfigBehavior.VISION_PROJECTION]: + return {"last_hidden_state": {0: "batch_size", 1: "height_width_projection"}} + return {} + + def with_behavior( + self, + behavior: Union[str, Phi3VisionConfigBehavior], + ): + """ + Creates a config for different behaviour. + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, Phi3VisionConfigBehavior): + behavior = Phi3VisionConfigBehavior(behavior) + + if behavior == Phi3VisionConfigBehavior.TEXT_EMBEDDINGS: + return get_vlm_text_embeddings_config("phi3", self._orig_config, self.int_dtype, self.float_dtype) + + if behavior == Phi3VisionConfigBehavior.LANGUAGE: + return get_vlm_text_generation_config("phi3", self._orig_config, self.int_dtype, self.float_dtype) + + if behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + if behavior == Phi3VisionConfigBehavior.VISION_PROJECTION: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + @staticmethod + def get_model_for_behavior(model, behavior: Union[str, Phi3VisionConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, Phi3VisionConfigBehavior): + behavior = Phi3VisionConfigBehavior(behavior) + + if behavior == Phi3VisionConfigBehavior.LANGUAGE: + return model + + if behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: + vision_embeddings = model.model.vision_embed_tokens + vision_embeddings.config = model.config + return vision_embeddings + + if behavior == Phi3VisionConfigBehavior.VISION_PROJECTION: + projection = model.model.vision_embed_tokens.img_projection + projection.config = model.config + return projection + + if behavior == Phi3VisionConfigBehavior.TEXT_EMBEDDINGS: + text_embedding = model.model.embed_tokens + text_embedding.config = model.config + return text_embedding + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ): + model_kwargs = model_kwargs or {} + if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: + return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 7406e1370..58659e637 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -1369,6 +1369,7 @@ def phi3_442_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask @@ -3216,3 +3217,23 @@ def forward(self, input): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward + + +def phi3_vision_embeddings_forward(self, pixel_values: torch.FloatTensor): + return self.get_img_features(pixel_values) + + +class Phi3VisionImageEmbeddingsPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + model.__orig_forward = model.forward + model.forward = types.MethodType(phi3_vision_embeddings_forward, model) + super().__init__(config, model, model_kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 701334209..7fb1bb5f1 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -213,7 +213,7 @@ def get_submodels(model): return custom_export, fn_get_submodels -MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv"] +MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v"] def save_config(config, save_dir): diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index b7bf96a0d..35d91488d 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -230,7 +230,15 @@ def forward(self, image_feature, pos_embed, key_padding_mask): return result -MODEL_PARTS_CLS_MAPPING = {"resampler": OVResampler} +class OVVisionProjection(OVModelPart): + _model_name = "vision_projection" + + def forward(self, img_features): + self._compile() + return self.request(img_features)[0] + + +MODEL_PARTS_CLS_MAPPING = {"resampler": OVResampler, "vision_projection": OVVisionProjection} class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin): @@ -1802,8 +1810,8 @@ def preprocess_inputs( raise ValueError("Tokenizer is required.") if image is not None and processor is None: raise ValueError("Processor is required.") - text_content = f"\n{text}" if image is not None else text - messages = [{"role": "user", "content": text_content}] + text = f"\n{text}" if image is not None else text + messages = [{"role": "user", "content": text}] if tokenizer.chat_template is not None: text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if image is not None: @@ -1818,10 +1826,161 @@ def preprocess_inputs( return result +class _OVPhi3VisionForCausalLM(OVModelForVisualCausalLM): + additional_parts = ["vision_projection"] + + def __init__( + self, + language_model: ov.Model, + text_embeddings: ov.Model, + vision_embeddings: ov.Model, + config: PretrainedConfig = None, + device: str = "CPU", + dynamic_shapes: bool = True, + ov_config: Optional[Dict[str, str]] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, + **kwargs, + ): + super().__init__( + language_model, + text_embeddings, + vision_embeddings, + config, + device, + dynamic_shapes, + ov_config, + model_save_dir, + quantization_config, + **kwargs, + ) + self.sub_GN = torch.tensor(self.config.sub_GN) + self.glb_GN = torch.tensor(self.config.glb_GN) + + def get_vision_embeddings(self, pixel_values, image_sizes, **kwargs): + num_images, num_crops, c, h, w = pixel_values.shape + img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape( + num_images, num_crops, -1, self.config.img_processor["image_dim_out"] + ) + image_features_proj = self.hd_feature_transform(img_features, image_sizes) + return image_features_proj + + def hd_feature_transform(self, image_features, image_sizes): + """ + image_features: (num_images, num_crops+1, 24*24, 1024) + """ + + image_features = torch.from_numpy(image_features) + global_image_features = image_features[:, 0] # (num_images, 24*24, 1024) + # global feature can be viewed as a special HD case with num_crops 1x1 + global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1) + global_image_features_hd_newline = self.add_image_newline(global_image_features_hd) + + all_image_embeddings = [] + # need a for loop to process each image because of different image sizes + # (patch arrangement is different for each image) + for i, img_size in enumerate(image_sizes): + h, w = img_size + h_crop = h // 336 + w_crop = w // 336 + num_crops = h_crop * w_crop + + # NOTE: real num_crops is padded + # (num_crops, 24*24, 1024) + sub_image_features = image_features[i, 1 : 1 + num_crops] + sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop) + sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd) + + # [sub features, separator, global features] + all_image_embeddings.extend( + [ + sub_image_features_hd_newline.squeeze(0), # (h_crop*12*(w_crop*12+1), 4096) + self.glb_GN.squeeze(0), + global_image_features_hd_newline[i], + ] + ) + image_features_proj = self.vision_projection(torch.cat(all_image_embeddings, dim=0).unsqueeze(0))[0] + + return image_features_proj + + def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): + """ + image_features: (num_images*num_crops, 24*24, 1024) + output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops + """ + N, L, C = image_features.shape + assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0 + num_images = N // (h_crop * w_crop) + H = int(L**0.5) + image_features_hd = ( + image_features.reshape(N, H, H, C) # N, 24, 24, 1024 + .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 + .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 + .reshape(N, -1, 4 * C) # N, 144, 4096 + .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 + .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + ) + + return image_features_hd + + def add_image_newline(self, image_features_hd): + """ + image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) + output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) + """ + num_images, h, w, hid_dim = image_features_hd.shape + # add the newline token to the HD image feature patches + newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim) + image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape( + num_images, -1, hid_dim + ) + return image_features_hd_newline + + def get_multimodal_embeddings( + self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, image_sizes=None, **kwargs + ): + MAX_INPUT_ID = int(1e9) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + # positions for image tokens + positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True) + has_image = len(positions[0].tolist()) > 0 + input_ids = input_ids.clamp_min(0).clamp_max(self.config.vocab_size) + inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids, **kwargs)) + if has_image: + vision_embeds = self.get_vision_embeddings( + pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs + ) + image_features_proj = torch.from_numpy(vision_embeds) + inputs_embeds = inputs_embeds.index_put(positions, image_features_proj, accumulate=False) + + return inputs_embeds, attention_mask, position_ids + + @staticmethod + def preprocess_inputs( + text: str, + image: Optional[Image] = None, + processor: Optional[AutoImageProcessor] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + ): + if processor is None: + raise ValueError("Processor is required.") + if image is not None and "<|image_1|>" not in text: + text = "<|image_1|>\n" + text + if getattr(processor.tokenizer, "chat_template", None) is not None: + chat_prompt = [{"role": "user", "content": text}] + text = processor.tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False) + inputs = processor(images=image, text=text, return_tensors="pt") + return inputs + + MODEL_TYPE_TO_CLS_MAPPING = { "llava": _OVLlavaForCausalLM, "llava_next": _OVLlavaNextForCausalLM, "internvl_chat": _OvInternVLForCausalLM, "minicpmv": _OVMiniCPMVForCausalLM, "llava-qwen2": _OVNanoLlavaForCausalLM, + "phi3_v": _OVPhi3VisionForCausalLM, } diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 12bb9e3e8..d9921e91e 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1880,9 +1880,9 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase): if is_transformers_version(">=", "4.40.0"): SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"] if is_transformers_version(">=", "4.45.0"): - SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2"] + SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2", "phi3_v"] TASK = "image-text-to-text" - REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava"] + REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava", "phi3_v"] IMAGE = Image.open( requests.get( diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 3822b7646..394151cc3 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -109,6 +109,7 @@ "pix2struct": "fxmarty/pix2struct-tiny-random", "phi": "echarlaix/tiny-random-PhiForCausalLM", "phi3": "Xenova/tiny-random-Phi3ForCausalLM", + "phi3_v": "katuni4ka/tiny-random-phi3-vision", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen": "katuni4ka/tiny-random-qwen", "qwen2": "fxmarty/tiny-dummy-qwen2",