Skip to content

Commit

Permalink
#14032: add bfp_pack_precise flag to compute config and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Nov 12, 2024
1 parent a3ac7ad commit 280abd0
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 8 deletions.
168 changes: 168 additions & 0 deletions tests/ttnn/unit_tests/test_bfp8_bf16_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import pytest
import torch

import ttnn

TILE_HEIGHT = 32
TILE_WIDTH = 32

cpu_layout = ttnn.Layout.ROW_MAJOR
npu_layout = ttnn.Layout.TILE


def test_typecast_bf16_to_bfp8_b(device):
torch.manual_seed(0)
shape = [32, 32]

# bf16 --> bfp8_b by cpu.
torch_bf16 = torch.randn(shape, dtype=torch.bfloat16)
bfp8_b_by_cpu = ttnn.Tensor(torch_bf16, ttnn.bfloat8_b).to(npu_layout)
cpu_version = bfp8_b_by_cpu.to(cpu_layout).to_torch()

# bf16 --> bfp8_b by npu
tt_bf16 = ttnn.Tensor(torch_bf16, ttnn.bfloat16).to(npu_layout).to(device)
bfp8_b_by_npu = ttnn.typecast(tt_bf16, ttnn.bfloat8_b)
npu_version = bfp8_b_by_npu.cpu().to(cpu_layout).to_torch()

passed = torch.equal(cpu_version, npu_version)
# print(cpu_version[0, 0:16])
# print(npu_version[0, 0:16])
assert passed


def print_mismatches(cpu, npu, num_max_print):
different_indices = (cpu != npu).nonzero(as_tuple=True)
count = 0
for idx in zip(*different_indices):
count = count + 1
print(f"idx={idx} cpu={cpu[idx]} npu={npu[idx]}")
if count > num_max_print:
break


@pytest.mark.parametrize("seed", [0, 2, 4, 6, 8])
@pytest.mark.parametrize("scale", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512])
@pytest.mark.parametrize("bias", [0, 1, 2, 4, 8, 16, 32, 64, 128])
def test_typecast_bf16_to_bfp8_b_various_input(seed, scale, bias, device):
torch.manual_seed(seed)
shape = [1024, 1024]

bias = bias
low = bias - scale
high = bias + scale
torch_bf16 = random_tensor = torch.empty(shape).uniform_(low, high).to(torch.bfloat16)

random_signs = torch.randint(0, 2, shape) * 2 - 1
torch_bf16 = torch_bf16 * random_signs

# bf16 --> bfp8_b by cpu.
bfp8_b_by_cpu = ttnn.Tensor(torch_bf16, ttnn.bfloat8_b).to(npu_layout)
cpu_version = bfp8_b_by_cpu.to(cpu_layout).to_torch()

# bf16 --> bfp8_b by npu
tt_bf16 = ttnn.Tensor(torch_bf16, ttnn.bfloat16).to(npu_layout).to(device)
bfp8_b_by_npu = ttnn.typecast(tt_bf16, ttnn.bfloat8_b)
npu_version = bfp8_b_by_npu.cpu().to(cpu_layout).to_torch()

passed = torch.equal(cpu_version, npu_version)
if not passed:
print_mismatches(cpu_version, npu_version, 16)
assert passed


@pytest.mark.parametrize("seed", [0])
@pytest.mark.parametrize("scale", [4])
@pytest.mark.parametrize("bias", [2])
# NaN becomes -Inf when converted to bfloat8_b format, skip testing
@pytest.mark.parametrize("insert_inf, insert_nan", [[True, False]]) # , [False, True], [True, True]])
def test_typecast_bf16_to_bfp8_b_with_inf_nan(seed, scale, bias, insert_inf, insert_nan, device):
torch.manual_seed(seed)
shape = [1024, 1024]

bias = bias
low = bias - scale
high = bias + scale

