From bab4a1dae4c45c5cc8cee0deada3d4d0d3488d18 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Fri, 18 Oct 2024 14:51:41 +0000 Subject: [PATCH 01/16] #13368: Move repeat interleave to xattn cache generation. (cherry picked from commit 7d90c9cd597f14bc924fa5314342b3a9c71ae009) --- .../demos/llama3/demo/multimodal_demo_text.py | 28 +++++++++++-------- .../multimodal/test_llama_cross_attention.py | 10 +++++-- ..._llama_cross_attention_transformer_text.py | 11 ++++++-- .../multimodal/test_llama_cross_block.py | 10 +++++-- .../tt/multimodal/llama_cross_attention.py | 16 +++++++---- 5 files changed, 52 insertions(+), 23 deletions(-) diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index df2e9a730d3..cb09b79cd2f 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -59,9 +59,14 @@ def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16): "target", ("tt", "cpu"), ) +@pytest.mark.parametrize( + "warmup_iters", + (0, 1), +) def test_llama_multimodal_demo_text( mesh_device, target, + warmup_iters, temperature: float = 0, top_p: float = 0.9, max_seq_len: int = 512, @@ -115,14 +120,15 @@ def test_llama_multimodal_demo_text( ] print(f"Running text completion on {target}") - for content in interleaved_contents: - result = generator.text_completion( - content, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) - - cprint(f"{content}", end="") - cprint(f"{result.generation}", color="yellow") - print("\n==================================\n") + for _ in range(warmup_iters + 1): + for content in interleaved_contents: + result = generator.text_completion( + content, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + + cprint(f"{content}", end="") + cprint(f"{result.generation}", color="yellow") + print("\n==================================\n") diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 46f7aba4b78..9292cd7df32 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -100,13 +100,19 @@ def test_llama_cross_attention_inference( pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens) pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) pt_xattn_cache_chunks = [ - x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache + # x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache + x.view(batch, n_heads, vision_seq_len, head_dim) + for x in pt_xattn_cache ] tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_xattn_tokens) tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, n_kv_heads, vision_seq_len, head_dim + # batch, n_kv_heads, vision_seq_len, head_dim + batch, + n_heads, + vision_seq_len, + head_dim, ) for x in tt_xattn_cache ] diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 1ce9d2e5699..60de5e6197f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -5,6 +5,7 @@ import pytest from loguru import logger import os +import time import ttnn import importlib @@ -118,13 +119,19 @@ def test_llama_cross_attention_transformer_text_inference( pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx] # slice out replicated k/v heads pt_xattn_cache_chunks = [ - x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache_chunks + # x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache_chunks + x.view(batch, n_heads, vision_seq_len, head_dim) + for x in pt_xattn_cache_chunks ] tt_xattn_cache = [layer.compute_xattn_kv_cache(tt_vision_tokens) for layer in tt_model.cross_attention_layers] tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, n_kv_heads, vision_seq_len, head_dim + # batch, n_kv_heads, vision_seq_len, head_dim + batch, + n_heads, + vision_seq_len, + head_dim, ) for kv_cache in tt_xattn_cache for x in kv_cache diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 36043e15437..e02873def82 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -94,13 +94,19 @@ def test_llama_cross_attention_transformer_block_inference( pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens) pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) pt_xattn_cache_chunks = [ - x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache + # x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache + x.view(batch, n_heads, vision_seq_len, head_dim) + for x in pt_xattn_cache ] tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_xattn_tokens) tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, n_kv_heads, vision_seq_len, head_dim + # batch, n_kv_heads, vision_seq_len, head_dim + batch, + n_heads, + vision_seq_len, + head_dim, ) for x in tt_xattn_cache ] diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 4303da8acf8..bf7a8ee7575 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -164,6 +164,9 @@ def compute_xattn_kv_cache(self, xattn_tokens): ) xk = self.k_norm(xk) + + xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) + xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) return [xk, xv] ### EVERYTHING BELOW IS BROKEN OMG @@ -224,8 +227,8 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xk, xv = xattn_cache cache_seq_len = xk.shape[-2] - xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) - xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) + # xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) + # xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) scores = ttnn.matmul( xq, @@ -310,13 +313,14 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH xk, xv = xattn_cache cache_seq_len = xk.shape[-2] + # NOTE: Doing repeat in xattn_cache generation to avoid massive overhead in forward # NOTE: Using naive SDPA for now since FlashDecode does not allow non-causal mask # xq = ttnn.reshape(xq, [self.n_local_heads // self.n_local_kv_heads, self.n_local_kv_heads, seq_len, self.head_dim]) # NOTE: repeat doesn't work, need to use repeat_interleave - # xk = ttnn.repeat(xk, ttnn.Shape((self.n_local_heads // self.n_local_kv_heads, 1, 1, 1))) - xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) - # xv = ttnn.repeat(xv, ttnn.Shape((self.n_local_heads // self.n_local_kv_heads, 1, 1, 1))) - xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) + # # xk = ttnn.repeat(xk, ttnn.Shape((self.n_local_heads // self.n_local_kv_heads, 1, 1, 1))) + # xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) + # # xv = ttnn.repeat(xv, ttnn.Shape((self.n_local_heads // self.n_local_kv_heads, 1, 1, 1))) + # xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) scores = ttnn.matmul( xq, From ee48aaa9dce473e7e33aefe84c64b7d7c1a1b7dc Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 21 Oct 2024 13:50:54 +0000 Subject: [PATCH 02/16] #0: Clean up demo, enable arbitrary padding for multimodal text sequence (cherry picked from commit a5254cfaa5670b9d0348b1ff80f482ee4821dbd6) --- .../demos/llama3/demo/multimodal_demo_text.py | 22 ++++----- models/demos/llama3/tt/model_config.py | 2 +- .../tt/multimodal/llama_vision_model.py | 46 +++++++++++++++---- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index cb09b79cd2f..c953db0794d 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -67,11 +67,11 @@ def test_llama_multimodal_demo_text( mesh_device, target, warmup_iters, - temperature: float = 0, + temperature: float = 0.5, top_p: float = 0.9, max_seq_len: int = 512, max_batch_size: int = 4, - max_gen_len: Optional[int] = None, + max_gen_len: Optional[int] = 200, model_parallel_size: Optional[int] = None, ): mesh_device.enable_program_cache() @@ -101,22 +101,18 @@ def test_llama_multimodal_demo_text( with open(THIS_DIR / "resources/ocr_image.jpeg", "rb") as f: ocr_image = PIL_Image.open(f).convert("RGB") - # with open(THIS_DIR / "resources/clutter.jpeg", "rb") as f: - # clutter = PIL_Image.open(f).convert("RGB") + + with open(THIS_DIR / "resources/clutter.jpeg", "rb") as f: + clutter = PIL_Image.open(f).convert("RGB") interleaved_contents = [ # text only - # "The color of the sky is blue but sometimes it can also be", + "The color of the sky is blue but sometimes it can also be", # image understanding - # [ - # ImageMedia(image=img), - # "If I had to write a haiku for this one", - # ], + [ImageMedia(image=img), "If I had to write a haiku for this one"], + [ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"], [ImageMedia(image=ocr_image), "The full text in this image is as follows"], - # [ - # ImageMedia(image=clutter), - # "The count of vases, books, and miscellaneous items in this image is", - # ] + [ImageMedia(image=clutter), "The count of vases, books, and miscellaneous items in this image is"], ] print(f"Running text completion on {target}") diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 7d32bf6ce6f..ad5173533cb 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -476,7 +476,7 @@ def find_largest_divisor(n, max_divisor=8): k=cache_seq_len, n=self.head_dim, grid_size=(8, 8), - in0_block_w=1, + # in0_block_w=1, # TODO: Remove this when we get non-causal FlashDecode fuse_batch=False, ) self.model_config["VISION_XATTN_DENSE_PROGCFG"] = lambda seq_len: self.matmul_config( diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index c6b72a78506..ab8841acd0f 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -272,6 +272,13 @@ def compute_vision_tokens_masks( ) return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask) + def validate_inputs(self, tokens): + batch, seq_len = tokens.shape[:2] + assert batch == 1, f"Only batch 1 is supported, got {batch}" + assert ( + seq_len <= self.configuration.max_seq_len + ), f"Sequence length {seq_len} exceeds max sequence length {self.configuration.max_seq_len}" + def forward( self, position_ids: torch.Tensor, @@ -281,8 +288,10 @@ def forward( xattn_caches: torch.Tensor, text_only_inference: bool = False, ) -> torch.Tensor: + self.validate_inputs(tokens) h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) batch, seq_len = h.shape[:2] + padded_seq_len = _get_padded_prefill_seqlen(seq_len) if seq_len == 1: mode = "decode" else: @@ -302,35 +311,39 @@ def forward( if mode == "prefill": xattn_mask_expand = torch.nn.functional.pad( xattn_mask_expand, - (0, 0, 0, 128 - xattn_mask_expand.shape[2]), + (0, 0, 0, padded_seq_len - xattn_mask_expand.shape[2]), "constant", get_negative_inf_value(torch.float32), ) + tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=self.mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) + tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) full_text_mask = full_text_row_masked_out_mask[:, :, position_ids] if mode == "prefill": full_text_mask = torch.nn.functional.pad( - full_text_mask, (0, 0, 0, 128 - full_text_mask.shape[2]), "constant", 0 + full_text_mask, (0, 0, 0, padded_seq_len - full_text_mask.shape[2]), "constant", 0 ) full_text_mask_expand_1NSH = full_text_mask.expand( -1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim ) + tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=self.mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) + tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim) tt_full_text_mask_expand_11SD = ttnn.from_torch( @@ -343,8 +356,7 @@ def forward( ) # Check mask shapes, pad if in prefill? if mode == "prefill": - # DEBUG: pad h seqlen to 128 - h = torch.nn.functional.pad(h, (0, 0, 0, 128 - h.shape[1]), "constant", 0) + h = torch.nn.functional.pad(h, (0, 0, 0, padded_seq_len - h.shape[1]), "constant", 0) tt_h = prepare_inputs_ttnn_prefill( h, self.mesh_device, @@ -427,9 +439,10 @@ def forward( text_only_inference=text_only_inference, ) - tt_out = ttnn.to_torch(logits, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() if mode == "prefill": - tt_out = tt_out[0].reshape(batch, 128, -1)[:, :seq_len, :] # DEBUG: undo padding + tt_out = tt_out[0].reshape(batch, padded_seq_len, -1)[:, :seq_len, :] # DEBUG: undo padding else: tt_out = tt_out[0, ..., :batch, :].transpose(0, 1).reshape(batch, seq_len, -1) @@ -493,3 +506,16 @@ def _pad_masks( out_masks[idx, mask_elem[0] : mask_elem[1], mask_idx, :mask_num_chunks].fill_(0.0) return out_masks + + +def _get_padded_prefill_seqlen(seq_len): + """ + If seq_len is less than 128, pad to 128 + If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 1024 + """ + if seq_len < 128: + return 128 + else: + mult_1024 = 1024 * math.ceil(seq_len / 1024) + pow_2 = 2 ** math.ceil(math.log2(seq_len)) + return min(mult_1024, pow_2) From 29ae070cecddec6a349654b0ac48caad158caec9 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 21 Oct 2024 16:10:35 +0000 Subject: [PATCH 03/16] #13368: Add llama_models Meta reference for Llama3.2 as a submodule --- .gitmodules | 3 +++ models/demos/llama3/reference/llama_models | 1 + 2 files changed, 4 insertions(+) create mode 160000 models/demos/llama3/reference/llama_models diff --git a/.gitmodules b/.gitmodules index ab121e423f3..1c29f48e987 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,3 +28,6 @@ [submodule "tt_metal/third_party/tt_llk_blackhole"] path = tt_metal/third_party/tt_llk_blackhole url = https://github.com/tenstorrent/tt-llk-bh.git +[submodule "models/demos/llama3/reference/llama_models"] + path = models/demos/llama3/reference/llama_models + url = https://github.com/tenstorrent/llama-models.git diff --git a/models/demos/llama3/reference/llama_models b/models/demos/llama3/reference/llama_models new file mode 160000 index 00000000000..c217d3eb10f --- /dev/null +++ b/models/demos/llama3/reference/llama_models @@ -0,0 +1 @@ +Subproject commit c217d3eb10f6c01bbaa1aa7c714bb7c5ccf3b14f From 303be06ad38321e0580d9d623a2272073f423c0f Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 21 Oct 2024 16:13:10 +0000 Subject: [PATCH 04/16] #13368: Change reference imports to use new submodule --- .../demos/llama3/demo/multimodal_demo_text.py | 14 ++----- .../multimodal/test_llama_conv2d_patch.py | 39 +------------------ .../multimodal/test_llama_cross_attention.py | 5 +-- ..._llama_cross_attention_transformer_text.py | 6 +-- .../multimodal/test_llama_cross_block.py | 5 +-- .../multimodal/test_llama_image_attention.py | 9 +---- .../multimodal/test_llama_image_block.py | 9 +---- .../tests/multimodal/test_llama_image_mlp.py | 5 +-- .../test_llama_image_transformer.py | 9 +---- .../test_llama_image_transformer_vision.py | 8 ++-- .../test_llama_image_vision_encoder.py | 5 +-- .../tests/multimodal/test_llama_layernorm.py | 5 +-- .../test_llama_positional_embedding.py | 2 - .../test_llama_tile_position_embedding.py | 6 +-- .../multimodal/llama_image_vision_encoder.py | 6 +-- .../tt/multimodal/llama_vision_model.py | 18 ++++----- 16 files changed, 28 insertions(+), 123 deletions(-) diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index c953db0794d..223d4dc96f3 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -8,19 +8,11 @@ from PIL import Image as PIL_Image from termcolor import cprint -import importlib +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.generation as llama_reference_generation -llama_reference_generation = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.generation" -) - -# Must import from reference for formatter to understand type of ImageMedia -datatypes = importlib.import_module("models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.api.datatypes") -ImageMedia = datatypes.ImageMedia +from models.demos.llama3.reference.llama_models.models.llama3.api.datatypes import ImageMedia -# THIS_DIR = Path(__file__).parent.resolve() -# TODO: Generalize not to cglagovich home :) -THIS_DIR = Path("/home/cglagovich/tt-metal/models/demos/t3000/llama2_70b/reference/llama-models/models/scripts/") +THIS_DIR = Path(__file__).parent.parent.resolve() / "reference/llama_models/models/scripts/" import torch import pytest diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index 5458a1ca8c2..10a0f95ae3c 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -29,44 +29,7 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs - -import importlib - -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) - - -# ##### Torch op ##### -# class Conv2dPatch(torch.nn.Module): -# """Conv2D Patching layer with model parallelism. -# Column parallel over unfolded input. -# Arguments: -# in_channels: Input channels. -# out_channels: Output channels. -# kernel_size: Size of convolution kernel. -# stride (default 1): Stride for convolution. -# bias (default False): Use bias in Conv2d. -# Input: (bsz, in_channels, width, height) -# Output: (bsz, num_tokens, out_channels) -# """ - -# def __init__(self, in_channels, out_channels, kernel_size, stride, bias) -> None: -# super().__init__() -# if isinstance(kernel_size, int): -# kernel_size = (kernel_size, kernel_size) -# self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) -# self._linear = torch.nn.Linear( -# in_channels * kernel_size[0] * kernel_size[1], -# out_channels, -# bias=bias, -# ) - -# def forward(self, x: torch.Tensor) -> torch.Tensor: -# x = self._unfold(x) -# x = x.permute(0, 2, 1) -# x = F.linear(x, self._linear.weight) -# return x +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod @skip_for_grayskull("Requires wormhole_b0 to run") diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 9292cd7df32..04518ceea47 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -6,11 +6,8 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_attention import TtLlamaCrossAttention from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 60de5e6197f..8a52a2b317b 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -5,13 +5,9 @@ import pytest from loguru import logger import os -import time import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_text import ( TtLlamaCrossAttentionTransformerText, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index e02873def82..d8d4d731d3e 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -6,11 +6,8 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_block import TtLlamaCrossAttentionTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py index dce2fedf4bc..96ffc2d2aa6 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py @@ -6,14 +6,9 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -encoder_utils = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.encoder_utils" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_attention import TtLlamaImageAttention from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.model_config import TtModelArgs diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 38c4356b406..7e6acf469a0 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -6,14 +6,9 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -encoder_utils = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.encoder_utils" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_block import TtLlamaImageTransformerBlock from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.model_config import TtModelArgs diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py index 900be3f49fe..20f396a525d 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py @@ -7,11 +7,8 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_image_mlp import TtLlamaImageFeedForward from models.demos.llama3.tt.model_config import TtModelArgs from models.utility_functions import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index c425caec570..bc9d6b12b9a 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -6,14 +6,9 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -encoder_utils = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.encoder_utils" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_transformer import TtLlamaImageTransformer from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py index 3a719955c20..11b4c58939f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py @@ -6,11 +6,8 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_image_transformer_vision import TtLlamaCrossAttentionTransformerVision from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( @@ -62,6 +59,7 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese # Create rand inputs of the right shape batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, 448) + chunk_seq_len = (patch_size // model_args.vision_patch_size) ** 2 + 1 # tokens per chunk + 1 class token images = torch.randn(batch, num_media, num_chunks, n_channel, patch_size, patch_size) ars = torch.tensor([2, 2]).reshape(batch, num_media, 2) @@ -69,7 +67,7 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese reference_output = reference_model(images, ars) tt_out = tt_model(images, ars) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) - tt_output_torch = tt_output_torch[0, :, :, :].view(reference_output.shape) + tt_output_torch = tt_output_torch[0, :, :chunk_seq_len, :].view(reference_output.shape) logger.info(f"Reference output shape: {reference_output.shape}") logger.info(f"TT output shape: {tt_output_torch.shape}") diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py b/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py index 4da57f6cc33..65147f91fea 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py @@ -6,11 +6,8 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import TtLlamaVisionEncoder from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py index 9e36eb42247..cdabc148528 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py +++ b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py @@ -7,11 +7,8 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_layernorm import TtLayerNorm # Updated import for LayerNorm from models.demos.llama3.tt.model_config import TtModelArgs from models.utility_functions import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py index 7430796270f..5aa233f39db 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py @@ -31,8 +31,6 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs -import importlib - ##### Torch op ##### class PositionalEmbedding(nn.Module): diff --git a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py index 619ca0bdb60..7ffa7943d1e 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py @@ -31,11 +31,7 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs -import importlib - -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod @skip_for_grayskull("Requires wormhole_b0 to run") diff --git a/models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py b/models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py index 2efce9ecd0f..ee22ab29e30 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py +++ b/models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py @@ -24,11 +24,7 @@ synchronize_devices, ) -import importlib - -encoder_utils = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.encoder_utils" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.encoder_utils as encoder_utils def to_2tuple(x): diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index ab8841acd0f..4a7560de5e9 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -15,16 +15,8 @@ from torch import nn, Tensor -BFLOAT = False - -import importlib - -llama_reference_model = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -llama_reference_image_transforms = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.image_transform" -) +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_model +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.image_transform as llama_reference_image_transforms import ttnn from models.demos.llama3.tt.multimodal.llama_image_transformer_vision import TtLlamaCrossAttentionTransformerVision @@ -225,8 +217,12 @@ def compute_vision_tokens_masks( vision_tokens = self.vision_model(stacked_images, aspect_ratios) # Back to torch vision_tokens = ttnn.to_torch(vision_tokens, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)) + chunk_seq_len = (self.configuration.vision_chunk_size // self.configuration.vision_patch_size) ** 2 + 1 + # NOTE: slicing up to chunk_seq_len is necessary because padding information is lost by this point vision_tokens = ( - vision_tokens[0].reshape(bsz, max_num_images, self.max_num_chunks, -1, self.model_dim).float() + vision_tokens[0, :, :chunk_seq_len] + .reshape(bsz, max_num_images, self.max_num_chunks, -1, self.model_dim) + .float() ) bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) From ef329cd95415b99c05eeed0304383a7765d22cf2 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 21 Oct 2024 16:18:41 +0000 Subject: [PATCH 05/16] #13368: Clean up comments after pushing repeat_interleave into xattn_cache generation. --- .../llama3/tests/multimodal/test_llama_cross_attention.py | 7 +------ .../test_llama_cross_attention_transformer_text.py | 7 +------ .../llama3/tests/multimodal/test_llama_cross_block.py | 7 +------ 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 04518ceea47..ef7c4555348 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -96,16 +96,11 @@ def test_llama_cross_attention_inference( """ pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens) pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) - pt_xattn_cache_chunks = [ - # x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache - x.view(batch, n_heads, vision_seq_len, head_dim) - for x in pt_xattn_cache - ] + pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_xattn_tokens) tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - # batch, n_kv_heads, vision_seq_len, head_dim batch, n_heads, vision_seq_len, diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 8a52a2b317b..7fcba58fedc 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -114,16 +114,11 @@ def test_llama_cross_attention_transformer_text_inference( pt_xattn_cache_chunks = [torch.chunk(x, 2, dim=1) for x in pt_xattn_cache_chunks] pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx] # slice out replicated k/v heads - pt_xattn_cache_chunks = [ - # x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache_chunks - x.view(batch, n_heads, vision_seq_len, head_dim) - for x in pt_xattn_cache_chunks - ] + pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache_chunks] tt_xattn_cache = [layer.compute_xattn_kv_cache(tt_vision_tokens) for layer in tt_model.cross_attention_layers] tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - # batch, n_kv_heads, vision_seq_len, head_dim batch, n_heads, vision_seq_len, diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index d8d4d731d3e..4f33c000977 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -90,16 +90,11 @@ def test_llama_cross_attention_transformer_block_inference( """ pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens) pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) - pt_xattn_cache_chunks = [ - # x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache - x.view(batch, n_heads, vision_seq_len, head_dim) - for x in pt_xattn_cache - ] + pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_xattn_tokens) tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - # batch, n_kv_heads, vision_seq_len, head_dim batch, n_heads, vision_seq_len, From 27fa6d5a642c47d0ed22a7723fee08c36b8239d1 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 21 Oct 2024 21:17:25 +0000 Subject: [PATCH 06/16] #13368: Clean up vision tests. Unify assertions and pcc checks. Fix LM head splitting on T3k. --- models/demos/llama3/lt | 23 ++++++- .../multimodal/test_llama_class_embedding.py | 12 ++-- .../multimodal/test_llama_conv2d_patch.py | 8 +-- .../multimodal/test_llama_cross_attention.py | 14 ++-- ..._llama_cross_attention_transformer_text.py | 18 ++--- .../multimodal/test_llama_cross_block.py | 14 ++-- .../multimodal/test_llama_image_attention.py | 23 ++----- .../multimodal/test_llama_image_block.py | 21 ++---- .../tests/multimodal/test_llama_image_mlp.py | 49 +------------- .../test_llama_image_transformer.py | 26 +++---- .../test_llama_image_transformer_vision.py | 23 ++----- .../test_llama_image_vision_encoder.py | 67 ++++++------------- .../tests/multimodal/test_llama_layernorm.py | 7 +- .../test_llama_positional_embedding.py | 13 ++-- .../test_llama_tile_position_embedding.py | 13 ++-- models/demos/llama3/tt/model_config.py | 4 +- .../tt/multimodal/llama_cross_attention.py | 26 +++++-- .../llama_cross_attention_transformer_text.py | 3 +- .../tt/multimodal/llama_image_attention.py | 4 +- 19 files changed, 127 insertions(+), 241 deletions(-) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 388339ea584..be513f6e1bd 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -184,7 +184,14 @@ def main(stdscr): commands = parse_list(command_input, allow_space=False) # Generate combinations (reordered) - combinations = [(c, m, d) for c in commands for m in models for d in devices] + # combinations = [(c, m, d) for c in commands for m in models for d in devices] + combinations = [ + (c, m, d) + for c in commands + for m in models + for d in devices + if not (m == "11b" and d == "n150") + ] # Create output entries for command, model, device in combinations: @@ -230,7 +237,7 @@ def main(stdscr): else: # Ignore enter key when exiting continue - elif c == curses.KEY_BACKSPACE or c == 127 or c == ord("x"): + elif c == curses.KEY_BACKSPACE or c == 127: if current_line < len(input_fields): current_field = current_line # Remove last character from current field @@ -506,6 +513,18 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): "model": "pytest models/demos/llama3/tests/test_llama_model.py::test_llama_model_inference[wormhole_b0-True-mesh_device0-full]", "model-prefill": "pytest models/demos/llama3/tests/test_llama_model_prefill.py::test_llama_model_inference[wormhole_b0-True-mesh_device0-4096]", "model-quick": "pytest models/demos/llama3/tests/test_llama_model.py -k quick", + "vision-mlp": "pytest models/demos/llama3/tests/multimodal/test_llama_image_mlp.py", + "vision-attn": "pytest models/demos/llama3/tests/multimodal/test_llama_image_attention.py", + "vision-block": "pytest models/demos/llama3/tests/multimodal/test_llama_image_block.py", + "vision-xattn": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention.py", + "vision-xblock": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_block.py", + "vision-conv": "pytest models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py", + "vision-class": "pytest models/demos/llama3/tests/multimodal/test_llama_class_embedding.py", + "vision-tile-pos": "pytest models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py", + "vision-pos": "pytest models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py", + "vision-encoder": "pytest models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py", + "vision-text-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py", + "vision-vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py", } # Check if the command is a shortcut and replace it if necessary diff --git a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py index 533bc9c2106..09aaa82d16c 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py @@ -56,7 +56,7 @@ def forward(self, x): @pytest.mark.parametrize( "mesh_device", [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get( + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) ) ], @@ -86,7 +86,7 @@ def test_llama_class_embedding_inference( ensure_gc, ): dtype = ttnn.bfloat16 - pcc = 0.9999 + pcc_required = 0.9999 mesh_device.enable_async(True) @@ -145,12 +145,8 @@ def test_llama_class_embedding_inference( # Only select output from one device tt_output_torch = tt_output_torch[..., :dim].view(reference_output.shape) - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"Llama_ClassEmbedding Passed!") - else: - logger.warning(f"Llama_ClassEmbedding Failed!") - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index 10a0f95ae3c..12de8c159f1 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -63,7 +63,7 @@ def test_llama_conv2d_inference( bias, ensure_gc, ): - pcc = 0.9999 + pcc_required = 0.9999 dtype = ttnn.bfloat16 mesh_device.enable_async(True) @@ -128,8 +128,4 @@ def test_llama_conv2d_inference( logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"Llama_Conv2dPatch Passed!") - else: - logger.warning(f"Llama_Conv2dPatch Failed!") - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index ef7c4555348..434da78df59 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -43,7 +43,7 @@ def test_llama_cross_attention_inference( vision_seq_len, text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc ): dtype = ttnn.bfloat16 - pcc = 0.99 + pcc_required = 0.99 mesh_device.enable_async(True) @@ -110,7 +110,7 @@ def test_llama_cross_attention_inference( ] for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, pcc) + passing, pcc_message = comp_pcc(pt, tt, pcc_required) logger.info(comp_allclose(pt, tt)) logger.info(f"PCC: {pcc_message}") @@ -120,6 +120,8 @@ def test_llama_cross_attention_inference( logger.warning(f"compute_xattn_kv_cache Failed!") all_tests_pass = False + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + """ Test forward, prefill and decode! """ @@ -216,13 +218,9 @@ def test_llama_cross_attention_inference( tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(batch, seq_len, dim) else: tt_output_torch = tt_output_torch[0, ..., :batch, :].transpose(0, 1).view(batch, seq_len, dim) - passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) logger.info(comp_allclose(pt_out, tt_output_torch)) logger.info(f"PCC: {pcc_message}") all_tests_pass = all_tests_pass and passing - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 7fcba58fedc..a0a21debe64 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -53,8 +53,8 @@ def test_llama_cross_attention_transformer_text_inference( reset_seeds, ): dtype = ttnn.bfloat8_b - prefill_pcc = 0.98 - decode_pcc = 0.73 + prefill_pcc_required = 0.98 + decode_pcc_required = 0.73 mesh_device.enable_async(True) @@ -129,10 +129,10 @@ def test_llama_cross_attention_transformer_text_inference( ] for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, prefill_pcc) + passing, pcc_message = comp_pcc(pt, tt, prefill_pcc_required) logger.info(comp_allclose(pt, tt)) - logger.info(pcc_message) + logger.info(f"PCC: {pcc_message}") if passing: logger.info(f"compute_xattn_kv_cache Passed!") @@ -310,12 +310,12 @@ def test_llama_cross_attention_transformer_text_inference( tt_out = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) if mode == "prefill": tt_out = tt_out[0].reshape(logits.shape) - pcc = prefill_pcc + pcc_required = prefill_pcc_required else: tt_out = tt_out[0, ..., :batch, :].transpose(0, 1).view(logits.shape) - pcc = decode_pcc - passing, pcc_message = comp_pcc(logits, tt_out, pcc) + pcc_required = decode_pcc_required + passing, pcc_message = comp_pcc(logits, tt_out, pcc_required) logger.info(comp_allclose(logits, tt_out)) - logger.info(pcc_message) + logger.info(f"PCC: {pcc_message}") prev_pos = cur_pos - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 4f33c000977..78020e80358 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -43,7 +43,7 @@ def test_llama_cross_attention_transformer_block_inference( vision_seq_len, text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc ): dtype = ttnn.bfloat16 - pcc = 0.99 + pcc_required = 0.99 mesh_device.enable_async(True) @@ -104,7 +104,7 @@ def test_llama_cross_attention_transformer_block_inference( ] for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, pcc) + passing, pcc_message = comp_pcc(pt, tt, pcc_required) logger.info(comp_allclose(pt, tt)) logger.info(f"PCC: {pcc_message}") @@ -114,6 +114,8 @@ def test_llama_cross_attention_transformer_block_inference( logger.warning(f"compute_xattn_kv_cache Failed!") all_tests_pass = False + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + """ Test forward, prefill and decode! """ @@ -230,13 +232,9 @@ def test_llama_cross_attention_transformer_block_inference( tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(batch, seq_len, dim) else: tt_output_torch = tt_output_torch[0, ..., :batch, :].transpose(0, 1).view(batch, seq_len, dim) - passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) logger.info(comp_allclose(pt_out, tt_output_torch)) logger.info(f"PCC: {pcc_message}") all_tests_pass = all_tests_pass and passing - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py index 96ffc2d2aa6..e0a14b69e7b 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py @@ -38,7 +38,7 @@ ) def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 - pcc = 0.99 + pcc_required = 0.99 mesh_device.enable_async(True) @@ -56,8 +56,6 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro reference_model = llama_reference_mod.ImageAttention(dim=dim, head_dim=dim // heads, n_heads=heads) reference_model.load_state_dict(partial_state_dict) - all_tests_pass = True - tt_model = TtLlamaImageAttention( mesh_device, state_dict, @@ -89,8 +87,8 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro ) tt_attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) # Make striped attention mask to mask out our padding between 8 and 32 - # Striped mask doesn't affect PCC - # tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, 32, num_chunks) + # Striped mask doesn't affect PCC on first layer but is necessary for later layers + tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, 32, num_chunks) attention_input = attention_input.reshape(1, batch, -1, dim) @@ -114,18 +112,9 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro reference_output = reference_output.reshape(batch, num_chunks, ntok + npad, dim) reference_output = encoder_utils.contract_num_tokens_from_mult8(reference_output, npad) - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 7e6acf469a0..17b3c219446 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -33,12 +33,12 @@ ) @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], indirect=True, ) def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 - pcc = 0.99 + pcc_required = 0.99 mesh_device.enable_async(True) @@ -61,8 +61,6 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ ) reference_model.load_state_dict(partial_state_dict) - all_tests_pass = True - tt_model = TtLlamaImageTransformerBlock( mesh_device, state_dict, @@ -94,6 +92,7 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ attention_input.shape[0], attention_input.shape[1], attention_input.shape[2], attention_input.shape[3] ) tt_attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) + tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, 32, num_chunks) attention_input = attention_input.reshape(1, batch, -1, dim) tt_mask = ttnn.from_torch( @@ -114,18 +113,8 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ reference_output = reference_output.reshape(batch, num_chunks, ntok + npad, dim) reference_output = encoder_utils.contract_num_tokens_from_mult8(reference_output, npad) - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py index 20f396a525d..958bfb919ca 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py @@ -48,48 +48,6 @@ def test_llama_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seed model_args.WEIGHTS_DTYPE = dtype - """ - self.patch_size = 14 - self.vision_encoder = VisionEncoder( - max_num_tiles=4, - image_size=args.vision_chunk_size, - patch_size=self.patch_size, - n_global_layers=8, - global_model=True, - return_intermediate=return_intermediate, - ) - - class VisionEncoder(nn.Module): - def __init__( - self, - max_num_tiles: int, - ckpt_path: str = None, - image_size: int = 224, - patch_size: int = 14, - width: int = 1280, - layers: int = 32, - heads: int = 16, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - in_channels: int = 3, - load_ckpt: bool = False, - n_global_layers: int = 2, - global_model: bool = False, - return_intermediate=None, - ... - - self.global_transformer = ImageTransformer( - width, n_global_layers, heads, mlp_ratio, act_layer=act_layer, gated=True - - self.mlp = ImageFeedForward( - dim=d_model, - hidden_dim=int(mlp_ratio * d_model), - dropout=0.0, - act_layer=act_layer, - ) - ) - """ - dim = 1280 mlp_ratio = 4.0 act_layer = torch.nn.GELU @@ -133,9 +91,4 @@ def __init__( logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info("Llama_MLP Passed!") - else: - logger.warning("Llama_MLP Failed!") - - assert passing, f"Llama_MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index bc9d6b12b9a..feffe9e4472 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -33,14 +33,14 @@ ) @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], indirect=True, ) def test_llama_image_transformer_inference( batch, num_chunks, ntok, mesh_device, is_global, use_program_cache, reset_seeds, ensure_gc ): dtype = ttnn.bfloat16 - pcc = 0.86 + pcc_required = 0.86 mesh_device.enable_async(True) @@ -54,11 +54,11 @@ def test_llama_image_transformer_inference( n_layers = model_args.vision_n_global_layers return_intermediate = None else: - # first_layer_prefix = "vision_model.vision_encoder.transformer." first_layer_prefix = "vision_model.vision_encoder." gated = False n_layers = model_args.vision_n_layers # return_intermediate = [int(l) for l in "3,7,15,23,30".split(",")] + # Checks all intermediates return_intermediate = list(range(n_layers)) partial_state_dict = { @@ -124,7 +124,6 @@ def test_llama_image_transformer_inference( ) tt_attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) # Make striped attention mask to mask out our padding between 8 and 32 - # Striped mask doesn't affect PCC tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, npadtt, num_chunks) attention_input = attention_input.reshape(1, batch, -1, dim) @@ -162,25 +161,16 @@ def test_llama_image_transformer_inference( intermediates = [i.squeeze(-1) for i in intermediates] reference_output = reference_output.reshape(batch, num_chunks, ntok + npad, dim) reference_output = encoder_utils.contract_num_tokens_from_mult8(reference_output, npad) - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) - # Check mse + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + all_tests_pass = all_tests_pass and passing logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") if return_intermediate: for idx, (pt_interm, tt_interm) in enumerate(zip(intermediates, tt_intermed_torch)): - passing, pcc_message = comp_pcc(pt_interm, tt_interm, pcc) + passing, pcc_message = comp_pcc(pt_interm, tt_interm, pcc_required) logger.info(f"Intermediate {idx}: {pcc_message}") logger.info(comp_allclose(pt_interm, tt_interm)) + all_tests_pass = all_tests_pass and passing - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py index 11b4c58939f..0dbb20bed56 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py @@ -23,12 +23,12 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], indirect=True, ) def test_llama_vision_transformer_inference(mesh_device, use_program_cache, reset_seeds): dtype = ttnn.bfloat16 - pcc = 0.79 + pcc_required = 0.79 model_args = TtModelArgs(mesh_device) state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) @@ -45,8 +45,6 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese reference_model = llama_reference_mod.CrossAttentionTransformerVision(model_args) reference_model.load_state_dict(partial_state_dict, strict=True) - all_tests_pass = True - tt_model = TtLlamaCrossAttentionTransformerVision( mesh_device, state_dict, @@ -72,19 +70,8 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese logger.info(f"Reference output shape: {reference_output.shape}") logger.info(f"TT output shape: {tt_output_torch.shape}") - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(pcc_message) - - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py b/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py index 65147f91fea..5a739393486 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py @@ -23,12 +23,12 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], indirect=True, ) def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_seeds): dtype = ttnn.bfloat16 - pcc = 0.88 + pcc_required = 0.88 model_args = TtModelArgs(mesh_device) state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) @@ -52,8 +52,6 @@ def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_se ) reference_model.load_state_dict(partial_state_dict, strict=True) - all_tests_pass = True - tt_model = TtLlamaVisionEncoder( mesh_device, state_dict, @@ -75,45 +73,22 @@ def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_se tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) tt_output_torch = tt_output_torch[0, :, :, :].view(reference_output.shape) - INTERM_OUT = True - if INTERM_OUT: - # reference_output is [x] + [shuffled_int_x] - # tt_output is [x] + [int_x] - # To compare, we will shuffle tt_output. NOTE! This requires that the vision model shuffle its projection weights - tt_output_shuffled = torch.zeros_like(tt_output_torch) - tt_output_shuffled[..., : model_args.vision_dim] = tt_output_torch[..., : model_args.vision_dim] - tt_int_x = tt_output_torch[..., model_args.vision_dim :] - tt_int_x = ( - tt_int_x.reshape(reference_output.shape[:-1] + (5, model_args.vision_dim)) - .transpose(-1, -2) - .reshape(reference_output.shape[:-1] + (model_args.vision_dim * 5,)) - ) - tt_output_shuffled[..., model_args.vision_dim :] = tt_int_x - - logger.info(f"Reference output shape: {reference_output.shape}") - logger.info(f"TT output shape: {tt_output_shuffled.shape}") - - passing, pcc_message = comp_pcc(reference_output, tt_output_shuffled, pcc) - - logger.info(comp_allclose(reference_output, tt_output_shuffled)) - logger.info(f"PCC: {pcc_message}") - else: - logger.info(f"Reference output shape: {reference_output.shape}") - logger.info(f"TT output shape: {tt_output_torch.shape}") - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(pcc_message) - - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + # reference_output is [x] + [shuffled_int_x] + # tt_output is [x] + [int_x] + # To compare, we will shuffle tt_output. + tt_output_shuffled = torch.zeros_like(tt_output_torch) + tt_output_shuffled[..., : model_args.vision_dim] = tt_output_torch[..., : model_args.vision_dim] + tt_int_x = tt_output_torch[..., model_args.vision_dim :] + tt_int_x = ( + tt_int_x.reshape(reference_output.shape[:-1] + (5, model_args.vision_dim)) + .transpose(-1, -2) + .reshape(reference_output.shape[:-1] + (model_args.vision_dim * 5,)) + ) + tt_output_shuffled[..., model_args.vision_dim :] = tt_int_x + + passing, pcc_message = comp_pcc(reference_output, tt_output_shuffled, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_shuffled)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py index cdabc148528..9c0efb1b2d4 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py +++ b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py @@ -96,9 +96,4 @@ def test_layernorm_inference(mesh_device, seq_len, use_program_cache, reset_seed logger.info(comp_allclose(reference_output, tt_output)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info("LayerNorm on device {idx} Passed!") - else: - logger.warning("LayerNorm {idx} Failed!") - - assert passing, f"LayerNorm output does not meet PCC requirement {pcc_required}: {pcc_message}." + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py index 5aa233f39db..75ab5b0bbf1 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py @@ -76,7 +76,7 @@ def forward(self, x, ar): @pytest.mark.parametrize( "mesh_device", [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get( + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) ) ], @@ -114,7 +114,7 @@ def test_llama_positional_embedding_inference( ensure_gc, ): dtype = ttnn.bfloat16 - pcc = 0.9999 + pcc_required = 0.9999 mesh_device.enable_async(True) @@ -196,13 +196,8 @@ def test_llama_positional_embedding_inference( # Only select output from one device tt_output_torch = tt_output_torch[..., :dim] - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info(f"Llama_PositionalEmbedding Passed!") - else: - logger.warning(f"Llama_PositionalEmbedding Failed!") - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py index 7ffa7943d1e..9d3230541e7 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py @@ -38,7 +38,7 @@ @pytest.mark.parametrize( "mesh_device", [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get( + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) ) ], @@ -81,7 +81,7 @@ def test_llama_conv2d_inference( ensure_gc, ): dtype = ttnn.bfloat16 - pcc = 0.9999 + pcc_required = 0.9999 mesh_device.enable_async(True) @@ -156,13 +156,8 @@ def test_llama_conv2d_inference( # Only select output from one device tt_output_torch = tt_output_torch[..., :dim] - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info(f"Llama_TilePositionEmbedding Passed!") - else: - logger.warning(f"Llama_TilePositionEmbedding Failed!") - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index ad5173533cb..699a323f220 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -500,9 +500,7 @@ def find_largest_divisor(n, max_divisor=8): self.model_config["CROSS_TRANSFORMER_TEXT_OUTPUT_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config( m=min(seq_len, max_seq), k=self.dim, - n=self.vocab_size - // 4 - // self.num_devices, # TODO: Remove magic number 8 from cross attention transformer text + n=self.vocab_size // 8, # Magic number. LM Head always contains 8 splits grid_size=(8, 8), in0_block_w=1, fuse_batch=seq_len <= max_seq, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index bf7a8ee7575..7128609c04a 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -156,12 +156,26 @@ def compute_xattn_kv_cache(self, xattn_tokens): xk = ttnn.reshape(xk, [1, bsz, seqlen_y, -1]) xv = ttnn.reshape(xv, [1, bsz, seqlen_y, -1]) - xk, _, _ = ttnn.experimental.nlp_create_qkv_heads( - xk, xk, num_heads=self.n_local_kv_heads, num_kv_heads=self.n_local_kv_heads // 2, transpose_k_heads=False - ) - xv, _, _ = ttnn.experimental.nlp_create_qkv_heads( - xv, xv, num_heads=self.n_local_kv_heads, num_kv_heads=self.n_local_kv_heads // 2, transpose_k_heads=False - ) + if self.n_local_kv_heads == 1: + # Only a simple reshape required, no need to split + xk = ttnn.reshape(xk, [bsz, 1, seqlen_y, -1]) + xv = ttnn.reshape(xv, [bsz, 1, seqlen_y, -1]) + else: + # 1, B, S, D -> B, NH, S, DH? + xk, _, _ = ttnn.experimental.nlp_create_qkv_heads( + xk, + xk, + num_heads=self.n_local_kv_heads, + num_kv_heads=self.n_local_kv_heads // 2, + transpose_k_heads=False, + ) + xv, _, _ = ttnn.experimental.nlp_create_qkv_heads( + xv, + xv, + num_heads=self.n_local_kv_heads, + num_kv_heads=self.n_local_kv_heads // 2, + transpose_k_heads=False, + ) xk = self.k_norm(xk) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 777a19d3b61..c185d89208c 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -95,7 +95,8 @@ def __init__( # ) lm_head_torch = self.state_dict[f"{state_dict_prefix}output.weight"].transpose(-1, -2) - num_splits = 4 # arbitrary, reasonable number + total_splits = 8 # Arbitrary value which allows whole-tile splits in LM Head + num_splits = total_splits // self.mesh_device.num_devices lm_head_torch = torch.chunk(lm_head_torch, num_splits, dim=-1) cache_name = lambda name, suffix, split: weight_cache_path / (state_dict_prefix + f"{name}{suffix}{split}") diff --git a/models/demos/llama3/tt/multimodal/llama_image_attention.py b/models/demos/llama3/tt/multimodal/llama_image_attention.py index 20da869640c..6600258986b 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_image_attention.py @@ -273,9 +273,7 @@ def forward_tt(self, x_11SH, mask=None): ttnn.deallocate(xqkv_fused) # sdpa_cfg = self.model_config["SDPA_PROGCFG"](seq_len) sdpa_cfg = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=(8, 8), - q_chunk_size=128, - k_chunk_size=128, + compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False ) attn_output_1QSD = ttnn.transformer.scaled_dot_product_attention( q_heads_1QSD, From f11162cbd90840e53d280072e2ff8f388e3569bd Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Tue, 22 Oct 2024 12:37:34 +0000 Subject: [PATCH 07/16] #13368: Fix LM head splits calculation --- .../llama3/tests/multimodal/test_llama_class_embedding.py | 4 +--- .../demos/llama3/tests/multimodal/test_llama_conv2d_patch.py | 3 +-- .../tt/multimodal/llama_cross_attention_transformer_text.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py index 09aaa82d16c..fcdf64003e0 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py @@ -65,9 +65,7 @@ def forward(self, x): @pytest.mark.parametrize( "input_shape", [ - ((1, 4, 4, 1024, 1280)), - ((1, 4, 4, 1024 + 1, 1280)), - ((1, 4, 4, 1032, 1280)), + ((1, 4, 4, 1024, 1280)), # Patch 448 ], ) @pytest.mark.parametrize( diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index 12de8c159f1..41e50345745 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -45,8 +45,7 @@ @pytest.mark.parametrize( "input_shape, in_channels, out_channels, kernel_size, stride, bias", [ - # ((1, 3, 32 * 32, 32 * 32), 3, 512, 32, 32, False), - ((1, 3, 14 * 32, 14 * 32), 3, 1280, 14, 14, False), # Llama3.2 case + ((1, 3, 448, 448), 3, 1280, 14, 14, False), # Llama3.2-11B Base ], ) def test_llama_conv2d_inference( diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index c185d89208c..34dd88d19c1 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -96,7 +96,7 @@ def __init__( lm_head_torch = self.state_dict[f"{state_dict_prefix}output.weight"].transpose(-1, -2) total_splits = 8 # Arbitrary value which allows whole-tile splits in LM Head - num_splits = total_splits // self.mesh_device.num_devices + num_splits = total_splits // self.configuration.num_devices lm_head_torch = torch.chunk(lm_head_torch, num_splits, dim=-1) cache_name = lambda name, suffix, split: weight_cache_path / (state_dict_prefix + f"{name}{suffix}{split}") From f64f65ad23af874f91cd9d5f45db6b6dcd62f045 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 23 Oct 2024 13:13:49 +0000 Subject: [PATCH 08/16] #13368: For all vision tests, get model-specific parameters from model_args rather than fixtures. This generalizes tests for base and finetuned 11B models. --- .../demos/llama3/demo/multimodal_demo_chat.py | 108 ++++++++++++ .../demos/llama3/demo/multimodal_demo_text.py | 2 +- models/demos/llama3/lt | 8 +- .../multimodal/test_llama_class_embedding.py | 19 +-- .../multimodal/test_llama_conv2d_patch.py | 29 +--- .../multimodal/test_llama_cross_attention.py | 11 +- ..._llama_cross_attention_transformer_text.py | 9 +- ...ama_cross_attention_transformer_vision.py} | 8 +- .../multimodal/test_llama_cross_block.py | 13 +- .../multimodal/test_llama_image_attention.py | 17 +- .../multimodal/test_llama_image_block.py | 13 +- .../tests/multimodal/test_llama_image_mlp.py | 14 +- .../test_llama_image_transformer.py | 10 +- .../tests/multimodal/test_llama_layernorm.py | 9 +- .../test_llama_positional_embedding.py | 40 ++--- .../test_llama_tile_position_embedding.py | 44 ++--- ...ncoder.py => test_llama_vision_encoder.py} | 2 +- .../multimodal/test_llama_vision_model.py | 154 ++++++++++++++++++ models/demos/llama3/tt/model_config.py | 13 +- .../tt/multimodal/llama_conv2d_patch.py | 1 - .../tt/multimodal/llama_cross_attention.py | 36 +--- .../llama_cross_attention_transformer_text.py | 22 --- ...ama_cross_attention_transformer_vision.py} | 2 +- .../tt/multimodal/llama_image_attention.py | 145 +---------------- .../llama3/tt/multimodal/llama_image_block.py | 41 +---- .../llama3/tt/multimodal/llama_image_mlp.py | 59 +------ .../llama3/tt/multimodal/llama_layernorm.py | 41 +---- .../llama_tile_position_embedding.py | 1 - ...ion_encoder.py => llama_vision_encoder.py} | 10 +- .../tt/multimodal/llama_vision_model.py | 14 +- 30 files changed, 407 insertions(+), 488 deletions(-) rename models/demos/llama3/tests/multimodal/{test_llama_image_transformer_vision.py => test_llama_cross_attention_transformer_vision.py} (91%) rename models/demos/llama3/tests/multimodal/{test_llama_image_vision_encoder.py => test_llama_vision_encoder.py} (97%) create mode 100644 models/demos/llama3/tests/multimodal/test_llama_vision_model.py rename models/demos/llama3/tt/multimodal/{llama_image_transformer_vision.py => llama_cross_attention_transformer_vision.py} (97%) rename models/demos/llama3/tt/multimodal/{llama_image_vision_encoder.py => llama_vision_encoder.py} (98%) diff --git a/models/demos/llama3/demo/multimodal_demo_chat.py b/models/demos/llama3/demo/multimodal_demo_chat.py index e69de29bb2d..05ee6c4159d 100644 --- a/models/demos/llama3/demo/multimodal_demo_chat.py +++ b/models/demos/llama3/demo/multimodal_demo_chat.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Optional +from loguru import logger + +from PIL import Image as PIL_Image +from termcolor import cprint + +from models.demos.llama3.demo.multimodal_demo_text import create_multimodal_model +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.generation as llama_reference_generation + +from models.demos.llama3.reference.llama_models.models.llama3.api.datatypes import ImageMedia, UserMessage + +THIS_DIR = Path(__file__).parent.parent.resolve() / "reference/llama_models/models/scripts/" + +import torch +import pytest +import os +import ttnn + + +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "target", + ("tt", "cpu"), +) +@pytest.mark.parametrize( + "warmup_iters", + (0, 1), +) +def test_llama_multimodal_demo_chat( + mesh_device, + target, + warmup_iters, + temperature: float = 0.5, + top_p: float = 0.9, + max_seq_len: int = 512, + max_batch_size: int = 4, + max_gen_len: Optional[int] = 200, + model_parallel_size: Optional[int] = None, +): + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") + generator = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + + if target == "tt": + logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") + model = create_multimodal_model(generator.args, mesh_device) + generator.model = model + + # image understanding + dialogs = [] + with open(THIS_DIR / "resources/dog.jpg", "rb") as f: + img = PIL_Image.open(f).convert("RGB") + + dialogs = [ + [ + UserMessage( + content=[ + ImageMedia(image=img), + "Describe this image in two sentences", + ], + ) + ], + ] + # text only + dialogs += [ + [UserMessage(content="what is the recipe of mayonnaise in two sentences?")], + ] + + print(f"Running text completion on {target}") + for _ in range(warmup_iters + 1): + for dialog in dialogs: + result = generator.chat_completion( + dialog, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + + for msg in dialog: + print(f"{msg.role.capitalize()}: {msg.content}\n") + + out_message = result.generation + print(f"> {out_message.role.capitalize()}: {out_message.content}") + for t in out_message.tool_calls: + print(f" Tool call: {t.tool_name} ({t.arguments})") + print("\n==================================\n") diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index 223d4dc96f3..f2eada1966c 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - # SPDX-License-Identifier: Apache-2.0 + from pathlib import Path from typing import Optional from loguru import logger diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index be513f6e1bd..280751cf7b9 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -190,7 +190,7 @@ def main(stdscr): for c in commands for m in models for d in devices - if not (m == "11b" and d == "n150") + if not (m in ["11b", "11b-b"] and d == "n150") ] # Create output entries @@ -516,15 +516,16 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): "vision-mlp": "pytest models/demos/llama3/tests/multimodal/test_llama_image_mlp.py", "vision-attn": "pytest models/demos/llama3/tests/multimodal/test_llama_image_attention.py", "vision-block": "pytest models/demos/llama3/tests/multimodal/test_llama_image_block.py", + "vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_image_transformer.py", "vision-xattn": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention.py", "vision-xblock": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_block.py", "vision-conv": "pytest models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py", "vision-class": "pytest models/demos/llama3/tests/multimodal/test_llama_class_embedding.py", "vision-tile-pos": "pytest models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py", "vision-pos": "pytest models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py", - "vision-encoder": "pytest models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py", + "vision-encoder": "pytest models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py", "vision-text-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py", - "vision-vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py", + "vision-vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py", } # Check if the command is a shortcut and replace it if necessary @@ -676,6 +677,7 @@ def get_llama_dir(model): "3b": os.environ.get("LLAMA_32_3B_DIR", "/proj_sw/user_dev/llama32-data/Llama3.2-3B-Instruct"), "8b": os.environ.get("LLAMA_31_8B_DIR", "/proj_sw/user_dev/llama31-8b-data/Meta-Llama-3.1-8B-Instruct"), "11b": os.environ.get("LLAMA_32_11B_DIR", "/proj_sw/user_dev/llama32-data/Llama3.2-11B-Vision-Instruct"), + "11b-b": os.environ.get("LLAMA_32_11B_BASE_DIR", "/proj_sw/user_dev/llama32-data/Llama3.2-11B-Vision"), }.get(model.lower(), "") if not llama_dir or not os.path.exists(llama_dir): diff --git a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py index fcdf64003e0..663787a18d1 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py @@ -3,11 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 ##### Python imports ##### -import math import pytest from loguru import logger import os -import itertools ##### PyTorch imports ##### import torch @@ -63,9 +61,9 @@ def forward(self, x): indirect=True, ) @pytest.mark.parametrize( - "input_shape", + "bsz, num_concurrent_media, num_chunks", [ - ((1, 4, 4, 1024, 1280)), # Patch 448 + ((1, 4, 4)), ], ) @pytest.mark.parametrize( @@ -79,7 +77,9 @@ def test_llama_class_embedding_inference( use_program_cache, reset_seeds, # Input params - input_shape, + bsz, + num_concurrent_media, + num_chunks, layout, ensure_gc, ): @@ -95,13 +95,8 @@ def test_llama_class_embedding_inference( k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - ( - bsz, - num_concurrent_media, - num_chunks, - ntok, - dim, - ) = input_shape + ntok = nearest_32(model_args.vision_chunk_ntok) + dim = model_args.vision_dim ##### Prepare inputs ##### input_tensor = torch.randn(bsz * num_concurrent_media * num_chunks, ntok, dim) diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index 41e50345745..d98d1c8613e 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 ##### Python imports ##### -import math import pytest from loguru import logger import os @@ -21,9 +20,6 @@ comp_pcc, comp_allclose, ) -from models.utility_functions import ( - nearest_32, -) from models.demos.llama3.tt.multimodal.llama_conv2d_patch import ( TtLlamaConv2dPatch, ) @@ -42,24 +38,10 @@ ], indirect=True, ) -@pytest.mark.parametrize( - "input_shape, in_channels, out_channels, kernel_size, stride, bias", - [ - ((1, 3, 448, 448), 3, 1280, 14, 14, False), # Llama3.2-11B Base - ], -) def test_llama_conv2d_inference( mesh_device, use_program_cache, reset_seeds, - # Input params - input_shape, - # Conv2d patch params - in_channels, - out_channels, - kernel_size, - stride, - bias, ensure_gc, ): pcc_required = 0.9999 @@ -78,7 +60,14 @@ def test_llama_conv2d_inference( num_devices = model_args.num_devices ##### Create input tensor for the all gather ##### - B, NCH, H, W = input_shape + B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size) + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + model_args.vision_dim, + model_args.vision_patch_size, + model_args.vision_patch_size, + False, + ) assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch." assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported." @@ -88,7 +77,7 @@ def test_llama_conv2d_inference( assert W % kernel_size == 0, "Width should be divisible by kernel_size." ##### Prepare inputs ##### - input_tensor = torch.randn(input_shape) + input_tensor = torch.randn((B, NCH, H, W)) logger.info(f"Input tensor shape: {input_tensor.shape}") ##### Perform the torch ops ##### diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 434da78df59..ba0e269480f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -17,15 +17,12 @@ from models.utility_functions import ( comp_pcc, comp_allclose, + nearest_32, ) from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "vision_seq_len", - (4224,), -) @pytest.mark.parametrize( "text_seq_len", (2048,), @@ -39,9 +36,7 @@ ], indirect=True, ) -def test_llama_cross_attention_inference( - vision_seq_len, text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc -): +def test_llama_cross_attention_inference(text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 pcc_required = 0.99 @@ -67,6 +62,8 @@ def test_llama_cross_attention_inference( reference_model.load_state_dict(partial_state_dict) batch = 1 + num_chunks = 4 + vision_seq_len = num_chunks * nearest_32(model_args.vision_chunk_ntok) all_tests_pass = True diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index a0a21debe64..f11165862b6 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -23,15 +23,12 @@ from models.utility_functions import ( comp_pcc, comp_allclose, + nearest_32, ) from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "vision_seq_len", - (4224,), -) @pytest.mark.parametrize( "text_seq_len", (2048,), @@ -46,7 +43,6 @@ indirect=True, ) def test_llama_cross_attention_transformer_text_inference( - vision_seq_len, text_seq_len, mesh_device, use_program_cache, @@ -90,6 +86,9 @@ def test_llama_cross_attention_transformer_text_inference( reference_model.load_state_dict(partial_state_dict) batch = 1 + num_chunks = 4 + chunk_length = nearest_32(model_args.vision_chunk_ntok) + vision_seq_len = num_chunks * chunk_length all_tests_pass = True diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py similarity index 91% rename from models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py rename to models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py index 0dbb20bed56..6555321578a 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py @@ -8,7 +8,9 @@ import ttnn import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod -from models.demos.llama3.tt.multimodal.llama_image_transformer_vision import TtLlamaCrossAttentionTransformerVision +from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_vision import ( + TtLlamaCrossAttentionTransformerVision, +) from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, @@ -56,8 +58,8 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese ) # Create rand inputs of the right shape - batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, 448) - chunk_seq_len = (patch_size // model_args.vision_patch_size) ** 2 + 1 # tokens per chunk + 1 class token + batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, model_args.vision_chunk_size) + chunk_seq_len = model_args.vision_chunk_ntok - 1 # tokens per chunk without class token images = torch.randn(batch, num_media, num_chunks, n_channel, patch_size, patch_size) ars = torch.tensor([2, 2]).reshape(batch, num_media, 2) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 78020e80358..f45f0eaa432 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -14,18 +14,11 @@ prepare_inputs_ttnn_prefill, prepare_inputs_ttnn, ) -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) +from models.utility_functions import comp_pcc, comp_allclose, nearest_32 from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "vision_seq_len", - (4224,), -) @pytest.mark.parametrize( "text_seq_len", (2048,), @@ -40,7 +33,7 @@ indirect=True, ) def test_llama_cross_attention_transformer_block_inference( - vision_seq_len, text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc + text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc ): dtype = ttnn.bfloat16 pcc_required = 0.99 @@ -65,6 +58,8 @@ def test_llama_cross_attention_transformer_block_inference( reference_model.load_state_dict(partial_state_dict) batch = 1 + num_chunks = 4 + vision_seq_len = num_chunks * nearest_32(model_args.vision_chunk_ntok) all_tests_pass = True diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py index e0a14b69e7b..357f02a5b10 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py @@ -10,7 +10,7 @@ import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_attention import TtLlamaImageAttention -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding +from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, @@ -24,8 +24,8 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( - "batch, num_chunks, ntok", - ((1, 4, 1024),), + "batch, num_chunks", + ((1, 4),), ) @pytest.mark.parametrize( "mesh_device", @@ -36,7 +36,7 @@ ], indirect=True, ) -def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_program_cache, reset_seeds, ensure_gc): +def test_llama_attention_inference(batch, num_chunks, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 pcc_required = 0.99 @@ -51,8 +51,9 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - dim = 1280 - heads = 16 + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + ntok = model_args.vision_chunk_ntok reference_model = llama_reference_mod.ImageAttention(dim=dim, head_dim=dim // heads, n_heads=heads) reference_model.load_state_dict(partial_state_dict) @@ -88,7 +89,7 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro tt_attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) # Make striped attention mask to mask out our padding between 8 and 32 # Striped mask doesn't affect PCC on first layer but is necessary for later layers - tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, 32, num_chunks) + tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, npadtt, num_chunks) attention_input = attention_input.reshape(1, batch, -1, dim) @@ -104,7 +105,7 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro tt_out = tt_model(attention_input, mask=tt_mask) # Doing contract in tt is correct!! - tt_out = tt_out.reshape(batch, num_chunks, ntok + 32, dim) + tt_out = tt_out.reshape(batch, num_chunks, ntok + npadtt, dim) tt_out = ttnn.slice(tt_out, (0, 0, 0, 0), (batch, num_chunks, ntok, dim)) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 17b3c219446..613fd2a3021 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -10,7 +10,7 @@ import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_block import TtLlamaImageTransformerBlock -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding +from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, @@ -24,8 +24,8 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( - "batch, num_chunks, ntok", - ((1, 4, 1024),), + "batch, num_chunks", + ((1, 4),), ) @pytest.mark.parametrize( "gated", @@ -36,7 +36,7 @@ [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], indirect=True, ) -def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_program_cache, reset_seeds, ensure_gc): +def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 pcc_required = 0.99 @@ -56,6 +56,7 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ dim = model_args.vision_dim heads = model_args.vision_attn_n_heads + ntok = model_args.vision_chunk_ntok reference_model = llama_reference_mod.ImageTransformerBlock( d_model=dim, n_head=heads, mlp_ratio=model_args.vision_mlp_ratio, gated=gated ) @@ -92,7 +93,7 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ attention_input.shape[0], attention_input.shape[1], attention_input.shape[2], attention_input.shape[3] ) tt_attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) - tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, 32, num_chunks) + tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, npadtt, num_chunks) attention_input = attention_input.reshape(1, batch, -1, dim) tt_mask = ttnn.from_torch( @@ -105,7 +106,7 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ ) tt_out = tt_model(attention_input, mask=tt_mask) - tt_out = tt_out.reshape(batch, num_chunks, ntok + 32, dim) + tt_out = tt_out.reshape(batch, num_chunks, ntok + npadtt, dim) tt_out = ttnn.slice(tt_out, (0, 0, 0, 0), (batch, num_chunks, ntok, dim)) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py index 958bfb919ca..4181f9dfd0c 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py @@ -14,14 +14,15 @@ from models.utility_functions import ( comp_pcc, comp_allclose, + nearest_32, ) from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( - "seq_len", - (4224,), + "batch, num_chunks", + ((1, 4),), ) @pytest.mark.parametrize( "mesh_device", @@ -32,7 +33,7 @@ ], indirect=True, ) -def test_llama_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +def test_llama_mlp_inference(batch, num_chunks, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 mesh_device.enable_async(True) @@ -48,8 +49,9 @@ def test_llama_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seed model_args.WEIGHTS_DTYPE = dtype - dim = 1280 - mlp_ratio = 4.0 + dim = model_args.vision_dim + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks + mlp_ratio = model_args.vision_mlp_ratio act_layer = torch.nn.GELU dropout = 0.0 reference_model = llama_reference_mod.ImageFeedForward( @@ -68,7 +70,7 @@ def test_llama_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seed weight_cache_path=model_args.weight_cache_path(dtype), dtype=dtype, ) - torch_input = torch.randn(1, 1, seq_len, dim) + torch_input = torch.randn(1, batch, seq_len, dim) reference_output = reference_model(torch_input).squeeze() tt_input = ttnn.from_torch( torch_input, diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index feffe9e4472..1fee8a125c4 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -11,7 +11,7 @@ from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_transformer import TtLlamaImageTransformer from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding +from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, ) @@ -24,8 +24,8 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( - "batch, num_chunks, ntok", - ((1, 4, 1024),), + "batch, num_chunks", + ((1, 4),), ) @pytest.mark.parametrize( "is_global", @@ -37,7 +37,7 @@ indirect=True, ) def test_llama_image_transformer_inference( - batch, num_chunks, ntok, mesh_device, is_global, use_program_cache, reset_seeds, ensure_gc + batch, num_chunks, mesh_device, is_global, use_program_cache, reset_seeds, ensure_gc ): dtype = ttnn.bfloat16 pcc_required = 0.86 @@ -66,7 +66,7 @@ def test_llama_image_transformer_inference( } dim = model_args.vision_dim - heads = model_args.vision_attn_n_heads + ntok = model_args.vision_chunk_ntok - 1 # NOTE: -1 to remove class embedding reference_model = llama_reference_mod.VisionEncoder( max_num_tiles=4, diff --git a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py index 9c0efb1b2d4..23294782651 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py +++ b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py @@ -14,15 +14,12 @@ from models.utility_functions import ( comp_pcc, comp_allclose, + nearest_32, ) from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - (4224,), -) @pytest.mark.parametrize( "mesh_device", [ @@ -32,13 +29,15 @@ ], indirect=True, ) -def test_layernorm_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +def test_layernorm_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 mesh_device.enable_async(True) model_args = TtModelArgs(mesh_device) width = model_args.vision_dim + num_chunks = 4 + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names diff --git a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py index 75ab5b0bbf1..c5262bf2235 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py @@ -83,37 +83,21 @@ def forward(self, x, ar): indirect=True, ) @pytest.mark.parametrize( - "image_size, patch_size", - [ - ((448, 448), (14, 14)), - ], -) -@pytest.mark.parametrize( - "input_shape, max_num_tiles", - [ - ((1, 4, 4, 1024 + 1, 1280), 4), - ], -) -@pytest.mark.parametrize( - "layout", - [ - ttnn.TILE_LAYOUT, - ], + "bsz, num_concurrent_media, num_chunks", + [(1, 4, 4)], ) def test_llama_positional_embedding_inference( mesh_device, use_program_cache, reset_seeds, # Input params - input_shape, - layout, - # Positional Embedding params - image_size, - patch_size, - max_num_tiles, + bsz, + num_concurrent_media, + num_chunks, ensure_gc, ): dtype = ttnn.bfloat16 + layout = ttnn.TILE_LAYOUT pcc_required = 0.9999 mesh_device.enable_async(True) @@ -125,15 +109,13 @@ def test_llama_positional_embedding_inference( k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - ( - bsz, - num_concurrent_media, - num_chunks, - ntok, - dim, - ) = input_shape + ntok = model_args.vision_chunk_ntok + dim = model_args.vision_dim + image_size = (model_args.vision_chunk_size, model_args.vision_chunk_size) + patch_size = (model_args.vision_patch_size, model_args.vision_patch_size) ##### Check parms ##### + max_num_tiles = model_args.vision_max_num_chunks assert num_chunks == max_num_tiles, "num_chunks must be the same value as max_num_tiles!" ##### Prepare inputs ##### diff --git a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py index 9d3230541e7..2249b684cbd 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py @@ -45,42 +45,27 @@ indirect=True, ) @pytest.mark.parametrize( - "gated", + "bsz, num_concurrent_media, num_chunks", [ - True, - ], -) -@pytest.mark.parametrize( - "input_shape, dim, max_num_tiles", - [ - ((1, 32, 4, 1032), 1280, 4), - ((1, 8, 4, 1032), 1280, 4), - ((1, 4, 4, 1032), 1280, 4), - ((1, 1, 4, 1032), 1280, 4), - ((1, 1, 4, 1024), 1280, 4), - # ((1, 32, 16, 1032), 1280, 16), # Large test, takes some time - ], -) -@pytest.mark.parametrize( - "layout", - [ - ttnn.TILE_LAYOUT, + (1, 1, 4), + (1, 4, 4), ], ) +@pytest.mark.parametrize("pre_embed", [False, True]) def test_llama_conv2d_inference( mesh_device, use_program_cache, reset_seeds, # Input params - input_shape, - layout, - # Tile Position Embedding params - dim, - gated, - max_num_tiles, + bsz, + num_concurrent_media, + num_chunks, + pre_embed, ensure_gc, ): dtype = ttnn.bfloat16 + layout = ttnn.TILE_LAYOUT + gated = True pcc_required = 0.9999 mesh_device.enable_async(True) @@ -89,13 +74,16 @@ def test_llama_conv2d_inference( state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - first_layer_prefix = "vision_model.vision_encoder.pre_tile_pos_embed." + first_layer_prefix = "vision_model.vision_encoder." + ( + "pre_tile_pos_embed." if pre_embed else "post_tile_pos_embed." + ) partial_state_dict = { k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - num_devices = model_args.num_devices - bsz, num_concurrent_media, num_chunks, ntok = input_shape + ntok = nearest_32(model_args.vision_chunk_ntok - (0 if pre_embed else 1)) + dim = model_args.vision_dim + max_num_tiles = model_args.vision_max_num_tiles ##### Check parms ##### assert num_chunks == max_num_tiles, "num_chunks must be the same value as max_num_tiles!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py similarity index 97% rename from models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py rename to models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py index 5a739393486..111c584f781 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py @@ -8,7 +8,7 @@ import ttnn import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import TtLlamaVisionEncoder +from models.demos.llama3.tt.multimodal.llama_vision_encoder import TtLlamaVisionEncoder from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_model.py b/models/demos/llama3/tests/multimodal/test_llama_vision_model.py new file mode 100644 index 00000000000..e39555d4be1 --- /dev/null +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_model.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path +from typing import Optional +from loguru import logger + +from PIL import Image as PIL_Image +from termcolor import cprint + +import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.generation as llama_reference_generation + +from models.demos.llama3.reference.llama_models.models.llama3.api.datatypes import ImageMedia + +from models.utility_functions import ( + comp_pcc, + comp_allclose, +) + +THIS_DIR = Path(__file__).parent.parent.parent.resolve() / "reference/llama_models/models/scripts/" + +import torch +import pytest +import os +import ttnn + + +def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16): + from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer + from models.demos.llama3.tt.model_config import TtModelArgs + + tt_model_args = TtModelArgs(mesh_device) + checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) + model = CrossAttentionTransformer( + model_args, + mesh_device, + checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + ) + model.setup_cache(model_args.max_batch_size, torch.float32) + return model + + +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_llama_vision_model( + mesh_device, + temperature: float = 0, + max_seq_len: int = 512, + max_batch_size: int = 4, + max_gen_len: Optional[int] = 50, + model_parallel_size: Optional[int] = None, +): + """ + This test runs the Llama3.2 vision model on CPU and TT concurrently. + It does not use teacher forcing and compares output logits at every token. + """ + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") + generator_pt = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + + generator_tt = llama_reference_generation.Llama(generator_pt.model, generator_pt.tokenizer, generator_pt.args) + logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") + model = create_multimodal_model(generator_tt.args, mesh_device) + generator_tt.model = model + + # with open(THIS_DIR / "resources/dog.jpg", "rb") as f: + # img = PIL_Image.open(f).convert("RGB") + + # with open(THIS_DIR / "resources/pasta.jpeg", "rb") as f: + # img2 = PIL_Image.open(f).convert("RGB") + + with open(THIS_DIR / "resources/ocr_image.jpeg", "rb") as f: + ocr_image = PIL_Image.open(f).convert("RGB") + + # with open(THIS_DIR / "resources/clutter.jpeg", "rb") as f: + # clutter = PIL_Image.open(f).convert("RGB") + + interleaved_contents = [ + # text only + # "The color of the sky is blue but sometimes it can also be", + # image understanding + # [ImageMedia(image=img), "If I had to write a haiku for this one"], + # [ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"], + [ImageMedia(image=ocr_image), "The full text in this image is as follows"], + # [ImageMedia(image=clutter), "The count of vases, books, and miscellaneous items in this image is"], + ] + + for content in interleaved_contents: + logger.info(f"Generating text for content: {content}") + model_input = generator_pt.formatter.encode_content(content) + gen_pt = generator_pt.generate( + model_input, max_gen_len=max_gen_len, temperature=temperature, return_logits=True + ) + gen_tt = generator_tt.generate( + model_input, max_gen_len=max_gen_len, temperature=temperature, return_logits=True + ) + + for out_idx, (token_pt, token_tt) in enumerate(zip(gen_pt, gen_tt)): + logger.info(f"Comparing output token {out_idx}") + out_pt, out_tt = token_pt[1], token_tt[1] + out_pt = out_pt[0, -1] + out_tt = out_tt[0, -1] + passing, pcc_message = comp_pcc(out_pt, out_tt, 0.90) + print(f"PCC: {pcc_message}") + # Check shapes of logprobs + + ref_argmax = torch.argmax(out_pt).item() + ref_logprob = out_pt[ref_argmax].item() + ref_token = generator_pt.tokenizer.decode([ref_argmax]) + + # Reference model: top-5 tokens + ref_top5_vals, ref_top5_idxs = torch.topk(out_pt, 5) + ref_top5_tokens = [generator_pt.tokenizer.decode([idx.item()]) for idx in ref_top5_idxs] + ref_top5_logprobs = ref_top5_vals.tolist() + + # Test model: top-5 tokens + top5_vals, top5_idxs = torch.topk(out_tt, 5) + top5_tokens = [generator_pt.tokenizer.decode([idx.item()]) for idx in top5_idxs] + top5_logprobs = top5_vals.tolist() + + def entropy(logits): + probs = torch.softmax(logits, dim=-1) + return -(probs * torch.log(probs)).sum().item() + + # Print the information + print(f"Token Position {out_idx}:") + print(f" Reference | Test") + print(f" Entropy: {entropy(out_pt):.4f} | {entropy(out_tt):.4f}") + print(f" Top-5 Tokens:") + for rank in range(5): + print( + f" {rank+1}. Token='{ref_top5_tokens[rank]}' @ {ref_top5_logprobs[rank]:.2f} | '{top5_tokens[rank]}' @ {top5_logprobs[rank]:.2f}" + ) + print() diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 699a323f220..30f0055c497 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -208,7 +208,7 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s self.compute_kernel_config_sdpa = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi4, math_approx_mode=False, - fp32_dest_acc_en=False, + fp32_dest_acc_en=True, packer_l1_acc=False, ) @@ -506,6 +506,8 @@ def find_largest_divisor(n, max_divisor=8): fuse_batch=seq_len <= max_seq, ) + self.VISION_MAX_MM_SEQ = nearest_32(self.vision_chunk_ntok) + def _set_llama_params_from_dict(self, params): # Text params self.dim = params["dim"] @@ -533,13 +535,20 @@ def _set_llama_params_from_dict(self, params): self.vision_act_layer = ttnn.UnaryOpType.GELU self.vision_dropout = 0.0 self.vision_attn_n_heads = 16 - self.vision_head_dim = self.vision_hidden_dim // self.vision_attn_n_heads + self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads self.vision_n_layers = 32 self.vision_n_global_layers = 8 self.vision_max_num_tiles = 4 self.vision_patch_size = 14 self.vision_in_channels = 3 + @property + def vision_chunk_ntok(self): + """ + Returns the number of tokens per chunk, accounting for the extra class token + """ + return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1 + def _set_llama_params(self, checkpoint_dir): params_file = os.path.join(checkpoint_dir, "params.json") assert os.path.exists(params_file), f"params.json file not found at {params_file}" diff --git a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py index 668edf2c001..a4d1bb59885 100644 --- a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py +++ b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py @@ -85,7 +85,6 @@ def __init__( fp32_dest_acc_en=True, packer_l1_acc=True, ) - self.program_config = None # TODO: Update with actual program config def forward(self, x: torch.Tensor): x = self._unfold(x) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 7128609c04a..a1554b0dfe7 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -47,7 +47,8 @@ def __init__( self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 - self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + + self.configuration = configuration self.model_config = configuration.get_model_config() @@ -131,7 +132,7 @@ def __init__( def compute_xattn_kv_cache(self, xattn_tokens): bsz, seqlen_y = xattn_tokens.shape[1], xattn_tokens.shape[2] - MAX_MM_SEQ_LEN = 1056 + MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ if seqlen_y > MAX_MM_SEQ_LEN: xattn_tokens = ttnn.reshape(xattn_tokens, [1, bsz * seqlen_y // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) @@ -179,23 +180,19 @@ def compute_xattn_kv_cache(self, xattn_tokens): xk = self.k_norm(xk) + # NOTE: Doing repeat in xattn_cache generation to avoid massive overhead in forward xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) return [xk, xv] - ### EVERYTHING BELOW IS BROKEN OMG - # BEWARNED! TMs are dangerous! - # WORKAROUND + ### Below is how I would like to implement TMs, but it results in poor PCC xk = ttnn.to_layout(xk, layout=ttnn.ROW_MAJOR_LAYOUT) xv = ttnn.to_layout(xv, layout=ttnn.ROW_MAJOR_LAYOUT) xk = xk.reshape(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) xv = xv.reshape(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) - # xk = ttnn.to_memory_config(xk, ttnn.L1_MEMORY_CONFIG) - # xk = ttnn.to_memory_config(xk, ttnn.DRAM_MEMORY_CONFIG) - return xk - xk = ttnn.transpose(xk, 1, 2, memory_config=ttnn.L1_MEMORY_CONFIG) + xk = ttnn.transpose(xk, 1, 2) xv = ttnn.transpose(xv, 1, 2) xk = ttnn.to_layout(xk, layout=ttnn.TILE_LAYOUT) @@ -208,10 +205,7 @@ def compute_xattn_kv_cache(self, xattn_tokens): return [xk, xv] def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache): - # batch = x_11SH.shape[-2] batch = xattn_cache[0].shape[0] - # assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - # 1, B, D xq = ttnn.linear( x_11SH, @@ -241,9 +235,6 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xk, xv = xattn_cache cache_seq_len = xk.shape[-2] - # xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) - # xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) - scores = ttnn.matmul( xq, ttnn.transpose(xk, -1, -2), @@ -300,7 +291,7 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache): seq_len = x_11SH.shape[-2] # B, S, D - # assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 128" + assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 32" if seq_len > 1024: x_11SH = ttnn.reshape(x_11SH, [1, seq_len // 1024, 1024, -1]) @@ -327,21 +318,12 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH xk, xv = xattn_cache cache_seq_len = xk.shape[-2] - # NOTE: Doing repeat in xattn_cache generation to avoid massive overhead in forward - # NOTE: Using naive SDPA for now since FlashDecode does not allow non-causal mask - # xq = ttnn.reshape(xq, [self.n_local_heads // self.n_local_kv_heads, self.n_local_kv_heads, seq_len, self.head_dim]) - # NOTE: repeat doesn't work, need to use repeat_interleave - # # xk = ttnn.repeat(xk, ttnn.Shape((self.n_local_heads // self.n_local_kv_heads, 1, 1, 1))) - # xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) - # # xv = ttnn.repeat(xv, ttnn.Shape((self.n_local_heads // self.n_local_kv_heads, 1, 1, 1))) - # xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) - scores = ttnn.matmul( xq, ttnn.transpose(xk, -1, -2), dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config_hifi4, + compute_kernel_config=self.compute_kernel_config_hifi2, program_config=self.model_config["VISION_XATTN_SCORE_PROGCFG"](seq_len, cache_seq_len), ) @@ -350,8 +332,6 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH scores = ttnn.add(scores, xattn_mask) scores = ttnn.softmax(scores, dim=-1, numeric_stable=True) - # TODO: scale_mask_softmax doesn't work for this xattn_mask shape - # scores = ttnn.scale_mask_softmax(scores, self.scale, xattn_mask, numeric_stable=True) output = ttnn.matmul( scores, xv, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 34dd88d19c1..eb552bc0e17 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -55,13 +55,6 @@ def __init__( self.model_config = configuration.get_model_config() self.state_dict = state_dict - # self.tok_embeddings = TtLlamaEmbedding( - # mesh_device=mesh_device, - # args=configuration, - # weight_cache_path=configuration.weight_cache_path(dtype), - # state_dict=state_dict, - # dtype=ttnn.bfloat16, # Row major layout requires bfloat16 - # ) # NOTE: Running all embeddings in torch for now since learnable embeddings use complex indexing ops which must be in torch self.tok_embeddings = torch.nn.Embedding(configuration.vocab_size, configuration.dim) tok_embedding_prefix = f"{state_dict_prefix}tok_embeddings." @@ -79,21 +72,7 @@ def __init__( weight_key="norm", ) - # # self.output layer weight # TODO: Generalize LMHead, maybe use llama_model's single-tile-sequence LMHead - # self.output = LMHead( - # configuration, - # mesh_device, - # ttnn.bfloat8_b, - # state_dict, - # state_dict_prefix, - # weight_cache_path, - # ) - - # torch_weight = lambda name, suffix: torch.transpose( - # self.state_dict[f"{state_dict_prefix}{name}.{suffix}"], -2, -1 - # ) - lm_head_torch = self.state_dict[f"{state_dict_prefix}output.weight"].transpose(-1, -2) total_splits = 8 # Arbitrary value which allows whole-tile splits in LM Head num_splits = total_splits // self.configuration.num_devices @@ -114,7 +93,6 @@ def __init__( self.outputs = [ as_interleaved_tensor("output", "weight", idx, ttnn.bfloat8_b, dim=-1) for idx in range(len(lm_head_torch)) ] - # self.output = as_interleaved_tensor("output", "weight", ttnn.bfloat8_b, dim=-1) self.n_llama_layers = configuration.n_layers self.model_dim = configuration.dim diff --git a/models/demos/llama3/tt/multimodal/llama_image_transformer_vision.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py similarity index 97% rename from models/demos/llama3/tt/multimodal/llama_image_transformer_vision.py rename to models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py index b02d04097b7..cded53b7120 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_transformer_vision.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py @@ -11,7 +11,7 @@ nearest_32, ) from models.common.lightweightmodule import LightweightModule -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import TtLlamaVisionEncoder +from models.demos.llama3.tt.multimodal.llama_vision_encoder import TtLlamaVisionEncoder from models.demos.falcon7b_common.tests.test_utils import ( synchronize_devices, diff --git a/models/demos/llama3/tt/multimodal/llama_image_attention.py b/models/demos/llama3/tt/multimodal/llama_image_attention.py index 6600258986b..b15e64d2374 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_image_attention.py @@ -45,6 +45,7 @@ def __init__( self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + self.configuration = configuration self.model_config = configuration.get_model_config() @@ -140,106 +141,10 @@ def pad_head_dim(weight, heads_out=True): self.scale = self.head_dim**-0.5 - def forward(self, x, mask): - return self.forward_tt(x, mask) - if os.environ.get("ATTN") == "tt": - return self.forward_tt(x, mask) - else: - return self.forward_pt(x, mask) - - def forward_pt(self, x_11SH, mask=None): + def forward(self, x_11SH, mask=None): seq_len = x_11SH.shape[-2] - x_torch = ttnn.to_torch(x_11SH, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - x_torch = x_torch[0].unsqueeze(0) - wqkv = ttnn.to_torch(self.wqkv, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - wqkv = wqkv.view(2, wqkv.shape[0] // 2, wqkv.shape[1]) - - wqkv = wqkv.reshape(2, self.hidden_size, 3 * self.n_heads // 2, -1) - wqkv = wqkv[..., : self.head_dim] - # wqkv = wqkv.reshape(2, self.hidden_size, -1) - wqkv = wqkv.reshape(2, self.hidden_size, 3, self.n_heads // 2, -1).permute(1, 2, 0, 3, 4) - wqkv = wqkv.reshape(self.hidden_size, 3, self.n_heads, -1).reshape(self.hidden_size, -1) - - # xqkv_fused_torch = torch.matmul(x_torch, wqkv).bfloat16().float() - xqkv_fused_torch = torch.nn.functional.linear(x_torch, wqkv.T).bfloat16().float() - # xqkv_fused_torch = torch.nn.functional.linear(x_torch, wqkv.tranpose).bfloat16().float() - # n, s, d = xqkv_fused_torch.shape[-3:] - s, d = xqkv_fused_torch.shape[-2:] - xqkv = xqkv_fused_torch.reshape(s, 3, d // 3) - q = xqkv[..., 0, :] - k = xqkv[..., 1, :] - v = xqkv[..., 2, :] - # xq = q.reshape(n, s, self.n_heads//2, -1).transpose(1, 2) - # xk = k.reshape(n, s, self.n_heads//2, -1).transpose(1, 2) - # xv = v.reshape(n, s, self.n_heads//2, -1).transpose(1, 2) - xq = q.reshape(s, self.n_heads, -1).transpose(0, 1) - xk = k.reshape(s, self.n_heads, -1).transpose(0, 1) - xv = v.reshape(s, self.n_heads, -1).transpose(0, 1) - mask_torch = ttnn.to_torch(mask, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - mask_torch = mask_torch[0] - attn_output = ( - torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask_torch, scale=self.scale) - .bfloat16() - .float() - ) # [...,:self.head_dim] - - # attn_output = attn_output.transpose(1, 2).reshape(n, s, -1).transpose(0, 1).reshape(s, -1) - attn_output = attn_output.transpose(0, 1).reshape(s, -1) - - wo = ttnn.to_torch(self.wo, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - wo = wo.view(2, wo.shape[0] // 2, wo.shape[1]) # .reshape(-1, wo.shape[1]) - - wo = ( - wo.transpose(1, 2) - .reshape(2, self.hidden_size, self.n_heads // 2, -1)[..., : self.head_dim] - .reshape(2, self.hidden_size, -1) - .transpose(1, 2) - .reshape(-1, self.hidden_size) - ) - - out = torch.nn.functional.linear(attn_output, wo.T).bfloat16().float() - # out = torch.sum(out, dim=0).unsqueeze(0).unsqueeze(0).bfloat16().float() - out = out.view(1, 1, 4224, -1) - - out_tt = ttnn.from_torch( - out, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - return out_tt - - def forward_tt(self, x_11SH, mask=None): - seq_len = x_11SH.shape[-2] - # assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - ### - # QKV matmuls - ### - - # reshaping long sequence to matmul fit on device - - # x_torch = ttnn.to_torch(x_11SH, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # x_torch = x_torch.view(2, seq_len, -1) - # wqkv = ttnn.to_torch(self.wqkv, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # wqkv = wqkv.view(2, wqkv.shape[0]//2, wqkv.shape[1]) - - # xqkv_fused_torch = torch.bmm(x_torch, wqkv).unsqueeze(1) - # xqkv_fused = ttnn.from_torch( - # xqkv_fused_torch, - # device=self.mesh_device, - # mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0), - # dtype=ttnn.bfloat16, - # memory_config=ttnn.DRAM_MEMORY_CONFIG, - # layout=ttnn.TILE_LAYOUT, - # ) - - # Depends on whether we are padding or not - MAX_MM_SEQ_LEN = 1056 - # MAX_MM_SEQ_LEN = 1024 - - # DEBUG: Don't batch it up - # MAX_MM_SEQ_LEN = 10000 + MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ if seq_len > MAX_MM_SEQ_LEN: x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) @@ -255,8 +160,6 @@ def forward_tt(self, x_11SH, mask=None): if seq_len > MAX_MM_SEQ_LEN: xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) - # ttnn.deallocate(x_11SH) - # split qkv into heads ( q_heads_1QSD, @@ -271,7 +174,7 @@ def forward_tt(self, x_11SH, mask=None): ) ttnn.deallocate(xqkv_fused) - # sdpa_cfg = self.model_config["SDPA_PROGCFG"](seq_len) + # TODO: get this from model_config sdpa_cfg = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False ) @@ -290,28 +193,6 @@ def forward_tt(self, x_11SH, mask=None): ttnn.deallocate(k_heads_1KSD) ttnn.deallocate(v_heads_1VSD) - # q_heads_torch = ttnn.to_torch(q_heads_1QSD, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # k_heads_torch = ttnn.to_torch(k_heads_1KSD, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # v_heads_torch = ttnn.to_torch(v_heads_1VSD, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # mask_torch = ttnn.to_torch(mask, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - - # attn_output_torch = torch.nn.functional.scaled_dot_product_attention( - # q_heads_torch, - # k_heads_torch, - # v_heads_torch, - # attn_mask=mask_torch, - # scale=self.scale, - # ) - - # attn_output_1QSD = ttnn.from_torch( - # attn_output_torch, - # device=self.mesh_device, - # mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0), - # dtype=ttnn.bfloat16, - # memory_config=ttnn.DRAM_MEMORY_CONFIG, - # layout=ttnn.TILE_LAYOUT, - # ) - ### # Output matmul ### @@ -321,24 +202,6 @@ def forward_tt(self, x_11SH, mask=None): ) ttnn.deallocate(attn_output_1QSD) - # attn_output_torch = ttnn.to_torch(attn_output_11SH, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # # breakpoint() - # wo = ttnn.to_torch(self.wo, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # wo = wo.view(2, 1, wo.shape[0]//2, wo.shape[1]) - # output = torch.matmul(attn_output_torch, wo) - # output = torch.sum(output, dim=0).unsqueeze(0).unsqueeze(0) - - # output_11SH = ttnn.from_torch( - # output, - # device=self.mesh_device, - # mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - # dtype=ttnn.bfloat16, - # memory_config=ttnn.DRAM_MEMORY_CONFIG, - # layout=ttnn.TILE_LAYOUT, - # ) - # return output_11SH - # breakpoint() - # reshaping long sequence to matmul fit on device if seq_len > MAX_MM_SEQ_LEN: attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) diff --git a/models/demos/llama3/tt/multimodal/llama_image_block.py b/models/demos/llama3/tt/multimodal/llama_image_block.py index d668b54897c..9ab361aed26 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_block.py +++ b/models/demos/llama3/tt/multimodal/llama_image_block.py @@ -92,46 +92,7 @@ def __init__( memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - def forward(self, x, mask): - return self.forward_tt(x, mask) - if os.environ.get("BLOCK") == "tt": - return self.forward_tt(x, mask) - else: - return self.forward_pt(x, mask) - - def forward_pt(self, x_11SH, mask=None): - seq_len = x_11SH.shape[-2] - assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - - attn_out = self.attn(self.ln_1(x_11SH), mask=mask) - if self.gated: - assert False - attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn)) - - x = ttnn.to_torch(x_11SH, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - attn_out = ttnn.to_torch(attn_out, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - res = (x + attn_out).bfloat16().float() - res = ttnn.from_torch( - res, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0), - ) - mlp_out = self.mlp(self.ln_2(res)) - if self.gated: - mlp_out = ttnn.mul(mlp_out, ttnn.tanh(self.gate_ffn)) - res = ttnn.to_torch(res, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - mlp_out = ttnn.to_torch(mlp_out, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - out = (res + mlp_out).bfloat16().float() - out = ttnn.from_torch( - out, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0), - ) - return out - - def forward_tt(self, x_11SH, mask=None): + def forward(self, x_11SH, mask=None): seq_len = x_11SH.shape[-2] assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" diff --git a/models/demos/llama3/tt/multimodal/llama_image_mlp.py b/models/demos/llama3/tt/multimodal/llama_image_mlp.py index fda3582597d..9f1c1f67590 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_mlp.py +++ b/models/demos/llama3/tt/multimodal/llama_image_mlp.py @@ -54,55 +54,7 @@ def __init__( self.c_proj_weight = as_interleaved_tensor("c_proj", "weight", dtype, dim=-2) self.c_proj_bias = as_interleaved_tensor("c_proj", "bias", ttnn.bfloat16, dim=None) - def forward(self, x): - return self.forward_tt(x) - if os.environ.get("MLP") == "tt": - return self.forward_tt(x) - else: - return self.forward_pt(x) - - def forward_pt(self, x): - x = ttnn.to_torch( - x, device=self.mesh_device, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) - ).float() - x = x[0] - - c_fc_weight = ttnn.to_torch( - self.c_fc_weight, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1) - ).float() - c_fc_bias = ttnn.to_torch( - self.c_fc_bias, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1) - ).float()[:1] - c_proj_weight = ttnn.to_torch( - self.c_proj_weight, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-2) - ).float() - c_proj_bias = ttnn.to_torch( - self.c_proj_bias, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) - ).float() - c_proj_bias = c_proj_bias[:1] - - # hidden = (torch.matmul(x, c_fc_weight).bfloat16().float() + c_fc_bias).bfloat16().float() - # hidden = torch.nn.functional.gelu(hidden).bfloat16().float() - # hidden = (torch.matmul(hidden, c_proj_weight).bfloat16().float() + c_proj_bias).bfloat16().float() - # hidden = hidden.view(1, 1, 5120, -1) - x = x.bfloat16().float() - hidden = torch.nn.functional.linear(x, c_fc_weight.T, c_fc_bias).bfloat16().float() - hidden = torch.nn.functional.gelu(hidden).bfloat16().float() - hidden = torch.nn.functional.linear(hidden, c_proj_weight.T).bfloat16().float() - hidden += c_proj_bias - hidden = hidden.view(1, 1, 4224, -1) - - hidden = ttnn.from_torch( - hidden, - device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - ) - return hidden - - def forward_tt(self, x: ttnn.Tensor) -> ttnn.Tensor: + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: """ w1 -> gate_proj w2 -> down_proj @@ -112,8 +64,8 @@ def forward_tt(self, x: ttnn.Tensor) -> ttnn.Tensor: seq_len = x.shape[-2] # Depends on whether we are padding or not - MAX_MM_SEQ_LEN = 1056 - # MAX_MM_SEQ_LEN = 1024 + MAX_MM_SEQ_LEN = self.args.VISION_MAX_MM_SEQ + x_in = x if seq_len >= MAX_MM_SEQ_LEN: # Too big to compute. Set different program configs based on seqlen # Reshape input to to fit on device and parallelize computation @@ -131,9 +83,9 @@ def forward_tt(self, x: ttnn.Tensor) -> ttnn.Tensor: dtype=ttnn.bfloat16, program_config=pc_1, memory_config=ttnn.DRAM_MEMORY_CONFIG, - # activation="gelu", # NOTE: activation must be passed to linear here, not in program config! Bad output otherwise + activation="gelu", # NOTE: activation must be passed to linear here, not in program config! Bad output otherwise ) - c_fc_out = ttnn.gelu(c_fc_out, fast_and_approximate_mode=False) + c_proj_out = ttnn.linear( c_fc_out, self.c_proj_weight, @@ -144,7 +96,6 @@ def forward_tt(self, x: ttnn.Tensor) -> ttnn.Tensor: memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - # if seq_len >= 1024: # Reshape back to intended shape # NOTE: Need to reshape to 4D so that fast_reduce_nc hsa a dim1 to work on c_proj_out = ttnn.reshape(c_proj_out, [1, 1, seq_len, -1]) diff --git a/models/demos/llama3/tt/multimodal/llama_layernorm.py b/models/demos/llama3/tt/multimodal/llama_layernorm.py index 0750d82da52..a20c4764ad1 100644 --- a/models/demos/llama3/tt/multimodal/llama_layernorm.py +++ b/models/demos/llama3/tt/multimodal/llama_layernorm.py @@ -82,46 +82,7 @@ def __init__( ) self.sharded_output_config = self.sharded_input_config - def forward(self, x): - return self.forward_tt(x) - if os.environ.get("LN") == "tt": - return self.forward_tt(x) - else: - return self.forward_pt(x) - - def forward_pt(self, x: ttnn.Tensor, in_sharded=False, out_sharded=False) -> ttnn.Tensor: - # If input is sharded do sharded RMSNorm and optionally return sharded output - - x = ttnn.to_torch( - x, - device=self.device, - mesh_composer=ttnn.ConcatMeshToTensor(self.device, dim=0), - )[0].float() - weight = ttnn.to_torch( - self.weight, - device=self.device, - mesh_composer=ttnn.ConcatMeshToTensor(self.device, dim=0), - )[0, 0].float() - bias = ttnn.to_torch( - self.bias, - device=self.device, - mesh_composer=ttnn.ConcatMeshToTensor(self.device, dim=0), - )[0, 0].float() - - out = torch.nn.functional.layer_norm(x, x.shape[-1:], weight=weight, bias=bias, eps=self.eps) - out = out - - out = ttnn.from_torch( - out, - device=self.device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.device), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - ) - return out - - def forward_tt(self, x: ttnn.Tensor, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + def forward(self, x: ttnn.Tensor, in_sharded=False, out_sharded=False) -> ttnn.Tensor: if in_sharded: x = ttnn.layer_norm( x, diff --git a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py index 031ca56e775..9ef2aadddac 100644 --- a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py @@ -101,7 +101,6 @@ def forward(self, x: ttnn.Tensor, ar: torch.Tensor, num_tiles: int = None): if num_tiles is None: num_tiles = self.num_tiles elif num_tiles > self.num_tiles: - # TODO: Need to implement? assert False, "_dynamic_resize is currently not supported for TtLllamaTilePositionEmbedding" # Get the correct embeddings for the given aspect ratios diff --git a/models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py similarity index 98% rename from models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py rename to models/demos/llama3/tt/multimodal/llama_vision_encoder.py index ee22ab29e30..8ab71f831c3 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py @@ -192,7 +192,6 @@ def forward(self, images, ar): assert isinstance( images, torch.Tensor ), "VisionEncoder input must be a torch tensor because of unfold in self.conv1" - SKIP_EMBED = False if images.ndim == 5: num_concurrent_media = 1 bsz, num_chunks, nch, w, h = images.shape @@ -215,17 +214,15 @@ def forward(self, images, ar): x = ttnn.reshape(x, (1, bsz * num_concurrent_media * num_chunks, ntok, dim)) # apply cls token - if not SKIP_EMBED: - x = self.class_embedding(x) - ntok += 1 + x = self.class_embedding(x) + ntok += 1 # apply position embeddings # NOTE! After class embedding, x is padded tilized tensor. Reshapes fail for padded tilized tensors, so do the reshape in row-major x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) x = ttnn.reshape(x, (bsz * num_concurrent_media, num_chunks, ntok, dim)) x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) - if not SKIP_EMBED: - x = self.positional_embedding(x, ar) + x = self.positional_embedding(x, ar) # BUG: layernorm takes 4d tensor -> 3d?? x = self.ln_pre(x) @@ -239,6 +236,7 @@ def forward(self, images, ar): fake_x = torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]) attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) + # Mask stripes for the extra padding required on TT hardware attn_mask = mask_tile_padding(attn_mask, ntok, npad, num_chunks) attn_mask = ttnn.as_tensor( attn_mask, diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 4a7560de5e9..5a64637aad8 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -19,7 +19,9 @@ import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.image_transform as llama_reference_image_transforms import ttnn -from models.demos.llama3.tt.multimodal.llama_image_transformer_vision import TtLlamaCrossAttentionTransformerVision +from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_vision import ( + TtLlamaCrossAttentionTransformerVision, +) from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_text import ( TtLlamaCrossAttentionTransformerText, ) @@ -31,6 +33,9 @@ get_rot_transformation_mat, get_single_rot_mat, ) +from models.utility_functions import ( + nearest_32, +) logger = logging.getLogger(__name__) MP_SCALE = 8 @@ -217,7 +222,7 @@ def compute_vision_tokens_masks( vision_tokens = self.vision_model(stacked_images, aspect_ratios) # Back to torch vision_tokens = ttnn.to_torch(vision_tokens, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)) - chunk_seq_len = (self.configuration.vision_chunk_size // self.configuration.vision_patch_size) ** 2 + 1 + chunk_seq_len = self.configuration.vision_chunk_ntok # NOTE: slicing up to chunk_seq_len is necessary because padding information is lost by this point vision_tokens = ( vision_tokens[0, :, :chunk_seq_len] @@ -226,11 +231,12 @@ def compute_vision_tokens_masks( ) bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) + padded_seq_len = self.max_num_chunks * nearest_32(self.configuration.vision_chunk_ntok) # Prepare vision tokens for TT text_model vision_tokens_squeeze = vision_tokens.view(1, bsz, -1, image_token_dim) vision_tokens_squeeze = torch.nn.functional.pad( - vision_tokens_squeeze, (0, 0, 0, 4224 - vision_tokens_squeeze.shape[2]), "constant", 0 + vision_tokens_squeeze, (0, 0, 0, padded_seq_len - vision_tokens_squeeze.shape[2]), "constant", 0 ) vision_tokens_tt = ttnn.from_torch( vision_tokens_squeeze, @@ -262,7 +268,7 @@ def compute_vision_tokens_masks( cross_attention_masks = torch.nn.functional.pad( cross_attention_masks, - (0, 4224 - cross_attention_masks.shape[3]), + (0, padded_seq_len - cross_attention_masks.shape[3]), "constant", get_negative_inf_value(torch.float32), ) From 3a39037e8a950401e0ece2b03bde093970539f9f Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 23 Oct 2024 14:32:24 +0000 Subject: [PATCH 09/16] #13368: Fixup tests --- .../multimodal/test_llama_cross_attention_transformer_vision.py | 2 +- .../demos/llama3/tests/multimodal/test_llama_vision_encoder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py index 6555321578a..e23d9063787 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py @@ -59,7 +59,7 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese # Create rand inputs of the right shape batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, model_args.vision_chunk_size) - chunk_seq_len = model_args.vision_chunk_ntok - 1 # tokens per chunk without class token + chunk_seq_len = model_args.vision_chunk_ntok # tokens per chunk, including class token images = torch.randn(batch, num_media, num_chunks, n_channel, patch_size, patch_size) ars = torch.tensor([2, 2]).reshape(batch, num_media, 2) diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py index 111c584f781..50064a2e480 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py @@ -63,7 +63,7 @@ def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_se ) # Create rand inputs of the right shape - batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, 448) + batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, model_args.vision_chunk_size) images = torch.randn(batch, num_media, num_chunks, n_channel, patch_size, patch_size) ars = torch.tensor([2, 2]).reshape(batch, num_media, 2) From 8689d406734444e9ac081bc5b0c1c3e208bdf171 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 23 Oct 2024 14:34:45 +0000 Subject: [PATCH 10/16] #13368: Add vision tests to unit, frequent, and demo --- .github/workflows/t3000-demo-tests-impl.yaml | 1 + .../workflows/t3000-frequent-tests-impl.yaml | 2 + .github/workflows/t3000-unit-tests-impl.yaml | 2 + tests/scripts/t3000/run_t3000_demo_tests.sh | 27 ++++++++ .../scripts/t3000/run_t3000_frequent_tests.sh | 52 ++++++++++++++++ tests/scripts/t3000/run_t3000_unit_tests.sh | 62 +++++++++++++++++++ 6 files changed, 146 insertions(+) diff --git a/.github/workflows/t3000-demo-tests-impl.yaml b/.github/workflows/t3000-demo-tests-impl.yaml index 1c353969673..a5fce0e2c1a 100644 --- a/.github/workflows/t3000-demo-tests-impl.yaml +++ b/.github/workflows/t3000-demo-tests-impl.yaml @@ -12,6 +12,7 @@ jobs: { name: "t3k_falcon40b_tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 50, owner_id: U053W15B6JF}, #Djordje Ivanovic { name: "t3k_llama3_70b_tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k_llama3_tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 30, owner_id: U03PUAKE719}, # Miguel Tairum + { name: "t3k_llama3_vision_tests", arch: wormhole_b0, cmd: run_t3000_llama3_vision_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k_falcon7b_tests", arch: wormhole_b0, cmd: run_t3000_falcon7b_tests, timeout: 90, owner_id: U05RWH3QUPM}, #Salar Hosseini { name: "t3k_mixtral_tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 50, owner_id: U03PUAKE719}, # Miguel Tairum ] diff --git a/.github/workflows/t3000-frequent-tests-impl.yaml b/.github/workflows/t3000-frequent-tests-impl.yaml index e83be639c72..a58a497e7e0 100644 --- a/.github/workflows/t3000-frequent-tests-impl.yaml +++ b/.github/workflows/t3000-frequent-tests-impl.yaml @@ -15,6 +15,8 @@ jobs: { name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 120, owner_id: U04S2UV6L8N}, #Sofija Jovic { name: "t3k llama2_70b tests", arch: wormhole_b0, cmd: run_t3000_llama2_70b_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k llama3 tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 60, owner_id: U03PUAKE719}, #Miguel Tairum Cruz + { name: "t3k llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 60, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k resnet tests", arch: wormhole_b0, cmd: run_t3000_resnet_tests, timeout: 30, owner_id: U013121KDH9}, #Austin Ho ] diff --git a/.github/workflows/t3000-unit-tests-impl.yaml b/.github/workflows/t3000-unit-tests-impl.yaml index 9856f92ec06..b2da2f63311 100644 --- a/.github/workflows/t3000-unit-tests-impl.yaml +++ b/.github/workflows/t3000-unit-tests-impl.yaml @@ -15,6 +15,8 @@ jobs: { name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 30, owner_id: U053W15B6JF}, #Djordje Ivanovic { name: "t3k llama3-small tests", arch: wormhole_b0, cmd: run_t3000_llama3-small_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k llama3.2-11b tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz + { name: "t3k llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k n300 mesh llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k grok tests", arch: wormhole_b0, cmd: run_t3000_grok_tests, timeout: 30, owner_id: U03HY7MK4BT}, #Mark O'Connor { name: "t3k unet shallow tests", arch: wormhole_b0, cmd: run_t3000_unet_shallow_tests, timeout: 30, owner_id: U06ECNVR0EN}, #Evan Smal diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh index a63433dd028..4497b392828 100755 --- a/tests/scripts/t3000/run_t3000_demo_tests.sh +++ b/tests/scripts/t3000/run_t3000_demo_tests.sh @@ -75,6 +75,33 @@ run_t3000_llama3_tests() { fi } +run_t3000_llama3_vision_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_llama3_vision_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + n300=N300 + t3k=T3K + + for fake_device in "$n300" "$t3k"; do + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/multimodal_demo_chat.py -k "tt and 1" --timeout 600; fail+=$? + echo "LOG_METAL: Llama3 vision tests for $fake_device completed" + done + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_llama3_vision_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + run_t3000_falcon7b_tests(){ # Record the start time fail=0 diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh index 468e6757845..fe5626b9cfe 100755 --- a/tests/scripts/t3000/run_t3000_frequent_tests.sh +++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh @@ -75,6 +75,58 @@ run_t3000_llama3_tests() { fi } +run_t3000_llama3.2-11b-vision_freq_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_llama3.2-11b-vision_freq_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_transformer.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py ; fail+=$? + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_llama3.2-11b-vision_freq_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + +run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + # Use FAKE_DEVICE env variable to run on an N300 mesh + fake_device=N300 + + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_transformer.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py ; fail+=$? + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + run_t3000_mixtral_tests() { # Record the start time fail=0 diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index 2302718fb45..4f54ecea9fd 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -152,6 +152,68 @@ run_t3000_llama3.2-11b_tests() { fi } +run_t3000_llama3.2-11b-vision_unit_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_llama3.2-11b-vision_unit_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_class_embedding.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py ; fail+=$? + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_llama3.2-11b-vision_unit_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + +run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + # Use FAKE_DEVICE env variable to run on an N300 mesh + fake_device=N300 + + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_class_embedding.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py ; fail+=$? + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + run_t3000_mixtral_tests() { # Record the start time fail=0 From 80db0b847b5a15d19dcd9afc3d17f12a6c8b9526 Mon Sep 17 00:00:00 2001 From: mtairum Date: Tue, 22 Oct 2024 13:54:22 +0000 Subject: [PATCH 11/16] #13368: Relaxed 11B perf estimate to avoid error in CI --- models/demos/llama3/tests/test_llama_perf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/demos/llama3/tests/test_llama_perf.py b/models/demos/llama3/tests/test_llama_perf.py index 260a0221e9d..ea3c922646d 100644 --- a/models/demos/llama3/tests/test_llama_perf.py +++ b/models/demos/llama3/tests/test_llama_perf.py @@ -52,7 +52,7 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_ elif "3.1-8B" in model_args.DEFAULT_CACHE_PATH: expected_inference_time = 0.07 elif "3.2-11B" in model_args.DEFAULT_CACHE_PATH: - expected_inference_time = 0.07 + expected_inference_time = 0.085 else: assert False, f"Llama model not found. Supported Llama models: [3.2-1B, 3.2-3B, 3.1-8B]" From dd10a03354f3566c114d64d99af5446f2275f8f2 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 23 Oct 2024 18:13:09 +0000 Subject: [PATCH 12/16] #0: Added Llama-models python requirements --- tt_metal/python_env/requirements-dev.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tt_metal/python_env/requirements-dev.txt b/tt_metal/python_env/requirements-dev.txt index f1761202d68..dd3ad7f4035 100644 --- a/tt_metal/python_env/requirements-dev.txt +++ b/tt_metal/python_env/requirements-dev.txt @@ -57,5 +57,7 @@ fsspec==2023.9.2 # Temporary pin to 2023.9.2: https://github.com/tenstorrent/tt- docopt==0.6.2 tabulate==0.9.0 blobfile==2.1.1 # Required for llama3 +pydantic==2.9.2 # Required for llama3 +pydantic_core==2.23.4 # Required for llama3 numpy>=1.24.4,<2 huggingface-hub==0.25.2 From 2c5ff7f9cc8cb9b6cece8b2e632c842e051f6d62 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 23 Oct 2024 23:30:51 +0000 Subject: [PATCH 13/16] #13368: Fixup mesh_device when not passed FAKE_DEVICE --- .../test_llama_cross_attention_transformer_vision.py | 6 +++++- .../demos/llama3/tests/multimodal/test_llama_image_block.py | 6 +++++- .../llama3/tests/multimodal/test_llama_image_transformer.py | 6 +++++- .../llama3/tests/multimodal/test_llama_vision_encoder.py | 6 +++++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py index e23d9063787..abcc1bd8156 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py @@ -25,7 +25,11 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_vision_transformer_inference(mesh_device, use_program_cache, reset_seeds): diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 613fd2a3021..bea85a0a16f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -33,7 +33,11 @@ ) @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_program_cache, reset_seeds, ensure_gc): diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index 1fee8a125c4..d042eb1e683 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -33,7 +33,11 @@ ) @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_image_transformer_inference( diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py index 50064a2e480..61824eb484e 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py @@ -23,7 +23,11 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_seeds): From bc994401ef48ff621eae1beda64ce2f593ab99db Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 28 Oct 2024 06:42:27 -0700 Subject: [PATCH 14/16] #13368: Remove llama-specific packages from requirements-dev.txt --- tests/scripts/t3000/run_t3000_demo_tests.sh | 3 +++ tests/scripts/t3000/run_t3000_frequent_tests.sh | 6 ++++++ tests/scripts/t3000/run_t3000_unit_tests.sh | 6 ++++++ tt_metal/python_env/requirements-dev.txt | 2 -- 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh index 4497b392828..ad329ca2319 100755 --- a/tests/scripts/t3000/run_t3000_demo_tests.sh +++ b/tests/scripts/t3000/run_t3000_demo_tests.sh @@ -88,6 +88,9 @@ run_t3000_llama3_vision_tests() { n300=N300 t3k=T3K + # Install Vision-specific packages + pip install -r models/demos/llama3/reference/llama_models/requirements.txt + for fake_device in "$n300" "$t3k"; do FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/multimodal_demo_chat.py -k "tt and 1" --timeout 600; fail+=$? echo "LOG_METAL: Llama3 vision tests for $fake_device completed" diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh index fe5626b9cfe..7b95c2570d9 100755 --- a/tests/scripts/t3000/run_t3000_frequent_tests.sh +++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh @@ -86,6 +86,9 @@ run_t3000_llama3.2-11b-vision_freq_tests() { # Llama3.2-11B llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + # Install Vision-specific packages + pip install -r models/demos/llama3/reference/llama_models/requirements.txt + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_transformer.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py ; fail+=$? @@ -113,6 +116,9 @@ run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests() { # Use FAKE_DEVICE env variable to run on an N300 mesh fake_device=N300 + # Install Vision-specific packages + pip install -r models/demos/llama3/reference/llama_models/requirements.txt + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_transformer.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py ; fail+=$? diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index 4f54ecea9fd..e15ebbb00df 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -163,6 +163,9 @@ run_t3000_llama3.2-11b-vision_unit_tests() { # Llama3.2-11B llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + # Install Vision-specific packages + pip install -r models/demos/llama3/reference/llama_models/requirements.txt + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$? @@ -195,6 +198,9 @@ run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests() { # Use FAKE_DEVICE env variable to run on an N300 mesh fake_device=N300 + # Install Vision-specific packages + pip install -r models/demos/llama3/reference/llama_models/requirements.txt + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$? diff --git a/tt_metal/python_env/requirements-dev.txt b/tt_metal/python_env/requirements-dev.txt index dd3ad7f4035..f1761202d68 100644 --- a/tt_metal/python_env/requirements-dev.txt +++ b/tt_metal/python_env/requirements-dev.txt @@ -57,7 +57,5 @@ fsspec==2023.9.2 # Temporary pin to 2023.9.2: https://github.com/tenstorrent/tt- docopt==0.6.2 tabulate==0.9.0 blobfile==2.1.1 # Required for llama3 -pydantic==2.9.2 # Required for llama3 -pydantic_core==2.23.4 # Required for llama3 numpy>=1.24.4,<2 huggingface-hub==0.25.2 From df3e545e3777427c9ce53d500ac8b1dcd9daecf9 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Tue, 29 Oct 2024 06:02:08 -0700 Subject: [PATCH 15/16] #13368: Remove llama_models as submodule. Move its install to llama3 requirements.txt. --- .gitmodules | 3 --- models/demos/llama3/demo/multimodal_demo_chat.py | 4 ++-- models/demos/llama3/demo/multimodal_demo_text.py | 4 ++-- models/demos/llama3/reference/llama_models | 1 - models/demos/llama3/requirements.txt | 1 + .../demos/llama3/tests/multimodal/test_llama_conv2d_patch.py | 2 +- .../llama3/tests/multimodal/test_llama_cross_attention.py | 2 +- .../multimodal/test_llama_cross_attention_transformer_text.py | 2 +- .../test_llama_cross_attention_transformer_vision.py | 2 +- .../demos/llama3/tests/multimodal/test_llama_cross_block.py | 2 +- .../llama3/tests/multimodal/test_llama_image_attention.py | 4 ++-- .../demos/llama3/tests/multimodal/test_llama_image_block.py | 4 ++-- models/demos/llama3/tests/multimodal/test_llama_image_mlp.py | 2 +- .../llama3/tests/multimodal/test_llama_image_transformer.py | 4 ++-- models/demos/llama3/tests/multimodal/test_llama_layernorm.py | 2 +- .../tests/multimodal/test_llama_tile_position_embedding.py | 2 +- .../llama3/tests/multimodal/test_llama_vision_encoder.py | 2 +- .../demos/llama3/tests/multimodal/test_llama_vision_model.py | 4 ++-- models/demos/llama3/tt/multimodal/llama_vision_encoder.py | 2 +- models/demos/llama3/tt/multimodal/llama_vision_model.py | 4 ++-- tests/scripts/t3000/run_t3000_demo_tests.sh | 2 +- tests/scripts/t3000/run_t3000_frequent_tests.sh | 4 ++-- tests/scripts/t3000/run_t3000_unit_tests.sh | 4 ++-- 23 files changed, 30 insertions(+), 33 deletions(-) delete mode 160000 models/demos/llama3/reference/llama_models create mode 100644 models/demos/llama3/requirements.txt diff --git a/.gitmodules b/.gitmodules index 1c29f48e987..ab121e423f3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,6 +28,3 @@ [submodule "tt_metal/third_party/tt_llk_blackhole"] path = tt_metal/third_party/tt_llk_blackhole url = https://github.com/tenstorrent/tt-llk-bh.git -[submodule "models/demos/llama3/reference/llama_models"] - path = models/demos/llama3/reference/llama_models - url = https://github.com/tenstorrent/llama-models.git diff --git a/models/demos/llama3/demo/multimodal_demo_chat.py b/models/demos/llama3/demo/multimodal_demo_chat.py index 05ee6c4159d..a7d8c3ffe4b 100644 --- a/models/demos/llama3/demo/multimodal_demo_chat.py +++ b/models/demos/llama3/demo/multimodal_demo_chat.py @@ -9,9 +9,9 @@ from termcolor import cprint from models.demos.llama3.demo.multimodal_demo_text import create_multimodal_model -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.generation as llama_reference_generation +import llama_models.llama3.reference_impl.generation as llama_reference_generation -from models.demos.llama3.reference.llama_models.models.llama3.api.datatypes import ImageMedia, UserMessage +from llama_models.llama3.api.datatypes import ImageMedia, UserMessage THIS_DIR = Path(__file__).parent.parent.resolve() / "reference/llama_models/models/scripts/" diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index f2eada1966c..4d3dad9f7f2 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -8,9 +8,9 @@ from PIL import Image as PIL_Image from termcolor import cprint -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.generation as llama_reference_generation +import llama_models.llama3.reference_impl.generation as llama_reference_generation -from models.demos.llama3.reference.llama_models.models.llama3.api.datatypes import ImageMedia +from llama_models.llama3.api.datatypes import ImageMedia THIS_DIR = Path(__file__).parent.parent.resolve() / "reference/llama_models/models/scripts/" diff --git a/models/demos/llama3/reference/llama_models b/models/demos/llama3/reference/llama_models deleted file mode 160000 index c217d3eb10f..00000000000 --- a/models/demos/llama3/reference/llama_models +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c217d3eb10f6c01bbaa1aa7c714bb7c5ccf3b14f diff --git a/models/demos/llama3/requirements.txt b/models/demos/llama3/requirements.txt new file mode 100644 index 00000000000..e830cffd233 --- /dev/null +++ b/models/demos/llama3/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/tenstorrent/llama-models.git@tt_metal_tag diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index d98d1c8613e..c38dd5ccb26 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -25,7 +25,7 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod @skip_for_grayskull("Requires wormhole_b0 to run") diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index ba0e269480f..14c7db894eb 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -7,7 +7,7 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_attention import TtLlamaCrossAttention from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index f11165862b6..286e7bac509 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -7,7 +7,7 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_text import ( TtLlamaCrossAttentionTransformerText, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py index abcc1bd8156..a3f360bfa23 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py @@ -7,7 +7,7 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_vision import ( TtLlamaCrossAttentionTransformerVision, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index f45f0eaa432..f64f9c98f7f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -7,7 +7,7 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_block import TtLlamaCrossAttentionTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py index 357f02a5b10..49f4ee58d2f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py @@ -7,8 +7,8 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod -from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod +from llama_models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_attention import TtLlamaImageAttention from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.model_config import TtModelArgs diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index bea85a0a16f..8eecfe156d6 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -7,8 +7,8 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod -from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod +from llama_models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_block import TtLlamaImageTransformerBlock from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.model_config import TtModelArgs diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py index 4181f9dfd0c..c6b65ef7f9d 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py @@ -8,7 +8,7 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_image_mlp import TtLlamaImageFeedForward from models.demos.llama3.tt.model_config import TtModelArgs from models.utility_functions import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index d042eb1e683..b92d74290d6 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -7,8 +7,8 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod -from models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal import encoder_utils +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod +from llama_models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_transformer import TtLlamaImageTransformer from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding diff --git a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py index 23294782651..d52d9f415f3 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py +++ b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py @@ -8,7 +8,7 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_layernorm import TtLayerNorm # Updated import for LayerNorm from models.demos.llama3.tt.model_config import TtModelArgs from models.utility_functions import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py index 2249b684cbd..4ba64dd76ff 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py @@ -31,7 +31,7 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod @skip_for_grayskull("Requires wormhole_b0 to run") diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py index 61824eb484e..b3790d498e5 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py @@ -7,7 +7,7 @@ import os import ttnn -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_mod +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_vision_encoder import TtLlamaVisionEncoder from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_model.py b/models/demos/llama3/tests/multimodal/test_llama_vision_model.py index e39555d4be1..f55a47891ac 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_vision_model.py +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_model.py @@ -8,9 +8,9 @@ from PIL import Image as PIL_Image from termcolor import cprint -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.generation as llama_reference_generation +import llama_models.llama3.reference_impl.generation as llama_reference_generation -from models.demos.llama3.reference.llama_models.models.llama3.api.datatypes import ImageMedia +from llama_models.llama3.api.datatypes import ImageMedia from models.utility_functions import ( comp_pcc, diff --git a/models/demos/llama3/tt/multimodal/llama_vision_encoder.py b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py index 8ab71f831c3..ff8a71c7de5 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_encoder.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py @@ -24,7 +24,7 @@ synchronize_devices, ) -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.encoder_utils as encoder_utils +import llama_models.llama3.reference_impl.multimodal.encoder_utils as encoder_utils def to_2tuple(x): diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 5a64637aad8..f96aba089c4 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -15,8 +15,8 @@ from torch import nn, Tensor -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.model as llama_reference_model -import models.demos.llama3.reference.llama_models.models.llama3.reference_impl.multimodal.image_transform as llama_reference_image_transforms +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_model +import llama_models.llama3.reference_impl.multimodal.image_transform as llama_reference_image_transforms import ttnn from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_vision import ( diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh index ad329ca2319..1ced36a3955 100755 --- a/tests/scripts/t3000/run_t3000_demo_tests.sh +++ b/tests/scripts/t3000/run_t3000_demo_tests.sh @@ -89,7 +89,7 @@ run_t3000_llama3_vision_tests() { t3k=T3K # Install Vision-specific packages - pip install -r models/demos/llama3/reference/llama_models/requirements.txt + pip install -r models/demos/llama3/requirements.txt for fake_device in "$n300" "$t3k"; do FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/multimodal_demo_chat.py -k "tt and 1" --timeout 600; fail+=$? diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh index 7b95c2570d9..b5fb7360c89 100755 --- a/tests/scripts/t3000/run_t3000_frequent_tests.sh +++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh @@ -87,7 +87,7 @@ run_t3000_llama3.2-11b-vision_freq_tests() { llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ # Install Vision-specific packages - pip install -r models/demos/llama3/reference/llama_models/requirements.txt + pip install -r models/demos/llama3/requirements.txt LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_transformer.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py ; fail+=$? @@ -117,7 +117,7 @@ run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests() { fake_device=N300 # Install Vision-specific packages - pip install -r models/demos/llama3/reference/llama_models/requirements.txt + pip install -r models/demos/llama3/requirements.txt FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_transformer.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py ; fail+=$? diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index e15ebbb00df..e47a6b7a93c 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -164,7 +164,7 @@ run_t3000_llama3.2-11b-vision_unit_tests() { llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ # Install Vision-specific packages - pip install -r models/demos/llama3/reference/llama_models/requirements.txt + pip install -r models/demos/llama3/requirements.txt LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? @@ -199,7 +199,7 @@ run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests() { fake_device=N300 # Install Vision-specific packages - pip install -r models/demos/llama3/reference/llama_models/requirements.txt + pip install -r models/demos/llama3/requirements.txt FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? From 3807f76953fa3579506a2d94d7b608dae27ec19e Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Tue, 29 Oct 2024 08:17:02 -0700 Subject: [PATCH 16/16] #13368: Fix resource path in multimodal demos. --- models/demos/llama3/demo/multimodal_demo_chat.py | 6 ++++-- models/demos/llama3/demo/multimodal_demo_text.py | 12 +++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/models/demos/llama3/demo/multimodal_demo_chat.py b/models/demos/llama3/demo/multimodal_demo_chat.py index a7d8c3ffe4b..7b39fb3db61 100644 --- a/models/demos/llama3/demo/multimodal_demo_chat.py +++ b/models/demos/llama3/demo/multimodal_demo_chat.py @@ -13,7 +13,9 @@ from llama_models.llama3.api.datatypes import ImageMedia, UserMessage -THIS_DIR = Path(__file__).parent.parent.resolve() / "reference/llama_models/models/scripts/" +from pkg_resources import resource_filename + +IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) import torch import pytest @@ -70,7 +72,7 @@ def test_llama_multimodal_demo_chat( # image understanding dialogs = [] - with open(THIS_DIR / "resources/dog.jpg", "rb") as f: + with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") dialogs = [ diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index 4d3dad9f7f2..102b03975e4 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -12,7 +12,9 @@ from llama_models.llama3.api.datatypes import ImageMedia -THIS_DIR = Path(__file__).parent.parent.resolve() / "reference/llama_models/models/scripts/" +from pkg_resources import resource_filename + +IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) import torch import pytest @@ -85,16 +87,16 @@ def test_llama_multimodal_demo_text( model = create_multimodal_model(generator.args, mesh_device) generator.model = model - with open(THIS_DIR / "resources/dog.jpg", "rb") as f: + with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") - with open(THIS_DIR / "resources/pasta.jpeg", "rb") as f: + with open(IMG_PATH / "pasta.jpeg", "rb") as f: img2 = PIL_Image.open(f).convert("RGB") - with open(THIS_DIR / "resources/ocr_image.jpeg", "rb") as f: + with open(IMG_PATH / "ocr_image.jpeg", "rb") as f: ocr_image = PIL_Image.open(f).convert("RGB") - with open(THIS_DIR / "resources/clutter.jpeg", "rb") as f: + with open(IMG_PATH / "clutter.jpeg", "rb") as f: clutter = PIL_Image.open(f).convert("RGB") interleaved_contents = [