diff --git a/models/demos/llama3/demo/multimodal_demo_chat.py b/models/demos/llama3/demo/multimodal_demo_chat.py index 7b39fb3db61..ca3d5b498e3 100644 --- a/models/demos/llama3/demo/multimodal_demo_chat.py +++ b/models/demos/llama3/demo/multimodal_demo_chat.py @@ -8,19 +8,21 @@ from PIL import Image as PIL_Image from termcolor import cprint -from models.demos.llama3.demo.multimodal_demo_text import create_multimodal_model -import llama_models.llama3.reference_impl.generation as llama_reference_generation +import pytest +import os +import ttnn +import llama_models.llama3.reference_impl.generation as llama_reference_generation +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ImageMedia, UserMessage from pkg_resources import resource_filename IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) -import torch -import pytest -import os -import ttnn +from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision +from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model @pytest.mark.parametrize( @@ -36,39 +38,36 @@ "target", ("tt", "cpu"), ) -@pytest.mark.parametrize( - "warmup_iters", - (0, 1), -) def test_llama_multimodal_demo_chat( mesh_device, target, - warmup_iters, temperature: float = 0.5, top_p: float = 0.9, max_seq_len: int = 512, - max_batch_size: int = 4, + max_batch_size: int = 1, max_gen_len: Optional[int] = 200, model_parallel_size: Optional[int] = None, ): - mesh_device.enable_program_cache() - mesh_device.enable_async(True) ckpt_dir = os.environ["LLAMA_DIR"] tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") - generator = llama_reference_generation.Llama.build( - ckpt_dir, - tokenizer_path=tokenizer_path, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - model_parallel_size=model_parallel_size, - ) - - if target == "tt": + if target == "cpu": + generator = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + else: logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") - model = create_multimodal_model(generator.args, mesh_device) - generator.model = model + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter) # image understanding dialogs = [] @@ -85,26 +84,21 @@ def test_llama_multimodal_demo_chat( ) ], ] - # text only - dialogs += [ - [UserMessage(content="what is the recipe of mayonnaise in two sentences?")], - ] print(f"Running text completion on {target}") - for _ in range(warmup_iters + 1): - for dialog in dialogs: - result = generator.chat_completion( - dialog, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) + for dialog in dialogs: + result = generator.chat_completion( + dialog, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) - for msg in dialog: - print(f"{msg.role.capitalize()}: {msg.content}\n") + for msg in dialog: + print(f"{msg.role.capitalize()}: {msg.content}\n") - out_message = result.generation - print(f"> {out_message.role.capitalize()}: {out_message.content}") - for t in out_message.tool_calls: - print(f" Tool call: {t.tool_name} ({t.arguments})") - print("\n==================================\n") + out_message = result.generation + print(f"> {out_message.role.capitalize()}: {out_message.content}") + for t in out_message.tool_calls: + print(f" Tool call: {t.tool_name} ({t.arguments})") + print("\n==================================\n") diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index 102b03975e4..2029c43458b 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -8,36 +8,22 @@ from PIL import Image as PIL_Image from termcolor import cprint -import llama_models.llama3.reference_impl.generation as llama_reference_generation +import pytest +import os +import ttnn +import llama_models.llama3.reference_impl.generation as llama_reference_generation from llama_models.llama3.api.datatypes import ImageMedia +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.api.chat_format import ChatFormat + from pkg_resources import resource_filename IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) -import torch -import pytest -import os -import ttnn - - -def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16): - from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer - from models.demos.llama3.tt.model_config import TtModelArgs - - tt_model_args = TtModelArgs(mesh_device) - checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) - model = CrossAttentionTransformer( - model_args, - mesh_device, - checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - ) - model.setup_cache(model_args.max_batch_size, torch.float32) - return model +from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model +from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision @pytest.mark.parametrize( @@ -64,28 +50,30 @@ def test_llama_multimodal_demo_text( temperature: float = 0.5, top_p: float = 0.9, max_seq_len: int = 512, - max_batch_size: int = 4, + max_batch_size: int = 1, max_gen_len: Optional[int] = 200, model_parallel_size: Optional[int] = None, ): - mesh_device.enable_program_cache() - mesh_device.enable_async(True) ckpt_dir = os.environ["LLAMA_DIR"] tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") - generator = llama_reference_generation.Llama.build( - ckpt_dir, - tokenizer_path=tokenizer_path, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - model_parallel_size=model_parallel_size, - ) - - if target == "tt": + if target == "cpu": + generator = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + else: logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") - model = create_multimodal_model(generator.args, mesh_device) - generator.model = model + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter) with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") @@ -100,8 +88,6 @@ def test_llama_multimodal_demo_text( clutter = PIL_Image.open(f).convert("RGB") interleaved_contents = [ - # text only - "The color of the sky is blue but sometimes it can also be", # image understanding [ImageMedia(image=img), "If I had to write a haiku for this one"], [ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"], diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index a5bf099d027..964554280ee 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -10,8 +10,7 @@ import llama_models.llama3.reference_impl.generation as llama_reference_generation from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.llama3.api.chat_format import ChatFormat, ModelInput - +from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ImageMedia, UserMessage from pkg_resources import resource_filename @@ -24,294 +23,7 @@ import ttnn import time - -class LlamaVision: - def __init__(self, model, model_args, mesh_device, vllm=False): - """ - Creating a LlamaVision wrapper requires only a mesh_device and model_args. - With model_args you have the checkpoint location, can specify max batch size - and max seqlen, and other model specific parameters. - - LlamaVision is general to text and chat. - - For bringup, make this class general to any backend implementation, as long as it takes torch tensors and returns torch tensors. - - """ - self.model = model - self.model_args = model_args - self.mesh_device = mesh_device - self.vllm = vllm - - def prefill_forward_single_user( - self, - vision_images, - vision_mask, - tokens, - xattn_caches, - user_id, - total_len, - prefill_len, - ): - """ - Performs vision encode step then text prefill. - Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) - """ - B = tokens.shape[0] - xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( - batch_images=[vision_images], - batch_masks=[vision_mask], - total_len=total_len, - xattn_caches=xattn_caches, - user_id=user_id, - ) - - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - tt_position_id, - rot_mats, - transformation_mats, - ) = self.model.prepare_inputs_prefill( - tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len=prefill_len - ) - - tt_logits = self.model.ttnn_prefill_forward( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - xattn_caches, - tt_position_id, - rot_mats, - transformation_mats, - user_id, - ) - - logits = self.model.process_output_prefill(tt_logits, B, prefill_len) - - return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits - - def decode_forward( - self, - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ): - """ - Performs text decode step. - Returns logits - """ - - # forward_decode should be traced callable - # decorator does compilation, capture, execute - # B = 1 # TODO: Only supports batch=1 right now! Might make tokens input a tensor. - # S = 1 - B, S = tokens.shape - - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - _, - tt_position_id, - rot_mats, - _, - ) = self.model.prepare_inputs_decode( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id - ) - - tt_logits = self.model.ttnn_decode_forward( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - xattn_caches, - tt_position_id, - rot_mats, - ) - - logits = self.model.process_output_decode(tt_logits, B, S) - return logits - - def capture_trace( - self, - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ): - """ - Captures a trace for the decode_forward method. - """ - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - _, - tt_position_id, - rot_mats, - _, - ) = self.model.prepare_inputs_decode( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id - ) - - # Compile run - tt_logits_rm = self.model.ttnn_decode_forward( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - xattn_caches, - tt_position_id, - rot_mats, - ) - - # Get inputs ready for trace run - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - _, - tt_position_id, - rot_mats, - _, - ) = self.model.prepare_decode_inputs_host( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id - ) - - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_position_id, - rot_mats, - ) = self.model.copy_host_to_device( - (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats) - ) - - trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) - B = tokens.shape[0] - # Do on-device transformations of inputs before forward - tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device( - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - B=B, - ) - - tt_logits_rm = self.model.ttnn_decode_forward( - tt_h, - tt_xattn_mask_transform, - tt_full_text_mask_expand_1NSH_transform, - xattn_caches, - tt_position_id, - rot_mats, - ) - - ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) - - return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats - - def decode_forward_trace( - self, - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, # TODO: unused since captured in trace? - trace_id, - trace_logits_rm, - trace_h, - trace_xattn_mask, - trace_full_text_mask_expand_1NSH, - trace_position_id, - trace_rot_mats, - ): - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - _, - tt_position_id, - rot_mats, - _, - ) = self.model.prepare_decode_inputs_host( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id - ) - - self.model.copy_host_to_device( - host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats), - device_tensors=( - trace_h, - trace_xattn_mask, - trace_full_text_mask_expand_1NSH, - trace_position_id, - trace_rot_mats, - ), - ) - - ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) - - B, S = tokens.shape - logits = self.model.process_output_decode(trace_logits_rm, B=B, S=S) - - return logits - - def easy_trace( - self, - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ): - """ - Tracing is easy! Just call this method and you'll run traced - """ - if not hasattr(self, "trace_id"): - ( - trace_id, - tt_logits_rm, - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_position_id, - rot_mats, - ) = self.capture_trace( - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ) - self.trace_id = trace_id - self.trace_inputs = { - "tt_h": tt_h, - "tt_xattn_mask": tt_xattn_mask, - "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, - "tt_position_id": tt_position_id, - "rot_mats": rot_mats, - } - self.trace_outputs = { - "tt_logits_rm": tt_logits_rm, - } - - return self.decode_forward_trace( - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - self.trace_id, - self.trace_outputs["tt_logits_rm"], - self.trace_inputs["tt_h"], - self.trace_inputs["tt_xattn_mask"], - self.trace_inputs["tt_full_text_mask_expand_1NSH"], - self.trace_inputs["tt_position_id"], - self.trace_inputs["rot_mats"], - ) +from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision def get_sampler(temperature, top_p, tokenizer): @@ -370,11 +82,17 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn "normal", ], ) +@pytest.mark.parametrize( + "enable_trace", + (False, True), + ids=["no_trace", "trace"], +) @pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True) def test_llama_multimodal_demo_text( mesh_device, warmup_iters, test_case, + enable_trace, temperature: float = 0, top_p: float = 0.9, max_seq_len: int = 512, @@ -391,11 +109,11 @@ def test_llama_multimodal_demo_text( mesh_device.enable_program_cache() mesh_device.enable_async(True) model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) - model = LlamaVision(model, model_args, mesh_device) + generator = LlamaVision(model, model_args, mesh_device) tokenizer = Tokenizer(model_path=tokenizer_path) formatter = ChatFormat(tokenizer) - xattn_caches = model.model.setup_cache(model_args.max_batch_size) + xattn_caches = generator.model.setup_cache(model_args.max_batch_size) with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") @@ -447,7 +165,7 @@ def test_llama_multimodal_demo_text( cross_attention_masks, full_text_row_masked_out_mask, logits, - ) = model.prefill_forward_single_user( + ) = generator.prefill_forward_single_user( vision_images, vision_mask, prompt_tokens_tensor, @@ -459,57 +177,35 @@ def test_llama_multimodal_demo_text( prefill_end = time.perf_counter() next_token, text = sampler(logits) - # logger.info(f"Prefill output: {next_token}:{text}") tokens[0, prefill_len] = next_token decode_times = [] - # Capture trace - # next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S - # trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats = model.capture_trace( - # prefill_len, - # next_token_tensor, - # cross_attention_masks, - # full_text_row_masked_out_mask, - # xattn_caches, - # ) - for gen_idx in range(max_gen_len - 1): decode_start = time.perf_counter() position_id = prefill_len + gen_idx next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S - # logits = model.decode_forward( - # position_id, - # next_token_tensor, - # cross_attention_masks, - # full_text_row_masked_out_mask, - # xattn_caches, - # ) - logits = model.easy_trace( - position_id, - next_token_tensor, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ) - # logits = model.decode_forward_trace( - # position_id, - # next_token_tensor, - # cross_attention_masks, - # full_text_row_masked_out_mask, - # xattn_caches, - # trace_id, - # tt_logits_rm, - # tt_h, - # tt_xattn_mask, - # tt_full_text_mask_expand_1NSH, - # tt_position_id, - # rot_mats - # ) + + if enable_trace: + logits = generator.easy_trace( + position_id, + next_token_tensor, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + else: + logits = generator.decode_forward( + position_id, + next_token_tensor, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + next_token, text = sampler(logits) # Update next token tokens[0, position_id + 1] = next_token - # logger.info(f"Decode output {position_id}: {next_token}:{text}") decode_end = time.perf_counter() decode_times.append(decode_end - decode_start) @@ -530,4 +226,4 @@ def test_llama_multimodal_demo_text( decode_time_ms = sum(decode_times) / (gen_idx + 1) * 1000 logger.info(f"Decode time: {decode_time_ms:.2f} ms") - # ttnn.release_trace(model.mesh_device, trace_id) + # ttnn.release_trace(generator.mesh_device, trace_id) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 80a27df0679..15ec522058a 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -498,7 +498,6 @@ def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand """ Does any transformations on device tensors which are necessary before ttnn_decode_forward """ - print("transforming xattn mask") assert ( B == self.configuration.max_batch_size ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" diff --git a/models/demos/llama3/tt/multimodal/vision_generator.py b/models/demos/llama3/tt/multimodal/vision_generator.py new file mode 100644 index 00000000000..06f32bc160d --- /dev/null +++ b/models/demos/llama3/tt/multimodal/vision_generator.py @@ -0,0 +1,429 @@ +import ttnn +import torch + +from llama_models.llama3.api.datatypes import ( + InterleavedTextMedia, + StopReason, +) + +from llama_models.llama3.reference_impl.generation import ( + ChatPrediction, + CompletionPrediction, + TokenResult, + sample_top_p, +) + + +class LlamaVision: + def __init__(self, model, model_args, mesh_device, vllm=False, tokenizer=None, formatter=None): + """ + Creating a LlamaVision wrapper requires only a mesh_device and model_args. + With model_args you have the checkpoint location, can specify max batch size + and max seqlen, and other model specific parameters. + + LlamaVision is general to text and chat. + + For bringup, make this class general to any backend implementation, as long as it takes torch tensors and returns torch tensors. + + """ + self.model = model + self.model_args = model_args + self.mesh_device = mesh_device + self.vllm = vllm + self.tokenizer = tokenizer + self.formatter = formatter + + def prefill_forward_single_user( + self, + vision_images, + vision_mask, + tokens, + xattn_caches, + user_id, + total_len, + prefill_len, + ): + """ + Performs vision encode step then text prefill. + Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) + """ + B = tokens.shape[0] + xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( + batch_images=[vision_images], + batch_masks=[vision_mask], + total_len=total_len, + xattn_caches=xattn_caches, + user_id=user_id, + ) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + transformation_mats, + ) = self.model.prepare_inputs_prefill( + tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len=prefill_len + ) + + tt_logits = self.model.ttnn_prefill_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + xattn_caches, + tt_position_id, + rot_mats, + transformation_mats, + user_id, + ) + + logits = self.model.process_output_prefill(tt_logits, B, prefill_len) + + return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits + + def decode_forward( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Performs text decode step. + Returns logits + """ + + # forward_decode should be traced callable + # decorator does compilation, capture, execute + # B = 1 # TODO: Only supports batch=1 right now! Might make tokens input a tensor. + B, S = tokens.shape + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_inputs_decode( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + tt_logits = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + xattn_caches, + tt_position_id, + rot_mats, + ) + + logits = self.model.process_output_decode(tt_logits, B, S) + return logits + + def capture_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Captures a trace for the decode_forward method. + """ + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_inputs_decode( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + # Compile run + tt_logits_rm = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + xattn_caches, + tt_position_id, + rot_mats, + ) + + # Get inputs ready for trace run + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id + ) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.model.copy_host_to_device( + (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats) + ) + + trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) + B = tokens.shape[0] + # Do on-device transformations of inputs before forward + tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device( + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + B=B, + ) + + tt_logits_rm = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask_transform, + tt_full_text_mask_expand_1NSH_transform, + xattn_caches, + tt_position_id, + rot_mats, + ) + + ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) + + return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats + + def decode_forward_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, # TODO: unused since captured in trace? + trace_id, + trace_logits_rm, + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_position_id, + trace_rot_mats, + ): + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + self.model.copy_host_to_device( + host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats), + device_tensors=( + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_position_id, + trace_rot_mats, + ), + ) + + ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) + + B, S = tokens.shape + logits = self.model.process_output_decode(trace_logits_rm, B=B, S=S) + + return logits + + def easy_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Tracing is easy! Just call this method and we'll handle tracing for you. + """ + if not hasattr(self, "trace_id"): + ( + trace_id, + tt_logits_rm, + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.capture_trace( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + self.trace_id = trace_id + self.trace_inputs = { + "tt_h": tt_h, + "tt_xattn_mask": tt_xattn_mask, + "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, + "tt_position_id": tt_position_id, + "rot_mats": rot_mats, + } + self.trace_outputs = { + "tt_logits_rm": tt_logits_rm, + } + + return self.decode_forward_trace( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + self.trace_id, + self.trace_outputs["tt_logits_rm"], + self.trace_inputs["tt_h"], + self.trace_inputs["tt_xattn_mask"], + self.trace_inputs["tt_full_text_mask_expand_1NSH"], + self.trace_inputs["tt_position_id"], + self.trace_inputs["rot_mats"], + ) + + def generate( + self, + model_input, + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + ): + # Do initial prefill + vision_images = model_input.vision.images + vision_mask = model_input.vision.mask + prompt_tokens = model_input.tokens + prefill_len = len(prompt_tokens) + total_len = prefill_len + max_gen_len # Prepares mask for full length of output + + prompt_tokens_tensor = torch.tensor(prompt_tokens, dtype=torch.long).reshape(1, -1) # B, S + # Suboptimal to allocate caches every time + xattn_caches = self.model.setup_cache(self.model_args.max_batch_size) + ( + xattn_caches, + cross_attention_masks, + full_text_row_masked_out_mask, + logits, + ) = self.prefill_forward_single_user( + vision_images, + vision_mask, + prompt_tokens_tensor, + xattn_caches, + user_id=0, + total_len=total_len, + prefill_len=prefill_len, + ) + + def sample(logits): + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + next_token = next_token.reshape(-1) + return next_token, self.tokenizer.decode(next_token.tolist()) + + next_token, text = sample(logits) + + yield TokenResult( + token=next_token[0].item(), + text=text, + ) + + for gen_idx in range(max_gen_len - 1): + position_id = prefill_len + gen_idx + next_token_tensor = next_token.reshape(1, 1) # B, S + + logits = self.decode_forward( + position_id, + next_token_tensor, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + + next_token, text = sample(logits) + yield TokenResult( + token=next_token[0].item(), + text=text, + ) + + def chat_completion( + self, + messages, + temperature=0.6, + top_p: float = 0.9, + max_gen_len=None, + ): + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.configuration.max_seq_len: + max_gen_len = self.model.configuration.max_seq_len - 1 + + tokens = [] + + stop_reason = None + breakpoint() + for result in self.generate( + model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format=False), + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ): + tokens.append(result.token) + if result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + elif result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + message = self.formatter.decode_assistant_message(tokens, stop_reason) + + return ChatPrediction(generation=message) + + def text_completion( + self, + content: InterleavedTextMedia, + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len=None, + ): + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.configuration.max_seq_len: + max_gen_len = self.model.configuration.max_seq_len - 1 + + model_input = self.formatter.encode_content(content) + + tokens = [] + + for result in self.generate( + model_input=model_input, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ): + tokens.append(result.token) + + generation = self.tokenizer.decode(tokens) + + return CompletionPrediction(generation=generation)