From ec32fc1c4067e6fb1f3f98a2bbccc859e796afdb Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:06:17 +0000 Subject: [PATCH 01/11] Enabled igemmlt in matmul --- bitsandbytes/autograd/_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c8d50ea86..59b0ac7b2 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -224,10 +224,8 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" - """Important: Could I use igemmlt on ROCm? """ if torch.version.hip: - #Well, lets currently disable it - return False + return True if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) From 4536b251209cbb2b6085a85ed24c8895a41ead0d Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:08:21 +0000 Subject: [PATCH 02/11] Fix shape issue in transform function --- bitsandbytes/functional.py | 4 ++-- tests/test_functional.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2146176ed..a0f2a6c49 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -452,13 +452,13 @@ def get_transform_buffer( rows = shape[0] * shape[1] cols = shape[-1] - state = (shape, to_order) if transpose: # swap dims tmp = rows rows = cols cols = tmp - state = (shape[::-1], to_order) + shape = shape[::-1] + state = (shape, to_order) if to_order == "row" or to_order == "col": return init_func(shape, dtype=dtype, device=device), state diff --git a/tests/test_functional.py b/tests/test_functional.py index f914820fe..05b52103e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1309,7 +1309,6 @@ def test_row_scale_bench(dim1, dim4, inner): for vals in values ] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize( "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, From 66e34c18d5cac3b28b46f09f00da2f25d41dac7f Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:11:50 +0000 Subject: [PATCH 03/11] Enable igemmlt int8 output --- bitsandbytes/functional.py | 7 +++++-- csrc/ops.hip | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a0f2a6c49..d19f88f83 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2524,7 +2524,10 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] + if not HIP_ENVIRONMENT: + assert formatA in ["col_turing", "col_ampere"] + else: + assert formatA in ["col"] assert A.device.type == "cuda" out = torch.zeros( @@ -2539,7 +2542,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == 'col_turing': + if formatA == 'col_turing' or HIP_ENVIRONMENT: lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) diff --git a/csrc/ops.hip b/csrc/ops.hip index 27e479573..8e1347840 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -594,7 +594,9 @@ template int igemmlt(hipblasLtHandl } else { - has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_32F)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_8I)); + hipblasOperation_t opA = HIPBLAS_OP_N; + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA))); has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_8I, m, n, ldc)); has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); From 7e5e223118fe2ce336613630721f7663c7e1530e Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:24:06 +0000 Subject: [PATCH 04/11] Add col format for extract outliers --- csrc/kernels.hip | 13 ++++++++++++- csrc/ops.hip | 7 +++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 458f7f1c0..50e66d87d 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2974,7 +2974,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * { int local_colidx = idx[blockIdx.x]; - if(FORMAT==COL_TURING) + /*if(FORMAT==COL_TURING) { // TURING FORMAT: // 8*32 tiles with 4*4 subtiles @@ -3030,6 +3030,17 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * int out_idx = (row*idx_size) + blockIdx.x; out[out_idx] = val; } + }*/ + + //Only col format is used on ROCm + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + //col-major offset + int offset = local_colidx * rowsA + row; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; } } diff --git a/csrc/ops.hip b/csrc/ops.hip index 8e1347840..41b29ba39 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -833,14 +833,17 @@ template void extractOutliers(char * A, int *idx, char *out, int id int num_blocks = idx_size; - if(FORMAT == COL_TURING) + /*if(FORMAT == COL_TURING) { tiledRows = fill_up_to_nearest_multiple(rows, 8); } else if(FORMAT == COL_AMPERE) { tiledRows = fill_up_to_nearest_multiple(rows, 32); - } + }*/ + + //for col format on ROCm + tiledRows = rows; hipLaunchKernelGGL(( kExtractOutliers), dim3(num_blocks), dim3(threads), 0, 0, A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); CUDA_CHECK_RETURN(hipPeekAtLastError()); From 2e42adb8c7993f466d63cff3b81accf007701610 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:25:56 +0000 Subject: [PATCH 05/11] Enable dequant_mm --- bitsandbytes/functional.py | 2 ++ csrc/kernels.hip | 57 ++++++++++++++++++++++++++++---------- csrc/ops.hip | 19 +++++++------ tests/test_functional.py | 1 - 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d19f88f83..8858e846a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1952,6 +1952,8 @@ def mm_dequant( new_col_stats=None, bias=None ): + if HIP_ENVIRONMENT: + A, quant_state = nvidia_transform(A, "row", state = quant_state) assert A.dtype == torch.int32 if bias is not None: assert bias.dtype == torch.float16 out_shape = quant_state[0] diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 50e66d87d..723504fa8 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2300,13 +2300,16 @@ template __global__ void kd const int n_out = numRows*numCols; - int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + //int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); // we have tiles of size numRows*32, thus col only increases every numRows // num_row_tiles is the tiles after which the column increases by 32 // blockIdx.x is the index of the current tile - int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + //int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached - int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + //int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD @@ -2321,20 +2324,33 @@ template __global__ void kd int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; - float local_rowStats[ITEMS_PER_THREAD]; - __shared__ float smem_rowStats[SUBTILE_ROWS]; + //float local_rowStats[ITEMS_PER_THREAD]; + //__shared__ float smem_rowStats[SUBTILE_ROWS]; typedef hipcub::BlockLoad LoadInt32; - typedef hipcub::BlockExchange ExchangeInt32; + //typedef hipcub::BlockExchange ExchangeInt32; __shared__ typename LoadInt32::TempStorage loadint32; - __shared__ typename ExchangeInt32::TempStorage exchangeint32; + //__shared__ typename ExchangeInt32::TempStorage exchangeint32; // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - float colStat = col >= numCols ? 0.0f : colStats[col]; - float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); + //float colStat = col >= numCols ? 0.0f : colStats[col]; + //float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); + int row_idx, col_idx; + float colStat[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; + float rowStat[ITEMS_PER_THREAD]; + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; + colStat[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_biasValue[j] = ((bias == NULL) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); + rowStat[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + } // no block loads for rows for now -- keep it simple - for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + /*for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) { // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? int row = (base_row+j) % numRows; // wrap around @@ -2342,12 +2358,25 @@ template __global__ void kd // todo: update description about striped shared memory, it is not needed // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements smem_rowStats[j] = rowStats[row]; - } + }*/ __syncthreads(); + int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]); + // each block processes SUBTILE_ROWS*32 elements - const int items_per_load = THREADS*ITEMS_PER_THREAD; + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = block_offset + thread_offset + j; + if(outIdx< n_out) + out[outIdx] = local_output[j]; + } + /*const int items_per_load = THREADS*ITEMS_PER_THREAD; const int rows_per_load = items_per_load/32; int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile @@ -2368,7 +2397,7 @@ template __global__ void kd #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; - + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); @@ -2388,7 +2417,7 @@ template __global__ void kd } row_offset += rows_per_load; - } + }*/ } diff --git a/csrc/ops.hip b/csrc/ops.hip index 41b29ba39..aa16e9c3f 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -664,14 +664,17 @@ int fill_up_to_nearest_multiple(int value, int multiple) void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) { int threads = 512; - int tileCols = fill_up_to_nearest_multiple(numCols, 32); - int n = numRows*tileCols; - int subtile_rows = 128; - int tilesize = 32*subtile_rows; - int num_blocks = numRows/subtile_rows; - num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; - num_blocks = num_blocks*(tileCols/32); - assert(threads <= tilesize); + //int tileCols = fill_up_to_nearest_multiple(numCols, 32); + //int n = numRows*tileCols; + int n = numRows*numCols; + //int subtile_rows = 128; + //int tilesize = 32*subtile_rows; + //int num_blocks = numRows/subtile_rows; + //num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + //num_blocks = num_blocks*(tileCols/32); + //assert(threads <= tilesize); + int num_blocks = numRows * numCols / (threads * 4); + num_blocks += (numRows * numCols) % (threads * 4) == 0 ? 0 : 1; hipLaunchKernelGGL(( kdequant_mm_int32_fp16<4, 128, 512>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); diff --git a/tests/test_functional.py b/tests/test_functional.py index 05b52103e..6d7ace64b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -967,7 +967,6 @@ def test_bench_8bit_training(batch, seq, model, hidden): values = list(product(dim1, dim4, dims, formatB, has_bias)) names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() From e32d2770bcfb2ad2e02d65d1ba81bb7bb4287799 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:26:42 +0000 Subject: [PATCH 06/11] Enable matmullt tests --- tests/test_autograd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 4c7e2b9df..7c28dc436 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -288,7 +288,6 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ) names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias", values, From 8206bd18cedc7555d5dd656db28c510d39b408ae Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:27:15 +0000 Subject: [PATCH 07/11] Enabled linear_serialization tests --- tests/test_linear8bitlt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 6d5fc6a82..b75fa4efd 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -68,7 +68,6 @@ def test_linear_no_igemmlt(): assert linear_custom.state.CxB is None -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", list(product([False, True], [False, True], [False, True], [False, True]))) def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): From 973a9f8c882612bb13935d39cf5bbfb21f17e907 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:28:57 +0000 Subject: [PATCH 08/11] fix error with dequant_mm change --- csrc/ops.hip | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/ops.hip b/csrc/ops.hip index aa16e9c3f..cb0acf851 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -666,6 +666,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, int threads = 512; //int tileCols = fill_up_to_nearest_multiple(numCols, 32); //int n = numRows*tileCols; + int tileCols = numCols; int n = numRows*numCols; //int subtile_rows = 128; //int tilesize = 32*subtile_rows; From 387a9b79659b7ca999a12d42b366629fcdc11079 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:36:15 +0000 Subject: [PATCH 09/11] Enable extract outliers test --- tests/test_functional.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 6d7ace64b..01a4f3f77 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2040,7 +2040,6 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_extract_outliers(): for i in range(k): shapeA = (4096, 4096 * 4) From 93dfb51a012786e33c737542336054f3fa2174ee Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:38:14 +0000 Subject: [PATCH 10/11] Enable test overflow --- tests/test_functional.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 01a4f3f77..291a2ea24 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1356,7 +1356,6 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for vals in values ] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_overflow(): formatB = F.get_special_format_str() print(formatB) From 90bbdc609291ebf3a263134f9091be191214f0be Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 18:47:57 +0000 Subject: [PATCH 11/11] Skip overflow and linear serialization for now --- tests/test_functional.py | 1 + tests/test_linear8bitlt.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_functional.py b/tests/test_functional.py index 291a2ea24..01a4f3f77 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1356,6 +1356,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for vals in values ] +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_overflow(): formatB = F.get_special_format_str() print(formatB) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index b75fa4efd..6d5fc6a82 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -68,6 +68,7 @@ def test_linear_no_igemmlt(): assert linear_custom.state.CxB is None +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", list(product([False, True], [False, True], [False, True], [False, True]))) def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):