Skip to content

Commit

Permalink
#14032: add bfp_pack_precise flag to compute config
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Nov 6, 2024
1 parent 5209f9d commit 73b2dd2
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 8 deletions.
1 change: 1 addition & 0 deletions tt_metal/impl/kernels/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ void ComputeKernel::set_build_options(JitBuildOptions &build_options) const {
build_options.fp32_dest_acc_en = this->config_.fp32_dest_acc_en;
build_options.dst_full_sync_en = this->config_.dst_full_sync_en;
build_options.unpack_to_dest_mode = this->config_.unpack_to_dest_mode;
build_options.bfp8_pack_precise = this->config_.bfp8_pack_precise;
}

void DataMovementKernel::generate_binaries(Device *device, JitBuildOptions &build_options) const {
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/kernels/kernel_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct ComputeConfig {
bool fp32_dest_acc_en = false;
bool dst_full_sync_en = false;
std::vector<UnpackToDestMode> unpack_to_dest_mode;
bool bfp8_pack_precise = false;
bool math_approx_mode = false;
std::vector<uint32_t> compile_args;
// Will cause CompileProgram to emit a file hlk_defines_generated.h
Expand Down
12 changes: 7 additions & 5 deletions tt_metal/jit_build/data_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ const DataFormat get_single_pack_src_format(
DataFormat output_format,
DataFormat unpack_conditional_dst_format,
bool fp32_dest_acc_en,
bool bfp8_pack_precise,
bool int_fpu_en,
tt::ARCH arch) {

Expand Down Expand Up @@ -334,7 +335,7 @@ const DataFormat get_single_pack_src_format(
TT_FATAL(arch != tt::ARCH::GRAYSKULL, "Dest Fp32 mode is not supported for arch grayskull");

if (is_bfp_format(output_format)) {
pack_src_format = is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16;
pack_src_format = bfp8_pack_precise ? DataFormat::Float32 : (is_exp_b_format(output_format) ? DataFormat::Bfp8_b : DataFormat::Bfp8);
} else if(is_exp_b_format(output_format) || (output_format == DataFormat::Float32)) {
pack_src_format = output_format;
} else if(output_format == DataFormat::Float16){
Expand Down Expand Up @@ -374,7 +375,7 @@ const DataFormat get_single_pack_src_format(
}
pack_src_format = unpack_conditional_dst_format;
} else if (is_bfp_format(output_format)) {
pack_src_format = is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16;
pack_src_format = bfp8_pack_precise ? (is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16) : (is_exp_b_format(output_format) ? DataFormat::Bfp8_b : DataFormat::Bfp8);
} else {
pack_src_format = output_format;
}
Expand All @@ -390,7 +391,7 @@ const DataFormat get_single_pack_src_format(
DataFormat pack_src_format_tmp = output_format;

if (is_bfp_format(output_format)) {
pack_src_format_tmp = is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16;
pack_src_format_tmp = bfp8_pack_precise ? (is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16) : (is_exp_b_format(output_format) ? DataFormat::Bfp8_b : DataFormat::Bfp8);
}

if (pack_src_format_tmp != DataFormat::Float32) {
Expand All @@ -413,6 +414,7 @@ std::vector<DataFormat> get_pack_src_formats(
DataFormat output_formats[NUM_OPERANDS],
DataFormat unpack_conditional_dst_format,
bool fp32_dest_acc_en,
bool bfp8_pack_precise,
bool int_fpu_en,
tt::ARCH arch
) {
Expand All @@ -421,14 +423,14 @@ std::vector<DataFormat> get_pack_src_formats(
std::vector<DataFormat> pack_src_formats;
DataFormat pack_src_format;
for (int i = 0; i < NUM_OPERANDS; i++) {
pack_src_format = get_single_pack_src_format(input_formats[i], pack_output_format, unpack_conditional_dst_format, fp32_dest_acc_en, int_fpu_en, arch);
pack_src_format = get_single_pack_src_format(input_formats[i], pack_output_format, unpack_conditional_dst_format, fp32_dest_acc_en, bfp8_pack_precise, int_fpu_en, arch);
pack_src_formats.push_back(pack_src_format);
}

// Intermediates
for (int i = 0; i < NUM_OPERANDS; i++) {
//Intermediates can be inputs & outputs to same op, provide same format per operand id
pack_src_format = get_single_pack_src_format(intermed_formats[i], intermed_formats[i], unpack_conditional_dst_format, fp32_dest_acc_en, int_fpu_en, arch);
pack_src_format = get_single_pack_src_format(intermed_formats[i], intermed_formats[i], unpack_conditional_dst_format, fp32_dest_acc_en, bfp8_pack_precise, int_fpu_en, arch);
pack_src_formats.push_back(pack_src_format);
}
return pack_src_formats;
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/jit_build/data_format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const DataFormat get_single_pack_src_format(DataFormat input_format, DataFormat

std::vector<DataFormat> get_unpack_src_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS]);
std::vector<DataFormat> get_unpack_dst_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, std::vector<UnpackToDestMode> unpack_to_dest_mode, bool int_fpu_en = false);
std::vector<DataFormat> get_pack_src_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, bool int_fpu_en = false, tt::ARCH arch = tt::ARCH::GRAYSKULL);
std::vector<DataFormat> get_pack_src_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, bool bfp8_pack_precise, bool int_fpu_en = false, tt::ARCH arch = tt::ARCH::GRAYSKULL);
std::vector<DataFormat> get_pack_dst_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS]);

}
5 changes: 3 additions & 2 deletions tt_metal/jit_build/genfiles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,15 @@ static void emit_unpack_data_formats(
}

static std::pair<std::vector<DataFormat>, std::vector<DataFormat>> generate_pack_data_formats(
tt_hlk_desc& desc, DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, const tt::ARCH arch) {
tt_hlk_desc& desc, DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, bool bfp8_pack_precise, const tt::ARCH arch) {
vector<DataFormat> src_formats = tt::get_pack_src_formats(
desc.input_buf_dataformat_arr,
desc.param_buf_dataformat_arr,
desc.intermediate_buf_dataformat_arr,
desc.output_buf_dataformat_arr,
unpack_conditional_dst_format,
fp32_dest_acc_en,
bfp8_pack_precise,
false,
arch);

Expand Down Expand Up @@ -398,7 +399,7 @@ static void generate_data_format_descriptors(JitBuildOptions& options, const tt:

vector<DataFormat> pack_src_formats_all_cbs, pack_dst_formats_all_cbs;
tie(pack_src_formats_all_cbs, pack_dst_formats_all_cbs) =
generate_pack_data_formats(desc, unpack_conditional_dst_format, options.fp32_dest_acc_en, arch);
generate_pack_data_formats(desc, unpack_conditional_dst_format, options.fp32_dest_acc_en, options.bfp8_pack_precise, arch);

// equalize "upack src" and "pack dst" data format vectors
// both "unpack src" and "pack dst" refer to data in L1, "unpack src" == L1, and "pack dst" == L1
Expand Down
1 change: 1 addition & 0 deletions tt_metal/jit_build/settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class JitBuildOptions {
// We can keep for future WH support, otherwise not used in GS
bool fp32_dest_acc_en;
std::vector<UnpackToDestMode> unpack_to_dest_mode;
bool bfp8_pack_precise;

bool dst_full_sync_en;

Expand Down

0 comments on commit 73b2dd2

Please sign in to comment.