From f415f321aa92982d87028f52a37a09642a6c593e Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 8 Nov 2024 11:09:32 -0500 Subject: [PATCH] fix bug when IPEXCausalModel forward directly; fix bug when save_pretrained Signed-off-by: Liu, Kaixuan --- optimum/intel/ipex/modeling_base.py | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 270a9b32da..3939cfbcda 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -138,11 +138,11 @@ def __init__( self.model.to(self._device) self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 self.model_save_dir = model_save_dir - self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature) + self._add_patch = _is_patched_with_ipex(model, self.export_feature) self.input_names = set(inspect.signature(model.forward).parameters) - if self._is_ipex_exported: + if self._add_patch: model = _patch_model(model) # Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating # a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863 @@ -230,7 +230,6 @@ def _from_pretrained( } task = cls.export_feature - config.torch_dtype = torch_dtype model = TasksManager.get_model_from_task( task, model_id, @@ -240,15 +239,16 @@ def _from_pretrained( _commit_hash=commit_hash, **model_kwargs, ) + config = model.config return cls(model, config=config, export=True, **kwargs) def _save_pretrained(self, save_directory: Union[str, Path]): - output_path = os.path.join(save_directory, WEIGHTS_NAME) if getattr(self.config, "torchscript", None): + output_path = os.path.join(save_directory, WEIGHTS_NAME) torch.jit.save(self.model, output_path) else: logger.warning("The module is not a torchscript model, will be treated as a transformers model.") - self.model.save_pretrained(output_path) + self.model.save_pretrained(save_directory) def forward( self, @@ -305,18 +305,14 @@ def can_generate(self): return isinstance(self, GenerationMixin) def _call_model(self, *args, **kwargs): - try: - with torch.autocast(self.device.type, self.dtype), torch.no_grad(): - out = self.model(*args, **kwargs) - except RuntimeError: - out = self.model(*args, **kwargs) + out = self.model(*args, **kwargs) return out def _init_warmup(self): # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and # the results of the compute are unpredictable # TODO : add warmup for IPEX exported model - if not self._is_ipex_exported: + if not self._add_patch: use_cache = "past_key_values" in self.input_names dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache) if self._device.type != "cpu": @@ -483,6 +479,10 @@ def forward( inputs["position_ids"] = position_ids if self.use_cache: + if past_key_values is None and self._add_patch: + max_length = self.config.max_length + input_ids.shape[1] + batch_size = input_ids.shape[0] + past_key_values = IPEXPagedCache(self.config, batch_size, max_length, input_ids.device, dtype=self.dtype) inputs["past_key_values"] = past_key_values # 2. Model forward @@ -511,7 +511,7 @@ def _prepare_generation_config( return generation_config, model_kwargs def generate(self, *args, **kwargs): - if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None): + if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None): raise ValueError( f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" ) @@ -523,9 +523,9 @@ def generate(self, *args, **kwargs): transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("paged") if kwargs.get("generation_config", None): kwargs["generation_config"].cache_implementation = "paged" - if self._is_ipex_exported and kwargs.get("assistant_model", None): + if self._add_patch and kwargs.get("assistant_model", None): transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values - elif self._is_ipex_exported: + elif self._add_patch: transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values try: @@ -535,7 +535,7 @@ def generate(self, *args, **kwargs): transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values raise e - if self._is_ipex_exported and kwargs.get("assistant_model", None): + if self._add_patch and kwargs.get("assistant_model", None): transformers.generation.utils._crop_past_key_values = _crop_past_key_values transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values