Skip to content

Commit

Permalink
Improve slice coverage in sweeps by adding N-dimensional slice support (
Browse files Browse the repository at this point in the history
#14115)

* #13830: add strided slice support for tiled layout
#13592: add slice to documentation

* #0: support N-d strided slice

* #0: add changes to ttnn side

* #0: add tests for N-d slice

* #0: consolidate assorted slice.cpp implementations

* #0: fix PCC issues on row major implementation and add adversarial tests (fixed ones included)
- remove bfloat8 froms sweeps for now as we're focused on bfloat16 generality

* #0: add tests for fixes to edge cases and tensor-blocked impls

* #0: add wrap_index to common and switch legacy and ttnn shape to logical shape and padded shape

* #14100: add common function and delete old strided slice code

* #0: add TMs team to data_movement sweep codeowners

* #0: correct slice docs

* #0: use more optimized slice implementation for mamba

* #0: address comments related to docs

* #0: fix common file includes
  • Loading branch information
sjameelTT authored Oct 25, 2024
1 parent e8c650e commit e9d39b3
Show file tree
Hide file tree
Showing 16 changed files with 579 additions and 304 deletions.
1 change: 1 addition & 0 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ tests/sweep_framework/ @xanderchin @jdesousa-TT @sjameelTT
tests/sweep_framework/sweeps
tests/sweep_framework/sweeps/eltwise/ @patrickroberts @yan-zaretskiy @eyonland
tests/sweep_framework/sweeps/conv2d/ @nkpatel-tt @mywoodstock @shwetankTT @sankarmanoj-tt @pavlejosipovic
tests/sweep_framework/sweeps/data_movement/ @sjameelTT @ntarafdar @jaykru-tt @yugi957

# TTNN Distributed
ttnn/cpp/ttnn/distributed/ @cfjchu @ayerofieiev-tt @dmakoviichuk-tt
Expand Down
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ Data Movement
ttnn.reshape
ttnn.repeat
ttnn.repeat_interleave
ttnn.slice
ttnn.tilize
ttnn.tilize_with_val_padding
ttnn.fill_rm
Expand Down
3 changes: 2 additions & 1 deletion models/demos/wormhole/mamba/tt/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def forward(self, x):
for i in range(0, 4):
slice_start = (0, 0, x_ssm.shape[2] - (4 - i), 0)
slice_end = (1, 1, (x_ssm.shape[2] - (4 - i)) + 1, self.args.d_inner)
entry = ttnn.slice(x_ssm, slice_start, slice_end)
step = (1, 1, 1, 1)
entry = ttnn.slice(x_ssm, starts=slice_start, ends=slice_end, steps=step)
self.convolution_cache.set(self.configs["current_user"], i, entry)
ttnn.deallocate(entry)

Expand Down
8 changes: 5 additions & 3 deletions models/demos/wormhole/mamba/tt/mamba_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ def prepare_input(self, input_tensor):
input_tensor_splits = []
split_size = self.config.input_channels // self.config.channels_split_factor
for i in range(self.config.channels_split_factor):
slice_start = ttnn.Shape((0, 0, 0, i * split_size))
slice_end = ttnn.Shape((1, self.config.input_length, 1, (i + 1) * split_size))
input_tensor_splits.append(ttnn.slice(input_tensor, slice_start, slice_end))
slice_start = (0, 0, 0, i * split_size)
slice_end = (1, self.config.input_length, 1, (i + 1) * split_size)
input_tensor_splits.append(
ttnn.slice(input_tensor, starts=slice_start, ends=slice_end, steps=(1, 1, 1, 1))
)
ttnn.deallocate(input_tensor)
return input_tensor_splits

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
{"dims": [8732, 4], "dim": 1, "start": 0, "end": -1, "step": 4},
{"dims": [8732, 4], "dim": 1, "start": 0, "end": 2},
],
"dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"dtype": [ttnn.bfloat16],
"layout": [ttnn.TILE_LAYOUT],
}
}
Expand Down
243 changes: 228 additions & 15 deletions tests/ttnn/unit_tests/operations/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def test_stride_slice_three_dim(c, h, w, begins_c, begins_h, begins_w, stride_c,
@pytest.mark.parametrize("begins", [[2, 0, 0, 2]])
@pytest.mark.parametrize("ends", [[18, 16, 16, 18]])
@pytest.mark.parametrize("strides", [[2, 2, 2, 2]])
def test_stride_slice_four_dim(dims, begins, ends, strides, device):
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
def test_stride_slice_four_dim(dims, begins, ends, strides, layout, device):
torch.manual_seed(2005)
torch_input = torch.rand(dims)
slices = []
Expand All @@ -320,17 +321,39 @@ def test_stride_slice_four_dim(dims, begins, ends, strides, device):

torch_output = torch_input[slices[0], slices[1], slices[2], slices[3]]

ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)
ttnn_output = ttnn_input[slices[0], slices[1], slices[2], slices[3]]
ttnn_output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_output, ttnn_output, 0.99)


