Skip to content

Commit

Permalink
Add aten::kthvalue (#1091)
Browse files Browse the repository at this point in the history
- [x] kthvalue.values

---------

Co-authored-by: mayuyuace <qiming1.zhang@intel.com>
  • Loading branch information
LuFinch and mayuyuace authored Nov 20, 2024
1 parent 2ea2405 commit 97fb903
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 2 deletions.
71 changes: 71 additions & 0 deletions src/ATen/native/xpu/Sorting.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@

#include <ATen/MemoryOverlap.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/op_registration/adaption.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Sorting.h>
#include <ATen/native/TensorIterator.h>

#include <ATen/native/SortingUtils.h>
#include <ATen/native/xpu/sycl/Sorting.h>
#include <comm/TensorInfo.h>
#include <comm/xpu_aten.h>

#include <ATen/ops/full.h>
#include <ATen/ops/kthvalue_native.h>
#include <ATen/ops/where.h>

namespace at {
Expand Down Expand Up @@ -129,5 +132,73 @@ Tensor nanmedian_xpu(const Tensor& self) {
return median_impl(self, /*ignore_nan=*/true);
}

std::tuple<Tensor&, Tensor&> kthvalue_out_impl(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
zero_numel_check_dims(self, dim, "kthvalue()");

TORCH_CHECK(
k >= 1 && k <= slicesize,
"kthvalue(): selected number k out of range for dimension ",
dim);

at::assert_no_overlap(self, values);

_reduction_with_indices_allocate_or_resize_output(
values, indices, self, dim, keepdim);
if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
return std::forward_as_tuple(values, indices);
}

TORCH_CHECK(
self.dim() <= XPU_MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
XPU_MAX_TENSORINFO_DIMS,
" dimensions");

// Based on required index size, run the algorithm with the
// appropriate index type
if (self.numel() != 0) {
at::native::xpu::launch_kthvalue_kernel(values, indices, self, dim, k);
}

if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
return std::forward_as_tuple(values, indices);
}

std::tuple<Tensor&, Tensor&> kthvalue_out_xpu(
const Tensor& self,
int64_t k,
int64_t dim,
bool keepdim,
Tensor& values,
Tensor& indices) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of the kth value, the procedure for
// choosing which of the duplicates to use for the indices output is
// nondeterministic.
at::globalContext().alertNotDeterministic("kthvalue XPU");
auto result = [&]() {
NoNamesGuard guard;
// `kthvalue_out_impl` expects contiguous in input `self`.
return kthvalue_out_impl(
values, indices, self.contiguous(), k, dim, keepdim);
}();
namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
return result;
}

} // namespace native
} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_flash_attention_forward",
"geqrf",
"index_reduce.out",
"kthvalue.values",
"linalg_cholesky_ex.L",
"_linalg_det.result",
"linalg_eig",
Expand Down
182 changes: 182 additions & 0 deletions src/ATen/native/xpu/sycl/Sorting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,102 @@ struct GatherMedianKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
sycl_local_acc_t<index_t> num_nan_;
};