torch_bf16 = random_tensor = torch.empty(shape).uniform_(low, high).to(torch.bfloat16)
if insert_inf:
num_inf = torch_bf16.numel() // 8 # 16 elements are pcked into
inf_indices = torch.randint(0, torch_bf16.numel(), (num_inf,))
torch_bf16.view(-1)[inf_indices] = float("inf")
if insert_nan:
num_nan = torch_bf16.numel() // 8
nan_indices = torch.randint(0, torch_bf16.numel(), (num_nan,))
torch_bf16.view(-1)[nan_indices] = float("nan")
random_signs = torch.randint(0, 2, shape) * 2 - 1
torch_bf16 = torch_bf16 * random_signs

# bf16 --> bfp8_b by cpu.
bfp8_b_by_cpu = ttnn.Tensor(torch_bf16, ttnn.bfloat8_b).to(npu_layout)
cpu_version = bfp8_b_by_cpu.to(cpu_layout).to_torch()

# bf16 --> bfp8_b by npu
tt_bf16 = ttnn.Tensor(torch_bf16, ttnn.bfloat16).to(npu_layout).to(device)
bfp8_b_by_npu = ttnn.typecast(tt_bf16, ttnn.bfloat8_b)
npu_version = bfp8_b_by_npu.cpu().to(cpu_layout).to_torch()

passed = torch.equal(cpu_version, npu_version)
if not passed:
print_mismatches(cpu_version, npu_version, 16)
assert passed


def test_typecast_bfp8_b_to_bf16(device):
torch.manual_seed(0)
shape = [1024, 1024]

# bfp8_b --> bf16 by cpu.
torch_bf16 = torch.randn(shape, dtype=torch.bfloat16)
bfp8_b = ttnn.Tensor(torch_bf16, ttnn.bfloat8_b).to(npu_layout)
cpu_version = bfp8_b.to(cpu_layout).to_torch()

# bfp8_b --> bf16 by npu.
bf16_by_npu = ttnn.typecast(bfp8_b.to(device), ttnn.bfloat16)
npu_version = bf16_by_npu.cpu().to(cpu_layout).to_torch()

passed = torch.equal(cpu_version, npu_version)
# print(cpu_version[0, 0:16])
# print(npu_version[0, 0:16])
assert passed


def test_typecast_fp32_to_bfp8_b(device):
torch.manual_seed(0)
shape = [32, 32]

# fp32 --> bfp8_b by cpu.
torch_fp32 = torch.randn(shape, dtype=torch.float32)
bfp8_b_by_cpu = ttnn.Tensor(torch_fp32, ttnn.bfloat8_b).to(npu_layout)
cpu_version = bfp8_b_by_cpu.to(cpu_layout).to_torch()

# fp32 --> bfp8_b by npu
tt_fp32 = ttnn.Tensor(torch_fp32, ttnn.float32).to(npu_layout).to(device)
bfp8_b_by_npu = ttnn.typecast(tt_fp32, ttnn.bfloat8_b)
npu_version = bfp8_b_by_npu.cpu().to(cpu_layout).to_torch()

passed = torch.equal(cpu_version, npu_version)
# print(cpu_version[0, 0:16])
# print(npu_version[0, 0:16])
assert passed


def test_typecast_bfp8_b_to_fp32(device):
torch.manual_seed(0)
shape = [1024, 1024]

# bfp8_b --> fp32 by cpu.
torch_fp32 = torch.randn(shape, dtype=torch.float32)
bfp8_b = ttnn.Tensor(torch_fp32, ttnn.bfloat8_b).to(npu_layout)
cpu_version = bfp8_b.to(cpu_layout).to_torch()

# bfp8_b --> fp32 by npu.
fp32_by_npu = ttnn.typecast(bfp8_b.to(device), ttnn.float32)
npu_version = fp32_by_npu.cpu().to(cpu_layout).to_torch()

