Skip to content

Commit

Permalink
#9492: move matmul code to ttnn directory hierarchy (#10015)
Browse files Browse the repository at this point in the history
* #9492: move matmul code to ttnn directory hierarchy

* #9492: Create copy of bmm zm fused bias kernel in tests because can't reference ttnn
  • Loading branch information
bbradelTT authored Jul 8, 2024
1 parent 7d03e02 commit f7c10a2
Show file tree
Hide file tree
Showing 52 changed files with 704 additions and 296 deletions.
40 changes: 32 additions & 8 deletions tests/tt_eager/integration_tests/test_bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
#include "tensor/host_buffer/types.hpp"
#include "tensor/tensor.hpp"
#include "tt_dnn/op_library/bcast/bcast_op.hpp"
#include "tt_dnn/op_library/bmm/bmm_op.hpp"
#include "tt_dnn/op_library/layernorm/layernorm_op.hpp"
#include "tt_dnn/op_library/operation.hpp"
#include "ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp"
#include "tt_dnn/op_library/transformer_tms/transformer_tms.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_numpy/functions.hpp"
#include "ttnn/operations/matmul/matmul.hpp"

using Parameters = std::map<std::string, Tensor>;

Expand All @@ -35,10 +35,12 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param
.transpose_mcast = false,
.fused_activation = std::nullopt,
};
auto fused_qkv_matmul_output = tt::operations::primary::matmul(
auto fused_qkv_matmul_output = ttnn::operations::matmul::matmul(
hidden_states,
parameters.at(fmt::format("fused_qkv_weight_{}", encoder_index)),
parameters.at(fmt::format("fused_qkv_bias_{}", encoder_index)),
/*transpose_a=*/false,
/*transpose_b=*/false,
fused_qkv_matmul_program_config,
l1_memory_config
);
Expand All @@ -56,7 +58,15 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param
.per_core_M = 12,
.per_core_N = 12,
};
auto pre_softmax_bmm_matmul = tt::operations::primary::matmul(query, key, std::nullopt /*bias*/, pre_softmax_bmm_program_config, dram_memory_config);
auto pre_softmax_bmm_matmul = ttnn::operations::matmul::matmul(
query,
key,
/*bias=*/std::nullopt,
/*transpose_a=*/false,
/*transpose_b=*/false,
pre_softmax_bmm_program_config,
dram_memory_config
);
query.deallocate();
key.deallocate();

Expand All @@ -72,7 +82,15 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param
.per_core_M = 12,
.per_core_N = 2,
};
auto post_softmax_bmm_output = tt::operations::primary::matmul(pre_softmax_bmm_matmul, value, std::nullopt /*bias*/, post_softmax_bmm_program_config, l1_memory_config);
auto post_softmax_bmm_output = ttnn::operations::matmul::matmul(
pre_softmax_bmm_matmul,
value,
/*bias=*/std::nullopt,
/*transpose_a=*/false,
/*transpose_b=*/false,
post_softmax_bmm_program_config,
l1_memory_config
);
pre_softmax_bmm_matmul.deallocate();
value.deallocate();

