diff --git a/models/demos/wormhole/mamba/tt/mamba_ssm.py b/models/demos/wormhole/mamba/tt/mamba_ssm.py index 6c6afe72b3e..a9124cb0265 100644 --- a/models/demos/wormhole/mamba/tt/mamba_ssm.py +++ b/models/demos/wormhole/mamba/tt/mamba_ssm.py @@ -267,12 +267,12 @@ def forward(self, x): ttnn.deallocate(abar2) bmulx0_sharded = ttnn.to_memory_config(bmulx0, self.configs["sharded_scan"]) ttnn.deallocate(bmulx0) - hidden_states_sharded = ttl.operations.primary.transformers.ssm_prefix_scan( + hidden_states_sharded = ttnn.experimental.prefix_scan( abar2_sharded, bmulx0_sharded, prev_hidden_state, - output_mem_config=self.configs["sharded_scan"], - output_dtype=ttnn.bfloat8_b, + memory_config=self.configs["sharded_scan"], + dtype=ttnn.bfloat8_b, ) ttnn.deallocate(abar2_sharded) ttnn.deallocate(bmulx0_sharded) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_prefix_scan.py b/tests/ttnn/unit_tests/operations/test_ssm_prefix_scan.py similarity index 95% rename from tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_prefix_scan.py rename to tests/ttnn/unit_tests/operations/test_ssm_prefix_scan.py index 479b905d507..b15bfcf76b1 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_prefix_scan.py +++ b/tests/ttnn/unit_tests/operations/test_ssm_prefix_scan.py @@ -4,6 +4,7 @@ import torch +import ttnn import tt_lib as ttl import pytest from loguru import logger @@ -64,9 +65,7 @@ def run_ssm_prefix_scan(L: int, E: int, N: int, num_cores: int, dtype, device): .to(device, h_memory_config) ) - actual = ttl.operations.primary.transformers.ssm_prefix_scan( - a, bx, h_prev, output_mem_config=memory_config, output_dtype=dtype - ) + actual = ttnn.experimental.prefix_scan(a, bx, h_prev, memory_config=memory_config, dtype=dtype) assert list(actual.get_legacy_shape()) == list(expected.shape) assert actual.dtype == dtype @@ -164,9 +163,7 @@ def to_device(x): a_chunk = to_device(a_chunks[idx]) bx_chunk = to_device(bx_chunks[idx]) - h_chunk = ttl.operations.primary.transformers.ssm_prefix_scan( - a_chunk, bx_chunk, h_prev, output_mem_config=memory_config, output_dtype=dtype - ) + h_chunk = ttnn.experimental.prefix_scan(a_chunk, bx_chunk, h_prev, memory_config=memory_config, dtype=dtype) actual.append(tt2torch_tensor(h_chunk)) actual = torch.concat(actual, dim=2) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index b81d83a6680..1b75ee63685 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -74,6 +74,10 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/multi_core/groupnorm_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/transformer/device/transformer_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_multi_core.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt index 6f6f69a685e..08c43af5f29 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt @@ -138,7 +138,6 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_eltwise_mul/multi_core_ssm_eltwise_mul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_1d_sum_reduce/multi_core_ssm_1d_sum_reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/split/split_tiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/split/split_last_dim_two_chunks_tiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/concat/multi_core/concat_op_multi_core.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index b1c366b72da..88484a6f6e7 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -472,58 +472,6 @@ operation::ProgramWithCallbacks SSM1DSumReduce::create_program( return multi_core_ssm_1d_sum_reduce(input_tensor_a, output_tensor, math_fidelity, device_compute_with_storage_grid_size); } -void SSMPrefixScan::validate(const std::vector& input_tensors) const { - TT_FATAL(input_tensors.size() == 3, "Expected 3 input tensors (A, Bx, H)"); - - const auto& a = input_tensors.at(0); - const auto& bx = input_tensors.at(1); - TT_FATAL(a.dtype() == bx.dtype(), "Expected input tensors to have the same data type"); - TT_FATAL(a.layout() == Layout::TILE && bx.layout() == Layout::TILE, "Expected input tensors to be tile layout"); - TT_FATAL(a.get_legacy_shape() == bx.get_legacy_shape(), "Expected input tensors to have the same shape"); - - const auto& shape = a.get_legacy_shape(); - TT_FATAL(shape.rank() == 4, "Expected input tensors to be rank 4"); - TT_FATAL(shape[0] == 1 && shape[1] == 1, "Dimension 0 and 1 should be size 1"); - TT_FATAL(shape[2] >= TILE_HEIGHT && shape[2] % TILE_HEIGHT == 0, "Sequence length should be a multiple of 32"); - - const auto& h = input_tensors.at(2); - TT_FATAL(h.dtype() == DataType::BFLOAT16, "Expected initial hidden state to be bfloat16"); - TT_FATAL(h.layout() == Layout::ROW_MAJOR, "Expected initial hidden state to be row-major"); - //TT_FATAL(h.get_legacy_shape() == {1, 1, 1, shape[3]}, "Expected initial hidden state to have the same hidden size as A and Bx"); - - TT_FATAL(a.is_sharded() && bx.is_sharded() && h.is_sharded(), "Expected input tensors to be sharded"); - TT_FATAL(a.shard_spec().has_value() && bx.shard_spec().has_value() && h.shard_spec().has_value(), "Expected input tensors to be sharded"); - TT_FATAL( - a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, - "Expected A tensor to be row major orientation"); - TT_FATAL( - bx.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, - "Expected Bx tensor to be row major orientation"); - TT_FATAL( - h.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, - "Expected h tensor to be row major orientation"); -} - -std::vector SSMPrefixScan::compute_output_shapes(const std::vector& input_tensors) const { - const auto& a = input_tensors.at(0); - return {a.get_legacy_shape()}; -} - -std::vector SSMPrefixScan::create_output_tensors(const std::vector& input_tensors) const { - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config); -} - -operation::ProgramWithCallbacks SSMPrefixScan::create_program( - const std::vector& input_tensors, std::vector& output_tensors) const { - const auto& a = input_tensors.at(0); - const auto& bx = input_tensors.at(1); - const auto& h = input_tensors.at(2); - auto& output = output_tensors.at(0); - auto device_compute_with_storage_grid_size = a.device()->compute_with_storage_grid_size(); - return multi_core_ssm_prefix_scan(a, bx, h, output, math_fidelity, device_compute_with_storage_grid_size); -} - } // namespace transformers } // namespace primary } // namespace operations diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/transformer_tms.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/transformer_tms.hpp index c64b514257f..54eabd11097 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/transformer_tms.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/transformer_tms.hpp @@ -180,42 +180,6 @@ inline Tensor ssm_1d_sum_reduce(const Tensor &input_tensor_a, const MemoryConfig return output_tensors.at(0); } -struct SSMPrefixScan { - MemoryConfig output_mem_config; - DataType output_dtype; - MathFidelity math_fidelity; - - void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; - operation::ProgramWithCallbacks create_program( - const std::vector& input_tensors, std::vector& output_tensors) const; -}; - -inline Tensor ssm_prefix_scan( - const Tensor& a, - const Tensor& bx, - const Tensor& h, - const MemoryConfig& mem_config, - std::optional output_dtype = std::nullopt, - MathFidelity math_fidelity = MathFidelity::HiFi4) { - std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a, bx}))}; - operation::launch_op( - [mem_config, output_dtype, math_fidelity]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - const auto& a = input_tensors.at(0); - const auto& bx = input_tensors.at(1); - const auto& h = input_tensors.at(2); - return operation::run( - SSMPrefixScan{mem_config, output_dtype.value_or(a.get_dtype()), math_fidelity}, input_tensors); - }, - {a, bx, h}, - output_tensors); - return output_tensors.at(0); -} - } // namespace transformers } // namespace primary diff --git a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/transformers/module.hpp b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/transformers/module.hpp index c52fc0ae8d0..7e2634105f7 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/transformers/module.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/transformers/module.hpp @@ -44,19 +44,6 @@ void py_module(py::module& m_transformers) { Performs a custom reduction along dim 3 which is used in the SSM block of the Mamba architecture. Performs the following PyTorch equivalent (where latent_size = 32): x = torch.sum(x.reshape(1, 1, shape[2], shape[3] // latent_size, latent_size), dim=-1).reshape(1, 1, shape[2], shape[3] // latent_size) )doc"); - m_transformers.def( - "ssm_prefix_scan", - &ssm_prefix_scan, - py::arg().noconvert(), - py::arg().noconvert(), - py::arg().noconvert(), - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("output_dtype").noconvert() = std::nullopt, - py::arg("math_fidelity").noconvert() = MathFidelity::HiFi4, - R"doc( - Performs a prefix scan to produce the SSM hidden states across an entire sequence. All input and output tensors are expected to be shape [1, 1, L, 2EN] where E = 2560 and N = 32. L can be any multiple of 32.)doc"); - - py::class_(m_transformers, "SDPADefaultProgramConfig") .def(py::init<>()); diff --git a/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.hpp b/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.hpp index 083beb8f36b..0b7822ab94e 100644 --- a/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/experimental_pybind.hpp @@ -9,6 +9,7 @@ #include "ttnn/operations/experimental/transformer/transformer_pybind.hpp" #include "ttnn/operations/experimental/reduction/argmax/argmax_pybind.hpp" +#include "ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.hpp" namespace ttnn::operations::experimental { @@ -17,6 +18,7 @@ void py_module(py::module& module) { transformer::detail::bind_experimental_transformer_operations(module); reduction::detail::bind_argmax_operation(module); reduction::detail::bind_argmin_operation(module); + ssm::detail::bind_prefix_scan(module); } } // namespace ttnn::operations::experimental diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_prefix_scan.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/reader_ssm_prefix_scan.cpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_prefix_scan.cpp rename to ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/reader_ssm_prefix_scan.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_prefix_scan.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_prefix_scan.cpp rename to ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_ssm_prefix_scan.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/writer_ssm_prefix_scan.cpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_ssm_prefix_scan.cpp rename to ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/writer_ssm_prefix_scan.cpp diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_op.cpp new file mode 100644 index 00000000000..6f6f802e455 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_op.cpp @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "prefix_scan_op.hpp" + +#include "prefix_scan_program_factory.hpp" + +namespace ttnn::operations::experimental::ssm { + +void PrefixScan::validate(const std::vector& input_tensors) const { + TT_FATAL(input_tensors.size() == 3, "Expected 3 input tensors (A, Bx, H)"); + + const auto& a = input_tensors[0]; + const auto& bx = input_tensors[1]; + TT_FATAL(a.dtype() == bx.dtype(), "Expected input tensors to have the same data type"); + TT_FATAL(a.layout() == Layout::TILE && bx.layout() == Layout::TILE, "Expected input tensors to be tile layout"); + TT_FATAL(a.get_legacy_shape() == bx.get_legacy_shape(), "Expected input tensors to have the same shape"); + + const auto& shape = a.get_legacy_shape(); + TT_FATAL(shape.rank() == 4, "Expected input tensors to be rank 4"); + TT_FATAL(shape[0] == 1 && shape[1] == 1, "Dimension 0 and 1 should be size 1"); + TT_FATAL(shape[2] >= TILE_HEIGHT && shape[2] % TILE_HEIGHT == 0, "Sequence length should be a multiple of 32"); + + const auto& h = input_tensors.at(2); + TT_FATAL(h.dtype() == DataType::BFLOAT16, "Expected initial hidden state to be bfloat16"); + TT_FATAL(h.layout() == Layout::ROW_MAJOR, "Expected initial hidden state to be row-major"); + + TT_FATAL(a.is_sharded() && bx.is_sharded() && h.is_sharded(), "Expected input tensors to be sharded"); + TT_FATAL( + a.shard_spec().has_value() && bx.shard_spec().has_value() && h.shard_spec().has_value(), + "Expected input tensors to be sharded"); + TT_FATAL( + a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, + "Expected A tensor to be row major orientation"); + TT_FATAL( + bx.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, + "Expected Bx tensor to be row major orientation"); + TT_FATAL( + h.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, + "Expected h tensor to be row major orientation"); +} + +std::vector PrefixScan::compute_output_shapes(const std::vector& input_tensors) const { + const auto& a = input_tensors.at(0); + return {a.get_legacy_shape()}; +} + +std::vector PrefixScan::create_output_tensors(const std::vector& input_tensors) const { + return operation::generic_create_output_tensors( + *this, input_tensors, this->dtype, Layout::TILE, this->memory_config); +} + +operation::ProgramWithCallbacks PrefixScan::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + const auto& a = input_tensors.at(0); + const auto& bx = input_tensors.at(1); + const auto& h = input_tensors.at(2); + auto& output = output_tensors.at(0); + auto device_compute_with_storage_grid_size = a.device()->compute_with_storage_grid_size(); + return tt::operations::experimental::ssm::detail::multi_core_ssm_prefix_scan( + a, bx, h, output, math_fidelity, device_compute_with_storage_grid_size); +} +} // namespace ttnn::operations::experimental::ssm diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_op.hpp new file mode 100644 index 00000000000..57d732cdc21 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_op.hpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/run_operation.hpp" +#include "ttnn/tensor/tensor.hpp" + +namespace ttnn::operations::experimental::ssm { + +struct PrefixScan { + MemoryConfig memory_config; + DataType dtype; + MathFidelity math_fidelity; + + void validate(const std::vector& input_tensors) const; + std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector create_output_tensors(const std::vector& input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; +}; + +} // namespace ttnn::operations::experimental::ssm diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_program_factory.cpp similarity index 83% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp rename to ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_program_factory.cpp index 0cf52d2fa44..eee4ddb74ca 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_program_factory.cpp @@ -2,17 +2,13 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "prefix_scan_program_factory.hpp" + #include "ttnn/tensor/tensor.hpp" -#include "ttnn/operation.hpp" -#include "tt_metal/host_api.hpp" -using namespace tt::constants; -using namespace tt; +namespace tt::operations::experimental::ssm::detail { -namespace tt { -namespace operations { -namespace primary { -namespace transformers { +using namespace tt::constants; operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( const Tensor& a, @@ -21,7 +17,7 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( Tensor& output, MathFidelity math_fidelity, CoreCoord compute_with_storage_grid_size) { - tt_metal::Program program = tt_metal::CreateProgram(); + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); auto* a_buffer = a.buffer(); auto* bx_buffer = bx.buffer(); @@ -29,11 +25,11 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( auto* output_buffer = output.buffer(); TT_ASSERT(output_buffer != nullptr, "Output buffer should be allocated on device"); - const tt::DataFormat input_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - const uint32_t input_tile_size = tt_metal::detail::TileSize(input_format); + const tt::DataFormat input_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + const uint32_t input_tile_size = tt::tt_metal::detail::TileSize(input_format); const tt::DataFormat intermediary_format = tt::DataFormat::Float16_b; - const uint32_t intermediary_tile_size = tt_metal::detail::TileSize(intermediary_format); + const uint32_t intermediary_tile_size = tt::tt_metal::detail::TileSize(intermediary_format); const auto all_cores = a.shard_spec()->grid; const auto create_circular_buffer = [&program, &all_cores]( @@ -41,12 +37,12 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( uint32_t num_tiles, uint32_t tile_size, const tt::DataFormat& format, - Buffer* buffer = nullptr) -> tt_metal::CBHandle { + Buffer* buffer = nullptr) -> tt::tt_metal::CBHandle { auto config = CircularBufferConfig(num_tiles * tile_size, {{index, format}}).set_page_size(index, tile_size); if (buffer != nullptr) { config = config.set_globally_allocated_address(*buffer); } - return tt_metal::CreateCircularBuffer(program, all_cores, config); + return tt::tt_metal::CreateCircularBuffer(program, all_cores, config); }; const uint32_t sharded_sequence_length = a.shard_spec()->shape[0]; @@ -115,23 +111,23 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( cb_out_id, cb_h_acc_id}; - auto reader_kernel_id = tt_metal::CreateKernel( + auto reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_prefix_scan.cpp", + "ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/reader_ssm_prefix_scan.cpp", all_cores, - tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); - auto writer_kernel_id = tt_metal::CreateKernel( + auto writer_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_ssm_prefix_scan.cpp", + "ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/writer_ssm_prefix_scan.cpp", all_cores, - tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - auto compute_kernel_id = tt_metal::CreateKernel( + auto compute_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_prefix_scan.cpp", + "ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp", all_cores, - tt_metal::ComputeConfig{ + tt::tt_metal::ComputeConfig{ .math_fidelity = math_fidelity, .fp32_dest_acc_en = false, .math_approx_mode = false, @@ -155,10 +151,10 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( cb_bx_in, cb_h_in, cb_out](Program& program, const Tensor& a, const Tensor& bx, const Tensor& h, const Tensor& output) { - tt_metal::Buffer* a_buffer = a.buffer(); - tt_metal::Buffer* bx_buffer = bx.buffer(); - tt_metal::Buffer* h_buffer = h.buffer(); - tt_metal::Buffer* output_buffer = output.buffer(); + tt::tt_metal::Buffer* a_buffer = a.buffer(); + tt::tt_metal::Buffer* bx_buffer = bx.buffer(); + tt::tt_metal::Buffer* h_buffer = h.buffer(); + tt::tt_metal::Buffer* output_buffer = output.buffer(); UpdateDynamicCircularBufferAddress(program, cb_a_in, *a_buffer); UpdateDynamicCircularBufferAddress(program, cb_bx_in, *bx_buffer); @@ -210,7 +206,4 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } -} // namespace transformers -} // namespace primary -} // namespace operations -} // namespace tt +} // namespace tt::operations::experimental::ssm::detail diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_program_factory.hpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_program_factory.hpp new file mode 100644 index 00000000000..54174ef7177 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/prefix_scan_program_factory.hpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/operation.hpp" + +namespace tt::operations::experimental::ssm::detail { + +operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( + const Tensor& a, + const Tensor& bx, + const Tensor& h, + Tensor& output, + MathFidelity math_fidelity, + CoreCoord compute_with_storage_grid_size); + +} // namespace tt::operations::experimental::ssm::detail diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan.cpp new file mode 100644 index 00000000000..39a810a75d6 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan.cpp @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "prefix_scan.hpp" + +#include "device/prefix_scan_op.hpp" + +namespace ttnn::operations::experimental::ssm { + +constexpr uint8_t DefaultQueueId = 0; + +ttnn::Tensor ExecutePrefixScan::operator()( + uint8_t queue_id, + const Tensor& a, + const Tensor& bx, + const Tensor& h_prev, + const std::optional& memory_config, + const std::optional dtype, + const std::optional math_fidelity) { + auto program = PrefixScan{ + memory_config.value_or(a.memory_config()), + dtype.value_or(a.dtype()), + math_fidelity.value_or(MathFidelity::HiFi4)}; + return operation::run(program, {a, bx, h_prev}, {}, {}, queue_id).at(0); +} + +ttnn::Tensor ExecutePrefixScan::operator()( + const Tensor& a, + const Tensor& bx, + const Tensor& h_prev, + const std::optional& memory_config, + const std::optional dtype, + const std::optional math_fidelity) { + return operator()(DefaultQueueId, a, bx, h_prev, memory_config, dtype, math_fidelity); +} + +} // namespace ttnn::operations::experimental::ssm diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan.hpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan.hpp new file mode 100644 index 00000000000..a5881a00ecf --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan.hpp @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace ttnn::operations::experimental::ssm { + +struct ExecutePrefixScan { + static ttnn::Tensor operator()( + uint8_t queue_id, + const Tensor& a, + const Tensor& bx, + const Tensor& h_prev, + const std::optional& memory_config = std::nullopt, + const std::optional dtype = std::nullopt, + const std::optional math_fidelity = std::nullopt); + + static ttnn::Tensor operator()( + const Tensor& a, + const Tensor& bx, + const Tensor& h_prev, + const std::optional& memory_config = std::nullopt, + const std::optional dtype = std::nullopt, + const std::optional math_fidelity = std::nullopt); +}; + +} // namespace ttnn::operations::experimental::ssm + +namespace ttnn::experimental { + +constexpr auto prefix_scan = ttnn::register_operation_with_auto_launch_op< + "ttnn::experimental::prefix_scan", + ttnn::operations::experimental::ssm::ExecutePrefixScan>(); + +} // namespace ttnn::experimental diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.cpp new file mode 100644 index 00000000000..8f8f13296b1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "prefix_scan_pybind.hpp" + +#include +#include + +#include "prefix_scan.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" + +namespace ttnn::operations::experimental::ssm::detail { + +namespace py = pybind11; + +void bind_prefix_scan(py::module& module) { + using OperationType = decltype(ttnn::experimental::prefix_scan); + + const auto doc = + R"doc(Performs a prefix scan to produce the SSM hidden states across an entire sequence. All input and output tensors are expected to be shape [1, 1, L, 2EN]. Values of 2EN and L can be any multiple of 32.)doc"; + + ttnn::bind_registered_operation( + module, + ttnn::experimental::prefix_scan, + doc, + ttnn::pybind_overload_t{ + [](const OperationType& self, + const ttnn::Tensor& a, + const ttnn::Tensor& bx, + const ttnn::Tensor& h_prev, + const std::optional& memory_config, + const std::optional dtype, + const std::optional math_fidelity, + uint8_t queue_id) { return self(queue_id, a, bx, h_prev, memory_config, dtype, math_fidelity); }, + py::arg("a"), + py::arg("bx"), + py::arg("h_prev"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("dtype") = std::nullopt, + py::arg("math_fidelity") = std::nullopt, + py::arg("queue_id") = 0}); +} + +} // namespace ttnn::operations::experimental::ssm::detail diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.hpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.hpp new file mode 100644 index 00000000000..8387fb969bc --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/prefix_scan_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace ttnn::operations::experimental::ssm::detail { + +void bind_prefix_scan(pybind11::module& module); + +} // namespace ttnn::operations::experimental::ssm::detail