From e9d39b3c6c4b3faaf9f77160b8bbad0d6f4f5340 Mon Sep 17 00:00:00 2001 From: Saad Jameel <163029024+sjameelTT@users.noreply.github.com> Date: Thu, 24 Oct 2024 20:31:24 -0400 Subject: [PATCH] Improve slice coverage in sweeps by adding N-dimensional slice support (#14115) * #13830: add strided slice support for tiled layout #13592: add slice to documentation * #0: support N-d strided slice * #0: add changes to ttnn side * #0: add tests for N-d slice * #0: consolidate assorted slice.cpp implementations * #0: fix PCC issues on row major implementation and add adversarial tests (fixed ones included) - remove bfloat8 froms sweeps for now as we're focused on bfloat16 generality * #0: add tests for fixes to edge cases and tensor-blocked impls * #0: add wrap_index to common and switch legacy and ttnn shape to logical shape and padded shape * #14100: add common function and delete old strided slice code * #0: add TMs team to data_movement sweep codeowners * #0: correct slice docs * #0: use more optimized slice implementation for mamba * #0: address comments related to docs * #0: fix common file includes --- CODEOWNERS | 1 + docs/source/ttnn/ttnn/api.rst | 1 + models/demos/wormhole/mamba/tt/mamba_block.py | 3 +- models/demos/wormhole/mamba/tt/mamba_conv.py | 8 +- .../slice/slice_pytorch2_tiled.py | 2 +- .../ttnn/unit_tests/operations/test_slice.py | 243 ++++++++++++++++-- ttnn/CMakeLists.txt | 1 + .../data_movement/common/common.cpp | 46 ++++ .../data_movement/common/common.hpp | 32 +-- .../data_movement/common/kernels/common.hpp | 47 ++++ .../strided_slice_reader_rm_interleaved.cpp | 118 --------- ...strided_slice_reader_rm_interleaved_nd.cpp | 89 +++++++ .../slice/device/slice_program_factory.cpp | 43 ++-- .../operations/data_movement/slice/slice.cpp | 201 ++++++++------- .../data_movement/slice/slice_pybind.hpp | 33 ++- ttnn/ttnn/operations/core.py | 15 +- 16 files changed, 579 insertions(+), 304 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/common/common.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp delete mode 100644 ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp diff --git a/CODEOWNERS b/CODEOWNERS index beb2eb6452f..ef6bf8072f2 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -120,6 +120,7 @@ tests/sweep_framework/ @xanderchin @jdesousa-TT @sjameelTT tests/sweep_framework/sweeps tests/sweep_framework/sweeps/eltwise/ @patrickroberts @yan-zaretskiy @eyonland tests/sweep_framework/sweeps/conv2d/ @nkpatel-tt @mywoodstock @shwetankTT @sankarmanoj-tt @pavlejosipovic +tests/sweep_framework/sweeps/data_movement/ @sjameelTT @ntarafdar @jaykru-tt @yugi957 # TTNN Distributed ttnn/cpp/ttnn/distributed/ @cfjchu @ayerofieiev-tt @dmakoviichuk-tt diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 30137696eec..2a044aa399c 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -420,6 +420,7 @@ Data Movement ttnn.reshape ttnn.repeat ttnn.repeat_interleave + ttnn.slice ttnn.tilize ttnn.tilize_with_val_padding ttnn.fill_rm diff --git a/models/demos/wormhole/mamba/tt/mamba_block.py b/models/demos/wormhole/mamba/tt/mamba_block.py index e3d333ae9f2..c895a15a17d 100644 --- a/models/demos/wormhole/mamba/tt/mamba_block.py +++ b/models/demos/wormhole/mamba/tt/mamba_block.py @@ -197,7 +197,8 @@ def forward(self, x): for i in range(0, 4): slice_start = (0, 0, x_ssm.shape[2] - (4 - i), 0) slice_end = (1, 1, (x_ssm.shape[2] - (4 - i)) + 1, self.args.d_inner) - entry = ttnn.slice(x_ssm, slice_start, slice_end) + step = (1, 1, 1, 1) + entry = ttnn.slice(x_ssm, starts=slice_start, ends=slice_end, steps=step) self.convolution_cache.set(self.configs["current_user"], i, entry) ttnn.deallocate(entry) diff --git a/models/demos/wormhole/mamba/tt/mamba_conv.py b/models/demos/wormhole/mamba/tt/mamba_conv.py index c4dd0d961ef..a2700198f83 100644 --- a/models/demos/wormhole/mamba/tt/mamba_conv.py +++ b/models/demos/wormhole/mamba/tt/mamba_conv.py @@ -75,9 +75,11 @@ def prepare_input(self, input_tensor): input_tensor_splits = [] split_size = self.config.input_channels // self.config.channels_split_factor for i in range(self.config.channels_split_factor): - slice_start = ttnn.Shape((0, 0, 0, i * split_size)) - slice_end = ttnn.Shape((1, self.config.input_length, 1, (i + 1) * split_size)) - input_tensor_splits.append(ttnn.slice(input_tensor, slice_start, slice_end)) + slice_start = (0, 0, 0, i * split_size) + slice_end = (1, self.config.input_length, 1, (i + 1) * split_size) + input_tensor_splits.append( + ttnn.slice(input_tensor, starts=slice_start, ends=slice_end, steps=(1, 1, 1, 1)) + ) ttnn.deallocate(input_tensor) return input_tensor_splits diff --git a/tests/sweep_framework/sweeps/data_movement/slice/slice_pytorch2_tiled.py b/tests/sweep_framework/sweeps/data_movement/slice/slice_pytorch2_tiled.py index b79e211bdcf..e35a28c1a5f 100644 --- a/tests/sweep_framework/sweeps/data_movement/slice/slice_pytorch2_tiled.py +++ b/tests/sweep_framework/sweeps/data_movement/slice/slice_pytorch2_tiled.py @@ -172,7 +172,7 @@ {"dims": [8732, 4], "dim": 1, "start": 0, "end": -1, "step": 4}, {"dims": [8732, 4], "dim": 1, "start": 0, "end": 2}, ], - "dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "dtype": [ttnn.bfloat16], "layout": [ttnn.TILE_LAYOUT], } } diff --git a/tests/ttnn/unit_tests/operations/test_slice.py b/tests/ttnn/unit_tests/operations/test_slice.py index 7882dde8697..a85b33ada9a 100644 --- a/tests/ttnn/unit_tests/operations/test_slice.py +++ b/tests/ttnn/unit_tests/operations/test_slice.py @@ -311,7 +311,8 @@ def test_stride_slice_three_dim(c, h, w, begins_c, begins_h, begins_w, stride_c, @pytest.mark.parametrize("begins", [[2, 0, 0, 2]]) @pytest.mark.parametrize("ends", [[18, 16, 16, 18]]) @pytest.mark.parametrize("strides", [[2, 2, 2, 2]]) -def test_stride_slice_four_dim(dims, begins, ends, strides, device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_stride_slice_four_dim(dims, begins, ends, strides, layout, device): torch.manual_seed(2005) torch_input = torch.rand(dims) slices = [] @@ -320,7 +321,28 @@ def test_stride_slice_four_dim(dims, begins, ends, strides, device): torch_output = torch_input[slices[0], slices[1], slices[2], slices[3]] - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) + ttnn_output = ttnn_input[slices[0], slices[1], slices[2], slices[3]] + ttnn_output = ttnn.to_torch(ttnn_output) + + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +@pytest.mark.parametrize("dims", [[1, 56, 56, 96]]) +@pytest.mark.parametrize("begins", [[0, 0, 0, 0]]) +@pytest.mark.parametrize("ends", [[1, -1, 56, 96]]) +@pytest.mark.parametrize("strides", [[1, 2, 1, 1]]) +@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT]) +def test_stride_slice_four_dim_tiled(dims, begins, ends, strides, layout, device): + torch.manual_seed(2005) + torch_input = torch.rand(dims) + slices = [] + for i in range(len(dims)): + slices.append(slice(begins[i], ends[i], strides[i])) + + torch_output = torch_input[slices[0], slices[1], slices[2], slices[3]] + + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) ttnn_output = ttnn_input[slices[0], slices[1], slices[2], slices[3]] ttnn_output = ttnn.to_torch(ttnn_output) @@ -328,9 +350,10 @@ def test_stride_slice_four_dim(dims, begins, ends, strides, device): # these tests are copy and paste from the yolo customers #8920 -def test_slice_usecase1(device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_slice_usecase1(layout, device): torch_input = torch.randn(1, 3, 640, 640) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) torch_output = torch_input[..., ::2, ::2] # torch_output shape: [1, 3, 320, 320] ttnn_output = ttnn_input[..., ::2, ::2] @@ -339,9 +362,10 @@ def test_slice_usecase1(device): assert_with_pcc(torch_output, ttnn_output, 0.99) -def test_slice_usecase2(device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_slice_usecase2(layout, device): torch_input = torch.randn(1, 3, 640, 640) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) torch_output = torch_input[..., ::2, 1::2] # torch_output shape: [1, 3, 320, 320] ttnn_output = ttnn_input[..., ::2, 1::2] @@ -350,9 +374,10 @@ def test_slice_usecase2(device): assert_with_pcc(torch_output, ttnn_output, 0.99) -def test_slice_usecase3(device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_slice_usecase3(layout, device): torch_input = torch.randn(1, 3, 640, 640) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) torch_output = torch_input[..., 1::2, ::2] # torch_output shape: [1, 3, 320, 320] ttnn_output = ttnn_input[..., 1::2, ::2] @@ -361,9 +386,10 @@ def test_slice_usecase3(device): assert_with_pcc(torch_output, ttnn_output, 0.99) -def test_slice_usecase4(device): +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_slice_usecase4(layout, device): torch_input = torch.randn(1, 3, 640, 640) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) torch_output = torch_input[..., 1::2, 1::2] # torch_output shape: [1, 3, 320, 320] ttnn_output = ttnn_input[..., 1::2, 1::2] @@ -428,10 +454,10 @@ def test_slice_bert(input_shape, input_start, input_ends, layout, device): torch_input = torch.randn(input_shape, dtype=torch.bfloat16) ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) else: - if input_ends[-1] - input_start[-1] == 1: + if (input_ends[-1] - input_start[-1]) % 2 != 0: pytest.skip("Cannot slice the last dimension to 1 in row major layout") torch_input = torch.randn(input_shape, dtype=torch.float32) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) if len(input_shape) == 4: torch_output = torch_input[ @@ -478,10 +504,10 @@ def test_ttnn_slice_bert(input_shape, input_start, input_ends, layout, memory_co torch_input = torch.randn(input_shape, dtype=torch.bfloat16) ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) else: - if input_ends[-1] - input_start[-1] == 1: + if (input_ends[-1] - input_start[-1]) % 2 != 0: pytest.skip("Cannot slice the last dimension to 1 in row major layout") torch_input = torch.randn(input_shape, dtype=torch.float32) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) if len(input_shape) == 4: torch_output = torch_input[ @@ -558,7 +584,7 @@ def test_ttnn_slice_optimized_shapes(input_shape, input_start, input_ends, layou if (input_ends[-1] - input_start[-1]) % 2: pytest.skip("Cannot slice the last dimension to 1 in row major layout") torch_input = torch.randn(input_shape, dtype=torch.float32) - ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.float32, layout=layout) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) torch_output = torch_input[ input_start[0] : input_ends[0], @@ -573,3 +599,190 @@ def test_ttnn_slice_optimized_shapes(input_shape, input_start, input_ends, layou ttnn_output = ttnn.to_torch(ttnn_output) assert_with_pcc(torch_output, ttnn_output, 0.99) + + +@pytest.mark.parametrize( + "input_shape, input_start, input_ends", + ( + ((1, 1, 1, 1, 256), (0, 0, 0, 0, 0), (1, 1, 1, 1, 255)), + ((1, 1, 32, 32, 32), (0, 0, 0, 0, 0), (1, 1, 32, 32, 1)), + ((1, 1, 32, 32, 64), (0, 0, 0, 0, 0), (1, 1, 32, 1, 32)), + ((1, 1, 1, 64, 64), (0, 0, 0, 0, 0), (1, 1, 1, 1, 1)), + ((4, 3, 2, 1, 4), (1, 1, 1, 0, 0), (1, 1, 2, 1, 4)), + ), +) +@pytest.mark.parametrize( + "layout", + (ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), +) +@pytest.mark.parametrize( + "memory_config", + (ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG), +) +def test_ttnn_slice_5d(input_shape, input_start, input_ends, layout, memory_config, device): + if layout == ttnn.TILE_LAYOUT: + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) + else: + if (input_ends[-1] - input_start[-1]) % 2: + pytest.skip("Cannot slice the last dimension to 1 in row major layout") + torch_input = torch.randn(input_shape, dtype=torch.float32) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) + + torch_output = torch_input[ + input_start[0] : input_ends[0], + input_start[1] : input_ends[1], + input_start[2] : input_ends[2], + input_start[3] : input_ends[3], + input_start[4] : input_ends[4], + ] + + ttnn_output = ttnn.slice(ttnn_input, input_start, input_ends, (1, 1, 1, 1, 1), memory_config=memory_config) + + ttnn_output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +@pytest.mark.parametrize( + "input_shape, input_start, input_ends, input_stride", + ( + ((1, 1, 5, 1, 256), (0, 0, 0, 0, 0), (1, 1, 1, 1, 234), (1, 1, 1, 1, 1)), + ((1, 2, 32, 32, 32), (0, 0, 0, 0, 0), (1, 1, 32, 32, 1), (1, 1, 1, 1, 1)), + ((1, 1, 32, 32, 64), (0, 0, 0, 0, 0), (1, 1, 32, 1, 32), (1, 1, 2, 1, 2)), + ((2, 1, 1, 64, 64), (1, 0, 0, 0, 0), (2, 1, 1, 1, 1), (1, 1, 1, 1, 1)), + ((4, 3, 2, 1, 18), (1, 1, 1, 0, 0), (1, 1, 2, 1, -2), (1, 1, 1, 1, 2)), + ), +) +@pytest.mark.parametrize( + "layout", + (ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), +) +def test_slice_5d(input_shape, input_start, input_ends, input_stride, layout, device): + if layout == ttnn.TILE_LAYOUT: + if input_stride is not (1, 1, 1, 1, 1): + pytest.skip("Cannot untilize 5D tensor") + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) + else: + if (input_ends[-1] - input_start[-1]) % 2: + pytest.skip("Cannot slice the last dimension to 1 in row major layout") + torch_input = torch.randn(input_shape, dtype=torch.float32) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=layout) + + torch_output = torch_input[ + input_start[0] : input_ends[0] : input_stride[0], + input_start[1] : input_ends[1] : input_stride[1], + input_start[2] : input_ends[2] : input_stride[2], + input_start[3] : input_ends[3] : input_stride[3], + input_start[4] : input_ends[4] : input_stride[4], + ] + ttnn_output = ttnn_input[ + input_start[0] : input_ends[0] : input_stride[0], + input_start[1] : input_ends[1] : input_stride[1], + input_start[2] : input_ends[2] : input_stride[2], + input_start[3] : input_ends[3] : input_stride[3], + input_start[4] : input_ends[4] : input_stride[4], + ] + + ttnn_output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +def test_slice_7d_strided(device): + torch_input = torch.randn(1, 1, 1, 1, 1, 1, 256) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT) + + torch_output = torch_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:256:2] + ttnn_output = ttnn_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:256:2] + + ttnn_output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +def test_slice_7d(device): + torch_input = torch.randn(1, 1, 1, 1, 1, 1, 256) + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + torch_output = torch_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:200] + ttnn_output = ttnn_input[..., 0:1, 0:1, 0:1, 0:1, 0:1, 0:200] + + ttnn_output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output, ttnn_output, 0.99) + + +@pytest.mark.parametrize( + "input_shape, dim, start, end, step, layout", + ( + ([1, 28, 56, 96], 2, 0, -1, 2, ttnn.TILE_LAYOUT), # Formerly bad pcc + ([1, 56, 56, 96], 1, 0, -1, 2, ttnn.TILE_LAYOUT), # Formerly bad pcc + ([8732, 4], 1, 0, 2, 1, ttnn.ROW_MAJOR_LAYOUT), # Formerly bad pcc + ([1, 14, 28, 192], 2, 1, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test (low priority) + ([1, 23, 40, 128], 3, 0, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test + ([1, 28, 28, 256], 1, 1, -1, 2, ttnn.TILE_LAYOUT), # Bad pcc on sweeps but not on unit test + ( + [1, 3], + 1, + 0, + -1, + 1, + ttnn.TILE_LAYOUT, + ), # works when you turn it into a 2D tensor (compared to [3] example in the next test) + ), +) +def test_slice_adversarial_fixed(input_shape, dim, start, end, step, layout, device): + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + + slice_obj = slice(start, end, step) + + # Prepare indices for slicing in the specified dimension + indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension + indices[dim] = slice_obj # Apply slicing to the target dimension + indices = tuple(indices) + + # Apply slicing to the input_tensor + torch_output_tensor = torch_input[indices] + + ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) + ttnn_output = ttnn_tensor[indices] + + ttnn_output_tensor = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999) + + +@pytest.mark.parametrize( + "input_shape, dim, start, end, step, layout", + ( + ([8732, 4], 1, 0, -1, 4, ttnn.TILE_LAYOUT), # Need tensor for this or a padding aware tiled kernel + ([1, 7], 0, 0, -1, 1, ttnn.ROW_MAJOR_LAYOUT), # page size must equal buffer size + ( + [1, 7, 71, 64], + 3, + 0, + -1, + 1, + ttnn.ROW_MAJOR_LAYOUT, + ), # An unpadding slice operations for a RowMajor layout on the output tensor requires the last dimension to be on a 32 bit boundary + ([1, 8, 2, 2], 2, -1, -1, 1, ttnn.TILE_LAYOUT), # Buffer size and page size should be larger than 0 bytes + ([3], 0, 0, -1, 1, ttnn.TILE_LAYOUT), # Difference in expected shape as it's a 1D tensor + ), +) +def test_slice_adversarial(input_shape, dim, start, end, step, layout, device): + pytest.skip("These tests are expected to fail at the moment") + torch_input = torch.randn(input_shape, dtype=torch.bfloat16) + + slice_obj = slice(start, end, step) + + # Prepare indices for slicing in the specified dimension + indices = [slice(None)] * len(input_shape) # By default, select all elements along every dimension + indices[dim] = slice_obj # Apply slicing to the target dimension + indices = tuple(indices) + + # Apply slicing to the input_tensor + torch_output_tensor = torch_input[indices] + + ttnn_tensor = ttnn.from_torch(torch_input, device=device, layout=layout, dtype=ttnn.bfloat16) + ttnn_output = ttnn_tensor[indices] + + ttnn_output_tensor = ttnn.to_torch(ttnn_output) + + assert_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 5f1d7d0ff1d..5e814aef73e 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -44,6 +44,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/kv_cache.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/kv_cache_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/kv_cache/device/update_cache_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/common/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/concat.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp new file mode 100644 index 00000000000..b1dbec7c2c7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/data_movement/common/common.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" + +namespace ttnn { +namespace operations { +namespace data_movement { + ttnn::Tensor pad_to_tile_vol(uint8_t queue_id, + const ttnn::Tensor& tensor, + const float value, + const bool use_multicore, + const std::optional& memory_config) { + auto logical_shape = tensor.get_logical_shape(); + auto padded_shape = tensor.get_padded_shape(); + auto rank = tensor.get_shape().rank(); + if (padded_shape.volume() % tt::constants::TILE_HW != 0) { + TT_ASSERT(rank >= 2, "rank of tensor to pad to tile must be at least 2."); + + auto padded_height = tt::round_up(padded_shape[-2], tt::constants::TILE_HEIGHT); + auto padded_width = tt::round_up(padded_shape[-1], tt::constants::TILE_WIDTH); + uint32_t num_non_hw_dims = rank - 2u; + auto padding_vec = std::vector>(num_non_hw_dims, {0,0}); + padding_vec.reserve(rank); + padding_vec.emplace_back(0, padded_height - padded_shape[-2]); + padding_vec.emplace_back(0, padded_width - padded_shape[-1]); + + constexpr bool pad_use_multicore = true; + auto padded_output = ttnn::pad(queue_id, + tensor, + padding_vec, + value, + use_multicore, + memory_config); + return padded_output; + } + return tensor; + } + uint32_t wrap_index(int index, int size) { + return index < 0 ? size + index : index; + } +} +} +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp index 2280dc608db..f82ef63ccf6 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp @@ -2,6 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttnn/cpp/ttnn/tensor/types.hpp" +#include "ttnn/cpp/ttnn/tensor/tensor.hpp" + namespace ttnn { namespace operations { namespace data_movement { @@ -9,32 +12,9 @@ namespace data_movement { const ttnn::Tensor& tensor, const float value, const bool use_multicore, - const std::optional& memory_config) { - auto logical_shape = tensor.get_logical_shape(); - auto padded_shape = tensor.get_padded_shape(); - auto rank = tensor.get_shape().rank(); - if (padded_shape.volume() % tt::constants::TILE_HW != 0) { - TT_ASSERT(rank >= 2, "rank of tensor to pad to tile must be at least 2."); - - auto padded_height = tt::round_up(padded_shape[-2], tt::constants::TILE_HEIGHT); - auto padded_width = tt::round_up(padded_shape[-1], tt::constants::TILE_WIDTH); - uint32_t num_non_hw_dims = rank - 2u; - auto padding_vec = std::vector>(num_non_hw_dims, {0,0}); - padding_vec.reserve(rank); - padding_vec.emplace_back(0, padded_height - padded_shape[-2]); - padding_vec.emplace_back(0, padded_width - padded_shape[-1]); - - constexpr bool pad_use_multicore = true; - auto padded_output = ttnn::pad(queue_id, - tensor, - padding_vec, - value, - use_multicore, - memory_config); - return padded_output; - } - return tensor; - } + const std::optional& memory_config); + + uint32_t wrap_index(int index, int size); template struct MassagedOperationParams { diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp new file mode 100644 index 00000000000..bf7062ab92b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// This file contains common kernel functions used in data movement device kernels +// It's best to copy and paste the functions in rather than include the header as code size will likely explode +// Best to separate in to cpp/hpp at some point to avoid the code size explosion but need to figure out the linking issues + +namespace tt::data_movement::common { + + // this function is useful for converting bfloat16 values to float32 + float bfloat16_to_float32(uint16_t bfloat16_data) { + uint32_t bits = static_cast(bfloat16_data) << 16; + + // Extract the sign bit + uint32_t sign = bits & 0x80000000; + + // Extract the exponent + uint32_t exponent = bits & 0x7F800000; + + // Extract the mantissa + uint32_t mantissa = bits & 0x007FFFFF; + + // Handle special cases + if (exponent == 0 && mantissa == 0) { + // Zero + return sign ? -0.0f : 0.0f; + } else if (exponent == 0x7F800000) { + if (mantissa == 0) { + // Infinity + return sign ? -__builtin_huge_valf() : __builtin_huge_valf(); + } else { + // NaN + return __builtin_nanf(""); + } + } + + // Assemble the float + union { + uint32_t u; + float f; + } ieee_float; + + ieee_float.u = sign | exponent | mantissa; + return ieee_float.f; + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp deleted file mode 100644 index ac3f9e2ac01..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp +++ /dev/null @@ -1,118 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" - -#ifdef DEBUG_PRINT -// this function is useful for printing bfloat16 values -#include "dprint.h" - -float bfloat16_to_float32(uint16_t bfloat16_data) { - uint32_t bits = static_cast(bfloat16_data) << 16; - - // Extract the sign bit - uint32_t sign = bits & 0x80000000; - - // Extract the exponent - uint32_t exponent = bits & 0x7F800000; - - // Extract the mantissa - uint32_t mantissa = bits & 0x007FFFFF; - - // Handle special cases - if (exponent == 0 && mantissa == 0) { - // Zero - return sign ? -0.0f : 0.0f; - } else if (exponent == 0x7F800000) { - if (mantissa == 0) { - // Infinity - return sign ? -__builtin_huge_valf() : __builtin_huge_valf(); - } else { - // NaN - return __builtin_nanf(""); - } - } - - // Assemble the float - union { - uint32_t u; - float f; - } ieee_float; - - ieee_float.u = sign | exponent | mantissa; - return ieee_float.f; -} -#endif - - -void kernel_main() { - - constexpr bool src0_is_dram = (bool) get_compile_time_arg_val(0); - constexpr uint32_t W = get_compile_time_arg_val(1); - constexpr uint32_t H = get_compile_time_arg_val(2); - constexpr uint32_t C = get_compile_time_arg_val(3); - constexpr uint32_t N = get_compile_time_arg_val(4); - - constexpr uint32_t stride_W = get_compile_time_arg_val(5); - constexpr uint32_t stride_H = get_compile_time_arg_val(6); - constexpr uint32_t stride_C = get_compile_time_arg_val(7); - constexpr uint32_t stride_N = get_compile_time_arg_val(8); - constexpr uint32_t page_size = get_compile_time_arg_val(9); - - const uint32_t src_addr = get_arg_val(0); - const uint32_t start_W = get_arg_val(1); - const uint32_t start_H = get_arg_val(2); - const uint32_t start_C = get_arg_val(3); - const uint32_t start_N = get_arg_val(4); - - const uint32_t end_W = get_arg_val(5); - const uint32_t end_H = get_arg_val(6); - const uint32_t end_C = get_arg_val(7); - const uint32_t end_N = get_arg_val(8); - - const InterleavedAddrGen s0 = { - .bank_base_address = src_addr, - .page_size = page_size - }; - - constexpr uint32_t cb_id_in0 = 0; - constexpr uint32_t cb_id_out0 = 24; - uint32_t src_buffer_l1_addr = get_write_ptr(cb_id_in0); - volatile tt_l1_ptr uint16_t* in_stick = reinterpret_cast(src_buffer_l1_addr); - constexpr uint32_t CH = C*H; - // TODO: optimize this kernel to read in multiple sticks at once - // TODO: add support for negative strides - // TODO: add axis support - for (uint32_t i = start_N; i < end_N; i+=stride_N) { - uint32_t iCH = i*CH; - for (uint32_t j = start_C; j < end_C; j+=stride_C) { - uint32_t jHplusiCH = j*H + iCH; - for (uint32_t k = start_H; k < end_H; k+=stride_H) { - - // relevant page/stick id - uint32_t src_stick_id = k + jHplusiCH; - - // read in entire stick and wait - we may want to allocate a CB and max out our reads before waiting - noc_async_read_page(src_stick_id, s0, src_buffer_l1_addr); - noc_async_read_barrier(); - - - // TODO: optimize when there's no slice or stride along W. In that case, we can just do a single read and write. - // reserve space in output buffer - cb_reserve_back(cb_id_out0, 1); - // write out element by element into output buffer - volatile tt_l1_ptr uint16_t* out_stick = reinterpret_cast(get_write_ptr(cb_id_out0)); - uint32_t out_stick_id = 0; - for (uint32_t l = start_W; l < end_W; l+=stride_W) { - out_stick[out_stick_id] = in_stick[l]; - out_stick_id++; - } - cb_push_back(cb_id_out0, 1); - } - } - } - - -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp new file mode 100644 index 00000000000..792b9e1ee91 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + + constexpr bool src0_is_dram = (bool) get_compile_time_arg_val(0); + constexpr uint32_t page_size = get_compile_time_arg_val(1); + constexpr uint32_t dims = get_compile_time_arg_val(2); + + const uint32_t src_addr = get_arg_val(0); + + // Initialize shape, starts, ends, strides + uint32_t shape[dims], starts[dims], ends[dims], strides[dims]; + for (uint32_t i = 1; i <= dims; i++) { + shape[i - 1] = get_arg_val(i); + starts[i - 1] = get_arg_val(i + dims); + ends[i - 1] = get_arg_val(i + 2*dims); + strides[i - 1] = get_arg_val(i + 3*dims); + } + + // Calculate the product array, excluding the last dimension + uint32_t prod[dims]; + for (uint32_t i = 0; i < dims - 1; i++) { + prod[i] = 1; + for (uint32_t j = i + 1; j < dims - 1; j++) { // Exclude the last dimension + prod[i] *= shape[j]; + } + } + prod[dims - 1] = 1; // Not used, but set to 1 for completeness + + const InterleavedAddrGen s0 = { + .bank_base_address = src_addr, + .page_size = page_size + }; + + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_out0 = 24; + uint32_t src_buffer_l1_addr = get_write_ptr(cb_id_in0); + volatile tt_l1_ptr uint16_t* in_stick = reinterpret_cast(src_buffer_l1_addr); + + + uint32_t index[dims - 1]; // To hold current index in each of the first dims-1 dimensions + for (uint32_t i = 0; i < dims - 1; i++) { + index[i] = starts[i]; // Initialize the index with the start values + } + + // Flag to determine when to terminate the loop + bool done = false; + + while (!done) { + // Calculate the base linear index based on the first dims-1 indices + uint32_t base_linear_index = 0; + for (uint32_t i = 0; i < dims - 1; i++) { + base_linear_index += index[i] * prod[i]; + } + + // Now iterate over the last dimension + uint32_t out_stick_id = 0; + // Perform the read operation + noc_async_read_page(base_linear_index, s0, src_buffer_l1_addr); + // Reserve space in the output buffer + cb_reserve_back(cb_id_out0, 1); + noc_async_read_barrier(); + for (uint32_t l = starts[dims - 1]; l < ends[dims - 1]; l += strides[dims - 1]) { + // Write the element into the output buffer + volatile tt_l1_ptr uint16_t* out_stick = reinterpret_cast(get_write_ptr(cb_id_out0)); + out_stick[out_stick_id] = in_stick[l]; // Assuming you write one element at a time + out_stick_id++; + } + cb_push_back(cb_id_out0, 1); + + // Increment the indices for the first dims-1 dimensions + for (int32_t i = dims - 2; i >= 0; i--) { // Start from the last of the first dims-1 + index[i] += strides[i]; + if (index[i] < ends[i]) { + break; // Successfully incremented this dimension, no carry over + } else { + index[i] = starts[i]; // Reset this dimension and carry over to the next + if (i == 0) { + done = true; // If the first dimension is reset, we've completed all iterations + } + } + } + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp index 2585f3561bc..e7215850ea1 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp @@ -59,7 +59,7 @@ inline std::vector, std::vector>> get_ accumulated_total_per_dim[i] = num_total_dim * accumulated_total_per_dim[i - 1]; } - uint32_t unpadded_row_size_bytes_offset = tt::round_up(unpadded_row_size_bytes, TILE_WIDTH / 2); + uint32_t unpadded_row_size_bytes_offset = output_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? tt::round_up(unpadded_row_size_bytes, TILE_WIDTH) : tt::round_up(unpadded_row_size_bytes, TILE_WIDTH / 2); vector common_reader_kernel_args = { input_tensor.buffer()->address() + output_tensor_start[-1] * output_tensor.element_size(), @@ -261,7 +261,7 @@ operation::ProgramWithCallbacks slice_rm_multi_core( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } -operation::ProgramWithCallbacks slice_rm_strided_single_core(const Tensor& a, Tensor& output, const tt::tt_metal::LegacyShape& output_tensor_start, const tt::tt_metal::LegacyShape& output_tensor_end, const tt::tt_metal::LegacyShape& step) { +operation::ProgramWithCallbacks slice_rm_strided_single_core_n_dims(const Tensor& a, Tensor& output, const tt::tt_metal::LegacyShape& output_tensor_start, const tt::tt_metal::LegacyShape& output_tensor_end, const tt::tt_metal::LegacyShape& step) { // TODO: multi core implementation - work division is not trivial as we need to determine the N/C/H/W start and end points for each split, and base that off stride tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); const tt::tt_metal::LegacyShape output_shape = output.get_legacy_shape(); @@ -291,20 +291,13 @@ operation::ProgramWithCallbacks slice_rm_strided_single_core(const Tensor& a, Te tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved.cpp", + "ttnn/cpp/ttnn/operations/data_movement/slice/device/kernels/dataflow/strided_slice_reader_rm_interleaved_nd.cpp", core, tt::tt_metal::ReaderDataMovementConfig( { src_is_dram, - input_shape[3], - input_shape[2], - input_shape[1], - input_shape[0], - step[3], - step[2], - step[1], - step[0], (uint32_t) page_size_input, + (uint32_t) input_shape.rank(), } )); @@ -320,26 +313,24 @@ operation::ProgramWithCallbacks slice_rm_strided_single_core(const Tensor& a, Te } )); - tt::tt_metal::SetRuntimeArgs( - program, unary_reader_kernel_id, core, - { - a.buffer()->address(), - output_tensor_start[3], - output_tensor_start[2], - output_tensor_start[1], - output_tensor_start[0], - output_tensor_end[3], - output_tensor_end[2], - output_tensor_end[1], - output_tensor_end[0], + std::vector reader_runtime_args; + reader_runtime_args.reserve(1 + (4*input_shape.rank())); + reader_runtime_args.push_back(a.buffer()->address()); - }); + reader_runtime_args.insert(reader_runtime_args.end(), input_shape.begin(), input_shape.end()); + reader_runtime_args.insert(reader_runtime_args.end(), output_tensor_start.begin(), output_tensor_start.end()); + reader_runtime_args.insert(reader_runtime_args.end(), output_tensor_end.begin(), output_tensor_end.end()); + reader_runtime_args.insert(reader_runtime_args.end(), step.begin(), step.end()); + + tt::tt_metal::SetRuntimeArgs( + program, unary_reader_kernel_id, core, reader_runtime_args); + uint32_t pages = output.volume() / output_shape[-1]; tt::tt_metal::SetRuntimeArgs( program, unary_writer_kernel_id, core, { output.buffer()->address(), - output_shape[0]*output_shape[1]*output_shape[2], + pages, }); auto override_address_callback = [unary_reader_kernel_id, unary_writer_kernel_id]( @@ -962,7 +953,7 @@ operation::ProgramWithCallbacks slice_multi_core( case Layout::ROW_MAJOR: return a.is_sharded() ? slice_rm_multi_core_sharded(a, output, output_tensor_start, output_tensor_end) : (has_step ? - slice_rm_strided_single_core(a, output, output_tensor_start, output_tensor_end, step) : + slice_rm_strided_single_core_n_dims(a, output, output_tensor_start, output_tensor_end, step) : slice_rm_multi_core(a, output, output_tensor_start, output_tensor_end)); case Layout::TILE: return slice_tile_multi_core(a, output, output_tensor_start, output_tensor_end); default: TT_ASSERT(false, "Unsupported Layout"); diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp index cff6bf540ec..e65f1bba9ce 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp @@ -10,17 +10,10 @@ #include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/common/constants.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/copy/copy.hpp" - +#include "ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/common/common.hpp" namespace ttnn::operations::data_movement { -namespace detail { - static inline uint32_t wrap_index(int index, int size) { - return index < 0 ? size + index : index; - } - static inline uint32_t round_up_to_multiple_of_32(uint32_t value) { - return value == 0 ? 32u : ((value + 31u) & ~31); - } -} template ttnn::Tensor SliceOperation::invoke( @@ -33,8 +26,10 @@ ttnn::Tensor SliceOperation::invoke( const std::optional& optional_output_tensor) { // Ensure start and end vectors have matching sizes and correct tensor rank - uint32_t input_rank = input_tensor.get_shape().rank(); - const auto &input_shape = input_tensor.get_shape(); + + const auto &input_shape = input_tensor.get_logical_shape(); + uint32_t input_rank = input_shape.rank(); + bool no_step = std::ranges::all_of(step, [](uint32_t s) { return s == 1; }); bool starts_zero = std::ranges::all_of(begins, [](uint32_t s) { return s == 0; }); bool ends_max = true; @@ -44,6 +39,7 @@ ttnn::Tensor SliceOperation::invoke( break; } } + if (no_step && starts_zero && ends_max) { if (input_tensor.storage_type() == StorageType::DEVICE) { auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); @@ -56,57 +52,80 @@ ttnn::Tensor SliceOperation::invoke( TT_FATAL(begins.size() == ends.size(), "Start {} and end {} must have the same size", begins.size(), ends.size()); TT_FATAL(step.size() == begins.size(), "Step {} must have the same size as start {} and end", step.size(), begins.size()); - // Create modified vectors with appropriate size (max rank 4) and wrap indices - Tensor input_4d = (input_rank < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor; - auto padded_4d_shape = input_4d.get_legacy_shape(); - std::array modified_begins = {0, 0, 0, 0}; - std::array modified_ends = {padded_4d_shape[0], padded_4d_shape[1], padded_4d_shape[2], padded_4d_shape[3]}; - std::array modified_step = {1, 1, 1, 1}; - uint32_t rank_diff = 4 - input_rank; - - // Ideally we would call the 4D array implementation of slice here and then handle reshapes and padding outside of it but it's not ready yet - // Insert values for start, step and end, wrapping indices using detail::wrap_index - // should be able to skip wrap_index if T is uint32_t + bool rm_only = !no_step && input_tensor.get_layout() == Layout::TILE; + Tensor input = input_tensor; + if (rm_only) { + TT_FATAL(input.get_dtype() == DataType::BFLOAT16, "Strided slice is not supported for BFLOAT8 tensors"); + input = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + } + + // Unsqueeze tensor to 4D if necessary + if (input_rank < 4) { + input = ttnn::unsqueeze_to_4D(input); + } + + auto padded_shape = input.get_padded_shape(); + size_t adjusted_rank = padded_shape.rank(); // Now adjusted to 4 after unsqueeze + + // Create modified vectors with wrapped indices and adjust them to match the tensor's rank + std::vector modified_begins(adjusted_rank, 0); + std::vector modified_ends = padded_shape.as_vector(); + std::vector modified_step(adjusted_rank, 1); + + size_t rank_diff = adjusted_rank - input_rank; + + // Wrap indices and adjust begins, ends, and step for (size_t i = 0; i < begins.size(); ++i) { - modified_begins[i + rank_diff] = detail::wrap_index(begins[i], input_tensor.get_shape()[i]); - modified_ends[i + rank_diff] = detail::wrap_index(ends[i], input_tensor.get_shape()[i]); - modified_step[i + rank_diff] = step[i]; + size_t idx = i + rank_diff; + + if constexpr (std::is_signed_v) { + modified_begins[idx] = wrap_index(begins[i], input_shape[i]); + modified_ends[idx] = wrap_index(ends[i], input_shape[i]); + modified_step[idx] = static_cast(step[i]); + } else { + modified_begins[idx] = begins[i]; + modified_ends[idx] = ends[i]; + modified_step[idx] = step[i]; + } } - auto output_dim_i = [&modified_begins, &modified_step] (size_t i, const std::array &modified_ends) { + auto output_dim_i = [&modified_begins, &modified_step](size_t i, const std::vector &modified_ends) { return (modified_ends[i] - modified_begins[i] + modified_step[i] - 1) / modified_step[i]; }; - std::array padded_ends = modified_ends; - if (input_tensor.layout() == Layout::TILE) { - padded_ends[2] = detail::round_up_to_multiple_of_32(padded_ends[2]); - padded_ends[3] = detail::round_up_to_multiple_of_32(padded_ends[3]); + std::vector padded_ends = modified_ends; + if (input.layout() == Layout::TILE) { + padded_ends[adjusted_rank - 2] = std::max(tt::round_up(padded_ends[adjusted_rank - 2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT); + padded_ends[adjusted_rank - 1] = std::max(tt::round_up(padded_ends[adjusted_rank - 1], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH); } - std::vector actual_shape, padded_shape; + + std::vector actual_shape, final_padded_shape; actual_shape.reserve(input_rank); - padded_shape.reserve(input_rank); + final_padded_shape.reserve(input_rank); bool empty = false; - for (int i = 0; i < input_rank; ++i) { - // Check that end indices are greater than or equal to start indices (empty tensor where end=start is supported) - TT_FATAL(modified_ends[i + rank_diff] >= modified_begins[i + rank_diff], "End {} must be greater than or equal to start {}", modified_ends[i + rank_diff], modified_begins[i + rank_diff]); - auto val = output_dim_i(i + rank_diff, modified_ends); + + // Compute actual and padded shapes for the original input rank + for (size_t i = 0; i < input_rank; ++i) { + size_t idx = i + rank_diff; + TT_FATAL(modified_ends[idx] >= modified_begins[idx], "End {} must be greater than or equal to start {}", modified_ends[idx], modified_begins[idx]); + auto val = output_dim_i(idx, modified_ends); if (val == 0) { empty = true; } actual_shape.push_back(val); - padded_shape.push_back(std::max(output_dim_i(i + rank_diff, padded_ends), (uint32_t)1)); + final_padded_shape.push_back(std::max(output_dim_i(idx, padded_ends), static_cast(1))); } - ttnn::Shape output_shape(actual_shape, padded_shape); - // PyTorch supports final dimension = 0 (start = end, where end is inclusive) so >= is okay, just return an empty tensor + ttnn::Shape output_shape(actual_shape, final_padded_shape); + if (empty) { TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Host tensor slice cannot return a scalar or empty tensor"); return ttnn::empty(output_shape, input_tensor.dtype(), input_tensor.layout(), input_tensor.device(), memory_config_arg.value_or(input_tensor.memory_config())); } - // Early exit if slice is a no-op (ends = padding ends and step = 1 for all dimensions) - if (tt::tt_metal::LegacyShape(padded_shape) == input_tensor.get_legacy_shape() and no_step) { + // Early exit if slice is a no-op + if (ttnn::SimpleShape(final_padded_shape) == input.get_padded_shape() && no_step) { if (input_tensor.storage_type() == StorageType::DEVICE) { auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); auto res = ttnn::to_memory_config(input_tensor, memory_config, std::nullopt); @@ -117,33 +136,28 @@ ttnn::Tensor SliceOperation::invoke( if (input_tensor.storage_type() != StorageType::DEVICE) { TT_FATAL(no_step, "Host tensor slice does not support strides"); - // if we support negative strides, we can't do this early exit - if (input_tensor.get_legacy_shape() == actual_shape) { + if (input_tensor.get_padded_shape() == actual_shape) { return input_tensor; } else { - auto input_4d_rm = ttnn::to_layout(input_4d, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); - auto output_4d = input_4d_rm.unpad(ttnn::SimpleShape(modified_begins), ttnn::SimpleShape(modified_ends)); - auto output_4d_rm = ttnn::to_layout(output_4d, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); - return ttnn::reshape(output_4d_rm, output_shape); + input = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + input = input.unpad(ttnn::SimpleShape(modified_begins), ttnn::SimpleShape(modified_ends)); + input = ttnn::to_layout(input, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); + return ttnn::reshape(input, output_shape); } - } - else { - // TODO: Generalize this early exit of slice for other cases - auto& input_tensor_shape = input_4d.get_legacy_shape(); + } else { + const auto& input_tensor_shape = input.get_padded_shape(); auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); - if (input_4d.is_sharded() && input_4d.memory_config() == memory_config && - input_tensor_shape.rank() > 1) { + + if (input.is_sharded() && input.memory_config() == memory_config && input_tensor_shape.rank() > 1) { TT_FATAL(no_step, "Sharded tensor slice implementation does not support striding"); uint32_t i; - // Require all leading dims to be 1 (TODO: This can be relaxed to support outermost non-1 dim unpadding) bool in_place_unpad = true; - for (i = 0; i < input_4d.get_legacy_shape().rank() - 2; ++i) { - in_place_unpad &= - modified_begins[i] == 0 && modified_ends[i] == 1 && input_tensor_shape[i] == 1; + for (i = 0; i < input_tensor_shape.rank() - 2; ++i) { + in_place_unpad &= modified_begins[i] == 0 && modified_ends[i] == 1 && input_tensor_shape[i] == 1; } in_place_unpad &= modified_begins[i] == 0 && - tt::div_up(modified_ends[i], input_4d.shard_spec().value().shape[0]) == - tt::div_up(input_tensor_shape[i], input_4d.shard_spec().value().shape[0]); + tt::div_up(modified_ends[i], input.shard_spec().value().shape[0]) == + tt::div_up(input_tensor_shape[i], input.shard_spec().value().shape[0]); i++; in_place_unpad &= modified_begins[i] == 0 && modified_ends[i] == input_tensor_shape[i]; if (in_place_unpad) { @@ -152,16 +166,18 @@ ttnn::Tensor SliceOperation::invoke( } auto res = operation::run( - SliceDeviceOperation{ - tt::tt_metal::LegacyShape(modified_begins), - tt::tt_metal::LegacyShape(padded_ends), - modified_step, - memory_config}, - {input_4d}, {}, {optional_output_tensor}, queue_id) + SliceDeviceOperation{ + tt::tt_metal::LegacyShape(modified_begins), + tt::tt_metal::LegacyShape(padded_ends), + tt::tt_metal::LegacyShape(modified_step), + memory_config}, + {input}, {}, {optional_output_tensor}, queue_id) .at(0); - return ttnn::reshape(res, output_shape); + res = ttnn::reshape(res, output_shape); + return rm_only ? ttnn::to_layout(res, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr) : res; } } + template ttnn::Tensor SliceOperation::invoke( const ttnn::Tensor& input_tensor, @@ -184,7 +200,7 @@ ttnn::Tensor SliceOperation::invoke( const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { - const auto& padded_input_shape = input_tensor.get_shape().with_tile_padding(); + const auto& padded_input_shape = input_tensor.get_padded_shape(); TT_FATAL(padded_input_shape.rank() == 4, "Input tensor must have rank 4"); bool no_step = step[0] == 1 && step[1] == 1 && step[2] == 1 && step[3] == 1; @@ -198,13 +214,18 @@ ttnn::Tensor SliceOperation::invoke( } return input_tensor; } + bool rm_only = !no_step && input_tensor.get_layout() == Layout::TILE; + ttnn::Tensor input = input_tensor; + if (rm_only) { + input = ttnn::to_layout(input_tensor, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + } - const bool tiled = input_tensor.get_layout() == Layout::TILE; - bool on_device = input_tensor.storage_type() == StorageType::DEVICE; + const bool tiled = input.get_layout() == Layout::TILE; + bool on_device = input.storage_type() == StorageType::DEVICE; std::array actual_shape; std::array padded_shape; - const std::array padded_ends = tiled ? std::array({ends[0], ends[1], detail::round_up_to_multiple_of_32(ends[2]), detail::round_up_to_multiple_of_32(ends[3])}) : ends; + const std::array padded_ends = tiled ? std::array({ends[0], ends[1], std::max(tt::round_up(ends[2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT), std::max(tt::round_up(ends[3], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH)}) : ends; bool empty = false; for (int i = 0; i < 4; ++i) { TT_FATAL(ends[i] >= begins[i], "End {} must be greater than or equal to start {}", ends[i], begins[i]); @@ -219,58 +240,59 @@ ttnn::Tensor SliceOperation::invoke( if (empty) { TT_FATAL(on_device, "Host tensor slice cannot return a scalar or empty tensor"); - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); - return ttnn::empty(output_shape, input_tensor.dtype(), input_tensor.layout(), - input_tensor.device(), memory_config); + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); + return ttnn::empty(output_shape, input.dtype(), input_tensor.layout(), + input.device(), memory_config); } // Early exit if slice is a no-op if (ttnn::Shape(padded_shape) == padded_input_shape && no_step) { - if (input_tensor.storage_type() == StorageType::DEVICE) { - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); - auto res = ttnn::to_memory_config(input_tensor, memory_config, std::nullopt); + if (input.storage_type() == StorageType::DEVICE) { + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); + auto res = ttnn::to_memory_config(input, memory_config, std::nullopt); return ttnn::reshape(res, output_shape); } - return ttnn::reshape(input_tensor, output_shape); // change to view + return ttnn::reshape(input, output_shape); // change to view } if (on_device) { - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); // Check for in-place unpad optimization - if (input_tensor.is_sharded() && input_tensor.memory_config() == memory_config && padded_input_shape.rank() > 1) { + if (input.is_sharded() && input.memory_config() == memory_config && padded_input_shape.rank() > 1) { TT_FATAL(no_step, "Sharded tensor slice implementation does not support striding"); bool in_place_unpad = true; for (int i = 0; i < 2; ++i) { in_place_unpad &= begins[i] == 0 && ends[i] == 1 && padded_input_shape[i] == 1; } in_place_unpad &= begins[2] == 0 && - tt::div_up(ends[2], input_tensor.shard_spec().value().shape[0]) == - tt::div_up(padded_input_shape[2], input_tensor.shard_spec().value().shape[0]); + tt::div_up(ends[2], input.shard_spec().value().shape[0]) == + tt::div_up(padded_input_shape[2], input.shard_spec().value().shape[0]); in_place_unpad &= begins[3] == 0 && ends[3] == padded_input_shape[3]; if (in_place_unpad) { - return ttnn::reshape(input_tensor, output_shape); + return ttnn::reshape(input, output_shape); } } - auto res = operation::run( + input = operation::run( SliceDeviceOperation{ begins, padded_ends, step, memory_config}, - {input_tensor}, {}, {optional_output_tensor}, queue_id)[0]; - return ttnn::reshape(res, output_shape); + {input}, {}, {optional_output_tensor}, queue_id)[0]; + input = ttnn::reshape(input, output_shape); + return rm_only ? ttnn::to_layout(input, input.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr) : input; } TT_FATAL(no_step, "Host tensor slice does not support strides"); - if (input_tensor.get_legacy_shape() == actual_shape) { - return input_tensor; + if (input.get_padded_shape() == actual_shape) { + return input; } else { - auto input_4d_rm = ttnn::to_layout(input_tensor, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + auto input_4d_rm = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); auto output_4d = input_4d_rm.unpad(ttnn::SimpleShape(begins), ttnn::SimpleShape(ends)); - auto output_4d_rm = ttnn::to_layout(output_4d, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); + auto output_4d_rm = ttnn::to_layout(output_4d, input.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); return ttnn::reshape(output_4d_rm, output_shape); } } @@ -301,7 +323,6 @@ ttnn::Tensor SliceOperation::invoke( return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg); } - template ttnn::Tensor SliceOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp index 342b282ec96..1aeae87eefd 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp @@ -17,26 +17,31 @@ namespace py = pybind11; void bind_slice(py::module& module) { auto doc = R"doc( - slice(input_tensor: ttnn.Tensor, slice_start: List[int[tensor rank], slice_end: List[int[tensor rank], value: Union[int, float], *, Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor - Returns a sliced tensor. If the input tensor is on host, the slice will be performed on host, and if its on device it will be performed on device. - Equivalent pytorch code: + Args: + input_tensor: Input Tensor. + slice_start: Start indices of input tensor. Values along each dim must be < input_tensor_shape[i]. + slice_end: End indices of input tensor. Values along each dim must be < input_tensor_shape[i]. + slice_step: (Optional[List[int[tensor rank]]) Step size for each dim. Default is None, which works out be 1 for each dimension. - .. code-block:: python + Keyword Args: + memory_config Memory Config of the output tensor + queue_id (uint8, optional) command queue id - output_tensor = input_tensor[output_start: output_end] + Returns: + ttnn.Tensor: the output tensor. - Args: - * :attr:`input_tensor`: Input Tensor. - * :attr:`slice_start`: Start indices of input tensor. Values along each dim must be < input_tensor_shape[i]. - * :attr:`slice_end`: End indices of input tensor. Values along each dim must be < input_tensor_shape[i]. - * :attr:`step` (Optional[List[int[tensor rank]]): Step size for each dim. Default is None, which works out be 1 for each dimension. + Example: + >>> tensor = ttnn.slice(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device), [0, 0, 0, 0], [1, 1, 64, 16], [1, 1, 2, 1]) + >>> print(tensor.shape) + [1, 1, 32, 16] + >>> input = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device) + >>> output = ttnn.slice(input, [0, 0, 0, 0], [1, 1, 32, 32]) + >>> print(output.shape) + [1, 1, 32, 32] + )doc"; - Keyword Args: - * :attr:`memory_config`: Memory Config of the output tensor - * :attr:`queue_id` (Optional[uint8]): command queue id - )doc"; // TODO: implementing the array version and overloading the pybind with all the possible array sizes is better than a vector with a fixed size default value using OperationType = decltype(ttnn::slice); diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 45e024ac82f..2c820effa02 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -64,18 +64,13 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor: if len(slices) > input_rank: raise RuntimeError(f"Too many slices for tensor of rank {input_rank}") - if input_rank <= 4: - slice_start = [_slice.start if _slice.start is not None else 0 for _slice in slices] - slice_end = [ - _slice.stop if _slice.stop is not None else input_tensor.shape[i] for i, _slice in enumerate(slices) - ] - slice_step = [_slice.step if _slice.step is not None else 1 for _slice in slices] + slice_start = [_slice.start if _slice.start is not None else 0 for _slice in slices] + slice_end = [_slice.stop if _slice.stop is not None else input_tensor.shape[i] for i, _slice in enumerate(slices)] + slice_step = [_slice.step if _slice.step is not None else 1 for _slice in slices] - output = ttnn.slice(input_tensor, slice_start, slice_end, slice_step) + output = ttnn.slice(input_tensor, slice_start, slice_end, slice_step) - return output - - raise NotImplementedError + return output def _preprocess_shape(input_shape, shape):