From 5db44cd30355a849120e0633152956eb7b51b511 Mon Sep 17 00:00:00 2001 From: jiayisun Date: Fri, 2 Apr 2021 16:05:30 +0800 Subject: [PATCH] add attr: fuse_gelu fuse_elu fuse_sigmoid fuse_clamp fuse_swish --- include/ideep/attributes.hpp | 44 ++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/include/ideep/attributes.hpp b/include/ideep/attributes.hpp index 63ab78d6..f1cadefc 100644 --- a/include/ideep/attributes.hpp +++ b/include/ideep/attributes.hpp @@ -41,6 +41,50 @@ struct attr_t : public dnnl::primitive_attr { return attr; } + static attr_t fuse_gelu(float scale = 1.0, float alpha = 0.f, + float beta = 0.f) { + attr_t attr; + post_ops po; + po.append_eltwise(scale, algorithm::eltwise_gelu_erf, alpha, beta); + attr.set_post_ops(po); + return attr; + } + + static attr_t fuse_elu(float scale = 1.0, float alpha = 0.f, + float beta = 1.0) { + attr_t attr; + post_ops po; + po.append_eltwise(scale, algorithm::eltwise_elu, alpha, beta); + attr.set_post_ops(po); + return attr; + } + + static attr_t fuse_sigmoid(float scale = 1.0, float alpha = 1.0, + float beta = 0.f) { + attr_t attr; + post_ops po; + po.append_eltwise(scale, algorithm::eltwise_logistic, alpha, beta); + attr.set_post_ops(po); + return attr; + } + + static attr_t fuse_clamp(float lower_bound = -1.0, float upper_bound = 1.0) { + attr_t attr; + post_ops po; + po.append_eltwise(1.0, algorithm::eltwise_clip, lower_bound, upper_bound); + attr.set_post_ops(po); + return attr; + } + + static attr_t fuse_swish(float scale = 1.0, float alpha = 1.0, + float beta = 0.f) { + attr_t attr; + post_ops po; + po.append_eltwise(scale, algorithm::eltwise_swish, alpha, beta); + attr.set_post_ops(po); + return attr; + } + static attr_t residual(float sum_scale = 1.0, float relu_scale = 1.0, float alpha = 0.f, float beta = 0.f) { attr_t attr;