Skip to content

Commit

Permalink
update extract_outliers, quantize_4bit, dequantize_4bit
Browse files Browse the repository at this point in the history
  • Loading branch information
lcskrishna committed May 6, 2024
1 parent 7835282 commit 765bfc8
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ def mm_dequant(
def extract_outliers(self, A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
if not HIP_ENVIRONMENT:
assert formatA in ["col_turing", "col_ampere"]
else:
assert formatA in ["col"]
assert A.device.type == "cuda"

out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
Expand All @@ -336,7 +339,7 @@ def extract_outliers(self, A, SA, idx):

prev_device = pre_call(A.device)

if formatA == "col_turing":
if formatA == "col_turing" or HIP_ENVIRONMENT::
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
Expand All @@ -355,6 +358,8 @@ def quantize_4bit(
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
if HIP_ENVIRONMENT:
blocksize = 128
if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]:
Expand All @@ -372,7 +377,10 @@ def quantize_4bit(
mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
if not HIP_ENVIRONMENT:
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
else:
assert blocksize in [4096, 2048, 1024, 512, 256, 128]

prev_device = pre_call(A.device)
is_on_gpu([A, out, absmax])
Expand Down Expand Up @@ -446,9 +454,14 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type="fp4",
) -> torch.Tensor:
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
if HIP_ENVIRONMENT:
blocksize = 128
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
if HIP_ENVIRONMENT:
supported_blocksizes = supported_blocksizes[:-1]
if blocksize not in supported_blocksizes:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
)

if quant_type not in ["fp4", "nf4"]:
Expand Down

0 comments on commit 765bfc8

Please sign in to comment.