Skip to content

Commit

Permalink
xpu mlp optimization
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi committed Oct 23, 2024
1 parent b341db6 commit 34ce74d
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,66 @@
)


class XPULinear2SiluMul(torch.nn.Module):
def __init__(
self,
gate_proj: torch.nn.Module,
up_proj: torch.nn.Module,
):
super().__init__()
self.gate_proj_weight = gate_proj.weight.transpose(0, 1).contiguous()
self.up_proj_weight = up_proj.weight.transpose(0, 1).contiguous()
self.gate_proj_bias = gate_proj.bias
self.up_proj_bias = up_proj.bias

def forward(
self,
hidden_states,
):
up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight)
if self.gate_proj_bias is not None:
up += self.gate_proj_bias
hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up)
if self.up_proj_bias is not None:
hidden_states += self.up_proj_bias
return hidden_states


class XPULinearAdd(torch.nn.Module):
def __init__(
self,
module: torch.nn.Module,
):
super().__init__()
self.weight = module.weight.transpose(0, 1).contiguous()
self.bias = module.bias

def forward(
self,
hidden_states,
residual,
):
token_len, _ = hidden_states.size()
if residual is None:
hidden_states = torch.matmul(hidden_states, self.weight)
if self.bias is not None:
hidden_states += self.bias
else:
if self.bias is not None:
hidden_states = torch.ops.torch_ipex.mm_bias_resadd(
hidden_states, self.weight, self.bias, 1.0, residual, 1.0
)
else:
hidden_states = torch.addmm(
residual.flatten(0, -2),
hidden_states.flatten(0, -2),
self.weight,
beta=1.0,
)
hidden_states = hidden_states.view(token_len, -1)
return hidden_states


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
def _ipex_rms_layer_norm_forward(self, hidden_states):
return rms_norm(hidden_states, self.weight, self.variance_epsilon)
Expand Down Expand Up @@ -293,6 +353,10 @@ def __init__(self, module, config) -> None:
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)
del self.__dict__["_modules"]["o_proj"]
elif self.module_device == "xpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = XPULinearAdd(module.o_proj)
del self.__dict__["_modules"]["o_proj"]

def qkv_gemm(self, hidden_states):
qkv_out = self.concat_qkv(hidden_states)
Expand Down Expand Up @@ -359,6 +423,14 @@ def __init__(self, module, config) -> None:
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
del self.__dict__["_modules"]["gate_proj"]
del self.__dict__["_modules"]["up_proj"]
elif self.module_device == "xpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = XPULinearAdd(module.down_proj)
del self.__dict__["_modules"]["down_proj"]
self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj)
del self.__dict__["_modules"]["gate_proj"]
del self.__dict__["_modules"]["up_proj"]

def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
if hasattr(self, "linear_silu_mul"):
Expand Down

0 comments on commit 34ce74d

Please sign in to comment.