passed = torch.equal(cpu_version, npu_version)
# print(cpu_version[0, 0:16])
# print(npu_version[0, 0:16])
assert passed
1 change: 1 addition & 0 deletions tt_metal/impl/kernels/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ void ComputeKernel::set_build_options(JitBuildOptions &build_options) const {
build_options.fp32_dest_acc_en = this->config_.fp32_dest_acc_en;
build_options.dst_full_sync_en = this->config_.dst_full_sync_en;
build_options.unpack_to_dest_mode = this->config_.unpack_to_dest_mode;
build_options.bfp8_pack_precise = this->config_.bfp8_pack_precise;
}

void DataMovementKernel::generate_binaries(Device *device, JitBuildOptions &build_options) const {
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/kernels/kernel_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct ComputeConfig {
bool fp32_dest_acc_en = false;
bool dst_full_sync_en = false;
std::vector<UnpackToDestMode> unpack_to_dest_mode;
bool bfp8_pack_precise = false;
bool math_approx_mode = false;
std::vector<uint32_t> compile_args;
// Will cause CompileProgram to emit a file hlk_defines_generated.h
Expand Down
12 changes: 7 additions & 5 deletions tt_metal/jit_build/data_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ const DataFormat get_single_pack_src_format(
DataFormat output_format,
DataFormat unpack_conditional_dst_format,
bool fp32_dest_acc_en,
bool bfp8_pack_precise,
bool int_fpu_en,
tt::ARCH arch) {

Expand Down Expand Up @@ -334,7 +335,7 @@ const DataFormat get_single_pack_src_format(
TT_FATAL(arch != tt::ARCH::GRAYSKULL, "Dest Fp32 mode is not supported for arch grayskull");

if (is_bfp_format(output_format)) {
pack_src_format = is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16;
pack_src_format = bfp8_pack_precise ? DataFormat::Float32 : (is_exp_b_format(output_format) ? DataFormat::Bfp8_b : DataFormat::Bfp8);
} else if(is_exp_b_format(output_format) || (output_format == DataFormat::Float32)) {
pack_src_format = output_format;
} else if(output_format == DataFormat::Float16){
Expand Down Expand Up @@ -374,7 +375,7 @@ const DataFormat get_single_pack_src_format(
}
pack_src_format = unpack_conditional_dst_format;
} else if (is_bfp_format(output_format)) {
pack_src_format = is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16;
pack_src_format = bfp8_pack_precise ? (is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16) : (is_exp_b_format(output_format) ? DataFormat::Bfp8_b : DataFormat::Bfp8);
} else {
pack_src_format = output_format;
}
Expand All @@ -390,7 +391,7 @@ const DataFormat get_single_pack_src_format(
DataFormat pack_src_format_tmp = output_format;

if (is_bfp_format(output_format)) {
pack_src_format_tmp = is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16;
pack_src_format_tmp = bfp8_pack_precise ? (is_exp_b_format(output_format) ? DataFormat::Float16_b : DataFormat::Float16) : (is_exp_b_format(output_format) ? DataFormat::Bfp8_b : DataFormat::Bfp8);
}

if (pack_src_format_tmp != DataFormat::Float32) {
Expand All @@ -413,6 +414,7 @@ std::vector<DataFormat> get_pack_src_formats(
DataFormat output_formats[NUM_OPERANDS],
DataFormat unpack_conditional_dst_format,
bool fp32_dest_acc_en,
bool bfp8_pack_precise,
bool int_fpu_en,
tt::ARCH arch
) {
Expand All @@ -421,14 +423,14 @@ std::vector<DataFormat> get_pack_src_formats(
std::vector<DataFormat> pack_src_formats;
DataFormat pack_src_format;
for (int i = 0; i < NUM_OPERANDS; i++) {
pack_src_format = get_single_pack_src_format(input_formats[i], pack_output_format, unpack_conditional_dst_format, fp32_dest_acc_en, int_fpu_en, arch);
pack_src_format = get_single_pack_src_format(input_formats[i], pack_output_format, unpack_conditional_dst_format, fp32_dest_acc_en, bfp8_pack_precise, int_fpu_en, arch);
pack_src_formats.push_back(pack_src_format);
}

// Intermediates
for (int i = 0; i < NUM_OPERANDS; i++) {
//Intermediates can be inputs & outputs to same op, provide same format per operand id
pack_src_format = get_single_pack_src_format(intermed_formats[i], intermed_formats[i], unpack_conditional_dst_format, fp32_dest_acc_en, int_fpu_en, arch);
pack_src_format = get_single_pack_src_format(intermed_formats[i], intermed_formats[i], unpack_conditional_dst_format, fp32_dest_acc_en, bfp8_pack_precise, int_fpu_en, arch);
pack_src_formats.push_back(pack_src_format);
}
return pack_src_formats;
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/jit_build/data_format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const DataFormat get_single_pack_src_format(DataFormat input_format, DataFormat

std::vector<DataFormat> get_unpack_src_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS]);
std::vector<DataFormat> get_unpack_dst_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, std::vector<UnpackToDestMode> unpack_to_dest_mode, bool int_fpu_en = false);
std::vector<DataFormat> get_pack_src_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, bool int_fpu_en = false, tt::ARCH arch = tt::ARCH::GRAYSKULL);
std::vector<DataFormat> get_pack_src_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, bool bfp8_pack_precise, bool int_fpu_en = false, tt::ARCH arch = tt::ARCH::GRAYSKULL);
std::vector<DataFormat> get_pack_dst_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS]);

}
5 changes: 3 additions & 2 deletions tt_metal/jit_build/genfiles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,15 @@ static void emit_unpack_data_formats(
}

static std::pair<std::vector<DataFormat>, std::vector<DataFormat>> generate_pack_data_formats(
tt_hlk_desc& desc, DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, const tt::ARCH arch) {
tt_hlk_desc& desc, DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, bool bfp8_pack_precise, const tt::ARCH arch) {
vector<DataFormat> src_formats = tt::get_pack_src_formats(
desc.input_buf_dataformat_arr,
desc.param_buf_dataformat_arr,
desc.intermediate_buf_dataformat_arr,
desc.output_buf_dataformat_arr,
unpack_conditional_dst_format,
fp32_dest_acc_en,
bfp8_pack_precise,
false,
arch);

Expand Down Expand Up @@ -399,7 +400,7 @@ static void generate_data_format_descriptors(JitBuildOptions& options, const tt:

vector<DataFormat> pack_src_formats_all_cbs, pack_dst_formats_all_cbs;
tie(pack_src_formats_all_cbs, pack_dst_formats_all_cbs) =
generate_pack_data_formats(desc, unpack_conditional_dst_format, options.fp32_dest_acc_en, arch);
generate_pack_data_formats(desc, unpack_conditional_dst_format, options.fp32_dest_acc_en, options.bfp8_pack_precise, arch);

// equalize "upack src" and "pack dst" data format vectors
// both "unpack src" and "pack dst" refer to data in L1, "unpack src" == L1, and "pack dst" == L1
Expand Down
1 change: 1 addition & 0 deletions tt_metal/jit_build/settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class JitBuildOptions {
// We can keep for future WH support, otherwise not used in GS
bool fp32_dest_acc_en;
std::vector<UnpackToDestMode> unpack_to_dest_mode;
bool bfp8_pack_precise;

bool dst_full_sync_en;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create(
.math_fidelity = MathFidelity::HiFi4,
.fp32_dest_acc_en = args.fp32_dest_acc_en,
.unpack_to_dest_mode = unpack_to_dest_mode,
.bfp8_pack_precise = true,
.math_approx_mode = math_approx_mode,
.compile_args = compute_kernel_args_group_1,
.defines = unary_defines});
Expand All @@ -119,6 +120,7 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create(
.math_fidelity = MathFidelity::HiFi4,
.fp32_dest_acc_en = args.fp32_dest_acc_en,
.unpack_to_dest_mode = unpack_to_dest_mode,
.bfp8_pack_precise = true,
.math_approx_mode = math_approx_mode,
.compile_args = compute_kernel_args_group_2,
.defines = unary_defines});
Expand Down

0 comments on commit 280abd0

Please sign in to comment.