Skip to content

Commit

Permalink
#12124: support moreh_nll_loss support large weight (#12126)
Browse files Browse the repository at this point in the history
* #12124: support moreh_nll_loss support large weight
  • Loading branch information
hschoi4448 authored Aug 31, 2024
1 parent 5a4fc17 commit c2c334a
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,7 @@ def run_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weight

@pytest.mark.parametrize(
"shape",
[
(5, 10),
(3000, 100),
(200, 100, 90),
(5, 50, 2, 7, 50, 70),
],
[[5, 10], [3000, 100], [200, 100, 90], [5, 50, 2, 7, 50, 70]],
)
@pytest.mark.parametrize("ignore_index", [1])
@pytest.mark.parametrize("reduction", ["mean", "sum"])
Expand All @@ -158,9 +153,9 @@ def test_moreh_nll_loss(shape, ignore_index, reduction, none_weight, device):
@pytest.mark.parametrize(
"shape",
[
(5, 10),
(5, 6, 7),
(5, 6, 8, 9),
[5, 10],
[5, 6, 7],
[5, 6, 8, 9],
],
)
@pytest.mark.parametrize("reduction", ["mean", "sum"])
Expand All @@ -172,15 +167,17 @@ def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_prog

for _ in range(2):
run_moreh_nll_loss(shape, ignore_idx, reduction, none_weight, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_npu(torch_dummy, device)


@pytest.mark.parametrize(
"shape",
[
(400, 300),
(20, 300, 320),
(3, 4, 32 * 5, 32 * 6),
(5, 2, 5, 40, 70),
[400, 300],
[20, 300, 320],
[3, 4, 32 * 5, 32 * 6],
[5, 2, 5, 40, 70],
],
)
@pytest.mark.parametrize("ignore_index", [1])
Expand All @@ -195,9 +192,9 @@ def test_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weigh
@pytest.mark.parametrize(
"shape",
[
(2, 3),
(2, 3, 4),
(2, 3, 5, 4),
[2, 3],
[2, 3, 4],
[2, 3, 5, 4],
],
)
@pytest.mark.parametrize("reduction_mean", [True, False])
Expand All @@ -209,14 +206,16 @@ def test_moreh_nll_loss_backward_test_callback(shape, reduction_mean, none_weigh

for _ in range(2):
run_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weight, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_npu(torch_dummy, device)


@pytest.mark.parametrize(
"shape",
[
(5, 10),
(10, 20, 30),
(10, 20, 30, 40),
[5, 10],
[10, 20, 30],
[10, 20, 30, 40],
],
)
@pytest.mark.parametrize("ignore_index", [1])
Expand All @@ -236,9 +235,9 @@ def test_moreh_nll_loss_compute_kernel_options(
@pytest.mark.parametrize(
"shape",
[
(5, 10),
(10, 20, 30),
(10, 20, 30, 40),
[5, 10],
[10, 20, 30],
[10, 20, 30, 40],
],
)
@pytest.mark.parametrize("reduction_mean", [True, False])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"

void kernel_main() {
uint32_t i = 0;
auto target_addr = get_arg_val<uint32_t>(i++);
auto weight_addr = get_arg_val<uint32_t>(i++);
auto ignore_index = static_cast<int32_t>(get_arg_val<uint32_t>(i++));
auto num_units_per_core = get_arg_val<uint32_t>(i++);
auto start_id = get_arg_val<uint32_t>(i++);
auto N = get_arg_val<uint32_t>(i++);
auto C = get_arg_val<uint32_t>(i++);
auto weight_num_tile = get_arg_val<uint32_t>(i++);
auto element_size = get_arg_val<uint32_t>(i++);
auto target_element_size = get_arg_val<uint32_t>(i++);

constexpr uint32_t cb_target = tt::CB::c_in0;
constexpr uint32_t cb_weight = tt::CB::c_in1;

constexpr uint32_t cb_output = tt::CB::c_out0;

// ublocks size defined in tiles
const uint32_t target_tile_bytes = get_tile_size(cb_target);

constexpr bool target_is_dram = get_compile_time_arg_val(0) == 1;
#if defined(WEIGHT)
constexpr bool weight_is_dram = get_compile_time_arg_val(1) == 1;
constexpr bool weight_has_value = get_compile_time_arg_val(2) == 1;
#endif

const InterleavedAddrGen<target_is_dram> addrg_target = {
.bank_base_address = target_addr, .page_size = target_tile_bytes};

#if defined(WEIGHT)
const uint32_t weight_tile_bytes = get_tile_size(cb_weight);
auto weight_element_size = weight_tile_bytes / 1024;
const DataFormat weight_data_format = get_dataformat(cb_weight);
const InterleavedAddrGen<weight_is_dram> addrg_weight = {
.bank_base_address = weight_addr,
.page_size = weight_tile_bytes,
};
#endif

constexpr uint32_t onetile = 1;

Scalar one, zero;
one.f = 1.0f;
zero.f = 0.0f;

const auto u16_one = uint16_t(one.u >> 16);
const auto u16_zero = uint16_t(zero.u >> 16);

uint32_t end_id = start_id + num_units_per_core;
for (uint32_t i = start_id; i < end_id; ++i) {
// target: (N, d1, d2, .. dk)
uint32_t target_noc_id = i;
read_tile(cb_target, addrg_target, target_noc_id);

cb_reserve_back(cb_output, onetile);
cb_wait_front(cb_target, onetile);

auto output_l1_ptr = get_write_ptr<uint16_t>(cb_output);
auto target_l1_ptr = get_read_ptr<int32_t>(cb_target);

for (uint32_t h = 0; h < TILE_HEIGHT; h++) {
for (uint32_t w = 0; w < TILE_WIDTH; w++) {
uint32_t inout_idx = h * TILE_WIDTH + w;
int32_t target_val = target_l1_ptr[inout_idx];
if (target_val != ignore_index) {
if (0 <= target_val && target_val < static_cast<int32_t>(C)) {
#if defined(WEIGHT)
uint32_t target_idx = target_val;

uint32_t noc_id = target_idx / TILE_WIDTH;
uint32_t weight_tilized_idx = get_tilized_idx(0, target_idx);
read_value(cb_weight, addrg_weight, noc_id, weight_tilized_idx);

cb_wait_front(cb_weight, onetile);
auto weight_l1_ptr = get_read_ptr<uint16_t>(cb_weight);

output_l1_ptr[inout_idx] = weight_l1_ptr[weight_tilized_idx];

cb_pop_front(cb_weight, onetile);
#else
output_l1_ptr[inout_idx] = u16_one;
#endif
} else {
output_l1_ptr[inout_idx] = u16_zero;
}
} else {
output_l1_ptr[inout_idx] = u16_zero;
}
}
}
cb_push_back(cb_output, onetile);

cb_pop_front(cb_target, onetile);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/run_operation.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/work_split.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "ttnn/run_operation.hpp"

using namespace tt::constants;
using namespace std;
Expand Down Expand Up @@ -48,28 +48,56 @@ operation::ProgramWithCallbacks moreh_nll_loss_step1_impl(
auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] =
split_work_to_cores(core_range, units_to_divide);

auto arch = target.device()->arch();
auto* device = target.device();
auto arch = device->arch();
auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] =
get_compute_kernel_config_args(arch, compute_kernel_config);

Program program = Program();

// create circular buffers
tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());

auto fp32_dest_acc_en_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format;

uint32_t weight_num_tile = div_up(channel_size, TILE_WIDTH);
CreateCircularBuffer(
program,
all_cores,
data_format,
{
{CB::c_in0, 1, tt::DataFormat::Int32}, // traget
{CB::c_in1, weight_num_tile}, // weight
{CB::c_intermed0, 1, fp32_dest_acc_en_data_format}, // tmp_weight
{CB::c_out0, 1}, // output
});
const auto target_data_format = tt_metal::datatype_to_dataformat_converter(target.get_dtype());
const auto data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());
const auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format;

const auto target_tile_size = tt_metal::detail::TileSize(target_data_format);
const auto data_tile_size = tt_metal::detail::TileSize(data_format);
const auto intermed_tile_size = tt_metal::detail::TileSize(intermed_data_format);

const uint32_t available_L1 = device->l1_size_per_core() - L1_UNRESERVED_BASE;

uint32_t target_num_tile = 1;
uint32_t weight_num_tile = weight_has_value ? div_up(channel_size, TILE_WIDTH) : 0;
uint32_t intermed_num_tile = 1;
uint32_t output_num_tile = 1;
uint32_t cb_usage = target_num_tile * target_tile_size + weight_num_tile * data_tile_size +
intermed_num_tile * intermed_tile_size + output_num_tile * data_tile_size;

const bool use_large_algorithm = cb_usage >= available_L1;;

if (use_large_algorithm) {
CreateCircularBuffer(
program,
all_cores,
data_format,
{
{CB::c_in0, 1, tt::DataFormat::Int32}, // traget
{CB::c_in1, 1}, // weight
{CB::c_intermed0, 1, intermed_data_format}, // tmp_weight
{CB::c_out0, 1}, // output
});
} else {
CreateCircularBuffer(
program,
all_cores,
data_format,
{
{CB::c_in0, 1, tt::DataFormat::Int32}, // traget
{CB::c_in1, weight_num_tile}, // weight
{CB::c_intermed0, 1, intermed_data_format}, // tmp_weight
{CB::c_out0, 1}, // output
});
}

// create read/wrtie kernel
const std::vector<uint32_t> reader_compile_time_args{
Expand All @@ -89,19 +117,17 @@ operation::ProgramWithCallbacks moreh_nll_loss_step1_impl(
if (fp32_dest_acc_en) {
reader_defines["FP32_DEST_ACC_EN"] = 1;
}

auto reader_kernel_id = CreateReadKernel(
program,
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1.cpp",
all_cores,
reader_compile_time_args,
reader_defines);
auto writer_kernel_id = CreateWriteKernel(
program,
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/writer_moreh_nll_loss_step1.cpp",
all_cores,
writer_compile_time_args,
writer_defines);
const auto reader_kernel_file =
use_large_algorithm ? "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1_large.cpp"
: "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1.cpp";
const auto writer_kernel_file =
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/"
"writer_moreh_nll_loss_step1.cpp";

auto reader_kernel_id =
CreateReadKernel(program, reader_kernel_file, all_cores, reader_compile_time_args, reader_defines);
auto writer_kernel_id =
CreateWriteKernel(program, writer_kernel_file, all_cores, writer_compile_time_args, writer_defines);

const auto target_addr = target.buffer()->address();
const auto weight_addr = weight_has_value ? weight.value().buffer()->address() : 0;
Expand Down

0 comments on commit c2c334a

Please sign in to comment.