-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#8662: add initial argmax op single core kernel implementation
Implmented on single riscv. Working only for argmax(dim=None). Outputs uint32_t tensor.
- Loading branch information
1 parent
696dc36
commit 11822c5
Showing
7 changed files
with
510 additions
and
0 deletions.
There are no files selected for viewing
76 changes: 76 additions & 0 deletions
76
tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import pytest | ||
import tt_lib | ||
import ttnn | ||
from loguru import logger | ||
from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
( | ||
(torch.Size([1, 1, 1, 10])), | ||
(torch.Size([1, 1, 10, 20])), | ||
(torch.Size([1, 1, 30, 4])), | ||
(torch.Size([1, 4, 3, 6])), | ||
(torch.Size([5, 4, 3, 20])), | ||
(torch.Size([2, 4, 3, 2])), | ||
(torch.Size([1, 1, 3, 8])), | ||
(torch.Size([1, 1, 1, 24])), | ||
(torch.Size([1, 1, 4, 8])), | ||
(torch.Size([1, 2, 2, 8])), | ||
(torch.Size([1, 2, 2, 4])), | ||
), | ||
) | ||
@pytest.mark.parametrize("dim", (None,)) | ||
@pytest.mark.parametrize("memconfig", (ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG)) | ||
class TestArgmax: | ||
def test_argmax(self, input_shapes, dim, memconfig, device): | ||
torch.manual_seed(10) | ||
input_data = torch.randn(input_shapes).bfloat16() | ||
|
||
# DEBUG | ||
# input_data = torch.randn(input_shapes).bfloat16() | ||
# lin = torch.arange(24) | ||
# input_data = torch.reshape(lin, input_shapes).bfloat16() | ||
|
||
input_tensor = tt_lib.tensor.Tensor(input_data, tt_lib.tensor.DataType.BFLOAT16).to(device, memconfig) | ||
|
||
tt_output_tensor_on_device = tt_lib.tensor.argmax_int(input_tensor, dim=dim) | ||
tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch() | ||
golden_tensor = torch.argmax(input_data, dim=dim) | ||
if dim == 1 or dim == -3 or dim == 0 or dim == -4: | ||
tt_out_tensor = tt_out_tensor[0, :, 0 : input_shapes[2], 0 : input_shapes[3]] | ||
else: | ||
if input_shapes[1] != 1 or input_shapes[0] != 1: | ||
if dim == 2 or dim == -2: | ||
tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[3]] | ||
else: | ||
tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[2]] | ||
else: | ||
if dim == 2 or dim == -2: | ||
tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[3]] | ||
else: | ||
tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[2]] | ||
|
||
pt_out_tensor = golden_tensor | ||
tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch() | ||
comp_pass, comp_out = comparison_funcs.comp_pcc(pt_out_tensor, tt_out_tensor, pcc=0.99) | ||
comp_all, _ = comparison_funcs.comp_allclose(pt_out_tensor, tt_out_tensor, atol=0, rtol=0) | ||
|
||
# DEBUG | ||
# print(pt_out_tensor) | ||
# print(tt_out_tensor) | ||
# flat = torch.flatten(input_data) | ||
# print(flat) | ||
# print(torch.topk(flat, 8)) | ||
|
||
logger.info(comp_pass) | ||
logger.info(comp_all) | ||
logger.info(comp_out) | ||
status = comp_pass | comp_all | ||
assert status |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <algorithm> | ||
|
||
#include "tt_dnn/op_library/math.hpp" | ||
#include "tt_dnn/op_library/risc_v/risc_v_op.hpp" | ||
#include "tt_dnn/op_library/work_split.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/detail/util.hpp" | ||
#include "tt_metal/host_api.hpp" | ||
|
||
using namespace tt::constants; | ||
|
||
namespace tt { | ||
|
||
namespace tt_metal { | ||
|
||
operation::ProgramWithCallbacks argmax_multi_core( | ||
const Tensor &input, const Tensor &output, std::optional<const int> dim) { | ||
tt_metal::Program program{}; | ||
|
||
tt::DataFormat input_cb_data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); | ||
uint32_t input_unit_size = input.get_legacy_shape()[-1] * input.element_size(); | ||
tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); | ||
uint32_t output_unit_size = output.get_legacy_shape()[-1] * output.element_size(); | ||
|
||
tt_metal::Device *device = output.device(); | ||
|
||
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); | ||
uint32_t num_cores_x = compute_with_storage_grid_size.x; | ||
uint32_t num_cores_y = compute_with_storage_grid_size.y; | ||
uint32_t num_units = 1; // single-core | ||
auto [num_cores, all_cores, core_group_1, core_group_2, num_units_per_core_group_1, num_units_per_core_group_2] = | ||
split_work_to_cores(compute_with_storage_grid_size, num_units); | ||
|
||
const auto &input_shape = input.get_legacy_shape(); | ||
const uint32_t B = input_shape[0]; | ||
const uint32_t C = input_shape[1]; | ||
const uint32_t H = input_shape[2]; | ||
const uint32_t W = input_shape[3]; | ||
|
||
uint32_t src0_cb_index = CB::c_in0; | ||
uint32_t num_input_units = W; | ||
uint32_t aligned_input_unit_size = round_up_to_mul32(input_unit_size * num_input_units); | ||
tt_metal::CircularBufferConfig cb_src0_config = | ||
tt_metal::CircularBufferConfig( | ||
aligned_input_unit_size, {{src0_cb_index, input_cb_data_format}}) | ||
.set_page_size(src0_cb_index, aligned_input_unit_size); | ||
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); | ||
|
||
uint32_t intermed0_cb_index = CB::c_intermed0; | ||
uint32_t num_intermed0_units = B * C * H * W; | ||
// TODO: output stick size should be output dim tensor innermost dim | ||
tt_metal::CircularBufferConfig intermed0_cb_config = | ||
tt_metal::CircularBufferConfig( | ||
num_intermed0_units * output.element_size(), {{intermed0_cb_index, output_cb_data_format}}) | ||
.set_page_size(intermed0_cb_index, output.element_size()); /// page size shouldn't matter here | ||
auto cb_intermed0 = tt_metal::CreateCircularBuffer(program, all_cores, intermed0_cb_config); | ||
|
||
/* NO WRITER FOR NOW | ||
uint32_t output_cb_index = 16; // same as input cb | ||
uint32_t num_output_units = 2; | ||
uint32_t aligned_output_unit_size = round_up_to_mul32(output_unit_size); | ||
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_units * | ||
aligned_output_unit_size, {{output_cb_index, output_cb_data_format}}) .set_page_size(output_cb_index, | ||
aligned_output_unit_size); auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); | ||
*/ | ||
|
||
auto src_buffer = input.buffer(); | ||
auto dst_buffer = output.buffer(); | ||
bool src_is_dram = src_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; | ||
bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; | ||
|
||
std::vector<uint32_t> reader_compile_time_args = { | ||
src0_cb_index, | ||
intermed0_cb_index, | ||
src_is_dram, | ||
dst_is_dram, | ||
input_unit_size, | ||
output_unit_size, | ||
B, | ||
C, | ||
H, | ||
W, | ||
dim.value_or(0), | ||
(uint32_t) (not dim.has_value()), | ||
}; | ||
|
||
std::map<string, string> kernel_defines; | ||
tt_metal::KernelHandle reader_kernel_id = tt_metal::CreateKernel( | ||
program, | ||
"tt_eager/tt_dnn/op_library/risc_v/kernels/reader_argmax_interleaved.cpp", | ||
all_cores, | ||
tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines)); | ||
|
||
uint32_t g1_numcores = core_group_1.num_cores(); | ||
uint32_t g2_numcores = core_group_2.num_cores(); | ||
auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false); | ||
|
||
for (uint32_t i = 0; i < cores.size(); ++i) { | ||
const CoreCoord &core = cores.at(i); | ||
|
||
tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, {src_buffer->address(), dst_buffer->address()}); | ||
} | ||
|
||
auto override_runtime_args_callback = [reader_kernel_id, cores]( | ||
const Program &program, | ||
const std::vector<Buffer *> &input_buffers, | ||
const std::vector<Buffer *> &output_buffers) { | ||
auto src_buffer = input_buffers.at(0); | ||
|
||
auto dst_buffer = output_buffers.at(0); | ||
|
||
for (const auto &core : cores) { | ||
{ | ||
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); | ||
runtime_args[0] = src_buffer->address(); | ||
runtime_args[1] = dst_buffer->address(); | ||
} | ||
} | ||
}; | ||
|
||
return {std::move(program), override_runtime_args_callback}; | ||
} | ||
|
||
} // namespace tt_metal | ||
|
||
} // namespace tt |
151 changes: 151 additions & 0 deletions
151
tt_eager/tt_dnn/op_library/risc_v/kernels/reader_argmax_interleaved.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
|
||
#include "dataflow_api.h" | ||
|
||
// #include "debug/dprint.h" | ||
|
||
// Function to compare two bfloat16 values using integer arithmetic | ||
bool bfloat16_greater(uint16_t bf16_a, uint16_t bf16_b) { | ||
// Extract signs | ||
uint16_t sign_a = (bf16_a >> 15) & 0x1; | ||
uint16_t sign_b = (bf16_b >> 15) & 0x1; | ||
|
||
uint16_t exp_a = (bf16_a >> 7) & 0xFF; | ||
uint16_t exp_b = (bf16_b >> 7) & 0xFF; | ||
|
||
uint16_t man_a = bf16_a & 0x7F; | ||
uint16_t man_b = bf16_b & 0x7F; | ||
|
||
// TODO: Investigate subnormal support | ||
// uint16_t subnormal_a = (exp_a == 0x00); | ||
// uint16_t subnormal_b = (exp_b == 0x00); | ||
|
||
// DPRINT << HEX() << (bf16_a) << " > " << bf16_b << ENDL(); | ||
// DPRINT << HEX() << (sign_a) << " signs " << sign_b << ENDL(); | ||
// DPRINT << HEX() << (exp_a) << " exp " << exp_b << ENDL(); | ||
// DPRINT << HEX() << (man_a) << " man " << man_b << ENDL(); | ||
|
||
// If signs are different, the one without the sign bit is greater | ||
if (sign_a != sign_b) { | ||
// DPRINT << "sign_b > sign_a: " << (int)(sign_b > sign_a) << ENDL(); | ||
return sign_b > sign_a; | ||
} | ||
|
||
// If signs are the same, compare the exponent and mantissa | ||
if (sign_a == 0) { // Positive numbers | ||
if(exp_a == exp_b) { | ||
// DPRINT << "man_a > man_b: " << (int)(man_a > man_b) << ENDL(); | ||
return man_a > man_b; | ||
} | ||
// DPRINT << "exp_a > exp_b: " << (int)(exp_a > exp_b) << ENDL(); | ||
return exp_a > exp_b; | ||
} else { // Negative numbers | ||
if(exp_a == exp_b) { | ||
// DPRINT << "man_a < man_b: " << (int)(man_a < man_b) << ENDL(); | ||
return man_a < man_b; | ||
} | ||
// DPRINT << "exp_a < exp_b: " << (int)(exp_a < exp_b) << ENDL(); | ||
return exp_a < exp_b; | ||
} | ||
} | ||
|
||
void kernel_main() { | ||
uint32_t src_addr = get_arg_val<uint32_t>(0); | ||
uint32_t dst_addr = get_arg_val<uint32_t>(1); | ||
|
||
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); | ||
constexpr uint32_t cb_id_intermed0 = get_compile_time_arg_val(1); | ||
constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(2); | ||
constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(3); | ||
constexpr uint32_t in0_stick_size = get_compile_time_arg_val(4); | ||
constexpr uint32_t out_stick_size = get_compile_time_arg_val(5); | ||
constexpr uint32_t B = get_compile_time_arg_val(6); | ||
constexpr uint32_t C = get_compile_time_arg_val(7); | ||
constexpr uint32_t H = get_compile_time_arg_val(8); | ||
constexpr uint32_t W = get_compile_time_arg_val(9); | ||
constexpr uint32_t dim = get_compile_time_arg_val(10); | ||
constexpr uint32_t all = get_compile_time_arg_val(11); | ||
|
||
const InterleavedAddrGen<src0_is_dram> s0 = {.bank_base_address = src_addr, .page_size = in0_stick_size}; | ||
|
||
const InterleavedAddrGen<dst_is_dram> s_out = {.bank_base_address = dst_addr, .page_size = out_stick_size}; | ||
|
||
|
||
// Use cb as L1 scratch memory | ||
uint32_t out_addr = get_write_ptr(cb_id_intermed0); | ||
volatile tt_l1_ptr uint32_t* max_vals = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(out_addr); | ||
|
||
// Use cb as L1 scratch memory | ||
uint32_t cb_addr = get_write_ptr(cb_id_in0); | ||
volatile tt_l1_ptr uint16_t* stick = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(cb_addr); | ||
|
||
//cb_reserve_back(cb_id_intermed0, C*H*W); | ||
//uint32_t indicies_addr = get_write_ptr(cb_id_intermed0); | ||
//volatile tt_l1_ptr uint32_t* max_indices = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(cb_addr); | ||
|
||
uint32_t max_index = 0; | ||
uint32_t max_val = 0; | ||
uint32_t index_counter = 0; | ||
for(uint32_t l = 0; l < B; l ++) { | ||
for(uint32_t k = 0; k < C; k++) { | ||
for(uint32_t j = 0; j < H; j++) { | ||
// load stick | ||
// DPRINT << (l*C*H + k*H + j) << ENDL(); | ||
noc_async_read_page(l*C*H + k*H + j, s0, cb_addr); | ||
noc_async_read_barrier(); | ||
for(uint32_t i = 0; i < W; i++) { | ||
if constexpr (all) { | ||
uint16_t val = stick[i]; | ||
if(bfloat16_greater(val, max_val)) { | ||
// DPRINT << "new max " << HEX() << (val) << "\nGT old max " << (max_val) << ENDL(); | ||
// DPRINT << "new idx " << DEC() << (index_counter) << "\nGT old idx " << (max_index) << ENDL(); | ||
// DPRINT << DEC() << (max_index) << ENDL(); | ||
max_index = index_counter; | ||
max_val = val; | ||
} | ||
// DPRINT << "[" << index_counter << "] = " << HEX() << (val) << ENDL(); | ||
index_counter++; | ||
} | ||
else { | ||
/* | ||
if(dim == 3) { | ||
if(bfloat16_greater(bfloat16_max_vals[l][k][j] < stick[i]) { | ||
bfloat16_max_vals[l][k][j] = stick[i]; | ||
max_indices[l][k][j] = i; | ||
} | ||
} | ||
else if(dim == 2) { | ||
if(bfloat16_max_vals[l][k][i] < stick[i]) { | ||
bfloat16_max_vals[l][k][i] = stick[i]; | ||
max_indices[l][k][i] = j; | ||
} | ||
} | ||
else if(dim == 1) { | ||
if(bfloat16_max_vals[l][j][i] < stick[i]) { | ||
bfloat16_max_vals[l][j][i] = stick[i]; | ||
max_indices[l][j][i] = k; | ||
} | ||
} | ||
else if(dim == 0) { | ||
if(bfloat16_greater(stick[i], bfloat16_max_vals[k][j][i])) { | ||
bfloat16_max_vals[k][j][i] = stick[i]; | ||
max_indices[k][j][i] = l; | ||
} | ||
} | ||
*/ | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// TODO: Generalize write for argmax for other dims | ||
max_vals[0] = max_index; | ||
uint64_t dst_noc_addr = get_noc_addr(0, s_out); | ||
noc_async_write(out_addr, dst_noc_addr, out_stick_size); | ||
noc_async_write_barrier(); | ||
} |
Oops, something went wrong.