From 2bfa3472ecde8f3e4a0306b017826314c288b7c8 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Mon, 26 Aug 2024 19:27:12 +0000 Subject: [PATCH 1/6] docs: tweaks for multi-backend preview release prep --- docs/source/installation.mdx | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 0e8da0cda..60419b38a 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -134,14 +134,23 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-11.7 3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded. -## Multi-backend preview release compilation[[multi-backend]] +## Multi-backend[[multi-backend]] + +> [!TIP] +> This functionality is currently in preview and therefore not yet production-ready! Please follow these steps to install bitsandbytes with device-specific backend support other than CUDA: +### Pip install the pre-built wheel (recommended for most) + +WIP (will be added in the coming days) + +### Compilation + -### AMD GPU +#### AMD GPU bitsandbytes is fully supported from ROCm 6.1 onwards (currently in alpha release). @@ -179,7 +188,7 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise -### Intel CPU +#### Intel CPU > [!TIP] > Intel CPU backend only supports building from source; for now, please follow the instructions below. @@ -200,6 +209,8 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise +#### Apple Silicon + WIP From c8383fbf65cee2bc61f7421dc9b57ad9e9447c1e Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:00:34 +0000 Subject: [PATCH 2/6] docs: get started on detailed multi-backend guide --- docs/source/_toctree.yml | 2 ++ docs/source/non_cuda_backends.mdx | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 docs/source/non_cuda_backends.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index fdfe19ee4..a72eb1967 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -12,6 +12,8 @@ title: 8-bit optimizers - local: algorithms title: Algorithms + - local: non_cuda_backends + title: Non-CUDA compute backends - local: fsdp_qlora title: FSDP-QLoRA - local: integrations diff --git a/docs/source/non_cuda_backends.mdx b/docs/source/non_cuda_backends.mdx new file mode 100644 index 000000000..fca586534 --- /dev/null +++ b/docs/source/non_cuda_backends.mdx @@ -0,0 +1,27 @@ +# Multi-backend support (non-CUDA backends) + +As part of a recent refactoring effort, we will soon offer official multi-backend support. Currently, this feature is available in a preview alpha release, allowing us to gather early feedback from users to improve the functionality and identify any bugs. + +At present, the Intel CPU and AMD ROCm backends are considered fully functional. The Intel XPU backend has limited functionality and is less mature. + +Please refer to the [installation instructions](./installation#multi-backend) for details on installing the backend you intend to test (and hopefully provide feedback on). + +> [!Tip] +> Apple Silicon support is planned for Q4 2024. We are actively seeking contributors to help implement this, develop a concrete plan, and create a detailed list of requirements. Due to limited resources, we rely on community contributions for this implementation effort. To discuss further, please spell out your thoughts and discuss in [this GitHub discussion](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1340) and tag `@Titus-von-Koeller` and `@matthewdouglas`. Thank you! + +## Alpha Release + +As we are currently in the alpha testing phase, bugs are expected, and performance might not meet expectations. However, this is exactly what we want to discover from **your** perspective as the end user! + +Please share and discuss your feedback with us here: + +- [Github Discussion: Multi-backend refactor: Alpha release ( AMD ROCm ONLY )](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1339) +- [Github Discussion: Multi-backend refactor: Alpha release ( Intel ONLY )](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1338) + +Thank you for your support! + +## Benchmarks + +### Intel + +### AMD From 3b94d626fdcde73b32586995828d68010668bedd Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Fri, 30 Aug 2024 01:25:43 +0800 Subject: [PATCH 3/6] rm warn for multi backend (#1336) --- bitsandbytes/cextension.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index cfeaf4f44..6c18275c6 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -106,10 +106,6 @@ def get_native_library() -> BNBNativeLibrary: if hasattr(dll, "get_context"): # only a CUDA-built library exposes this return CudaBNBNativeLibrary(dll) - logger.warning( - "The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", - ) return BNBNativeLibrary(dll) From 39097a6fae9951630e83baa7b6a34f569d91f1a9 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 30 Aug 2024 10:18:48 +0000 Subject: [PATCH 4/6] actions: update permissions for pr docs publishing --- .github/workflows/upload_pr_documentation.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml index 6497caf2d..707705297 100644 --- a/.github/workflows/upload_pr_documentation.yml +++ b/.github/workflows/upload_pr_documentation.yml @@ -6,6 +6,10 @@ on: types: - completed +permissions: + contents: read + pull-requests: write # Allows posting comments on pull requests + jobs: build: uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main From 27846533d19eed5c6ef3cb01e8ee237639069180 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:49:28 +0800 Subject: [PATCH 5/6] fix nf4 memory issue by init op_context in forward (#1349) * fix nf4 memory issue by init op_context in forward * disable repack in init * fix code style --- bitsandbytes/backends/cpu_xpu_common.py | 19 ----------------- bitsandbytes/nn/modules.py | 27 +++++++++++++++++++++---- bitsandbytes/utils.py | 24 ++++++++++++++++++++++ 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 0fcfffa07..0d865b541 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -370,25 +370,6 @@ def quantize_4bit_impl( quant_type=quant_type, ) - if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4": - # lowp_mode: lowest precision for computation - lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 - state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( - out.reshape([input_shape[0], input_shape[1] // 2]), - ipex_cpu.quantization.WoqWeightDtype.NF4, - input_shape, # weight shape - absmax.view(input_shape[0], input_shape[1] // blocksize), # scales - None, # zero_points - None, # bias - None, # g_idx - None, # batch_size - blocksize, - int(lowp_mode), - -1, # act_quant_mode. -1 means don't quant activation - ) - state.absmax = torch.Tensor() - return torch.empty([1, 0], dtype=torch.uint8), state - return out.unsqueeze(0), state diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2348d0791..ad424a6f4 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -19,6 +19,7 @@ INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, + enable_ipex_fusion, ) T = TypeVar("T", bound="torch.nn.Module") @@ -444,17 +445,35 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ + if ( + getattr(self.weight, "quant_state", None) is not None + and getattr(self.weight.quant_state, "op_context", None) is not None + ): + context = self.weight.quant_state.op_context + self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) + super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: + if ( + self.weight.quant_state.absmax.shape.numel() == 0 + and getattr(self.weight.quant_state, "op_context", None) is not None + ): + self.weight.quant_state.absmax = context.get_scales().reshape(-1) + delattr(self.weight.quant_state, "op_context") for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() - if getattr(self.weight.quant_state, "op_context", None) is not None: - context = self.weight.quant_state.op_context - destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1) - self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) def forward(self, x: torch.Tensor): + # Check if ipex fusion can be used + if ( + x.device.type == "cpu" + and not hasattr(self.weight.quant_state, "op_context") + and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 + and self.weight.quant_state.quant_type == "nf4" + ): + enable_ipex_fusion(self.weight, self.weight.quant_state) + # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index fa9a7eb70..9e52c915d 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -200,6 +200,30 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict +def enable_ipex_fusion(weight, quant_state): + from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq + + if _ipex_cpu_version_prereq(2, 3): + import intel_extension_for_pytorch as ipex + + lowp_mode = ipex.quantization.WoqLowpMode.BF16 + quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( + weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + ipex.quantization.WoqWeightDtype.NF4, + quant_state.shape, # weight shape + quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size + quant_state.blocksize, + int(lowp_mode), + -1, # act_quant_mode. -1 means don't quant activation + ) + quant_state.absmax = torch.Tensor() + weight.data = torch.empty([1, 0], dtype=torch.uint8) + + class QuantState: """container for quantization state components to work with Params4bit and similar classes""" From 45b7d14a9ae58927688c04dde6a8d70275abd0ae Mon Sep 17 00:00:00 2001 From: pnunna93 <104791500+pnunna93@users.noreply.github.com> Date: Mon, 16 Sep 2024 11:45:43 -0500 Subject: [PATCH 6/6] AMD: Clarify diagnostic messages; free up disk space for CI build * Add build job for rocm * Add rocm build script * Copy shared obj file into output_dir * upload build artifacts and enable wheels build * Remove cuda build temporarily * Add ROCm version to .so filename * Add rocm_version to whls build * Revert "Remove cuda build temporarily" This reverts commit 1413c5f3a2aed51140b86daa8ee9283c67cce738. * Add rocm_version env var * Remove thrush header files * Print node info * print cuda node info * Revert "print cuda node info" This reverts commit cdb209a2eb896d9c4166f53e9b2aa580c10e42c0. * Revert "Print node info" This reverts commit 7e9a65c33f66fffcb14ee2438170718777c06022. * Add rocm arch to compile command * Rename .so files to rocm * Update default gpu arch * Skip cpu based igemmlt int tests on ROCm * Update Documentation * Update upstream repo name * Update docs * Update string format Co-authored-by: Aarni Koskela * Remove pre-release option for torch install * Update pytorch install path Co-authored-by: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> * Add messages for Heuristics error * Remove toolcache for disk space * print disk usage * Clean disk space for linux * Fix for ubuntu * Add sudo for apt clean * Update clean up disk list * remove disk usage print * Add BNB_BACKEND variable * Update diagnostic functions for ROCm * Fix tuple error * Fix library detection bug for recursive and symlink cases * fix pre-commit errors * Remove recursive path lib search * Create function for runtime lib patterns * Update logger format Co-authored-by: Aarni Koskela * Update error reporting Co-authored-by: Aarni Koskela * Remove commented code Co-authored-by: Aarni Koskela * Update error reporting Co-authored-by: Aarni Koskela * Update error reporting * Create hip diagnostics functions * Fix Typo * Fix pre-commit checks --------- Co-authored-by: Aarni Koskela Co-authored-by: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> --- .github/workflows/python-package.yml | 21 +++++-- bitsandbytes/cextension.py | 11 ++-- bitsandbytes/diagnostics/cuda.py | 89 ++++++++++++++++++++++++---- bitsandbytes/diagnostics/main.py | 31 ++++++---- csrc/ops.hip | 26 +++++--- 5 files changed, 137 insertions(+), 41 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 91e6d82a6..d2da82501 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -116,10 +116,23 @@ jobs: uses: docker/setup-qemu-action@v2 - name: Clean up disk space run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift - name: Build C++ run: bash .github/scripts/build-rocm.sh env: diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 6c18275c6..cc5d8deff 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -99,7 +99,7 @@ def get_native_library() -> BNBNativeLibrary: if cuda_binary_path.exists(): binary_path = cuda_binary_path else: - logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path) + logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path) logger.debug(f"Loading bitsandbytes native library from: {binary_path}") dll = ct.cdll.LoadLibrary(str(binary_path)) @@ -116,21 +116,24 @@ def get_native_library() -> BNBNativeLibrary: hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor BNB_HIP_VERSION_SHORT = f"{hip_major}{hip_minor}" + BNB_BACKEND = "ROCm" else: HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0 BNB_HIP_VERSION_SHORT = "" + BNB_BACKEND = "CUDA" + lib = get_native_library() except Exception as e: lib = None logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True) if torch.cuda.is_available(): logger.warning( - """ -CUDA Setup failed despite CUDA being available. Please run the following command to get more information: + f""" +{BNB_BACKEND} Setup failed despite {BNB_BACKEND} being available. Please run the following command to get more information: python -m bitsandbytes -Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them +Inspect the output of the command and see if you can locate {BNB_BACKEND} libraries. You might need to add them to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues """, diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index 8974c6400..014b753a9 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -5,7 +5,7 @@ import torch -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.consts import NONPYTORCH_DOC_URL from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.diagnostics.utils import print_dedented @@ -32,15 +32,20 @@ "_", # current Python interpreter } -CUDA_RUNTIME_LIB_PATTERNS = ( - "cudart64*.dll", # Windows - "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. - "nvcuda*.dll", # Windows -) - logger = logging.getLogger(__name__) +def get_runtime_lib_patterns() -> tuple: + if HIP_ENVIRONMENT: + return ("libamdhip64.so*",) + else: + return ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows + ) + + def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: for dir_string in paths_list_candidate.split(os.pathsep): if not dir_string: @@ -55,9 +60,9 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path continue except OSError: # Assume an esoteric error trying to poke at the directory pass - for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: + for lib_pattern in get_runtime_lib_patterns(): for pth in dir.glob(lib_pattern): - if pth.is_file(): + if pth.is_file() and not pth.is_symlink(): yield pth except (OSError, PermissionError): pass @@ -104,7 +109,7 @@ def find_cudart_libraries() -> Iterator[Path]: yield from find_cuda_libraries_in_path_list(value) -def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: +def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print( f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", @@ -149,10 +154,40 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: # (2) Multiple CUDA versions installed -def print_cuda_runtime_diagnostics() -> None: +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") + + binary_path = get_cuda_bnb_library_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + and rebuild bitsandbytes. + """, + ) + + hip_major, hip_minor = cuda_specs.cuda_version_tuple + if (hip_major, hip_minor) < (6, 1): + print_dedented( + """ + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + """, + ) + + +def print_diagnostics(cuda_specs: CUDASpecs) -> None: + if HIP_ENVIRONMENT: + _print_hip_diagnostics(cuda_specs) + else: + _print_cuda_diagnostics(cuda_specs) + + +def _print_cuda_runtime_diagnostics() -> None: cudart_paths = list(find_cudart_libraries()) if not cudart_paths: - print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") + print("WARNING! CUDA runtime files not found in any environmental path.") elif len(cudart_paths) > 1: print_dedented( f""" @@ -174,3 +209,33 @@ def print_cuda_runtime_diagnostics() -> None: ) for pth in cudart_paths: print(f"* Found CUDA runtime at: {pth}") + + +def _print_hip_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("WARNING! ROCm runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate ROCm runtime files (see below). + + We select the PyTorch default ROCm runtime, which is {torch.version.hip}, + but this might mismatch with the ROCm version that is needed for bitsandbytes. + + To resolve it, install PyTorch built for the ROCm version you want to use + + and set LD_LIBRARY_PATH to your ROCm install path, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, + """, + ) + + for pth in cudart_paths: + print(f"* Found ROCm runtime at: {pth}") + + +def print_runtime_diagnostics() -> None: + if HIP_ENVIRONMENT: + _print_hip_runtime_diagnostics() + else: + _print_cuda_runtime_diagnostics() diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 1ce096f69..8dc43ed2a 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -3,11 +3,12 @@ import torch +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.diagnostics.cuda import ( - print_cuda_diagnostics, - print_cuda_runtime_diagnostics, + print_diagnostics, + print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header @@ -16,12 +17,13 @@ def sanity_check(): from bitsandbytes.cextension import lib if lib is None: + compute_backend = "cuda" if not HIP_ENVIRONMENT else "hip" print_dedented( - """ + f""" Couldn't load the bitsandbytes library, likely due to missing binaries. Please ensure bitsandbytes is properly installed. - For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND=cuda -S .`. + For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND={compute_backend} -S .`. See the documentation for more details if needed. Trying a simple check anyway, but this will likely fail... @@ -49,19 +51,24 @@ def main(): print_header("OTHER") cuda_specs = get_cuda_specs() - print("CUDA specs:", cuda_specs) + if HIP_ENVIRONMENT: + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + print(f"{BNB_BACKEND} specs:{rocm_specs}") + else: + print(f"{BNB_BACKEND} specs:{cuda_specs}") if not torch.cuda.is_available(): - print("Torch says CUDA is not available. Possible reasons:") - print("1. CUDA driver not installed") - print("2. CUDA not installed") - print("3. You have multiple conflicting CUDA libraries") + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") + print(f"1. {BNB_BACKEND} driver not installed") + print(f"2. {BNB_BACKEND} not installed") + print(f"3. You have multiple conflicting {BNB_BACKEND} libraries") if cuda_specs: - print_cuda_diagnostics(cuda_specs) - print_cuda_runtime_diagnostics() + print_diagnostics(cuda_specs) + print_runtime_diagnostics() print_header("") print_header("DEBUG INFO END") print_header("") - print("Checking that the library is importable and CUDA is callable...") + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") try: sanity_check() print("SUCCESS!") diff --git a/csrc/ops.hip b/csrc/ops.hip index 157e84629..4fdc3cbfa 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -576,6 +576,7 @@ template int igemmlt(hipblasLtHandl if (returnedAlgoCount == 0) { has_error = 1; + fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); } else { @@ -614,18 +615,25 @@ template int igemmlt(hipblasLtHandl heuristicResult, &returnedAlgoCount)); - if(!SCALE_ROWS) + if (returnedAlgoCount == 0) { - float alpha = 1.0f, beta = 0.0f; - - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + has_error = 1; + fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); } else { - //has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, hipblasLt_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); - float beta = 0.0f; - - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + if(!SCALE_ROWS) + { + float alpha = 1.0f, beta = 0.0f; + + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + } + else + { + float beta = 0.0f; + + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + } } } @@ -635,7 +643,7 @@ template int igemmlt(hipblasLtHandl if (Adesc) has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(Adesc)); if (matmulDesc) has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc)); if(has_error == 1) - printf("error detected"); + fprintf(stderr, "error detected\n"); return has_error; #endif // NO_HIPBLASLT