Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pack hpu and from_quantized #7

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@
MARLIN_AVAILABLE,
QIGEN_AVAILABLE,
TRITON_AVAILABLE,
HPU_AVAILABLE,
dynamically_import_QuantLinear,
)
from ..utils.marlin_utils import (
_validate_marlin_compatibility,
_validate_marlin_device_support,
prepare_model_for_marlin_load,
)
from ._const import CPU, CUDA_0, SUPPORTED_MODELS
from ._const import CPU, CUDA_0, HPU, SUPPORTED_MODELS
from ._utils import (
autogptq_post_init,
find_layers,
Expand All @@ -75,6 +76,7 @@
logger.addHandler(handler)
logger.setLevel(logging.INFO)

device_to = CUDA_0 if not HPU_AVAILABLE else HPU

def nested_move_to_device(v, device):
if isinstance(v, torch.Tensor):
Expand Down Expand Up @@ -198,7 +200,7 @@ def quantize(
logger.info(f"truly offloading {name} to cpu with hook.")
module = get_module_by_name_suffix(self.model, name)
remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, CUDA_0)
accelerate.cpu_offload_with_hook(module, device_to)

layer_inputs = []
attention_masks = []
Expand Down Expand Up @@ -244,7 +246,7 @@ def store_input_hook(_, args, kwargs):

force_layer_back_to_cpu = False
if get_device(layers[0]) == CPU:
layers[0] = layers[0].to(CUDA_0)
layers[0] = layers[0].to(device_to)
force_layer_back_to_cpu = True

ori_outside_layer_module_devices = {}
Expand Down Expand Up @@ -277,7 +279,8 @@ def store_input_hook(_, args, kwargs):
if module is not None:
move_to_device(module, ori_outside_layer_module_devices[module_name])

torch.cuda.empty_cache()
if not HPU_AVAILABLE:
torch.cuda.empty_cache()

inside_layer_modules = self.inside_layer_modules
if not self.quantize_config.true_sequential:
Expand All @@ -288,7 +291,7 @@ def store_input_hook(_, args, kwargs):
layer = layers[i]
force_layer_back_to_cpu = False
if get_device(layer) == CPU:
move_to_device(layer, CUDA_0)
move_to_device(layer, device_to)
force_layer_back_to_cpu = True
cur_layer_device = get_device(layer)

Expand Down Expand Up @@ -372,7 +375,8 @@ def tmp(_, inp, out):
del gptq
del layer_inputs
layer_inputs, layer_outputs = layer_outputs, [] # TODO: is it really OK to cache only the first positional argument?
torch.cuda.empty_cache()
if not HPU_AVAILABLE:
torch.cuda.empty_cache()

pack_model(
model=self.model,
Expand All @@ -393,7 +397,8 @@ def tmp(_, inp, out):

self._quantized = True

torch.cuda.empty_cache()
if not HPU_AVAILABLE:
torch.cuda.empty_cache()

@property
def device(self):
Expand Down Expand Up @@ -597,8 +602,12 @@ def from_pretrained(
):
"""load un-quantized pretrained model to cpu"""

if not torch.cuda.is_available():
if not HPU_AVAILABLE and not torch.cuda.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
elif HPU_AVAILABLE:
import habana_frameworks.torch.hpu as hpu
if not hpu.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires HPU available.")

def skip(*args, **kwargs):
pass
Expand Down Expand Up @@ -666,7 +675,8 @@ def skip(*args, **kwargs):
model_init_kwargs["device_map"] = None
model_init_kwargs["low_cpu_mem_usage"] = False

torch.cuda.empty_cache()
if not HPU_AVAILABLE:
torch.cuda.empty_cache()

merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
Expand Down Expand Up @@ -816,7 +826,10 @@ def from_quantized(
# format marlin requires marlin kernel
use_marlin = True

marlin_compatible = _validate_marlin_device_support()
if not HPU_AVAILABLE:
marlin_compatible = _validate_marlin_device_support()
else:
marlin_compatible = False
if use_marlin and not MARLIN_AVAILABLE:
raise TypeError("use_marlin is true but Marlin is not available due to cuda/device support.")

Expand Down
3 changes: 2 additions & 1 deletion auto_gptq/modeling/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

CPU = device("cpu")
CUDA_0 = device("cuda:0")
HPU = device("hpu")

SUPPORTED_MODELS = [
"bloom",
Expand Down Expand Up @@ -49,4 +50,4 @@

EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]
__all__ = ["CPU", "CUDA_0", "HPU", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]
98 changes: 90 additions & 8 deletions auto_gptq/nn_modules/qlinear/qlinear_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,96 @@ def post_init(self):
self._preprocessing()

def pack(self, linear, scales, zeros, g_idx):
#TODO: implement
raise NotImplementedError("QuantLinear HPU currently doesn't support packing")

def set_packed(self, qlinear_cls):
self.qweight = qlinear_cls.qweight
self.qzeros = qlinear_cls.qzeros
self.scales = qlinear_cls.scales
self.bias = qlinear_cls.bias
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().to(dtype=linear.weight.dtype)
if linear.bias is not None:
self.bias = linear.bias.clone().to(dtype=linear.weight.dtype)

intweight = []
for idx in range(self.infeatures):
g_idx = idx // self.group_size
intweight.append(torch.round((W[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)

i = 0
row = 0
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
elif self.bits == 3:
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i))
i += 10
qweight[row] |= intweight[i] << 30
row += 1
qweight[row] |= (intweight[i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
i += 10
qweight[row] |= intweight[i] << 31
row += 1
qweight[row] |= (intweight[i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
i += 10
row += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
elif self.bits == 3:
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)

def forward(self, x):
x_dtype = x.dtype
Expand Down
13 changes: 11 additions & 2 deletions auto_gptq/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from .quantizer import Quantizer

from ..utils.import_utils import (
HPU_AVAILABLE,
)


logger = getLogger(__name__)

Expand Down Expand Up @@ -166,7 +170,11 @@ def fasterquant(
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))

torch.cuda.synchronize()
if not HPU_AVAILABLE:
torch.cuda.synchronize()
else:
import habana_frameworks.torch.hpu as hpu
hpu.synchronize()
logger.info(f"duration: {(time.time() - tick)}")
logger.info(f"avg loss: {torch.sum(Losses).item() / self.nsamples}")

Expand Down Expand Up @@ -200,7 +208,8 @@ def free(self):
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
if not HPU_AVAILABLE:
torch.cuda.empty_cache()


__all__ = ["GPTQ"]
12 changes: 7 additions & 5 deletions auto_gptq/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
MARLIN_AVAILABLE = False
MARLIN_EXCEPTION = e

try:
import habana_frameworks.torch.hpu # noqa: F401
HPU_AVAILABLE = True
except Exception as e:
HPU_AVAILABLE = False


logger = getLogger(__name__)

Expand All @@ -67,11 +73,7 @@ def dynamically_import_QuantLinear(
use_marlin: bool = False,
use_tritonv2: bool = False,
):
try:
import habana_frameworks.torch.hpu # noqa: F401
except Exception as e:
pass
else:
if HPU_AVAILABLE:
from ..nn_modules.qlinear.qlinear_hpu import QuantLinear
return QuantLinear
if use_qigen:
Expand Down
8 changes: 3 additions & 5 deletions tests/test_hpu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,14 @@ def test_qlinear_hpu(bits, group_size, infeatures, outfeatures, bias, scales_val
zeros = torch.full((infeatures // group_size, outfeatures), 1, dtype=torch.int32)

htcore.mark_step()
quant_hpu.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None)
htcore.mark_step()
quant_hpu.to("hpu")

quant_ref_cuda_old.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None)
htcore.mark_step()
quant_ref_cuda_old.to("hpu")

#TODO: pack independently
quant_hpu.set_packed(quant_ref_cuda_old)
htcore.mark_step()
quant_hpu.to("hpu")

out_ref_cuda_old = quant_ref_cuda_old(input)
htcore.mark_step()
quant_hpu.post_init()
Expand Down
4 changes: 1 addition & 3 deletions tests/test_q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from auto_gptq.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as TritonV2QuantLinear
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
import habana_frameworks.torch.core as htcore


try:
Expand Down Expand Up @@ -2295,14 +2296,11 @@ def test_bias(self, in_device, model_dtype):
self.skipTest("Can not run this test on HPU")
else:
raise e

for _, param in model_q.named_parameters():
self.assertTrue(param.device != torch.device("meta"))

for _, param in model_q.named_buffers():
self.assertTrue(param.device != torch.device("meta"))

self.assertTrue(torch.count_nonzero(model_q.model.transformer.h[0].attn.c_proj.bias) > 0)
self.assertTrue(torch.count_nonzero(model_q.model.transformer.h[0].attn.c_attn.bias) > 0)

tokenizer_kwargs = {
Expand Down
21 changes: 16 additions & 5 deletions tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
from auto_gptq import AutoGPTQForCausalLM
from auto_gptq.quantization import CHECKPOINT_FORMAT, QUANT_CONFIG_FILENAME, BaseQuantizeConfig

try:
import habana_frameworks.torch.core as htcore
HPU_AVAILABLE = True
except Exception as e:
HPU_AVAILABLE = False

class TestQuantization(unittest.TestCase):
@parameterized.expand([(False,), (True,)])
def test_quantize(self, use_marlin: bool):
if HPU_AVAILABLE and use_marlin:
return unittest.skip(reason="HPU does not support marlin.")

pretrained_model_dir = "saibo/llama-1B"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
Expand Down Expand Up @@ -43,9 +51,11 @@ def test_quantize(self, use_marlin: bool):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0", use_marlin=use_marlin)
device = None if HPU_AVAILABLE else "cuda:0"
model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device=device, use_marlin=use_marlin)
del model
torch.cuda.empty_cache()
if not HPU_AVAILABLE:
torch.cuda.empty_cache()

# test compat: 1) with simple dict type 2) is_marlin_format
compat_quantize_config = {
Expand All @@ -54,11 +64,12 @@ def test_quantize(self, use_marlin: bool):
"desc_act": False,
"is_marlin_format": use_marlin,
}
model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0", quantize_config=compat_quantize_config)
model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device=device, quantize_config=compat_quantize_config)
assert(isinstance(model.quantize_config, BaseQuantizeConfig))

del model
torch.cuda.empty_cache()
if not HPU_AVAILABLE:
torch.cuda.empty_cache()

# test checkinpoint_format hint to from_quantized()
os.remove(f"{tmpdirname}/{QUANT_CONFIG_FILENAME}")
Expand All @@ -68,7 +79,7 @@ def test_quantize(self, use_marlin: bool):
"group_size": 128,
"desc_act": False,
}
model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0",
model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device=device,
quantize_config=compat_quantize_config,
checkpoint_format=CHECKPOINT_FORMAT.MARLIN if use_marlin else None)
assert (isinstance(model.quantize_config, BaseQuantizeConfig))