diff --git a/csrc/kernels.hip b/csrc/kernels.hip index f48f8b991..1f8c97e32 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -3093,9 +3093,9 @@ template __global__ void gemm_device(int M, __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; //__shared__ T smem_C[8*32]; - rocwmma::fragment a_frag; - rocwmma::fragment b_frag; - rocwmma::fragment c_frag; + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; rocwmma::fill_fragment(c_frag, 0.0f); int ticktock = 0; @@ -3272,7 +3272,7 @@ template __global__ void gemm_device(int M, // 129 mu if(warp_id == (WARPS-1)) - rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major); if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; @@ -3322,9 +3322,9 @@ template __global__ void kgemm_4bit_inference(int M, i __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; __shared__ T smem_C[8*32]; - rocwmma::fragment a_frag; - rocwmma::fragment b_frag; - rocwmma::fragment c_frag; + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; rocwmma::fill_fragment(c_frag, 0.0f); for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) @@ -3494,7 +3494,7 @@ template __global__ void kgemm_4bit_inference(int M, i // 129 mu if(warp_id == (WARPS-1)) - rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); + rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major); //printnonzero(smem_C, 32, "");