template <typename scalar_t, typename index_t, int Dim>
struct GatherKthValueKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
void operator()(sycl::nd_item<1> item) const {
index_t slice = item.get_group_linear_id();

// Finds the start offset for our slice
index_t valuesSliceStartIndex =
IndexToOffset<scalar_t, index_t>::get(slice, values_);
index_t indicesSliceStartIndex =
IndexToOffset<int64_t, index_t>::get(slice, indices_);
index_t inputSliceStartIndex =
IndexToOffset<const scalar_t, index_t>::get(slice, input_);

scalar_t* valuesSliceStart = values_data_ + valuesSliceStartIndex;
int64_t* indicesSliceStart = indices_data_ + indicesSliceStartIndex;
const scalar_t* inputSliceStart = in_data_ + inputSliceStartIndex;

// Find the k-th highest element in our input
scalar_t kValue = static_cast<scalar_t>(0);
radixSelect<
scalar_t,
typename TopKTypeConfig<scalar_t>::RadixType,
index_t,
false>(
inputSliceStart,
k,
inputSliceSize_,
inputWithinSliceStride_,
smem_,
&kValue,
item);

// Find the index of the k-th highest element
index_t kValueIndex = 0;
bool foundKValue = false;

for (index_t i = item.get_local_id(0); i < inputSliceSize_;
i += item.get_local_range(0)) {
bool inRange = (i < inputSliceSize_);
scalar_t v = inRange ? inputSliceStart[i * inputWithinSliceStride_]
: static_cast<scalar_t>(0);
bool isKValue =
inRange && ((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
if (isKValue) {
kValueIndex = i;
foundKValue = true;
break;
}
}

if (foundKValue) {
valuesSliceStart[0] = kValue;
indicesSliceStart[0] = kValueIndex;
}
}

void sycl_ker_config_convention(sycl::handler& cgh) {
smem_ = sycl_local_acc_t<int>(32, cgh);
}

GatherKthValueKernelFunctor(
TensorInfo<scalar_t, index_t> values,
TensorInfo<int64_t, index_t> indices,
TensorInfo<const scalar_t, index_t> input,
index_t inputSliceSize,
index_t numInputSlices,
index_t inputWithinSliceStride,
index_t k,
const scalar_t* in_data,
scalar_t* values_data,
int64_t* indices_data)
: values_(values),
indices_(indices),
input_(input),
inputSliceSize_(inputSliceSize),
numInputSlices_(numInputSlices),
inputWithinSliceStride_(inputWithinSliceStride),
k(k),
in_data_(in_data),
values_data_(values_data),
indices_data_(indices_data) {}

private:
TensorInfo<scalar_t, index_t> values_;
TensorInfo<int64_t, index_t> indices_;
TensorInfo<const scalar_t, index_t> input_;
index_t inputSliceSize_;
index_t numInputSlices_;
index_t inputWithinSliceStride_;
index_t k;
const scalar_t* in_data_;
scalar_t* values_data_;
int64_t* indices_data_;
sycl_local_acc_t<int> smem_;
};

// kernel to find the median, and its index, of the values along dimension dim
template <typename scalar_t, typename index_t, int Dim>
void gatherMedian(
Expand Down Expand Up @@ -312,6 +408,39 @@ void gatherMedian(
numInputSlices * local_size, local_size, getCurrentSYCLQueue(), kfn);
}

// Finds the rank k element, and its index, of the values along dimension dim
template <typename scalar_t, typename index_t, int Dim>
void gatherKthValue(
TensorInfo<const scalar_t, index_t> input,
index_t inputSliceSize,
index_t k,
index_t numInputSlices,
index_t inputWithinSliceStride,
TensorInfo<scalar_t, index_t> kthValue,
TensorInfo<int64_t, index_t> indices) {
// Shared memory for the subroutine RadixSelect. Note that RadixSelect
// converts the floating point type to int with the same relative ordering.

auto values_data = kthValue.data;
auto indices_data = indices.data;
auto in_data = input.data;

GatherKthValueKernelFunctor<scalar_t, index_t, Dim> kfn(
kthValue,
indices,
input,
inputSliceSize,
numInputSlices,
inputWithinSliceStride,
k,
in_data,
values_data,
indices_data);
int64_t local_size = syclMaxWorkGroupSize(kfn);
sycl_kernel_submit(
numInputSlices * local_size, local_size, getCurrentSYCLQueue(), kfn);
}

struct MedianLauncher {
bool ignore_nan;

Expand All @@ -338,6 +467,34 @@ struct MedianLauncher {
}
};

struct KthValueLauncher {
int64_t k;

KthValueLauncher(int64_t k) : k(k) {}

template <typename scalar_t, typename index_t, int all_dims>
inline void launch(
TensorInfo<scalar_t, index_t> values_info,
int collapse_values_dim,
TensorInfo<int64_t, index_t> indices_info,
int collapse_indices_dim,
TensorInfo<const scalar_t, index_t> self_info,
int collapse_self_dim,
int64_t num_slices,
int64_t slice_size) {
gatherKthValue<scalar_t, index_t, all_dims>(
self_info,
slice_size,
k,
num_slices,
/* The actual dimension that the k-selection is running in */
/* may have changed from collapseDims() */
self_info.strides[collapse_self_dim],
values_info,
indices_info);
}
};

void launch_median_kernel(
const TensorBase& vals,
const TensorBase& inds,
Expand All @@ -362,6 +519,31 @@ void launch_median_kernel(
});
}

void launch_kthvalue_kernel(
const TensorBase& values,
const TensorBase& indices,
const TensorBase& self,
int64_t dim,
int64_t k) {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
self.scalar_type(),
"kthvalue_xpu",
[&] {
AT_DISPATCH_INDEX_TYPES(
canUse32BitIndexMath(self) && canUse32BitIndexMath(values) &&
canUse32BitIndexMath(indices)
? ScalarType::Int
: ScalarType::Long,
"kth_value_launcher_xpu",
[&] {
run_launcher<scalar_t, index_t>(
values, indices, self, dim, KthValueLauncher(k));
});
});
}

} // namespace at::native::xpu

#pragma GCC diagnostic pop
Expand Down
7 changes: 7 additions & 0 deletions src/ATen/native/xpu/sycl/Sorting.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,11 @@ TORCH_XPU_API void launch_median_kernel(
int64_t dim,
bool ignore_nan);

TORCH_XPU_API void launch_kthvalue_kernel(
const TensorBase& values,
const TensorBase& indices,
const TensorBase& self,
int64_t dim,
int64_t k);

} // namespace at::native::xpu
6 changes: 6 additions & 0 deletions test/xpu/extended/skip_list_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@
"test_compare_cpu_sub_xpu_float16",
# different results for value index due to unstable sort.
# XPU and CUDA have the same result.
"test_compare_cpu_kthvalue_xpu_bfloat16",
"test_compare_cpu_kthvalue_xpu_int16",
"test_compare_cpu_kthvalue_xpu_int32",
"test_compare_cpu_kthvalue_xpu_int64",
"test_compare_cpu_kthvalue_xpu_int8",
"test_compare_cpu_kthvalue_xpu_uint8",
"test_compare_cpu_median_xpu_int16",
"test_compare_cpu_median_xpu_int32",
"test_compare_cpu_median_xpu_int64",
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"masked_select",
"isin",
"isnan",
"kthvalue",
"lcm",
"le",
"log",
Expand Down
16 changes: 15 additions & 1 deletion yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3093,7 +3093,21 @@
variants: function
dispatch:
CompositeExplicitAutograd: _foreach_copy


- func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
variants: function, method
dispatch:
CompositeExplicitAutograd: kthvalue

- func: kthvalue.values(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
dispatch:
XPU: kthvalue_out_xpu

- func: kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
variants: function, method

- func: kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)

- func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
dispatch:
XPU: layer_norm_xpu
Expand Down

0 comments on commit 97fb903

Please sign in to comment.