Skip to content

Commit

Permalink
Merge pull request #6 from ROCm/rocwmma_merge
Browse files Browse the repository at this point in the history
Fix wmma api parity
  • Loading branch information
Lzy17 authored Feb 19, 2024
2 parents ffb0c5d + b044010 commit 2b77380
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -3095,10 +3095,10 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
//__shared__ T smem_C[8*32];

wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
rocwmma::fragment<rocwmma::matrix_a, 8, 32, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 8, 32, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fill_fragment(c_frag, 0.0f);

int ticktock = 0;
int idx = 0 + threadIdx.x;
Expand Down Expand Up @@ -3253,9 +3253,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}

Expand All @@ -3267,14 +3267,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

// 129 mu
if(warp_id == (WARPS-1))
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major);

if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
Expand Down Expand Up @@ -3324,10 +3324,10 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
__shared__ T smem_C[8*32];

wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
rocwmma::fragment<rocwmma::matrix_a, 8, 32, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 8, 32, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fill_fragment(c_frag, 0.0f);

for(int i = threadIdx.x; i < (8*32); i+=blockDim.x)
smem_C[i] = 0.0f;
Expand Down Expand Up @@ -3468,9 +3468,9 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}

Expand All @@ -3489,14 +3489,14 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
{
//if(warp_lane == 0)
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

// 129 mu
if(warp_id == (WARPS-1))
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major);

//printnonzero<T>(smem_C, 32, "");

Expand Down

0 comments on commit 2b77380

Please sign in to comment.