From 8f64d6764c8de1c53d923f96689ca130bae30eb8 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Mon, 25 Nov 2024 16:11:20 +0100 Subject: [PATCH] cuda with CC<70 and hip do not support 16 bit atomic. throw error for idr --- common/cuda_hip/solver/idr_kernels.cpp | 34 +++++++++++++++++++------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/common/cuda_hip/solver/idr_kernels.cpp b/common/cuda_hip/solver/idr_kernels.cpp index 0dc310ebd2e..649d8a1769c 100644 --- a/common/cuda_hip/solver/idr_kernels.cpp +++ b/common/cuda_hip/solver/idr_kernels.cpp @@ -454,11 +454,19 @@ void update_g_and_u(std::shared_ptr exec, if (nrhs > 1 || is_complex()) { components::fill_array(exec, alpha->get_values(), nrhs, zero()); - multidot_kernel<<get_stream()>>>( - size, nrhs, as_device_type(p_i), - as_device_type(g_k->get_values()), g_k->get_stride(), - as_device_type(alpha->get_values()), - stop_status->get_const_data()); + // not support 16 bit atomic +#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700)) + if constexpr (std::is_same_v, half>) { + GKO_NOT_SUPPORTED(alpha); + } else +#endif + { + multidot_kernel<<get_stream()>>>( + size, nrhs, as_device_type(p_i), + as_device_type(g_k->get_values()), g_k->get_stride(), + as_device_type(alpha->get_values()), + stop_status->get_const_data()); + } } else { blas::dot(exec->get_blas_handle(), size, p_i, 1, g_k->get_values(), g_k->get_stride(), alpha->get_values()); @@ -505,10 +513,18 @@ void update_m(std::shared_ptr exec, const size_type nrhs, auto m_i = m->get_values() + i * m_stride + k * nrhs; if (nrhs > 1 || is_complex()) { components::fill_array(exec, m_i, nrhs, zero()); - multidot_kernel<<get_stream()>>>( - size, nrhs, as_device_type(p_i), - as_device_type(g_k->get_const_values()), g_k->get_stride(), - as_device_type(m_i), stop_status->get_const_data()); + // not support 16 bit atomic +#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700)) + if constexpr (std::is_same_v, half>) { + GKO_NOT_SUPPORTED(m_i); + } else +#endif + { + multidot_kernel<<get_stream()>>>( + size, nrhs, as_device_type(p_i), + as_device_type(g_k->get_const_values()), g_k->get_stride(), + as_device_type(m_i), stop_status->get_const_data()); + } } else { blas::dot(exec->get_blas_handle(), size, p_i, 1, g_k->get_const_values(), g_k->get_stride(), m_i);