From abcbdf06c25fc2eaad61582867255c824c16ff7c Mon Sep 17 00:00:00 2001 From: Radomir Djogo Date: Tue, 12 Nov 2024 21:00:04 +0000 Subject: [PATCH] #14032: set bfp8_pack_precise based on op and dtype --- ttnn/cpp/ttnn/operations/copy.hpp | 11 ++++++----- .../eltwise/unary/device/unary_device_operation.cpp | 2 ++ .../eltwise/unary/device/unary_device_operation.hpp | 1 + .../unary/device/unary_device_operation_types.hpp | 1 + .../eltwise/unary/device/unary_program_factory.cpp | 4 ++-- .../unary/device/unary_sharded_program_factory.cpp | 1 + ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp | 3 ++- 7 files changed, 15 insertions(+), 8 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/copy.hpp b/ttnn/cpp/ttnn/operations/copy.hpp index ab75133a5f3f..9991455a4fab 100644 --- a/ttnn/cpp/ttnn/operations/copy.hpp +++ b/ttnn/cpp/ttnn/operations/copy.hpp @@ -23,18 +23,19 @@ inline Tensor copy_impl( const std::vector& op_chain, const std::optional& memory_config = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { - DataType output_dtype = (op_chain[0].op_type == ttnn::operations::unary::UnaryOpType::TYPECAST) ? static_cast(op_chain[0].params[1]) : input_tensor.get_dtype(); - bool preserve_fp32_precision = (op_chain[0].op_type == ttnn::operations::unary::UnaryOpType::TYPECAST) and (input_tensor.get_dtype() == DataType::FLOAT32); + DataType output_dtype = (op_chain[0].op_type == unary::UnaryOpType::TYPECAST) ? static_cast(op_chain[0].params[1]) : input_tensor.get_dtype(); + auto arch = input_tensor.device()->arch(); + bool preserve_fp32_precision = (arch != tt::ARCH::GRAYSKULL) and (input_tensor.get_dtype() == DataType::FLOAT32); bool fp32_dest_acc_en = preserve_fp32_precision or output_dtype == DataType::UINT32 or output_dtype == DataType::INT32 or output_dtype == DataType::FLOAT32 or input_tensor.get_dtype() == DataType::UINT32 or - input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to - // DST directly, fp32 is converted to fp16b + input_tensor.get_dtype() == DataType::INT32; + bool bfp8_pack_precise = (op_chain[0].op_type == unary::UnaryOpType::TYPECAST && output_dtype == DataType::BFLOAT8_B); auto output_memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(input_tensor.memory_config()); - return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, optional_output_tensor); + return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, bfp8_pack_precise, optional_output_tensor); } } // namespace detail diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp index 255ca4595048..179077d05073 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp @@ -192,6 +192,7 @@ UnaryDeviceOperation::invoke( const MemoryConfig& output_memory_config, bool fp32_dest_acc_en, bool preserve_fp32_precision, + bool bfp8_pack_precise, const std::optional& preallocated_output) { return { operation_attributes_t{ @@ -200,6 +201,7 @@ UnaryDeviceOperation::invoke( .output_memory_config = output_memory_config, .fp32_dest_acc_en = fp32_dest_acc_en, .preserve_fp32_precision = preserve_fp32_precision, + .bfp8_pack_precise = bfp8_pack_precise, }, tensor_args_t{.input = input, .preallocated_output = preallocated_output}}; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.hpp index 30cb9296c916..a8bdafcf64ba 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.hpp @@ -46,6 +46,7 @@ struct UnaryDeviceOperation { const MemoryConfig& output_memory_config, bool fp32_dest_acc_en, bool preserve_fp32_precision, + bool bfp8_pack_precise, const std::optional& preallocated_output); }; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp index 95d100a9c852..3c9ce09fb752 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp @@ -18,6 +18,7 @@ struct operation_attributes_t { const MemoryConfig output_memory_config; const bool fp32_dest_acc_en = false; const bool preserve_fp32_precision = false; + const bool bfp8_pack_precise = false; }; struct tensor_args_t { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_program_factory.cpp index 3c56302305c5..ab8166c1f4c8 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_program_factory.cpp @@ -101,7 +101,7 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create( .math_fidelity = MathFidelity::HiFi4, .fp32_dest_acc_en = args.fp32_dest_acc_en, .unpack_to_dest_mode = unpack_to_dest_mode, - .bfp8_pack_precise = true, + .bfp8_pack_precise = args.bfp8_pack_precise, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_1, .defines = unary_defines}); @@ -120,7 +120,7 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create( .math_fidelity = MathFidelity::HiFi4, .fp32_dest_acc_en = args.fp32_dest_acc_en, .unpack_to_dest_mode = unpack_to_dest_mode, - .bfp8_pack_precise = true, + .bfp8_pack_precise = args.bfp8_pack_precise, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_2, .defines = unary_defines}); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_sharded_program_factory.cpp index e2f771f37f66..b693504d98ad 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_sharded_program_factory.cpp @@ -130,6 +130,7 @@ UnaryShardedProgramFactory::cached_program_t UnaryShardedProgramFactory::create( .math_fidelity = MathFidelity::HiFi4, .fp32_dest_acc_en = args.fp32_dest_acc_en, .unpack_to_dest_mode = unpack_to_dest_mode, + .bfp8_pack_precise = args.bfp8_pack_precise, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_1, .defines = unary_defines}); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp index 7a40003fa527..e68ec9535d60 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp @@ -31,9 +31,10 @@ inline Tensor unary_impl( output_dtype == DataType::FLOAT32 or input_tensor.get_dtype() == DataType::UINT32 or input_tensor.get_dtype() == DataType::INT32; + bool bfp8_pack_precise = (op_chain[0].op_type == UnaryOpType::TYPECAST && output_dtype == DataType::BFLOAT8_B); auto output_memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(input_tensor.memory_config()); - return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, optional_output_tensor); + return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, bfp8_pack_precise, optional_output_tensor); } } // namespace detail