From 11822c5b20dc2cacbc896260a38cd085bd7e0437 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Tue, 4 Jun 2024 22:12:07 +0000 Subject: [PATCH] #8662: add initial argmax op single core kernel implementation Implmented on single riscv. Working only for argmax(dim=None). Outputs uint32_t tensor. --- .../pytests/tt_dnn/test_argmax_int.py | 76 +++++++++ tt_eager/tt_dnn/op_library/CMakeLists.txt | 2 + .../tt_dnn/op_library/risc_v/argmax_op.cpp | 130 +++++++++++++++ .../kernels/reader_argmax_interleaved.cpp | 151 ++++++++++++++++++ .../tt_dnn/op_library/risc_v/risc_v_op.cpp | 78 +++++++++ .../tt_dnn/op_library/risc_v/risc_v_op.hpp | 55 +++++++ .../csrc/tt_lib_bindings_tensor_dm_ops.cpp | 18 +++ 7 files changed, 510 insertions(+) create mode 100644 tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py create mode 100644 tt_eager/tt_dnn/op_library/risc_v/argmax_op.cpp create mode 100644 tt_eager/tt_dnn/op_library/risc_v/kernels/reader_argmax_interleaved.cpp create mode 100644 tt_eager/tt_dnn/op_library/risc_v/risc_v_op.cpp create mode 100644 tt_eager/tt_dnn/op_library/risc_v/risc_v_op.hpp diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py new file mode 100644 index 00000000000..147e38dd62a --- /dev/null +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import tt_lib +import ttnn +from loguru import logger +from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 1, 10])), + (torch.Size([1, 1, 10, 20])), + (torch.Size([1, 1, 30, 4])), + (torch.Size([1, 4, 3, 6])), + (torch.Size([5, 4, 3, 20])), + (torch.Size([2, 4, 3, 2])), + (torch.Size([1, 1, 3, 8])), + (torch.Size([1, 1, 1, 24])), + (torch.Size([1, 1, 4, 8])), + (torch.Size([1, 2, 2, 8])), + (torch.Size([1, 2, 2, 4])), + ), +) +@pytest.mark.parametrize("dim", (None,)) +@pytest.mark.parametrize("memconfig", (ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG)) +class TestArgmax: + def test_argmax(self, input_shapes, dim, memconfig, device): + torch.manual_seed(10) + input_data = torch.randn(input_shapes).bfloat16() + + # DEBUG + # input_data = torch.randn(input_shapes).bfloat16() + # lin = torch.arange(24) + # input_data = torch.reshape(lin, input_shapes).bfloat16() + + input_tensor = tt_lib.tensor.Tensor(input_data, tt_lib.tensor.DataType.BFLOAT16).to(device, memconfig) + + tt_output_tensor_on_device = tt_lib.tensor.argmax_int(input_tensor, dim=dim) + tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch() + golden_tensor = torch.argmax(input_data, dim=dim) + if dim == 1 or dim == -3 or dim == 0 or dim == -4: + tt_out_tensor = tt_out_tensor[0, :, 0 : input_shapes[2], 0 : input_shapes[3]] + else: + if input_shapes[1] != 1 or input_shapes[0] != 1: + if dim == 2 or dim == -2: + tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[3]] + else: + tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[2]] + else: + if dim == 2 or dim == -2: + tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[3]] + else: + tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[2]] + + pt_out_tensor = golden_tensor + tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch() + comp_pass, comp_out = comparison_funcs.comp_pcc(pt_out_tensor, tt_out_tensor, pcc=0.99) + comp_all, _ = comparison_funcs.comp_allclose(pt_out_tensor, tt_out_tensor, atol=0, rtol=0) + + # DEBUG + # print(pt_out_tensor) + # print(tt_out_tensor) + # flat = torch.flatten(input_data) + # print(flat) + # print(torch.topk(flat, 8)) + + logger.info(comp_pass) + logger.info(comp_all) + logger.info(comp_out) + status = comp_pass | comp_all + assert status diff --git a/tt_eager/tt_dnn/op_library/CMakeLists.txt b/tt_eager/tt_dnn/op_library/CMakeLists.txt index 6f56c4579a5..515d330a701 100644 --- a/tt_eager/tt_dnn/op_library/CMakeLists.txt +++ b/tt_eager/tt_dnn/op_library/CMakeLists.txt @@ -210,6 +210,8 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/scan/scan_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/topk/topk_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/topk/single_core/single_core_topk.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/risc_v/risc_v_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/risc_v/argmax_op.cpp ) add_library(tt_dnn OBJECT ${TT_DNN_SRCS}) diff --git a/tt_eager/tt_dnn/op_library/risc_v/argmax_op.cpp b/tt_eager/tt_dnn/op_library/risc_v/argmax_op.cpp new file mode 100644 index 00000000000..79667a4ddc1 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/risc_v/argmax_op.cpp @@ -0,0 +1,130 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "tt_dnn/op_library/math.hpp" +#include "tt_dnn/op_library/risc_v/risc_v_op.hpp" +#include "tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; + +namespace tt { + +namespace tt_metal { + +operation::ProgramWithCallbacks argmax_multi_core( + const Tensor &input, const Tensor &output, std::optional dim) { + tt_metal::Program program{}; + + tt::DataFormat input_cb_data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + uint32_t input_unit_size = input.get_legacy_shape()[-1] * input.element_size(); + tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_unit_size = output.get_legacy_shape()[-1] * output.element_size(); + + tt_metal::Device *device = output.device(); + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_units = 1; // single-core + auto [num_cores, all_cores, core_group_1, core_group_2, num_units_per_core_group_1, num_units_per_core_group_2] = + split_work_to_cores(compute_with_storage_grid_size, num_units); + + const auto &input_shape = input.get_legacy_shape(); + const uint32_t B = input_shape[0]; + const uint32_t C = input_shape[1]; + const uint32_t H = input_shape[2]; + const uint32_t W = input_shape[3]; + + uint32_t src0_cb_index = CB::c_in0; + uint32_t num_input_units = W; + uint32_t aligned_input_unit_size = round_up_to_mul32(input_unit_size * num_input_units); + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig( + aligned_input_unit_size, {{src0_cb_index, input_cb_data_format}}) + .set_page_size(src0_cb_index, aligned_input_unit_size); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + + uint32_t intermed0_cb_index = CB::c_intermed0; + uint32_t num_intermed0_units = B * C * H * W; + // TODO: output stick size should be output dim tensor innermost dim + tt_metal::CircularBufferConfig intermed0_cb_config = + tt_metal::CircularBufferConfig( + num_intermed0_units * output.element_size(), {{intermed0_cb_index, output_cb_data_format}}) + .set_page_size(intermed0_cb_index, output.element_size()); /// page size shouldn't matter here + auto cb_intermed0 = tt_metal::CreateCircularBuffer(program, all_cores, intermed0_cb_config); + + /* NO WRITER FOR NOW + uint32_t output_cb_index = 16; // same as input cb + uint32_t num_output_units = 2; + uint32_t aligned_output_unit_size = round_up_to_mul32(output_unit_size); + tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_units * + aligned_output_unit_size, {{output_cb_index, output_cb_data_format}}) .set_page_size(output_cb_index, + aligned_output_unit_size); auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); + */ + + auto src_buffer = input.buffer(); + auto dst_buffer = output.buffer(); + bool src_is_dram = src_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + + std::vector reader_compile_time_args = { + src0_cb_index, + intermed0_cb_index, + src_is_dram, + dst_is_dram, + input_unit_size, + output_unit_size, + B, + C, + H, + W, + dim.value_or(0), + (uint32_t) (not dim.has_value()), + }; + + std::map kernel_defines; + tt_metal::KernelHandle reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/risc_v/kernels/reader_argmax_interleaved.cpp", + all_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines)); + + uint32_t g1_numcores = core_group_1.num_cores(); + uint32_t g2_numcores = core_group_2.num_cores(); + auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false); + + for (uint32_t i = 0; i < cores.size(); ++i) { + const CoreCoord &core = cores.at(i); + + tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, {src_buffer->address(), dst_buffer->address()}); + } + + auto override_runtime_args_callback = [reader_kernel_id, cores]( + const Program &program, + const std::vector &input_buffers, + const std::vector &output_buffers) { + auto src_buffer = input_buffers.at(0); + + auto dst_buffer = output_buffers.at(0); + + for (const auto &core : cores) { + { + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = dst_buffer->address(); + } + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/risc_v/kernels/reader_argmax_interleaved.cpp b/tt_eager/tt_dnn/op_library/risc_v/kernels/reader_argmax_interleaved.cpp new file mode 100644 index 00000000000..03a2b90a443 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/risc_v/kernels/reader_argmax_interleaved.cpp @@ -0,0 +1,151 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +// #include "debug/dprint.h" + +// Function to compare two bfloat16 values using integer arithmetic +bool bfloat16_greater(uint16_t bf16_a, uint16_t bf16_b) { + // Extract signs + uint16_t sign_a = (bf16_a >> 15) & 0x1; + uint16_t sign_b = (bf16_b >> 15) & 0x1; + + uint16_t exp_a = (bf16_a >> 7) & 0xFF; + uint16_t exp_b = (bf16_b >> 7) & 0xFF; + + uint16_t man_a = bf16_a & 0x7F; + uint16_t man_b = bf16_b & 0x7F; + + // TODO: Investigate subnormal support + // uint16_t subnormal_a = (exp_a == 0x00); + // uint16_t subnormal_b = (exp_b == 0x00); + + // DPRINT << HEX() << (bf16_a) << " > " << bf16_b << ENDL(); + // DPRINT << HEX() << (sign_a) << " signs " << sign_b << ENDL(); + // DPRINT << HEX() << (exp_a) << " exp " << exp_b << ENDL(); + // DPRINT << HEX() << (man_a) << " man " << man_b << ENDL(); + + // If signs are different, the one without the sign bit is greater + if (sign_a != sign_b) { + // DPRINT << "sign_b > sign_a: " << (int)(sign_b > sign_a) << ENDL(); + return sign_b > sign_a; + } + + // If signs are the same, compare the exponent and mantissa + if (sign_a == 0) { // Positive numbers + if(exp_a == exp_b) { + // DPRINT << "man_a > man_b: " << (int)(man_a > man_b) << ENDL(); + return man_a > man_b; + } + // DPRINT << "exp_a > exp_b: " << (int)(exp_a > exp_b) << ENDL(); + return exp_a > exp_b; + } else { // Negative numbers + if(exp_a == exp_b) { + // DPRINT << "man_a < man_b: " << (int)(man_a < man_b) << ENDL(); + return man_a < man_b; + } + // DPRINT << "exp_a < exp_b: " << (int)(exp_a < exp_b) << ENDL(); + return exp_a < exp_b; + } +} + +void kernel_main() { + uint32_t src_addr = get_arg_val(0); + uint32_t dst_addr = get_arg_val(1); + + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); + constexpr uint32_t cb_id_intermed0 = get_compile_time_arg_val(1); + constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(2); + constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(3); + constexpr uint32_t in0_stick_size = get_compile_time_arg_val(4); + constexpr uint32_t out_stick_size = get_compile_time_arg_val(5); + constexpr uint32_t B = get_compile_time_arg_val(6); + constexpr uint32_t C = get_compile_time_arg_val(7); + constexpr uint32_t H = get_compile_time_arg_val(8); + constexpr uint32_t W = get_compile_time_arg_val(9); + constexpr uint32_t dim = get_compile_time_arg_val(10); + constexpr uint32_t all = get_compile_time_arg_val(11); + + const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = in0_stick_size}; + + const InterleavedAddrGen s_out = {.bank_base_address = dst_addr, .page_size = out_stick_size}; + + + // Use cb as L1 scratch memory + uint32_t out_addr = get_write_ptr(cb_id_intermed0); + volatile tt_l1_ptr uint32_t* max_vals = reinterpret_cast(out_addr); + + // Use cb as L1 scratch memory + uint32_t cb_addr = get_write_ptr(cb_id_in0); + volatile tt_l1_ptr uint16_t* stick = reinterpret_cast(cb_addr); + + //cb_reserve_back(cb_id_intermed0, C*H*W); + //uint32_t indicies_addr = get_write_ptr(cb_id_intermed0); + //volatile tt_l1_ptr uint32_t* max_indices = reinterpret_cast(cb_addr); + + uint32_t max_index = 0; + uint32_t max_val = 0; + uint32_t index_counter = 0; + for(uint32_t l = 0; l < B; l ++) { + for(uint32_t k = 0; k < C; k++) { + for(uint32_t j = 0; j < H; j++) { + // load stick + // DPRINT << (l*C*H + k*H + j) << ENDL(); + noc_async_read_page(l*C*H + k*H + j, s0, cb_addr); + noc_async_read_barrier(); + for(uint32_t i = 0; i < W; i++) { + if constexpr (all) { + uint16_t val = stick[i]; + if(bfloat16_greater(val, max_val)) { + // DPRINT << "new max " << HEX() << (val) << "\nGT old max " << (max_val) << ENDL(); + // DPRINT << "new idx " << DEC() << (index_counter) << "\nGT old idx " << (max_index) << ENDL(); + // DPRINT << DEC() << (max_index) << ENDL(); + max_index = index_counter; + max_val = val; + } + // DPRINT << "[" << index_counter << "] = " << HEX() << (val) << ENDL(); + index_counter++; + } + else { + /* + if(dim == 3) { + if(bfloat16_greater(bfloat16_max_vals[l][k][j] < stick[i]) { + bfloat16_max_vals[l][k][j] = stick[i]; + max_indices[l][k][j] = i; + } + } + else if(dim == 2) { + if(bfloat16_max_vals[l][k][i] < stick[i]) { + bfloat16_max_vals[l][k][i] = stick[i]; + max_indices[l][k][i] = j; + } + } + else if(dim == 1) { + if(bfloat16_max_vals[l][j][i] < stick[i]) { + bfloat16_max_vals[l][j][i] = stick[i]; + max_indices[l][j][i] = k; + } + } + else if(dim == 0) { + if(bfloat16_greater(stick[i], bfloat16_max_vals[k][j][i])) { + bfloat16_max_vals[k][j][i] = stick[i]; + max_indices[k][j][i] = l; + } + } + */ + } + } + } + } + } + + // TODO: Generalize write for argmax for other dims + max_vals[0] = max_index; + uint64_t dst_noc_addr = get_noc_addr(0, s_out); + noc_async_write(out_addr, dst_noc_addr, out_stick_size); + noc_async_write_barrier(); +} diff --git a/tt_eager/tt_dnn/op_library/risc_v/risc_v_op.cpp b/tt_eager/tt_dnn/op_library/risc_v/risc_v_op.cpp new file mode 100644 index 00000000000..6bdbe93fe50 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/risc_v/risc_v_op.cpp @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/risc_v/risc_v_op.hpp" + +#include "third_party/magic_enum/magic_enum.hpp" +#include "tt_metal/host_api.hpp" + +namespace tt { + +namespace tt_metal { + +void ArgMax::validate(const std::vector &input_tensors) const { + const auto& input_tensor_a = input_tensors.at(0); + TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Input to argmax need to be on device!"); + TT_FATAL(input_tensor_a.buffer() != nullptr , "Input to argmax need to be allocated in buffers on device!"); + + TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Only BFLOAT16 is supported for inputs!"); + TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Only INTERLEAVED memory layout is supported for inputs!"); + TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Only ROW_MAJOR layout is supported for inputs!"); + + TT_FATAL(this->output_dtype == DataType::UINT32, "Only UINT32 is supported for outputs!"); + TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Only INTERLEAVED memory layout is supported for outputs!"); + + if (this->dim.has_value()) { + const uint32_t input_rank = input_tensor_a.get_legacy_shape().rank(); + const uint32_t normalized_dim = dim.value() < 0 ? dim.value() + input_rank : dim.value(); + TT_FATAL(normalized_dim >= 0, fmt::format("Invalid dim for argmax: {}!", dim.value())); + TT_FATAL(normalized_dim < input_rank, fmt::format("Invalid dim for argmax: {}!", dim.value())); + } +} + +std::vector ArgMax::compute_output_shapes(const std::vector &input_tensors) const { + auto input_shape = input_tensors[0].get_legacy_shape(); + if (this->dim.has_value()) { + // TODO: There seems to be an underflow issue with directly modifying last two dims + if (this->dim.value() == -1 or this->dim.value() == 3) { + Shape output_shape({input_shape[0], input_shape[1], input_shape[2], 1}); + return {output_shape}; + } else if (this->dim.value() == -2 or this->dim.value() == 2) { + Shape output_shape({input_shape[0], input_shape[1], 1, input_shape[3]}); + return {output_shape}; + } else { + input_shape[this->dim.value()] = 1; + return {input_shape}; + } + } else { + Shape output_shape({1, 1, 1, 1}); + return {output_shape}; + } +} + +std::vector ArgMax::create_output_tensors(const std::vector &input_tensors) const { + const auto &input_tensor = input_tensors[0]; + return operation::generic_create_output_tensors( + *this, input_tensors, this->output_dtype, input_tensor.get_layout(), this->output_mem_config); +} + +operation::ProgramWithCallbacks ArgMax::create_program( + const std::vector &input_tensors, std::vector &output_tensors) const { + const auto &input_tensor = input_tensors.at(0); + const auto &output_tensor = output_tensors.at(0); + + return argmax_multi_core(input_tensor, output_tensor, this->dim); +} + +tt::stl::reflection::Attributes ArgMax::attributes() const { + return { + {"output_dtype", this->output_dtype}, + {"output_mem_config", this->output_mem_config}, + {"dim", this->dim}, + }; +} + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/risc_v/risc_v_op.hpp b/tt_eager/tt_dnn/op_library/risc_v/risc_v_op.hpp new file mode 100644 index 00000000000..2285ee2bf9e --- /dev/null +++ b/tt_eager/tt_dnn/op_library/risc_v/risc_v_op.hpp @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; + +namespace tt { + +namespace tt_metal { + +operation::ProgramWithCallbacks argmax_multi_core( + const Tensor& input, const Tensor& output, std::optional dim); + +struct ArgMax { + const DataType output_dtype; + const MemoryConfig output_mem_config; + std::optional dim; + + void validate(const std::vector& input_tensors) const; + std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector create_output_tensors(const std::vector& input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; + tt::stl::reflection::Attributes attributes() const; +}; + +inline Tensor argmax_int( + const Tensor& input_tensor, std::optional dim, const MemoryConfig& output_mem_config) { + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_op( + [output_mem_config, dim]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& input_tensor = input_tensors.at(0); + return operation::run(ArgMax{tt::tt_metal::DataType::UINT32, output_mem_config, dim}, {input_tensor}); + }, + {input_tensor}, + output_tensors); + return output_tensors.at(0); +} + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp index b63b38e6f02..5d53c2196c5 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp @@ -25,6 +25,7 @@ #include "tt_dnn/op_library/sharded_partial/sharded_op_partial.hpp" #include "tt_dnn/op_library/all_gather/all_gather_op.hpp" #include "tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" +#include "tt_dnn/op_library/risc_v/risc_v_op.hpp" namespace tt::tt_metal::detail{ @@ -392,6 +393,23 @@ namespace tt::tt_metal::detail{ "output_dtype", "DataType of output tensor", "DataType", "Default is None (use input dtype)", "No" )doc"); + m_tensor.def("argmax_int", &argmax_int, + py::arg("input").noconvert(), py::arg("dim").noconvert() = std::nullopt, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + Returns the indices of the maximum value of elements in the ``input`` tensor + If no ``dim`` is provided, it will return the indices of maximum value of all elements in given ``input`` + + Input tensor must have BFLOAT16 data type. + + Output tensor will have UINT16 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "input", "Tensor argmax is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "dim", "Dimension to perform argmax", "int", "", "No" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + )doc"); + // *** experimental operations *** m_tensor.def("fill_rm", &fill_rm, py::arg("N"), py::arg("C"), py::arg("H"), py::arg("W"), py::arg("hOnes"), py::arg("wOnes"), py::arg("any").noconvert(), py::arg("val_hi"), py::arg("val_lo"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(