From f7c10a255f2cd24c27cb7cb2b80854f9276a91e2 Mon Sep 17 00:00:00 2001 From: Borys Bradel <164946524+bbradelTT@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:33:03 -0400 Subject: [PATCH] #9492: move matmul code to ttnn directory hierarchy (#10015) * #9492: move matmul code to ttnn directory hierarchy * #9492: Create copy of bmm zm fused bias kernel in tests because can't reference ttnn --- .../tt_eager/integration_tests/test_bert.cpp | 40 +- tests/tt_eager/ops/test_bmm_op.cpp | 2 +- ...ge_block_zm_fused_bias_activation_copy.cpp | 0 .../1_compute_mm/test_compute_mm.cpp | 2 +- tt_eager/tt_dnn/op_library/CMakeLists.txt | 9 - .../tt_dnn/op_library/complex/complex_ops.cpp | 1 - .../op_library/composite/composite_ops.cpp | 12 +- .../fully_connected/fully_connected_op.cpp | 12 +- .../tt_lib/csrc/operations/primary/module.hpp | 113 ------ ttnn/CMakeLists.txt | 11 +- ttnn/cpp/pybind11/operations/__init__.hpp | 2 +- ttnn/cpp/pybind11/operations/matmul.hpp | 100 ----- ttnn/cpp/ttnn/operations/conv2d.hpp | 2 +- .../matmul/device}/kernels/compute/bmm.cpp | 0 .../kernels/compute/bmm_large_block_zm.cpp | 0 ...m_large_block_zm_fused_bias_activation.cpp | 367 ++++++++++++++++++ ...ed_bias_activation_inline_untilize_out.cpp | 0 ...der_bmm_8bank_output_tiles_partitioned.cpp | 0 .../dataflow/reader_bmm_interleaved.cpp | 0 .../dataflow/reader_bmm_single_core.cpp | 0 ...reader_bmm_single_core_tilize_untilize.cpp | 0 .../dataflow/reader_bmm_tile_layout.cpp | 0 .../dataflow/reader_bmm_tile_layout_in0.cpp | 0 .../reader_bmm_tile_layout_in0_receiver.cpp | 0 ..._tile_layout_in0_receiver_dram_sharded.cpp | 0 ...mm_tile_layout_in0_sender_dram_sharded.cpp | 0 ...der_bmm_tile_layout_in0_sender_padding.cpp | 0 ...ayout_in0_sender_padding_block_sharded.cpp | 0 ..._sender_receiver_padding_block_sharded.cpp | 0 ...ile_layout_in1_receiver_writer_padding.cpp | 0 ...mm_tile_layout_in1_sender_dram_sharded.cpp | 0 ..._tile_layout_in1_sender_writer_padding.cpp | 0 .../reader_bmm_tile_layout_padding.cpp | 0 .../reader_writer_bmm_tile_layout_in1.cpp | 0 .../dataflow/writer_bmm_interleaved.cpp | 0 .../dataflow/writer_bmm_single_core_tiled.cpp | 0 .../dataflow/writer_bmm_tile_layout.cpp | 0 .../writer_bmm_tile_layout_padding.cpp | 0 .../operations/matmul/device/matmul_op.cpp | 2 +- .../operations/matmul/device/matmul_op.hpp | 0 .../device}/multi_core/bmm_op_multi_core.cpp | 8 +- .../bmm_op_multi_core_reuse.cpp | 9 +- ...op_multi_core_reuse_mcast_1d_optimized.cpp | 24 +- ...op_multi_core_reuse_mcast_2d_optimized.cpp | 20 +- ...ulti_core_reuse_dram_sharded_optimized.cpp | 8 +- .../bmm_op_multi_core_reuse_optimized.cpp | 11 +- .../bmm_op_multi_core_reuse_padding.cpp | 9 +- .../bmm_op_single_core_tilize_untilize.cpp | 8 +- .../ttnn/operations/{ => matmul}/matmul.cpp | 0 .../ttnn/operations/{ => matmul}/matmul.hpp | 2 +- .../ttnn/operations/matmul/matmul_pybind.hpp | 214 ++++++++++ ttnn/ttnn/operations/matmul.py | 12 +- 52 files changed, 704 insertions(+), 296 deletions(-) rename tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp => tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation_copy.cpp (100%) delete mode 100644 ttnn/cpp/pybind11/operations/matmul.hpp rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/compute/bmm.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/compute/bmm_large_block_zm.cpp (100%) create mode 100644 ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/compute/bmm_large_block_zm_fused_bias_activation_inline_untilize_out.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_8bank_output_tiles_partitioned.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_interleaved.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_single_core.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_single_core_tilize_untilize.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in0.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in0_receiver_dram_sharded.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding_block_sharded.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_bmm_tile_layout_padding.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/writer_bmm_interleaved.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/writer_bmm_single_core_tiled.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/writer_bmm_tile_layout.cpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/kernels/dataflow/writer_bmm_tile_layout_padding.cpp (100%) rename tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp => ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp (99%) rename tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp => ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp (100%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/multi_core/bmm_op_multi_core.cpp (96%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/multi_core_reuse/bmm_op_multi_core_reuse.cpp (97%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp (98%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp (98%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp (99%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp (98%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp (97%) rename {tt_eager/tt_dnn/op_library/bmm => ttnn/cpp/ttnn/operations/matmul/device}/single_core/bmm_op_single_core_tilize_untilize.cpp (98%) rename ttnn/cpp/ttnn/operations/{ => matmul}/matmul.cpp (100%) rename ttnn/cpp/ttnn/operations/{ => matmul}/matmul.hpp (97%) create mode 100644 ttnn/cpp/ttnn/operations/matmul/matmul_pybind.hpp diff --git a/tests/tt_eager/integration_tests/test_bert.cpp b/tests/tt_eager/integration_tests/test_bert.cpp index 90361d07bd4..8da89329d55 100644 --- a/tests/tt_eager/integration_tests/test_bert.cpp +++ b/tests/tt_eager/integration_tests/test_bert.cpp @@ -7,7 +7,6 @@ #include "tensor/host_buffer/types.hpp" #include "tensor/tensor.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/layernorm/layernorm_op.hpp" #include "tt_dnn/op_library/operation.hpp" #include "ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp" @@ -15,6 +14,7 @@ #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_numpy/functions.hpp" +#include "ttnn/operations/matmul/matmul.hpp" using Parameters = std::map; @@ -35,10 +35,12 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param .transpose_mcast = false, .fused_activation = std::nullopt, }; - auto fused_qkv_matmul_output = tt::operations::primary::matmul( + auto fused_qkv_matmul_output = ttnn::operations::matmul::matmul( hidden_states, parameters.at(fmt::format("fused_qkv_weight_{}", encoder_index)), parameters.at(fmt::format("fused_qkv_bias_{}", encoder_index)), + /*transpose_a=*/false, + /*transpose_b=*/false, fused_qkv_matmul_program_config, l1_memory_config ); @@ -56,7 +58,15 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param .per_core_M = 12, .per_core_N = 12, }; - auto pre_softmax_bmm_matmul = tt::operations::primary::matmul(query, key, std::nullopt /*bias*/, pre_softmax_bmm_program_config, dram_memory_config); + auto pre_softmax_bmm_matmul = ttnn::operations::matmul::matmul( + query, + key, + /*bias=*/std::nullopt, + /*transpose_a=*/false, + /*transpose_b=*/false, + pre_softmax_bmm_program_config, + dram_memory_config + ); query.deallocate(); key.deallocate(); @@ -72,7 +82,15 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param .per_core_M = 12, .per_core_N = 2, }; - auto post_softmax_bmm_output = tt::operations::primary::matmul(pre_softmax_bmm_matmul, value, std::nullopt /*bias*/, post_softmax_bmm_program_config, l1_memory_config); + auto post_softmax_bmm_output = ttnn::operations::matmul::matmul( + pre_softmax_bmm_matmul, + value, + /*bias=*/std::nullopt, + /*transpose_a=*/false, + /*transpose_b=*/false, + post_softmax_bmm_program_config, + l1_memory_config + ); pre_softmax_bmm_matmul.deallocate(); value.deallocate(); @@ -91,10 +109,12 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param .transpose_mcast = false, .fused_activation = std::nullopt, }; - auto selfout_bmm_output = tt::operations::primary::matmul( + auto selfout_bmm_output = ttnn::operations::matmul::matmul( concat_heads_output, parameters.at(fmt::format("selfout_weight_{}", encoder_index)), parameters.at(fmt::format("selfout_bias_{}", encoder_index)), + /*transpose_a=*/false, + /*transpose_b=*/false, selfout_bmm_program_config, l1_memory_config ); @@ -123,10 +143,12 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param .transpose_mcast = false, .fused_activation = UnaryWithParam(UnaryOpType::GELU,1.0f), }; - auto ff1_matmul_output = tt::operations::primary::matmul( + auto ff1_matmul_output = ttnn::operations::matmul::matmul( attention_layernorm_output, parameters.at(fmt::format("ff1_weight_{}", encoder_index)), parameters.at(fmt::format("ff1_bias_{}", encoder_index)), + /*transpose_a=*/false, + /*transpose_b=*/false, ff1_matmul_program_config, dram_memory_config ); @@ -142,10 +164,12 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param .transpose_mcast = false, .fused_activation = std::nullopt, }; - auto ff2_matmul_output = tt::operations::primary::matmul( + auto ff2_matmul_output = ttnn::operations::matmul::matmul( ff1_matmul_output, parameters.at(fmt::format("ff2_weight_{}", encoder_index)), parameters.at(fmt::format("ff2_bias_{}", encoder_index)), + /*transpose_a=*/false, + /*transpose_b=*/false, ff2_matmul_program_config, l1_memory_config ); @@ -169,7 +193,7 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param Tensor qa_head(Tensor&& hidden_states, const Parameters& parameters) { - auto output = tt::operations::primary::matmul(hidden_states, parameters.at("qa_head_weight")); + auto output = ttnn::operations::matmul::matmul(hidden_states, parameters.at("qa_head_weight"), /*bias=*/std::nullopt); hidden_states.deallocate(); diff --git a/tests/tt_eager/ops/test_bmm_op.cpp b/tests/tt_eager/ops/test_bmm_op.cpp index cca9c8483ff..02a169bf29d 100644 --- a/tests/tt_eager/ops/test_bmm_op.cpp +++ b/tests/tt_eager/ops/test_bmm_op.cpp @@ -4,7 +4,7 @@ #include "tt_metal/host_api.hpp" #include "tensor/tensor.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" #include "common/constants.hpp" #include "tt_numpy/functions.hpp" diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation_copy.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp rename to tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation_copy.cpp diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/test_compute_mm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/test_compute_mm.cpp index 4b5f1fe927c..263cb5b3189 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/test_compute_mm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/test_compute_mm.cpp @@ -1027,7 +1027,7 @@ tt_metal::Program create_program_single_core ( auto mm_kernel_id = tt_metal::CreateKernel( program, matmul_block ? - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp" : + "tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation_copy.cpp" : "tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation.cpp", all_cores, tt_metal::ComputeConfig{ diff --git a/tt_eager/tt_dnn/op_library/CMakeLists.txt b/tt_eager/tt_dnn/op_library/CMakeLists.txt index ff706ef9b3c..be42ddb29af 100644 --- a/tt_eager/tt_dnn/op_library/CMakeLists.txt +++ b/tt_eager/tt_dnn/op_library/CMakeLists.txt @@ -55,15 +55,6 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_h/bcast_op_sharded_h_optimised.cpp ${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_w/bcast_op_multi_core_w.cpp ${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_hw/bcast_op_multi_core_hw.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/bmm_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/single_core/bmm_op_single_core_tilize_untilize.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core/bmm_op_multi_core.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse/bmm_op_multi_core_reuse.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/downsample/downsample_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/optimized_conv_op.cpp diff --git a/tt_eager/tt_dnn/op_library/complex/complex_ops.cpp b/tt_eager/tt_dnn/op_library/complex/complex_ops.cpp index 81864d3c78c..2c344c9e0a6 100644 --- a/tt_eager/tt_dnn/op_library/complex/complex_ops.cpp +++ b/tt_eager/tt_dnn/op_library/complex/complex_ops.cpp @@ -4,7 +4,6 @@ #include "tt_dnn/op_library/complex/complex_ops.hpp" #include "tt_dnn/op_library/concat/concat_op.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/reshape/reshape_op.hpp" #include "ttnn/operations/data_movement/slice/slice.hpp" #include "tt_numpy/functions.hpp" diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 2f413755f02..2fafa560d03 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -5,7 +5,6 @@ #include "tt_dnn/op_library/composite/composite_ops.hpp" #include "tt_dnn/op_library/auto_format.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/concat/concat_op.hpp" #include "tt_dnn/op_library/copy/copy_op.hpp" #include "tt_dnn/op_library/math.hpp" @@ -23,6 +22,7 @@ #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" +#include "ttnn/operations/matmul/matmul.hpp" namespace tt { @@ -1769,7 +1769,15 @@ Tensor _outer(Tensor& a, Tensor& b, const MemoryConfig& output_mem_config) { } } - return tt::operations::primary::matmul(a_slim, b_slim, std::nullopt, std::nullopt, output_mem_config); + return ttnn::operations::matmul::matmul( + a_slim, + b_slim, + /*bias=*/std::nullopt, + /*transpose_a=*/false, + /*transpose_b=*/false, + /*program_config=*/std::nullopt, + output_mem_config + ); } Tensor outer(Tensor& a, Tensor& b, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _outer)(a, b, output_mem_config); diff --git a/tt_eager/tt_dnn/op_library/fully_connected/fully_connected_op.cpp b/tt_eager/tt_dnn/op_library/fully_connected/fully_connected_op.cpp index b920e3a010c..9478cbaa876 100644 --- a/tt_eager/tt_dnn/op_library/fully_connected/fully_connected_op.cpp +++ b/tt_eager/tt_dnn/op_library/fully_connected/fully_connected_op.cpp @@ -6,14 +6,22 @@ #include #include "tt_metal/host_api.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/bcast/bcast_op.hpp" +#include "ttnn/operations/matmul/matmul.hpp" namespace tt { namespace tt_metal { Tensor fully_connected_(const Tensor& act, const Tensor& weights, std::optional> bias, const MemoryConfig& output_mem_config) { - Tensor mm_output = tt::operations::primary::matmul(act, weights, /*bias=*/std::nullopt, /*program_config=*/std::nullopt, output_mem_config); + Tensor mm_output = ttnn::operations::matmul::matmul( + act, + weights, + /*bias=*/std::nullopt, + /*transpose_a=*/false, + /*transpose_b=*/false, + /*program_config=*/std::nullopt, + output_mem_config + ); if (bias) { return bcast(mm_output, bias.value(), BcastOpMath::ADD, BcastOpDim::H, output_mem_config); } diff --git a/tt_eager/tt_lib/csrc/operations/primary/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/module.hpp index c9b4d5e8b44..5a4d2b991c5 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/module.hpp @@ -8,7 +8,6 @@ #include #include "transformers/module.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/groupnorm/groupnorm_op.hpp" #include "tt_dnn/op_library/layernorm/layernorm_op.hpp" #include "tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp" @@ -53,118 +52,6 @@ void py_module(py::module& m_primary) { auto m_transformers = m_primary.def_submodule("transformers", "Primary transformers operations"); transformers::py_module(m_transformers); - py::class_(m_primary, "MatmulProgramConfig") - .def("__repr__", [](const MatmulProgramConfig& config) { return fmt::format("{}", config); }); - - py::class_(m_primary, "MatmulMultiCoreReuseProgramConfig") - .def( - py::init(), - py::kw_only(), - py::arg("compute_with_storage_grid_size"), - py::arg("in0_block_w").noconvert(), - py::arg("out_subblock_h").noconvert(), - py::arg("out_subblock_w").noconvert(), - py::arg("per_core_M").noconvert(), - py::arg("per_core_N").noconvert()) - .def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseProgramConfig::compute_with_storage_grid_size) - .def_readwrite("in0_block_w", &MatmulMultiCoreReuseProgramConfig::in0_block_w) - .def_readwrite("out_subblock_h", &MatmulMultiCoreReuseProgramConfig::out_subblock_h) - .def_readwrite("out_subblock_w", &MatmulMultiCoreReuseProgramConfig::out_subblock_w) - .def_readwrite("per_core_M", &MatmulMultiCoreReuseProgramConfig::per_core_M) - .def_readwrite("per_core_N", &MatmulMultiCoreReuseProgramConfig::per_core_N) - .def("__repr__", [](const MatmulMultiCoreReuseProgramConfig& config) { return fmt::format("{}", config); }); - - py::class_(m_primary, "MatmulMultiCoreReuseMultiCastProgramConfig") - .def( - py::init< - CoreCoord, - std::size_t, - std::size_t, - std::size_t, - std::size_t, - std::size_t, - bool, - std::optional, - bool>(), - py::kw_only(), - py::arg("compute_with_storage_grid_size"), - py::arg("in0_block_w").noconvert(), - py::arg("out_subblock_h").noconvert(), - py::arg("out_subblock_w").noconvert(), - py::arg("per_core_M").noconvert(), - py::arg("per_core_N").noconvert(), - py::arg("transpose_mcast").noconvert(), - py::arg("fused_activation"), - py::arg("fuse_batch").noconvert() = true) - .def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseMultiCastProgramConfig::compute_with_storage_grid_size) - .def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCastProgramConfig::in0_block_w) - .def_readwrite("out_subblock_h", &MatmulMultiCoreReuseMultiCastProgramConfig::out_subblock_h) - .def_readwrite("out_subblock_w", &MatmulMultiCoreReuseMultiCastProgramConfig::out_subblock_w) - .def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCastProgramConfig::per_core_M) - .def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCastProgramConfig::per_core_N) - .def_readwrite("transpose_mcast", &MatmulMultiCoreReuseMultiCastProgramConfig::transpose_mcast) - .def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCastProgramConfig::fused_activation) - .def_readwrite("fuse_batch", &MatmulMultiCoreReuseMultiCastProgramConfig::fuse_batch) - .def("__repr__", [](const MatmulMultiCoreReuseMultiCastProgramConfig& config) { - return fmt::format("{}", config); - }); - - py::class_(m_primary, "MatmulMultiCoreReuseMultiCast1DProgramConfig") - .def( - py::init< - CoreCoord, - std::size_t, - std::size_t, - std::size_t, - std::size_t, - std::size_t, - bool, - std::optional, - bool>(), - py::kw_only(), - py::arg("compute_with_storage_grid_size"), - py::arg("in0_block_w").noconvert(), - py::arg("out_subblock_h").noconvert(), - py::arg("out_subblock_w").noconvert(), - py::arg("per_core_M").noconvert(), - py::arg("per_core_N").noconvert(), - py::arg("fuse_batch").noconvert(), - py::arg("fused_activation"), - py::arg("mcast_in0").noconvert()) - .def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseMultiCast1DProgramConfig::compute_with_storage_grid_size) - .def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::in0_block_w) - .def_readwrite("out_subblock_h", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_h) - .def_readwrite("out_subblock_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_w) - .def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_M) - .def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_N) - .def_readwrite("fuse_batch", &MatmulMultiCoreReuseMultiCast1DProgramConfig::fuse_batch) - .def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCast1DProgramConfig::fused_activation) - .def_readwrite("mcast_in0", &MatmulMultiCoreReuseMultiCast1DProgramConfig::mcast_in0) - .def("__repr__", [](const MatmulMultiCoreReuseMultiCast1DProgramConfig& config) { - return fmt::format("{}", config); - }); - - py::class_( - m_primary, "MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig") - .def( - py::init< - std::size_t, - std::size_t, - std::size_t, - std::optional>(), - py::kw_only(), - py::arg("in0_block_w").noconvert(), - py::arg("per_core_M").noconvert(), - py::arg("per_core_N").noconvert(), - py::arg("fused_activation")) - .def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::in0_block_w) - .def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::per_core_M) - .def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::per_core_N) - .def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::fused_activation) - .def("__repr__", [](const MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig& config) { - return fmt::format("{}", config); - }); - py::class_(m_primary, "LayerNormDefaultProgramConfig").def(py::init<>()); py::class_(m_primary, "LayerNormShardedMultiCoreProgramConfig") diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index dba47e7c914..aa4604601a9 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -3,7 +3,16 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/async_runtime.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/to_layout/to_layout_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv2d.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/matmul_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/single_core/bmm_op_single_core_tilize_untilize.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core/bmm_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse/bmm_op_multi_core_reuse.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 395d132fdb8..c185943ebfc 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -13,7 +13,6 @@ #include "pybind11/operations/core.hpp" #include "pybind11/operations/creation.hpp" #include "pybind11/operations/kv_cache.hpp" -#include "pybind11/operations/matmul.hpp" #include "pybind11/operations/maxpool2d.hpp" #include "pybind11/operations/pool.hpp" #include "pybind11/operations/ternary.hpp" @@ -27,6 +26,7 @@ #include "ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp" #include "ttnn/operations/data_movement/data_movement_pybind.hpp" #include "ttnn/operations/embedding/embedding_ops_pybind.hpp" +#include "ttnn/operations/matmul/matmul_pybind.hpp" namespace py = pybind11; diff --git a/ttnn/cpp/pybind11/operations/matmul.hpp b/ttnn/cpp/pybind11/operations/matmul.hpp deleted file mode 100644 index 4bd41040893..00000000000 --- a/ttnn/cpp/pybind11/operations/matmul.hpp +++ /dev/null @@ -1,100 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include - -#include "ttnn/operations/matmul.hpp" -#include "ttnn/types.hpp" - -namespace py = pybind11; - -namespace ttnn { -namespace operations { -namespace matmul { - -void py_module(py::module& module) { - module.def( - "matmul", - [](const ttnn::Tensor& input_tensor_a, - const ttnn::Tensor& input_tensor_b, - const bool transpose_a = false, - const bool transpose_b = false, - const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG, - const std::optional dtype = std::nullopt, - const std::optional program_config = std::nullopt, - const std::optional& activation = std::nullopt, - const std::optional compute_kernel_config = std::nullopt, - const std::optional core_grid = std::nullopt) -> ttnn::Tensor { - return ttnn::operations::matmul::matmul( - input_tensor_a, - input_tensor_b, - /*bias=*/std::nullopt, - transpose_a, - transpose_b, - program_config, - memory_config, - dtype, - activation, - compute_kernel_config, - core_grid, - /*propagate_is_b_batched=*/true); - }, - py::arg("input_tensor_a"), - py::arg("input_tensor_b"), - py::kw_only(), - py::arg("transpose_a") = false, - py::arg("transpose_b") = false, - py::arg("memory_config") = DRAM_MEMORY_CONFIG, - py::arg("dtype") = std::nullopt, - py::arg("program_config") = std::nullopt, - py::arg("activation") = std::nullopt, - py::arg("compute_kernel_config") = std::nullopt, - py::arg("core_grid") = std::nullopt); - - module.def( - "linear", - [](const ttnn::Tensor& input_tensor_a, - const ttnn::Tensor& input_tensor_b, - const std::optional& bias = std::nullopt, - const bool transpose_a = false, - const bool transpose_b = false, - const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG, - const std::optional dtype = std::nullopt, - const std::optional program_config = std::nullopt, - const std::optional& activation = std::nullopt, - const std::optional compute_kernel_config = std::nullopt, - const std::optional core_grid = std::nullopt) -> ttnn::Tensor { - return ttnn::operations::matmul::matmul( - input_tensor_a, - input_tensor_b, - bias, - transpose_a, - transpose_b, - program_config, - memory_config, - dtype, - activation, - compute_kernel_config, - core_grid); - }, - py::arg("input_tensor_a"), - py::arg("input_tensor_b"), - py::kw_only(), - py::arg("bias") = std::nullopt, - py::arg("transpose_a") = false, - py::arg("transpose_b") = false, - py::arg("memory_config") = DRAM_MEMORY_CONFIG, - py::arg("dtype") = std::nullopt, - py::arg("program_config") = std::nullopt, - py::arg("activation") = std::nullopt, - py::arg("compute_kernel_config") = std::nullopt, - py::arg("core_grid") = std::nullopt); -} - -} // namespace matmul -} // namespace operations -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv2d.hpp index 13b40714eea..e4eb58b4f52 100644 --- a/ttnn/cpp/ttnn/operations/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv2d.hpp @@ -8,7 +8,7 @@ #include "ttnn/core.hpp" #include "ttnn/operations/core.hpp" -#include "ttnn/cpp/ttnn/operations/matmul.hpp" +#include "ttnn/cpp/ttnn/operations/matmul/matmul.hpp" #include "ttnn/types.hpp" #include "tt_eager/tensor/tensor_utils.hpp" #include "tt_metal/impl/dispatch/command_queue.hpp" diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm.cpp diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp new file mode 100644 index 00000000000..5ec74d7367c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp @@ -0,0 +1,367 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "mod_div_lib.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "compute_kernel_api/matmul.h" +#include "compute_kernel_api/pack_untilize.h" + +#ifdef FUSE_BIAS +#include "compute_kernel_api/bcast.h" +#endif + +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" + + +// Please update tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation_copy.cpp +// when making any changes to this file. +// Have to keep a copy because cannot import ttnn into tests/tt_metal. + +namespace NAMESPACE { + +FORCE_INLINE void reload_from_cb_to_dst(uint32_t in0_cb_id, uint32_t in1_cb_id, uint32_t mm_partials_cb_id, uint32_t out_subblock_num_tiles, uint32_t out_subblock_w, uint32_t out_subblock_h, uint32_t in0_block_w) { + // Reconfigure input + copy_tile_to_dst_init_short_with_dt(in1_cb_id, mm_partials_cb_id); + cb_wait_front(mm_partials_cb_id, out_subblock_num_tiles); + + uint32_t start_dst_index = 0; + uint32_t start_tile_index = 0; + copy_block_matmul_partials(mm_partials_cb_id, start_tile_index, start_dst_index, out_subblock_num_tiles); + + cb_pop_front(mm_partials_cb_id, out_subblock_num_tiles); + // Reconfigure srcA back + mm_block_init_short_with_dt(in0_cb_id, in1_cb_id, mm_partials_cb_id, false, out_subblock_w, out_subblock_h, in0_block_w); +} + +template +inline void reblock_and_untilize( + uint32_t num_out_subblocks_in_col, + uint32_t out_subblock_num_tiles, + uint32_t out_subblock_h, + uint32_t interm_cb_id, + uint32_t out_cb_id) { + + uint32_t num_tiles_in_row_of_subblocks = mulsi3(out_subblock_num_tiles, num_out_subblocks_in_col); + cb_wait_front(interm_cb_id, num_tiles_in_row_of_subblocks); + + uint32_t within_block_index = 0; + for (uint32_t h = 0; h < out_subblock_h; h++) { + uint32_t block_offset = 0; + + cb_reserve_back(out_cb_id, out_block_w); + for (uint32_t n = 0; n < num_out_subblocks_in_col; n++) { + tile_regs_acquire(); + for (uint32_t w = 0; w < out_subblock_w; w++) { + uint32_t tile_index = block_offset + within_block_index + w; + copy_tile(interm_cb_id, tile_index, w); + } + tile_regs_commit(); + tile_regs_wait(); + pack_untilize_dst(out_cb_id, 1, n); + tile_regs_release(); + block_offset += out_subblock_num_tiles; + } + cb_push_back(out_cb_id, out_block_w); + + within_block_index += out_subblock_w; + } + cb_pop_front(interm_cb_id, num_tiles_in_row_of_subblocks); +} + +void MAIN { + // RUNTIME ARGS + #ifdef MATMUL_DRAM_SHARDED + const bool is_worker_core = get_arg_val(0) == 1; + // if not worker core, skip + if (not is_worker_core) { + return; + } + #endif + + constexpr uint32_t in0_block_w = get_compile_time_arg_val(0); // inner block size in tiles + constexpr uint32_t in0_num_subblocks = get_compile_time_arg_val(1); // outer row block size (in inner row blocks) + constexpr uint32_t in0_block_num_tiles = get_compile_time_arg_val(2); // out_subblock_h*in0_block_w*in0_num_subblocks; + constexpr uint32_t in0_subblock_num_tiles = get_compile_time_arg_val(3); // out_subblock_h*in0_block_w + constexpr uint32_t in1_num_subblocks = get_compile_time_arg_val(4); // outer column block size (in inner column blocks) + constexpr uint32_t in1_block_num_tiles = get_compile_time_arg_val(5); //out_subblock_w*in0_block_w* in1_num_subblocks; + constexpr uint32_t in1_per_core_w = get_compile_time_arg_val(6); // out_subblock_w*in1_num_subblocks + constexpr uint32_t num_blocks = get_compile_time_arg_val(7); // outer inner dim (in inner dim blocks) + constexpr uint32_t out_subblock_h = get_compile_time_arg_val(8); // inner row block size in tiles + constexpr uint32_t out_subblock_w = get_compile_time_arg_val(9); // inner column block size in tiles + constexpr uint32_t out_subblock_num_tiles = get_compile_time_arg_val(10); // out_subblock_h * out_subblock_w; + constexpr uint32_t batch = get_compile_time_arg_val(11); // batch dim + constexpr uint32_t out_block_num_tiles = get_compile_time_arg_val(12); // number of tiles in out_block + constexpr bool untilize_out = get_compile_time_arg_val(13); // untilize output + + constexpr uint32_t out_block_w = out_subblock_w*in1_num_subblocks; + + constexpr uint32_t in0_cb_id = tt::CB::c_in0; + constexpr uint32_t in1_cb_id = tt::CB::c_in1; + constexpr uint32_t out_cb_id = tt::CB::c_out0; + constexpr uint32_t mm_partials_cb_id = tt::CB::c_intermed0; + + constexpr uint32_t untilize_mode_out_cb_id = untilize_out ? mm_partials_cb_id : out_cb_id; + + #ifdef FUSE_BIAS + constexpr uint32_t bias_cb_id = tt::CB::c_in3; + constexpr uint32_t mm_out_cb_id = mm_partials_cb_id; + #else + constexpr uint32_t mm_out_cb_id = untilize_mode_out_cb_id; + #endif + + #ifdef SFPU_OP_INIT_ACTIVATION + SFPU_OP_INIT_ACTIVATION + #endif + + constexpr bool spill = num_blocks > 1; + + mm_block_init(in0_cb_id, in1_cb_id, mm_partials_cb_id, false, out_subblock_w, out_subblock_h, in0_block_w ); + for (uint32_t b = 0; b < batch; b++){ + bool enable_reload = false; + uint32_t out_num_tiles_to_wait = out_subblock_num_tiles; + + #ifdef PACK_RELU + // for each batch we start we relu disabled so that intermediate results are not relu'd + if constexpr(batch > 1) { + PACK(( llk_pack_relu_config(ReluType::NO_RELU) )); + } + #endif + + if constexpr(batch > 1) { + PACK(( pack_reconfig_data_format(mm_partials_cb_id) )); + } + + for(uint32_t block = 0; block < num_blocks; block++) + { + bool last_out = block == (num_blocks-1); + // Configure packer once for pack out without Bias + #if not defined FUSE_BIAS and defined PACK_RELU + if (last_out) { + // if last block we pack the final result with relu enabled + PACK(( llk_pack_relu_config(ReluType::ZERO_RELU) )); + } + #endif + + cb_wait_front(in0_cb_id, in0_block_num_tiles); + cb_wait_front(in1_cb_id, in1_block_num_tiles); + int in0_index_subblock_offset = 0; + for (uint32_t in0_subblock = 0; in0_subblock < in0_num_subblocks; in0_subblock++) { + int in1_index_subblock_offset = 0; + for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { + + tile_regs_acquire(); + if (enable_reload) { + reload_from_cb_to_dst(in0_cb_id, in1_cb_id, mm_partials_cb_id, out_subblock_num_tiles, out_subblock_w, out_subblock_h, in0_block_w); + } + + #ifndef SKIP_COMPUTE + // Compute output sub-block + uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index + uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block + uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block + // inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w + for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) { + // matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst + // accumulation is done by iterating matmul_block across inner dim + // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0 + matmul_block(in0_cb_id, in1_cb_id, in0_index, in1_index, dst_index, false, out_subblock_w, out_subblock_h, in0_block_w); + in0_index ++; // stride right by 1 + in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w) + } + #endif // SKIP_COMPUTE + + if (last_out) { + // If we fuse bias, we will pack out and run bias + optional sfpu in a separate loop + #if not defined FUSE_BIAS and defined SFPU_OP_INIT_ACTIVATION + for (uint32_t i = 0; i < out_subblock_num_tiles; i++) { + SFPU_OP_FUNC_ACTIVATION + } + #endif + + tile_regs_commit(); + // Pack out to output buffer + cb_reserve_back(mm_out_cb_id, out_subblock_num_tiles); + tile_regs_wait(); + + #ifdef PACKER_L1_ACC + #ifdef FUSE_BIAS + if (block == 0) { // no accumulation for first iteration + PACK(( llk_pack_reconfig_l1_acc(0) )); + } else { + PACK(( llk_pack_reconfig_l1_acc(1) )); + } + #else + PACK(( llk_pack_reconfig_l1_acc(0) )); + #endif + #endif + + #if defined FP32_DEST_ACC_EN or defined PACKER_L1_ACC + PACK(( pack_reconfig_data_format(mm_out_cb_id) )); + #endif + + uint32_t start_dst_index = 0; + matmul_pack_tile(start_dst_index, mm_out_cb_id, out_subblock_num_tiles); + + tile_regs_release(); + cb_push_back(mm_out_cb_id, out_subblock_num_tiles); + + } else { + tile_regs_commit(); + // Wait for tiles in output buffer to be written out since interm and output share memory + if (block == 0) { + cb_reserve_back(out_cb_id, out_num_tiles_to_wait); + out_num_tiles_to_wait += out_subblock_num_tiles; + } + // Move partial result to interm buffer + cb_reserve_back(mm_partials_cb_id, out_subblock_num_tiles); + tile_regs_wait(); + + #ifdef PACKER_L1_ACC + if (block == 0) { // no accumulation for first iteration + PACK(( llk_pack_reconfig_l1_acc(0) )); + } else if (block == 1) { + PACK(( llk_pack_reconfig_l1_acc(1) )); + } + #endif + + uint32_t start_dst_index = 0; + matmul_pack_tile(start_dst_index, mm_partials_cb_id, out_subblock_num_tiles); + + tile_regs_release(); + cb_push_back(mm_partials_cb_id, out_subblock_num_tiles); + + } + + in1_index_subblock_offset += out_subblock_w; + } + in0_index_subblock_offset += in0_subblock_num_tiles; + } + + + #ifdef PACKER_L1_ACC + #ifdef FUSE_BIAS + if (block < num_blocks - 1) { + //Wait for l1 accumulation to populate interm buffer, + //then pop to update fifo rd pointer + cb_wait_front(mm_partials_cb_id, out_block_num_tiles); + cb_pop_front(mm_partials_cb_id, out_block_num_tiles); + } + // never reload when with bias, bias uses interm buffer + enable_reload = false; + #else + //Last iteration does spill and reload to output buffer + if (block < num_blocks - 2) { + cb_wait_front(mm_partials_cb_id, out_block_num_tiles); + cb_pop_front(mm_partials_cb_id, out_block_num_tiles); + } + if (block == num_blocks - 2) { enable_reload = true; } // reload when last iteration + #endif + #else + if constexpr (spill) { enable_reload = true; } + #endif + + cb_pop_front(in0_cb_id, in0_block_num_tiles); + cb_pop_front(in1_cb_id, in1_block_num_tiles); + + } + + #ifdef FUSE_BIAS + #ifdef PACK_RELU + // if last block we pack the final result with relu enabled + PACK(( llk_pack_relu_config(ReluType::ZERO_RELU) )); + #endif + #ifdef PACKER_L1_ACC + PACK(( llk_pack_reconfig_l1_acc(0) )); + #endif + #if defined FP32_DEST_ACC_EN or defined PACKER_L1_ACC + PACK(( pack_reconfig_data_format(out_cb_id) )); + #endif + + unpack_reconfig_data_format(in1_cb_id, mm_partials_cb_id, in0_cb_id, bias_cb_id); + add_bcast_rows_init_short(); + // reconfigure unpacker df for src B + cb_wait_front(bias_cb_id, in1_per_core_w); + for (uint32_t in0_subblock = 0; in0_subblock < in0_num_subblocks; in0_subblock++) { + int in1_index_subblock_offset = 0; + for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { + // Redundant wait since we know data was just pushed + cb_wait_front(mm_partials_cb_id, out_subblock_num_tiles); + tile_regs_acquire(); + for (uint32_t i = 0, j = 0; j < out_subblock_h; j++) { + uint32_t bcast_tile_idx = in1_index_subblock_offset; + for (uint32_t k = 0; k < out_subblock_w; k++, i++) { + add_tiles_bcast_rows(mm_partials_cb_id, bias_cb_id, i, bcast_tile_idx, i); + bcast_tile_idx++; + } + } + // if there's no SFPU fusion, we commit the regs so packer can start packing + #ifndef SFPU_OP_INIT_ACTIVATION + tile_regs_commit(); + #endif + + cb_pop_front(mm_partials_cb_id, out_subblock_num_tiles); + + + // sfpu activation + #ifdef SFPU_OP_INIT_ACTIVATION + for (uint32_t i = 0; i < out_subblock_num_tiles; i++) { + SFPU_OP_FUNC_ACTIVATION + } + tile_regs_commit(); + #endif + + // Pack out to output buffer + cb_reserve_back(untilize_mode_out_cb_id, out_subblock_num_tiles); + tile_regs_wait(); + for (uint32_t i = 0; i < out_subblock_num_tiles; i++) { + pack_tile(i, untilize_mode_out_cb_id); + } + tile_regs_release(); + cb_push_back(untilize_mode_out_cb_id, out_subblock_num_tiles); + + in1_index_subblock_offset += out_subblock_w; + } + } + #endif // FUSE_BIAS + if constexpr(untilize_out) { + #ifdef PACK_RELU + PACK(( llk_pack_relu_config(ReluType::NO_RELU) )); + #endif // PACK_RELU + #ifndef FUSE_BIAS + unpack_reconfig_data_format_srca(in1_cb_id, mm_partials_cb_id); + #ifdef PACKER_L1_ACC + PACK(( llk_pack_reconfig_l1_acc(0) )); + #endif + #if defined FP32_DEST_ACC_EN or defined PACKER_L1_ACC + PACK(( pack_reconfig_data_format(out_cb_id) )); + #endif + #endif // FUSE_BIAS + pack_untilize_dst_init_short(out_cb_id); + copy_tile_to_dst_init_short(); + for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { + reblock_and_untilize ( + in1_num_subblocks, + out_subblock_num_tiles, + out_subblock_h, + mm_partials_cb_id, + out_cb_id); + } + pack_untilize_uninit(mm_partials_cb_id); + } + if constexpr(batch > 1) { + // reconfigure init for matmul + mm_block_init_short(in0_cb_id, in1_cb_id, 0, out_subblock_w, out_subblock_h, in0_block_w); + #ifdef FUSE_BIAS + // reconfigure unpacker df for src A and src B + unpack_reconfig_data_format(mm_partials_cb_id, in1_cb_id, bias_cb_id, in0_cb_id); + #else + // reconfigure unpacker df for src A + unpack_reconfig_data_format_srca(mm_partials_cb_id, in1_cb_id); + #endif + } + } +} +} diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation_inline_untilize_out.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation_inline_untilize_out.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation_inline_untilize_out.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation_inline_untilize_out.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_8bank_output_tiles_partitioned.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_8bank_output_tiles_partitioned.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_8bank_output_tiles_partitioned.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_8bank_output_tiles_partitioned.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_interleaved.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_interleaved.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_interleaved.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_interleaved.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_single_core.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_single_core.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_single_core.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_single_core.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_single_core_tilize_untilize.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_single_core_tilize_untilize.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_single_core_tilize_untilize.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_single_core_tilize_untilize.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver_dram_sharded.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_receiver_dram_sharded.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver_dram_sharded.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_receiver_dram_sharded.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding_block_sharded.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding_block_sharded.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding_block_sharded.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding_block_sharded.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_padding.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_padding.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_padding.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_interleaved.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_interleaved.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_interleaved.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_interleaved.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_single_core_tiled.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_single_core_tiled.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_single_core_tiled.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_single_core_tiled.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_tile_layout.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_tile_layout.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_tile_layout.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_tile_layout.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_tile_layout_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_tile_layout_padding.cpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_tile_layout_padding.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_tile_layout_padding.cpp diff --git a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp similarity index 99% rename from tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index ee05a7fd8f0..d91e2a6cfef 100644 --- a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_dnn/op_library/bmm/bmm_op.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" #include #include diff --git a/tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp similarity index 100% rename from tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp rename to ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core/bmm_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/matmul/device/multi_core/bmm_op_multi_core.cpp similarity index 96% rename from tt_eager/tt_dnn/op_library/bmm/multi_core/bmm_op_multi_core.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/multi_core/bmm_op_multi_core.cpp index e488f14eaec..3e3315b527b 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core/bmm_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/multi_core/bmm_op_multi_core.cpp @@ -2,11 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" using namespace tt::constants; @@ -89,7 +89,7 @@ operation::ProgramWithCallbacks matmul_multi_core(const Tensor &a, const Tensor auto reader_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_8bank_output_tiles_partitioned.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_8bank_output_tiles_partitioned.cpp", all_cores, tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); @@ -108,7 +108,7 @@ operation::ProgramWithCallbacks matmul_multi_core(const Tensor &a, const Tensor auto eltwise_binary_kernel_group_1_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm.cpp", core_group_1, tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .compile_args = compute_args_group_1} ); @@ -123,7 +123,7 @@ operation::ProgramWithCallbacks matmul_multi_core(const Tensor &a, const Tensor auto eltwise_binary_kernel_group_2_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm.cpp", core_group_2, tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .compile_args = compute_args_group_2} ); diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse/bmm_op_multi_core_reuse.cpp b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse/bmm_op_multi_core_reuse.cpp similarity index 97% rename from tt_eager/tt_dnn/op_library/bmm/multi_core_reuse/bmm_op_multi_core_reuse.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse/bmm_op_multi_core_reuse.cpp index 8044bd8d978..8ceaef7ad48 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse/bmm_op_multi_core_reuse.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse/bmm_op_multi_core_reuse.cpp @@ -2,7 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_dnn/op_library/bmm/bmm_op.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" + #include "tt_dnn/op_library/work_split.hpp" #include "tt_dnn/op_library/operation.hpp" #include "tt_metal/host_api.hpp" @@ -108,20 +109,20 @@ tt_metal::operation::ProgramWithCallbacks create_program( // Create reader and writer kernels per core auto mm_reader_kernel_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout.cpp", all_cores, tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); auto unary_writer_kernel_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_tile_layout.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_tile_layout.cpp", all_cores, tt_metal::WriterDataMovementConfig(writer_compile_time_args)); // Create compute kernel auto mm_kernel_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm.cpp", all_cores, tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .compile_args = compute_kernel_args} ); diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp similarity index 98% rename from tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp index 1b700509f83..3057995bf46 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp @@ -5,7 +5,6 @@ #include #include "hostdevcommon/common_values.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_dnn/op_library/operation.hpp" #include "tt_dnn/op_library/work_split.hpp" @@ -13,6 +12,7 @@ #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" using namespace tt::constants; using namespace tt; @@ -356,9 +356,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0( auto mm_kernel_in0_mcast_cores_with_work_and_in_receiver_grid_id = tt_metal::CreateKernel( program, in0_is_sharded - ? "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/" + ? "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/" "reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp" - : "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp", + : "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp", in0_mcast_cores_with_work_and_in_receiver_grid, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, @@ -374,7 +374,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( in0_sender_compile_time_args[1] = 1; // core_in_in0_receiver_mcast_grid mm_kernel_in0_mcast_cores_without_work_and_in_receiver_grid_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/" + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/" "reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp", in0_mcast_cores_without_work_and_in_receiver_grid, tt_metal::DataMovementConfig{ @@ -388,7 +388,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( in0_sender_compile_time_args[1] = 0; // core_in_in0_receiver_mcast_grid mm_kernel_in0_mcast_cores_without_work_and_not_in_receiver_grid_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/" + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/" "reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp", in0_mcast_cores_without_work_and_not_in_receiver_grid, tt_metal::DataMovementConfig{ @@ -403,7 +403,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( if (!in0_is_sharded and in0_mcast_receivers.num_cores() > 0) { mm_kernel_in0_receiver_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp", in0_mcast_receivers, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, @@ -413,7 +413,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( auto mm_kernel_in1_sender_writer_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp", all_cores_with_work, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, @@ -458,7 +458,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( // bool math_approx_mode = false; auto mm_kernel = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", all_cores_with_work, tt_metal::ComputeConfig{ .math_fidelity = math_fidelity, @@ -1067,7 +1067,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1( auto mm_kernel_in0_sender_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp", all_cores, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, @@ -1077,7 +1077,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1( auto mm_kernel_in1_sender_writer_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp", in1_mcast_sender, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, @@ -1089,7 +1089,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1( if (in1_mcast_receivers.num_cores() > 0) { mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/" + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/" "reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", in1_mcast_receivers, tt_metal::DataMovementConfig{ @@ -1138,7 +1138,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1( // bool math_approx_mode = false; auto mm_kernel = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", all_cores, tt_metal::ComputeConfig{ .math_fidelity = math_fidelity, diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp similarity index 98% rename from tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp index 86cd8d442a6..1055a942ec1 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp @@ -5,13 +5,13 @@ #include #include "hostdevcommon/common_values.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_dnn/op_library/operation.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" using namespace tt::constants; using namespace tt; @@ -496,7 +496,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( if (in0_block_sharded) { mm_kernel_in0_sender_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/" + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/" "reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp", all_cores_with_work, // in0_mcast_cores_with_work_and_in_receiver_grid tt_metal::DataMovementConfig{ @@ -509,7 +509,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( in0_sender_compile_time_args[1] = 0; // core_in_in0_receiver_mcast_grid mm_kernel_in0_mcast_cores_without_work_and_not_in_receiver_grid_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/" + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/" "reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp", in0_mcast_cores_without_work_and_not_in_receiver_grid.value(), tt_metal::DataMovementConfig{ @@ -521,7 +521,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( } else { mm_kernel_in0_sender_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp", in0_sender_interleaved, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, @@ -532,7 +532,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( auto mm_kernel_in1_sender_writer_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp", in1_sender, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, @@ -544,7 +544,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( if (in1_receiver.num_cores() > 0) { mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", /* in0_sender_in1_receiver, // If not using half-half noc setup */ in1_receiver, tt_metal::DataMovementConfig{ @@ -558,7 +558,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( if (!in0_block_sharded and in0_receiver_interleaved.num_cores() > 0) { mm_kernel_in0_receiver_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp", /* in0_receiver_in1_sender, // If not using half-half noc setup */ in0_receiver_interleaved, tt_metal::DataMovementConfig{ @@ -573,7 +573,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( if (in0_receiver_in1_receiver_interleaved_other_cores.has_value()) { mm_kernel_in1_receiver_writer_other_noc_setup_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", in0_receiver_in1_receiver_interleaved_other_cores.value(), tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, @@ -583,7 +583,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( mm_kernel_in0_receiver_other_noc_setup_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp", in0_receiver_in1_receiver_interleaved_other_cores.value(), tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, @@ -627,7 +627,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( // bool math_approx_mode = false; auto mm_kernel = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", all_cores_with_work, tt_metal::ComputeConfig{ .math_fidelity = math_fidelity, diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp similarity index 99% rename from tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp index 2fd60cce636..691ec69f292 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp @@ -5,7 +5,6 @@ #include #include "hostdevcommon/common_values.hpp" -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_dnn/op_library/operation.hpp" #include "tt_dnn/op_library/work_split.hpp" @@ -13,6 +12,7 @@ #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" using namespace tt::constants; using namespace tt; @@ -636,7 +636,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded( auto mm_kernel_in0_sender_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp", all_cores_in_rect_grid, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, @@ -646,7 +646,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded( auto mm_kernel_in1_sender_writer_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp", all_cores_in_rect_grid, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, @@ -685,7 +685,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded( // Create compute kernel auto mm_kernel = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", // all_worker_cores, all_cores_in_rect_grid, tt_metal::ComputeConfig{ diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp similarity index 98% rename from tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp index 8ccc6545453..770623d41e2 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp @@ -2,7 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/work_split.hpp" #include "tt_dnn/op_library/operation.hpp" @@ -11,6 +10,8 @@ #include "tt_metal/detail/util.hpp" #include "tt_metal/detail/tt_metal.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" + using namespace tt::constants; using namespace tt; @@ -160,14 +161,14 @@ operation::ProgramWithCallbacks create_program( KernelHandle mm_kernel_in0_reader_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0.cpp", all_cores, ReaderDataMovementConfig(reader_compile_time_args, mm_kernel_in0_reader_defines) ); KernelHandle mm_kernel_in1_reader_writer_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp", all_cores, WriterDataMovementConfig(reader_writer_compile_time_args, mm_kernel_in1_reader_writer_defines) ); @@ -204,7 +205,7 @@ operation::ProgramWithCallbacks create_program( // Create compute kernel auto mm_kernel_group_1_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", core_group_1, tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_1, .defines = mm_kernel_defines} ); @@ -231,7 +232,7 @@ operation::ProgramWithCallbacks create_program( }; auto mm_kernel_group_2_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", core_group_2, tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_2, .defines = mm_kernel_defines} ); diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp similarity index 97% rename from tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp index 7e1acf8d84d..f1b8d4f55b1 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp @@ -2,7 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_dnn/op_library/bmm/bmm_op.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" + #include "tt_dnn/op_library/work_split.hpp" #include "tt_dnn/op_library/operation.hpp" @@ -123,20 +124,20 @@ operation::ProgramWithCallbacks create_program( // Create reader and writer kernels per core auto mm_reader_kernel_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_padding.cpp_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_padding.cpp_padding.cpp", all_cores, tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); auto unary_writer_kernel_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_tile_layout_padding.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_tile_layout_padding.cpp", all_cores, tt_metal::WriterDataMovementConfig(writer_compile_time_args)); // Create compute kernel auto mm_kernel_id = tt_metal::CreateKernel( program, - "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm.cpp", + "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm.cpp", all_cores, tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .compile_args = compute_kernel_args} ); diff --git a/tt_eager/tt_dnn/op_library/bmm/single_core/bmm_op_single_core_tilize_untilize.cpp b/ttnn/cpp/ttnn/operations/matmul/device/single_core/bmm_op_single_core_tilize_untilize.cpp similarity index 98% rename from tt_eager/tt_dnn/op_library/bmm/single_core/bmm_op_single_core_tilize_untilize.cpp rename to ttnn/cpp/ttnn/operations/matmul/device/single_core/bmm_op_single_core_tilize_untilize.cpp index 72797768ef2..efabbe3dd9a 100644 --- a/tt_eager/tt_dnn/op_library/bmm/single_core/bmm_op_single_core_tilize_untilize.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/single_core/bmm_op_single_core_tilize_untilize.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_dnn/op_library/bmm/bmm_op.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" #include "tt_dnn/op_library/run_operation.hpp" #include "tt_metal/host_api.hpp" @@ -401,7 +401,7 @@ operation::ProgramWithCallbacks bmm_single_core_tilize_untilize( if (tilize_in0) { // in0 is row major, in1 is tiled // NOTE: this only makes sense for non-tile-shared datatypes for in0 - reader_kernel = "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_single_core_tilize_untilize.cpp"; + reader_kernel = "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_single_core_tilize_untilize.cpp"; reader_rt_args = { // in0 in0_dram_addr, @@ -434,7 +434,7 @@ operation::ProgramWithCallbacks bmm_single_core_tilize_untilize( }; } else { // in0 is tiled, in1 is tiled - reader_kernel = "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_single_core.cpp"; + reader_kernel = "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_single_core.cpp"; reader_rt_args = { // in0 in0_dram_addr, @@ -497,7 +497,7 @@ operation::ProgramWithCallbacks bmm_single_core_tilize_untilize( }; } else { // out is tiled - writer_kernel = "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/writer_bmm_single_core_tiled.cpp"; + writer_kernel = "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_single_core_tiled.cpp"; writer_rt_args = { out_dram_addr, 0, // UNUSED diff --git a/ttnn/cpp/ttnn/operations/matmul.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul.cpp similarity index 100% rename from ttnn/cpp/ttnn/operations/matmul.cpp rename to ttnn/cpp/ttnn/operations/matmul/matmul.cpp diff --git a/ttnn/cpp/ttnn/operations/matmul.hpp b/ttnn/cpp/ttnn/operations/matmul/matmul.hpp similarity index 97% rename from ttnn/cpp/ttnn/operations/matmul.hpp rename to ttnn/cpp/ttnn/operations/matmul/matmul.hpp index 34a38744191..e1d112d7cf7 100644 --- a/ttnn/cpp/ttnn/operations/matmul.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul.hpp @@ -6,7 +6,7 @@ #include "tt_eager/tensor/tensor_utils.hpp" #include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp" -#include "tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp" +#include "ttnn/operations/matmul/device/matmul_op.hpp" #include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_metal/common/core_coord.h" #include "tt_metal/impl/dispatch/command_queue.hpp" diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.hpp b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.hpp new file mode 100644 index 00000000000..97069448faa --- /dev/null +++ b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.hpp @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/operations/matmul/matmul.hpp" +#include "ttnn/types.hpp" + +namespace py = pybind11; + +namespace ttnn { +namespace operations { +namespace matmul { + +using namespace tt::operations::primary; + +void py_module(py::module& module) { + py::class_(module, "MatmulProgramConfig") + .def("__repr__", [](const MatmulProgramConfig& config) { return fmt::format("{}", config); }); + + py::class_(module, "MatmulMultiCoreReuseProgramConfig") + .def( + py::init(), + py::kw_only(), + py::arg("compute_with_storage_grid_size"), + py::arg("in0_block_w").noconvert(), + py::arg("out_subblock_h").noconvert(), + py::arg("out_subblock_w").noconvert(), + py::arg("per_core_M").noconvert(), + py::arg("per_core_N").noconvert()) + .def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseProgramConfig::compute_with_storage_grid_size) + .def_readwrite("in0_block_w", &MatmulMultiCoreReuseProgramConfig::in0_block_w) + .def_readwrite("out_subblock_h", &MatmulMultiCoreReuseProgramConfig::out_subblock_h) + .def_readwrite("out_subblock_w", &MatmulMultiCoreReuseProgramConfig::out_subblock_w) + .def_readwrite("per_core_M", &MatmulMultiCoreReuseProgramConfig::per_core_M) + .def_readwrite("per_core_N", &MatmulMultiCoreReuseProgramConfig::per_core_N) + .def("__repr__", [](const MatmulMultiCoreReuseProgramConfig& config) { return fmt::format("{}", config); }); + + py::class_(module, "MatmulMultiCoreReuseMultiCastProgramConfig") + .def( + py::init< + CoreCoord, + std::size_t, + std::size_t, + std::size_t, + std::size_t, + std::size_t, + bool, + std::optional, + bool>(), + py::kw_only(), + py::arg("compute_with_storage_grid_size"), + py::arg("in0_block_w").noconvert(), + py::arg("out_subblock_h").noconvert(), + py::arg("out_subblock_w").noconvert(), + py::arg("per_core_M").noconvert(), + py::arg("per_core_N").noconvert(), + py::arg("transpose_mcast").noconvert(), + py::arg("fused_activation"), + py::arg("fuse_batch").noconvert() = true) + .def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseMultiCastProgramConfig::compute_with_storage_grid_size) + .def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCastProgramConfig::in0_block_w) + .def_readwrite("out_subblock_h", &MatmulMultiCoreReuseMultiCastProgramConfig::out_subblock_h) + .def_readwrite("out_subblock_w", &MatmulMultiCoreReuseMultiCastProgramConfig::out_subblock_w) + .def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCastProgramConfig::per_core_M) + .def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCastProgramConfig::per_core_N) + .def_readwrite("transpose_mcast", &MatmulMultiCoreReuseMultiCastProgramConfig::transpose_mcast) + .def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCastProgramConfig::fused_activation) + .def_readwrite("fuse_batch", &MatmulMultiCoreReuseMultiCastProgramConfig::fuse_batch) + .def("__repr__", [](const MatmulMultiCoreReuseMultiCastProgramConfig& config) { + return fmt::format("{}", config); + }); + + py::class_(module, "MatmulMultiCoreReuseMultiCast1DProgramConfig") + .def( + py::init< + CoreCoord, + std::size_t, + std::size_t, + std::size_t, + std::size_t, + std::size_t, + bool, + std::optional, + bool>(), + py::kw_only(), + py::arg("compute_with_storage_grid_size"), + py::arg("in0_block_w").noconvert(), + py::arg("out_subblock_h").noconvert(), + py::arg("out_subblock_w").noconvert(), + py::arg("per_core_M").noconvert(), + py::arg("per_core_N").noconvert(), + py::arg("fuse_batch").noconvert(), + py::arg("fused_activation"), + py::arg("mcast_in0").noconvert()) + .def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseMultiCast1DProgramConfig::compute_with_storage_grid_size) + .def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::in0_block_w) + .def_readwrite("out_subblock_h", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_h) + .def_readwrite("out_subblock_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_w) + .def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_M) + .def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_N) + .def_readwrite("fuse_batch", &MatmulMultiCoreReuseMultiCast1DProgramConfig::fuse_batch) + .def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCast1DProgramConfig::fused_activation) + .def_readwrite("mcast_in0", &MatmulMultiCoreReuseMultiCast1DProgramConfig::mcast_in0) + .def("__repr__", [](const MatmulMultiCoreReuseMultiCast1DProgramConfig& config) { + return fmt::format("{}", config); + }); + + py::class_( + module, "MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig") + .def( + py::init< + std::size_t, + std::size_t, + std::size_t, + std::optional>(), + py::kw_only(), + py::arg("in0_block_w").noconvert(), + py::arg("per_core_M").noconvert(), + py::arg("per_core_N").noconvert(), + py::arg("fused_activation")) + .def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::in0_block_w) + .def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::per_core_M) + .def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::per_core_N) + .def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::fused_activation) + .def("__repr__", [](const MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig& config) { + return fmt::format("{}", config); + }); + + module.def( + "matmul", + [](const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + const bool transpose_a = false, + const bool transpose_b = false, + const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG, + const std::optional dtype = std::nullopt, + const std::optional program_config = std::nullopt, + const std::optional& activation = std::nullopt, + const std::optional compute_kernel_config = std::nullopt, + const std::optional core_grid = std::nullopt) -> ttnn::Tensor { + return ttnn::operations::matmul::matmul( + input_tensor_a, + input_tensor_b, + /*bias=*/std::nullopt, + transpose_a, + transpose_b, + program_config, + memory_config, + dtype, + activation, + compute_kernel_config, + core_grid, + /*propagate_is_b_batched=*/true); + }, + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("transpose_a") = false, + py::arg("transpose_b") = false, + py::arg("memory_config") = DRAM_MEMORY_CONFIG, + py::arg("dtype") = std::nullopt, + py::arg("program_config") = std::nullopt, + py::arg("activation") = std::nullopt, + py::arg("compute_kernel_config") = std::nullopt, + py::arg("core_grid") = std::nullopt); + + module.def( + "linear", + [](const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + const std::optional& bias = std::nullopt, + const bool transpose_a = false, + const bool transpose_b = false, + const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG, + const std::optional dtype = std::nullopt, + const std::optional program_config = std::nullopt, + const std::optional& activation = std::nullopt, + const std::optional compute_kernel_config = std::nullopt, + const std::optional core_grid = std::nullopt) -> ttnn::Tensor { + return ttnn::operations::matmul::matmul( + input_tensor_a, + input_tensor_b, + bias, + transpose_a, + transpose_b, + program_config, + memory_config, + dtype, + activation, + compute_kernel_config, + core_grid); + }, + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("bias") = std::nullopt, + py::arg("transpose_a") = false, + py::arg("transpose_b") = false, + py::arg("memory_config") = DRAM_MEMORY_CONFIG, + py::arg("dtype") = std::nullopt, + py::arg("program_config") = std::nullopt, + py::arg("activation") = std::nullopt, + py::arg("compute_kernel_config") = std::nullopt, + py::arg("core_grid") = std::nullopt); +} + +} // namespace matmul +} // namespace operations +} // namespace ttnn diff --git a/ttnn/ttnn/operations/matmul.py b/ttnn/ttnn/operations/matmul.py index e21351a3a67..65304b9da22 100644 --- a/ttnn/ttnn/operations/matmul.py +++ b/ttnn/ttnn/operations/matmul.py @@ -7,14 +7,12 @@ import ttnn -MatmulProgramConfig = ttnn._tt_lib.operations.primary.MatmulProgramConfig -MatmulMultiCoreReuseProgramConfig = ttnn._tt_lib.operations.primary.MatmulMultiCoreReuseProgramConfig -MatmulMultiCoreReuseMultiCastProgramConfig = ttnn._tt_lib.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig -MatmulMultiCoreReuseMultiCast1DProgramConfig = ( - ttnn._tt_lib.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig -) +MatmulProgramConfig = ttnn._ttnn.operations.matmul.MatmulProgramConfig +MatmulMultiCoreReuseProgramConfig = ttnn._ttnn.operations.matmul.MatmulMultiCoreReuseProgramConfig +MatmulMultiCoreReuseMultiCastProgramConfig = ttnn._ttnn.operations.matmul.MatmulMultiCoreReuseMultiCastProgramConfig +MatmulMultiCoreReuseMultiCast1DProgramConfig = ttnn._ttnn.operations.matmul.MatmulMultiCoreReuseMultiCast1DProgramConfig MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig = ( - ttnn._tt_lib.operations.primary.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig + ttnn._ttnn.operations.matmul.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig )