-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#13857: Add WHB0 support for binary sfpu ops
- Loading branch information
1 parent
ec4bd12
commit 18de632
Showing
8 changed files
with
251 additions
and
66 deletions.
There are no files selected for viewing
68 changes: 3 additions & 65 deletions
68
tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,8 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
#include "llk_math_common_api.h" | ||
#include "llk_math_eltwise_binary_sfpu.h" | ||
|
||
/************************************************************************* | ||
* LLK ELTWISE BINARY SFPU | ||
*************************************************************************/ | ||
|
||
template <SfpuType sfpu_op, bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu( | ||
const uint operand, | ||
uint dst_index_a, | ||
uint dst_index_b, | ||
int vector_mode = (int)VectorMode::RC, | ||
uint param0 = 0, | ||
uint param1 = 0, | ||
uint param2 = 0, | ||
uint param3 = 0, | ||
uint param4 = 0, | ||
uint param5 = 0) { | ||
const std::uint32_t operand_id = get_operand_id(0); | ||
const std::uint32_t num_faces = get_operand_num_faces(operand_id); | ||
const std::uint32_t face_r_dim = get_operand_face_r_dim(operand_id); | ||
|
||
_llk_math_eltwise_binary_sfpu_<sfpu_op, APPROXIMATE, DST_SYNC_MODE>( | ||
face_r_dim, num_faces, dst_index_a, dst_index_b, vector_mode, param0, param1, param2, param3, param4, param5); | ||
} | ||
|
||
template <SfpuType sfpu_op, bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_init( | ||
uint param0 = 0, uint param1 = 0, uint param2 = 0, uint param3 = 0, uint param4 = 0, uint param5 = 0) { | ||
_llk_math_eltwise_binary_sfpu_init_<sfpu_op, APPROXIMATE>(param0, param1, param2, param3, param4, param5); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_quant_int32( | ||
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { | ||
llk_math_eltwise_binary_sfpu<SfpuType::quant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) { | ||
llk_math_eltwise_binary_sfpu_init<SfpuType::quant_int32, APPROXIMATE>(zero_point); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_requant_int32( | ||
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { | ||
llk_math_eltwise_binary_sfpu<SfpuType::requant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_point) { | ||
llk_math_eltwise_binary_sfpu_init<SfpuType::requant_int32, APPROXIMATE>(zero_point); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_dequant_int32( | ||
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { | ||
llk_math_eltwise_binary_sfpu<SfpuType::dequant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_dequant_int32_init(const uint zero_point) { | ||
llk_math_eltwise_binary_sfpu_init<SfpuType::dequant_int32, APPROXIMATE>(zero_point); | ||
} | ||
#include "llk_math_eltwise_binary_sfpu_init.h" | ||
#include "llk_math_eltwise_binary_sfpu_binop.h" |
24 changes: 24 additions & 0 deletions
24
tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ckernel.h" | ||
#include "ckernel_defs.h" | ||
#include "noc_nonblocking_api.h" | ||
#include "sfpi.h" | ||
|
||
using namespace sfpi; | ||
|
||
namespace ckernel { | ||
namespace sfpu { | ||
|
||
template <bool APPROXIMATION_MODE, int BINOP_MODE, int ITERATIONS = 8> | ||
inline void calculate_sfpu_binary(const uint dst_offset) | ||
{ | ||
_calculate_sfpu_binary_<APPROXIMATION_MODE, BINOP_MODE, ITERATIONS>(dst_offset); | ||
} | ||
|
||
} // namespace sfpu | ||
} // namespace ckernel |
29 changes: 29 additions & 0 deletions
29
tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "llk_math_eltwise_binary_sfpu_init.h" | ||
#include "llk_math_eltwise_binary_sfpu_params.h" | ||
#include "ckernel_sfpu_binary.h" | ||
|
||
namespace ckernel { | ||
|
||
// New LLK SFPU APIs | ||
|
||
template <bool APPROXIMATE, int binop_mode> | ||
inline void llk_math_eltwise_binary_sfpu_binop(uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { | ||
llk_math_eltwise_binary_sfpu_params<APPROXIMATE>( | ||
ckernel::sfpu::calculate_sfpu_binary<APPROXIMATE, binop_mode>, | ||
dst_index0, | ||
dst_index1, | ||
vector_mode); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_binop_init() { | ||
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>(); | ||
} | ||
|
||
} // namespace ckernel |
23 changes: 23 additions & 0 deletions
23
tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_init.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "llk_sfpu_types.h" | ||
#include "llk_math_eltwise_binary_sfpu.h" | ||
|
||
namespace ckernel { | ||
|
||
template <SfpuType sfpu_op, bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_init() { | ||
_llk_math_eltwise_binary_sfpu_init_<sfpu_op>(); | ||
} | ||
|
||
template <SfpuType sfpu_op, bool APPROXIMATE, class F, class ... ARGS> | ||
inline void llk_math_eltwise_binary_sfpu_init(F&& init_func, ARGS&& ... args) { | ||
_llk_math_eltwise_binary_sfpu_init_<sfpu_op>(); | ||
init_func(static_cast<ARGS&&>(args)...); | ||
} | ||
|
||
} |
63 changes: 63 additions & 0 deletions
63
...etal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_params.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
#include "llk_sfpu_types.h" | ||
#include "llk_math_eltwise_binary_sfpu.h" | ||
|
||
template <bool APPROXIMATE, class F, class ... ARGS> | ||
inline void llk_math_eltwise_binary_sfpu_params( | ||
F&& sfpu_func, | ||
uint dst_index0, | ||
uint dst_index1, | ||
int vector_mode = (int)VectorMode::RC, | ||
ARGS&& ... args) { | ||
|
||
uint dst_index = (dst_index0 <= dst_index1) ? dst_index0 : dst_index1; | ||
uint dst_offset = (dst_index0 > dst_index1) ? dst_index0 - dst_index1 : dst_index1 - dst_index0; | ||
|
||
math::set_dst_write_addr<DstTileLayout::Default, DstTileShape::Tile32x32>(dst_index); | ||
math::set_addr_mod_base(); | ||
|
||
TTI_STALLWAIT(p_stall::STALL_SFPU, p_stall::MATH); | ||
if (vector_mode == (int)VectorMode::R) { | ||
// Do a row vector, Face0 + Face1 -- first iteration (first row) | ||
const int ITERATIONS = 1; | ||
#pragma GCC unroll 0 | ||
for (int face = 0; face < 2; face++) { | ||
sfpu_func(dst_offset, static_cast<ARGS&&>(args)...); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
} | ||
// Skip the next 2 faces | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
} else if (vector_mode == (int)VectorMode::C) { | ||
// Do a column vector, Face0 + Face2 -- All iterations for full face | ||
#pragma GCC unroll 0 | ||
for (int face = 0; face < 2; face++) { | ||
sfpu_func(dst_offset, static_cast<ARGS&&>(args)...); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
} | ||
} else if (vector_mode == (int)VectorMode::RC) { | ||
// Do all four faces, and iterate through all 4 blocks of 4 rows each | ||
#pragma GCC unroll 0 | ||
for (int face = 0; face < 4; face++) { | ||
sfpu_func(dst_offset, static_cast<ARGS&&>(args)...); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); | ||
} | ||
} else { | ||
sfpu_func(dst_offset, static_cast<ARGS&&>(args)...); | ||
} | ||
math::clear_dst_reg_addr(); | ||
|
||
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::WAIT_SFPU); | ||
math::clear_addr_mod_base(); | ||
} |
49 changes: 49 additions & 0 deletions
49
tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_quant.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "llk_math_eltwise_binary_sfpu_init.h" | ||
#include "llk_math_eltwise_binary_sfpu_params.h" | ||
// #include "ckernel_sfpu_.h" | ||
|
||
namespace ckernel { | ||
|
||
// New LLK SFPU APIs | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_quant_int32( | ||
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { | ||
llk_math_eltwise_binary_sfpu<SfpuType::quant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) { | ||
llk_math_eltwise_binary_sfpu_init<SfpuType::quant_int32, APPROXIMATE>(zero_point); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_requant_int32( | ||
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { | ||
llk_math_eltwise_binary_sfpu<SfpuType::requant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_point) { | ||
llk_math_eltwise_binary_sfpu_init<SfpuType::requant_int32, APPROXIMATE>(zero_point); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_dequant_int32( | ||
uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { | ||
llk_math_eltwise_binary_sfpu<SfpuType::dequant_int32, APPROXIMATE>(dst_index_a, dst_index_b, vector_mode); | ||
} | ||
|
||
template <bool APPROXIMATE> | ||
inline void llk_math_eltwise_binary_sfpu_dequant_int32_init(const uint zero_point) { | ||
llk_math_eltwise_binary_sfpu_init<SfpuType::dequant_int32, APPROXIMATE>(zero_point); | ||
} | ||
|
||
|
||
} // namespace ckernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "compute_kernel_api/common_globals.h" | ||
#ifdef TRISC_MATH | ||
#include "llk_math_eltwise_binary_sfpu_binop.h" | ||
#define MAIN math_main() | ||
#define MATH(x) x | ||
#else | ||
#define MATH(x) | ||
#endif | ||
|
||
namespace ckernel { | ||
|
||
/** | ||
* Performs an elementwise binop operation with the two inputs: y = binop(x0,x1) | ||
* Output overwrites first operand in DST. | ||
* | ||
* Return value: None | ||
* | ||
* | Argument | Description | Type | Valid Range | Required | | ||
* |----------------|-----------------------------------------------------------------------|----------|-------------------------------------------------------|----------| | ||
* | idst0 | The index of the tile in DST register buffer to use as first operand | uint32_t | Must be less than the size of the DST register buffer | True | | ||
* | idst1 | The index of the tile in DST register buffer to use as second operand | uint32_t | Must be less than the size of the DST register buffer | True | | ||
*/ | ||
enum { ADD_BINARY = 0, SUB_BINARY = 1, MUL_BINARY = 2, DIV_BINARY = 3, RSUB_BINARY = 4, POW_BINARY = 5 }; | ||
ALWI void add_binary_tile(uint32_t idst0, uint32_t idst1) { | ||
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, ADD_BINARY>(idst0, idst1))); | ||
} | ||
|
||
ALWI void sub_binary_tile(uint32_t idst0, uint32_t idst1) { | ||
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, SUB_BINARY>(idst0, idst1))); | ||
} | ||
|
||
ALWI void mul_binary_tile(uint32_t idst0, uint32_t idst1) { | ||
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, MUL_BINARY>(idst0, idst1))); | ||
} | ||
|
||
ALWI void div_binary_tile(uint32_t idst0, uint32_t idst1) { | ||
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, DIV_BINARY>(idst0, idst1))); | ||
} | ||
|
||
ALWI void rsub_binary_tile(uint32_t idst0, uint32_t idst1) { | ||
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, RSUB_BINARY>(idst0, idst1))); | ||
} | ||
|
||
ALWI void power_binary_tile(uint32_t idst0, uint32_t idst1) { | ||
MATH((llk_math_eltwise_binary_sfpu_binop<APPROX, POW_BINARY>(idst0, idst1))); | ||
} | ||
|
||
/** | ||
* Please refer to documentation for any_init. | ||
*/ | ||
ALWI void eltwise_binop_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX>())); } | ||
|
||
} // namespace ckernel |
Submodule tt_llk_wormhole_b0
updated
3 files
+1 −0 | common/inc/ckernel_sfpu.h | |
+119 −0 | common/inc/sfpu/ckernel_sfpu_binary.h | |
+1 −0 | llk_lib/llk_math_eltwise_binary_sfpu.h |