Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Batch=32 / paged attention / and new RoPE module support to all Llama3 demo and tests #15327

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
722b71a
#13368: Add rope module to llama3 codebase
mtairum Nov 7, 2024
fa7cd65
#13368: Add paged attention to test llama attention. Update max batch…
mtairum Nov 8, 2024
3fd83f8
#13368: Fix KV cache file to be replicated instead of sharded
mtairum Nov 11, 2024
0328db6
#13368: Add paged attention support and batch=32 support for test dec…
mtairum Nov 11, 2024
a7d3d72
#13368: Add page attention and batch=32 support to test model. TODO i…
mtairum Nov 12, 2024
08fce34
#13368: Addedrope and paged attn support to llama demo. TODO: Check b…
mtairum Nov 13, 2024
2e38fb8
#0: Add llama rope
mtairum Nov 14, 2024
4c13276
#0: Fix llama demo with batch size > 1 and paged attn. TODO: code cle…
mtairum Nov 18, 2024
564fb20
#0: Fixed the llama tests: attn, attn-prefill, decoder-prefill, decod…
mtairum Nov 18, 2024
135df39
#0: Fix test llama model prefill
mtairum Nov 18, 2024
ba73d3c
#0: Relax PCC check for test_llama_accuracy
mtairum Nov 19, 2024
0165ddb
#0: All llama tests now compatible with paged attention and llama rope
mtairum Nov 19, 2024
b4f52a3
#0: Add support for batch sizes that are not divisible by tile size, …
avoraTT Nov 20, 2024
85e0155
#0: Fix assert
mtairum Nov 21, 2024
57f6fc9
#0: use ttnn.argmax multicore for 1 user
mtairum Nov 21, 2024
7d7536d
#0: [REVERT] Added mayo input
mtairum Nov 21, 2024
5cca54d
#0: Remove debug code to speed up demo
mtairum Nov 21, 2024
a967c69
#0: Update debug max seqlen
mtairum Nov 22, 2024
03c2e79
Add padding to position ids to support rope with batch < 32 in trace …
avoraTT Nov 22, 2024
c71c207
#0: Fix llama rope on-host device for single-chip
mtairum Nov 25, 2024
0ec756b
#0: Refactor llama3 test_attention
mtairum Nov 25, 2024
4ff627a
#0: Fix N150-8B demo. Minor fixes after rebase.
mtairum Nov 26, 2024
29770af
#0: Remove sliding_window references from llama3 codebase
mtairum Nov 26, 2024
a1ad346
#0: Remove references to kv_seq_len to simplify llama3 codebase
mtairum Nov 26, 2024
9616567
Update rope to do padding internally. Add comments explaining inconsi…
avoraTT Nov 26, 2024
834fecd
Add fix for accuracy test to work for batch > 1
avoraTT Nov 26, 2024
a7f633c
#0: Refactored all llama3 tests and demo code
mtairum Nov 26, 2024
b15f45d
Merge branch 'mtairum/paged_attn_llama3' of https://github.com/tensto…
mtairum Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 213 additions & 85 deletions models/demos/llama3/demo/demo.py

Large diffs are not rendered by default.

98 changes: 98 additions & 0 deletions models/demos/llama3/demo/mayo.json

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions models/demos/llama3/demo/simple_vision_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn
tt_model_args = TtModelArgs(mesh_device, max_batch_size=max_batch_size)
# limit length or we'll run out of space
tt_model_args.max_seq_len = max_seq_len
tt_model_args.kv_seq_len = max_seq_len
tt_model_args.sliding_window = max_seq_len
checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True)
model = CrossAttentionTransformer(
mesh_device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def test_llama_cross_attention_transformer_text_inference(
model_args = TtModelArgs(mesh_device, max_batch_size=batch)
# Limit the max seqlen to 4k to avoid OOM on host
model_args.max_seq_len = 4096
model_args.kv_seq_len = model_args.max_seq_len
model_args.sliding_window = model_args.max_seq_len

state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def test_llama_cross_attention_transformer_block_inference(
model_args = TtModelArgs(mesh_device, max_batch_size=batch)
# Limit the max seqlen to 4k to avoid OOM on host
model_args.max_seq_len = 4096
model_args.kv_seq_len = model_args.max_seq_len
model_args.sliding_window = model_args.max_seq_len
state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu"))

# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
Expand Down
156 changes: 116 additions & 40 deletions models/demos/llama3/tests/test_llama_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@
import os
import ttnn
from models.demos.llama3.tt.llama_common import (
get_single_rot_mat,
get_prefill_rot_mat,
get_rot_transformation_mat,
HostEmbedding,
PagedAttentionConfig,
)
from models.demos.llama3.tt.llama_model import TtTransformer
from models.demos.llama3.tt.model_config import TtModelArgs
from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer
from models.demos.llama3.demo.demo import preprocess_inputs_prefill


@torch.no_grad()
@pytest.mark.timeout(900)
@pytest.mark.parametrize("prefill_len", [512])
@pytest.mark.parametrize("decode_len", [128])
@pytest.mark.parametrize(
"prefill_len, decode_len, max_seq_len", # Max seqlen should be at least prefill_len + decode_len
((512, 128, 1024),),
)
@pytest.mark.parametrize(
"mesh_device",
[
Expand All @@ -32,15 +35,47 @@
],
indirect=True,
)
def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cache, reset_seeds):
@pytest.mark.parametrize(
"paged_attention",
(
True,
# False
),
ids=(
"paged_attention",
# "default_attention"
),
)
@pytest.mark.parametrize(
"paged_attention_params",
[{"page_block_size": 32, "page_max_num_blocks": 1024}],
)
@pytest.mark.parametrize(
"batch_size",
(1,),
)
def test_tt_model_accuracy(
prefill_len,
decode_len,
max_seq_len,
batch_size,
paged_attention,
paged_attention_params,
mesh_device,
use_program_cache,
reset_seeds,
ensure_gc,
):
dtype = ttnn.bfloat8_b
# TODO
min_top1_acc = 75
min_top5_acc = 96

mesh_device.enable_async(True)

# Load model args and tokenizer
model_args = TtModelArgs(mesh_device)
model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len)

tokenizer = Tokenizer(model_args.tokenizer_path)

# Load state_dict for TT model
Expand All @@ -62,13 +97,60 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac
N = prefill_len + decode_len
input_ids = reference_tokens[:, : N + 1] # Shape [1, N+1]

# Setup RoPE transformation matrices
rope_setup = TtLlamaRotarySetup(
mesh_device,
model_args.max_batch_size,
model_args.head_dim,
model_args.max_seq_len,
model_args.rope_theta,
model_args.use_scaled_rope,
)
transformation_mats_decode = rope_setup.get_trans_mats()

transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim)
transformation_mats_prefill = ttnn.from_torch(
transformation_mats_prefill_torch,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
)
transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill}

