diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index b45818d27..5555b0b80 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -49,6 +49,66 @@ ) +class XPULinear2SiluMul(torch.nn.Module): + def __init__( + self, + gate_proj: torch.nn.Module, + up_proj: torch.nn.Module, + ): + super().__init__() + self.gate_proj_weight = gate_proj.weight.transpose(0, 1).contiguous() + self.up_proj_weight = up_proj.weight.transpose(0, 1).contiguous() + self.gate_proj_bias = gate_proj.bias + self.up_proj_bias = up_proj.bias + + def forward( + self, + hidden_states, + ): + up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) + if self.gate_proj_bias is not None: + up += self.gate_proj_bias + hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) + if self.up_proj_bias is not None: + hidden_states += self.up_proj_bias + return hidden_states + + +class XPULinearAdd(torch.nn.Module): + def __init__( + self, + module: torch.nn.Module, + ): + super().__init__() + self.weight = module.weight.transpose(0, 1).contiguous() + self.bias = module.bias + + def forward( + self, + hidden_states, + residual, + ): + token_len, _ = hidden_states.size() + if residual is None: + hidden_states = torch.matmul(hidden_states, self.weight) + if self.bias is not None: + hidden_states += self.bias + else: + if self.bias is not None: + hidden_states = torch.ops.torch_ipex.mm_bias_resadd( + hidden_states, self.weight, self.bias, 1.0, residual, 1.0 + ) + else: + hidden_states = torch.addmm( + residual.flatten(0, -2), + hidden_states.flatten(0, -2), + self.weight, + beta=1.0, + ) + hidden_states = hidden_states.view(token_len, -1) + return hidden_states + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 def _ipex_rms_layer_norm_forward(self, hidden_states): return rms_norm(hidden_states, self.weight, self.variance_epsilon) @@ -293,6 +353,10 @@ def __init__(self, module, config) -> None: if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: self.mha_linear_add = LinearAdd(module.o_proj) del self.__dict__["_modules"]["o_proj"] + elif self.module_device == "xpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = XPULinearAdd(module.o_proj) + del self.__dict__["_modules"]["o_proj"] def qkv_gemm(self, hidden_states): qkv_out = self.concat_qkv(hidden_states) @@ -359,6 +423,14 @@ def __init__(self, module, config) -> None: self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) del self.__dict__["_modules"]["gate_proj"] del self.__dict__["_modules"]["up_proj"] + elif self.module_device == "xpu": + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = XPULinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): if hasattr(self, "linear_silu_mul"):