Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/multi-backend-refactor' into e…
Browse files Browse the repository at this point in the history
…nable_6.2_packaging
  • Loading branch information
pnunna93 committed Sep 19, 2024
2 parents 7e787da + 45b7d14 commit cd026c3
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 30 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/upload_pr_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ on:
types:
- completed

permissions:
contents: read
pull-requests: write # Allows posting comments on pull requests

jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
Expand Down
19 changes: 0 additions & 19 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,25 +370,6 @@ def quantize_4bit_impl(
quant_type=quant_type,
)

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
# lowp_mode: lowest precision for computation
lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
out.reshape([input_shape[0], input_shape[1] // 2]),
ipex_cpu.quantization.WoqWeightDtype.NF4,
input_shape, # weight shape
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
)
state.absmax = torch.Tensor()
return torch.empty([1, 0], dtype=torch.uint8), state

return out.unsqueeze(0), state


Expand Down
4 changes: 0 additions & 4 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def get_native_library() -> BNBNativeLibrary:
if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
return CudaBNBNativeLibrary(dll)

logger.warning(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.",
)
return BNBNativeLibrary(dll)


Expand Down
27 changes: 23 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
enable_ipex_fusion,
)

T = TypeVar("T", bound="torch.nn.Module")
Expand Down Expand Up @@ -444,17 +445,35 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
save weight and bias,
then fill state_dict with components of quant_state
"""
if (
getattr(self.weight, "quant_state", None) is not None
and getattr(self.weight.quant_state, "op_context", None) is not None
):
context = self.weight.quant_state.op_context
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])

super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias

if getattr(self.weight, "quant_state", None) is not None:
if (
self.weight.quant_state.absmax.shape.numel() == 0
and getattr(self.weight.quant_state, "op_context", None) is not None
):
self.weight.quant_state.absmax = context.get_scales().reshape(-1)
delattr(self.weight.quant_state, "op_context")
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
if getattr(self.weight.quant_state, "op_context", None) is not None:
context = self.weight.quant_state.op_context
destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1)
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])

def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
if (
x.device.type == "cpu"
and not hasattr(self.weight.quant_state, "op_context")
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
):
enable_ipex_fusion(self.weight, self.weight.quant_state)

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
Expand Down
24 changes: 24 additions & 0 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,30 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def enable_ipex_fusion(weight, quant_state):
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq

if _ipex_cpu_version_prereq(2, 3):
import intel_extension_for_pytorch as ipex

lowp_mode = ipex.quantization.WoqLowpMode.BF16
quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
ipex.quantization.WoqWeightDtype.NF4,
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
quant_state.blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
)
quant_state.absmax = torch.Tensor()
weight.data = torch.empty([1, 0], dtype=torch.uint8)


class QuantState:
"""container for quantization state components to work with Params4bit and similar classes"""

Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
title: 8-bit optimizers
- local: algorithms
title: Algorithms
- local: non_cuda_backends
title: Non-CUDA compute backends
- local: fsdp_qlora
title: FSDP-QLoRA
- local: integrations
Expand Down
17 changes: 14 additions & 3 deletions docs/source/installation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,23 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-11.7

3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded.

## Multi-backend preview release compilation[[multi-backend]]
## Multi-backend[[multi-backend]]

> [!TIP]
> This functionality is currently in preview and therefore not yet production-ready!
Please follow these steps to install bitsandbytes with device-specific backend support other than CUDA:

### Pip install the pre-built wheel (recommended for most)

WIP (will be added in the coming days)

### Compilation

<hfoptions id="backend">
<hfoption id="AMD ROCm">

### AMD GPU
#### AMD GPU

bitsandbytes is fully supported from ROCm 6.1 onwards (currently in alpha release).

Expand Down Expand Up @@ -179,7 +188,7 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise
</hfoption>
<hfoption id="Intel CPU + GPU">

### Intel CPU
#### Intel CPU

> [!TIP]
> Intel CPU backend only supports building from source; for now, please follow the instructions below.
Expand All @@ -200,6 +209,8 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise
</hfoption>
<hfoption id="Apple Silicon (MPS)">

#### Apple Silicon

WIP

</hfoption>
Expand Down
27 changes: 27 additions & 0 deletions docs/source/non_cuda_backends.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Multi-backend support (non-CUDA backends)

As part of a recent refactoring effort, we will soon offer official multi-backend support. Currently, this feature is available in a preview alpha release, allowing us to gather early feedback from users to improve the functionality and identify any bugs.

At present, the Intel CPU and AMD ROCm backends are considered fully functional. The Intel XPU backend has limited functionality and is less mature.

Please refer to the [installation instructions](./installation#multi-backend) for details on installing the backend you intend to test (and hopefully provide feedback on).

> [!Tip]
> Apple Silicon support is planned for Q4 2024. We are actively seeking contributors to help implement this, develop a concrete plan, and create a detailed list of requirements. Due to limited resources, we rely on community contributions for this implementation effort. To discuss further, please spell out your thoughts and discuss in [this GitHub discussion](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1340) and tag `@Titus-von-Koeller` and `@matthewdouglas`. Thank you!
## Alpha Release

As we are currently in the alpha testing phase, bugs are expected, and performance might not meet expectations. However, this is exactly what we want to discover from **your** perspective as the end user!

Please share and discuss your feedback with us here:

- [Github Discussion: Multi-backend refactor: Alpha release ( AMD ROCm ONLY )](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1339)
- [Github Discussion: Multi-backend refactor: Alpha release ( Intel ONLY )](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1338)

Thank you for your support!

## Benchmarks

### Intel

### AMD

0 comments on commit cd026c3

Please sign in to comment.