Skip to content

Commit

Permalink
Simplify IPEXModel (#1032)
Browse files Browse the repository at this point in the history
* simplify forward and save pretrained since no jit support

* fix format

* rm warmup because no jit mode anymore

* simplify forward for causal lm model

* fix paged pkv  forward

* disable use_cache when just run forward

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng authored Nov 26, 2024
1 parent 8a8e7e3 commit bcce6b0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 187 deletions.
189 changes: 7 additions & 182 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import copy
import inspect
import logging
import os
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -41,26 +40,20 @@
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.generation.candidate_generator import _crop_past_key_values
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.exporters.tasks import make_backend_config_constructor_for_task
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ...exporters.ipex.cache_utils import IPEXPagedCache
from ...exporters.ipex.model_config import ipex_onnx_config
from ...exporters.ipex.model_patcher import (
_IPEX_EXPORTED_GENERATION_TASKS,
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_patch_model,
)
from ..generation.modeling import get_float_type
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_ipex_version, is_transformers_version
from ..utils.modeling_utils import recursive_to_device


logger = logging.getLogger(__name__)
Expand All @@ -78,38 +71,6 @@ def _is_patched_with_ipex(model, task, use_cache: bool = True):
return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES


def _prepare_inputs_for_ipex_model(model, task, use_cache):
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
if _is_patched_with_ipex(model, task, use_cache) and model.config.model_type in ipex_onnx_config:
onnx_config_class = make_backend_config_constructor_for_task(
ipex_onnx_config[model.config.model_type], task=task
)
else:
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
float_dtype = get_float_type(model.dtype)
if "text-generation" in task:
onnx_config = onnx_config_class(
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
)
else:
onnx_config = onnx_config_class(model.config)

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")

# Check attention_mask shape
if _is_patched_with_ipex(model, task, use_cache) and model.config.model_type in ipex_onnx_config:
past_len = dummy_inputs["past_key_values"][0][0].shape[-2]
input_len = dummy_inputs["input_ids"].shape[-1]
attention_len = dummy_inputs["attention_mask"].shape[-1]
if attention_len != input_len + past_len:
dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to(
dummy_inputs["input_ids"].dtype
)

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


class IPEXModel(OptimizedModel):
auto_model_class = AutoModel
export_feature = "feature-extraction"
Expand All @@ -123,7 +84,6 @@ def __init__(
config: PretrainedConfig = None,
export: bool = False,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
warmup: bool = True,
**kwargs,
):
config = config or model.config
Expand All @@ -143,8 +103,6 @@ def __init__(
AutoConfig.register(self.base_model_prefix, AutoConfig)
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)
if warmup:
self._init_warmup()

@classmethod
def _from_transformers(cls, *args, **kwargs):
Expand Down Expand Up @@ -233,39 +191,10 @@ def _from_pretrained(
return cls(model, config=config, export=True, **kwargs)

def _save_pretrained(self, save_directory: Union[str, Path]):
if getattr(self.config, "torchscript", None):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
torch.jit.save(self.model, output_path)
else:
logger.warning("The module is not a torchscript model, will be treated as a transformers model.")
self.model.save_pretrained(save_directory, safe_serialization=False)

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
position_ids: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
self.model.save_pretrained(save_directory, safe_serialization=False)

if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

if "position_ids" in self.input_names:
inputs["position_ids"] = position_ids

outputs = self._call_model(**inputs)
if isinstance(outputs, dict):
model_output = ModelOutput(**outputs)
else:
model_output = ModelOutput()
model_output[self.output_name] = outputs[0]
return model_output
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def eval(self):
self.model.eval()
Expand All @@ -291,28 +220,12 @@ def add_patch(self) -> bool:
return self._add_patch

def to(self, device: Union[torch.device, str]):
self.model.to(self.device)
self.model.to(device)
return self

def can_generate(self):
return isinstance(self, GenerationMixin)

def _call_model(self, *args, **kwargs):
out = self.model(*args, **kwargs)
return out

def _init_warmup(self):
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
# the results of the compute are unpredictable
# TODO : add warmup for IPEX exported model
if not self._add_patch:
# use_cache = "past_key_values" in self.input_names
dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, self.use_cache)
if self.device.type != "cpu":
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self.device)
for _ in range(2):
self(**dummy_inputs)


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand All @@ -336,64 +249,16 @@ class IPEXModelForImageClassification(IPEXModel):
auto_model_class = AutoModelForImageClassification
export_feature = "image-classification"

def forward(
self,
pixel_values: torch.Tensor,
**kwargs,
):
inputs = {
"pixel_values": pixel_values,
}

outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForAudioClassification(IPEXModel):
auto_model_class = AutoModelForAudioClassification
export_feature = "audio-classification"

def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_values": input_values,
}

if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForQuestionAnswering(IPEXModel):
auto_model_class = AutoModelForQuestionAnswering
export_feature = "question-answering"

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self._call_model(**inputs)
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
return ModelOutput(start_logits=start_logits, end_logits=end_logits)


class IPEXModelForCausalLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForCausalLM
Expand All @@ -406,13 +271,9 @@ def __init__(
export: bool = False,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: bool = True,
**kwargs,
):
# Perform the initial warmup at the end of __init__
super().__init__(
model, config, export=export, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache
)
super().__init__(model, config, export=export, model_save_dir=model_save_dir, use_cache=use_cache)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
Expand Down Expand Up @@ -442,50 +303,14 @@ def __init__(
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
if warmup:
self._init_warmup()

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
# 1. Prepare model inputs
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if "position_ids" in self.input_names or not self.input_names:
inputs["position_ids"] = position_ids

if self.use_cache:
if past_key_values is None and self._add_patch:
max_length = self.config.max_length + input_ids.shape[1]
batch_size = input_ids.shape[0]
past_key_values = IPEXPagedCache(
self.config, batch_size, max_length, input_ids.device, dtype=self.dtype
)
inputs["past_key_values"] = past_key_values

# 2. Model forward
outputs = self._call_model(**inputs)

# 3. Process model outputs
if isinstance(outputs, (list, tuple)):
logits = outputs[0]
past_key_values = outputs[1] if self.use_cache else None
else:
logits = outputs["logits"]
past_key_values = outputs["past_key_values"] if self.use_cache else None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
Expand Down
13 changes: 8 additions & 5 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,11 @@ def test_compare_to_transformers(self, model_arch):
dtype = torch.float32
if IS_XPU:
dtype = torch.float16
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, torch_dtype=dtype)
# Test model forward do not need cache.
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, torch_dtype=dtype, use_cache=False)
device = ipex_model.device
self.assertIsInstance(ipex_model.config, PretrainedConfig)
self.assertTrue(ipex_model.use_cache)
self.assertFalse(ipex_model.use_cache)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
"This is a sample",
Expand All @@ -238,18 +239,20 @@ def test_compare_to_transformers(self, model_arch):

self.assertIsInstance(outputs.logits, torch.Tensor)

transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, use_cache=False).to(
device
)
with torch.no_grad():
transformers_outputs = transformers_model(**tokens)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, use_cache=False)
loaded_model_outputs = loaded_model(**inputs)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True)
init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True, use_cache=False)
init_model_outputs = init_model(**inputs)

# Compare tensor outputs
Expand Down

0 comments on commit bcce6b0

Please sign in to comment.