Skip to content

Commit

Permalink
hipify wmma datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
Lzy17 committed Feb 7, 2024
1 parent a84c369 commit b044010
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -3093,9 +3093,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
//__shared__ T smem_C[8*32];

rocwmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
rocwmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
rocwmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fragment<rocwmma::matrix_a, 8, 32, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 8, 32, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fill_fragment(c_frag, 0.0f);

int ticktock = 0;
Expand Down Expand Up @@ -3272,7 +3272,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,

// 129 mu
if(warp_id == (WARPS-1))
rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major);

if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
Expand Down Expand Up @@ -3322,9 +3322,9 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
__shared__ T smem_C[8*32];

rocwmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
rocwmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
rocwmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fragment<rocwmma::matrix_a, 8, 32, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 8, 32, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fill_fragment(c_frag, 0.0f);

for(int i = threadIdx.x; i < (8*32); i+=blockDim.x)
Expand Down Expand Up @@ -3494,7 +3494,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i

// 129 mu
if(warp_id == (WARPS-1))
rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major);

//printnonzero<T>(smem_C, 32, "");

Expand Down

0 comments on commit b044010

Please sign in to comment.