Skip to content

Commit

Permalink
#8662: add initial argmax op single core kernel implementation
Browse files Browse the repository at this point in the history
Implmented on single riscv. Working only for argmax(dim=None). Outputs uint32_t tensor.
  • Loading branch information
TT-BrianLiu authored and xanderchin committed Jun 5, 2024
1 parent 696dc36 commit 11822c5
Show file tree
Hide file tree
Showing 7 changed files with 510 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
import ttnn
from loguru import logger
from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 1, 10])),
(torch.Size([1, 1, 10, 20])),
(torch.Size([1, 1, 30, 4])),
(torch.Size([1, 4, 3, 6])),
(torch.Size([5, 4, 3, 20])),
(torch.Size([2, 4, 3, 2])),
(torch.Size([1, 1, 3, 8])),
(torch.Size([1, 1, 1, 24])),
(torch.Size([1, 1, 4, 8])),
(torch.Size([1, 2, 2, 8])),
(torch.Size([1, 2, 2, 4])),
),
)
@pytest.mark.parametrize("dim", (None,))
@pytest.mark.parametrize("memconfig", (ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG))
class TestArgmax:
def test_argmax(self, input_shapes, dim, memconfig, device):
torch.manual_seed(10)
input_data = torch.randn(input_shapes).bfloat16()

# DEBUG
# input_data = torch.randn(input_shapes).bfloat16()
# lin = torch.arange(24)
# input_data = torch.reshape(lin, input_shapes).bfloat16()

input_tensor = tt_lib.tensor.Tensor(input_data, tt_lib.tensor.DataType.BFLOAT16).to(device, memconfig)

tt_output_tensor_on_device = tt_lib.tensor.argmax_int(input_tensor, dim=dim)
tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
golden_tensor = torch.argmax(input_data, dim=dim)
if dim == 1 or dim == -3 or dim == 0 or dim == -4:
tt_out_tensor = tt_out_tensor[0, :, 0 : input_shapes[2], 0 : input_shapes[3]]
else:
if input_shapes[1] != 1 or input_shapes[0] != 1:
if dim == 2 or dim == -2:
tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[3]]
else:
tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[2]]
else:
if dim == 2 or dim == -2:
tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[3]]
else:
tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[2]]

pt_out_tensor = golden_tensor
tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
comp_pass, comp_out = comparison_funcs.comp_pcc(pt_out_tensor, tt_out_tensor, pcc=0.99)
comp_all, _ = comparison_funcs.comp_allclose(pt_out_tensor, tt_out_tensor, atol=0, rtol=0)

# DEBUG
# print(pt_out_tensor)
# print(tt_out_tensor)
# flat = torch.flatten(input_data)
# print(flat)
# print(torch.topk(flat, 8))

logger.info(comp_pass)
logger.info(comp_all)
logger.info(comp_out)
status = comp_pass | comp_all
assert status
2 changes: 2 additions & 0 deletions tt_eager/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/scan/scan_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/topk/topk_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/topk/single_core/single_core_topk.cpp
${CMAKE_CURRENT_SOURCE_DIR}/risc_v/risc_v_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/risc_v/argmax_op.cpp
)

add_library(tt_dnn OBJECT ${TT_DNN_SRCS})
Expand Down
130 changes: 130 additions & 0 deletions tt_eager/tt_dnn/op_library/risc_v/argmax_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <algorithm>

#include "tt_dnn/op_library/math.hpp"
#include "tt_dnn/op_library/risc_v/risc_v_op.hpp"
#include "tt_dnn/op_library/work_split.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/detail/util.hpp"
#include "tt_metal/host_api.hpp"

using namespace tt::constants;

