Skip to content

Commit

Permalink
fix bug when IPEXCausalModel forward directly; fix bug when save_pret…
Browse files Browse the repository at this point in the history
…rained

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
  • Loading branch information
kaixuanliu committed Nov 8, 2024
1 parent 45130c9 commit f415f32
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def __init__(
self.model.to(self._device)
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
self.model_save_dir = model_save_dir
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
self._add_patch = _is_patched_with_ipex(model, self.export_feature)

self.input_names = set(inspect.signature(model.forward).parameters)

if self._is_ipex_exported:
if self._add_patch:
model = _patch_model(model)
# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
Expand Down Expand Up @@ -230,7 +230,6 @@ def _from_pretrained(
}

task = cls.export_feature
config.torch_dtype = torch_dtype
model = TasksManager.get_model_from_task(
task,
model_id,
Expand All @@ -240,15 +239,16 @@ def _from_pretrained(
_commit_hash=commit_hash,
**model_kwargs,
)
config = model.config
return cls(model, config=config, export=True, **kwargs)

def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
if getattr(self.config, "torchscript", None):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
torch.jit.save(self.model, output_path)
else:
logger.warning("The module is not a torchscript model, will be treated as a transformers model.")
self.model.save_pretrained(output_path)
self.model.save_pretrained(save_directory)

def forward(
self,
Expand Down Expand Up @@ -305,18 +305,14 @@ def can_generate(self):
return isinstance(self, GenerationMixin)

def _call_model(self, *args, **kwargs):
try:
with torch.autocast(self.device.type, self.dtype), torch.no_grad():
out = self.model(*args, **kwargs)
except RuntimeError:
out = self.model(*args, **kwargs)
out = self.model(*args, **kwargs)
return out

def _init_warmup(self):
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
# the results of the compute are unpredictable
# TODO : add warmup for IPEX exported model
if not self._is_ipex_exported:
if not self._add_patch:
use_cache = "past_key_values" in self.input_names
dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache)
if self._device.type != "cpu":
Expand Down Expand Up @@ -483,6 +479,10 @@ def forward(
inputs["position_ids"] = position_ids

if self.use_cache:
if past_key_values is None and self._add_patch:
max_length = self.config.max_length + input_ids.shape[1]
batch_size = input_ids.shape[0]
past_key_values = IPEXPagedCache(self.config, batch_size, max_length, input_ids.device, dtype=self.dtype)
inputs["past_key_values"] = past_key_values

# 2. Model forward
Expand Down Expand Up @@ -511,7 +511,7 @@ def _prepare_generation_config(
return generation_config, model_kwargs

def generate(self, *args, **kwargs):
if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None):
if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
Expand All @@ -523,9 +523,9 @@ def generate(self, *args, **kwargs):
transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("paged")
if kwargs.get("generation_config", None):
kwargs["generation_config"].cache_implementation = "paged"
if self._is_ipex_exported and kwargs.get("assistant_model", None):
if self._add_patch and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
elif self._is_ipex_exported:
elif self._add_patch:
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values

try:
Expand All @@ -535,7 +535,7 @@ def generate(self, *args, **kwargs):
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
raise e

if self._is_ipex_exported and kwargs.get("assistant_model", None):
if self._add_patch and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values

Expand Down

0 comments on commit f415f32

Please sign in to comment.