diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index 837e03a5dbc..49622ad3482 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -15,20 +15,27 @@ from pathlib import Path import hashlib +from models.utility_functions import nearest_32 from models.demos.llama3.tt.llama_common import ( - get_single_rot_mat, get_prefill_rot_mat, get_rot_transformation_mat, HostEmbedding, encode_prompt_llama_instruct, + PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer +from models.demos.llama3.tt.model_config import TtModelArgs from models.perf.benchmarking_utils import BenchmarkProfiler from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf +from models.utility_functions import ( + comp_pcc, +) + def load_and_cache_context(context_url, cache_dir): cache_file = cache_dir / hashlib.md5(context_url.encode()).hexdigest() @@ -65,6 +72,7 @@ def load_inputs(user_input, batch): cache_dir = Path("models/demos/llama3/demo/context_cache") cache_dir.mkdir(parents=True, exist_ok=True) + # TODO Miguel: Clip the long prompt to actually fit within token limit for i in range(batch): prompt = user_input[i]["prompt"] if "context" in user_input[i]: @@ -152,7 +160,18 @@ def preprocess_inputs_prefill( def run_llama3_demo( - user_input, batch_size, single_layer, mesh_device, instruct_mode, is_ci_env, num_batches, print_to_file + user_input, + mesh_device, + max_seq_len, + batch_size, + num_batches, + paged_attention, + paged_attention_config, + max_generated_tokens, + single_layer, + instruct_mode, + is_ci_env, + print_to_file, ): # Creat batch output file timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -161,10 +180,9 @@ def run_llama3_demo( os.chmod(output_directory, 0o755) output_filename = f"{output_directory}/demo_user_output_{timestamp}.txt" - # This module requires the env paths above for CI runs - from models.demos.llama3.tt.model_config import TtModelArgs - dtype = ttnn.bfloat8_b + assert batch_size <= 32, "Max batch size currently supported is 32" + assert max_seq_len <= 128 * 1024, "Max sequence length must be less than 128k tokens" # We disregard any warmup iteration for profiling, in favour of just measuring compile time on the first iteration N_warmup_iter = {"inference_prefill": 0, "inference_decode": 0} @@ -188,8 +206,10 @@ def run_llama3_demo( for i in range(num_batches): batch_prompts.append([input_prompts[(j + i) % len(input_prompts)] for j in range(len(input_prompts))]) + # TODO Miguel Add configuration for the combinations of Llama3 models and TT architectures and max supported sizes + # Load model args, weights, and tokenizer - model_args = TtModelArgs(mesh_device, instruct=instruct_mode) + model_args = TtModelArgs(mesh_device, instruct=instruct_mode, max_batch_size=batch_size, max_seq_len=max_seq_len) tokenizer = Tokenizer(model_args.tokenizer_path) if single_layer: @@ -200,6 +220,46 @@ def run_llama3_demo( state_dict = model_args.load_state_dict() profiler.end("weight_loading") + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( + mesh_device, + batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, + ) + transformation_mats_decode = rope_setup.get_trans_mats() + + transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim) + transformation_mats_prefill = ttnn.from_torch( + transformation_mats_prefill_torch, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill} + + page_table_tt = None + + if paged_attention: + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + # Load TTNN Llama3.1 model logger.info("Loading weights to device...") profiler.start("loading_weights_to_device") @@ -209,6 +269,8 @@ def run_llama3_demo( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) tt_embd = TtLlamaEmbedding( mesh_device=mesh_device, @@ -223,7 +285,6 @@ def run_llama3_demo( profiler.end("loading_weights_to_device") logger.info("Finished loading weights to device.") - max_generated_tokens = 100 # Maximum number of tokens to generate per user num_tokens_generated_decode = [] logger.info("Starting inference...") @@ -243,6 +304,12 @@ def run_llama3_demo( instruct_mode, max_generated_tokens, ) + + max_encoded_prompt_len = max(len(p) for p in encoded_prompts) + assert ( + max_generated_tokens + max_encoded_prompt_len <= max_seq_len + ), f"Prompt prefill tokens ({max_encoded_prompt_len}) + maximum number of decoded iterations ({max_generated_tokens}) needs to be <= than max_seq_len ({max_seq_len})" + # Prefill embeddings are on host since we need to mask out the tokens after the prefill length after embeddings are computed pt_prefill_input = [embd(input_tokens_prefill_pt[b]).view(1, prefill_lens[b], -1) for b in range(batch_size)] profiler.end(f"preprocess_prefill_inputs", iteration=batch_idx) @@ -256,18 +323,6 @@ def run_llama3_demo( logger.info(f"Starting prefill...") - profiler.start(f"prepare_rot_mat_for_prefill", iteration=batch_idx) - transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.from_torch( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - profiler.end(f"prepare_rot_mat_for_prefill", iteration=batch_idx) - # Do not count the first user for prefill time and instead log it as compile time num_users_generated_prefill = batch_size - 1 if batch_size > 1 else 1 @@ -293,11 +348,11 @@ def run_llama3_demo( tt_out = tt_model( prefill_input, - None, # Current position - rot_mats_prefill, - transformation_mats, + current_pos=None, + rot_mats=rot_mats_prefill, user_id=batch_id, mode="prefill", + page_table=page_table_tt, get_last_token=((decoding_pos[batch_id] - 1) // 32) * 32, ) @@ -311,11 +366,11 @@ def run_llama3_demo( ttnn.deallocate(tt_out) tt_out = tt_model( prefill_input, - None, # Current position - rot_mats_prefill, - transformation_mats, + current_pos=None, + rot_mats=rot_mats_prefill, user_id=batch_id, mode="prefill", + page_table=page_table_tt, get_last_token=((decoding_pos[batch_id] - 1) // 32) * 32, ) @@ -336,8 +391,11 @@ def run_llama3_demo( profiler.start(f"prepare_first_decode_token_{batch_idx}") pt_out_batched = torch.stack(pt_out, dim=-2) pt_out_batched = torch.argmax(pt_out_batched, dim=-1) + # Pad the output tensor to be tile sized tt_out_tok = ttnn.from_torch( - torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), + torch.nn.functional.pad( + pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 32 - len(pt_out_batched)), "constant", 0 + ), device=mesh_device, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), dtype=ttnn.uint32, @@ -354,32 +412,34 @@ def run_llama3_demo( logger.info("Starting decode...") - profiler.start(f"get_single_rot_mat_decode_{batch_idx}") - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=decoding_pos[0] - 2, - ) - profiler.end(f"get_single_rot_mat_decode_{batch_idx}") - # Create events profiler.start(f"compile_trace_{batch_idx}") op_event = ttnn.create_event(mesh_device) write_event = ttnn.create_event(mesh_device) - current_pos = ttnn.from_torch( - torch.tensor(decoding_pos, dtype=torch.int32), + # Initial positions + current_pos = torch.tensor([decoding_pos[b] for b in range(batch_size)]) + + current_pos_tensor = ttnn.from_torch( + current_pos, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + # Get cos/sin matrices for the current position of each user + rot_mats, rot_mat_idxs = rope_setup.get_rot_mats(current_pos, return_rot_idxs=True) # Compile logger.info(f"Compiling model trace...") decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) if tt_model.args.num_devices > 1: tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) ttnn.deallocate(tt_out) @@ -387,11 +447,11 @@ def run_llama3_demo( tt_out_gathered = tt_out tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True) ttnn.deallocate(tt_out_gathered) - tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) + tt_out_tok = ttnn.argmax( + tt_out_rm, dim=3, use_multicore=False if batch_size > 1 else True, output_tensor=tt_out_tok + ) ttnn.deallocate(tt_out_rm) - new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) - ttnn.plus_one(current_pos) + ttnn.plus_one(current_pos_tensor) profiler.end(f"compile_trace_{batch_idx}") # Capture Trace @@ -401,7 +461,14 @@ def run_llama3_demo( decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + rot_mats = rope_setup.get_rot_mats(rot_mat_idxs) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) if tt_model.args.num_devices > 1: tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) ttnn.deallocate(tt_out) @@ -409,28 +476,35 @@ def run_llama3_demo( tt_out_gathered = tt_out tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True) ttnn.deallocate(tt_out_gathered) - tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) + tt_out_tok = ttnn.argmax( + tt_out_rm, dim=3, use_multicore=False if batch_size > 1 else True, output_tensor=tt_out_tok + ) # TODO Multicore is not compatible with batch > 1 ttnn.deallocate(tt_out_rm) - new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) - ttnn.plus_one(current_pos) + ttnn.plus_one(current_pos_tensor) + # ttnn.plus_one(rot_mat_idxs) # TODO <- This won't work since embedding requires uint32 and plus_one only works for int32 ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) # Reset the decoding position for the proper run of the model current_pos_reset = ttnn.from_torch( - torch.tensor(decoding_pos, dtype=torch.int32), + current_pos, dtype=ttnn.int32, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if tt_model.args.num_devices > 1 else None, ) tt_out_tok_reset = ttnn.from_torch( - torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), + torch.nn.functional.pad( + pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 32 - len(pt_out_batched)), "constant", 0 + ), + # torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 30), "constant", 0), dtype=ttnn.uint32, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if tt_model.args.num_devices > 1 else None, ) - ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos) + # Reset the current position and output token tensors for the real decode run + ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos_tensor) ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok) + rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos, on_host=True) + ttnn.copy_host_to_device_tensor(rot_mat_idxs_reset, rot_mat_idxs) profiler.end(f"capture_trace_{batch_idx}") @@ -454,6 +528,13 @@ def run_llama3_demo( ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=True) ttnn.record_event(0, op_event) + # Update current pos and mat idxs on host and send to device + # TODO This is required for now since we cannot ttnn.plus_one(rot_mat_idxs) while it being uint32. + # If this tensor is int32, it won't be supported by ttnn.embedding + current_pos += 1 + rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos, on_host=True) + ttnn.copy_host_to_device_tensor(rot_mat_idxs_updated, rot_mat_idxs) + # Write to host ttnn.wait_for_event(1, op_event) tt_output_torch = ttnn.to_torch( @@ -506,19 +587,6 @@ def run_llama3_demo( iteration += 1 - # Reset rotation matrix every 100 iterations - profiler.start(f"reset_rot_mat_{iteration-1}", iteration=batch_idx) - if iteration % 100 == 0: - current_rot_mat_reset, rot_matrix_reset = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=decoding_pos[0] + iteration, - on_host=True, - ) - ttnn.copy_host_to_device_tensor(current_rot_mat_reset, current_rot_mat) - profiler.end(f"reset_rot_mat_{iteration-1}", iteration=batch_idx) - # Upper limit of generated tokens for each user (to avoid infinite generation in case eos is not seen) if iteration >= max_generated_tokens: users_decoding = False @@ -593,13 +661,10 @@ def run_llama3_demo( "loading_inputs": profiler.get_duration("loading_inputs"), "weight_loading": profiler.get_duration("weight_loading"), "prepare_first_decode_token": profiler.get_duration("prepare_first_decode_token_0"), - "get_single_rot_mat_decode": profiler.get_duration("get_single_rot_mat_decode_0"), # Only for batch 0 "preprocess_prefill_inputs": profiler.get_duration("preprocess_prefill_inputs"), "loading_weights_to_device": profiler.get_duration("loading_weights_to_device"), - "prepare_rot_mat_for_prefill": profiler.get_duration("prepare_rot_mat_for_prefill"), "compile_trace": profiler.get_duration("compile_trace_0"), # Only for batch 0 "capture_trace": profiler.get_duration("capture_trace_0"), # Only for batch 0 - "reset_rot_mat": sum(profiler.get_duration(f"reset_rot_mat_{i}") for i in range(max_generated_tokens)), "Total compile time": compile_prefill_time + compile_decode_time, "Full demo runtime": profiler.get_duration("run"), } @@ -699,23 +764,30 @@ def run_llama3_demo( ) +# input_prompts: Input file size with prompts to process +# max_seq_len: Maximum sequence length supported by the model (max size = 128 * 1024) +# instruct_weights: Whether to use instruct weights or general weights +# Num_batches: How many consecutive batches of users are run +# single_layer: Whether to run the model with a single layer (for debug) @pytest.mark.parametrize( - "input_prompts, instruct_weights, num_batches, single_layer", + "input_prompts, max_seq_len, instruct_weights, num_batches, single_layer", [ - ("models/demos/llama3/demo/input_data_prefill_128.json", False, 1, False), - ("models/demos/llama3/demo/input_data_prefill_128.json", False, 2, False), - ("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 1, False), - ("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 2, False), - ("models/demos/llama3/demo/input_data_long.json", True, 1, False), - ("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 1, True), + ("models/demos/llama3/demo/input_data_prefill_128.json", 1024, False, 1, False), + ("models/demos/llama3/demo/input_data_prefill_128.json", 1024, False, 2, False), + ("models/demos/llama3/demo/input_data_questions_prefill_128.json", 1024, True, 1, False), + ("models/demos/llama3/demo/input_data_questions_prefill_128.json", 1024, True, 2, False), + ("models/demos/llama3/demo/input_data_long.json", 128 * 1024, True, 1, False), + ("models/demos/llama3/demo/input_data_questions_prefill_128.json", 1024, True, 1, True), + ("models/demos/llama3/demo/mayo.json", 1024, True, 1, False), ], ids=[ - "general_weights-1_batch", - "general_weights-2_batch", - "instruct_weights-1_batch", - "instruct_weights-2_batch", - "instruct_weights-long", + "general-1_batch", + "general-2_batch", + "instructs-1_batch", + "instruct-2_batch", + "instruct-long", "single_layer", + "mayo", # TODO Miguel: Remove this debug test ], ) @pytest.mark.parametrize("device_params", [{"trace_region_size": 23887872, "num_command_queues": 2}], indirect=True) @@ -728,21 +800,77 @@ def run_llama3_demo( ], indirect=True, ) +# NOTE: Varying the batch size will result in slightly different outputs. +# For example, a prompt w/ 1 user vs, the same prompt repeated N times for N users, will produce different outputs +# This is because the SDPA op in decode mode has different number of reductions depending on the number of users +# Which leads to slightly different outputs (due to accumulated errors) +@pytest.mark.parametrize( + "batch_size", + ( + 1, + 32, + ), +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( # TODO Substitute these values for a proper vLLM integration + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "max_generated_tokens", # Maximum number of tokens to decode, per user + (100,), +) def test_llama_demo( - mesh_device, use_program_cache, input_prompts, instruct_weights, is_ci_env, num_batches, single_layer, reset_seeds + input_prompts, + max_seq_len, + instruct_weights, + batch_size, + num_batches, + paged_attention, + paged_attention_params, + max_generated_tokens, + single_layer, + mesh_device, + use_program_cache, + is_ci_env, + reset_seeds, ): - if is_ci_env and (instruct_weights == False or "long" in input_prompts or single_layer == True): - pytest.skip("CI demo test only runs instruct weights to reduce CI pipeline load (both are supported)") + if is_ci_env and (instruct_weights == False or "long" in input_prompts or single_layer == True or batch_size > 1): + pytest.skip( + "CI demo test only runs instruct weights with batch_size=1 to reduce CI pipeline load (all modes are supported)" + ) mesh_device.enable_async(True) + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + else: + paged_attention_config = None + return run_llama3_demo( user_input=input_prompts, - batch_size=1, - single_layer=single_layer, mesh_device=mesh_device, + max_seq_len=max_seq_len, + batch_size=batch_size, + num_batches=num_batches, + paged_attention=paged_attention, + paged_attention_config=paged_attention_config, + max_generated_tokens=max_generated_tokens, + single_layer=single_layer, instruct_mode=instruct_weights, is_ci_env=is_ci_env, - num_batches=num_batches, print_to_file=False, ) diff --git a/models/demos/llama3/demo/mayo.json b/models/demos/llama3/demo/mayo.json new file mode 100644 index 00000000000..fe794eec893 --- /dev/null +++ b/models/demos/llama3/demo/mayo.json @@ -0,0 +1,98 @@ +[ + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + }, + { + "prompt": "Do you have mayonnaise recipes? Mayonnaise is a versatile ingredient that can be used in countless recipes beyond just a sandwich spread. What are some of your favorite ways to use mayonnaise in cooking or baking? Do you have a special recipe for a creamy potato salad, a tangy coleslaw, or perhaps a savory dip for vegetables and chips? Mayonnaise can also be used as a base for homemade dressings and sauces, adding richness and flavor to your dishes. Have you tried baking with mayonnaise to keep cakes moist and tender? Share any recipes, tips, or creative uses you have for mayonnaise. How did you discover these recipes, and do you have any variations that you particularly enjoy?" + } +] diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index 673c3bc5a73..b4946c3eecf 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -49,8 +49,6 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn tt_model_args = TtModelArgs(mesh_device, max_batch_size=max_batch_size) # limit length or we'll run out of space tt_model_args.max_seq_len = max_seq_len - tt_model_args.kv_seq_len = max_seq_len - tt_model_args.sliding_window = max_seq_len checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) model = CrossAttentionTransformer( mesh_device, diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 172531645c9..7448601b8ce 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -63,8 +63,6 @@ def test_llama_cross_attention_transformer_text_inference( model_args = TtModelArgs(mesh_device, max_batch_size=batch) # Limit the max seqlen to 4k to avoid OOM on host model_args.max_seq_len = 4096 - model_args.kv_seq_len = model_args.max_seq_len - model_args.sliding_window = model_args.max_seq_len state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 1b0013c78ee..96637e5090c 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -46,8 +46,6 @@ def test_llama_cross_attention_transformer_block_inference( model_args = TtModelArgs(mesh_device, max_batch_size=batch) # Limit the max seqlen to 4k to avoid OOM on host model_args.max_seq_len = 4096 - model_args.kv_seq_len = model_args.max_seq_len - model_args.sliding_window = model_args.max_seq_len state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index acdcc257901..bbc896bc064 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -8,21 +8,24 @@ import os import ttnn from models.demos.llama3.tt.llama_common import ( - get_single_rot_mat, get_prefill_rot_mat, get_rot_transformation_mat, HostEmbedding, + PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.demo.demo import preprocess_inputs_prefill @torch.no_grad() @pytest.mark.timeout(900) -@pytest.mark.parametrize("prefill_len", [512]) -@pytest.mark.parametrize("decode_len", [128]) +@pytest.mark.parametrize( + "prefill_len, decode_len, max_seq_len", # Max seqlen should be at least prefill_len + decode_len + ((512, 128, 1024),), +) @pytest.mark.parametrize( "mesh_device", [ @@ -32,15 +35,47 @@ ], indirect=True, ) -def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cache, reset_seeds): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_tt_model_accuracy( + prefill_len, + decode_len, + max_seq_len, + batch_size, + paged_attention, + paged_attention_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b + # TODO min_top1_acc = 75 min_top5_acc = 96 mesh_device.enable_async(True) # Load model args and tokenizer - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + tokenizer = Tokenizer(model_args.tokenizer_path) # Load state_dict for TT model @@ -62,6 +97,51 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac N = prefill_len + decode_len input_ids = reference_tokens[:, : N + 1] # Shape [1, N+1] + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( + mesh_device, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, + ) + transformation_mats_decode = rope_setup.get_trans_mats() + + transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim) + transformation_mats_prefill = ttnn.from_torch( + transformation_mats_prefill_torch, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill} + + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + # Initialize TT model tt_model = TtTransformer( args=model_args, @@ -69,6 +149,8 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) # Initialize embedding embd = HostEmbedding(model_args) @@ -96,18 +178,9 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac pt_prefill_input = [embd(input_tokens_prefill_pt[b]).view(1, prefill_lens[b], -1) for b in range(1)] # Pre-compute the rotational embedding matrix and send to device - rot_mats = get_prefill_rot_mat( + rot_mats_prefill = get_prefill_rot_mat( model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=prefill_lens[0] ) - transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.from_torch( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) prefill_input = model_args.prepare_inputs_ttnn_prefill( pt_prefill_input[batch_id], @@ -115,11 +188,11 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac tt_out = tt_model( prefill_input, - None, # Current position - rot_mats, - transformation_mats, + current_pos=None, + rot_mats=rot_mats_prefill, user_id=batch_id, mode="prefill", + page_table=page_table_tt, get_last_token=((decoding_pos[batch_id] - 1) // 32) * 32, ) @@ -127,19 +200,19 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac logger.info(f"Starting decode...") generation_start_pos = prefill_len generation_length = decode_len - current_pos = ttnn.from_torch( - torch.tensor([generation_start_pos]), + + # Initial positions + decoding_pos = [generation_start_pos] * model_args.max_batch_size + current_pos = torch.tensor([decoding_pos[b] for b in range(model_args.max_batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, device=mesh_device, dtype=ttnn.int32, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=max(0, generation_start_pos - 1), - ) + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) # Print table header logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}") @@ -164,7 +237,13 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) # Run TT model - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) if tt_model.args.num_devices > 1: tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) @@ -173,23 +252,20 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac tt_out_gathered = tt_out tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True) ttnn.deallocate(tt_out_gathered) - tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True) + tt_out_tok = ttnn.argmax( + tt_out_rm, + dim=3, + use_multicore=True if model_args.max_batch_size == 1 else False, + ) tt_argmax_token = ttnn.to_torch(tt_out_tok, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ 0, 0, 0, 0 ] ttnn.deallocate(tt_out_rm) - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - ttnn.plus_one(current_pos) - - # Reset rotation matrix every 100 iterations - if i % 100 == 0: # Doing this every 100 iterations as in demo takes top5 from 99% -> - current_rot_mat, rot_matrix_reset = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=generation_start_pos + i, - on_host=False, - ) + ttnn.plus_one(current_pos_tensor) + + # Update rot_mats for next iteration + current_pos += 1 + rot_mats = rope_setup.get_rot_mats(current_pos) # Get reference top5 tokens and probabilities for this position ref_top5_tokens = top5_tokens[prefill_len + i] diff --git a/models/demos/llama3/tests/test_llama_attention.py b/models/demos/llama3/tests/test_llama_attention.py index c41ac5644ca..010e5079e0a 100644 --- a/models/demos/llama3/tests/test_llama_attention.py +++ b/models/demos/llama3/tests/test_llama_attention.py @@ -7,10 +7,11 @@ import os import ttnn from models.demos.llama3.tt.llama_attention import TtLlamaAttention +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( precompute_freqs, - get_single_rot_mat, + PagedAttentionConfig, ) from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Attention from models.utility_functions import ( @@ -31,14 +32,47 @@ ], indirect=True, ) -def test_llama_attention_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_llama_attention_inference( + max_seq_len, + batch_size, + paged_attention, + paged_attention_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b pcc = 0.99 mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) - model_args.n_layers = 1 + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args.n_layers = 1 # For the unit test, just run a sigle layer + state_dict = model_args.load_state_dict() first_layer_prefix = model_args.get_state_dict_prefix("TtLlamaAttention", 0) + "." @@ -50,44 +84,79 @@ def test_llama_attention_inference(mesh_device, use_program_cache, reset_seeds, reference_model = Attention(args=model_args) reference_model.load_state_dict(partial_state_dict) - batch = model_args.max_batch_size seq_len = 1 generation_start_pos = 0 generation_length = 10 all_tests_pass = True - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( mesh_device, - model_args.num_devices, - start_pos=0, + batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, ) + transformation_mats = rope_setup.get_trans_mats() + transformation_mats = {"decode": transformation_mats} + + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_model = TtLlamaAttention( mesh_device, state_dict, weight_cache_path=model_args.weight_cache_path(dtype), layer_num=0, dtype=dtype, + transformation_mats=transformation_mats, configuration=model_args, + paged_attention_config=paged_attention_config, ) - cos, sin = precompute_freqs(model_args.head_dim, model_args.max_seq_len * 2) + cos, sin = precompute_freqs( + model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope + ) freqs_cis = torch.complex(cos, sin) + + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + for i in range(generation_length): # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 - pt_attention_input = torch.randn(batch, seq_len, model_args.dim) * 0.05 + pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) * 0.05 tt_attention_input = pt_attention_input.clone() - current_pos = generation_start_pos + i - current_pos_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch), - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) attention_input = model_args.prepare_inputs_ttnn_decode( tt_attention_input, @@ -95,48 +164,84 @@ def test_llama_attention_inference(mesh_device, use_program_cache, reset_seeds, force_replicated=True, ) - tt_out = tt_model(attention_input, current_pos_tensor, rot_mats=current_rot_mat, mode="decode") + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) + + tt_out = tt_model( + attention_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) # multi-device attention module returns replicated output tt_output_torch = ( ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[0, :, :, : model_args.dim] .view(1, -1, model_args.dim) .permute(1, 0, 2)[: model_args.max_batch_size, :, :] - ) # [ batch, seq, hidden_dim] + ) # [ batch_size, seq, hidden_dim] - freqs_cis_i = freqs_cis[current_pos, :].unsqueeze(0) - # positions = torch.tensor([current_pos]) + # In this test all users have the same position + freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) - reference_output = reference_model(pt_attention_input, current_pos, freqs_cis_i, mask=None) + reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None) passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") if passing: - logger.info(f"[pos={current_pos}] Llama_Attention Passed!") + logger.info(f"[pos={current_pos[0]}] Llama_Attention Passed!") else: - logger.warning(f"[pos={current_pos}] Llama_Attention Failed!") + logger.warning(f"[pos={current_pos[0]}] Llama_Attention Failed!") all_tests_pass = False - # Update rotation matrix for next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) + # Increment position + current_pos = torch.tensor([generation_start_pos + i for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) check_kv_cache = True if check_kv_cache: # PyTorch output -------------------------------------------------------------------- pytorch_layer_present = [ - reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] - reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] + reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] + reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] ] # TT hardware execution ------------------------------------------------------------- - tt_layer_present = [ - ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) - for cache in tt_model.layer_past - ] + if paged_attention: + tt_layer_present = [ + ( + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch_size, ... + ] + ) + for cache in tt_model.layer_past + ] + else: + tt_layer_present = [ + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + for cache in tt_model.layer_past + ] for i, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): - cache_length_to_check = min(model_args.sliding_window, generation_start_pos + generation_length + 1) + cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + generation_length + 1) cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) diff --git a/models/demos/llama3/tests/test_llama_attention_prefill.py b/models/demos/llama3/tests/test_llama_attention_prefill.py index fe3f1834eae..ad4f0c96bfe 100644 --- a/models/demos/llama3/tests/test_llama_attention_prefill.py +++ b/models/demos/llama3/tests/test_llama_attention_prefill.py @@ -11,6 +11,7 @@ from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, get_rot_transformation_mat, + PagedAttentionConfig, ) from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Attention, precompute_freqs_cis from models.utility_functions import ( @@ -22,10 +23,6 @@ @torch.no_grad() @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - (2048,), -) @pytest.mark.parametrize( "mesh_device", [ @@ -35,13 +32,45 @@ ], indirect=True, ) -def test_llama_attention_inference(seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (2048,), +) +def test_llama_attention_inference( + max_seq_len, + batch_size, + paged_attention, + paged_attention_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b pcc = 0.99 mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) model_args.n_layers = 1 state_dict = model_args.load_state_dict() @@ -53,45 +82,78 @@ def test_llama_attention_inference(seq_len, mesh_device, use_program_cache, rese reference_model = Attention(args=model_args) reference_model.load_state_dict(partial_state_dict) - batch = 1 - # pre-compute the rotational embedding matrix and send to device rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.as_tensor( + transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + transformation_mats = {"prefill": transformation_mats_prefill} + generation_start_pos = 0 generation_length = 3 all_tests_pass = True + # Setup page table + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_model = TtLlamaAttention( mesh_device, state_dict, weight_cache_path=model_args.weight_cache_path(dtype), layer_num=0, dtype=dtype, + transformation_mats=transformation_mats, configuration=model_args, + paged_attention_config=paged_attention_config, ) - pt_attention_input = (torch.rand(batch, seq_len, model_args.dim) * 2) - 1 + pt_attention_input = (torch.rand(batch_size, seq_len, model_args.dim) * 2) - 1 tt_attention_input = pt_attention_input.clone() attention_input = model_args.prepare_inputs_ttnn_prefill( tt_attention_input, force_replicated=True, ) - tt_out = tt_model(attention_input, 0, rot_mats, transformation_mats, user_id=0, mode="prefill") + tt_out = tt_model( + attention_input, + current_pos=None, + rot_mats=rot_mats, + user_id=0, + mode="prefill", + page_table=page_table_tt, + ) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ 0, :, :, : model_args.dim ].view( - batch, seq_len, -1 - ) # [ batch, seq, dim] + batch_size, seq_len, -1 + ) # [ batch_size, seq, dim] positions = torch.LongTensor(range(seq_len)) freqs_cis_i = precompute_freqs_cis( @@ -115,17 +177,36 @@ def test_llama_attention_inference(seq_len, mesh_device, use_program_cache, rese if check_kv_cache: # PyTorch output -------------------------------------------------------------------- pytorch_layer_present = [ - reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] - reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] + reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] + reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] ] # TT hardware execution ------------------------------------------------------------- - tt_layer_present = [ - ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) - for cache in tt_model.layer_past - ] + if paged_attention: + tt_layer_present = [ + ( + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[reverse_permutation] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch_size, ... + ] + ) + for cache in tt_model.layer_past + ] + else: + tt_layer_present = [ + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + for cache in tt_model.layer_past + ] for i, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): - cache_length_to_check = min(model_args.sliding_window, generation_start_pos + generation_length + 1) + cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + generation_length + 1) cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) diff --git a/models/demos/llama3/tests/test_llama_decoder.py b/models/demos/llama3/tests/test_llama_decoder.py index 1fad070640b..d5ec9833e32 100644 --- a/models/demos/llama3/tests/test_llama_decoder.py +++ b/models/demos/llama3/tests/test_llama_decoder.py @@ -8,10 +8,11 @@ import ttnn from models.demos.llama3.tt.llama_common import ( precompute_freqs, - get_single_rot_mat, + PagedAttentionConfig, ) -from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.llama_decoder import TtTransformerBlock +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import TransformerBlock from models.utility_functions import ( comp_pcc, @@ -31,13 +32,45 @@ ], indirect=True, ) -def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_llama_decoder_inference( + max_seq_len, + batch_size, + paged_attention, + paged_attention_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b - mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) model_args.n_layers = 1 + state_dict = model_args.load_state_dict() # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -52,13 +85,41 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en generation_length = 10 all_tests_pass = True - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( mesh_device, - model_args.num_devices, - start_pos=0, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, ) + transformation_mats = rope_setup.get_trans_mats() + transformation_mats = {"decode": transformation_mats} + + # Prepare page table for paged attention + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) # Initialize TT model tt_model = TtTransformerBlock( @@ -68,27 +129,31 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en state_dict=state_dict, layer_num=0, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) seqlen = 1 - batch = model_args.max_batch_size - cos, sin = precompute_freqs(model_args.head_dim, model_args.max_seq_len * 2) + cos, sin = precompute_freqs( + model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope + ) freqs_cis = torch.complex(cos, sin) + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) for i in range(generation_length): logger.info(f"[Decoder] Generating token {i}") # input = torch.randn(1, 32, 4096) - pt_decode_input = (torch.rand(batch, seqlen, model_args.dim) * 2) - 1 + pt_decode_input = (torch.rand(batch_size, seqlen, model_args.dim) * 2) - 1 tt_decode_input = pt_decode_input.clone() - current_pos = generation_start_pos + i - current_pos_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch), - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) decode_input = model_args.prepare_inputs_ttnn_decode( tt_decode_input, @@ -96,20 +161,31 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) + # Run TT model - tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) + tt_output_torch = ( ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ :1, :, :, : model_args.dim ] .permute(2, 1, 0, 3) .squeeze(1)[: model_args.max_batch_size, :, :] - ) # [seq, batch, dim] + ) # [seq, batch_size, dim] - freqs_cis_i = freqs_cis[current_pos, :].unsqueeze(0) + # In this test all users have the same position + freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) # Reference model - ref_output = reference_model(pt_decode_input, current_pos, freqs_cis_i, mask=None) + ref_output = reference_model(pt_decode_input, current_pos[0], freqs_cis_i, mask=None) passing, pcc_message = comp_pcc(ref_output, tt_output_torch) @@ -122,8 +198,14 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en logger.warning("Llama Decoder Block Failed!") all_tests_pass = False - # Update rotation matrix for next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) + # Increment position + current_pos = torch.tensor([generation_start_pos + i for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) if all_tests_pass: logger.info(f"All {generation_length} Llama decode iterations Passed!") diff --git a/models/demos/llama3/tests/test_llama_decoder_prefill.py b/models/demos/llama3/tests/test_llama_decoder_prefill.py index 998a4ab2f39..a44ec69fd99 100644 --- a/models/demos/llama3/tests/test_llama_decoder_prefill.py +++ b/models/demos/llama3/tests/test_llama_decoder_prefill.py @@ -9,6 +9,7 @@ from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, get_rot_transformation_mat, + PagedAttentionConfig, ) from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs @@ -22,13 +23,6 @@ @torch.no_grad() @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - ( - 4096, - 128, - ), -) @pytest.mark.parametrize( "mesh_device", [ @@ -38,13 +32,48 @@ ], indirect=True, ) -def test_llama_decoder_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + ( + 4096, + 128, + ), +) +def test_llama_decoder_inference( + max_seq_len, + batch_size, + paged_attention, + paged_attention_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b - mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) model_args.n_layers = 1 + state_dict = model_args.load_state_dict() # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -52,7 +81,7 @@ def test_llama_decoder_inference(mesh_device, seq_len, use_program_cache, reset_ partial_state_dict = { k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - batch = 1 + reference_model = TransformerBlock(layer_id=0, args=model_args) reference_model.load_state_dict(partial_state_dict) @@ -61,51 +90,84 @@ def test_llama_decoder_inference(mesh_device, seq_len, use_program_cache, reset_ all_tests_pass = True # pre-compute the rotational embedding matrix and send to device - rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) + rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=max_seq_len) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.as_tensor( + transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + transformation_mats = {"prefill": transformation_mats_prefill} + + # Setup page table + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) # Initialize TT model tt_model = TtTransformerBlock( - args=model_args, mesh_device=mesh_device, - dtype=dtype, state_dict=state_dict, - layer_num=0, weight_cache_path=model_args.weight_cache_path(dtype), + layer_num=0, + dtype=dtype, + transformation_mats=transformation_mats, + args=model_args, + paged_attention_config=paged_attention_config, ) - # TODO Update start_pos (check llama test for reference) for i in range(generation_length): logger.info(f"[Decoder] Generating token {i}") - pt_decode_input = (torch.rand(batch, seq_len, model_args.dim) * 2) - 1 + pt_decode_input = (torch.rand(batch_size, max_seq_len, model_args.dim) * 2) - 1 tt_decode_input = pt_decode_input.clone() decode_input = model_args.prepare_inputs_ttnn_prefill( tt_decode_input, ) - positions = torch.LongTensor(range(seq_len)) + positions = torch.LongTensor(range(max_seq_len)) freqs_cis_i = precompute_freqs_cis( model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope )[positions] # Reference model - attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) + attn_mask = torch.full((max_seq_len, max_seq_len), torch.finfo(torch.float32).min) attn_mask_torch = torch.triu(attn_mask, diagonal=1) ref_output = reference_model(pt_decode_input, positions[0], freqs_cis_i, mask=attn_mask_torch) # Run TT model - tt_out = tt_model(decode_input, None, rot_mats, transformation_mats, user_id=0, mode="prefill") + tt_out = tt_model( + decode_input, + current_pos=None, + rot_mats=rot_mats, + user_id=0, + mode="prefill", + page_table=page_table_tt, + ) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ 0, :, :, : model_args.dim ].view( - batch, seq_len, -1 - ) # [ batch, seq, hidden_dim] + batch_size, max_seq_len, -1 + ) # [ batch_size, seq, hidden_dim] passing, pcc_message = comp_pcc(ref_output, tt_output_torch) logger.info(comp_allclose(ref_output, tt_output_torch)) diff --git a/models/demos/llama3/tests/test_llama_embedding.py b/models/demos/llama3/tests/test_llama_embedding.py index e8178f7e2e1..d5223b64254 100644 --- a/models/demos/llama3/tests/test_llama_embedding.py +++ b/models/demos/llama3/tests/test_llama_embedding.py @@ -28,15 +28,22 @@ ], indirect=True, ) -def test_llama_embedding(mesh_device, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 - mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) model_args.n_layers = 1 - state_dict = model_args.load_state_dict() + state_dict = model_args.load_state_dict() tokenizer = Tokenizer(model_args.tokenizer_path) reference_emb = HostEmbedding(model_args) diff --git a/models/demos/llama3/tests/test_llama_mlp.py b/models/demos/llama3/tests/test_llama_mlp.py index fa7655dd6ff..b810cb357bd 100644 --- a/models/demos/llama3/tests/test_llama_mlp.py +++ b/models/demos/llama3/tests/test_llama_mlp.py @@ -19,15 +19,6 @@ @torch.no_grad() @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - ( - 64 * 1024, - 32 * 1024, - # 1024, - 32, - ), -) @pytest.mark.parametrize( "mesh_device", [ @@ -37,13 +28,21 @@ ], indirect=True, ) -def test_llama_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "seq_len", + ( + 64 * 1024, + 32 * 1024, + 32, + ), +) +def test_llama_mlp_inference(seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat8_b mode = "decode" if seq_len <= 32 else "prefill" mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=1, max_seq_len=128) model_args.n_layers = 1 state_dict = model_args.load_state_dict() diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index 803381ffce3..e8d276de9c2 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -8,13 +8,14 @@ import ttnn from models.demos.llama3.tt.llama_common import ( precompute_freqs, - get_single_rot_mat, sample, encode_prompt_llama_instruct, HostEmbedding, + PagedAttentionConfig, ) -from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.llama_model import TtTransformer +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( @@ -36,6 +37,29 @@ ], ids=["quick", "full"], ) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) @pytest.mark.parametrize( "mesh_device", [ @@ -45,7 +69,18 @@ ], indirect=True, ) -def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, reset_seeds, ensure_gc): +def test_llama_model_inference( + weights, + layers, + max_seq_len, + batch_size, + paged_attention, + paged_attention_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): run_ref_pt = True # Flag to run reference PyTorch model and compare PCC cache_pcc = layers == 1 # Flag to measure KV cache PCC. Avoid running for all layers to speed up test time. @@ -54,11 +89,19 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, mesh_device.enable_async(True) # This sets the minimum PCC for each iteration - pcc = 0.88 if layers == 1 else 0.94 # TODO For model test quick (1 layer) one iteration might get a worse PCC + if batch_size == 1: + pcc = 0.88 if layers == 1 else 0.94 # TODO For model test quick (1 layer) one iteration might get a worse PCC + else: + pcc = 0.7 # TODO Miguel: Investigate lower PCC with batch_size > 1 instruct = True if weights == "instruct" else False dummy_weights = True if weights == "random" else False - model_args = TtModelArgs(mesh_device, instruct=instruct, dummy_weights=dummy_weights) + model_args = TtModelArgs( + mesh_device, instruct=instruct, dummy_weights=dummy_weights, max_seq_len=max_seq_len, max_batch_size=batch_size + ) + + # Reduce max seq len and KV cache seq_len params to speed up the test + model_args.max_seq_len = 128 model_name = { (16, False): "llama32_1b", @@ -116,7 +159,9 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, prompts = ["This is a test"] * model_args.max_batch_size if dummy_weights: - encoded_prompts = [[128000, 2028, 374, 264, 1296]] # "This is a test" encoded prompt + encoded_prompts = [ + [128000, 2028, 374, 264, 1296] + ] * model_args.max_batch_size # "This is a test" encoded prompt assert not instruct, "Instruct prompt not implemented with dummy weights" else: tokenizer = Tokenizer(model_args.tokenizer_path) @@ -136,13 +181,41 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, generation_start_pos = 0 generation_length = iterations - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( mesh_device, - model_args.num_devices, - start_pos=0, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, ) + transformation_mats = rope_setup.get_trans_mats() + transformation_mats = {"decode": transformation_mats} + + page_table_tt = None + paged_attention_config = None + + # Prepare page table for paged attention + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) # Load TTNN model tt_model = TtTransformer( @@ -151,6 +224,8 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) logger.info("Model and caches loaded.") @@ -163,7 +238,6 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, # Select the first token from the prompts for initial decoding encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0] pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1) - tt_decode_input = pt_decode_input # Keep track of generated outputs to print out later @@ -171,42 +245,59 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, if run_ref_pt: all_outputs_ref = [] + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + for i in range(generation_length): - current_pos = generation_start_pos + i + logger.info(f"[Llama3 Model] Generating token {i}") decode_input = model_args.prepare_inputs_ttnn_decode( tt_decode_input, model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) - current_pos_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch), - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) + + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) # Run TT model - tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) + # Convert ttnn tensor to torch tensor tt_output_torch = ( ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) .permute(2, 1, 0, 3) .squeeze(1)[: model_args.max_batch_size, :, :] ) # [seq, batch, hidden_dim] - ttnn.deallocate(tt_out) - # Update rotation matrix for next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - if run_ref_pt: # Run reference model - # freqs_cis_i = freqs_cis[current_pos, :].unsqueeze(0) - # positions = torch.tensor([current_pos]) - # mask = ttnn.to_torch(attn_mask[0]) - ref_output = reference_model(pt_decode_input, current_pos) + # In this test all users have the same position + ref_output = reference_model(pt_decode_input, current_pos[0]) - # While in "prefill" mode, use the prompt tokens as the output + # Increment position + current_pos = torch.tensor([generation_start_pos + i for _ in range(batch)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + # Append the generated token to the list of outputs if i in range(len(encoded_prompts[0])): + # While in "prefill" mode, use the prompt tokens as the output all_outputs.append(encoded_prompts[0][i]) # Update list of TT outputs if run_ref_pt: all_outputs_ref.append(encoded_prompts[0][i]) # Update list of ref outputs @@ -225,7 +316,6 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, all_outputs_ref.append( pt_out_tok.squeeze(1).tolist()[0] ) # Update generated token to list of ref outputs - # Measure PCC if also running reference model if run_ref_pt: if layers == 1 and i == iterations - 1: # On last iteration in the quick test, set a tighter PCC @@ -256,14 +346,52 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, ] tt_layer_present = [] - for layer_past in tt_model.layers[l].attention.layer_past: - tt_layer_present.append( - ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) - ) + if paged_attention: + for layer_past in tt_model.layers[l].attention.layer_past: + tt_layer_present.append( + ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch, ... + ] + ) + tt_layer_present = [ + ( + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch, ... + ] + ) + for cache in tt_model.layers[l].attention.layer_past + ] + else: + for layer_past in tt_model.layers[l].attention.layer_past: + tt_layer_present.append( + ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + ) for kv_cache, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): cache_length_to_check = min( - model_args.sliding_window, generation_start_pos + generation_length + 1 + model_args.max_seq_len, generation_start_pos + generation_length + 1 ) cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index ca48efd8b11..71ad3505d76 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -13,6 +13,7 @@ sample, HostEmbedding, encode_prompt_llama_instruct, + PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs @@ -29,14 +30,6 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.timeout(900) @pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "seq_len", - ( - # 128, - # 1024, - 4096, - ), -) @pytest.mark.parametrize( "mesh_device", [ @@ -46,19 +39,38 @@ ], indirect=True, ) -def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "paged_attention", + (True, False), + ids=("paged_attention", "default_attention"), +) +@pytest.mark.parametrize( + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "seq_len", + (2048,), +) +def test_llama_model_inference( + seq_len, batch_size, paged_attention, paged_attention_params, mesh_device, use_program_cache, reset_seeds, ensure_gc +): run_ref_pt = True # Flag to run reference PyTorch model and compare PCC cache_pcc = False # Flag to measure KV cache PCC for all layers dtype = ttnn.bfloat8_b - pcc = 0.91 # TODO Look on improving PCC - + pcc = 0.90 mesh_device.enable_async(True) # Use instruct weights instead of general weights instruct = True - model_args = TtModelArgs(mesh_device, instruct=instruct, max_batch_size=1) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=seq_len) + tokenizer = Tokenizer(model_args.tokenizer_path) logger.info("Loading weights...") @@ -101,14 +113,39 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se # pre-compute the rotational embedding matrix and send to device rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.as_tensor( + transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + transformation_mats = {"prefill": transformation_mats_prefill} + + # Setup page table + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) # Load TTNN model tt_model = TtTransformer( @@ -117,6 +154,8 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) logger.info("Model and caches loaded.") @@ -124,21 +163,28 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se if run_ref_pt: all_tests_pass = True - batch = 1 + batch = model_args.max_batch_size # 1 # Select the first token from the prompt for initial decoding encoded_prompt_tensor = torch.tensor(encoded_prompt) # [:,0] - pt_decode_input = embd(encoded_prompt_tensor).view(batch, seq_len, -1) + pt_prefill_input = embd(encoded_prompt_tensor).view(batch, seq_len, -1) - tt_decode_input = pt_decode_input + tt_prefill_input = pt_prefill_input - decode_input = model_args.prepare_inputs_ttnn_prefill( - tt_decode_input, + tt_prefill_input = model_args.prepare_inputs_ttnn_prefill( + pt_prefill_input, ) for i in range(1): start_pos = 0 # Run TT model - tt_out = tt_model(decode_input, None, rot_mats, transformation_mats, user_id=i, mode="prefill") + tt_out = tt_model( + tt_prefill_input, + current_pos=None, + rot_mats=rot_mats, + user_id=i, + mode="prefill", + page_table=page_table_tt, + ) # Convert ttnn tensor to torch tensor tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ :, 0, :, : @@ -147,7 +193,7 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se ) # [ batch, seq, hidden_dim] if run_ref_pt: # Run reference model - ref_output = reference_model(pt_decode_input, start_pos, mode="prefill") + ref_output = reference_model(pt_prefill_input, start_pos, mode="prefill") # Measure PCC if also running reference model if run_ref_pt: @@ -176,13 +222,51 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se ] tt_layer_present = [] - for layer_past in tt_model.layers[i].attention.layer_past_list[0]: - tt_layer_present.append( - ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) - ) + if paged_attention: + for layer_past in tt_model.layers[l].attention.layer_past: + tt_layer_present.append( + ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch, ... + ] + ) + tt_layer_present = [ + ( + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch, ... + ] + ) + for cache in tt_model.layers[l].attention.layer_past + ] + else: + for layer_past in tt_model.layers[i].attention.layer_past_list[0]: + tt_layer_present.append( + ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + ) for i, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): - cache_length_to_check = model_args.sliding_window + cache_length_to_check = model_args.max_seq_len cache_pt = cache_pt[:, :, 0:cache_length_to_check, :] cache_tt = cache_tt[:, :, 0:cache_length_to_check, :] does_pass, output_pcc = comp_pcc(cache_pt, cache_tt) @@ -200,7 +284,7 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se if run_ref_pt: if all_tests_pass: - logger.info(f"All Llama decode iterations Passed!") + logger.info(f"All Llama prefill iterations Passed!") else: - logger.warning("One or more iterations of Llama decode had bad PCC") + logger.warning("One or more iterations of Llama prefill had bad PCC") assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/test_llama_perf.py b/models/demos/llama3/tests/test_llama_perf.py index c2cda7b346c..c2ea707515d 100644 --- a/models/demos/llama3/tests/test_llama_perf.py +++ b/models/demos/llama3/tests/test_llama_perf.py @@ -11,13 +11,13 @@ from models.demos.llama3.tt.llama_common import ( sample, HostEmbedding, - get_single_rot_mat, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer - +from models.demos.llama3.tt.llama_common import PagedAttentionConfig from models.perf.perf_utils import prep_perf_report from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report from models.utility_functions import profiler, skip_for_grayskull @@ -29,13 +29,32 @@ @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( - "kv_cache_len, expected_compile_time", + "seq_len, expected_compile_time", ( (32, 30), (128, 30), (1024, 30), ), ) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "paged_attention_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) @pytest.mark.parametrize( "mesh_device", [ @@ -45,12 +64,22 @@ ], indirect=True, ) -def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_program_cache, reset_seeds, ensure_gc): +def test_llama_model_perf( + batch_size, + seq_len, + expected_compile_time, + paged_attention, + paged_attention_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=seq_len) tokenizer = Tokenizer(model_args.tokenizer_path) if "3.2-1B" in model_args.DEFAULT_CACHE_PATH: @@ -83,9 +112,44 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_ state_dict_prefix = model_args.get_state_dict_prefix("", None) embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) - generation_start_pos = kv_cache_len + generation_start_pos = seq_len generation_length = 1 + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( + mesh_device, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, + ) + transformation_mats_decode = rope_setup.get_trans_mats() + transformation_mats = {"decode": transformation_mats_decode} + + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=paged_attention_params["page_block_size"], + max_num_blocks=paged_attention_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + profiler.start("TtLlama_model_setup") # Load TTNN model @@ -95,6 +159,8 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_ dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) # Load TTNN embedding module tt_embd = TtLlamaEmbedding( @@ -108,7 +174,9 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_ # Call the function profiler.start(f"end_to_end_inference_with_compile") - run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length) + run_inference( + tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length, rope_setup, page_table_tt + ) profiler.end(f"end_to_end_inference_with_compile") profiler.print() compile_and_iter_time = profiler.get("model_run_for_inference_0") @@ -119,12 +187,14 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_ signpost("Model perf run") profiler.start(f"end_to_end_inference") - run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length) + run_inference( + tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length, rope_setup, page_table_tt + ) profiler.end(f"end_to_end_inference") profiler.print() iter_time = profiler.get("end_to_end_inference") - comment = f"kv_cache_len={kv_cache_len}_num_layers={model_args.n_layers}" + comment = f"kv_cache_len={seq_len}_num_layers={model_args.n_layers}" # Extract the version, number of weights and device name from the cache folder if "3.1" in model_args.DEFAULT_CACHE_PATH: @@ -145,19 +215,13 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_ ) -def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length): +def run_inference( + tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length, rope_setup, page_table +): seqlen = 1 # Generating one token per user at a time batch = tt_model.args.max_batch_size mesh_device = tt_model.mesh_device - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - tt_model.args.head_dim, - tt_model.mesh_device, - tt_model.args.num_devices, - start_pos=0, - ) - # Select the first token from the prompts for initial decoding encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0] @@ -172,29 +236,46 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos ) # Send first input to device - current_pos = ttnn.from_torch( - torch.tensor([generation_start_pos] * batch), + current_pos = torch.tensor([generation_start_pos] * batch) + current_pos_tensor = ttnn.from_torch( + current_pos, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) + for i in range(generation_length): # Run TT model profiler.start(f"model_run_for_inference_{i}") decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table, + ) tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) ttnn.deallocate(tt_out) - tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) + tt_out_tok = ttnn.argmax( + tt_out_rm, + dim=3, + use_multicore=True if tt_model.args.max_batch_size == 1 else False, + output_tensor=tt_out_tok, + ) ttnn.deallocate(tt_out_rm) # Update the rotation matrix for the next iteration - new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) - ttnn.plus_one(current_pos) + ttnn.plus_one(current_pos_tensor) + + # Update rot_mats for next iteration + current_pos += 1 + rot_mats = rope_setup.get_rot_mats(current_pos) profiler.end(f"model_run_for_inference_{i}") diff --git a/models/demos/llama3/tests/test_llama_rms_norm.py b/models/demos/llama3/tests/test_llama_rms_norm.py index bf0ce828900..cca0f113b55 100644 --- a/models/demos/llama3/tests/test_llama_rms_norm.py +++ b/models/demos/llama3/tests/test_llama_rms_norm.py @@ -28,13 +28,30 @@ ], indirect=True, ) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) @pytest.mark.parametrize("mode", ["prefill", "decode"]) -def test_llama_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc, mode): +def test_llama_rms_norm_inference( + max_seq_len, + batch_size, + mode, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat16 mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args.n_layers = 1 state_dict = model_args.load_state_dict() state_dict_prefix = model_args.get_state_dict_prefix("", 0) diff --git a/models/demos/llama3/tests/test_lm_head.py b/models/demos/llama3/tests/test_lm_head.py index a626910c729..4a5570f5cc0 100644 --- a/models/demos/llama3/tests/test_lm_head.py +++ b/models/demos/llama3/tests/test_lm_head.py @@ -23,6 +23,10 @@ "seq_len", (32,), ) +@pytest.mark.parametrize( + "batch_size", + (1,), +) @pytest.mark.parametrize( "mesh_device", [ @@ -32,12 +36,12 @@ ], indirect=True, ) -def test_llama_lm_head_inference(mesh_device, seq_len, use_program_cache, reset_seeds): +def test_llama_lm_head_inference(seq_len, batch_size, mesh_device, use_program_cache, reset_seeds): dtype = ttnn.bfloat8_b mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=seq_len) model_args.n_layers = 1 state_dict = model_args.load_state_dict() diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index d630e91a3bd..24e7eb572f7 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -20,7 +20,9 @@ def __init__( weight_cache_path, layer_num, dtype, + transformation_mats, configuration, + paged_attention_config=None, ): super().__init__() @@ -34,7 +36,7 @@ def __init__( self.max_seq_len = configuration.max_seq_len self.max_batch_size = configuration.max_batch_size self.n_kv_heads = configuration.n_kv_heads - self.paged_attention_config = configuration.paged_attention_config + self.paged_attention_config = paged_attention_config self.min_kv_prefill_shard_seqlen = configuration.min_kv_prefill_shard_seqlen self.n_local_heads = self.n_heads // configuration.num_devices @@ -42,13 +44,14 @@ def __init__( self.dtype = dtype - self.kv_seq_len = configuration.kv_seq_len - self.sliding_window = configuration.sliding_window + self.max_seq_len = configuration.max_seq_len self.grid_size = configuration.max_grid_size self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.transformation_mats = transformation_mats + self.model_config = configuration.get_model_config() self.ccl_topology = configuration.ccl_topology() self.is_multichip = configuration.is_multichip @@ -113,7 +116,7 @@ def __init__( self.use_fused_all_gather_matmul = self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] if self.is_multichip and self.use_fused_all_gather_matmul: pt_wo = self.state_dict[wo_str].transpose(-1, -2).unsqueeze(0).unsqueeze(0) - wo_ttnn = ttnn.as_tensor( + self.wo = ttnn.as_tensor( pt_wo, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, @@ -122,7 +125,6 @@ def __init__( mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), cache_file_name=cache_name("wo_width_sharded"), ) - self.wo = ttnn.to_device(wo_ttnn, self.mesh_device) else: # For line topology we can't do all gather matmul for now, but we can height shard and reduce scatter # wo: 2048 (2devices) x 4096: width-sharded on 12 banks, 4224 over 12 banks. wo_mem_config = configuration.create_dram_sharded_mem_config( @@ -163,16 +165,16 @@ def __init__( cache_k = torch.zeros( ( self.max_batch_size, - self.n_kv_heads, - self.sliding_window, + self.n_kv_heads // configuration.num_devices, + self.max_seq_len, self.head_dim, ) ) cache_v = torch.zeros( ( self.max_batch_size, - self.n_kv_heads, - self.sliding_window, + self.n_kv_heads // configuration.num_devices, + self.max_seq_len, self.head_dim, ) ) @@ -180,14 +182,14 @@ def __init__( self.layer_past = [ ttnn.as_tensor( k_or_v, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), - layout=self.model_config["ATTN_W_LAYOUT_TILE"], dtype=self.dtype, + layout=self.model_config["ATTN_W_LAYOUT_TILE"], + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), cache_file_name=f"{weight_cache_path}/kvcache_{k_or_v.shape}" if weight_cache_path and not configuration.dummy_weights else None, - memory_config=ttnn.DRAM_MEMORY_CONFIG, ) for k_or_v in [cache_k, cache_v] ] @@ -198,14 +200,14 @@ def forward_decode( self, x: ttnn.Tensor, current_pos, - rot_mat=None, + rot_mats=None, page_table=None, ) -> ttnn.Tensor: """ x: (seq_len, 1, batch, dim) current_pos: (batch_size), current token position in the sequence for each user """ - assert self.max_batch_size * self.n_kv_heads < 64 + # assert self.max_batch_size * self.n_kv_heads < 64 # TODO Miguel Are these needed? - check these params ### # QKV matmuls # Use HiFi2 for DRAM-sharded matmuls as they are otherwise flop-bound. Loses 1 bit of activation precision. @@ -218,10 +220,10 @@ def forward_decode( compute_kernel_config=self.compute_kernel_config_hifi2, dtype=ttnn.bfloat16, ) - ttnn.deallocate(x) + # ttnn.deallocate(x) xqkv_fused = ttnn.sharded_to_interleaved(xqkv_fused_sharded, ttnn.L1_MEMORY_CONFIG) - ttnn.deallocate(xqkv_fused_sharded) + # ttnn.deallocate(xqkv_fused_sharded) # Reshape such that true unpadded batch is tracked in shape fqkv_shape = xqkv_fused.shape @@ -243,52 +245,44 @@ def forward_decode( memory_config=ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG, ) - ttnn.deallocate(xqkv_fused) + # ttnn.deallocate(xqkv_fused) - q_heads_1BQD = ttnn.linear( - q_heads_pre_rot_1BQD, - rot_mat, - program_config=self.model_config["ROT_MAT_BMM_PROGCFG"]( - q_heads_pre_rot_1BQD.shape[-2], q_heads_pre_rot_1BQD.shape[-1], rot_mat.shape[-1] - ), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config_hifi2, - dtype=ttnn.bfloat16, + # Q Rotary Embeddings + q_heads_1BQD = ttnn.experimental.rotary_embedding_llama( + q_heads_pre_rot_1BQD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True ) - k_heads_1BKD = ttnn.linear( - k_heads_pre_rot_1BKD, - rot_mat, - program_config=self.model_config["ROT_MAT_BMM_PROGCFG"]( - k_heads_pre_rot_1BKD.shape[-2], k_heads_pre_rot_1BKD.shape[-1], rot_mat.shape[-1] - ), - memory_config=k_heads_pre_rot_1BKD.memory_config(), - compute_kernel_config=self.compute_kernel_config_hifi2, - dtype=ttnn.bfloat16, + # K Rotary Embeddings + k_heads_1BKD = ttnn.experimental.rotary_embedding_llama( + k_heads_pre_rot_1BKD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True ) - ttnn.deallocate(q_heads_pre_rot_1BQD) - ttnn.deallocate(k_heads_pre_rot_1BKD) + # ttnn.deallocate(q_heads_pre_rot_1BQD) + # ttnn.deallocate(k_heads_pre_rot_1BKD) ### # KV update ### keys = self.layer_past[0] values = self.layer_past[1] - # k_heads, [seqlen, n_kv_heads, bsz, head_dim] # v_heads [seqlen, n_kv_heads, bsz, head_dim] - # keys, [max_batch_size, n_kv_heads // configuration.num_devices, sliding_window, head_dim] + # keys, [max_batch_size, n_kv_heads // configuration.num_devices, max_seq_len, head_dim] ttnn.experimental.paged_update_cache(keys, k_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table) ttnn.experimental.paged_update_cache( values, v_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table ) + self.layer_past[0] = keys self.layer_past[1] = values - ttnn.deallocate(k_heads_1BKD) - ttnn.deallocate(v_heads_1BKD) + # ttnn.deallocate(k_heads_1BKD) + # ttnn.deallocate(v_heads_1BKD) + # NOTE: Varying the batch size will result in slightly different outputs. + # For example, a prompt w/ 1 user vs, the same prompt repeated N times for N users, will produce different outputs + # This is because the SDPA op in decode mode has different number of reductions depending on batch size + # Which leads to slightly different outputs from attention (due to accumulated errors) if page_table: attn_output_1G4D = ttnn.transformer.paged_scaled_dot_product_attention_decode( q_heads_1BQD, @@ -313,7 +307,7 @@ def forward_decode( memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG? ) - ttnn.deallocate(q_heads_1BQD) + # ttnn.deallocate(q_heads_1BQD) attn_output_11BH = ttnn.to_memory_config( attn_output_1G4D, memory_config=self.model_config["SCORES_BATCHED_MM_OUTPUT_MEMCFG"] @@ -322,8 +316,8 @@ def forward_decode( attn_output_11BH, num_heads=self.n_local_heads, ) - ttnn.deallocate(attn_output_11BH) - ttnn.deallocate(attn_output_1G4D) + # ttnn.deallocate(attn_output_11BH) + # ttnn.deallocate(attn_output_1G4D) if self.is_multichip and self.use_fused_all_gather_matmul: attn_output_cat = ttnn.to_memory_config( @@ -369,7 +363,7 @@ def forward_decode( dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"]) return dense_out_sharded - def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = 0, page_table=None): + def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None): seq_len = x_11SH.shape[-2] assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" ### @@ -392,7 +386,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = if seq_len > 2048: xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) - ttnn.deallocate(x_11SH) + # ttnn.deallocate(x_11SH) # split qkv into heads ( @@ -407,61 +401,51 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - ttnn.deallocate(xqkv_fused) + # ttnn.deallocate(xqkv_fused) ### # Rotary embeddings ### q_heads_1QSD = ttnn.experimental.rotary_embedding_llama( - q_heads_1QSD_pre_rot, rot_mats[0], rot_mats[1], transformation_mats + q_heads_1QSD_pre_rot, + rot_mats[0], + rot_mats[1], + self.transformation_mats["prefill"], + is_decode_mode=False, ) - ttnn.deallocate(q_heads_1QSD_pre_rot) + # ttnn.deallocate(q_heads_1QSD_pre_rot) k_heads_1KSD = ttnn.experimental.rotary_embedding_llama( - k_heads_1KSD_pre_rot, rot_mats[0], rot_mats[1], transformation_mats + k_heads_1KSD_pre_rot, + rot_mats[0], + rot_mats[1], + self.transformation_mats["prefill"], + is_decode_mode=False, ) - ttnn.deallocate(k_heads_1KSD_pre_rot) + # ttnn.deallocate(k_heads_1KSD_pre_rot) # Fill KV-Cache keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1] k_heads_1KSD_8b = ttnn.typecast(k_heads_1KSD, dtype=ttnn.bfloat8_b) - ttnn.deallocate(k_heads_1KSD) - # sharding k_fill to deal with update_cache memory limitation - if seq_len >= self.min_kv_prefill_shard_seqlen: - k_fill = ttnn.interleaved_to_sharded(k_heads_1KSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) - else: - k_fill = k_heads_1KSD_8b - v_heads_1VSD_8b = ttnn.typecast(v_heads_1VSD, dtype=ttnn.bfloat8_b) - ttnn.deallocate(v_heads_1VSD) - # sharding v_fill to deal with update_cache memory limitation - if seq_len >= self.min_kv_prefill_shard_seqlen: - v_fill = ttnn.interleaved_to_sharded(v_heads_1VSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) - else: - v_fill = v_heads_1VSD_8b - if page_table: - ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill, page_table, batch_idx=user_id) - ttnn.experimental.paged_fill_cache(values_BKSD, v_fill, page_table, batch_idx=user_id) + ttnn.experimental.paged_fill_cache(keys_BKSD, k_heads_1KSD_8b, page_table, batch_idx=user_id) + ttnn.experimental.paged_fill_cache(values_BKSD, v_heads_1VSD_8b, page_table, batch_idx=user_id) else: ttnn.fill_cache( keys_BKSD, - k_fill, + k_heads_1KSD_8b, user_id, ) ttnn.fill_cache( values_BKSD, - v_fill, + v_heads_1VSD_8b, user_id, ) - if seq_len >= self.min_kv_prefill_shard_seqlen: - ttnn.deallocate(k_fill) - ttnn.deallocate(v_fill) - self.layer_past = [keys_BKSD, values_BKSD] # SDPA @@ -471,7 +455,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = v_heads_V1SD_8b = ttnn.reshape(v_heads_1VSD_8b, [self.n_local_kv_heads, 1, -1, self.head_dim]) q_heads_1QSD_8b = ttnn.typecast(q_heads_1QSD, dtype=ttnn.bfloat8_b) - ttnn.deallocate(q_heads_1QSD) + # ttnn.deallocate(q_heads_1QSD) q_heads_84SD_8b = ttnn.reshape( q_heads_1QSD_8b, [self.n_local_kv_heads, self.n_local_heads // self.n_local_kv_heads, -1, self.head_dim] @@ -487,9 +471,9 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = ) # deallocate keys and values - ttnn.deallocate(q_heads_84SD_8b) - ttnn.deallocate(k_heads_K1SD_8b) - ttnn.deallocate(v_heads_V1SD_8b) + # ttnn.deallocate(q_heads_84SD_8b) + # ttnn.deallocate(k_heads_K1SD_8b) + # ttnn.deallocate(v_heads_V1SD_8b) attn_output_1QSD = ttnn.reshape(attn_output_84SD, [1, self.n_local_heads, -1, self.head_dim]) @@ -500,7 +484,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = attn_output_1QSD, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - ttnn.deallocate(attn_output_1QSD) + # ttnn.deallocate(attn_output_1QSD) # reshaping long sequence to matmul fit on device if seq_len > 2048: @@ -526,7 +510,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = ) if seq_len > 2048: output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) - ttnn.deallocate(attn_output_11SH) + # ttnn.deallocate(attn_output_11SH) # Reduce-scatter if self.is_multichip and not self.use_fused_all_gather_matmul: @@ -537,15 +521,14 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - ttnn.deallocate(output_11SH) + # ttnn.deallocate(output_11SH) return dense_out_reduced else: return output_11SH - def forward( - self, x, current_pos, rot_mats=None, transformation_mats=None, user_id=0, mode="decode", page_table=None - ): + # TODO Miguel: Remove transformation_mats input (send at initialization instead) + def forward(self, x, current_pos, rot_mats=None, user_id=0, mode="decode", page_table=None): if mode == "prefill": - return self.forward_prefill(x, rot_mats, transformation_mats, user_id, page_table) + return self.forward_prefill(x, rot_mats, user_id, page_table) else: return self.forward_decode(x, current_pos, rot_mats, page_table) diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index 6368443df4f..43ca95bbe74 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -16,6 +16,13 @@ def forward(self, x): return self.emb(x) +# Default configuration for Paged Attention +class PagedAttentionConfig: + def __init__(self, block_size=32, max_num_blocks=1024): + self.block_size = block_size + self.max_num_blocks = max_num_blocks + + def encode_prompt_llama_instruct(tokenizer, prompt_text, system_prompt_text=None): """<|begin_of_text|><|start_header_id|>system<|end_header_id|> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> diff --git a/models/demos/llama3/tt/llama_decoder.py b/models/demos/llama3/tt/llama_decoder.py index 578e0bf81a6..e5edfce889a 100644 --- a/models/demos/llama3/tt/llama_decoder.py +++ b/models/demos/llama3/tt/llama_decoder.py @@ -10,7 +10,17 @@ class TtTransformerBlock(LightweightModule): - def __init__(self, args, mesh_device, dtype, state_dict, layer_num, weight_cache_path): + def __init__( + self, + args, + mesh_device, + dtype, + state_dict, + layer_num, + weight_cache_path, + transformation_mats, + paged_attention_config=None, + ): super().__init__() self.state_dict = state_dict @@ -25,7 +35,6 @@ def __init__(self, args, mesh_device, dtype, state_dict, layer_num, weight_cache self.max_batch_size = args.max_batch_size self.n_kv_heads = args.n_kv_heads self.current = 0 - self.sliding_window = args.sliding_window self.model_config = args.get_model_config() self.layer_num = layer_num @@ -36,7 +45,9 @@ def __init__(self, args, mesh_device, dtype, state_dict, layer_num, weight_cache weight_cache_path=weight_cache_path, layer_num=layer_num, dtype=dtype, + transformation_mats=transformation_mats, configuration=args, + paged_attention_config=paged_attention_config, ) self.feed_forward = TtLlamaMLP( mesh_device=mesh_device, @@ -82,8 +93,7 @@ def forward( self, x: ttnn.Tensor, current_pos, - rot_mat=None, - transformation_mats=None, + rot_mats=None, user_id=0, mode="decode", page_table=None, @@ -99,8 +109,7 @@ def forward( attn_out = self.attention.forward( attn_in, current_pos, - rot_mat, - transformation_mats, + rot_mats, user_id, mode, page_table, diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 04cf2c8d77b..e04ed2c4cf8 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -24,6 +24,8 @@ def __init__( mesh_device, state_dict, weight_cache_path, + transformation_mats, + paged_attention_config=None, ): super().__init__() self.args = args @@ -44,6 +46,8 @@ def __init__( state_dict=state_dict, weight_cache_path=weight_cache_path, layer_num=i, + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) for i in range(self.n_layers) ] @@ -76,8 +80,7 @@ def forward( self, x: ttnn.Tensor, current_pos, - rot_mat=None, - transformation_mats=None, + rot_mats=None, user_id=0, mode="decode", page_table=None, @@ -88,7 +91,7 @@ def forward( x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"]) for layer in self.layers: - x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table) + x = layer(x, current_pos, rot_mats, user_id, mode, page_table) if mode == "prefill" and get_last_token == -1: return x diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py new file mode 100644 index 00000000000..576ce982e8c --- /dev/null +++ b/models/demos/llama3/tt/llama_rope.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from models.common.lightweightmodule import LightweightModule +from models.demos.llama3.tt.llama_common import precompute_freqs, get_rot_transformation_mat, gather_cos_sin +from models.utility_functions import nearest_32 +from loguru import logger + + +def compute_gather_cos_sin(dhead, end, theta, position_ids, use_scaled_rope): + cos, sin = precompute_freqs(dhead, end, theta, use_scaled_rope) + return gather_cos_sin(position_ids, cos, sin) + + +class TtLlamaRotarySetup(LightweightModule): + def __init__( + self, + device, + batch_size: int, + head_dim: int, + max_seq_len: int, + rope_theta: float = 10000, + use_scaled_rope: bool = False, + datatype=ttnn.bfloat16, + ): + super().__init__() + + self.batch_size = batch_size + self.head_dim = head_dim + self.device = device + self.is_mesh_device = isinstance(device, ttnn._ttnn.multi_device.MeshDevice) + self.num_devices = device.get_num_devices() if self.is_mesh_device else 1 + + self.core_grid = device.compute_with_storage_grid_size() + num_cores = self.core_grid.x * self.core_grid.y + + # Generate the cos/sin matrices needed for ttnn.embedding op + cos_matrix, sin_matrix = compute_gather_cos_sin( + dhead=head_dim, + end=max_seq_len * 2, + theta=rope_theta, + position_ids=torch.arange(max_seq_len), + use_scaled_rope=use_scaled_rope, + ) + + self.cos_matrix = ttnn.from_torch( + cos_matrix, + device=device, + layout=ttnn.ROW_MAJOR_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + self.sin_matrix = ttnn.from_torch( + sin_matrix, + device=device, + layout=ttnn.ROW_MAJOR_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + batch_grid = ttnn.num_cores_to_corerangeset(batch_size, self.core_grid, row_wise=True) + # Generate the transformation matrix + trans_mat = get_rot_transformation_mat(dhead=ttnn.TILE_SIZE).repeat( + 1, + 1, + batch_size, + 1 + # 1, 1, num_cores, 1 + ) # Repeat across all cores on device + trans_mat_mem_config = ttnn.create_sharded_memory_config( + shape=(ttnn.TILE_SIZE, ttnn.TILE_SIZE), + core_grid=batch_grid, + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + self.transformation_mat = ttnn.from_torch( + trans_mat, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + memory_config=trans_mat_mem_config, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + def get_trans_mats(self): + assert self.transformation_mat is not None, "Transformation matrix not initialized" + return self.transformation_mat + + def get_rot_idxs(self, position_idxs, on_host=False): + assert isinstance(position_idxs, torch.Tensor), "Position ids must be a torch tensor" + assert len(position_idxs.shape) == 1, "position idxs must be a [batch] tensor" + + batch = position_idxs.shape[0] + position_idxs = position_idxs.reshape(1, batch) # [1, 1, 1, batch] + assert position_idxs.shape == (1, batch), "position idxs must be a [1, batch] tensor" + assert torch.min(position_idxs) >= 0, "position idxs must be non-negative" + + # Add padding if needed + pad_size = nearest_32(batch) - batch + position_idxs = torch.nn.functional.pad(position_idxs, (0, pad_size), "constant", 0) + + if on_host: # If tensor is on host, don't pass a mesh mapper if single-device + rot_idxs = ttnn.as_tensor( + position_idxs, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ReplicateTensorToMesh(self.device) if self.num_devices > 1 else None, + ) + else: # On device + rot_idxs = ttnn.as_tensor( + position_idxs, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + device=self.device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, + ) + + return rot_idxs + + def get_rot_mats(self, position_idxs, return_rot_idxs=False): + device = self.device + + # If position_idxs is a torch tensor, get the TTNN version of it + if isinstance(position_idxs, torch.Tensor): + rot_idxs = self.get_rot_idxs(position_idxs) + else: + rot_idxs = position_idxs + assert len(rot_idxs.shape) == 2 and rot_idxs.shape[0] == 1, "rot_idxs must be a [1, batch] tensor" + + # Send the idxs to device + if rot_idxs.device != device: + rot_idxs = ttnn.to_device(rot_idxs, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + embedding_layout = ttnn.TILE_LAYOUT + cos = ttnn.embedding(rot_idxs, self.cos_matrix, layout=embedding_layout) # [1, batch, head_dim] + sin = ttnn.embedding(rot_idxs, self.sin_matrix, layout=embedding_layout) # [1, batch, head_dim] + + cos = ttnn.unsqueeze_to_4D(cos) # [1, 1, batch, head_dim] + sin = ttnn.unsqueeze_to_4D(sin) # [1, 1, batch, head_dim] + + cos = ttnn.transpose(cos, 1, 2) # [1, batch, 1[32], head_dim] + sin = ttnn.transpose(sin, 1, 2) # [1, batch, 1[32], head_dim] + + if self.batch_size % ttnn.TILE_SIZE != 0: + cos = cos[:, : self.batch_size, :, :] + sin = sin[:, : self.batch_size, :, :] + + grid = ttnn.num_cores_to_corerangeset(self.batch_size, self.core_grid, row_wise=True) + mem_config = ttnn.create_sharded_memory_config( + shape=(ttnn.TILE_SIZE, self.head_dim), + core_grid=grid, + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + cos = ttnn.interleaved_to_sharded(cos, mem_config) # [1, 1 (= batch / shard_num_cores), 1[32], self.head_dim] + sin = ttnn.interleaved_to_sharded(sin, mem_config) # [1, 1 (= batch / shard_num_cores), 1[32], self.head_dim] + + if return_rot_idxs: + return [cos, sin], rot_idxs + return [cos, sin] diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index aaf0352c809..7ba2d66f1b0 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -24,17 +24,6 @@ class TtModelArgs: - paged_attention_config = None - - # TODO Update these params. In init we update the max_seq_len to 32k if it's a single device - max_batch_size = 1 - # Context length for Llama models (if single device, reduce to 32k in init) - max_seq_len = 8192 * 16 # 128k - kv_seq_len = 8192 * 16 # 128k - sliding_window = 8192 * 16 # 128k - - tile_size = 32 - OP_KEYS = ( # Embedding "EMB_WEIGHTS", @@ -68,13 +57,16 @@ class TtModelArgs: "LLAMA3_1_70B_PARAMS": "models/demos/llama3/model_params/Llama3.1-70B-Instruct", } - def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_size=1): + def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_size=1, max_seq_len=1024 * 128): # Add this near the top of the class, with other class attributes self.num_devices = mesh_device.get_num_devices() if mesh_device else 0 self.mesh_device = mesh_device self.device_name = {0: "CPU", 1: "N150", 2: "N300", 8: "T3K", 32: "TG"}[self.num_devices] self.is_large_model = False self.model_name = "Unknown" # Llama model name will be dependent on the checkpoint directory + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + self.tile_size = 32 LLAMA_DIR = os.getenv("LLAMA_DIR") if LLAMA_DIR: @@ -144,15 +136,11 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s self.num_devices <= 2 ): # for 1-chip or 2-chip devices limit the seqlen to 4K (to avoid OoO on N150/N300 CI tests) self.max_seq_len = 1024 * 4 - self.kv_seq_len = 1024 * 4 - self.sliding_window = 1024 * 4 if ( self.n_layers == 1 ): # When running a single layer just reduce the seq len to 128, since we won't be decoding that many iterations self.max_seq_len = 128 - self.kv_seq_len = 128 - self.sliding_window = 128 # Some consumers like SentencePiece only accept str not Path for files self.model_base_path = Path(self.DEFAULT_CKPT_DIR) @@ -167,7 +155,6 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s if "instruct" in self.DEFAULT_CACHE_PATH.lower(): self.instruct = True self.dummy_weights = dummy_weights - self.max_batch_size = max_batch_size self.tile_padded_batch_rows = self.tile_size * int(math.ceil(self.max_batch_size / self.tile_size)) # Enable workarounds by default until di/dt issues are fixed @@ -416,14 +403,6 @@ def find_largest_divisor(n, max_divisor=8): orientation=ttnn.ShardOrientation.ROW_MAJOR, use_height_and_width_as_shard_shape=True, ) - self.model_config["ROT_MAT_BMM_PROGCFG"] = lambda m, k, n: ttnn.MatmulMultiCoreReuseProgramConfig( - compute_with_storage_grid_size=grid_by_batch, - in0_block_w=math.ceil(k / 32), - out_subblock_h=1, - out_subblock_w=1, # TODO How to choose this subblock size? - per_core_M=math.ceil(m / 32), - per_core_N=math.ceil(n / 32), - ) self.model_config["ROT_MAT_MEMCONFIG"] = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py index c9958604dad..cd6efbe74c3 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py @@ -11,21 +11,16 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_pcc, ) -from models.utility_functions import skip_for_grayskull, skip_for_blackhole, nearest_32 -from models.demos.t3000.llama2_70b.tt.llama_common import precompute_freqs, freqs_to_rotation_matrix, gather_rotary_emb -from models.demos.t3000.llama2_70b.tt.llama_rope import TtLlamaRotarySetup +from models.utility_functions import skip_for_grayskull, skip_for_blackhole, nearest_32, skip_for_wormhole_b0 +from models.demos.llama3.tt.llama_common import ( + precompute_freqs, + get_rot_transformation_mat, +) +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup MAX_SEQ_LEN = 128 * 1024 -def get_rotation_mat(dhead, end, start_pos, seqlen, batch): - cos, sin = precompute_freqs(dhead, end) - rot_mat = freqs_to_rotation_matrix(cos, sin) - position_ids = torch.ones(seqlen, batch, dtype=torch.long) * start_pos - rot_emb = gather_rotary_emb(rot_mat, position_ids) - return rot_emb - - class TtLlamaRotary(torch.nn.Module): def __init__( self, @@ -110,15 +105,8 @@ def forward(self, xq, xk, freqs_cis): return xq, xk -def get_rot_transformation_mat(dhead): - rot_emb_matrix = torch.zeros(1, 1, dhead, dhead) - rot_emb_matrix[..., torch.arange(0, dhead, 2), torch.arange(1, dhead, 2)] = 1 - rot_emb_matrix[..., torch.arange(1, dhead, 2), torch.arange(0, dhead, 2)] = -1 - return rot_emb_matrix - - def compute_gather_cos_sin(dhead, end, position_ids): - cos, sin = precompute_freqs(dhead, end) + cos, sin = precompute_freqs(dhead, end, theta=10000.0, use_scaled=False) # Using reference defaults position_id_expanded = position_ids.unsqueeze(1).expand(-1, cos.shape[-1]) cos = cos.gather(0, position_id_expanded) sin = sin.gather(0, position_id_expanded) @@ -185,17 +173,16 @@ def run_test_rotary_embedding_llama( tt_model = TtLlamaRotary(device, head_dim, mode, datatype, fuse_qk) if mode == "decode": - rope_setup_decode = TtLlamaRotarySetup(device, head_dim, max_seq_len) - tt_model.transformation_mat = rope_setup_decode.transformation_mat - # For decode, TTNN expects inputs to be [1, batch, nh, dhead] inp = [x.transpose(1, 2) for x in inp] # inp: [seq_len, batch, n_heads, head_dim] if fuse_qk: - # For fused_qk, repeat the position_ids for q and k - position_ids = torch.concat([position_ids, position_ids]) + # Set up rope with 2 * batch size (for fused qk) + rope_setup_decode = TtLlamaRotarySetup(device, batch * 2, head_dim, max_seq_len) + tt_model.transformation_mat = rope_setup_decode.transformation_mat cos, sin = rope_setup_decode.get_rot_mats(position_ids) + assert ( batch % 8 == 0 or batch == 1 ), "Batch size must be a multiple of 8 or less than 8 for fused_qk rotary embedding" @@ -230,18 +217,19 @@ def run_test_rotary_embedding_llama( input_mem_configs = [q_input_mem_config, k_input_mem_config] else: + # Set up rope with batch size + rope_setup_decode = TtLlamaRotarySetup(device, batch, head_dim, max_seq_len) + tt_model.transformation_mat = rope_setup_decode.transformation_mat cos, sin = rope_setup_decode.get_rot_mats(position_ids) - grid = ( - ttnn.num_cores_to_corerangeset(batch, rope_setup_decode.core_grid, row_wise=True) - .bounding_box() - .grid_size() - ) + + grid = ttnn.num_cores_to_corerangeset(batch, rope_setup_decode.core_grid, row_wise=True) input_mem_configs = [ ttnn.create_sharded_memory_config( - shape=(1, batch, ttnn.TILE_SIZE, head_dim), - core_grid=ttnn.CoreGrid(y=grid.y, x=grid.x), + shape=(ttnn.TILE_SIZE, head_dim), + core_grid=grid, strategy=ttnn.ShardStrategy.HEIGHT, orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, ) for _ in range(len(inp)) ] @@ -313,7 +301,7 @@ def run_test_rotary_embedding_llama( (1, 128 * 1024), (64, 1), (32, 1), - (16, 1), + (15, 1), (8, 1), (1, 1), ), @@ -330,7 +318,7 @@ def run_test_rotary_embedding_llama( "prefill_128k", "decode_64", "decode_32", - "decode_16", + "decode_15", "decode_8", "decode_1", ), @@ -459,12 +447,9 @@ def test_rotary_embedding_llama_with_program_cache( num_ops = 2 # 2 * rope if mode == "decode": - num_ops += 4 # embedding + transpose + pad + interleaved_to_sharded + num_ops += 3 # embedding + transpose + interleaved_to_sharded - # When batch size is 1, transpose is a no-op - if batch == 1: - num_ops -= 1 - elif batch % 32 == 0: - num_ops -= 1 # When batch size is a multiple of 32, no padding + if batch % ttnn.TILE_SIZE != 0: + num_ops += 1 # slice assert device.num_program_cache_entries() == num_ops diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py index 579791f0eab..893fe74baa5 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py @@ -132,9 +132,9 @@ def test_rotary_embedding_llama_fused_qk_with_program_cache( cache_tensors.append(test_tensor) - if batch == 32 or batch == 16: - num_ops = 4 - else: - num_ops = 5 # embedding + fused_qk_rope + transpose + pad + interleaved_to_sharded + num_ops = 4 # embedding + fused_qk_rope + transpose + interleaved_to_sharded + + if (batch * 2) % ttnn.TILE_SIZE != 0: + num_ops += 1 # slice assert device.num_program_cache_entries() == num_ops diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp index 1fb032dcce6..c3c55559038 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp @@ -32,7 +32,7 @@ void Pad::validate_with_output_tensors( TT_FATAL(input_tensor.get_dtype() == DataType::FLOAT32 || input_tensor.get_dtype() == DataType::BFLOAT16, "Cannot pad tilized tensor with specified format"); } else if (input_tensor.get_layout() == Layout::ROW_MAJOR) { TT_FATAL(this->output_tensor_shape[3] % 2 == 0, "RM padding requires output X dim to be a multiple of 2"); - TT_FATAL(input_tensor.get_dtype() == DataType::FLOAT32 || input_tensor.get_dtype() == DataType::BFLOAT16, "Cannot pad RM tensor with specified format"); + // TT_FATAL(input_tensor.get_dtype() == DataType::FLOAT32 || input_tensor.get_dtype() == DataType::BFLOAT16, "Cannot pad RM tensor with specified format"); } if (input_tensor.is_sharded()) {