diff --git a/.github/workflows/t3000-demo-tests-impl.yaml b/.github/workflows/t3000-demo-tests-impl.yaml index 1c353969673..a5fce0e2c1a 100644 --- a/.github/workflows/t3000-demo-tests-impl.yaml +++ b/.github/workflows/t3000-demo-tests-impl.yaml @@ -12,6 +12,7 @@ jobs: { name: "t3k_falcon40b_tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 50, owner_id: U053W15B6JF}, #Djordje Ivanovic { name: "t3k_llama3_70b_tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k_llama3_tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 30, owner_id: U03PUAKE719}, # Miguel Tairum + { name: "t3k_llama3_vision_tests", arch: wormhole_b0, cmd: run_t3000_llama3_vision_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k_falcon7b_tests", arch: wormhole_b0, cmd: run_t3000_falcon7b_tests, timeout: 90, owner_id: U05RWH3QUPM}, #Salar Hosseini { name: "t3k_mixtral_tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 50, owner_id: U03PUAKE719}, # Miguel Tairum ] diff --git a/.github/workflows/t3000-frequent-tests-impl.yaml b/.github/workflows/t3000-frequent-tests-impl.yaml index e83be639c72..a58a497e7e0 100644 --- a/.github/workflows/t3000-frequent-tests-impl.yaml +++ b/.github/workflows/t3000-frequent-tests-impl.yaml @@ -15,6 +15,8 @@ jobs: { name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 120, owner_id: U04S2UV6L8N}, #Sofija Jovic { name: "t3k llama2_70b tests", arch: wormhole_b0, cmd: run_t3000_llama2_70b_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k llama3 tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 60, owner_id: U03PUAKE719}, #Miguel Tairum Cruz + { name: "t3k llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 60, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k resnet tests", arch: wormhole_b0, cmd: run_t3000_resnet_tests, timeout: 30, owner_id: U013121KDH9}, #Austin Ho ] diff --git a/.github/workflows/t3000-unit-tests-impl.yaml b/.github/workflows/t3000-unit-tests-impl.yaml index 9856f92ec06..b2da2f63311 100644 --- a/.github/workflows/t3000-unit-tests-impl.yaml +++ b/.github/workflows/t3000-unit-tests-impl.yaml @@ -15,6 +15,8 @@ jobs: { name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 30, owner_id: U053W15B6JF}, #Djordje Ivanovic { name: "t3k llama3-small tests", arch: wormhole_b0, cmd: run_t3000_llama3-small_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k llama3.2-11b tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz + { name: "t3k llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k n300 mesh llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k grok tests", arch: wormhole_b0, cmd: run_t3000_grok_tests, timeout: 30, owner_id: U03HY7MK4BT}, #Mark O'Connor { name: "t3k unet shallow tests", arch: wormhole_b0, cmd: run_t3000_unet_shallow_tests, timeout: 30, owner_id: U06ECNVR0EN}, #Evan Smal diff --git a/models/demos/llama3/demo/multimodal_demo_chat.py b/models/demos/llama3/demo/multimodal_demo_chat.py index e69de29bb2d..7b39fb3db61 100644 --- a/models/demos/llama3/demo/multimodal_demo_chat.py +++ b/models/demos/llama3/demo/multimodal_demo_chat.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Optional +from loguru import logger + +from PIL import Image as PIL_Image +from termcolor import cprint + +from models.demos.llama3.demo.multimodal_demo_text import create_multimodal_model +import llama_models.llama3.reference_impl.generation as llama_reference_generation + +from llama_models.llama3.api.datatypes import ImageMedia, UserMessage + +from pkg_resources import resource_filename + +IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) + +import torch +import pytest +import os +import ttnn + + +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "target", + ("tt", "cpu"), +) +@pytest.mark.parametrize( + "warmup_iters", + (0, 1), +) +def test_llama_multimodal_demo_chat( + mesh_device, + target, + warmup_iters, + temperature: float = 0.5, + top_p: float = 0.9, + max_seq_len: int = 512, + max_batch_size: int = 4, + max_gen_len: Optional[int] = 200, + model_parallel_size: Optional[int] = None, +): + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") + generator = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + + if target == "tt": + logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") + model = create_multimodal_model(generator.args, mesh_device) + generator.model = model + + # image understanding + dialogs = [] + with open(IMG_PATH / "dog.jpg", "rb") as f: + img = PIL_Image.open(f).convert("RGB") + + dialogs = [ + [ + UserMessage( + content=[ + ImageMedia(image=img), + "Describe this image in two sentences", + ], + ) + ], + ] + # text only + dialogs += [ + [UserMessage(content="what is the recipe of mayonnaise in two sentences?")], + ] + + print(f"Running text completion on {target}") + for _ in range(warmup_iters + 1): + for dialog in dialogs: + result = generator.chat_completion( + dialog, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + + for msg in dialog: + print(f"{msg.role.capitalize()}: {msg.content}\n") + + out_message = result.generation + print(f"> {out_message.role.capitalize()}: {out_message.content}") + for t in out_message.tool_calls: + print(f" Tool call: {t.tool_name} ({t.arguments})") + print("\n==================================\n") diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index df2e9a730d3..102b03975e4 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - # SPDX-License-Identifier: Apache-2.0 + from pathlib import Path from typing import Optional from loguru import logger @@ -8,19 +8,13 @@ from PIL import Image as PIL_Image from termcolor import cprint -import importlib +import llama_models.llama3.reference_impl.generation as llama_reference_generation -llama_reference_generation = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.generation" -) +from llama_models.llama3.api.datatypes import ImageMedia -# Must import from reference for formatter to understand type of ImageMedia -datatypes = importlib.import_module("models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.api.datatypes") -ImageMedia = datatypes.ImageMedia +from pkg_resources import resource_filename -# THIS_DIR = Path(__file__).parent.resolve() -# TODO: Generalize not to cglagovich home :) -THIS_DIR = Path("/home/cglagovich/tt-metal/models/demos/t3000/llama2_70b/reference/llama-models/models/scripts/") +IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) import torch import pytest @@ -59,14 +53,19 @@ def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16): "target", ("tt", "cpu"), ) +@pytest.mark.parametrize( + "warmup_iters", + (0, 1), +) def test_llama_multimodal_demo_text( mesh_device, target, - temperature: float = 0, + warmup_iters, + temperature: float = 0.5, top_p: float = 0.9, max_seq_len: int = 512, max_batch_size: int = 4, - max_gen_len: Optional[int] = None, + max_gen_len: Optional[int] = 200, model_parallel_size: Optional[int] = None, ): mesh_device.enable_program_cache() @@ -88,41 +87,38 @@ def test_llama_multimodal_demo_text( model = create_multimodal_model(generator.args, mesh_device) generator.model = model - with open(THIS_DIR / "resources/dog.jpg", "rb") as f: + with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") - with open(THIS_DIR / "resources/pasta.jpeg", "rb") as f: + with open(IMG_PATH / "pasta.jpeg", "rb") as f: img2 = PIL_Image.open(f).convert("RGB") - with open(THIS_DIR / "resources/ocr_image.jpeg", "rb") as f: + with open(IMG_PATH / "ocr_image.jpeg", "rb") as f: ocr_image = PIL_Image.open(f).convert("RGB") - # with open(THIS_DIR / "resources/clutter.jpeg", "rb") as f: - # clutter = PIL_Image.open(f).convert("RGB") + + with open(IMG_PATH / "clutter.jpeg", "rb") as f: + clutter = PIL_Image.open(f).convert("RGB") interleaved_contents = [ # text only - # "The color of the sky is blue but sometimes it can also be", + "The color of the sky is blue but sometimes it can also be", # image understanding - # [ - # ImageMedia(image=img), - # "If I had to write a haiku for this one", - # ], + [ImageMedia(image=img), "If I had to write a haiku for this one"], + [ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"], [ImageMedia(image=ocr_image), "The full text in this image is as follows"], - # [ - # ImageMedia(image=clutter), - # "The count of vases, books, and miscellaneous items in this image is", - # ] + [ImageMedia(image=clutter), "The count of vases, books, and miscellaneous items in this image is"], ] print(f"Running text completion on {target}") - for content in interleaved_contents: - result = generator.text_completion( - content, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) - - cprint(f"{content}", end="") - cprint(f"{result.generation}", color="yellow") - print("\n==================================\n") + for _ in range(warmup_iters + 1): + for content in interleaved_contents: + result = generator.text_completion( + content, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + + cprint(f"{content}", end="") + cprint(f"{result.generation}", color="yellow") + print("\n==================================\n") diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 388339ea584..280751cf7b9 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -184,7 +184,14 @@ def main(stdscr): commands = parse_list(command_input, allow_space=False) # Generate combinations (reordered) - combinations = [(c, m, d) for c in commands for m in models for d in devices] + # combinations = [(c, m, d) for c in commands for m in models for d in devices] + combinations = [ + (c, m, d) + for c in commands + for m in models + for d in devices + if not (m in ["11b", "11b-b"] and d == "n150") + ] # Create output entries for command, model, device in combinations: @@ -230,7 +237,7 @@ def main(stdscr): else: # Ignore enter key when exiting continue - elif c == curses.KEY_BACKSPACE or c == 127 or c == ord("x"): + elif c == curses.KEY_BACKSPACE or c == 127: if current_line < len(input_fields): current_field = current_line # Remove last character from current field @@ -506,6 +513,19 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): "model": "pytest models/demos/llama3/tests/test_llama_model.py::test_llama_model_inference[wormhole_b0-True-mesh_device0-full]", "model-prefill": "pytest models/demos/llama3/tests/test_llama_model_prefill.py::test_llama_model_inference[wormhole_b0-True-mesh_device0-4096]", "model-quick": "pytest models/demos/llama3/tests/test_llama_model.py -k quick", + "vision-mlp": "pytest models/demos/llama3/tests/multimodal/test_llama_image_mlp.py", + "vision-attn": "pytest models/demos/llama3/tests/multimodal/test_llama_image_attention.py", + "vision-block": "pytest models/demos/llama3/tests/multimodal/test_llama_image_block.py", + "vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_image_transformer.py", + "vision-xattn": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention.py", + "vision-xblock": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_block.py", + "vision-conv": "pytest models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py", + "vision-class": "pytest models/demos/llama3/tests/multimodal/test_llama_class_embedding.py", + "vision-tile-pos": "pytest models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py", + "vision-pos": "pytest models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py", + "vision-encoder": "pytest models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py", + "vision-text-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py", + "vision-vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py", } # Check if the command is a shortcut and replace it if necessary @@ -657,6 +677,7 @@ def get_llama_dir(model): "3b": os.environ.get("LLAMA_32_3B_DIR", "/proj_sw/user_dev/llama32-data/Llama3.2-3B-Instruct"), "8b": os.environ.get("LLAMA_31_8B_DIR", "/proj_sw/user_dev/llama31-8b-data/Meta-Llama-3.1-8B-Instruct"), "11b": os.environ.get("LLAMA_32_11B_DIR", "/proj_sw/user_dev/llama32-data/Llama3.2-11B-Vision-Instruct"), + "11b-b": os.environ.get("LLAMA_32_11B_BASE_DIR", "/proj_sw/user_dev/llama32-data/Llama3.2-11B-Vision"), }.get(model.lower(), "") if not llama_dir or not os.path.exists(llama_dir): diff --git a/models/demos/llama3/requirements.txt b/models/demos/llama3/requirements.txt new file mode 100644 index 00000000000..e830cffd233 --- /dev/null +++ b/models/demos/llama3/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/tenstorrent/llama-models.git@tt_metal_tag diff --git a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py index 533bc9c2106..663787a18d1 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py @@ -3,11 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 ##### Python imports ##### -import math import pytest from loguru import logger import os -import itertools ##### PyTorch imports ##### import torch @@ -56,18 +54,16 @@ def forward(self, x): @pytest.mark.parametrize( "mesh_device", [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get( + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) ) ], indirect=True, ) @pytest.mark.parametrize( - "input_shape", + "bsz, num_concurrent_media, num_chunks", [ - ((1, 4, 4, 1024, 1280)), - ((1, 4, 4, 1024 + 1, 1280)), - ((1, 4, 4, 1032, 1280)), + ((1, 4, 4)), ], ) @pytest.mark.parametrize( @@ -81,12 +77,14 @@ def test_llama_class_embedding_inference( use_program_cache, reset_seeds, # Input params - input_shape, + bsz, + num_concurrent_media, + num_chunks, layout, ensure_gc, ): dtype = ttnn.bfloat16 - pcc = 0.9999 + pcc_required = 0.9999 mesh_device.enable_async(True) @@ -97,13 +95,8 @@ def test_llama_class_embedding_inference( k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - ( - bsz, - num_concurrent_media, - num_chunks, - ntok, - dim, - ) = input_shape + ntok = nearest_32(model_args.vision_chunk_ntok) + dim = model_args.vision_dim ##### Prepare inputs ##### input_tensor = torch.randn(bsz * num_concurrent_media * num_chunks, ntok, dim) @@ -145,12 +138,8 @@ def test_llama_class_embedding_inference( # Only select output from one device tt_output_torch = tt_output_torch[..., :dim].view(reference_output.shape) - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"Llama_ClassEmbedding Passed!") - else: - logger.warning(f"Llama_ClassEmbedding Failed!") - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index 5458a1ca8c2..c38dd5ccb26 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 ##### Python imports ##### -import math import pytest from loguru import logger import os @@ -21,52 +20,12 @@ comp_pcc, comp_allclose, ) -from models.utility_functions import ( - nearest_32, -) from models.demos.llama3.tt.multimodal.llama_conv2d_patch import ( TtLlamaConv2dPatch, ) from models.demos.llama3.tt.model_config import TtModelArgs - -import importlib - -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) - - -# ##### Torch op ##### -# class Conv2dPatch(torch.nn.Module): -# """Conv2D Patching layer with model parallelism. -# Column parallel over unfolded input. -# Arguments: -# in_channels: Input channels. -# out_channels: Output channels. -# kernel_size: Size of convolution kernel. -# stride (default 1): Stride for convolution. -# bias (default False): Use bias in Conv2d. -# Input: (bsz, in_channels, width, height) -# Output: (bsz, num_tokens, out_channels) -# """ - -# def __init__(self, in_channels, out_channels, kernel_size, stride, bias) -> None: -# super().__init__() -# if isinstance(kernel_size, int): -# kernel_size = (kernel_size, kernel_size) -# self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) -# self._linear = torch.nn.Linear( -# in_channels * kernel_size[0] * kernel_size[1], -# out_channels, -# bias=bias, -# ) - -# def forward(self, x: torch.Tensor) -> torch.Tensor: -# x = self._unfold(x) -# x = x.permute(0, 2, 1) -# x = F.linear(x, self._linear.weight) -# return x +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod @skip_for_grayskull("Requires wormhole_b0 to run") @@ -79,28 +38,13 @@ ], indirect=True, ) -@pytest.mark.parametrize( - "input_shape, in_channels, out_channels, kernel_size, stride, bias", - [ - # ((1, 3, 32 * 32, 32 * 32), 3, 512, 32, 32, False), - ((1, 3, 14 * 32, 14 * 32), 3, 1280, 14, 14, False), # Llama3.2 case - ], -) def test_llama_conv2d_inference( mesh_device, use_program_cache, reset_seeds, - # Input params - input_shape, - # Conv2d patch params - in_channels, - out_channels, - kernel_size, - stride, - bias, ensure_gc, ): - pcc = 0.9999 + pcc_required = 0.9999 dtype = ttnn.bfloat16 mesh_device.enable_async(True) @@ -116,7 +60,14 @@ def test_llama_conv2d_inference( num_devices = model_args.num_devices ##### Create input tensor for the all gather ##### - B, NCH, H, W = input_shape + B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size) + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + model_args.vision_dim, + model_args.vision_patch_size, + model_args.vision_patch_size, + False, + ) assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch." assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported." @@ -126,7 +77,7 @@ def test_llama_conv2d_inference( assert W % kernel_size == 0, "Width should be divisible by kernel_size." ##### Prepare inputs ##### - input_tensor = torch.randn(input_shape) + input_tensor = torch.randn((B, NCH, H, W)) logger.info(f"Input tensor shape: {input_tensor.shape}") ##### Perform the torch ops ##### @@ -165,8 +116,4 @@ def test_llama_conv2d_inference( logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"Llama_Conv2dPatch Passed!") - else: - logger.warning(f"Llama_Conv2dPatch Failed!") - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 46f7aba4b78..14c7db894eb 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -6,11 +6,8 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_attention import TtLlamaCrossAttention from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( @@ -20,15 +17,12 @@ from models.utility_functions import ( comp_pcc, comp_allclose, + nearest_32, ) from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "vision_seq_len", - (4224,), -) @pytest.mark.parametrize( "text_seq_len", (2048,), @@ -42,11 +36,9 @@ ], indirect=True, ) -def test_llama_cross_attention_inference( - vision_seq_len, text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc -): +def test_llama_cross_attention_inference(text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 - pcc = 0.99 + pcc_required = 0.99 mesh_device.enable_async(True) @@ -70,6 +62,8 @@ def test_llama_cross_attention_inference( reference_model.load_state_dict(partial_state_dict) batch = 1 + num_chunks = 4 + vision_seq_len = num_chunks * nearest_32(model_args.vision_chunk_ntok) all_tests_pass = True @@ -99,20 +93,21 @@ def test_llama_cross_attention_inference( """ pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens) pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) - pt_xattn_cache_chunks = [ - x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache - ] + pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_xattn_tokens) tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, n_kv_heads, vision_seq_len, head_dim + batch, + n_heads, + vision_seq_len, + head_dim, ) for x in tt_xattn_cache ] for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, pcc) + passing, pcc_message = comp_pcc(pt, tt, pcc_required) logger.info(comp_allclose(pt, tt)) logger.info(f"PCC: {pcc_message}") @@ -122,6 +117,8 @@ def test_llama_cross_attention_inference( logger.warning(f"compute_xattn_kv_cache Failed!") all_tests_pass = False + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + """ Test forward, prefill and decode! """ @@ -218,13 +215,9 @@ def test_llama_cross_attention_inference( tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(batch, seq_len, dim) else: tt_output_torch = tt_output_torch[0, ..., :batch, :].transpose(0, 1).view(batch, seq_len, dim) - passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) logger.info(comp_allclose(pt_out, tt_output_torch)) logger.info(f"PCC: {pcc_message}") all_tests_pass = all_tests_pass and passing - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 1ce9d2e5699..286e7bac509 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -6,11 +6,8 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_text import ( TtLlamaCrossAttentionTransformerText, ) @@ -26,15 +23,12 @@ from models.utility_functions import ( comp_pcc, comp_allclose, + nearest_32, ) from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "vision_seq_len", - (4224,), -) @pytest.mark.parametrize( "text_seq_len", (2048,), @@ -49,15 +43,14 @@ indirect=True, ) def test_llama_cross_attention_transformer_text_inference( - vision_seq_len, text_seq_len, mesh_device, use_program_cache, reset_seeds, ): dtype = ttnn.bfloat8_b - prefill_pcc = 0.98 - decode_pcc = 0.73 + prefill_pcc_required = 0.98 + decode_pcc_required = 0.73 mesh_device.enable_async(True) @@ -93,6 +86,9 @@ def test_llama_cross_attention_transformer_text_inference( reference_model.load_state_dict(partial_state_dict) batch = 1 + num_chunks = 4 + chunk_length = nearest_32(model_args.vision_chunk_ntok) + vision_seq_len = num_chunks * chunk_length all_tests_pass = True @@ -117,24 +113,25 @@ def test_llama_cross_attention_transformer_text_inference( pt_xattn_cache_chunks = [torch.chunk(x, 2, dim=1) for x in pt_xattn_cache_chunks] pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx] # slice out replicated k/v heads - pt_xattn_cache_chunks = [ - x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache_chunks - ] + pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache_chunks] tt_xattn_cache = [layer.compute_xattn_kv_cache(tt_vision_tokens) for layer in tt_model.cross_attention_layers] tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, n_kv_heads, vision_seq_len, head_dim + batch, + n_heads, + vision_seq_len, + head_dim, ) for kv_cache in tt_xattn_cache for x in kv_cache ] for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, prefill_pcc) + passing, pcc_message = comp_pcc(pt, tt, prefill_pcc_required) logger.info(comp_allclose(pt, tt)) - logger.info(pcc_message) + logger.info(f"PCC: {pcc_message}") if passing: logger.info(f"compute_xattn_kv_cache Passed!") @@ -312,12 +309,12 @@ def test_llama_cross_attention_transformer_text_inference( tt_out = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) if mode == "prefill": tt_out = tt_out[0].reshape(logits.shape) - pcc = prefill_pcc + pcc_required = prefill_pcc_required else: tt_out = tt_out[0, ..., :batch, :].transpose(0, 1).view(logits.shape) - pcc = decode_pcc - passing, pcc_message = comp_pcc(logits, tt_out, pcc) + pcc_required = decode_pcc_required + passing, pcc_message = comp_pcc(logits, tt_out, pcc_required) logger.info(comp_allclose(logits, tt_out)) - logger.info(pcc_message) + logger.info(f"PCC: {pcc_message}") prev_pos = cur_pos - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py similarity index 70% rename from models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py rename to models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py index 3a719955c20..a3f360bfa23 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer_vision.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py @@ -6,12 +6,11 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod +from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_vision import ( + TtLlamaCrossAttentionTransformerVision, ) -from models.demos.llama3.tt.multimodal.llama_image_transformer_vision import TtLlamaCrossAttentionTransformerVision from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, @@ -26,12 +25,16 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_vision_transformer_inference(mesh_device, use_program_cache, reset_seeds): dtype = ttnn.bfloat16 - pcc = 0.79 + pcc_required = 0.79 model_args = TtModelArgs(mesh_device) state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) @@ -48,8 +51,6 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese reference_model = llama_reference_mod.CrossAttentionTransformerVision(model_args) reference_model.load_state_dict(partial_state_dict, strict=True) - all_tests_pass = True - tt_model = TtLlamaCrossAttentionTransformerVision( mesh_device, state_dict, @@ -61,7 +62,8 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese ) # Create rand inputs of the right shape - batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, 448) + batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, model_args.vision_chunk_size) + chunk_seq_len = model_args.vision_chunk_ntok # tokens per chunk, including class token images = torch.randn(batch, num_media, num_chunks, n_channel, patch_size, patch_size) ars = torch.tensor([2, 2]).reshape(batch, num_media, 2) @@ -69,24 +71,13 @@ def test_llama_vision_transformer_inference(mesh_device, use_program_cache, rese reference_output = reference_model(images, ars) tt_out = tt_model(images, ars) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) - tt_output_torch = tt_output_torch[0, :, :, :].view(reference_output.shape) + tt_output_torch = tt_output_torch[0, :, :chunk_seq_len, :].view(reference_output.shape) logger.info(f"Reference output shape: {reference_output.shape}") logger.info(f"TT output shape: {tt_output_torch.shape}") - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(pcc_message) - - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 36043e15437..f64f9c98f7f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -6,29 +6,19 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_cross_block import TtLlamaCrossAttentionTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, prepare_inputs_ttnn, ) -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) +from models.utility_functions import comp_pcc, comp_allclose, nearest_32 from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "vision_seq_len", - (4224,), -) @pytest.mark.parametrize( "text_seq_len", (2048,), @@ -43,10 +33,10 @@ indirect=True, ) def test_llama_cross_attention_transformer_block_inference( - vision_seq_len, text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc + text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc ): dtype = ttnn.bfloat16 - pcc = 0.99 + pcc_required = 0.99 mesh_device.enable_async(True) @@ -68,6 +58,8 @@ def test_llama_cross_attention_transformer_block_inference( reference_model.load_state_dict(partial_state_dict) batch = 1 + num_chunks = 4 + vision_seq_len = num_chunks * nearest_32(model_args.vision_chunk_ntok) all_tests_pass = True @@ -93,20 +85,21 @@ def test_llama_cross_attention_transformer_block_inference( """ pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens) pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) - pt_xattn_cache_chunks = [ - x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache - ] + pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_xattn_tokens) tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, n_kv_heads, vision_seq_len, head_dim + batch, + n_heads, + vision_seq_len, + head_dim, ) for x in tt_xattn_cache ] for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, pcc) + passing, pcc_message = comp_pcc(pt, tt, pcc_required) logger.info(comp_allclose(pt, tt)) logger.info(f"PCC: {pcc_message}") @@ -116,6 +109,8 @@ def test_llama_cross_attention_transformer_block_inference( logger.warning(f"compute_xattn_kv_cache Failed!") all_tests_pass = False + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + """ Test forward, prefill and decode! """ @@ -232,13 +227,9 @@ def test_llama_cross_attention_transformer_block_inference( tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(batch, seq_len, dim) else: tt_output_torch = tt_output_torch[0, ..., :batch, :].transpose(0, 1).view(batch, seq_len, dim) - passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) logger.info(comp_allclose(pt_out, tt_output_torch)) logger.info(f"PCC: {pcc_message}") all_tests_pass = all_tests_pass and passing - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py index dce2fedf4bc..49f4ee58d2f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py @@ -6,16 +6,11 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -encoder_utils = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.encoder_utils" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod +from llama_models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_attention import TtLlamaImageAttention -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding +from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, @@ -29,8 +24,8 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( - "batch, num_chunks, ntok", - ((1, 4, 1024),), + "batch, num_chunks", + ((1, 4),), ) @pytest.mark.parametrize( "mesh_device", @@ -41,9 +36,9 @@ ], indirect=True, ) -def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_program_cache, reset_seeds, ensure_gc): +def test_llama_attention_inference(batch, num_chunks, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 - pcc = 0.99 + pcc_required = 0.99 mesh_device.enable_async(True) @@ -56,13 +51,12 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - dim = 1280 - heads = 16 + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + ntok = model_args.vision_chunk_ntok reference_model = llama_reference_mod.ImageAttention(dim=dim, head_dim=dim // heads, n_heads=heads) reference_model.load_state_dict(partial_state_dict) - all_tests_pass = True - tt_model = TtLlamaImageAttention( mesh_device, state_dict, @@ -94,8 +88,8 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro ) tt_attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) # Make striped attention mask to mask out our padding between 8 and 32 - # Striped mask doesn't affect PCC - # tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, 32, num_chunks) + # Striped mask doesn't affect PCC on first layer but is necessary for later layers + tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, npadtt, num_chunks) attention_input = attention_input.reshape(1, batch, -1, dim) @@ -111,7 +105,7 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro tt_out = tt_model(attention_input, mask=tt_mask) # Doing contract in tt is correct!! - tt_out = tt_out.reshape(batch, num_chunks, ntok + 32, dim) + tt_out = tt_out.reshape(batch, num_chunks, ntok + npadtt, dim) tt_out = ttnn.slice(tt_out, (0, 0, 0, 0), (batch, num_chunks, ntok, dim)) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] @@ -119,18 +113,9 @@ def test_llama_attention_inference(batch, num_chunks, ntok, mesh_device, use_pro reference_output = reference_output.reshape(batch, num_chunks, ntok + npad, dim) reference_output = encoder_utils.contract_num_tokens_from_mult8(reference_output, npad) - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 38c4356b406..8eecfe156d6 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -6,16 +6,11 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -encoder_utils = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.encoder_utils" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod +from llama_models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_block import TtLlamaImageTransformerBlock -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding +from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, @@ -29,8 +24,8 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( - "batch, num_chunks, ntok", - ((1, 4, 1024),), + "batch, num_chunks", + ((1, 4),), ) @pytest.mark.parametrize( "gated", @@ -38,12 +33,16 @@ ) @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) -def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_program_cache, reset_seeds, ensure_gc): +def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 - pcc = 0.99 + pcc_required = 0.99 mesh_device.enable_async(True) @@ -61,13 +60,12 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ dim = model_args.vision_dim heads = model_args.vision_attn_n_heads + ntok = model_args.vision_chunk_ntok reference_model = llama_reference_mod.ImageTransformerBlock( d_model=dim, n_head=heads, mlp_ratio=model_args.vision_mlp_ratio, gated=gated ) reference_model.load_state_dict(partial_state_dict) - all_tests_pass = True - tt_model = TtLlamaImageTransformerBlock( mesh_device, state_dict, @@ -99,6 +97,7 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ attention_input.shape[0], attention_input.shape[1], attention_input.shape[2], attention_input.shape[3] ) tt_attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) + tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, npadtt, num_chunks) attention_input = attention_input.reshape(1, batch, -1, dim) tt_mask = ttnn.from_torch( @@ -111,7 +110,7 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ ) tt_out = tt_model(attention_input, mask=tt_mask) - tt_out = tt_out.reshape(batch, num_chunks, ntok + 32, dim) + tt_out = tt_out.reshape(batch, num_chunks, ntok + npadtt, dim) tt_out = ttnn.slice(tt_out, (0, 0, 0, 0), (batch, num_chunks, ntok, dim)) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] @@ -119,18 +118,8 @@ def test_llama_block_inference(batch, num_chunks, ntok, mesh_device, gated, use_ reference_output = reference_output.reshape(batch, num_chunks, ntok + npad, dim) reference_output = encoder_utils.contract_num_tokens_from_mult8(reference_output, npad) - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py index 900be3f49fe..c6b65ef7f9d 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py @@ -7,24 +7,22 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_image_mlp import TtLlamaImageFeedForward from models.demos.llama3.tt.model_config import TtModelArgs from models.utility_functions import ( comp_pcc, comp_allclose, + nearest_32, ) from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( - "seq_len", - (4224,), + "batch, num_chunks", + ((1, 4),), ) @pytest.mark.parametrize( "mesh_device", @@ -35,7 +33,7 @@ ], indirect=True, ) -def test_llama_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +def test_llama_mlp_inference(batch, num_chunks, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 mesh_device.enable_async(True) @@ -51,50 +49,9 @@ def test_llama_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seed model_args.WEIGHTS_DTYPE = dtype - """ - self.patch_size = 14 - self.vision_encoder = VisionEncoder( - max_num_tiles=4, - image_size=args.vision_chunk_size, - patch_size=self.patch_size, - n_global_layers=8, - global_model=True, - return_intermediate=return_intermediate, - ) - - class VisionEncoder(nn.Module): - def __init__( - self, - max_num_tiles: int, - ckpt_path: str = None, - image_size: int = 224, - patch_size: int = 14, - width: int = 1280, - layers: int = 32, - heads: int = 16, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - in_channels: int = 3, - load_ckpt: bool = False, - n_global_layers: int = 2, - global_model: bool = False, - return_intermediate=None, - ... - - self.global_transformer = ImageTransformer( - width, n_global_layers, heads, mlp_ratio, act_layer=act_layer, gated=True - - self.mlp = ImageFeedForward( - dim=d_model, - hidden_dim=int(mlp_ratio * d_model), - dropout=0.0, - act_layer=act_layer, - ) - ) - """ - - dim = 1280 - mlp_ratio = 4.0 + dim = model_args.vision_dim + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks + mlp_ratio = model_args.vision_mlp_ratio act_layer = torch.nn.GELU dropout = 0.0 reference_model = llama_reference_mod.ImageFeedForward( @@ -113,7 +70,7 @@ def __init__( weight_cache_path=model_args.weight_cache_path(dtype), dtype=dtype, ) - torch_input = torch.randn(1, 1, seq_len, dim) + torch_input = torch.randn(1, batch, seq_len, dim) reference_output = reference_model(torch_input).squeeze() tt_input = ttnn.from_torch( torch_input, @@ -136,9 +93,4 @@ def __init__( logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info("Llama_MLP Passed!") - else: - logger.warning("Llama_MLP Failed!") - - assert passing, f"Llama_MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index c425caec570..b92d74290d6 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -6,17 +6,12 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -encoder_utils = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.encoder_utils" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod +from llama_models.llama3.reference_impl.multimodal import encoder_utils from models.demos.llama3.tt.multimodal.llama_image_transformer import TtLlamaImageTransformer from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import pad_seq_one_tile, mask_tile_padding +from models.demos.llama3.tt.multimodal.llama_vision_encoder import pad_seq_one_tile, mask_tile_padding from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, ) @@ -29,8 +24,8 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( - "batch, num_chunks, ntok", - ((1, 4, 1024),), + "batch, num_chunks", + ((1, 4),), ) @pytest.mark.parametrize( "is_global", @@ -38,14 +33,18 @@ ) @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_image_transformer_inference( - batch, num_chunks, ntok, mesh_device, is_global, use_program_cache, reset_seeds, ensure_gc + batch, num_chunks, mesh_device, is_global, use_program_cache, reset_seeds, ensure_gc ): dtype = ttnn.bfloat16 - pcc = 0.86 + pcc_required = 0.86 mesh_device.enable_async(True) @@ -59,11 +58,11 @@ def test_llama_image_transformer_inference( n_layers = model_args.vision_n_global_layers return_intermediate = None else: - # first_layer_prefix = "vision_model.vision_encoder.transformer." first_layer_prefix = "vision_model.vision_encoder." gated = False n_layers = model_args.vision_n_layers # return_intermediate = [int(l) for l in "3,7,15,23,30".split(",")] + # Checks all intermediates return_intermediate = list(range(n_layers)) partial_state_dict = { @@ -71,7 +70,7 @@ def test_llama_image_transformer_inference( } dim = model_args.vision_dim - heads = model_args.vision_attn_n_heads + ntok = model_args.vision_chunk_ntok - 1 # NOTE: -1 to remove class embedding reference_model = llama_reference_mod.VisionEncoder( max_num_tiles=4, @@ -129,7 +128,6 @@ def test_llama_image_transformer_inference( ) tt_attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) # Make striped attention mask to mask out our padding between 8 and 32 - # Striped mask doesn't affect PCC tt_attn_mask = mask_tile_padding(tt_attn_mask, ntok, npadtt, num_chunks) attention_input = attention_input.reshape(1, batch, -1, dim) @@ -167,25 +165,16 @@ def test_llama_image_transformer_inference( intermediates = [i.squeeze(-1) for i in intermediates] reference_output = reference_output.reshape(batch, num_chunks, ntok + npad, dim) reference_output = encoder_utils.contract_num_tokens_from_mult8(reference_output, npad) - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) - # Check mse + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + all_tests_pass = all_tests_pass and passing logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") if return_intermediate: for idx, (pt_interm, tt_interm) in enumerate(zip(intermediates, tt_intermed_torch)): - passing, pcc_message = comp_pcc(pt_interm, tt_interm, pcc) + passing, pcc_message = comp_pcc(pt_interm, tt_interm, pcc_required) logger.info(f"Intermediate {idx}: {pcc_message}") logger.info(comp_allclose(pt_interm, tt_interm)) + all_tests_pass = all_tests_pass and passing - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py index 9e36eb42247..d52d9f415f3 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py +++ b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py @@ -7,25 +7,19 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod from models.demos.llama3.tt.multimodal.llama_layernorm import TtLayerNorm # Updated import for LayerNorm from models.demos.llama3.tt.model_config import TtModelArgs from models.utility_functions import ( comp_pcc, comp_allclose, + nearest_32, ) from models.utility_functions import skip_for_grayskull @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - (4224,), -) @pytest.mark.parametrize( "mesh_device", [ @@ -35,13 +29,15 @@ ], indirect=True, ) -def test_layernorm_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +def test_layernorm_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 mesh_device.enable_async(True) model_args = TtModelArgs(mesh_device) width = model_args.vision_dim + num_chunks = 4 + seq_len = nearest_32(model_args.vision_chunk_ntok) * num_chunks state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -99,9 +95,4 @@ def test_layernorm_inference(mesh_device, seq_len, use_program_cache, reset_seed logger.info(comp_allclose(reference_output, tt_output)) logger.info(f"PCC: {pcc_message}") - if passing: - logger.info("LayerNorm on device {idx} Passed!") - else: - logger.warning("LayerNorm {idx} Failed!") - - assert passing, f"LayerNorm output does not meet PCC requirement {pcc_required}: {pcc_message}." + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py index 7430796270f..c5262bf2235 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py @@ -31,8 +31,6 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs -import importlib - ##### Torch op ##### class PositionalEmbedding(nn.Module): @@ -78,45 +76,29 @@ def forward(self, x, ar): @pytest.mark.parametrize( "mesh_device", [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get( + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) ) ], indirect=True, ) @pytest.mark.parametrize( - "image_size, patch_size", - [ - ((448, 448), (14, 14)), - ], -) -@pytest.mark.parametrize( - "input_shape, max_num_tiles", - [ - ((1, 4, 4, 1024 + 1, 1280), 4), - ], -) -@pytest.mark.parametrize( - "layout", - [ - ttnn.TILE_LAYOUT, - ], + "bsz, num_concurrent_media, num_chunks", + [(1, 4, 4)], ) def test_llama_positional_embedding_inference( mesh_device, use_program_cache, reset_seeds, # Input params - input_shape, - layout, - # Positional Embedding params - image_size, - patch_size, - max_num_tiles, + bsz, + num_concurrent_media, + num_chunks, ensure_gc, ): dtype = ttnn.bfloat16 - pcc = 0.9999 + layout = ttnn.TILE_LAYOUT + pcc_required = 0.9999 mesh_device.enable_async(True) @@ -127,15 +109,13 @@ def test_llama_positional_embedding_inference( k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - ( - bsz, - num_concurrent_media, - num_chunks, - ntok, - dim, - ) = input_shape + ntok = model_args.vision_chunk_ntok + dim = model_args.vision_dim + image_size = (model_args.vision_chunk_size, model_args.vision_chunk_size) + patch_size = (model_args.vision_patch_size, model_args.vision_patch_size) ##### Check parms ##### + max_num_tiles = model_args.vision_max_num_chunks assert num_chunks == max_num_tiles, "num_chunks must be the same value as max_num_tiles!" ##### Prepare inputs ##### @@ -198,13 +178,8 @@ def test_llama_positional_embedding_inference( # Only select output from one device tt_output_torch = tt_output_torch[..., :dim] - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info(f"Llama_PositionalEmbedding Passed!") - else: - logger.warning(f"Llama_PositionalEmbedding Failed!") - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py index 619ca0bdb60..4ba64dd76ff 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py @@ -31,61 +31,42 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs -import importlib - -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get( + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) ) ], indirect=True, ) @pytest.mark.parametrize( - "gated", - [ - True, - ], -) -@pytest.mark.parametrize( - "input_shape, dim, max_num_tiles", + "bsz, num_concurrent_media, num_chunks", [ - ((1, 32, 4, 1032), 1280, 4), - ((1, 8, 4, 1032), 1280, 4), - ((1, 4, 4, 1032), 1280, 4), - ((1, 1, 4, 1032), 1280, 4), - ((1, 1, 4, 1024), 1280, 4), - # ((1, 32, 16, 1032), 1280, 16), # Large test, takes some time - ], -) -@pytest.mark.parametrize( - "layout", - [ - ttnn.TILE_LAYOUT, + (1, 1, 4), + (1, 4, 4), ], ) +@pytest.mark.parametrize("pre_embed", [False, True]) def test_llama_conv2d_inference( mesh_device, use_program_cache, reset_seeds, # Input params - input_shape, - layout, - # Tile Position Embedding params - dim, - gated, - max_num_tiles, + bsz, + num_concurrent_media, + num_chunks, + pre_embed, ensure_gc, ): dtype = ttnn.bfloat16 - pcc = 0.9999 + layout = ttnn.TILE_LAYOUT + gated = True + pcc_required = 0.9999 mesh_device.enable_async(True) @@ -93,13 +74,16 @@ def test_llama_conv2d_inference( state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - first_layer_prefix = "vision_model.vision_encoder.pre_tile_pos_embed." + first_layer_prefix = "vision_model.vision_encoder." + ( + "pre_tile_pos_embed." if pre_embed else "post_tile_pos_embed." + ) partial_state_dict = { k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - num_devices = model_args.num_devices - bsz, num_concurrent_media, num_chunks, ntok = input_shape + ntok = nearest_32(model_args.vision_chunk_ntok - (0 if pre_embed else 1)) + dim = model_args.vision_dim + max_num_tiles = model_args.vision_max_num_tiles ##### Check parms ##### assert num_chunks == max_num_tiles, "num_chunks must be the same value as max_num_tiles!" @@ -160,13 +144,8 @@ def test_llama_conv2d_inference( # Only select output from one device tt_output_torch = tt_output_torch[..., :dim] - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info(f"Llama_TilePositionEmbedding Passed!") - else: - logger.warning(f"Llama_TilePositionEmbedding Failed!") - assert passing, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py similarity index 50% rename from models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py rename to models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py index 4da57f6cc33..b3790d498e5 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_vision_encoder.py +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py @@ -6,12 +6,9 @@ from loguru import logger import os import ttnn -import importlib -llama_reference_mod = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import TtLlamaVisionEncoder +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_mod +from models.demos.llama3.tt.multimodal.llama_vision_encoder import TtLlamaVisionEncoder from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( prepare_inputs_ttnn_prefill, @@ -26,12 +23,16 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (2, 4), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_seeds): dtype = ttnn.bfloat16 - pcc = 0.88 + pcc_required = 0.88 model_args = TtModelArgs(mesh_device) state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) @@ -55,8 +56,6 @@ def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_se ) reference_model.load_state_dict(partial_state_dict, strict=True) - all_tests_pass = True - tt_model = TtLlamaVisionEncoder( mesh_device, state_dict, @@ -68,7 +67,7 @@ def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_se ) # Create rand inputs of the right shape - batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, 448) + batch, num_media, num_chunks, n_channel, patch_size = (1, 1, 4, 3, model_args.vision_chunk_size) images = torch.randn(batch, num_media, num_chunks, n_channel, patch_size, patch_size) ars = torch.tensor([2, 2]).reshape(batch, num_media, 2) @@ -78,45 +77,22 @@ def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_se tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) tt_output_torch = tt_output_torch[0, :, :, :].view(reference_output.shape) - INTERM_OUT = True - if INTERM_OUT: - # reference_output is [x] + [shuffled_int_x] - # tt_output is [x] + [int_x] - # To compare, we will shuffle tt_output. NOTE! This requires that the vision model shuffle its projection weights - tt_output_shuffled = torch.zeros_like(tt_output_torch) - tt_output_shuffled[..., : model_args.vision_dim] = tt_output_torch[..., : model_args.vision_dim] - tt_int_x = tt_output_torch[..., model_args.vision_dim :] - tt_int_x = ( - tt_int_x.reshape(reference_output.shape[:-1] + (5, model_args.vision_dim)) - .transpose(-1, -2) - .reshape(reference_output.shape[:-1] + (model_args.vision_dim * 5,)) - ) - tt_output_shuffled[..., model_args.vision_dim :] = tt_int_x - - logger.info(f"Reference output shape: {reference_output.shape}") - logger.info(f"TT output shape: {tt_output_shuffled.shape}") - - passing, pcc_message = comp_pcc(reference_output, tt_output_shuffled, pcc) - - logger.info(comp_allclose(reference_output, tt_output_shuffled)) - logger.info(f"PCC: {pcc_message}") - else: - logger.info(f"Reference output shape: {reference_output.shape}") - logger.info(f"TT output shape: {tt_output_torch.shape}") - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(pcc_message) - - if passing: - logger.info(f"Llama_Attention Passed!") - else: - logger.warning(f"Llama_Attention Failed!") - all_tests_pass = False - - if all_tests_pass: - logger.info("Llama Attention output Passed!") - else: - logger.warning("Llama Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + # reference_output is [x] + [shuffled_int_x] + # tt_output is [x] + [int_x] + # To compare, we will shuffle tt_output. + tt_output_shuffled = torch.zeros_like(tt_output_torch) + tt_output_shuffled[..., : model_args.vision_dim] = tt_output_torch[..., : model_args.vision_dim] + tt_int_x = tt_output_torch[..., model_args.vision_dim :] + tt_int_x = ( + tt_int_x.reshape(reference_output.shape[:-1] + (5, model_args.vision_dim)) + .transpose(-1, -2) + .reshape(reference_output.shape[:-1] + (model_args.vision_dim * 5,)) + ) + tt_output_shuffled[..., model_args.vision_dim :] = tt_int_x + + passing, pcc_message = comp_pcc(reference_output, tt_output_shuffled, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_shuffled)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_model.py b/models/demos/llama3/tests/multimodal/test_llama_vision_model.py new file mode 100644 index 00000000000..f55a47891ac --- /dev/null +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_model.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path +from typing import Optional +from loguru import logger + +from PIL import Image as PIL_Image +from termcolor import cprint + +import llama_models.llama3.reference_impl.generation as llama_reference_generation + +from llama_models.llama3.api.datatypes import ImageMedia + +from models.utility_functions import ( + comp_pcc, + comp_allclose, +) + +THIS_DIR = Path(__file__).parent.parent.parent.resolve() / "reference/llama_models/models/scripts/" + +import torch +import pytest +import os +import ttnn + + +def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16): + from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer + from models.demos.llama3.tt.model_config import TtModelArgs + + tt_model_args = TtModelArgs(mesh_device) + checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) + model = CrossAttentionTransformer( + model_args, + mesh_device, + checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + ) + model.setup_cache(model_args.max_batch_size, torch.float32) + return model + + +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_llama_vision_model( + mesh_device, + temperature: float = 0, + max_seq_len: int = 512, + max_batch_size: int = 4, + max_gen_len: Optional[int] = 50, + model_parallel_size: Optional[int] = None, +): + """ + This test runs the Llama3.2 vision model on CPU and TT concurrently. + It does not use teacher forcing and compares output logits at every token. + """ + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") + generator_pt = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + + generator_tt = llama_reference_generation.Llama(generator_pt.model, generator_pt.tokenizer, generator_pt.args) + logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") + model = create_multimodal_model(generator_tt.args, mesh_device) + generator_tt.model = model + + # with open(THIS_DIR / "resources/dog.jpg", "rb") as f: + # img = PIL_Image.open(f).convert("RGB") + + # with open(THIS_DIR / "resources/pasta.jpeg", "rb") as f: + # img2 = PIL_Image.open(f).convert("RGB") + + with open(THIS_DIR / "resources/ocr_image.jpeg", "rb") as f: + ocr_image = PIL_Image.open(f).convert("RGB") + + # with open(THIS_DIR / "resources/clutter.jpeg", "rb") as f: + # clutter = PIL_Image.open(f).convert("RGB") + + interleaved_contents = [ + # text only + # "The color of the sky is blue but sometimes it can also be", + # image understanding + # [ImageMedia(image=img), "If I had to write a haiku for this one"], + # [ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"], + [ImageMedia(image=ocr_image), "The full text in this image is as follows"], + # [ImageMedia(image=clutter), "The count of vases, books, and miscellaneous items in this image is"], + ] + + for content in interleaved_contents: + logger.info(f"Generating text for content: {content}") + model_input = generator_pt.formatter.encode_content(content) + gen_pt = generator_pt.generate( + model_input, max_gen_len=max_gen_len, temperature=temperature, return_logits=True + ) + gen_tt = generator_tt.generate( + model_input, max_gen_len=max_gen_len, temperature=temperature, return_logits=True + ) + + for out_idx, (token_pt, token_tt) in enumerate(zip(gen_pt, gen_tt)): + logger.info(f"Comparing output token {out_idx}") + out_pt, out_tt = token_pt[1], token_tt[1] + out_pt = out_pt[0, -1] + out_tt = out_tt[0, -1] + passing, pcc_message = comp_pcc(out_pt, out_tt, 0.90) + print(f"PCC: {pcc_message}") + # Check shapes of logprobs + + ref_argmax = torch.argmax(out_pt).item() + ref_logprob = out_pt[ref_argmax].item() + ref_token = generator_pt.tokenizer.decode([ref_argmax]) + + # Reference model: top-5 tokens + ref_top5_vals, ref_top5_idxs = torch.topk(out_pt, 5) + ref_top5_tokens = [generator_pt.tokenizer.decode([idx.item()]) for idx in ref_top5_idxs] + ref_top5_logprobs = ref_top5_vals.tolist() + + # Test model: top-5 tokens + top5_vals, top5_idxs = torch.topk(out_tt, 5) + top5_tokens = [generator_pt.tokenizer.decode([idx.item()]) for idx in top5_idxs] + top5_logprobs = top5_vals.tolist() + + def entropy(logits): + probs = torch.softmax(logits, dim=-1) + return -(probs * torch.log(probs)).sum().item() + + # Print the information + print(f"Token Position {out_idx}:") + print(f" Reference | Test") + print(f" Entropy: {entropy(out_pt):.4f} | {entropy(out_tt):.4f}") + print(f" Top-5 Tokens:") + for rank in range(5): + print( + f" {rank+1}. Token='{ref_top5_tokens[rank]}' @ {ref_top5_logprobs[rank]:.2f} | '{top5_tokens[rank]}' @ {top5_logprobs[rank]:.2f}" + ) + print() diff --git a/models/demos/llama3/tests/test_llama_perf.py b/models/demos/llama3/tests/test_llama_perf.py index 260a0221e9d..ea3c922646d 100644 --- a/models/demos/llama3/tests/test_llama_perf.py +++ b/models/demos/llama3/tests/test_llama_perf.py @@ -52,7 +52,7 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_ elif "3.1-8B" in model_args.DEFAULT_CACHE_PATH: expected_inference_time = 0.07 elif "3.2-11B" in model_args.DEFAULT_CACHE_PATH: - expected_inference_time = 0.07 + expected_inference_time = 0.085 else: assert False, f"Llama model not found. Supported Llama models: [3.2-1B, 3.2-3B, 3.1-8B]" diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 7d32bf6ce6f..30f0055c497 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -208,7 +208,7 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s self.compute_kernel_config_sdpa = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi4, math_approx_mode=False, - fp32_dest_acc_en=False, + fp32_dest_acc_en=True, packer_l1_acc=False, ) @@ -476,7 +476,7 @@ def find_largest_divisor(n, max_divisor=8): k=cache_seq_len, n=self.head_dim, grid_size=(8, 8), - in0_block_w=1, + # in0_block_w=1, # TODO: Remove this when we get non-causal FlashDecode fuse_batch=False, ) self.model_config["VISION_XATTN_DENSE_PROGCFG"] = lambda seq_len: self.matmul_config( @@ -500,14 +500,14 @@ def find_largest_divisor(n, max_divisor=8): self.model_config["CROSS_TRANSFORMER_TEXT_OUTPUT_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config( m=min(seq_len, max_seq), k=self.dim, - n=self.vocab_size - // 4 - // self.num_devices, # TODO: Remove magic number 8 from cross attention transformer text + n=self.vocab_size // 8, # Magic number. LM Head always contains 8 splits grid_size=(8, 8), in0_block_w=1, fuse_batch=seq_len <= max_seq, ) + self.VISION_MAX_MM_SEQ = nearest_32(self.vision_chunk_ntok) + def _set_llama_params_from_dict(self, params): # Text params self.dim = params["dim"] @@ -535,13 +535,20 @@ def _set_llama_params_from_dict(self, params): self.vision_act_layer = ttnn.UnaryOpType.GELU self.vision_dropout = 0.0 self.vision_attn_n_heads = 16 - self.vision_head_dim = self.vision_hidden_dim // self.vision_attn_n_heads + self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads self.vision_n_layers = 32 self.vision_n_global_layers = 8 self.vision_max_num_tiles = 4 self.vision_patch_size = 14 self.vision_in_channels = 3 + @property + def vision_chunk_ntok(self): + """ + Returns the number of tokens per chunk, accounting for the extra class token + """ + return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1 + def _set_llama_params(self, checkpoint_dir): params_file = os.path.join(checkpoint_dir, "params.json") assert os.path.exists(params_file), f"params.json file not found at {params_file}" diff --git a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py index 668edf2c001..a4d1bb59885 100644 --- a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py +++ b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py @@ -85,7 +85,6 @@ def __init__( fp32_dest_acc_en=True, packer_l1_acc=True, ) - self.program_config = None # TODO: Update with actual program config def forward(self, x: torch.Tensor): x = self._unfold(x) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 4303da8acf8..a1554b0dfe7 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -47,7 +47,8 @@ def __init__( self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 - self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + + self.configuration = configuration self.model_config = configuration.get_model_config() @@ -131,7 +132,7 @@ def __init__( def compute_xattn_kv_cache(self, xattn_tokens): bsz, seqlen_y = xattn_tokens.shape[1], xattn_tokens.shape[2] - MAX_MM_SEQ_LEN = 1056 + MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ if seqlen_y > MAX_MM_SEQ_LEN: xattn_tokens = ttnn.reshape(xattn_tokens, [1, bsz * seqlen_y // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) @@ -156,29 +157,42 @@ def compute_xattn_kv_cache(self, xattn_tokens): xk = ttnn.reshape(xk, [1, bsz, seqlen_y, -1]) xv = ttnn.reshape(xv, [1, bsz, seqlen_y, -1]) - xk, _, _ = ttnn.experimental.nlp_create_qkv_heads( - xk, xk, num_heads=self.n_local_kv_heads, num_kv_heads=self.n_local_kv_heads // 2, transpose_k_heads=False - ) - xv, _, _ = ttnn.experimental.nlp_create_qkv_heads( - xv, xv, num_heads=self.n_local_kv_heads, num_kv_heads=self.n_local_kv_heads // 2, transpose_k_heads=False - ) + if self.n_local_kv_heads == 1: + # Only a simple reshape required, no need to split + xk = ttnn.reshape(xk, [bsz, 1, seqlen_y, -1]) + xv = ttnn.reshape(xv, [bsz, 1, seqlen_y, -1]) + else: + # 1, B, S, D -> B, NH, S, DH? + xk, _, _ = ttnn.experimental.nlp_create_qkv_heads( + xk, + xk, + num_heads=self.n_local_kv_heads, + num_kv_heads=self.n_local_kv_heads // 2, + transpose_k_heads=False, + ) + xv, _, _ = ttnn.experimental.nlp_create_qkv_heads( + xv, + xv, + num_heads=self.n_local_kv_heads, + num_kv_heads=self.n_local_kv_heads // 2, + transpose_k_heads=False, + ) xk = self.k_norm(xk) + + # NOTE: Doing repeat in xattn_cache generation to avoid massive overhead in forward + xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) + xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) return [xk, xv] - ### EVERYTHING BELOW IS BROKEN OMG - # BEWARNED! TMs are dangerous! - # WORKAROUND + ### Below is how I would like to implement TMs, but it results in poor PCC xk = ttnn.to_layout(xk, layout=ttnn.ROW_MAJOR_LAYOUT) xv = ttnn.to_layout(xv, layout=ttnn.ROW_MAJOR_LAYOUT) xk = xk.reshape(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) xv = xv.reshape(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) - # xk = ttnn.to_memory_config(xk, ttnn.L1_MEMORY_CONFIG) - # xk = ttnn.to_memory_config(xk, ttnn.DRAM_MEMORY_CONFIG) - return xk - xk = ttnn.transpose(xk, 1, 2, memory_config=ttnn.L1_MEMORY_CONFIG) + xk = ttnn.transpose(xk, 1, 2) xv = ttnn.transpose(xv, 1, 2) xk = ttnn.to_layout(xk, layout=ttnn.TILE_LAYOUT) @@ -191,10 +205,7 @@ def compute_xattn_kv_cache(self, xattn_tokens): return [xk, xv] def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache): - # batch = x_11SH.shape[-2] batch = xattn_cache[0].shape[0] - # assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - # 1, B, D xq = ttnn.linear( x_11SH, @@ -224,9 +235,6 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xk, xv = xattn_cache cache_seq_len = xk.shape[-2] - xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) - xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) - scores = ttnn.matmul( xq, ttnn.transpose(xk, -1, -2), @@ -283,7 +291,7 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache): seq_len = x_11SH.shape[-2] # B, S, D - # assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 128" + assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 32" if seq_len > 1024: x_11SH = ttnn.reshape(x_11SH, [1, seq_len // 1024, 1024, -1]) @@ -310,20 +318,12 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH xk, xv = xattn_cache cache_seq_len = xk.shape[-2] - # NOTE: Using naive SDPA for now since FlashDecode does not allow non-causal mask - # xq = ttnn.reshape(xq, [self.n_local_heads // self.n_local_kv_heads, self.n_local_kv_heads, seq_len, self.head_dim]) - # NOTE: repeat doesn't work, need to use repeat_interleave - # xk = ttnn.repeat(xk, ttnn.Shape((self.n_local_heads // self.n_local_kv_heads, 1, 1, 1))) - xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) - # xv = ttnn.repeat(xv, ttnn.Shape((self.n_local_heads // self.n_local_kv_heads, 1, 1, 1))) - xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) - scores = ttnn.matmul( xq, ttnn.transpose(xk, -1, -2), dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config_hifi4, + compute_kernel_config=self.compute_kernel_config_hifi2, program_config=self.model_config["VISION_XATTN_SCORE_PROGCFG"](seq_len, cache_seq_len), ) @@ -332,8 +332,6 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH scores = ttnn.add(scores, xattn_mask) scores = ttnn.softmax(scores, dim=-1, numeric_stable=True) - # TODO: scale_mask_softmax doesn't work for this xattn_mask shape - # scores = ttnn.scale_mask_softmax(scores, self.scale, xattn_mask, numeric_stable=True) output = ttnn.matmul( scores, xv, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 777a19d3b61..eb552bc0e17 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -55,13 +55,6 @@ def __init__( self.model_config = configuration.get_model_config() self.state_dict = state_dict - # self.tok_embeddings = TtLlamaEmbedding( - # mesh_device=mesh_device, - # args=configuration, - # weight_cache_path=configuration.weight_cache_path(dtype), - # state_dict=state_dict, - # dtype=ttnn.bfloat16, # Row major layout requires bfloat16 - # ) # NOTE: Running all embeddings in torch for now since learnable embeddings use complex indexing ops which must be in torch self.tok_embeddings = torch.nn.Embedding(configuration.vocab_size, configuration.dim) tok_embedding_prefix = f"{state_dict_prefix}tok_embeddings." @@ -79,23 +72,10 @@ def __init__( weight_key="norm", ) - # # self.output layer weight # TODO: Generalize LMHead, maybe use llama_model's single-tile-sequence LMHead - # self.output = LMHead( - # configuration, - # mesh_device, - # ttnn.bfloat8_b, - # state_dict, - # state_dict_prefix, - # weight_cache_path, - # ) - - # torch_weight = lambda name, suffix: torch.transpose( - # self.state_dict[f"{state_dict_prefix}{name}.{suffix}"], -2, -1 - # ) - lm_head_torch = self.state_dict[f"{state_dict_prefix}output.weight"].transpose(-1, -2) - num_splits = 4 # arbitrary, reasonable number + total_splits = 8 # Arbitrary value which allows whole-tile splits in LM Head + num_splits = total_splits // self.configuration.num_devices lm_head_torch = torch.chunk(lm_head_torch, num_splits, dim=-1) cache_name = lambda name, suffix, split: weight_cache_path / (state_dict_prefix + f"{name}{suffix}{split}") @@ -113,7 +93,6 @@ def __init__( self.outputs = [ as_interleaved_tensor("output", "weight", idx, ttnn.bfloat8_b, dim=-1) for idx in range(len(lm_head_torch)) ] - # self.output = as_interleaved_tensor("output", "weight", ttnn.bfloat8_b, dim=-1) self.n_llama_layers = configuration.n_layers self.model_dim = configuration.dim diff --git a/models/demos/llama3/tt/multimodal/llama_image_transformer_vision.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py similarity index 97% rename from models/demos/llama3/tt/multimodal/llama_image_transformer_vision.py rename to models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py index b02d04097b7..cded53b7120 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_transformer_vision.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py @@ -11,7 +11,7 @@ nearest_32, ) from models.common.lightweightmodule import LightweightModule -from models.demos.llama3.tt.multimodal.llama_image_vision_encoder import TtLlamaVisionEncoder +from models.demos.llama3.tt.multimodal.llama_vision_encoder import TtLlamaVisionEncoder from models.demos.falcon7b_common.tests.test_utils import ( synchronize_devices, diff --git a/models/demos/llama3/tt/multimodal/llama_image_attention.py b/models/demos/llama3/tt/multimodal/llama_image_attention.py index 20da869640c..b15e64d2374 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_image_attention.py @@ -45,6 +45,7 @@ def __init__( self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + self.configuration = configuration self.model_config = configuration.get_model_config() @@ -140,106 +141,10 @@ def pad_head_dim(weight, heads_out=True): self.scale = self.head_dim**-0.5 - def forward(self, x, mask): - return self.forward_tt(x, mask) - if os.environ.get("ATTN") == "tt": - return self.forward_tt(x, mask) - else: - return self.forward_pt(x, mask) - - def forward_pt(self, x_11SH, mask=None): + def forward(self, x_11SH, mask=None): seq_len = x_11SH.shape[-2] - x_torch = ttnn.to_torch(x_11SH, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - x_torch = x_torch[0].unsqueeze(0) - wqkv = ttnn.to_torch(self.wqkv, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - wqkv = wqkv.view(2, wqkv.shape[0] // 2, wqkv.shape[1]) - - wqkv = wqkv.reshape(2, self.hidden_size, 3 * self.n_heads // 2, -1) - wqkv = wqkv[..., : self.head_dim] - # wqkv = wqkv.reshape(2, self.hidden_size, -1) - wqkv = wqkv.reshape(2, self.hidden_size, 3, self.n_heads // 2, -1).permute(1, 2, 0, 3, 4) - wqkv = wqkv.reshape(self.hidden_size, 3, self.n_heads, -1).reshape(self.hidden_size, -1) - - # xqkv_fused_torch = torch.matmul(x_torch, wqkv).bfloat16().float() - xqkv_fused_torch = torch.nn.functional.linear(x_torch, wqkv.T).bfloat16().float() - # xqkv_fused_torch = torch.nn.functional.linear(x_torch, wqkv.tranpose).bfloat16().float() - # n, s, d = xqkv_fused_torch.shape[-3:] - s, d = xqkv_fused_torch.shape[-2:] - xqkv = xqkv_fused_torch.reshape(s, 3, d // 3) - q = xqkv[..., 0, :] - k = xqkv[..., 1, :] - v = xqkv[..., 2, :] - # xq = q.reshape(n, s, self.n_heads//2, -1).transpose(1, 2) - # xk = k.reshape(n, s, self.n_heads//2, -1).transpose(1, 2) - # xv = v.reshape(n, s, self.n_heads//2, -1).transpose(1, 2) - xq = q.reshape(s, self.n_heads, -1).transpose(0, 1) - xk = k.reshape(s, self.n_heads, -1).transpose(0, 1) - xv = v.reshape(s, self.n_heads, -1).transpose(0, 1) - mask_torch = ttnn.to_torch(mask, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - mask_torch = mask_torch[0] - attn_output = ( - torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask_torch, scale=self.scale) - .bfloat16() - .float() - ) # [...,:self.head_dim] - - # attn_output = attn_output.transpose(1, 2).reshape(n, s, -1).transpose(0, 1).reshape(s, -1) - attn_output = attn_output.transpose(0, 1).reshape(s, -1) - - wo = ttnn.to_torch(self.wo, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - wo = wo.view(2, wo.shape[0] // 2, wo.shape[1]) # .reshape(-1, wo.shape[1]) - - wo = ( - wo.transpose(1, 2) - .reshape(2, self.hidden_size, self.n_heads // 2, -1)[..., : self.head_dim] - .reshape(2, self.hidden_size, -1) - .transpose(1, 2) - .reshape(-1, self.hidden_size) - ) - - out = torch.nn.functional.linear(attn_output, wo.T).bfloat16().float() - # out = torch.sum(out, dim=0).unsqueeze(0).unsqueeze(0).bfloat16().float() - out = out.view(1, 1, 4224, -1) - - out_tt = ttnn.from_torch( - out, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - return out_tt - - def forward_tt(self, x_11SH, mask=None): - seq_len = x_11SH.shape[-2] - # assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - ### - # QKV matmuls - ### - - # reshaping long sequence to matmul fit on device - - # x_torch = ttnn.to_torch(x_11SH, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # x_torch = x_torch.view(2, seq_len, -1) - # wqkv = ttnn.to_torch(self.wqkv, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # wqkv = wqkv.view(2, wqkv.shape[0]//2, wqkv.shape[1]) - - # xqkv_fused_torch = torch.bmm(x_torch, wqkv).unsqueeze(1) - # xqkv_fused = ttnn.from_torch( - # xqkv_fused_torch, - # device=self.mesh_device, - # mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0), - # dtype=ttnn.bfloat16, - # memory_config=ttnn.DRAM_MEMORY_CONFIG, - # layout=ttnn.TILE_LAYOUT, - # ) - - # Depends on whether we are padding or not - MAX_MM_SEQ_LEN = 1056 - # MAX_MM_SEQ_LEN = 1024 - - # DEBUG: Don't batch it up - # MAX_MM_SEQ_LEN = 10000 + MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ if seq_len > MAX_MM_SEQ_LEN: x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) @@ -255,8 +160,6 @@ def forward_tt(self, x_11SH, mask=None): if seq_len > MAX_MM_SEQ_LEN: xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) - # ttnn.deallocate(x_11SH) - # split qkv into heads ( q_heads_1QSD, @@ -271,11 +174,9 @@ def forward_tt(self, x_11SH, mask=None): ) ttnn.deallocate(xqkv_fused) - # sdpa_cfg = self.model_config["SDPA_PROGCFG"](seq_len) + # TODO: get this from model_config sdpa_cfg = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=(8, 8), - q_chunk_size=128, - k_chunk_size=128, + compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False ) attn_output_1QSD = ttnn.transformer.scaled_dot_product_attention( q_heads_1QSD, @@ -292,28 +193,6 @@ def forward_tt(self, x_11SH, mask=None): ttnn.deallocate(k_heads_1KSD) ttnn.deallocate(v_heads_1VSD) - # q_heads_torch = ttnn.to_torch(q_heads_1QSD, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # k_heads_torch = ttnn.to_torch(k_heads_1KSD, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # v_heads_torch = ttnn.to_torch(v_heads_1VSD, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # mask_torch = ttnn.to_torch(mask, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - - # attn_output_torch = torch.nn.functional.scaled_dot_product_attention( - # q_heads_torch, - # k_heads_torch, - # v_heads_torch, - # attn_mask=mask_torch, - # scale=self.scale, - # ) - - # attn_output_1QSD = ttnn.from_torch( - # attn_output_torch, - # device=self.mesh_device, - # mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0), - # dtype=ttnn.bfloat16, - # memory_config=ttnn.DRAM_MEMORY_CONFIG, - # layout=ttnn.TILE_LAYOUT, - # ) - ### # Output matmul ### @@ -323,24 +202,6 @@ def forward_tt(self, x_11SH, mask=None): ) ttnn.deallocate(attn_output_1QSD) - # attn_output_torch = ttnn.to_torch(attn_output_11SH, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # # breakpoint() - # wo = ttnn.to_torch(self.wo, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - # wo = wo.view(2, 1, wo.shape[0]//2, wo.shape[1]) - # output = torch.matmul(attn_output_torch, wo) - # output = torch.sum(output, dim=0).unsqueeze(0).unsqueeze(0) - - # output_11SH = ttnn.from_torch( - # output, - # device=self.mesh_device, - # mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - # dtype=ttnn.bfloat16, - # memory_config=ttnn.DRAM_MEMORY_CONFIG, - # layout=ttnn.TILE_LAYOUT, - # ) - # return output_11SH - # breakpoint() - # reshaping long sequence to matmul fit on device if seq_len > MAX_MM_SEQ_LEN: attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) diff --git a/models/demos/llama3/tt/multimodal/llama_image_block.py b/models/demos/llama3/tt/multimodal/llama_image_block.py index d668b54897c..9ab361aed26 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_block.py +++ b/models/demos/llama3/tt/multimodal/llama_image_block.py @@ -92,46 +92,7 @@ def __init__( memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - def forward(self, x, mask): - return self.forward_tt(x, mask) - if os.environ.get("BLOCK") == "tt": - return self.forward_tt(x, mask) - else: - return self.forward_pt(x, mask) - - def forward_pt(self, x_11SH, mask=None): - seq_len = x_11SH.shape[-2] - assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - - attn_out = self.attn(self.ln_1(x_11SH), mask=mask) - if self.gated: - assert False - attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn)) - - x = ttnn.to_torch(x_11SH, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - attn_out = ttnn.to_torch(attn_out, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - res = (x + attn_out).bfloat16().float() - res = ttnn.from_torch( - res, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0), - ) - mlp_out = self.mlp(self.ln_2(res)) - if self.gated: - mlp_out = ttnn.mul(mlp_out, ttnn.tanh(self.gate_ffn)) - res = ttnn.to_torch(res, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - mlp_out = ttnn.to_torch(mlp_out, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() - out = (res + mlp_out).bfloat16().float() - out = ttnn.from_torch( - out, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=0), - ) - return out - - def forward_tt(self, x_11SH, mask=None): + def forward(self, x_11SH, mask=None): seq_len = x_11SH.shape[-2] assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" diff --git a/models/demos/llama3/tt/multimodal/llama_image_mlp.py b/models/demos/llama3/tt/multimodal/llama_image_mlp.py index fda3582597d..9f1c1f67590 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_mlp.py +++ b/models/demos/llama3/tt/multimodal/llama_image_mlp.py @@ -54,55 +54,7 @@ def __init__( self.c_proj_weight = as_interleaved_tensor("c_proj", "weight", dtype, dim=-2) self.c_proj_bias = as_interleaved_tensor("c_proj", "bias", ttnn.bfloat16, dim=None) - def forward(self, x): - return self.forward_tt(x) - if os.environ.get("MLP") == "tt": - return self.forward_tt(x) - else: - return self.forward_pt(x) - - def forward_pt(self, x): - x = ttnn.to_torch( - x, device=self.mesh_device, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) - ).float() - x = x[0] - - c_fc_weight = ttnn.to_torch( - self.c_fc_weight, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1) - ).float() - c_fc_bias = ttnn.to_torch( - self.c_fc_bias, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1) - ).float()[:1] - c_proj_weight = ttnn.to_torch( - self.c_proj_weight, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-2) - ).float() - c_proj_bias = ttnn.to_torch( - self.c_proj_bias, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) - ).float() - c_proj_bias = c_proj_bias[:1] - - # hidden = (torch.matmul(x, c_fc_weight).bfloat16().float() + c_fc_bias).bfloat16().float() - # hidden = torch.nn.functional.gelu(hidden).bfloat16().float() - # hidden = (torch.matmul(hidden, c_proj_weight).bfloat16().float() + c_proj_bias).bfloat16().float() - # hidden = hidden.view(1, 1, 5120, -1) - x = x.bfloat16().float() - hidden = torch.nn.functional.linear(x, c_fc_weight.T, c_fc_bias).bfloat16().float() - hidden = torch.nn.functional.gelu(hidden).bfloat16().float() - hidden = torch.nn.functional.linear(hidden, c_proj_weight.T).bfloat16().float() - hidden += c_proj_bias - hidden = hidden.view(1, 1, 4224, -1) - - hidden = ttnn.from_torch( - hidden, - device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - ) - return hidden - - def forward_tt(self, x: ttnn.Tensor) -> ttnn.Tensor: + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: """ w1 -> gate_proj w2 -> down_proj @@ -112,8 +64,8 @@ def forward_tt(self, x: ttnn.Tensor) -> ttnn.Tensor: seq_len = x.shape[-2] # Depends on whether we are padding or not - MAX_MM_SEQ_LEN = 1056 - # MAX_MM_SEQ_LEN = 1024 + MAX_MM_SEQ_LEN = self.args.VISION_MAX_MM_SEQ + x_in = x if seq_len >= MAX_MM_SEQ_LEN: # Too big to compute. Set different program configs based on seqlen # Reshape input to to fit on device and parallelize computation @@ -131,9 +83,9 @@ def forward_tt(self, x: ttnn.Tensor) -> ttnn.Tensor: dtype=ttnn.bfloat16, program_config=pc_1, memory_config=ttnn.DRAM_MEMORY_CONFIG, - # activation="gelu", # NOTE: activation must be passed to linear here, not in program config! Bad output otherwise + activation="gelu", # NOTE: activation must be passed to linear here, not in program config! Bad output otherwise ) - c_fc_out = ttnn.gelu(c_fc_out, fast_and_approximate_mode=False) + c_proj_out = ttnn.linear( c_fc_out, self.c_proj_weight, @@ -144,7 +96,6 @@ def forward_tt(self, x: ttnn.Tensor) -> ttnn.Tensor: memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - # if seq_len >= 1024: # Reshape back to intended shape # NOTE: Need to reshape to 4D so that fast_reduce_nc hsa a dim1 to work on c_proj_out = ttnn.reshape(c_proj_out, [1, 1, seq_len, -1]) diff --git a/models/demos/llama3/tt/multimodal/llama_layernorm.py b/models/demos/llama3/tt/multimodal/llama_layernorm.py index 0750d82da52..a20c4764ad1 100644 --- a/models/demos/llama3/tt/multimodal/llama_layernorm.py +++ b/models/demos/llama3/tt/multimodal/llama_layernorm.py @@ -82,46 +82,7 @@ def __init__( ) self.sharded_output_config = self.sharded_input_config - def forward(self, x): - return self.forward_tt(x) - if os.environ.get("LN") == "tt": - return self.forward_tt(x) - else: - return self.forward_pt(x) - - def forward_pt(self, x: ttnn.Tensor, in_sharded=False, out_sharded=False) -> ttnn.Tensor: - # If input is sharded do sharded RMSNorm and optionally return sharded output - - x = ttnn.to_torch( - x, - device=self.device, - mesh_composer=ttnn.ConcatMeshToTensor(self.device, dim=0), - )[0].float() - weight = ttnn.to_torch( - self.weight, - device=self.device, - mesh_composer=ttnn.ConcatMeshToTensor(self.device, dim=0), - )[0, 0].float() - bias = ttnn.to_torch( - self.bias, - device=self.device, - mesh_composer=ttnn.ConcatMeshToTensor(self.device, dim=0), - )[0, 0].float() - - out = torch.nn.functional.layer_norm(x, x.shape[-1:], weight=weight, bias=bias, eps=self.eps) - out = out - - out = ttnn.from_torch( - out, - device=self.device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.device), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - ) - return out - - def forward_tt(self, x: ttnn.Tensor, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + def forward(self, x: ttnn.Tensor, in_sharded=False, out_sharded=False) -> ttnn.Tensor: if in_sharded: x = ttnn.layer_norm( x, diff --git a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py index 031ca56e775..9ef2aadddac 100644 --- a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py @@ -101,7 +101,6 @@ def forward(self, x: ttnn.Tensor, ar: torch.Tensor, num_tiles: int = None): if num_tiles is None: num_tiles = self.num_tiles elif num_tiles > self.num_tiles: - # TODO: Need to implement? assert False, "_dynamic_resize is currently not supported for TtLllamaTilePositionEmbedding" # Get the correct embeddings for the given aspect ratios diff --git a/models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py similarity index 96% rename from models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py rename to models/demos/llama3/tt/multimodal/llama_vision_encoder.py index 2efce9ecd0f..ff8a71c7de5 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_vision_encoder.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py @@ -24,11 +24,7 @@ synchronize_devices, ) -import importlib - -encoder_utils = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.encoder_utils" -) +import llama_models.llama3.reference_impl.multimodal.encoder_utils as encoder_utils def to_2tuple(x): @@ -196,7 +192,6 @@ def forward(self, images, ar): assert isinstance( images, torch.Tensor ), "VisionEncoder input must be a torch tensor because of unfold in self.conv1" - SKIP_EMBED = False if images.ndim == 5: num_concurrent_media = 1 bsz, num_chunks, nch, w, h = images.shape @@ -219,17 +214,15 @@ def forward(self, images, ar): x = ttnn.reshape(x, (1, bsz * num_concurrent_media * num_chunks, ntok, dim)) # apply cls token - if not SKIP_EMBED: - x = self.class_embedding(x) - ntok += 1 + x = self.class_embedding(x) + ntok += 1 # apply position embeddings # NOTE! After class embedding, x is padded tilized tensor. Reshapes fail for padded tilized tensors, so do the reshape in row-major x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) x = ttnn.reshape(x, (bsz * num_concurrent_media, num_chunks, ntok, dim)) x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) - if not SKIP_EMBED: - x = self.positional_embedding(x, ar) + x = self.positional_embedding(x, ar) # BUG: layernorm takes 4d tensor -> 3d?? x = self.ln_pre(x) @@ -243,6 +236,7 @@ def forward(self, images, ar): fake_x = torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]) attn_mask = encoder_utils.build_encoder_attention_mask(fake_x, ar, ntok, num_chunks, 1) + # Mask stripes for the extra padding required on TT hardware attn_mask = mask_tile_padding(attn_mask, ntok, npad, num_chunks) attn_mask = ttnn.as_tensor( attn_mask, diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index c6b72a78506..f96aba089c4 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -15,19 +15,13 @@ from torch import nn, Tensor -BFLOAT = False - -import importlib - -llama_reference_model = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.model" -) -llama_reference_image_transforms = importlib.import_module( - "models.demos.t3000.llama2_70b.reference.llama-models.models.llama3.reference_impl.multimodal.image_transform" -) +import llama_models.llama3.reference_impl.multimodal.model as llama_reference_model +import llama_models.llama3.reference_impl.multimodal.image_transform as llama_reference_image_transforms import ttnn -from models.demos.llama3.tt.multimodal.llama_image_transformer_vision import TtLlamaCrossAttentionTransformerVision +from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_vision import ( + TtLlamaCrossAttentionTransformerVision, +) from models.demos.llama3.tt.multimodal.llama_cross_attention_transformer_text import ( TtLlamaCrossAttentionTransformerText, ) @@ -39,6 +33,9 @@ get_rot_transformation_mat, get_single_rot_mat, ) +from models.utility_functions import ( + nearest_32, +) logger = logging.getLogger(__name__) MP_SCALE = 8 @@ -225,16 +222,21 @@ def compute_vision_tokens_masks( vision_tokens = self.vision_model(stacked_images, aspect_ratios) # Back to torch vision_tokens = ttnn.to_torch(vision_tokens, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)) + chunk_seq_len = self.configuration.vision_chunk_ntok + # NOTE: slicing up to chunk_seq_len is necessary because padding information is lost by this point vision_tokens = ( - vision_tokens[0].reshape(bsz, max_num_images, self.max_num_chunks, -1, self.model_dim).float() + vision_tokens[0, :, :chunk_seq_len] + .reshape(bsz, max_num_images, self.max_num_chunks, -1, self.model_dim) + .float() ) bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) + padded_seq_len = self.max_num_chunks * nearest_32(self.configuration.vision_chunk_ntok) # Prepare vision tokens for TT text_model vision_tokens_squeeze = vision_tokens.view(1, bsz, -1, image_token_dim) vision_tokens_squeeze = torch.nn.functional.pad( - vision_tokens_squeeze, (0, 0, 0, 4224 - vision_tokens_squeeze.shape[2]), "constant", 0 + vision_tokens_squeeze, (0, 0, 0, padded_seq_len - vision_tokens_squeeze.shape[2]), "constant", 0 ) vision_tokens_tt = ttnn.from_torch( vision_tokens_squeeze, @@ -266,12 +268,19 @@ def compute_vision_tokens_masks( cross_attention_masks = torch.nn.functional.pad( cross_attention_masks, - (0, 4224 - cross_attention_masks.shape[3]), + (0, padded_seq_len - cross_attention_masks.shape[3]), "constant", get_negative_inf_value(torch.float32), ) return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask) + def validate_inputs(self, tokens): + batch, seq_len = tokens.shape[:2] + assert batch == 1, f"Only batch 1 is supported, got {batch}" + assert ( + seq_len <= self.configuration.max_seq_len + ), f"Sequence length {seq_len} exceeds max sequence length {self.configuration.max_seq_len}" + def forward( self, position_ids: torch.Tensor, @@ -281,8 +290,10 @@ def forward( xattn_caches: torch.Tensor, text_only_inference: bool = False, ) -> torch.Tensor: + self.validate_inputs(tokens) h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) batch, seq_len = h.shape[:2] + padded_seq_len = _get_padded_prefill_seqlen(seq_len) if seq_len == 1: mode = "decode" else: @@ -302,35 +313,39 @@ def forward( if mode == "prefill": xattn_mask_expand = torch.nn.functional.pad( xattn_mask_expand, - (0, 0, 0, 128 - xattn_mask_expand.shape[2]), + (0, 0, 0, padded_seq_len - xattn_mask_expand.shape[2]), "constant", get_negative_inf_value(torch.float32), ) + tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=self.mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) + tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) full_text_mask = full_text_row_masked_out_mask[:, :, position_ids] if mode == "prefill": full_text_mask = torch.nn.functional.pad( - full_text_mask, (0, 0, 0, 128 - full_text_mask.shape[2]), "constant", 0 + full_text_mask, (0, 0, 0, padded_seq_len - full_text_mask.shape[2]), "constant", 0 ) full_text_mask_expand_1NSH = full_text_mask.expand( -1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim ) + tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=self.mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) + tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim) tt_full_text_mask_expand_11SD = ttnn.from_torch( @@ -343,8 +358,7 @@ def forward( ) # Check mask shapes, pad if in prefill? if mode == "prefill": - # DEBUG: pad h seqlen to 128 - h = torch.nn.functional.pad(h, (0, 0, 0, 128 - h.shape[1]), "constant", 0) + h = torch.nn.functional.pad(h, (0, 0, 0, padded_seq_len - h.shape[1]), "constant", 0) tt_h = prepare_inputs_ttnn_prefill( h, self.mesh_device, @@ -427,9 +441,10 @@ def forward( text_only_inference=text_only_inference, ) - tt_out = ttnn.to_torch(logits, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)).float() + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() if mode == "prefill": - tt_out = tt_out[0].reshape(batch, 128, -1)[:, :seq_len, :] # DEBUG: undo padding + tt_out = tt_out[0].reshape(batch, padded_seq_len, -1)[:, :seq_len, :] # DEBUG: undo padding else: tt_out = tt_out[0, ..., :batch, :].transpose(0, 1).reshape(batch, seq_len, -1) @@ -493,3 +508,16 @@ def _pad_masks( out_masks[idx, mask_elem[0] : mask_elem[1], mask_idx, :mask_num_chunks].fill_(0.0) return out_masks + + +def _get_padded_prefill_seqlen(seq_len): + """ + If seq_len is less than 128, pad to 128 + If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 1024 + """ + if seq_len < 128: + return 128 + else: + mult_1024 = 1024 * math.ceil(seq_len / 1024) + pow_2 = 2 ** math.ceil(math.log2(seq_len)) + return min(mult_1024, pow_2) diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh index a63433dd028..1ced36a3955 100755 --- a/tests/scripts/t3000/run_t3000_demo_tests.sh +++ b/tests/scripts/t3000/run_t3000_demo_tests.sh @@ -75,6 +75,36 @@ run_t3000_llama3_tests() { fi } +run_t3000_llama3_vision_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_llama3_vision_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + n300=N300 + t3k=T3K + + # Install Vision-specific packages + pip install -r models/demos/llama3/requirements.txt + + for fake_device in "$n300" "$t3k"; do + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/multimodal_demo_chat.py -k "tt and 1" --timeout 600; fail+=$? + echo "LOG_METAL: Llama3 vision tests for $fake_device completed" + done + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_llama3_vision_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + run_t3000_falcon7b_tests(){ # Record the start time fail=0 diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh index 468e6757845..b5fb7360c89 100755 --- a/tests/scripts/t3000/run_t3000_frequent_tests.sh +++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh @@ -75,6 +75,64 @@ run_t3000_llama3_tests() { fi } +run_t3000_llama3.2-11b-vision_freq_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_llama3.2-11b-vision_freq_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + + # Install Vision-specific packages + pip install -r models/demos/llama3/requirements.txt + + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_transformer.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py ; fail+=$? + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_llama3.2-11b-vision_freq_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + +run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + # Use FAKE_DEVICE env variable to run on an N300 mesh + fake_device=N300 + + # Install Vision-specific packages + pip install -r models/demos/llama3/requirements.txt + + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_transformer.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py ; fail+=$? + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + run_t3000_mixtral_tests() { # Record the start time fail=0 diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index 2302718fb45..e47a6b7a93c 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -152,6 +152,74 @@ run_t3000_llama3.2-11b_tests() { fi } +run_t3000_llama3.2-11b-vision_unit_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_llama3.2-11b-vision_unit_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + + # Install Vision-specific packages + pip install -r models/demos/llama3/requirements.txt + + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_class_embedding.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py ; fail+=$? + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_llama3.2-11b-vision_unit_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + +run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + # Use FAKE_DEVICE env variable to run on an N300 mesh + fake_device=N300 + + # Install Vision-specific packages + pip install -r models/demos/llama3/requirements.txt + + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_class_embedding.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py ; fail+=$? + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + run_t3000_mixtral_tests() { # Record the start time fail=0