Expand All @@ -91,10 +109,12 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param
.transpose_mcast = false,
.fused_activation = std::nullopt,
};
auto selfout_bmm_output = tt::operations::primary::matmul(
auto selfout_bmm_output = ttnn::operations::matmul::matmul(
concat_heads_output,
parameters.at(fmt::format("selfout_weight_{}", encoder_index)),
parameters.at(fmt::format("selfout_bias_{}", encoder_index)),
/*transpose_a=*/false,
/*transpose_b=*/false,
selfout_bmm_program_config,
l1_memory_config
);
Expand Down Expand Up @@ -123,10 +143,12 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param
.transpose_mcast = false,
.fused_activation = UnaryWithParam(UnaryOpType::GELU,1.0f),
};
auto ff1_matmul_output = tt::operations::primary::matmul(
auto ff1_matmul_output = ttnn::operations::matmul::matmul(
attention_layernorm_output,
parameters.at(fmt::format("ff1_weight_{}", encoder_index)),
parameters.at(fmt::format("ff1_bias_{}", encoder_index)),
/*transpose_a=*/false,
/*transpose_b=*/false,
ff1_matmul_program_config,
dram_memory_config
);
Expand All @@ -142,10 +164,12 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param
.transpose_mcast = false,
.fused_activation = std::nullopt,
};
auto ff2_matmul_output = tt::operations::primary::matmul(
auto ff2_matmul_output = ttnn::operations::matmul::matmul(
ff1_matmul_output,
parameters.at(fmt::format("ff2_weight_{}", encoder_index)),
parameters.at(fmt::format("ff2_bias_{}", encoder_index)),
/*transpose_a=*/false,
/*transpose_b=*/false,
ff2_matmul_program_config,
l1_memory_config
);
Expand All @@ -169,7 +193,7 @@ Tensor encoder(Tensor&& hidden_states, const Tensor& attention_mask, const Param

Tensor qa_head(Tensor&& hidden_states, const Parameters& parameters) {

auto output = tt::operations::primary::matmul(hidden_states, parameters.at("qa_head_weight"));
auto output = ttnn::operations::matmul::matmul(hidden_states, parameters.at("qa_head_weight"), /*bias=*/std::nullopt);
hidden_states.deallocate();


Expand Down
2 changes: 1 addition & 1 deletion tests/tt_eager/ops/test_bmm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "tt_metal/host_api.hpp"
#include "tensor/tensor.hpp"
#include "tt_dnn/op_library/bmm/bmm_op.hpp"
#include "ttnn/operations/matmul/device/matmul_op.hpp"
#include "common/constants.hpp"
#include "tt_numpy/functions.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ tt_metal::Program create_program_single_core (
auto mm_kernel_id = tt_metal::CreateKernel(
program,
matmul_block ?
"tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp" :
"tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation_copy.cpp" :
"tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation.cpp",
all_cores,
tt_metal::ComputeConfig{
Expand Down
9 changes: 0 additions & 9 deletions tt_eager/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,6 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_h/bcast_op_sharded_h_optimised.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_w/bcast_op_multi_core_w.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_hw/bcast_op_multi_core_hw.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/bmm_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/single_core/bmm_op_single_core_tilize_untilize.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core/bmm_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse/bmm_op_multi_core_reuse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/downsample/downsample_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/optimized_conv_op.cpp
Expand Down
1 change: 0 additions & 1 deletion tt_eager/tt_dnn/op_library/complex/complex_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "tt_dnn/op_library/complex/complex_ops.hpp"
#include "tt_dnn/op_library/concat/concat_op.hpp"
#include "tt_dnn/op_library/bmm/bmm_op.hpp"
#include "tt_dnn/op_library/reshape/reshape_op.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "tt_numpy/functions.hpp"
Expand Down
12 changes: 10 additions & 2 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "tt_dnn/op_library/composite/composite_ops.hpp"

#include "tt_dnn/op_library/auto_format.hpp"
#include "tt_dnn/op_library/bmm/bmm_op.hpp"
#include "tt_dnn/op_library/concat/concat_op.hpp"
#include "tt_dnn/op_library/copy/copy_op.hpp"
#include "tt_dnn/op_library/math.hpp"
Expand All @@ -23,6 +22,7 @@

#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/matmul/matmul.hpp"

namespace tt {

Expand Down Expand Up @@ -1769,7 +1769,15 @@ Tensor _outer(Tensor& a, Tensor& b, const MemoryConfig& output_mem_config) {
}
}

return tt::operations::primary::matmul(a_slim, b_slim, std::nullopt, std::nullopt, output_mem_config);
return ttnn::operations::matmul::matmul(
a_slim,
b_slim,
/*bias=*/std::nullopt,
/*transpose_a=*/false,
/*transpose_b=*/false,
/*program_config=*/std::nullopt,
output_mem_config
);
}
Tensor outer(Tensor& a, Tensor& b, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _outer)(a, b, output_mem_config);
Expand Down
12 changes: 10 additions & 2 deletions tt_eager/tt_dnn/op_library/fully_connected/fully_connected_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,22 @@
#include <type_traits>

#include "tt_metal/host_api.hpp"
#include "tt_dnn/op_library/bmm/bmm_op.hpp"
#include "tt_dnn/op_library/bcast/bcast_op.hpp"
#include "ttnn/operations/matmul/matmul.hpp"

namespace tt {
namespace tt_metal {

Tensor fully_connected_(const Tensor& act, const Tensor& weights, std::optional<std::reference_wrapper<const Tensor>> bias, const MemoryConfig& output_mem_config) {
Tensor mm_output = tt::operations::primary::matmul(act, weights, /*bias=*/std::nullopt, /*program_config=*/std::nullopt, output_mem_config);
Tensor mm_output = ttnn::operations::matmul::matmul(
act,
weights,
/*bias=*/std::nullopt,
/*transpose_a=*/false,
/*transpose_b=*/false,
/*program_config=*/std::nullopt,
output_mem_config
);
if (bias) {
return bcast(mm_output, bias.value(), BcastOpMath::ADD, BcastOpDim::H, output_mem_config);
}
Expand Down
113 changes: 0 additions & 113 deletions tt_eager/tt_lib/csrc/operations/primary/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <pybind11/stl.h>

#include "transformers/module.hpp"
#include "tt_dnn/op_library/bmm/bmm_op.hpp"
#include "tt_dnn/op_library/groupnorm/groupnorm_op.hpp"
#include "tt_dnn/op_library/layernorm/layernorm_op.hpp"
#include "tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp"
Expand Down Expand Up @@ -53,118 +52,6 @@ void py_module(py::module& m_primary) {
auto m_transformers = m_primary.def_submodule("transformers", "Primary transformers operations");
transformers::py_module(m_transformers);

py::class_<MatmulProgramConfig>(m_primary, "MatmulProgramConfig")
.def("__repr__", [](const MatmulProgramConfig& config) { return fmt::format("{}", config); });

py::class_<MatmulMultiCoreReuseProgramConfig>(m_primary, "MatmulMultiCoreReuseProgramConfig")
.def(
py::init<CoreCoord, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t>(),
py::kw_only(),
py::arg("compute_with_storage_grid_size"),
py::arg("in0_block_w").noconvert(),
py::arg("out_subblock_h").noconvert(),
py::arg("out_subblock_w").noconvert(),
py::arg("per_core_M").noconvert(),
py::arg("per_core_N").noconvert())
.def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseProgramConfig::compute_with_storage_grid_size)
.def_readwrite("in0_block_w", &MatmulMultiCoreReuseProgramConfig::in0_block_w)
.def_readwrite("out_subblock_h", &MatmulMultiCoreReuseProgramConfig::out_subblock_h)
.def_readwrite("out_subblock_w", &MatmulMultiCoreReuseProgramConfig::out_subblock_w)
.def_readwrite("per_core_M", &MatmulMultiCoreReuseProgramConfig::per_core_M)
.def_readwrite("per_core_N", &MatmulMultiCoreReuseProgramConfig::per_core_N)
.def("__repr__", [](const MatmulMultiCoreReuseProgramConfig& config) { return fmt::format("{}", config); });

py::class_<MatmulMultiCoreReuseMultiCastProgramConfig>(m_primary, "MatmulMultiCoreReuseMultiCastProgramConfig")
.def(
py::init<
CoreCoord,
std::size_t,
std::size_t,
std::size_t,
std::size_t,
std::size_t,
bool,
std::optional<UnaryWithParam>,
bool>(),
py::kw_only(),
py::arg("compute_with_storage_grid_size"),
py::arg("in0_block_w").noconvert(),
py::arg("out_subblock_h").noconvert(),
py::arg("out_subblock_w").noconvert(),
py::arg("per_core_M").noconvert(),
py::arg("per_core_N").noconvert(),
py::arg("transpose_mcast").noconvert(),
py::arg("fused_activation"),
py::arg("fuse_batch").noconvert() = true)
.def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseMultiCastProgramConfig::compute_with_storage_grid_size)
.def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCastProgramConfig::in0_block_w)
.def_readwrite("out_subblock_h", &MatmulMultiCoreReuseMultiCastProgramConfig::out_subblock_h)
.def_readwrite("out_subblock_w", &MatmulMultiCoreReuseMultiCastProgramConfig::out_subblock_w)
.def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCastProgramConfig::per_core_M)
.def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCastProgramConfig::per_core_N)
.def_readwrite("transpose_mcast", &MatmulMultiCoreReuseMultiCastProgramConfig::transpose_mcast)
.def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCastProgramConfig::fused_activation)
.def_readwrite("fuse_batch", &MatmulMultiCoreReuseMultiCastProgramConfig::fuse_batch)
.def("__repr__", [](const MatmulMultiCoreReuseMultiCastProgramConfig& config) {
return fmt::format("{}", config);
});

py::class_<MatmulMultiCoreReuseMultiCast1DProgramConfig>(m_primary, "MatmulMultiCoreReuseMultiCast1DProgramConfig")
.def(
py::init<
CoreCoord,
std::size_t,
std::size_t,
std::size_t,
std::size_t,
std::size_t,
bool,
std::optional<UnaryWithParam>,
bool>(),
py::kw_only(),
py::arg("compute_with_storage_grid_size"),
py::arg("in0_block_w").noconvert(),
py::arg("out_subblock_h").noconvert(),
py::arg("out_subblock_w").noconvert(),
py::arg("per_core_M").noconvert(),
py::arg("per_core_N").noconvert(),
py::arg("fuse_batch").noconvert(),
py::arg("fused_activation"),
py::arg("mcast_in0").noconvert())
.def_readwrite("compute_with_storage_grid_size", &MatmulMultiCoreReuseMultiCast1DProgramConfig::compute_with_storage_grid_size)
.def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::in0_block_w)
.def_readwrite("out_subblock_h", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_h)
.def_readwrite("out_subblock_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_w)
.def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_M)
.def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_N)
.def_readwrite("fuse_batch", &MatmulMultiCoreReuseMultiCast1DProgramConfig::fuse_batch)
.def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCast1DProgramConfig::fused_activation)
.def_readwrite("mcast_in0", &MatmulMultiCoreReuseMultiCast1DProgramConfig::mcast_in0)
.def("__repr__", [](const MatmulMultiCoreReuseMultiCast1DProgramConfig& config) {
return fmt::format("{}", config);
});

