Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable matmul function #13

Merged
merged 11 commits into from
Mar 12, 2024
4 changes: 1 addition & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,8 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
"""Important: Could I use igemmlt on ROCm? """
if torch.version.hip:
#Well, lets currently disable it
return False
return True
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
Expand Down
13 changes: 9 additions & 4 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,13 @@ def get_transform_buffer(
rows = shape[0] * shape[1]
cols = shape[-1]

state = (shape, to_order)
if transpose:
# swap dims
tmp = rows
rows = cols
cols = tmp
state = (shape[::-1], to_order)
shape = shape[::-1]
state = (shape, to_order)

if to_order == "row" or to_order == "col":
return init_func(shape, dtype=dtype, device=device), state
Expand Down Expand Up @@ -1952,6 +1952,8 @@ def mm_dequant(
new_col_stats=None,
bias=None
):
if HIP_ENVIRONMENT:
A, quant_state = nvidia_transform(A, "row", state = quant_state)
assert A.dtype == torch.int32
if bias is not None: assert bias.dtype == torch.float16
out_shape = quant_state[0]
Expand Down Expand Up @@ -2524,7 +2526,10 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
def extract_outliers(A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
if not HIP_ENVIRONMENT:
assert formatA in ["col_turing", "col_ampere"]
else:
assert formatA in ["col"]
assert A.device.type == "cuda"

out = torch.zeros(
Expand All @@ -2539,7 +2544,7 @@ def extract_outliers(A, SA, idx):
ptrOut = get_ptr(out)

prev_device = pre_call(A.device)
if formatA == 'col_turing':
if formatA == 'col_turing' or HIP_ENVIRONMENT:
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
Expand Down
70 changes: 55 additions & 15 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -2300,13 +2300,16 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd

const int n_out = numRows*numCols;

int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
//int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
// we have tiles of size numRows*32, thus col only increases every numRows
// num_row_tiles is the tiles after which the column increases by 32
// blockIdx.x is the index of the current tile
int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
//int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
// base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
//int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);

int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD;
int thread_offset = threadIdx.x * ITEMS_PER_THREAD;

// SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
// subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
Expand All @@ -2321,33 +2324,59 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd

int local_values[ITEMS_PER_THREAD];
half local_output[ITEMS_PER_THREAD];
float local_rowStats[ITEMS_PER_THREAD];
__shared__ float smem_rowStats[SUBTILE_ROWS];
//float local_rowStats[ITEMS_PER_THREAD];
//__shared__ float smem_rowStats[SUBTILE_ROWS];

typedef hipcub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, hipcub::BLOCK_LOAD_DIRECT> LoadInt32;
typedef hipcub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
//typedef hipcub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
__shared__ typename LoadInt32::TempStorage loadint32;
__shared__ typename ExchangeInt32::TempStorage exchangeint32;
//__shared__ typename ExchangeInt32::TempStorage exchangeint32;


// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
float colStat = col >= numCols ? 0.0f : colStats[col];
float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
//float colStat = col >= numCols ? 0.0f : colStats[col];
//float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
int row_idx, col_idx;
float colStat[ITEMS_PER_THREAD];
float local_biasValue[ITEMS_PER_THREAD];
float rowStat[ITEMS_PER_THREAD];
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
row_idx = (block_offset + thread_offset + j) / numCols;
col_idx = (block_offset + thread_offset + j) % numCols;
colStat[j] = col_idx >= numCols ? 0.0f : colStats[col_idx];
local_biasValue[j] = ((bias == NULL) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]);
rowStat[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx];
}
// no block loads for rows for now -- keep it simple
for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
/*for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
{
// todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
int row = (base_row+j) % numRows; // wrap around
// each warp accesses the same element, for four consequitive elements
// todo: update description about striped shared memory, it is not needed
// rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
smem_rowStats[j] = rowStats[row];
}
}*/
__syncthreads();

int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset;
LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0);

#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]);

// each block processes SUBTILE_ROWS*32 elements
const int items_per_load = THREADS*ITEMS_PER_THREAD;
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int outIdx = block_offset + thread_offset + j;
if(outIdx< n_out)
out[outIdx] = local_output[j];
}
/*const int items_per_load = THREADS*ITEMS_PER_THREAD;
const int rows_per_load = items_per_load/32;

int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
Expand All @@ -2368,7 +2397,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];

#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue);
Expand All @@ -2388,7 +2417,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
}

row_offset += rows_per_load;
}
}*/
}


Expand Down Expand Up @@ -2974,7 +3003,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
{
int local_colidx = idx[blockIdx.x];

if(FORMAT==COL_TURING)
/*if(FORMAT==COL_TURING)
{
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
Expand Down Expand Up @@ -3030,6 +3059,17 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
}*/

//Only col format is used on ROCm
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
{
//col-major offset
int offset = local_colidx * rowsA + row;

char val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
}

Expand Down
31 changes: 20 additions & 11 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,9 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
}
else
{
has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_32F));
has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_8I));
hipblasOperation_t opA = HIPBLAS_OP_N;
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA)));
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_8I, m, n, ldc));
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
Expand Down Expand Up @@ -662,14 +664,18 @@ int fill_up_to_nearest_multiple(int value, int multiple)
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols)
{
int threads = 512;
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
int n = numRows*tileCols;
int subtile_rows = 128;
int tilesize = 32*subtile_rows;
int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32);
assert(threads <= tilesize);
//int tileCols = fill_up_to_nearest_multiple(numCols, 32);
//int n = numRows*tileCols;
int tileCols = numCols;
int n = numRows*numCols;
//int subtile_rows = 128;
//int tilesize = 32*subtile_rows;
//int num_blocks = numRows/subtile_rows;
//num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
//num_blocks = num_blocks*(tileCols/32);
//assert(threads <= tilesize);
int num_blocks = numRows * numCols / (threads * 4);
num_blocks += (numRows * numCols) % (threads * 4) == 0 ? 0 : 1;

hipLaunchKernelGGL(( kdequant_mm_int32_fp16<4, 128, 512>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
Expand Down Expand Up @@ -831,14 +837,17 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id

int num_blocks = idx_size;

if(FORMAT == COL_TURING)
/*if(FORMAT == COL_TURING)
{
tiledRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
tiledRows = fill_up_to_nearest_multiple(rows, 32);
}
}*/

//for col format on ROCm
tiledRows = rows;

hipLaunchKernelGGL(( kExtractOutliers<FORMAT>), dim3(num_blocks), dim3(threads), 0, 0, A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(hipPeekAtLastError());
Expand Down
1 change: 0 additions & 1 deletion tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
)
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]

@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
values,
Expand Down
3 changes: 0 additions & 3 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]

@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
inner = torch.randint(1, 128, size=(1,)).item()
Expand Down Expand Up @@ -1309,7 +1308,6 @@ def test_row_scale_bench(dim1, dim4, inner):
for vals in values
]

@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.parametrize(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
values,
Expand Down Expand Up @@ -2042,7 +2040,6 @@ def quant_zp(x):
print(err1, err2, err3, err4, err5, err6)


@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_extract_outliers():
for i in range(k):
shapeA = (4096, 4096 * 4)
Expand Down
Loading