diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index 5914754b..c93b1e47 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -41,6 +41,7 @@ MARLIN_AVAILABLE, QIGEN_AVAILABLE, TRITON_AVAILABLE, + HPU_AVAILABLE, dynamically_import_QuantLinear, ) from ..utils.marlin_utils import ( @@ -48,7 +49,7 @@ _validate_marlin_device_support, prepare_model_for_marlin_load, ) -from ._const import CPU, CUDA_0, SUPPORTED_MODELS +from ._const import CPU, CUDA_0, HPU, SUPPORTED_MODELS from ._utils import ( autogptq_post_init, find_layers, @@ -75,6 +76,7 @@ logger.addHandler(handler) logger.setLevel(logging.INFO) +device_to = CUDA_0 if not HPU_AVAILABLE else HPU def nested_move_to_device(v, device): if isinstance(v, torch.Tensor): @@ -198,7 +200,7 @@ def quantize( logger.info(f"truly offloading {name} to cpu with hook.") module = get_module_by_name_suffix(self.model, name) remove_hook_from_module(module, recurse=True) - accelerate.cpu_offload_with_hook(module, CUDA_0) + accelerate.cpu_offload_with_hook(module, device_to) layer_inputs = [] attention_masks = [] @@ -244,7 +246,7 @@ def store_input_hook(_, args, kwargs): force_layer_back_to_cpu = False if get_device(layers[0]) == CPU: - layers[0] = layers[0].to(CUDA_0) + layers[0] = layers[0].to(device_to) force_layer_back_to_cpu = True ori_outside_layer_module_devices = {} @@ -277,7 +279,8 @@ def store_input_hook(_, args, kwargs): if module is not None: move_to_device(module, ori_outside_layer_module_devices[module_name]) - torch.cuda.empty_cache() + if not HPU_AVAILABLE: + torch.cuda.empty_cache() inside_layer_modules = self.inside_layer_modules if not self.quantize_config.true_sequential: @@ -288,7 +291,7 @@ def store_input_hook(_, args, kwargs): layer = layers[i] force_layer_back_to_cpu = False if get_device(layer) == CPU: - move_to_device(layer, CUDA_0) + move_to_device(layer, device_to) force_layer_back_to_cpu = True cur_layer_device = get_device(layer) @@ -372,7 +375,8 @@ def tmp(_, inp, out): del gptq del layer_inputs layer_inputs, layer_outputs = layer_outputs, [] # TODO: is it really OK to cache only the first positional argument? - torch.cuda.empty_cache() + if not HPU_AVAILABLE: + torch.cuda.empty_cache() pack_model( model=self.model, @@ -393,7 +397,8 @@ def tmp(_, inp, out): self._quantized = True - torch.cuda.empty_cache() + if not HPU_AVAILABLE: + torch.cuda.empty_cache() @property def device(self): @@ -597,8 +602,12 @@ def from_pretrained( ): """load un-quantized pretrained model to cpu""" - if not torch.cuda.is_available(): + if not HPU_AVAILABLE and not torch.cuda.is_available(): raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") + elif HPU_AVAILABLE: + import habana_frameworks.torch.hpu as hpu + if not hpu.is_available(): + raise EnvironmentError("Load pretrained model to do quantization requires HPU available.") def skip(*args, **kwargs): pass @@ -666,7 +675,8 @@ def skip(*args, **kwargs): model_init_kwargs["device_map"] = None model_init_kwargs["low_cpu_mem_usage"] = False - torch.cuda.empty_cache() + if not HPU_AVAILABLE: + torch.cuda.empty_cache() merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) @@ -816,7 +826,10 @@ def from_quantized( # format marlin requires marlin kernel use_marlin = True - marlin_compatible = _validate_marlin_device_support() + if not HPU_AVAILABLE: + marlin_compatible = _validate_marlin_device_support() + else: + marlin_compatible = False if use_marlin and not MARLIN_AVAILABLE: raise TypeError("use_marlin is true but Marlin is not available due to cuda/device support.") diff --git a/auto_gptq/modeling/_const.py b/auto_gptq/modeling/_const.py index 5355f6ee..f24fbec2 100644 --- a/auto_gptq/modeling/_const.py +++ b/auto_gptq/modeling/_const.py @@ -5,6 +5,7 @@ CPU = device("cpu") CUDA_0 = device("cuda:0") +HPU = device("hpu") SUPPORTED_MODELS = [ "bloom", @@ -49,4 +50,4 @@ EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 -__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"] +__all__ = ["CPU", "CUDA_0", "HPU", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"] diff --git a/auto_gptq/nn_modules/qlinear/qlinear_hpu.py b/auto_gptq/nn_modules/qlinear/qlinear_hpu.py index 757f3aaf..4447969d 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_hpu.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_hpu.py @@ -114,14 +114,96 @@ def post_init(self): self._preprocessing() def pack(self, linear, scales, zeros, g_idx): - #TODO: implement - raise NotImplementedError("QuantLinear HPU currently doesn't support packing") - - def set_packed(self, qlinear_cls): - self.qweight = qlinear_cls.qweight - self.qzeros = qlinear_cls.qzeros - self.scales = qlinear_cls.scales - self.bias = qlinear_cls.bias + W = linear.weight.data.clone() + if isinstance(linear, nn.Conv2d): + W = W.flatten(1) + if isinstance(linear, transformers.pytorch_utils.Conv1D): + W = W.t() + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().to(dtype=linear.weight.dtype) + if linear.bias is not None: + self.bias = linear.bias.clone().to(dtype=linear.weight.dtype) + + intweight = [] + for idx in range(self.infeatures): + g_idx = idx // self.group_size + intweight.append(torch.round((W[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + + i = 0 + row = 0 + qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) def forward(self, x): x_dtype = x.dtype diff --git a/auto_gptq/quantization/gptq.py b/auto_gptq/quantization/gptq.py index cda3e7ac..cd9ea47a 100644 --- a/auto_gptq/quantization/gptq.py +++ b/auto_gptq/quantization/gptq.py @@ -9,6 +9,10 @@ from .quantizer import Quantizer +from ..utils.import_utils import ( + HPU_AVAILABLE, +) + logger = getLogger(__name__) @@ -166,7 +170,11 @@ def fasterquant( logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) logger.debug(torch.sum(Losses)) - torch.cuda.synchronize() + if not HPU_AVAILABLE: + torch.cuda.synchronize() + else: + import habana_frameworks.torch.hpu as hpu + hpu.synchronize() logger.info(f"duration: {(time.time() - tick)}") logger.info(f"avg loss: {torch.sum(Losses).item() / self.nsamples}") @@ -200,7 +208,8 @@ def free(self): self.H = None self.Losses = None self.Trace = None - torch.cuda.empty_cache() + if not HPU_AVAILABLE: + torch.cuda.empty_cache() __all__ = ["GPTQ"] diff --git a/auto_gptq/utils/import_utils.py b/auto_gptq/utils/import_utils.py index 0f0f1f58..d13004ea 100644 --- a/auto_gptq/utils/import_utils.py +++ b/auto_gptq/utils/import_utils.py @@ -52,6 +52,12 @@ MARLIN_AVAILABLE = False MARLIN_EXCEPTION = e +try: + import habana_frameworks.torch.hpu # noqa: F401 + HPU_AVAILABLE = True +except Exception as e: + HPU_AVAILABLE = False + logger = getLogger(__name__) @@ -67,11 +73,7 @@ def dynamically_import_QuantLinear( use_marlin: bool = False, use_tritonv2: bool = False, ): - try: - import habana_frameworks.torch.hpu # noqa: F401 - except Exception as e: - pass - else: + if HPU_AVAILABLE: from ..nn_modules.qlinear.qlinear_hpu import QuantLinear return QuantLinear if use_qigen: diff --git a/tests/test_hpu_linear.py b/tests/test_hpu_linear.py index 62a37f5b..adb903a2 100644 --- a/tests/test_hpu_linear.py +++ b/tests/test_hpu_linear.py @@ -159,16 +159,14 @@ def test_qlinear_hpu(bits, group_size, infeatures, outfeatures, bias, scales_val zeros = torch.full((infeatures // group_size, outfeatures), 1, dtype=torch.int32) htcore.mark_step() + quant_hpu.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None) + htcore.mark_step() + quant_hpu.to("hpu") quant_ref_cuda_old.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None) htcore.mark_step() quant_ref_cuda_old.to("hpu") - #TODO: pack independently - quant_hpu.set_packed(quant_ref_cuda_old) - htcore.mark_step() - quant_hpu.to("hpu") - out_ref_cuda_old = quant_ref_cuda_old(input) htcore.mark_step() quant_hpu.post_init() diff --git a/tests/test_q4.py b/tests/test_q4.py index b367eb5b..e225b187 100644 --- a/tests/test_q4.py +++ b/tests/test_q4.py @@ -7,6 +7,7 @@ from auto_gptq.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as TritonV2QuantLinear from auto_gptq.utils.import_utils import dynamically_import_QuantLinear +import habana_frameworks.torch.core as htcore try: @@ -2295,14 +2296,11 @@ def test_bias(self, in_device, model_dtype): self.skipTest("Can not run this test on HPU") else: raise e - for _, param in model_q.named_parameters(): self.assertTrue(param.device != torch.device("meta")) for _, param in model_q.named_buffers(): self.assertTrue(param.device != torch.device("meta")) - - self.assertTrue(torch.count_nonzero(model_q.model.transformer.h[0].attn.c_proj.bias) > 0) self.assertTrue(torch.count_nonzero(model_q.model.transformer.h[0].attn.c_attn.bias) > 0) tokenizer_kwargs = { diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 1fac9775..bcbba953 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -9,10 +9,18 @@ from auto_gptq import AutoGPTQForCausalLM from auto_gptq.quantization import CHECKPOINT_FORMAT, QUANT_CONFIG_FILENAME, BaseQuantizeConfig +try: + import habana_frameworks.torch.core as htcore + HPU_AVAILABLE = True +except Exception as e: + HPU_AVAILABLE = False class TestQuantization(unittest.TestCase): @parameterized.expand([(False,), (True,)]) def test_quantize(self, use_marlin: bool): + if HPU_AVAILABLE and use_marlin: + return unittest.skip(reason="HPU does not support marlin.") + pretrained_model_dir = "saibo/llama-1B" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) @@ -43,9 +51,11 @@ def test_quantize(self, use_marlin: bool): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0", use_marlin=use_marlin) + device = None if HPU_AVAILABLE else "cuda:0" + model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device=device, use_marlin=use_marlin) del model - torch.cuda.empty_cache() + if not HPU_AVAILABLE: + torch.cuda.empty_cache() # test compat: 1) with simple dict type 2) is_marlin_format compat_quantize_config = { @@ -54,11 +64,12 @@ def test_quantize(self, use_marlin: bool): "desc_act": False, "is_marlin_format": use_marlin, } - model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0", quantize_config=compat_quantize_config) + model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device=device, quantize_config=compat_quantize_config) assert(isinstance(model.quantize_config, BaseQuantizeConfig)) del model - torch.cuda.empty_cache() + if not HPU_AVAILABLE: + torch.cuda.empty_cache() # test checkinpoint_format hint to from_quantized() os.remove(f"{tmpdirname}/{QUANT_CONFIG_FILENAME}") @@ -68,7 +79,7 @@ def test_quantize(self, use_marlin: bool): "group_size": 128, "desc_act": False, } - model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0", + model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device=device, quantize_config=compat_quantize_config, checkpoint_format=CHECKPOINT_FORMAT.MARLIN if use_marlin else None) assert (isinstance(model.quantize_config, BaseQuantizeConfig))