diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 026a4c9a9..7cc88326d 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -322,7 +322,10 @@ def mm_dequant( def extract_outliers(self, A, SA, idx): shapeA = SA[0] formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] + if not HIP_ENVIRONMENT: + assert formatA in ["col_turing", "col_ampere"] + else: + assert formatA in ["col"] assert A.device.type == "cuda" out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) @@ -336,7 +339,7 @@ def extract_outliers(self, A, SA, idx): prev_device = pre_call(A.device) - if formatA == "col_turing": + if formatA == "col_turing" or HIP_ENVIRONMENT:: lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -355,6 +358,8 @@ def quantize_4bit( quant_type="fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: + if HIP_ENVIRONMENT: + blocksize = 128 if A.device.type != "cuda": raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") if quant_type not in ["fp4", "nf4"]: @@ -372,7 +377,10 @@ def quantize_4bit( mod = dtype2bytes[quant_storage] * 2 out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + if not HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] prev_device = pre_call(A.device) is_on_gpu([A, out, absmax]) @@ -446,9 +454,14 @@ def dequantize_4bit( blocksize: int = 64, quant_type="fp4", ) -> torch.Tensor: - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + if HIP_ENVIRONMENT: + blocksize = 128 + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] + if HIP_ENVIRONMENT: + supported_blocksizes = supported_blocksizes[:-1] + if blocksize not in supported_blocksizes: raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" + f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}" ) if quant_type not in ["fp4", "nf4"]: