Skip to content

Commit

Permalink
aarch64: enable matmul bf16f32 format desc
Browse files Browse the repository at this point in the history
  • Loading branch information
aditew01 committed Nov 8, 2024
1 parent 274a419 commit 4a026ff
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions include/ideep/operators/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ struct matmul_forward : public dnnl::matmul,
attr_t& bias_attr = param.bias_attr;
op_attr = attr;
auto dst_data_type = data_type::f32;
auto wei_data_type = data_type::f32;

tensor::dims src_dims = src.get_dims();
tensor::dims dst_dims = {src_dims[0]};
Expand All @@ -818,23 +819,42 @@ struct matmul_forward : public dnnl::matmul,
// introduces *an extra reorder* afterwards. Here we keep the weight format
// untouched thanks to optimizations for both plain and transposed formats
// in DNNL.
IDEEP_ENFORCE(weights.get_data_type() == data_type::f32 ||
weights.get_data_type() == data_type::bf16 ||
weights.get_data_type() == data_type::f16,
"Incorrect data type in weights");
dst_data_type = src.get_data_type() == data_type::bf16 ?
data_type::bf16 :
((src.get_data_type() == data_type::f16) ?
data_type::f16 : data_type::f32);
src_desc = src.get_desc().to_type(dst_data_type);
IDEEP_ENFORCE(
weights.get_data_type() == data_type::f32 ||
weights.get_data_type() == data_type::bf16 ||
weights.get_data_type() == data_type::f16,
"Incorrect data type in weights");

#ifdef __aarch64__

dst_data_type = src.get_data_type() == data_type::bf16
? ((dst.get_data_type() == data_type::f32) ? data_type::f32
: data_type::bf16)
: ((src.get_data_type() == data_type::f16) ? data_type::f16
: data_type::f32);
src_desc = (src.get_data_type() == data_type::bf16 &&
dst_data_type == data_type::f32)
? src.get_desc()
: src.get_desc().to_type(dst_data_type);
// for aarch64 ACL backend with fixed format kernels, the weights are
// always in blocked layout, so, set the descriptor to tag::any for the backend
// to decide the format
weights_desc = tensor::desc(weights.get_dims(), dst_data_type, tag::any);
wei_data_type = (src.get_data_type() == data_type::bf16 &&
dst_data_type == data_type::f32) ? data_type::bf16 : dst_data_type;
weights_desc = tensor::desc(weights.get_dims(), wei_data_type, tag::any);
#else

dst_data_type = src.get_data_type() == data_type::bf16
? ((dst.get_data_type() == data_type::f32) ? data_type::f32
: data_type::bf16)
: ((src.get_data_type() == data_type::f16) ? data_type::f16
: data_type::f32);

src_desc = (src.get_data_type() == data_type::bf16 &&
dst_data_type == data_type::f32)
? src.get_desc()
: src.get_desc().to_type(dst_data_type);

// For fp32 matmul, weight (2nd input) is usually not in blocked layout
// Plain layout runs faster as of oneDNN 3.0
// Should use tag::any to query blocked layout if there is perf gain later
Expand Down Expand Up @@ -1529,4 +1549,4 @@ struct matmul_forward : public dnnl::matmul,

} // namespace ideep

#endif
#endif

0 comments on commit 4a026ff

Please sign in to comment.