@pytest.mark.parametrize("dims", [[1, 56, 56, 96]])
@pytest.mark.parametrize("begins", [[0, 0, 0, 0]])
@pytest.mark.parametrize("ends", [[1, -1, 56, 96]])
@pytest.mark.parametrize("strides", [[1, 2, 1, 1]])
@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT])
def test_stride_slice_four_dim_tiled(dims, begins, ends, strides, layout, device):
torch.manual_seed(2005)
torch_input = torch.rand(dims)
slices = []
for i in range(len(dims)):
slices.append(slice(begins[i], ends[i], strides[i]))

torch_output = torch_input[slices[0], slices[1], slices[2], slices[3]]

ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)
ttnn_output = ttnn_input[slices[0], slices[1], slices[2], slices[3]]
ttnn_output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_output, ttnn_output, 0.99)


# these tests are copy and paste from the yolo customers #8920
def test_slice_usecase1(device):
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
def test_slice_usecase1(layout, device):
torch_input = torch.randn(1, 3, 640, 640)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)

torch_output = torch_input[..., ::2, ::2] # torch_output shape: [1, 3, 320, 320]
ttnn_output = ttnn_input[..., ::2, ::2]
Expand All @@ -339,9 +362,10 @@ def test_slice_usecase1(device):
assert_with_pcc(torch_output, ttnn_output, 0.99)


def test_slice_usecase2(device):
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
def test_slice_usecase2(layout, device):
torch_input = torch.randn(1, 3, 640, 640)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)

torch_output = torch_input[..., ::2, 1::2] # torch_output shape: [1, 3, 320, 320]
ttnn_output = ttnn_input[..., ::2, 1::2]
Expand All @@ -350,9 +374,10 @@ def test_slice_usecase2(device):
assert_with_pcc(torch_output, ttnn_output, 0.99)


def test_slice_usecase3(device):
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
def test_slice_usecase3(layout, device):
torch_input = torch.randn(1, 3, 640, 640)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)

torch_output = torch_input[..., 1::2, ::2] # torch_output shape: [1, 3, 320, 320]
ttnn_output = ttnn_input[..., 1::2, ::2]
Expand All @@ -361,9 +386,10 @@ def test_slice_usecase3(device):
assert_with_pcc(torch_output, ttnn_output, 0.99)


def test_slice_usecase4(device):
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
def test_slice_usecase4(layout, device):
torch_input = torch.randn(1, 3, 640, 640)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)

torch_output = torch_input[..., 1::2, 1::2] # torch_output shape: [1, 3, 320, 320]
ttnn_output = ttnn_input[..., 1::2, 1::2]
Expand Down Expand Up @@ -428,10 +454,10 @@ def test_slice_bert(input_shape, input_start, input_ends, layout, device):
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)
else:
if input_ends[-1] - input_start[-1] == 1:
if (input_ends[-1] - input_start[-1]) % 2 != 0:
pytest.skip("Cannot slice the last dimension to 1 in row major layout")
torch_input = torch.randn(input_shape, dtype=torch.float32)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)

