From 716e01040320afbb0fcf841d1a47088d81973172 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Fri, 8 Nov 2024 19:23:59 +0000 Subject: [PATCH] fix diagnostics --- bitsandbytes/cextension.py | 2 +- bitsandbytes/diagnostics/cuda.py | 30 ++-- bitsandbytes/diagnostics/gpu.py | 241 +++++++++++++++++++++++++++++ bitsandbytes/diagnostics/main.py | 14 +- bitsandbytes/gpu_specs.py | 2 +- tests/test_cuda_setup_evaluator.py | 10 +- 6 files changed, 270 insertions(+), 29 deletions(-) create mode 100644 bitsandbytes/diagnostics/gpu.py diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 532f6970b..d863ad41e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -37,7 +37,7 @@ def get_gpu_bnb_library_path(gpu_specs: GPUSpecs) -> Path: The library is not guaranteed to exist at the returned path. """ library_name = f"libbitsandbytes_{gpu_specs.gpu_backend}{gpu_specs.backend_version_string}" - if not gpu_specs.has_blaslt: + if not gpu_specs.enable_blaslt: # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt if gpu_specs.gpu_backend == "rocm": library_name += "_nohipblaslt" diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index 014b753a9..0e8593fdd 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -5,9 +5,9 @@ import torch -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path from bitsandbytes.consts import NONPYTORCH_DOC_URL -from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.gpu_specs import GPUSpecs from bitsandbytes.diagnostics.utils import print_dedented CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") @@ -109,13 +109,13 @@ def find_cudart_libraries() -> Iterator[Path]: yield from find_cuda_libraries_in_path_list(value) -def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: +def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None: print( - f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " - f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", + f"PyTorch settings found: CUDA_VERSION={gpu_specs.cuda_version_string}, " + f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.", ) - binary_path = get_cuda_bnb_library_path(cuda_specs) + binary_path = get_gpu_bnb_library_path(gpu_specs) if not binary_path.exists(): print_dedented( f""" @@ -128,7 +128,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: """, ) - cuda_major, cuda_minor = cuda_specs.cuda_version_tuple + cuda_major, cuda_minor = gpu_specs.cuda_version_tuple if cuda_major < 11: print_dedented( """ @@ -140,7 +140,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") # 7.5 is the minimum CC for cublaslt - if not cuda_specs.has_cublaslt: + if not gpu_specs.has_cublaslt: print_dedented( """ WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! @@ -154,10 +154,10 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: # (2) Multiple CUDA versions installed -def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: - print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") +def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.cuda_version_string}") - binary_path = get_cuda_bnb_library_path(cuda_specs) + binary_path = get_gpu_bnb_library_path(gpu_specs) if not binary_path.exists(): print_dedented( f""" @@ -168,7 +168,7 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: """, ) - hip_major, hip_minor = cuda_specs.cuda_version_tuple + hip_major, hip_minor = gpu_specs.cuda_version_tuple if (hip_major, hip_minor) < (6, 1): print_dedented( """ @@ -177,11 +177,11 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: ) -def print_diagnostics(cuda_specs: CUDASpecs) -> None: +def print_diagnostics(gpu_specs: GPUSpecs) -> None: if HIP_ENVIRONMENT: - _print_hip_diagnostics(cuda_specs) + _print_hip_diagnostics(gpu_specs) else: - _print_cuda_diagnostics(cuda_specs) + _print_cuda_diagnostics(gpu_specs) def _print_cuda_runtime_diagnostics() -> None: diff --git a/bitsandbytes/diagnostics/gpu.py b/bitsandbytes/diagnostics/gpu.py new file mode 100644 index 000000000..18db9592c --- /dev/null +++ b/bitsandbytes/diagnostics/gpu.py @@ -0,0 +1,241 @@ +import logging +import os +from pathlib import Path +from typing import Dict, Iterable, Iterator + +import torch + +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path +from bitsandbytes.consts import NONPYTORCH_DOC_URL +from bitsandbytes.gpu_specs import GPUSpecs +from bitsandbytes.diagnostics.utils import print_dedented + +GPU_RT_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") + +GPU_RT_PATH_IGNORED_ENVVARS = { + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks + "HOME", # Linux shell default + "LESSCLOSE", + "LESSOPEN", # related to the `less` command + "MAIL", # something related to emails + "OLDPWD", + "PATH", # this is for finding binaries, not libraries + "PWD", # PWD: this is how the shell keeps track of the current working dir + "SHELL", # binary for currently invoked shell + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "_", # current Python interpreter +} + +logger = logging.getLogger(__name__) + + +def get_runtime_lib_patterns() -> tuple: + if HIP_ENVIRONMENT: + return ("libamdhip64.so*",) + else: + return ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows + ) + + +def find_gpu_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: + for dir_string in paths_list_candidate.split(os.pathsep): + if not dir_string: + continue + if os.sep not in dir_string: + continue + try: + dir = Path(dir_string) + try: + if not dir.exists(): + logger.warning(f"The directory listed in your path is found to be non-existent: {dir}") + continue + except OSError: # Assume an esoteric error trying to poke at the directory + pass + for lib_pattern in get_runtime_lib_patterns(): + for pth in dir.glob(lib_pattern): + if pth.is_file() and not pth.is_symlink(): + yield pth + except (OSError, PermissionError): + pass + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return ( + env_var in GPU_RT_PATH_PREFERRED_ENVVARS # is a preferred location + or ( + os.sep in value # might contain a path + and env_var not in GPU_RT_PATH_IGNORED_ENVVARS # not ignored + and "CONDA" not in env_var # not another conda envvar + and "BASH_FUNC" not in env_var # not a bash function defined via envvar + and "\n" not in value # likely e.g. a script or something? + ) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)} + + +def find_gpu_rt_libraries() -> Iterator[Path]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + for envvar in GPU_RT_PATH_PREFERRED_ENVVARS: + if envvar in candidate_env_vars: + directory = candidate_env_vars[envvar] + yield from find_gpu_libraries_in_path_list(directory) + candidate_env_vars.pop(envvar) + + for env_var, value in candidate_env_vars.items(): + yield from find_gpu_libraries_in_path_list(value) + + +def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None: + print( + f"PyTorch settings found: CUDA_VERSION={gpu_specs.backend_version_string}, " + f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.", + ) + + binary_path = get_gpu_bnb_library_path(gpu_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. Maybe you need to compile it from source? + If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, + for example, `make CUDA_VERSION=113`. + + The CUDA version for the compile might depend on your conda install, if using conda. + Inspect CUDA version via `conda list | grep cuda`. + """, + ) + + cuda_major, cuda_minor = gpu_specs.backend_version_tuple + if cuda_major < 11: + print_dedented( + """ + WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). + You will be only to use 8-bit optimizers and quantization routines! + """, + ) + + print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") + + # 7.5 is the minimum CC for cublaslt + if not gpu_specs.enable_blaslt: + print_dedented( + """ + WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! + If you run into issues with 8-bit matmul, you can try 4-bit quantization: + https://huggingface.co/blog/4bit-transformers-bitsandbytes + """, + ) + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + +def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.backend_version_string}") + + binary_path = get_gpu_bnb_library_path(gpu_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + and rebuild bitsandbytes. + """, + ) + + hip_major, hip_minor = gpu_specs.backend_version_tuple + if (hip_major, hip_minor) < (6, 1): + print_dedented( + """ + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + """, + ) + + +def print_diagnostics(gpu_specs: GPUSpecs) -> None: + if HIP_ENVIRONMENT: + _print_hip_diagnostics(gpu_specs) + else: + _print_cuda_diagnostics(gpu_specs) + + +def _print_cuda_runtime_diagnostics() -> None: + gpu_rt_paths = list(find_gpu_rt_libraries()) + if not gpu_rt_paths: + print("WARNING! CUDA runtime files not found in any environmental path.") + elif len(gpu_rt_paths) > 1: + print_dedented( + f""" + Found duplicate CUDA runtime files (see below). + + We select the PyTorch default CUDA runtime, which is {torch.version.cuda}, + but this might mismatch with the CUDA version that is needed for bitsandbytes. + To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. + + For example, if you want to use the CUDA version 122, + BNB_CUDA_VERSION=122 python ... + + OR set the environmental variable in your .bashrc: + export BNB_CUDA_VERSION=122 + + In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, + """, + ) + for pth in gpu_rt_paths: + print(f"* Found CUDA runtime at: {pth}") + + +def _print_hip_runtime_diagnostics() -> None: + gpu_rt_paths = list(find_gpu_rt_libraries()) + if not gpu_rt_paths: + print("WARNING! ROCm runtime files not found in any environmental path.") + elif len(gpu_rt_paths) > 1: + print_dedented( + f""" + Found duplicate ROCm runtime files (see below). + + We select the PyTorch default ROCm runtime, which is {torch.version.hip}, + but this might mismatch with the ROCm version that is needed for bitsandbytes. + + To resolve it, install PyTorch built for the ROCm version you want to use + + and set LD_LIBRARY_PATH to your ROCm install path, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, + """, + ) + + for pth in gpu_rt_paths: + print(f"* Found ROCm runtime at: {pth}") + + +def print_runtime_diagnostics() -> None: + if HIP_ENVIRONMENT: + _print_hip_runtime_diagnostics() + else: + _print_cuda_runtime_diagnostics() diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 8dc43ed2a..65e0fe924 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -5,7 +5,7 @@ from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT from bitsandbytes.consts import PACKAGE_GITHUB_URL -from bitsandbytes.cuda_specs import get_cuda_specs +from bitsandbytes.gpu_specs import get_gpu_specs from bitsandbytes.diagnostics.cuda import ( print_diagnostics, print_runtime_diagnostics, @@ -50,20 +50,20 @@ def main(): print_header("") print_header("OTHER") - cuda_specs = get_cuda_specs() + gpu_specs = get_gpu_specs() if HIP_ENVIRONMENT: - rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + rocm_specs = f" rocm_version_string='{gpu_specs.cuda_version_string}'," + rocm_specs += f" rocm_version_tuple={gpu_specs.cuda_version_tuple}" print(f"{BNB_BACKEND} specs:{rocm_specs}") else: - print(f"{BNB_BACKEND} specs:{cuda_specs}") + print(f"{BNB_BACKEND} specs:{gpu_specs}") if not torch.cuda.is_available(): print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") print(f"1. {BNB_BACKEND} driver not installed") print(f"2. {BNB_BACKEND} not installed") print(f"3. You have multiple conflicting {BNB_BACKEND} libraries") - if cuda_specs: - print_diagnostics(cuda_specs) + if gpu_specs: + print_diagnostics(gpu_specs) print_runtime_diagnostics() print_header("") print_header("DEBUG INFO END") diff --git a/bitsandbytes/gpu_specs.py b/bitsandbytes/gpu_specs.py index b01a38390..822ad3fb2 100644 --- a/bitsandbytes/gpu_specs.py +++ b/bitsandbytes/gpu_specs.py @@ -15,7 +15,7 @@ class GPUSpecs: backend_version_tuple: Tuple[int, int] @property - def has_blaslt(self) -> bool: + def enable_blaslt(self) -> bool: if torch.version.hip: return self.highest_compute_capability >= 601 else: diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 42749ef00..a8597acae 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,12 +1,12 @@ import pytest from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path -from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.gpu_specs import GPUSpecs @pytest.fixture -def cuda120_spec() -> CUDASpecs: - return CUDASpecs( +def cuda120_spec() -> GPUSpecs: + return GPUSpecs( cuda_version_string="120", highest_compute_capability=(8, 6), cuda_version_tuple=(12, 0), @@ -14,8 +14,8 @@ def cuda120_spec() -> CUDASpecs: @pytest.fixture -def cuda111_noblas_spec() -> CUDASpecs: - return CUDASpecs( +def cuda111_noblas_spec() -> GPUSpecs: + return GPUSpecs( cuda_version_string="111", highest_compute_capability=(7, 2), cuda_version_tuple=(11, 1),