namespace tt {

namespace tt_metal {

operation::ProgramWithCallbacks argmax_multi_core(
const Tensor &input, const Tensor &output, std::optional<const int> dim) {
tt_metal::Program program{};

tt::DataFormat input_cb_data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype());
uint32_t input_unit_size = input.get_legacy_shape()[-1] * input.element_size();
tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());
uint32_t output_unit_size = output.get_legacy_shape()[-1] * output.element_size();

tt_metal::Device *device = output.device();

auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_units = 1; // single-core
auto [num_cores, all_cores, core_group_1, core_group_2, num_units_per_core_group_1, num_units_per_core_group_2] =
split_work_to_cores(compute_with_storage_grid_size, num_units);

const auto &input_shape = input.get_legacy_shape();
const uint32_t B = input_shape[0];
const uint32_t C = input_shape[1];
const uint32_t H = input_shape[2];
const uint32_t W = input_shape[3];

uint32_t src0_cb_index = CB::c_in0;
uint32_t num_input_units = W;
uint32_t aligned_input_unit_size = round_up_to_mul32(input_unit_size * num_input_units);
tt_metal::CircularBufferConfig cb_src0_config =
tt_metal::CircularBufferConfig(
aligned_input_unit_size, {{src0_cb_index, input_cb_data_format}})
.set_page_size(src0_cb_index, aligned_input_unit_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);

uint32_t intermed0_cb_index = CB::c_intermed0;
uint32_t num_intermed0_units = B * C * H * W;
// TODO: output stick size should be output dim tensor innermost dim
tt_metal::CircularBufferConfig intermed0_cb_config =
tt_metal::CircularBufferConfig(
num_intermed0_units * output.element_size(), {{intermed0_cb_index, output_cb_data_format}})
.set_page_size(intermed0_cb_index, output.element_size()); /// page size shouldn't matter here
auto cb_intermed0 = tt_metal::CreateCircularBuffer(program, all_cores, intermed0_cb_config);

/* NO WRITER FOR NOW
uint32_t output_cb_index = 16; // same as input cb
uint32_t num_output_units = 2;
uint32_t aligned_output_unit_size = round_up_to_mul32(output_unit_size);
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_units *
aligned_output_unit_size, {{output_cb_index, output_cb_data_format}}) .set_page_size(output_cb_index,
aligned_output_unit_size); auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config);
*/

auto src_buffer = input.buffer();
auto dst_buffer = output.buffer();
bool src_is_dram = src_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;

std::vector<uint32_t> reader_compile_time_args = {
src0_cb_index,
intermed0_cb_index,
src_is_dram,
dst_is_dram,
input_unit_size,
output_unit_size,
B,
C,
H,
W,
dim.value_or(0),
(uint32_t) (not dim.has_value()),
};

std::map<string, string> kernel_defines;
tt_metal::KernelHandle reader_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/risc_v/kernels/reader_argmax_interleaved.cpp",
all_cores,
tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines));

uint32_t g1_numcores = core_group_1.num_cores();
uint32_t g2_numcores = core_group_2.num_cores();
auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false);

for (uint32_t i = 0; i < cores.size(); ++i) {
const CoreCoord &core = cores.at(i);

tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, {src_buffer->address(), dst_buffer->address()});
}

auto override_runtime_args_callback = [reader_kernel_id, cores](
const Program &program,
const std::vector<Buffer *> &input_buffers,
const std::vector<Buffer *> &output_buffers) {
auto src_buffer = input_buffers.at(0);

auto dst_buffer = output_buffers.at(0);

for (const auto &core : cores) {
{
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = src_buffer->address();
runtime_args[1] = dst_buffer->address();
}
}
};

return {std::move(program), override_runtime_args_callback};
}

} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>

#include "dataflow_api.h"

// #include "debug/dprint.h"

// Function to compare two bfloat16 values using integer arithmetic
bool bfloat16_greater(uint16_t bf16_a, uint16_t bf16_b) {
// Extract signs
uint16_t sign_a = (bf16_a >> 15) & 0x1;
uint16_t sign_b = (bf16_b >> 15) & 0x1;

uint16_t exp_a = (bf16_a >> 7) & 0xFF;
uint16_t exp_b = (bf16_b >> 7) & 0xFF;

uint16_t man_a = bf16_a & 0x7F;
uint16_t man_b = bf16_b & 0x7F;

// TODO: Investigate subnormal support
// uint16_t subnormal_a = (exp_a == 0x00);
// uint16_t subnormal_b = (exp_b == 0x00);

// DPRINT << HEX() << (bf16_a) << " > " << bf16_b << ENDL();
// DPRINT << HEX() << (sign_a) << " signs " << sign_b << ENDL();
// DPRINT << HEX() << (exp_a) << " exp " << exp_b << ENDL();
// DPRINT << HEX() << (man_a) << " man " << man_b << ENDL();

// If signs are different, the one without the sign bit is greater
if (sign_a != sign_b) {
// DPRINT << "sign_b > sign_a: " << (int)(sign_b > sign_a) << ENDL();
return sign_b > sign_a;
}

// If signs are the same, compare the exponent and mantissa
if (sign_a == 0) { // Positive numbers
if(exp_a == exp_b) {
// DPRINT << "man_a > man_b: " << (int)(man_a > man_b) << ENDL();
return man_a > man_b;
}
// DPRINT << "exp_a > exp_b: " << (int)(exp_a > exp_b) << ENDL();
return exp_a > exp_b;
} else { // Negative numbers
if(exp_a == exp_b) {
// DPRINT << "man_a < man_b: " << (int)(man_a < man_b) << ENDL();
return man_a < man_b;
}
// DPRINT << "exp_a < exp_b: " << (int)(exp_a < exp_b) << ENDL();
return exp_a < exp_b;
}
}

