From 43957e2523e9530adb2019fa730b1b43b988bf5a Mon Sep 17 00:00:00 2001 From: Kanya-Mo <167922169+Kanya-Mo@users.noreply.github.com> Date: Wed, 27 Nov 2024 22:37:46 -0800 Subject: [PATCH] Add upsample_aa op series. (#1106) - [x] _upsample_bicubic2d_aa - [x] _upsample_bicubic2d_aa.out - [x] _upsample_bicubic2d_aa_backward - [x] _upsample_bicubic2d_aa_backward.grad_input - [x] _upsample_bilinear2d_aa - [x] _upsample_bilinear2d_aa.out - [x] _upsample_bilinear2d_aa_backward - [x] _upsample_bilinear2d_aa_backward.grad_input --------- Co-authored-by: Yutao Xu --- src/ATen/native/xpu/UpSample.h | 109 +++ src/ATen/native/xpu/UpSampleBicubic2d.cpp | 28 + src/ATen/native/xpu/UpSampleBilinear2d.cpp | 27 + src/ATen/native/xpu/XPUFallback.template | 2 - .../xpu/sycl/UpSampleBilinear2dKernels.cpp | 627 ++++++++++++++++++ .../xpu/sycl/UpSampleBilinear2dKernels.h | 34 + test/xpu/xpu_test_utils.py | 1 + yaml/native/native_functions.yaml | 40 ++ 8 files changed, 866 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/UpSample.h b/src/ATen/native/xpu/UpSample.h index 447eacff2..ef9696f41 100644 --- a/src/ATen/native/xpu/UpSample.h +++ b/src/ATen/native/xpu/UpSample.h @@ -316,4 +316,113 @@ static void upsample_increment_value_bounded( return {nbatch, channels, output_width}; } +namespace upsample_antialias { + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L20-L29 +struct BilinearFilterFunctor { + template + accscalar_t operator()(accscalar_t x) const { + if (x < 0) { + x = -x; + } + if (x < 1) { + return 1 - x; + } + return 0; + } + + static const int size = 2; +}; + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L46-L62 +struct BicubicFilterFunctor { + template + accscalar_t operator()(accscalar_t x) const { + // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + const accscalar_t a = -0.5; + if (x < 0) { + x = -x; + } + if (x < 1) { + return ((a + 2) * x - (a + 3)) * x * x + 1; + } + if (x < 2) { + return (((x - 5) * x + 8) * x - 4) * a; + } + return 0; + } + + static const int size = 4; +}; + +template +static inline void _compute_weights_span( + const int i, + const int input_size, + const accscalar_t scale, + const accscalar_t support, + int& xmin, + int& xsize, + accscalar_t& center) { + center = scale * (i + static_cast(0.5)); + xmin = + max(static_cast(center - support + static_cast(0.5)), + static_cast(0)); + xsize = + min(static_cast(center + support + static_cast(0.5)), + input_size) - + xmin; +} + +template +static inline void _compute_weights( + scalar_t* wt_ptr, + const accscalar_t scale, + int interp_size, + const interp_filter_t& interp_filter, + accscalar_t xmin_m_center, + int xsize) { + accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; + accscalar_t total_w = 0.0; + int j = 0; + for (j = 0; j < xsize; j++) { + accscalar_t w = interp_filter( + (j + xmin_m_center + static_cast(0.5)) * invscale); + wt_ptr[j] = static_cast(w); + total_w += w; + } + for (j = 0; j < xsize; j++) { + if (total_w != 0.0) { + wt_ptr[j] /= total_w; + } + } + for (; j < interp_size; j++) { + wt_ptr[j] = static_cast(0.0); + } +} + +template +static inline accscalar_t interpolate_aa_single_dim( + const scalar_t* src, + const scalar_t* weights, + int size) { + scalar_t t = static_cast(*src); + scalar_t wts = static_cast(weights[0]); + accscalar_t output = t * wts; + + int j = 1; + for (; j < size; j++) { + wts = static_cast(weights[j]); + t = static_cast(*(src + j)); + output += t * wts; + } + return output; +} + +} // namespace upsample_antialias + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/UpSampleBicubic2d.cpp b/src/ATen/native/xpu/UpSampleBicubic2d.cpp index b0baf0969..7e0e4de40 100644 --- a/src/ATen/native/xpu/UpSampleBicubic2d.cpp +++ b/src/ATen/native/xpu/UpSampleBicubic2d.cpp @@ -2,10 +2,13 @@ #include #include #include +#include #include #include #include +#include +#include namespace at { namespace native { TORCH_IMPL_FUNC(upsample_bicubic2d_out_xpu) @@ -37,5 +40,30 @@ TORCH_IMPL_FUNC(upsample_bicubic2d_backward_out_xpu) scales_h, scales_w); } + +TORCH_IMPL_FUNC(_upsample_bicubic2d_aa_out_xpu) +(const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& output) { + xpu::_upsample_bicubic2d_aa_out_kernel( + output, input, output_size, align_corners, scales_h, scales_w); +} + +TORCH_IMPL_FUNC(_upsample_bicubic2d_aa_backward_out_xpu) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& grad_input) { + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("upsample_bicubic2d_aa_backward_out_xpu"); + xpu::_upsample_bicubic2d_aa_backward_out_kernel( + grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w); +} } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/UpSampleBilinear2d.cpp b/src/ATen/native/xpu/UpSampleBilinear2d.cpp index 67fed551c..ee8c37ac0 100644 --- a/src/ATen/native/xpu/UpSampleBilinear2d.cpp +++ b/src/ATen/native/xpu/UpSampleBilinear2d.cpp @@ -6,6 +6,8 @@ #include #include +#include +#include namespace at { namespace native { @@ -38,5 +40,30 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_xpu) scales_w); } +TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_xpu) +(const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& output) { + xpu::_upsample_bilinear2d_aa_out_kernel( + output, input, output_size, align_corners, scales_h, scales_w); +} + +TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_xpu) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& grad_input) { + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("upsample_bilinear2d_aa_backward_out_xpu"); + xpu::_upsample_bilinear2d_aa_backward_out_kernel( + grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w); +} + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 10e16e2dc..8492a98be 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -189,10 +189,8 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_thnn_fused_gru_cell", "_to_sparse_csr", "triangular_solve.X", - "_upsample_bilinear2d_aa.out", "_validate_compressed_sparse_indices", "vdot", - "_upsample_bicubic2d_aa.out", }; for (auto& op_name : fallback_list) { m.impl( diff --git a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp index e5a717495..cd52a2a4e 100644 --- a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -740,6 +741,632 @@ void upsample_bilinear2d_backward_out_kernel( }); } +template +struct UpsampleGen2dAaKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<3> item) const { + const int output_x = item.get_global_id(2); + const int output_y = item.get_global_id(1); + + const int interp_height = (int)ceilf(support_h_) * 2 + 1; + const int interp_width = (int)ceilf(support_w_) * 2 + 1; + + auto ptr = + (scalar_t*)shared_.template get_multi_ptr() + .get(); + scalar_t* wx = ptr + interp_width * item.get_local_id(2); + scalar_t* wy = ptr + interp_width * item.get_local_range(2) + + interp_height * item.get_local_id(1); + const int offset = interp_width * item.get_local_range(2) + + interp_height * item.get_local_range(1); + scalar_t* buffer2 = ptr + offset + + interp_height * + (item.get_local_id(2) + + item.get_local_id(1) * item.get_local_range(2)); + + int xmin, xsize, ymin, ysize; + accscalar_t xcenter, ycenter; + + if (output_x < output_width_ && output_y < output_height_) { + upsample_antialias::_compute_weights_span( + output_x, + input_width_, + width_scale_, + support_w_, + xmin, + xsize, + xcenter); + upsample_antialias::_compute_weights_span( + output_y, + input_height_, + height_scale_, + support_h_, + ymin, + ysize, + ycenter); + + if (item.get_local_id(1) == 0) { + // All threadIdx.y have the same wx weights + upsample_antialias::_compute_weights( + wx, + width_scale_, + interp_width, + interp_filter_, + xmin - xcenter, + xsize); + } + + if (item.get_local_id(2) == 0) { + // All threadIdx.x have the same wy weights + upsample_antialias::_compute_weights( + wy, + height_scale_, + interp_height, + interp_filter_, + ymin - ycenter, + ysize); + } + } + + item.barrier(sycl_local_fence); + + if (output_x < output_width_ && output_y < output_height_) { + const scalar_t* buffer1; + auto odata = odata_; + + // Parallelized across batch/channels + for (int i = item.get_group(0); i < batchsize_ * channels_; + i += item.get_global_range(0)) { + int n = i / channels_; + int c = i % channels_; + // interpolate on y-axis for ymin to ymin + ysize + for (int y = 0; y < ysize; y++) { + buffer1 = &(idata_[n][c][ymin + y][xmin]); + buffer2[y] = static_cast( + upsample_antialias:: + interpolate_aa_single_dim( + buffer1, wx, xsize)); + } + odata[n][c][output_y][output_x] = static_cast( + upsample_antialias:: + interpolate_aa_single_dim( + buffer2, wy, ysize)); + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(local_size_, cgh); + } + + UpsampleGen2dAaKernelFunctor( + const accscalar_t height_scale, + const accscalar_t width_scale, + const PackedTensorAccessor idata, + PackedTensorAccessor odata, + InterpFilter interp_filter, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t batchsize, + int64_t channels, + const accscalar_t support_h, + const accscalar_t support_w, + int64_t local_size) + : height_scale_(height_scale), + width_scale_(width_scale), + idata_(idata), + odata_(odata), + interp_filter_(interp_filter), + input_height_(input_height), + input_width_(input_width), + output_height_(output_height), + output_width_(output_width), + batchsize_(batchsize), + channels_(channels), + support_h_(support_h), + support_w_(support_w), + local_size_(local_size) {} + + private: + const accscalar_t height_scale_; + const accscalar_t width_scale_; + const PackedTensorAccessor idata_; + PackedTensorAccessor odata_; + InterpFilter interp_filter_; + int64_t input_height_; + int64_t input_width_; + int64_t output_height_; + int64_t output_width_; + int64_t batchsize_; + int64_t channels_; + const accscalar_t support_h_; + const accscalar_t support_w_; + int64_t local_size_; + sycl_local_acc_t shared_; +}; + +template +struct UpsampleGen2dAaBackwardKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<3> item) const { + const int output_x = item.get_global_id(2); + const int output_y = item.get_global_id(1); + + const int interp_height = (int)ceilf(support_h_) * 2 + 1; + const int interp_width = (int)ceilf(support_w_) * 2 + 1; + + auto ptr = + (scalar_t*)shared_.template get_multi_ptr() + .get(); + scalar_t* wx = ptr + interp_width * item.get_local_id(2); + scalar_t* wy = ptr + interp_width * item.get_local_range(2) + + interp_height * item.get_local_id(1); + + int xmin, xsize, ymin, ysize; + accscalar_t xcenter, ycenter; + if (output_x < output_width_ && output_y < output_height_) { + upsample_antialias::_compute_weights_span( + output_x, + input_width_, + width_scale_, + support_w_, + xmin, + xsize, + xcenter); + upsample_antialias::_compute_weights_span( + output_y, + input_height_, + height_scale_, + support_h_, + ymin, + ysize, + ycenter); + + if (item.get_local_id(1) == 0) { + // All threadIdx.y have the same wx weights + upsample_antialias::_compute_weights( + wx, + width_scale_, + interp_width, + interp_filter_, + xmin - xcenter, + xsize); + } + + if (item.get_local_id(2) == 0) { + // All threadIdx.x have the same wy weights + upsample_antialias::_compute_weights( + wy, + height_scale_, + interp_height, + interp_filter_, + ymin - ycenter, + ysize); + } + } + + item.barrier(sycl_local_fence); + + if (output_x < output_width_ && output_y < output_height_) { + // Parallelized across batch/channels + auto idata = idata_; + for (int i = item.get_group(0); i < batchsize_ * channels_; + i += item.get_global_range(0)) { + int n = i / channels_; + int c = i % channels_; + scalar_t out_value = odata_[n][c][output_y][output_x]; + for (int y = 0; y < ysize; y++) { + for (int x = 0; x < xsize; x++) { + upsample_increment_value_bounded( + idata, + n, + c, + input_height_, + input_width_, + ymin + y, + xmin + x, + wx[x] * wy[y] * out_value); + } + } + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(local_size_, cgh); + } + + UpsampleGen2dAaBackwardKernelFunctor( + const accscalar_t height_scale, + const accscalar_t width_scale, + PackedTensorAccessor idata, + const PackedTensorAccessor odata, + InterpFilter interp_filter, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t batchsize, + int64_t channels, + const accscalar_t support_h, + const accscalar_t support_w, + int64_t local_size) + : height_scale_(height_scale), + width_scale_(width_scale), + idata_(idata), + odata_(odata), + interp_filter_(interp_filter), + input_height_(input_height), + input_width_(input_width), + output_height_(output_height), + output_width_(output_width), + batchsize_(batchsize), + channels_(channels), + support_h_(support_h), + support_w_(support_w), + local_size_(local_size) {} + + private: + const accscalar_t height_scale_; + const accscalar_t width_scale_; + PackedTensorAccessor idata_; + const PackedTensorAccessor odata_; + InterpFilter interp_filter_; + int64_t input_height_; + int64_t input_width_; + int64_t output_height_; + int64_t output_width_; + int64_t batchsize_; + int64_t channels_; + const accscalar_t support_h_; + const accscalar_t support_w_; + int64_t local_size_; + sycl_local_acc_t shared_; +}; + +template +void launch_upsample_gen2d_aa_kernel( + const accscalar_t height_scale, + const accscalar_t width_scale, + const PackedTensorAccessor idata, + PackedTensorAccessor odata, + InterpFilter interp_filter, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + const accscalar_t support_h, + const accscalar_t support_w) { + auto queue = getCurrentSYCLQueue(); + + const int interp_height = (int)ceilf(support_h) * 2 + 1; + const int interp_width = (int)ceilf(support_w) * 2 + 1; + + auto sharedMemPerBlock = syclLocalMemSize(); + auto total_threads = syclMaxWorkItemsPerTile(); + int maxThreadsPerBlock = std::min( + syclMaxWorkGroupSize< + UpsampleGen2dAaKernelFunctor>(), + 256); // 256 performs better + int block_x = syclMaxSubGroupSize(); + + int numer = + sharedMemPerBlock * 1.0 / sizeof(scalar_t) - interp_width * block_x; + int denom = interp_height * (block_x + 1); + int block_y = lastPow2((unsigned int)(numer / denom)); + block_y = std::min(maxThreadsPerBlock / block_x, block_y); + + int grid_x = std::min( + total_threads, (output_width + block_x - 1) / block_x * block_x); + int grid_y = std::min( + total_threads / grid_x, + (output_height + block_y - 1) / block_y * block_y); + int grid_z = + std::min(total_threads / grid_x / grid_y, nbatch * channels); + + int64_t weights_per_block = interp_width * block_x + interp_height * block_y; + weights_per_block += interp_height * block_y * block_x; + int64_t shmem_size = weights_per_block * sizeof(scalar_t); + TORCH_CHECK( + shmem_size <= sharedMemPerBlock, + "Provided interpolation parameters can not be handled with current algorithm implementation. ", + "Please reduce the scale factor. Too much shared memory required: ", + shmem_size, + " vs ", + sharedMemPerBlock); + + UpsampleGen2dAaKernelFunctor kfn( + height_scale, + width_scale, + idata, + odata, + interp_filter, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + support_h, + support_w, + weights_per_block); + + sycl_kernel_submit( + sycl::range<3>(grid_z, grid_y, grid_x), + sycl::range<3>(1, block_y, block_x), + queue, + kfn); +} + +template +void launch_upsample_gen2d_aa_backward_kernel( + const accscalar_t height_scale, + const accscalar_t width_scale, + PackedTensorAccessor idata, + const PackedTensorAccessor odata, + InterpFilter interp_filter, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + const accscalar_t support_h, + const accscalar_t support_w) { + auto queue = getCurrentSYCLQueue(); + + auto sharedMemPerBlock = syclLocalMemSize(); + auto total_threads = syclMaxWorkItemsPerTile(); + int maxThreadsPerBlock = std::min( + syclMaxWorkGroupSize< + UpsampleGen2dAaKernelFunctor>(), + 256); // 256 performs better + int block_x = syclMaxSubGroupSize(); + int block_y = maxThreadsPerBlock / block_x; + + int grid_x = std::min( + total_threads, (output_width + block_x - 1) / block_x * block_x); + int grid_y = std::min( + total_threads / grid_x, + (output_height + block_y - 1) / block_y * block_y); + int grid_z = + std::min(total_threads / grid_x / grid_y, nbatch * channels); + + const int interp_height = (int)ceilf(support_h) * 2 + 1; + const int interp_width = (int)ceilf(support_w) * 2 + 1; + + int64_t weights_per_block = interp_width * block_x + interp_height * block_y; + int64_t shmem_size = weights_per_block * sizeof(scalar_t); + TORCH_CHECK( + shmem_size <= sharedMemPerBlock, + "Provided interpolation parameters can not be handled with current algorithm implementation. ", + "Please reduce the scale factor. Too much shared memory required: ", + shmem_size, + " vs ", + sharedMemPerBlock); + + UpsampleGen2dAaBackwardKernelFunctor kfn( + height_scale, + width_scale, + idata, + odata, + interp_filter, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + support_h, + support_w, + weights_per_block); + + sycl_kernel_submit( + sycl::range<3>(grid_z, grid_y, grid_x), + sycl::range<3>(1, block_y, block_x), + queue, + kfn); +} + +template +void upsample_gen2d_aa_out_kernel( + const Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2}; + checkAllSameGPU(__func__, {input_arg, output_arg}); + + // TODO: remove this when the kernel is updated to support the channels_last + // memory format. + auto output_c = output.is_contiguous() + ? output + : at::empty(output.sizes(), output.options()); + auto input = input_.contiguous(); + int output_height = output_size[0]; + int output_width = output_size[1]; + int input_height = input.size(2); + int input_width = input.size(3); + int nbatch = input.size(0); + int channels = input.size(1); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "upsample_bilinear2d_xpu", + [&] { + using accscalar_t = acc_type_device; + auto idata = input.packed_accessor64(); + auto odata = output_c.packed_accessor64(); + + const accscalar_t height_scale = area_pixel_compute_scale( + input_height, output_height, align_corners, scales_h); + const accscalar_t width_scale = area_pixel_compute_scale( + input_width, output_width, align_corners, scales_w); + + auto interp_filter = InterpFilter(); + const accscalar_t support_h = static_cast( + (height_scale >= 1.0) ? (interp_filter.size * 0.5) * height_scale + : interp_filter.size * 0.5); + const accscalar_t support_w = static_cast( + (width_scale >= 1.0) ? (interp_filter.size * 0.5) * width_scale + : interp_filter.size * 0.5); + launch_upsample_gen2d_aa_kernel( + height_scale, + width_scale, + idata, + odata, + interp_filter, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + support_h, + support_w); + }); + + if (!output.is_contiguous()) { + output.copy_(output_c); + } +} + +template +void upsample_gen2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + TensorArg grad_input_arg{grad_input, "grad_input", 1}, + grad_output_arg{grad_output_, "grad_output_", 2}; + checkAllSameGPU( + "upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg}); + + int output_height = output_size[0]; + int output_width = output_size[1]; + int input_height = input_size[2]; + int input_width = input_size[3]; + int nbatch = input_size[0]; + int channels = input_size[1]; + + Tensor grad_output = grad_output_.contiguous(); + grad_input.zero_(); + + if (grad_output.sizes() == grad_input.sizes()) { + grad_input.copy_(grad_output_); + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + grad_output.scalar_type(), + "upsample_bilinear2d_xpu", + [&] { + using accscalar_t = acc_type_device; + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); + + const accscalar_t height_scale = area_pixel_compute_scale( + input_height, output_height, align_corners, scales_h); + const accscalar_t width_scale = area_pixel_compute_scale( + input_width, output_width, align_corners, scales_w); + + auto interp_filter = InterpFilter(); + const accscalar_t support_h = static_cast( + (height_scale >= 1.0) ? (interp_filter.size * 0.5) * height_scale + : interp_filter.size * 0.5); + const accscalar_t support_w = static_cast( + (width_scale >= 1.0) ? (interp_filter.size * 0.5) * width_scale + : interp_filter.size * 0.5); + launch_upsample_gen2d_aa_backward_kernel( + height_scale, + width_scale, + idata, + odata, + interp_filter, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + support_h, + support_w); + }); +} + +void _upsample_bilinear2d_aa_out_kernel( + const Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + return upsample_gen2d_aa_out_kernel< + upsample_antialias::BilinearFilterFunctor>( + output, input, output_size, align_corners, scales_h, scales_w); +} + +void _upsample_bilinear2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + return upsample_gen2d_aa_backward_out_kernel< + upsample_antialias::BilinearFilterFunctor>( + grad_input, + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w); +} + +void _upsample_bicubic2d_aa_out_kernel( + const Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + return upsample_gen2d_aa_out_kernel( + output, input, output_size, align_corners, scales_h, scales_w); +} + +void _upsample_bicubic2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + return upsample_gen2d_aa_backward_out_kernel< + upsample_antialias::BicubicFilterFunctor>( + grad_input, + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w); +} + } // namespace at::native::xpu #pragma GCC diagnostic pop diff --git a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h index aa5ee2c09..d7ae0dcf1 100644 --- a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h +++ b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h @@ -21,4 +21,38 @@ TORCH_XPU_API void upsample_bilinear2d_backward_out_kernel( c10::optional scales_h, c10::optional scales_w); +TORCH_XPU_API void _upsample_bilinear2d_aa_out_kernel( + const Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + +TORCH_XPU_API void _upsample_bilinear2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + +TORCH_XPU_API void _upsample_bicubic2d_aa_out_kernel( + const Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + +TORCH_XPU_API void _upsample_bicubic2d_aa_backward_out_kernel( + const Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + } // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 05a5b8e73..6c31415cc 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -209,6 +209,7 @@ "nn.functional.pad", "nn.functional.interpolate", "nn.functional.upsample_bilinear", + "_upsample_bilinear2d_aa", "nn.functional.upsample_nearest", "nn.functional.nll_loss", "nn.functional.smooth_l1_loss", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 999dcaf28..e3bec5484 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -4768,6 +4768,26 @@ python_module: nn structured_delegate: upsample_bicubic2d_backward.grad_input +- func: _upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + XPU: _upsample_bicubic2d_aa_out_xpu + +- func: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bicubic2d_aa.out + +- func: _upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + XPU: _upsample_bicubic2d_aa_backward_out_xpu + +- func: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bicubic2d_aa_backward.grad_input + - func: upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn structured: True @@ -4788,6 +4808,26 @@ python_module: nn structured_delegate: upsample_bilinear2d_backward.grad_input +- func: _upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + XPU: _upsample_bilinear2d_aa_out_xpu + +- func: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bilinear2d_aa.out + +- func: _upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + XPU: _upsample_bilinear2d_aa_backward_out_xpu + +- func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bilinear2d_aa_backward.grad_input + - func: native_norm(Tensor self, Scalar p=2) -> Tensor dispatch: SparseXPU: norm_sparse