Skip to content

Commit

Permalink
#13857: Add WHB0 support for binary sfpu ops
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT authored and KalaivaniMCW committed Nov 28, 2024
1 parent ec4bd12 commit 18de632
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -1,70 +1,8 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "llk_math_common_api.h"
#include "llk_math_eltwise_binary_sfpu.h"

/*************************************************************************
* LLK ELTWISE BINARY SFPU
*************************************************************************/

template <SfpuType sfpu_op, bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu(
const uint operand,
uint dst_index_a,
uint dst_index_b,
int vector_mode = (int)VectorMode::RC,
uint param0 = 0,
uint param1 = 0,
uint param2 = 0,
uint param3 = 0,
uint param4 = 0,
uint param5 = 0) {
const std::uint32_t operand_id = get_operand_id(0);
const std::uint32_t num_faces = get_operand_num_faces(operand_id);
const std::uint32_t face_r_dim = get_operand_face_r_dim(operand_id);

_llk_math_eltwise_binary_sfpu_<sfpu_op, APPROXIMATE, DST_SYNC_MODE>(
face_r_dim, num_faces, dst_index_a, dst_index_b, vector_mode, param0, param1, param2, param3, param4, param5);
}

template <SfpuType sfpu_op, bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_init(
uint param0 = 0, uint param1 = 0, uint param2 = 0, uint param3 = 0, uint param4 = 0, uint param5 = 0) {
_llk_math_eltwise_binary_sfpu_init_<sfpu_op, APPROXIMATE>(param0, param1, param2, param3, param4, param5);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_quant_int32(
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_binary_sfpu<SfpuType::quant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) {
llk_math_eltwise_binary_sfpu_init<SfpuType::quant_int32, APPROXIMATE>(zero_point);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_requant_int32(
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_binary_sfpu<SfpuType::requant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_point) {
llk_math_eltwise_binary_sfpu_init<SfpuType::requant_int32, APPROXIMATE>(zero_point);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_dequant_int32(
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_binary_sfpu<SfpuType::dequant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_dequant_int32_init(const uint zero_point) {
llk_math_eltwise_binary_sfpu_init<SfpuType::dequant_int32, APPROXIMATE>(zero_point);
}
#include "llk_math_eltwise_binary_sfpu_init.h"
#include "llk_math_eltwise_binary_sfpu_binop.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "sfpi.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int BINOP_MODE, int ITERATIONS = 8>
inline void calculate_sfpu_binary(const uint dst_offset)
{
_calculate_sfpu_binary_<APPROXIMATION_MODE, BINOP_MODE, ITERATIONS>(dst_offset);
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_math_eltwise_binary_sfpu_init.h"
#include "llk_math_eltwise_binary_sfpu_params.h"
#include "ckernel_sfpu_binary.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE, int binop_mode>
inline void llk_math_eltwise_binary_sfpu_binop(uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) {
llk_math_eltwise_binary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_sfpu_binary<APPROXIMATE, binop_mode>,
dst_index0,
dst_index1,
vector_mode);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_binop_init() {
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>();
}

} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_sfpu_types.h"
#include "llk_math_eltwise_binary_sfpu.h"

namespace ckernel {

template <SfpuType sfpu_op, bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_init() {
_llk_math_eltwise_binary_sfpu_init_<sfpu_op>();
}

template <SfpuType sfpu_op, bool APPROXIMATE, class F, class ... ARGS>
inline void llk_math_eltwise_binary_sfpu_init(F&& init_func, ARGS&& ... args) {
_llk_math_eltwise_binary_sfpu_init_<sfpu_op>();
init_func(static_cast<ARGS&&>(args)...);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "llk_sfpu_types.h"
#include "llk_math_eltwise_binary_sfpu.h"

template <bool APPROXIMATE, class F, class ... ARGS>
inline void llk_math_eltwise_binary_sfpu_params(
F&& sfpu_func,
uint dst_index0,
uint dst_index1,
int vector_mode = (int)VectorMode::RC,
ARGS&& ... args) {

uint dst_index = (dst_index0 <= dst_index1) ? dst_index0 : dst_index1;
uint dst_offset = (dst_index0 > dst_index1) ? dst_index0 - dst_index1 : dst_index1 - dst_index0;

math::set_dst_write_addr<DstTileLayout::Default, DstTileShape::Tile32x32>(dst_index);
math::set_addr_mod_base();

TTI_STALLWAIT(p_stall::STALL_SFPU, p_stall::MATH);
if (vector_mode == (int)VectorMode::R) {
// Do a row vector, Face0 + Face1 -- first iteration (first row)
const int ITERATIONS = 1;
#pragma GCC unroll 0
for (int face = 0; face < 2; face++) {
sfpu_func(dst_offset, static_cast<ARGS&&>(args)...);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
}
// Skip the next 2 faces
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
} else if (vector_mode == (int)VectorMode::C) {
// Do a column vector, Face0 + Face2 -- All iterations for full face
#pragma GCC unroll 0
for (int face = 0; face < 2; face++) {
sfpu_func(dst_offset, static_cast<ARGS&&>(args)...);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
}
} else if (vector_mode == (int)VectorMode::RC) {
// Do all four faces, and iterate through all 4 blocks of 4 rows each
#pragma GCC unroll 0
for (int face = 0; face < 4; face++) {
sfpu_func(dst_offset, static_cast<ARGS&&>(args)...);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
}
} else {
sfpu_func(dst_offset, static_cast<ARGS&&>(args)...);
}
math::clear_dst_reg_addr();

TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::WAIT_SFPU);
math::clear_addr_mod_base();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_math_eltwise_binary_sfpu_init.h"
#include "llk_math_eltwise_binary_sfpu_params.h"
// #include "ckernel_sfpu_.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_quant_int32(
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_binary_sfpu<SfpuType::quant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) {
llk_math_eltwise_binary_sfpu_init<SfpuType::quant_int32, APPROXIMATE>(zero_point);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_requant_int32(
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_binary_sfpu<SfpuType::requant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_point) {
llk_math_eltwise_binary_sfpu_init<SfpuType::requant_int32, APPROXIMATE>(zero_point);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_dequant_int32(
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_binary_sfpu<SfpuType::dequant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode);
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_dequant_int32_init(const uint zero_point) {
llk_math_eltwise_binary_sfpu_init<SfpuType::dequant_int32, APPROXIMATE>(zero_point);
}


} // namespace ckernel
59 changes: 59 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_binary_sfpu_binop.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif

namespace ckernel {

/**
* Performs an elementwise binop operation with the two inputs: y = binop(x0,x1)
* Output overwrites first operand in DST.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |----------------|-----------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst0 | The index of the tile in DST register buffer to use as first operand | uint32_t | Must be less than the size of the DST register buffer | True |
* | idst1 | The index of the tile in DST register buffer to use as second operand | uint32_t | Must be less than the size of the DST register buffer | True |
*/
enum { ADD_BINARY = 0, SUB_BINARY = 1, MUL_BINARY = 2, DIV_BINARY = 3, RSUB_BINARY = 4, POW_BINARY = 5 };
ALWI void add_binary_tile(uint32_t idst0, uint32_t idst1) {
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, ADD_BINARY>(idst0, idst1)));
}

ALWI void sub_binary_tile(uint32_t idst0, uint32_t idst1) {
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, SUB_BINARY>(idst0, idst1)));
}

ALWI void mul_binary_tile(uint32_t idst0, uint32_t idst1) {
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, MUL_BINARY>(idst0, idst1)));
}

ALWI void div_binary_tile(uint32_t idst0, uint32_t idst1) {
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, DIV_BINARY>(idst0, idst1)));
}

ALWI void rsub_binary_tile(uint32_t idst0, uint32_t idst1) {
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, RSUB_BINARY>(idst0, idst1)));
}

ALWI void power_binary_tile(uint32_t idst0, uint32_t idst1) {
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, POW_BINARY>(idst0, idst1)));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void eltwise_binop_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX>())); }

} // namespace ckernel

0 comments on commit 18de632

Please sign in to comment.