diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c8d50ea86..59b0ac7b2 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -224,10 +224,8 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" - """Important: Could I use igemmlt on ROCm? """ if torch.version.hip: - #Well, lets currently disable it - return False + return True if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2146176ed..8858e846a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -452,13 +452,13 @@ def get_transform_buffer( rows = shape[0] * shape[1] cols = shape[-1] - state = (shape, to_order) if transpose: # swap dims tmp = rows rows = cols cols = tmp - state = (shape[::-1], to_order) + shape = shape[::-1] + state = (shape, to_order) if to_order == "row" or to_order == "col": return init_func(shape, dtype=dtype, device=device), state @@ -1952,6 +1952,8 @@ def mm_dequant( new_col_stats=None, bias=None ): + if HIP_ENVIRONMENT: + A, quant_state = nvidia_transform(A, "row", state = quant_state) assert A.dtype == torch.int32 if bias is not None: assert bias.dtype == torch.float16 out_shape = quant_state[0] @@ -2524,7 +2526,10 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] + if not HIP_ENVIRONMENT: + assert formatA in ["col_turing", "col_ampere"] + else: + assert formatA in ["col"] assert A.device.type == "cuda" out = torch.zeros( @@ -2539,7 +2544,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == 'col_turing': + if formatA == 'col_turing' or HIP_ENVIRONMENT: lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 458f7f1c0..723504fa8 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2300,13 +2300,16 @@ template __global__ void kd const int n_out = numRows*numCols; - int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + //int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); // we have tiles of size numRows*32, thus col only increases every numRows // num_row_tiles is the tiles after which the column increases by 32 // blockIdx.x is the index of the current tile - int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + //int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached - int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + //int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD @@ -2321,20 +2324,33 @@ template __global__ void kd int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; - float local_rowStats[ITEMS_PER_THREAD]; - __shared__ float smem_rowStats[SUBTILE_ROWS]; + //float local_rowStats[ITEMS_PER_THREAD]; + //__shared__ float smem_rowStats[SUBTILE_ROWS]; typedef hipcub::BlockLoad LoadInt32; - typedef hipcub::BlockExchange ExchangeInt32; + //typedef hipcub::BlockExchange ExchangeInt32; __shared__ typename LoadInt32::TempStorage loadint32; - __shared__ typename ExchangeInt32::TempStorage exchangeint32; + //__shared__ typename ExchangeInt32::TempStorage exchangeint32; // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - float colStat = col >= numCols ? 0.0f : colStats[col]; - float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); + //float colStat = col >= numCols ? 0.0f : colStats[col]; + //float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); + int row_idx, col_idx; + float colStat[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; + float rowStat[ITEMS_PER_THREAD]; + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; + colStat[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_biasValue[j] = ((bias == NULL) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); + rowStat[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + } // no block loads for rows for now -- keep it simple - for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + /*for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) { // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? int row = (base_row+j) % numRows; // wrap around @@ -2342,12 +2358,25 @@ template __global__ void kd // todo: update description about striped shared memory, it is not needed // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements smem_rowStats[j] = rowStats[row]; - } + }*/ __syncthreads(); + int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]); + // each block processes SUBTILE_ROWS*32 elements - const int items_per_load = THREADS*ITEMS_PER_THREAD; + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = block_offset + thread_offset + j; + if(outIdx< n_out) + out[outIdx] = local_output[j]; + } + /*const int items_per_load = THREADS*ITEMS_PER_THREAD; const int rows_per_load = items_per_load/32; int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile @@ -2368,7 +2397,7 @@ template __global__ void kd #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; - + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); @@ -2388,7 +2417,7 @@ template __global__ void kd } row_offset += rows_per_load; - } + }*/ } @@ -2974,7 +3003,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * { int local_colidx = idx[blockIdx.x]; - if(FORMAT==COL_TURING) + /*if(FORMAT==COL_TURING) { // TURING FORMAT: // 8*32 tiles with 4*4 subtiles @@ -3030,6 +3059,17 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * int out_idx = (row*idx_size) + blockIdx.x; out[out_idx] = val; } + }*/ + + //Only col format is used on ROCm + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + //col-major offset + int offset = local_colidx * rowsA + row; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; } } diff --git a/csrc/ops.hip b/csrc/ops.hip index 27e479573..cb0acf851 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -594,7 +594,9 @@ template int igemmlt(hipblasLtHandl } else { - has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_32F)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_8I)); + hipblasOperation_t opA = HIPBLAS_OP_N; + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA))); has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_8I, m, n, ldc)); has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); @@ -662,14 +664,18 @@ int fill_up_to_nearest_multiple(int value, int multiple) void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) { int threads = 512; - int tileCols = fill_up_to_nearest_multiple(numCols, 32); - int n = numRows*tileCols; - int subtile_rows = 128; - int tilesize = 32*subtile_rows; - int num_blocks = numRows/subtile_rows; - num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; - num_blocks = num_blocks*(tileCols/32); - assert(threads <= tilesize); + //int tileCols = fill_up_to_nearest_multiple(numCols, 32); + //int n = numRows*tileCols; + int tileCols = numCols; + int n = numRows*numCols; + //int subtile_rows = 128; + //int tilesize = 32*subtile_rows; + //int num_blocks = numRows/subtile_rows; + //num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + //num_blocks = num_blocks*(tileCols/32); + //assert(threads <= tilesize); + int num_blocks = numRows * numCols / (threads * 4); + num_blocks += (numRows * numCols) % (threads * 4) == 0 ? 0 : 1; hipLaunchKernelGGL(( kdequant_mm_int32_fp16<4, 128, 512>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); @@ -831,14 +837,17 @@ template void extractOutliers(char * A, int *idx, char *out, int id int num_blocks = idx_size; - if(FORMAT == COL_TURING) + /*if(FORMAT == COL_TURING) { tiledRows = fill_up_to_nearest_multiple(rows, 8); } else if(FORMAT == COL_AMPERE) { tiledRows = fill_up_to_nearest_multiple(rows, 32); - } + }*/ + + //for col format on ROCm + tiledRows = rows; hipLaunchKernelGGL(( kExtractOutliers), dim3(num_blocks), dim3(threads), 0, 0, A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); CUDA_CHECK_RETURN(hipPeekAtLastError()); diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 4c7e2b9df..7c28dc436 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -288,7 +288,6 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ) names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias", values, diff --git a/tests/test_functional.py b/tests/test_functional.py index f914820fe..01a4f3f77 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -967,7 +967,6 @@ def test_bench_8bit_training(batch, seq, model, hidden): values = list(product(dim1, dim4, dims, formatB, has_bias)) names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() @@ -1309,7 +1308,6 @@ def test_row_scale_bench(dim1, dim4, inner): for vals in values ] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize( "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, @@ -2042,7 +2040,6 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_extract_outliers(): for i in range(k): shapeA = (4096, 4096 * 4)