Skip to content

Commit

Permalink
#0: Fix vision tests, work around s2i issue
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 20, 2024
1 parent 7a9b752 commit 1a1e4ae
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_llama_cross_attention_transformer_text_inference(
prefill_pcc_required = 0.98
decode_pcc_required = 0.73

mesh_device.enable_async(True)
mesh_device.enable_async(False)

model_args = TtModelArgs(mesh_device, max_batch_size=batch)
# Limit the max seqlen to 4k to avoid OOM on host
Expand Down
6 changes: 3 additions & 3 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s
# Assert if all folders and files exist
assert os.path.exists(
self.DEFAULT_CKPT_DIR
), f"Checkpoint directory {self.DEFAULT_CKPT_DIR} does not exist, please use export LLAMA_CKPT_DIR=..."
), f"Checkpoint directory {self.DEFAULT_CKPT_DIR} does not exist, please set LLAMA_DIR=... or LLAMA_CKPT_DIR=..."
assert os.path.isfile(
self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model"
), f"Tokenizer file {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'} does not exist, please use export LLAMA_TOKENIZER_PATH=..."
), f"Tokenizer file {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'} does not exist, please set LLAMA_TOKENIZER_PATH=..."
if not os.path.exists(self.DEFAULT_CACHE_PATH):
os.makedirs(self.DEFAULT_CACHE_PATH)
assert os.path.exists(
self.DEFAULT_CACHE_PATH
), f"Cache directory {self.DEFAULT_CACHE_PATH} does not exist, please use export LLAMA_CACHE_PATH=..."
), f"Cache directory {self.DEFAULT_CACHE_PATH} does not exist, please set LLAMA_CACHE_PATH=..."
# Check if weights exist in the specified folder. If not warn the user to run the download and untar script.
# assert os.path.isfile(
# self.DEFAULT_CKPT_DIR + "/consolidated.00.pth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@

# SPDX-License-Identifier: Apache-2.0

import os
import math
import ttnn
import torch
import torch.nn as nn
from models.demos.llama3.tt.llama_decoder import TtTransformerBlock
from models.demos.llama3.tt.multimodal.llama_cross_block import TtLlamaCrossAttentionTransformerBlock
from models.demos.llama3.tt.llama_model import LMHead
from models.demos.llama3.tt.distributed_norm import DistributedNorm
from models.common.rmsnorm import RMSNorm
import ttnn
from typing import Optional
from models.common.lightweightmodule import LightweightModule
from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding

from models.utility_functions import (
nearest_32,
Expand Down Expand Up @@ -288,8 +283,9 @@ def forward(

h = self.norm(h, mode=mode)

if mode == "decode": # h is expected to be interleaved for the lm head
h = ttnn.sharded_to_interleaved(h)
# TODO: Switch to using dram-sharded LM haed and remove this
# Note: workaround for sharded_to_interleaved memory corruption (#15113)
h = ttnn.to_memory_config(h, ttnn.DRAM_MEMORY_CONFIG)

seq_len = h.shape[2]
MAX_MM_SEQ_LEN = 1024
Expand Down

0 comments on commit 1a1e4ae

Please sign in to comment.