Skip to content

Commit

Permalink
phi3 vision (#977)
Browse files Browse the repository at this point in the history
* wip

* wip wip

* fix images processing

* add test

* add input preprocessing

* Update tests/openvino/test_modeling.py

* Update optimum/exporters/openvino/__main__.py

* Update optimum/intel/openvino/modeling_visual_language.py

Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>

* refactor export configs

---------

Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>
  • Loading branch information
eaidova and nikita-savelyevv authored Nov 14, 2024
1 parent 41637d0 commit febc50e
Show file tree
Hide file tree
Showing 8 changed files with 426 additions and 216 deletions.
6 changes: 6 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from .utils import _MAX_UNCOMPRESSED_SIZE, MULTI_MODAL_TEXT_GENERATION_MODELS, clear_class_registry


FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"}

if TYPE_CHECKING:
from optimum.intel.openvino.configuration import OVConfig

Expand Down Expand Up @@ -264,6 +266,10 @@ def main_export(

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default that does not support load model on cpu
if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES:
loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type]
# there are some difference between remote and in library representation of past key values for some models,
# for avoiding confusion we disable remote code for them
if (
Expand Down
17 changes: 16 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,18 @@ def export_from_model(
)

model_name_or_path = model.config._name_or_path
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)
if preprocessors is not None:
# phi3-vision processor does not have chat_template attribute that breaks Processor saving on disk
if is_transformers_version(">=", "4.45") and model_type == "phi3-v" and len(preprocessors) > 1:
if not hasattr(preprocessors[1], "chat_template"):
preprocessors[1].chat_template = getattr(preprocessors[0], "chat_template", None)
for processor in preprocessors:
try:
processor.save_pretrained(output)
except Exception as ex:
logger.error(f"Saving {type(processor)} failed with {ex}")
else:
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)

files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]

Expand Down Expand Up @@ -891,6 +902,10 @@ def _get_multi_modal_submodels_and_export_configs(
if model_type == "internvl-chat" and preprocessors is not None:
model.config.img_context_token_id = preprocessors[0].convert_tokens_to_ids("<IMG_CONTEXT>")

if model_type == "phi3-v":
model.config.glb_GN = model.model.vision_embed_tokens.glb_GN.tolist()
model.config.sub_GN = model.model.vision_embed_tokens.sub_GN.tolist()

if hasattr(model, "image_newline"):
model.config.image_newline = model.image_newline.tolist()
main_config_cls = TasksManager.get_exporter_config_constructor(
Expand Down
426 changes: 217 additions & 209 deletions optimum/exporters/openvino/model_configs.py

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,7 @@ def phi3_442_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
Expand Down Expand Up @@ -3216,3 +3217,23 @@ def forward(self, input):
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


def phi3_vision_embeddings_forward(self, pixel_values: torch.FloatTensor):
return self.get_img_features(pixel_values)


class Phi3VisionImageEmbeddingsPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(phi3_vision_embeddings_forward, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_submodels(model):
return custom_export, fn_get_submodels


MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv"]
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v"]


def save_config(config, save_dir):
Expand Down
165 changes: 162 additions & 3 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@ def forward(self, image_feature, pos_embed, key_padding_mask):
return result


MODEL_PARTS_CLS_MAPPING = {"resampler": OVResampler}
class OVVisionProjection(OVModelPart):
_model_name = "vision_projection"

def forward(self, img_features):
self._compile()
return self.request(img_features)[0]


MODEL_PARTS_CLS_MAPPING = {"resampler": OVResampler, "vision_projection": OVVisionProjection}


class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
Expand Down Expand Up @@ -1802,8 +1810,8 @@ def preprocess_inputs(
raise ValueError("Tokenizer is required.")
if image is not None and processor is None:
raise ValueError("Processor is required.")
text_content = f"<image>\n{text}" if image is not None else text
messages = [{"role": "user", "content": text_content}]
text = f"<image>\n{text}" if image is not None else text
messages = [{"role": "user", "content": text}]
if tokenizer.chat_template is not None:
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if image is not None:
Expand All @@ -1818,10 +1826,161 @@ def preprocess_inputs(
return result


class _OVPhi3VisionForCausalLM(OVModelForVisualCausalLM):
additional_parts = ["vision_projection"]

def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model,
text_embeddings,
vision_embeddings,
config,
device,
dynamic_shapes,
ov_config,
model_save_dir,
quantization_config,
**kwargs,
)
self.sub_GN = torch.tensor(self.config.sub_GN)
self.glb_GN = torch.tensor(self.config.glb_GN)

def get_vision_embeddings(self, pixel_values, image_sizes, **kwargs):
num_images, num_crops, c, h, w = pixel_values.shape
img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape(
num_images, num_crops, -1, self.config.img_processor["image_dim_out"]
)
image_features_proj = self.hd_feature_transform(img_features, image_sizes)
return image_features_proj

def hd_feature_transform(self, image_features, image_sizes):
"""
image_features: (num_images, num_crops+1, 24*24, 1024)
"""

image_features = torch.from_numpy(image_features)
global_image_features = image_features[:, 0] # (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)
global_image_features_hd_newline = self.add_image_newline(global_image_features_hd)

all_image_embeddings = []
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for i, img_size in enumerate(image_sizes):
h, w = img_size
h_crop = h // 336
w_crop = w // 336
num_crops = h_crop * w_crop

# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features = image_features[i, 1 : 1 + num_crops]
sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop)
sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)

# [sub features, separator, global features]
all_image_embeddings.extend(
[
sub_image_features_hd_newline.squeeze(0), # (h_crop*12*(w_crop*12+1), 4096)
self.glb_GN.squeeze(0),
global_image_features_hd_newline[i],
]
)
image_features_proj = self.vision_projection(torch.cat(all_image_embeddings, dim=0).unsqueeze(0))[0]

return image_features_proj

def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
"""
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops
"""
N, L, C = image_features.shape
assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0
num_images = N // (h_crop * w_crop)
H = int(L**0.5)
image_features_hd = (
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
.reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
.reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096
)

return image_features_hd

def add_image_newline(self, image_features_hd):
"""
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
"""
num_images, h, w, hid_dim = image_features_hd.shape
# add the newline token to the HD image feature patches
newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim)
image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape(
num_images, -1, hid_dim
)
return image_features_hd_newline

def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, image_sizes=None, **kwargs
):
MAX_INPUT_ID = int(1e9)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])

