Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unify xpu and cpu backend and use paged attention #1009

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1c35c4f
add page attention implementation remove jit logic
sywangyi Oct 9, 2024
973e034
add support in transformers 4.45
sywangyi Oct 9, 2024
8b574d0
fix congif (#935)
jiqing-feng Oct 10, 2024
541a236
move patch model to init
sywangyi Oct 10, 2024
35cd0c1
refine class IPEXPagedCache's update method (#945)
kaixuanliu Oct 17, 2024
80e8071
fix bug when doing beam search (#954)
kaixuanliu Oct 18, 2024
184faea
enable qkv concat layer (#958)
jiqing-feng Oct 23, 2024
b341db6
add xpu cache optimiztion
sywangyi Oct 23, 2024
34ce74d
xpu mlp optimization
sywangyi Oct 23, 2024
45130c9
optimize cache ops in xpu, improve for beam search
sywangyi Oct 24, 2024
74eec8b
enable gpt2, falcon has core dump error in PagedAttention.single_quer…
jiqing-feng Nov 5, 2024
76d32be
fix unit test case, CPU part is OK; Enable Falcon7b for XPU (#992)
kaixuanliu Nov 13, 2024
039c72d
skip assited decoding unit test for models using paged attention (#998)
kaixuanliu Nov 22, 2024
459c78c
Merge branch 'main' into paged_attn
sywangyi Nov 22, 2024
1ab0233
fix ci config (#1010)
jiqing-feng Nov 22, 2024
b0cd5db
Fix tests versions (#1011)
jiqing-feng Nov 22, 2024
e31e6d4
fix torch test version (#1012)
jiqing-feng Nov 22, 2024
ed35ffc
use python3.9 test (#1013)
jiqing-feng Nov 22, 2024
a5c48a8
change ipex transformers limited verison in setup (#1015)
jiqing-feng Nov 22, 2024
388265f
add XPU LinearAddAdd op (#1017)
kaixuanliu Nov 22, 2024
ad9b795
fix bert and vit patch (#1022)
jiqing-feng Nov 25, 2024
0d7f8b6
Merge branch 'main' into paged_attn
IlyasMoutawwakil Nov 25, 2024
b48192b
Paged attn (#1024)
jiqing-feng Nov 25, 2024
8a8e7e3
set device as the same as origin model (#1031)
jiqing-feng Nov 26, 2024
bcce6b0
Simplify IPEXModel (#1032)
jiqing-feng Nov 26, 2024
51030e5
nice code (#1035)
kaixuanliu Nov 27, 2024
587837e
Paged attn (#1036)
kaixuanliu Nov 27, 2024
6ddf93e
Enable torch.compile for non-generation tasks in CPU (#1037)
jiqing-feng Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_inc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
torch-version: ["2.4.*", "2.5.0"]
torch-version: ["2.4.0", "2.5.*"]

runs-on: ubuntu-22.04

Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
strategy:
fail-fast: false
matrix:
torch-version: ["2.2.0", "2.3.*", "2.4.*"]
transformers-version: ["4.39.0", "4.44.*"]
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
transformers-version: ["4.46.*"]
torch-version: ["2.4.0", "2.5.0"]
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved

runs-on: ubuntu-22.04

Expand All @@ -38,10 +38,6 @@ jobs:
pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }}

- if: ${{ matrix.torch-version == '2.2.0' }}
name: Downgrade Numpy
run: pip install numpy==1.*

- name: Assert versions
run: |
python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))"
Expand Down
238 changes: 238 additions & 0 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from typing import List, Optional, Tuple

import torch
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig


class IPEXPagedCache(Cache):
"""
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
ipex-xpu:
ipex-cpu:

Example:

```python
>>> from transformers import AutoTokenizer
>>> from optimum.intel import IPEXModelForCausalLM
>>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache

>>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True)
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")

>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = IPEXPagedCache()
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""

def __init__(
self,
config: PretrainedConfig,
batch_size: int,
max_cache_len: int,
device,
dtype=None,
layer_device_map=None,
**kwargs,
) -> None:
super().__init__()
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
)
self.free_blocks = torch.arange(self.num_blocks, device=device)
self.max_cache_len = max_cache_len
self.num_kv_heads = config.num_key_value_heads
self.num_hidden_layers = config.num_hidden_layers
if hasattr(config, "head_dim"):
head_size = config.head_dim
else:
head_size = config.hidden_size // config.num_attention_heads
self.head_size = head_size
self.max_seq_len = 0

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

if device.type == "cpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
elif device.type == "xpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
for i in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update_for_prefill(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
batch_size: int,
input_lens: torch.Tensor,
):
if layer_idx == 0:
all_block_indices = []
all_slot_offsets = []
num_blocks = (input_lens + self.block_size - 1) // self.block_size
for i in range(batch_size):
for b_idx in range(num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

slots_range = torch.arange(input_lens[i], device=key_states.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
all_block_indices.append(self.block_tables[i][block_indices])
all_slot_offsets.append(slot_offsets)

all_block_indices = torch.cat(all_block_indices)
all_slot_offsets = torch.cat(all_slot_offsets)
self.slots = all_block_indices * self.block_size + all_slot_offsets

# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
self._seen_tokens = self._seen_tokens + input_lens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)

def update_for_decode(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
batch_size: int,
):
if layer_idx == 0:
start_block_idx = self._seen_tokens // self.block_size
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
for i in range(batch_size):
for b_idx in range(start_block_idx[i], num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
self._seen_tokens = self._seen_tokens + 1
self.max_seq_len = self.max_seq_len + 1

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
attention_mask: torch.Tensor,
input_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
Return:
A tuple containing the updated key and value states.
"""

batch_size = input_lens.shape[-1]
if self.get_seq_length() == 0:
# prefill
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens)
else:
# decode
self.update_for_decode(key_states, value_states, layer_idx, batch_size)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
return self.max_seq_len

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len

def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
self.max_seq_len = 0

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.block_tables.device
origin_table = self.block_tables.clone()
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
num_blocks = mask.cumsum(-1)[:, -1]
updated_table = []
for i in range(beam_idx.shape[0]):
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
updated_table = torch.cat(tuple(updated_table), dim=0)
for layer_idx in range(self.num_hidden_layers):
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""

max_seq_len = self.get_seq_length()
if maximum_length < 0:
maximum_length = max_seq_len - abs(maximum_length)

if max_seq_len <= maximum_length:
return
origin_table = self.block_tables.clone()
for bs in range(self._seen_tokens.shape[0]):
new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len
num_blocks = (new_tokens + self.block_size - 1) // self.block_size
self.block_tables[bs, num_blocks:] = -1
self._seen_tokens[bs] = new_tokens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
33 changes: 19 additions & 14 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
Expand All @@ -28,7 +27,8 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_gpt2_block_forward,
_falcon_model_forward,
_gpt2_model_forward,
_ipex_rms_layer_norm_forward,
_IPEXFalconDecoderLayer,
_IPEXGPT2Attention,
Expand All @@ -39,8 +39,8 @@


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.44.99"
_TRANSFORMERS_MIN_VERSION = "4.46.0"
_TRANSFORMERS_MAX_VERSION = "4.46.99"

_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)

Expand Down Expand Up @@ -75,7 +75,7 @@ def patch_op(m, target_m, new_op_name, new_op):
def _patch_llama_model(model):
"""
Patch llama model:
1. Use IPEX Rope and IAKV cache
1. Use IPEX Rope and Paged cache
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
"""
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
Expand All @@ -88,10 +88,14 @@ def _patch_falcon_model(model):
"""
Patch falcon model:
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IPEX Rope and IAKV cache
2. Use IPEX Rope and paged cache
3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
"""
model.transformer._use_sdpa = False
num_key_value_heads = (
model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
)
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
replace_customized_linear_with_linear(model)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
return model
Expand All @@ -103,9 +107,10 @@ def _patch_gpt2_model(model):
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IAKV cache
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
"""
model.transformer._attn_implementation = "eager"
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
return model


Expand Down Expand Up @@ -136,11 +141,11 @@ def _patch_model(model):
raise ImportError(
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
)
if isinstance(model, LlamaForCausalLM):
if model.config.model_type == "llama":
model = _patch_llama_model(model)
elif isinstance(model, FalconForCausalLM):
elif model.config.model_type == "falcon":
model = _patch_falcon_model(model)
elif isinstance(model, GPT2LMHeadModel):
elif model.config.model_type == "gpt2":
model = _patch_gpt2_model(model)
elif model.config.model_type == "bert":
model = _patch_bert_model(model)
Expand Down
Loading
Loading