Skip to content

Commit

Permalink
#0: Fix tracing support with vision demo
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 20, 2024
1 parent 768b956 commit 1223510
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
13 changes: 7 additions & 6 deletions models/demos/llama3/tt/multimodal/llama_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,13 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas
rot_mats,
) = self.copy_host_to_device((tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats))

tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device(
tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
B=tokens.shape[0],
)

tt_h = ttnn.to_memory_config(tt_h, self.configuration.model_config["DECODE_RESIDUAL_MEMCFG"])

return (
tt_h,
tt_xattn_mask,
Expand All @@ -415,7 +414,7 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro
h = self.prepare_inputs_common(position_ids, tokens)
tt_h = self.configuration.prepare_inputs_ttnn_decode(
h,
ttnn.DRAM_MEMORY_CONFIG, # L1 memory_configs are not respected for on_host tensors
None, # on_host tensors have no memory_config
on_host=True,
)

Expand Down Expand Up @@ -489,7 +488,7 @@ def copy_host_to_device(self, host_tensors, device_tensors=None):
ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i])
return device_tensors

def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B):
def transform_decode_inputs_device(self, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B):
"""
Does any transformations on device tensors which are necessary before ttnn_decode_forward
"""
Expand All @@ -498,6 +497,8 @@ def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand
), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}"
S = 1

tt_h = ttnn.to_memory_config(tt_h, self.configuration.model_config["DECODE_RESIDUAL_MEMCFG"])

tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT)
tt_xattn_mask = ttnn.reshape(
tt_xattn_mask,
Expand Down Expand Up @@ -530,7 +531,7 @@ def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand
),
)

return (tt_xattn_mask, tt_full_text_mask_expand_1NSH)
return (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH)

def process_output_prefill(self, tt_out, B, S):
padded_seq_len = _get_padded_prefill_seqlen(S)
Expand Down
18 changes: 16 additions & 2 deletions models/demos/llama3/tt/multimodal/vision_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,15 @@ def capture_trace(
)

trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)
tt_h_trace_input = tt_h
B = tokens.shape[0]
# Do on-device transformations of inputs before forward
tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device(
(
tt_h,
tt_xattn_mask_transform,
tt_full_text_mask_expand_1NSH_transform,
) = self.model.transform_decode_inputs_device(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
B=B,
Expand All @@ -204,7 +210,15 @@ def capture_trace(

ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)

return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats
return (
trace_id,
tt_logits_rm,
tt_h_trace_input,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
tt_position_id,
rot_mats,
)

def decode_forward_trace(
self,
Expand Down

0 comments on commit 1223510

Please sign in to comment.