void kernel_main() {
uint32_t src_addr = get_arg_val<uint32_t>(0);
uint32_t dst_addr = get_arg_val<uint32_t>(1);

constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0);
constexpr uint32_t cb_id_intermed0 = get_compile_time_arg_val(1);
constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(2);
constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(3);
constexpr uint32_t in0_stick_size = get_compile_time_arg_val(4);
constexpr uint32_t out_stick_size = get_compile_time_arg_val(5);
constexpr uint32_t B = get_compile_time_arg_val(6);
constexpr uint32_t C = get_compile_time_arg_val(7);
constexpr uint32_t H = get_compile_time_arg_val(8);
constexpr uint32_t W = get_compile_time_arg_val(9);
constexpr uint32_t dim = get_compile_time_arg_val(10);
constexpr uint32_t all = get_compile_time_arg_val(11);

const InterleavedAddrGen<src0_is_dram> s0 = {.bank_base_address = src_addr, .page_size = in0_stick_size};

const InterleavedAddrGen<dst_is_dram> s_out = {.bank_base_address = dst_addr, .page_size = out_stick_size};


// Use cb as L1 scratch memory
uint32_t out_addr = get_write_ptr(cb_id_intermed0);
volatile tt_l1_ptr uint32_t* max_vals = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(out_addr);

// Use cb as L1 scratch memory
uint32_t cb_addr = get_write_ptr(cb_id_in0);
volatile tt_l1_ptr uint16_t* stick = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(cb_addr);

//cb_reserve_back(cb_id_intermed0, C*H*W);
//uint32_t indicies_addr = get_write_ptr(cb_id_intermed0);
//volatile tt_l1_ptr uint32_t* max_indices = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(cb_addr);

uint32_t max_index = 0;
uint32_t max_val = 0;
uint32_t index_counter = 0;
for(uint32_t l = 0; l < B; l ++) {
for(uint32_t k = 0; k < C; k++) {
for(uint32_t j = 0; j < H; j++) {
// load stick
// DPRINT << (l*C*H + k*H + j) << ENDL();
noc_async_read_page(l*C*H + k*H + j, s0, cb_addr);
noc_async_read_barrier();
for(uint32_t i = 0; i < W; i++) {
if constexpr (all) {
uint16_t val = stick[i];
if(bfloat16_greater(val, max_val)) {
// DPRINT << "new max " << HEX() << (val) << "\nGT old max " << (max_val) << ENDL();
// DPRINT << "new idx " << DEC() << (index_counter) << "\nGT old idx " << (max_index) << ENDL();
// DPRINT << DEC() << (max_index) << ENDL();
max_index = index_counter;
max_val = val;
}
// DPRINT << "[" << index_counter << "] = " << HEX() << (val) << ENDL();
index_counter++;
}
else {
/*
if(dim == 3) {
if(bfloat16_greater(bfloat16_max_vals[l][k][j] < stick[i]) {
bfloat16_max_vals[l][k][j] = stick[i];
max_indices[l][k][j] = i;
}
}
else if(dim == 2) {
if(bfloat16_max_vals[l][k][i] < stick[i]) {
bfloat16_max_vals[l][k][i] = stick[i];
max_indices[l][k][i] = j;
}
}
else if(dim == 1) {
if(bfloat16_max_vals[l][j][i] < stick[i]) {
bfloat16_max_vals[l][j][i] = stick[i];
max_indices[l][j][i] = k;
}
}
else if(dim == 0) {
if(bfloat16_greater(stick[i], bfloat16_max_vals[k][j][i])) {
bfloat16_max_vals[k][j][i] = stick[i];
max_indices[k][j][i] = l;
}
}
*/
}
}
}
}
}

// TODO: Generalize write for argmax for other dims
max_vals[0] = max_index;
uint64_t dst_noc_addr = get_noc_addr(0, s_out);
noc_async_write(out_addr, dst_noc_addr, out_stick_size);
noc_async_write_barrier();
}
Loading

0 comments on commit 11822c5

Please sign in to comment.