if len(input_shape) == 4:
torch_output = torch_input[
Expand Down Expand Up @@ -478,10 +504,10 @@ def test_ttnn_slice_bert(input_shape, input_start, input_ends, layout, memory_co
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)
else:
if input_ends[-1] - input_start[-1] == 1:
if (input_ends[-1] - input_start[-1]) % 2 != 0:
pytest.skip("Cannot slice the last dimension to 1 in row major layout")
torch_input = torch.randn(input_shape, dtype=torch.float32)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)

if len(input_shape) == 4:
torch_output = torch_input[
Expand Down Expand Up @@ -558,7 +584,7 @@ def test_ttnn_slice_optimized_shapes(input_shape, input_start, input_ends, layou
if (input_ends[-1] - input_start[-1]) % 2:
pytest.skip("Cannot slice the last dimension to 1 in row major layout")
torch_input = torch.randn(input_shape, dtype=torch.float32)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)

torch_output = torch_input[
input_start[0] : input_ends[0],
Expand All @@ -573,3 +599,190 @@ def test_ttnn_slice_optimized_shapes(input_shape, input_start, input_ends, layou

ttnn_output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output, ttnn_output, 0.99)


@pytest.mark.parametrize(
"input_shape, input_start, input_ends",
(
((1, 1, 1, 1, 256), (0, 0, 0, 0, 0), (1, 1, 1, 1, 255)),
((1, 1, 32, 32, 32), (0, 0, 0, 0, 0), (1, 1, 32, 32, 1)),
((1, 1, 32, 32, 64), (0, 0, 0, 0, 0), (1, 1, 32, 1, 32)),
((1, 1, 1, 64, 64), (0, 0, 0, 0, 0), (1, 1, 1, 1, 1)),
((4, 3, 2, 1, 4), (1, 1, 1, 0, 0), (1, 1, 2, 1, 4)),
),
)
@pytest.mark.parametrize(
"layout",
(ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT),
)
@pytest.mark.parametrize(
"memory_config",
(ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
)
def test_ttnn_slice_5d(input_shape, input_start, input_ends, layout, memory_config, device):
if layout == ttnn.TILE_LAYOUT:
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)
else:
if (input_ends[-1] - input_start[-1]) % 2:
pytest.skip("Cannot slice the last dimension to 1 in row major layout")
torch_input = torch.randn(input_shape, dtype=torch.float32)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)

torch_output = torch_input[
input_start[0] : input_ends[0],
input_start[1] : input_ends[1],
input_start[2] : input_ends[2],
input_start[3] : input_ends[3],
input_start[4] : input_ends[4],
]

ttnn_output = ttnn.slice(ttnn_input, input_start, input_ends, (1, 1, 1, 1, 1), memory_config=memory_config)

ttnn_output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output, ttnn_output, 0.99)


@pytest.mark.parametrize(
"input_shape, input_start, input_ends, input_stride",
(
((1, 1, 5, 1, 256), (0, 0, 0, 0, 0), (1, 1, 1, 1, 234), (1, 1, 1, 1, 1)),
((1, 2, 32, 32, 32), (0, 0, 0, 0, 0), (1, 1, 32, 32, 1), (1, 1, 1, 1, 1)),
((1, 1, 32, 32, 64), (0, 0, 0, 0, 0), (1, 1, 32, 1, 32), (1, 1, 2, 1, 2)),
((2, 1, 1, 64, 64), (1, 0, 0, 0, 0), (2, 1, 1, 1, 1), (1, 1, 1, 1, 1)),
((4, 3, 2, 1, 18), (1, 1, 1, 0, 0), (1, 1, 2, 1, -2), (1, 1, 1, 1, 2)),
),
)
@pytest.mark.parametrize(
"layout",
(ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT),
)
def test_slice_5d(input_shape, input_start, input_ends, input_stride, layout, device):
if layout == ttnn.TILE_LAYOUT:
if input_stride is not (1, 1, 1, 1, 1):
pytest.skip("Cannot untilize 5D tensor")
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)
else:
if (input_ends[-1] - input_start[-1]) % 2:
pytest.skip("Cannot slice the last dimension to 1 in row major layout")
torch_input = torch.randn(input_shape, dtype=torch.float32)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout)