py::class_<MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig>(
m_primary, "MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig")
.def(
py::init<
std::size_t,
std::size_t,
std::size_t,
std::optional<UnaryWithParam>>(),
py::kw_only(),
py::arg("in0_block_w").noconvert(),
py::arg("per_core_M").noconvert(),
py::arg("per_core_N").noconvert(),
py::arg("fused_activation"))
.def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::in0_block_w)
.def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::per_core_M)
.def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::per_core_N)
.def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::fused_activation)
.def("__repr__", [](const MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig& config) {
return fmt::format("{}", config);
});

py::class_<LayerNormDefaultProgramConfig>(m_primary, "LayerNormDefaultProgramConfig").def(py::init<>());

py::class_<LayerNormShardedMultiCoreProgramConfig>(m_primary, "LayerNormShardedMultiCoreProgramConfig")
Expand Down
11 changes: 10 additions & 1 deletion ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/async_runtime.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/to_layout/to_layout_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/matmul_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/single_core/bmm_op_single_core_tilize_untilize.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core/bmm_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse/bmm_op_multi_core_reuse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_padding/bmm_op_multi_core_reuse_padding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "pybind11/operations/core.hpp"
#include "pybind11/operations/creation.hpp"
#include "pybind11/operations/kv_cache.hpp"
#include "pybind11/operations/matmul.hpp"
#include "pybind11/operations/maxpool2d.hpp"
#include "pybind11/operations/pool.hpp"
#include "pybind11/operations/ternary.hpp"
Expand All @@ -27,6 +26,7 @@
#include "ttnn/operations/eltwise/ternary_backward/ternary_backward_pybind.hpp"
#include "ttnn/operations/data_movement/data_movement_pybind.hpp"
#include "ttnn/operations/embedding/embedding_ops_pybind.hpp"
#include "ttnn/operations/matmul/matmul_pybind.hpp"


namespace py = pybind11;
Expand Down
Loading

0 comments on commit f7c10a2

Please sign in to comment.