Skip to content

Commit

Permalink
cuda with CC<70 and hip do not support 16 bit atomic. throw error for…
Browse files Browse the repository at this point in the history
… idr
  • Loading branch information
yhmtsai committed Nov 28, 2024
1 parent ab82457 commit 8f64d67
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions common/cuda_hip/solver/idr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,19 @@ void update_g_and_u(std::shared_ptr<const DefaultExecutor> exec,
if (nrhs > 1 || is_complex<ValueType>()) {
components::fill_array(exec, alpha->get_values(), nrhs,
zero<ValueType>());
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
size, nrhs, as_device_type(p_i),
as_device_type(g_k->get_values()), g_k->get_stride(),
as_device_type(alpha->get_values()),
stop_status->get_const_data());
// not support 16 bit atomic
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(alpha);
} else
#endif
{
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
size, nrhs, as_device_type(p_i),
as_device_type(g_k->get_values()), g_k->get_stride(),
as_device_type(alpha->get_values()),
stop_status->get_const_data());
}
} else {
blas::dot(exec->get_blas_handle(), size, p_i, 1, g_k->get_values(),
g_k->get_stride(), alpha->get_values());
Expand Down Expand Up @@ -505,10 +513,18 @@ void update_m(std::shared_ptr<const DefaultExecutor> exec, const size_type nrhs,
auto m_i = m->get_values() + i * m_stride + k * nrhs;
if (nrhs > 1 || is_complex<ValueType>()) {
components::fill_array(exec, m_i, nrhs, zero<ValueType>());
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
size, nrhs, as_device_type(p_i),
as_device_type(g_k->get_const_values()), g_k->get_stride(),
as_device_type(m_i), stop_status->get_const_data());
// not support 16 bit atomic
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(m_i);
} else
#endif
{
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
size, nrhs, as_device_type(p_i),
as_device_type(g_k->get_const_values()), g_k->get_stride(),
as_device_type(m_i), stop_status->get_const_data());
}
} else {
blas::dot(exec->get_blas_handle(), size, p_i, 1,
g_k->get_const_values(), g_k->get_stride(), m_i);
Expand Down

0 comments on commit 8f64d67

Please sign in to comment.