Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 5, 2024
1 parent 25b42c5 commit c96bb24
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
5 changes: 5 additions & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,11 @@ def half(self):
compress_model_transformation(model)
return self

def to(self, device):
self.language_model.to(device)
super().to(device)
return self

def forward(
self,
input_ids,
Expand Down
32 changes: 31 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,7 +1944,7 @@ def test_compare_to_transformers(self, model_arch):
inputs = self.gen_inputs(model_arch, "What is shown on this image?", self.IMAGE)

ov_model = OVModelForVisualCausalLM.from_pretrained(
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS, compile=False
)
self.assertIsInstance(ov_model, MODEL_TYPE_TO_CLS_MAPPING[ov_model.config.model_type])
self.assertIsInstance(ov_model.vision_embeddings, OVVisionEmbedding)
Expand All @@ -1953,6 +1953,27 @@ def test_compare_to_transformers(self, model_arch):
self.assertTrue(hasattr(ov_model, additional_part))
self.assertIsInstance(getattr(ov_model, additional_part), MODEL_PARTS_CLS_MAPPING[additional_part])
self.assertIsInstance(ov_model.config, PretrainedConfig)
ov_model.to("AUTO")
self.assertTrue("AUTO" in ov_model._device)
self.assertTrue("AUTO" in ov_model.vision_embeddings._device)
self.assertTrue(ov_model.vision_embeddings.request is None)
self.assertTrue("AUTO" in ov_model.language_model._device)
self.assertTrue(ov_model.language_model.request is None)
self.assertTrue(ov_model.language_model.text_emb_request is None)
for additional_part in ov_model.additional_parts:
self.assertTrue("AUTO" in getattr(ov_model, additional_part)._device)
self.assertTrue(getattr(ov_model, additional_part).request is None)
ov_model.to("CPU")
ov_model.compile()
self.assertTrue("CPU" in ov_model._device)
self.assertTrue("CPU" in ov_model.vision_embeddings._device)
self.assertTrue(ov_model.vision_embeddings.request is not None)
self.assertTrue("CPU" in ov_model.language_model._device)
self.assertTrue(ov_model.language_model.request is not None)
self.assertTrue(ov_model.language_model.text_emb_request is not None)
for additional_part in ov_model.additional_parts:
self.assertTrue("CPU" in getattr(ov_model, additional_part)._device)
self.assertTrue(getattr(ov_model, additional_part).request is not None)
# pytorch minicpmv is not designed to be used via forward
if "minicpmv" not in model_arch:
set_seed(SEED)
Expand Down Expand Up @@ -2015,6 +2036,15 @@ def test_generate_utils(self, model_arch):

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_model_can_be_loaded_after_saving(self, model_arch):
model_id = MODEL_NAMES[model_arch]
with TemporaryDirectory() as save_dir:
ov_model = OVModelForVisualCausalLM.from_pretrained(model_id, compile=False)
ov_model.save_pretrained(save_dir)
ov_restored_model = OVModelForVisualCausalLM.from_pretrained(save_dir, compile=False)
self.assertIsInstance(ov_restored_model, type(ov_model))


class OVModelForSpeechSeq2SeqIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("whisper",)
Expand Down

0 comments on commit c96bb24

Please sign in to comment.