torch_output = torch_input[
input_start[0] : input_ends[0] : input_stride[0],
input_start[1] : input_ends[1] : input_stride[1],
input_start[2] : input_ends[2] : input_stride[2],
input_start[3] : input_ends[3] : input_stride[3],
input_start[4] : input_ends[4] : input_stride[4],
]
ttnn_output = ttnn_input[
input_start[0] : input_ends[0] : input_stride[0],
input_start[1] : input_ends[1] : input_stride[1],
input_start[2] : input_ends[2] : input_stride[2],
input_start[3] : input_ends[3] : input_stride[3],
input_start[4] : input_ends[4] : input_stride[4],
]

ttnn_output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output, ttnn_output, 0.99)


def test_slice_7d_strided(device):
torch_input = torch.randn(1, 1, 1, 1, 1, 1, 256)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT)

torch_output = torch_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:256:2]
ttnn_output = ttnn_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:256:2]

ttnn_output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output, ttnn_output, 0.99)


def test_slice_7d(device):
torch_input = torch.randn(1, 1, 1, 1, 1, 1, 256)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)

torch_output = torch_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:200]
ttnn_output = ttnn_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:200]

ttnn_output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output, ttnn_output, 0.99)


@pytest.mark.parametrize(
"input_shape, dim, start, end, step, layout",
(
([1, 28, 56, 96], 2, 0, -1, 2, ttnn.TILE_LAYOUT), # Formerly bad pcc
([1, 56, 56, 96], 1, 0, -1, 2, ttnn.TILE_LAYOUT), # Formerly bad pcc
([8732, 4], 1, 0, 2, 1, ttnn.ROW_MAJOR_LAYOUT), # Formerly bad pcc
([1, 14, 28, 192], 2, 1, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test (low priority)
([1, 23, 40, 128], 3, 0, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test
([1, 28, 28, 256], 1, 1, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test
(
[1, 3],
1,
0,
-1,
1,
ttnn.TILE_LAYOUT,
), # works when you turn it into a 2D tensor (compared to [3] example in the next test)
),
)
def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, device):
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)

slice_obj = slice(start, end, step)

# Prepare indices for slicing in the specified dimension
indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension
indices[dim] = slice_obj # Apply slicing to the target dimension
indices = tuple(indices)

# Apply slicing to the input_tensor
torch_output_tensor = torch_input[indices]

ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)
ttnn_output = ttnn_tensor[indices]

ttnn_output_tensor = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999)


@pytest.mark.parametrize(
"input_shape, dim, start, end, step, layout",
(
([8732, 4], 1, 0, -1, 4, ttnn.TILE_LAYOUT), # Need tensor for this or a padding aware tiled kernel
([1, 7], 0, 0, -1, 1, ttnn.ROW_MAJOR_LAYOUT), # page size must equal buffer size
(
[1, 7, 71, 64],
3,
0,
-1,
1,
ttnn.ROW_MAJOR_LAYOUT,
), # An unpadding slice operations for a RowMajor layout on the output tensor requires the last dimension to be on a 32 bit boundary
([1, 8, 2, 2], 2, -1, -1, 1, ttnn.TILE_LAYOUT), # Buffer size and page size should be larger than 0 bytes
([3], 0, 0, -1, 1, ttnn.TILE_LAYOUT), # Difference in expected shape as it's a 1D tensor
),
)
def test_slice_adversarial(input_shape, dim, start, end, step, layout, device):
pytest.skip("These tests are expected to fail at the moment")
torch_input = torch.randn(input_shape, dtype=torch.bfloat16)

slice_obj = slice(start, end, step)

# Prepare indices for slicing in the specified dimension
indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension
indices[dim] = slice_obj # Apply slicing to the target dimension
indices = tuple(indices)

# Apply slicing to the input_tensor
torch_output_tensor = torch_input[indices]

ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16)
ttnn_output = ttnn_tensor[indices]

ttnn_output_tensor = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999)
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/kv_cache.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/kv_cache_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/device/update_cache_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/common/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/concat.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp
Expand Down
Loading

0 comments on commit e9d39b3

Please sign in to comment.