page_table_tt = None
paged_attention_config = None

if paged_attention:
paged_attention_config = PagedAttentionConfig(
block_size=paged_attention_params["page_block_size"],
max_num_blocks=paged_attention_params["page_max_num_blocks"],
)
# Implied shuffling of blocks
permutation = torch.randperm(paged_attention_config.max_num_blocks)
# Page table which maps virtual blocks to physical
reverse_permutation = torch.argsort(permutation)
page_table = reverse_permutation.reshape(
model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size
)
page_table_tt = ttnn.from_torch(
page_table,
device=mesh_device,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
)

# Initialize TT model
tt_model = TtTransformer(
args=model_args,
mesh_device=mesh_device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
transformation_mats=transformation_mats,
paged_attention_config=paged_attention_config,
)
# Initialize embedding
embd = HostEmbedding(model_args)
Expand Down Expand Up @@ -96,50 +178,41 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac
pt_prefill_input = [embd(input_tokens_prefill_pt[b]).view(1, prefill_lens[b], -1) for b in range(1)]

# Pre-compute the rotational embedding matrix and send to device
rot_mats = get_prefill_rot_mat(
rot_mats_prefill = get_prefill_rot_mat(
model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=prefill_lens[0]
)
transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim)
transformation_mats = ttnn.from_torch(
transformation_mat_torch,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=mesh_device,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

prefill_input = model_args.prepare_inputs_ttnn_prefill(
pt_prefill_input[batch_id],
)

tt_out = tt_model(
prefill_input,
None, # Current position
rot_mats,
transformation_mats,
current_pos=None,
rot_mats=rot_mats_prefill,
user_id=batch_id,
mode="prefill",
page_table=page_table_tt,
get_last_token=((decoding_pos[batch_id] - 1) // 32) * 32,
)

# Start decoding
logger.info(f"Starting decode...")
generation_start_pos = prefill_len
generation_length = decode_len
current_pos = ttnn.from_torch(
torch.tensor([generation_start_pos]),

# Initial positions
decoding_pos = [generation_start_pos] * model_args.max_batch_size
current_pos = torch.tensor([decoding_pos[b] for b in range(model_args.max_batch_size)])
current_pos_tensor = ttnn.from_torch(
current_pos,
device=mesh_device,
dtype=ttnn.int32,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
)

current_rot_mat, rot_matrix = get_single_rot_mat(
model_args.head_dim,
mesh_device,
model_args.num_devices,
start_pos=max(0, generation_start_pos - 1),
)
# Get cos/sin matrices for the current position of each user
rot_mats = rope_setup.get_rot_mats(current_pos)

# Print table header
logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}")
Expand All @@ -164,7 +237,13 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
# Run TT model
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
tt_out = tt_model(
decode_input,
current_pos_tensor,
rot_mats=rot_mats,
mode="decode",
page_table=page_table_tt,
)

if tt_model.args.num_devices > 1:
tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
Expand All @@ -173,23 +252,20 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac
tt_out_gathered = tt_out
tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True)
ttnn.deallocate(tt_out_gathered)
tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True)
tt_out_tok = ttnn.argmax(
tt_out_rm,
dim=3,
use_multicore=True if model_args.max_batch_size == 1 else False,
)
tt_argmax_token = ttnn.to_torch(tt_out_tok, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[
0, 0, 0, 0
]
ttnn.deallocate(tt_out_rm)
current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat)
ttnn.plus_one(current_pos)

# Reset rotation matrix every 100 iterations
if i % 100 == 0: # Doing this every 100 iterations as in demo takes top5 from 99% ->
current_rot_mat, rot_matrix_reset = get_single_rot_mat(
model_args.head_dim,
mesh_device,
model_args.num_devices,
start_pos=generation_start_pos + i,
on_host=False,
)
ttnn.plus_one(current_pos_tensor)

# Update rot_mats for next iteration
current_pos += 1
rot_mats = rope_setup.get_rot_mats(current_pos)

# Get reference top5 tokens and probabilities for this position
ref_top5_tokens = top5_tokens[prefill_len + i]
Expand Down
Loading
Loading