Skip to content

Commit

Permalink
fix switching between legacy and new processing for llava (#970)
Browse files Browse the repository at this point in the history
* fix switching between legacy and new processing for llava

* extend tests

* update legacy processing path

* replace llava test model

* Update tests/openvino/test_modeling.py
  • Loading branch information
eaidova authored Nov 8, 2024
1 parent 222748e commit c887610
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 53 deletions.
98 changes: 47 additions & 51 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,33 @@ def can_generate(self):


class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model=language_model,
text_embeddings=text_embeddings,
vision_embeddings=vision_embeddings,
config=config,
device=device,
dynamic_shapes=dynamic_shapes,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
**kwargs,
)
self._support_new_processing = hasattr(self.config, "image_seq_length")

def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
Expand Down Expand Up @@ -725,17 +752,11 @@ def merge_vision_text_embeddings(
input_ids,
attention_mask,
position_ids=None,
legacy_processing=None,
legacy_processing=False,
**kwargs,
):
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
if legacy_processing is None:
legacy_processing = (
not hasattr(self.config, "image_seq_length")
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1)
)

if legacy_processing:
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
Expand Down Expand Up @@ -768,15 +789,6 @@ def merge_vision_text_embeddings(
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
Expand All @@ -787,15 +799,15 @@ def merge_vision_text_embeddings(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model a/pre-releasesre wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

Expand All @@ -815,11 +827,12 @@ def merge_vision_text_embeddings(
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs
):
legacy_processing = (
not hasattr(self.config, "image_seq_length")
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1 and pixel_values is not None)
)
if pixel_values is not None and self._support_new_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
else:
legacy_processing = True
inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings(
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs
)
Expand All @@ -830,38 +843,19 @@ def get_multimodal_embeddings(
return inputs_embeds, attention_mask, position_ids

def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
if not self.language_model.stateful:
first_layer_past_key_value = torch.from_numpy(past_key_values[0][0][:, :, :, 0])
else:
first_layer_past_key_value = torch.from_numpy(
self.language_model.request.query_state()[0].state.data[:, :, :, 0]
)

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
past_length = self.language_model._get_past_length(past_key_values)

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
position_ids = torch.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
return attention_mask, position_ids


Expand Down Expand Up @@ -938,11 +932,13 @@ def get_multimodal_embeddings(

inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)

legacy_processing = (
not hasattr(self.config, "image_seq_length")
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1 and pixel_values is not None)
)
if pixel_values is not None and self._support_new_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
else:
legacy_processing = True

if pixel_values is not None and pixel_values.size(0) > 0:
# ! infer image_num_patches from image_sizes
image_num_patches = [
Expand Down Expand Up @@ -996,7 +992,7 @@ def merge_vision_text_embeddings(
input_ids,
attention_mask,
position_ids=None,
legacy_processing=None,
legacy_processing=False,
**kwargs,
):
image_token_index = self.config.image_token_index
Expand Down
57 changes: 56 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,12 +1983,67 @@ def test_compare_to_transformers(self, model_arch):
torch.equal(ov_outputs, transformers_outputs),
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
)

del transformers_model
del ov_model

gc.collect()

@parameterized.expand(["llava", "llava_next"])
@unittest.skipIf(
is_transformers_version("<", "4.45.0"), reason="New preprocessing available only in transformers >= 4.45"
)
def test_llava_with_new_preprocessing(self, model_arch):
prompt = "<image>\n What is shown in this image?"
model_id = MODEL_NAMES[model_arch]
config = AutoConfig.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
processor = AutoProcessor.from_pretrained(
model_id,
patch_size=config.vision_config.patch_size,
vision_feature_select_strategy=config.vision_feature_select_strategy,
trust_remote_code=model_arch in self.REMOTE_CODE_MODELS,
)
transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(model_id)
ov_model = OVModelForVisualCausalLM.from_pretrained(
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
self.assertTrue(ov_model._support_new_processing)
self.assertTrue(processor.patch_size is not None)
self.assertTrue(processor.vision_feature_select_strategy is not None)
inputs = processor(images=self.IMAGE, text=prompt, return_tensors="pt")
self.assertTrue(
(inputs.input_ids == ov_model.config.image_token_index).sum(1).max() >= ov_model.config.image_seq_length
)
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
set_seed(SEED)
ov_outputs = ov_model(**inputs)
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
ov_model.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
ov_model.config.eos_token_id = None
transformers_model.config.eos_token_id = None
gen_config = GenerationConfig(
max_new_tokens=30,
min_new_tokens=30,
num_beams=3,
do_sample=False,
eos_token_id=None,
)
set_seed(SEED)
ov_outputs = ov_model.generate(**inputs, generation_config=gen_config)
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model.generate(**inputs, generation_config=gen_config)
self.assertTrue(
torch.equal(ov_outputs, transformers_outputs),
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
)

del ov_model
del transformers_model
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_generate_utils(self, model_arch):
model_id = MODEL_NAMES[model_arch]
Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM",
"llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM",
"llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ",
"llava": "trl-internal-testing/tiny-random-LlavaForConditionalGeneration",
"llava": "katuni4ka/tiny-random-llava",
"llava_next": "katuni4ka/tiny-random-llava-next",
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"opt": "hf-internal-testing/tiny-random-OPTModel",
Expand Down

0 comments on commit c887610

Please sign in to comment.