# positions for image tokens
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
has_image = len(positions[0].tolist()) > 0
input_ids = input_ids.clamp_min(0).clamp_max(self.config.vocab_size)
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids, **kwargs))
if has_image:
vision_embeds = self.get_vision_embeddings(
pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs
)
image_features_proj = torch.from_numpy(vision_embeds)
inputs_embeds = inputs_embeds.index_put(positions, image_features_proj, accumulate=False)

return inputs_embeds, attention_mask, position_ids

@staticmethod
def preprocess_inputs(
text: str,
image: Optional[Image] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if image is not None and "<|image_1|>" not in text:
text = "<|image_1|>\n" + text
if getattr(processor.tokenizer, "chat_template", None) is not None:
chat_prompt = [{"role": "user", "content": text}]
text = processor.tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
inputs = processor(images=image, text=text, return_tensors="pt")
return inputs


MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
"internvl_chat": _OvInternVLForCausalLM,
"minicpmv": _OVMiniCPMVForCausalLM,
"llava-qwen2": _OVNanoLlavaForCausalLM,
"phi3_v": _OVPhi3VisionForCausalLM,
}
4 changes: 2 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,9 +1880,9 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
if is_transformers_version(">=", "4.40.0"):
SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
if is_transformers_version(">=", "4.45.0"):
SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2"]
SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2", "phi3_v"]
TASK = "image-text-to-text"
REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava"]
REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava", "phi3_v"]

IMAGE = Image.open(
requests.get(
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"pix2struct": "fxmarty/pix2struct-tiny-random",
"phi": "echarlaix/tiny-random-PhiForCausalLM",
"phi3": "Xenova/tiny-random-Phi3ForCausalLM",
"phi3_v": "katuni4ka/tiny-random-phi3-vision",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen": "katuni4ka/tiny-random-qwen",
"qwen2": "fxmarty/tiny-dummy-qwen2",
Expand Down

0 comments on commit febc50e

Please sign in to comment.