Skip to content

Commit

Permalink
#0: use more optimized slice implementation for mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Oct 23, 2024
1 parent c23bd94 commit 890a4f2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
3 changes: 2 additions & 1 deletion models/demos/wormhole/mamba/tt/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def forward(self, x):
for i in range(0, 4):
slice_start = (0, 0, x_ssm.shape[2] - (4 - i), 0)
slice_end = (1, 1, (x_ssm.shape[2] - (4 - i)) + 1, self.args.d_inner)
entry = ttnn.slice(x_ssm, slice_start, slice_end)
step = (1, 1, 1, 1)
entry = ttnn.slice(x_ssm, starts=slice_start, ends=slice_end, steps=step)
self.convolution_cache.set(self.configs["current_user"], i, entry)
ttnn.deallocate(entry)

Expand Down
8 changes: 5 additions & 3 deletions models/demos/wormhole/mamba/tt/mamba_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ def prepare_input(self, input_tensor):
input_tensor_splits = []
split_size = self.config.input_channels // self.config.channels_split_factor
for i in range(self.config.channels_split_factor):
slice_start = ttnn.Shape((0, 0, 0, i * split_size))
slice_end = ttnn.Shape((1, self.config.input_length, 1, (i + 1) * split_size))
input_tensor_splits.append(ttnn.slice(input_tensor, slice_start, slice_end))
slice_start = (0, 0, 0, i * split_size)
slice_end = (1, self.config.input_length, 1, (i + 1) * split_size)
input_tensor_splits.append(
ttnn.slice(input_tensor, starts=slice_start, ends=slice_end, steps=(1, 1, 1, 1))
)
ttnn.deallocate(input_tensor)
return input_tensor_splits

Expand Down

0 comments on commit 890a4f2

Please sign in to comment.