Skip to content

Commit

Permalink
fix diagnostics
Browse files Browse the repository at this point in the history
  • Loading branch information
Lzy17 committed Nov 8, 2024
1 parent 69803b6 commit 716e010
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 29 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_gpu_bnb_library_path(gpu_specs: GPUSpecs) -> Path:
The library is not guaranteed to exist at the returned path.
"""
library_name = f"libbitsandbytes_{gpu_specs.gpu_backend}{gpu_specs.backend_version_string}"
if not gpu_specs.has_blaslt:
if not gpu_specs.enable_blaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
if gpu_specs.gpu_backend == "rocm":
library_name += "_nohipblaslt"
Expand Down
30 changes: 15 additions & 15 deletions bitsandbytes/diagnostics/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import torch

from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path
from bitsandbytes.consts import NONPYTORCH_DOC_URL
from bitsandbytes.cuda_specs import CUDASpecs
from bitsandbytes.gpu_specs import GPUSpecs
from bitsandbytes.diagnostics.utils import print_dedented

CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH")
Expand Down Expand Up @@ -109,13 +109,13 @@ def find_cudart_libraries() -> Iterator[Path]:
yield from find_cuda_libraries_in_path_list(value)


def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None:
print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
f"PyTorch settings found: CUDA_VERSION={gpu_specs.cuda_version_string}, "
f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.",
)

binary_path = get_cuda_bnb_library_path(cuda_specs)
binary_path = get_gpu_bnb_library_path(gpu_specs)
if not binary_path.exists():
print_dedented(
f"""
Expand All @@ -128,7 +128,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
""",
)

cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
cuda_major, cuda_minor = gpu_specs.cuda_version_tuple
if cuda_major < 11:
print_dedented(
"""
Expand All @@ -140,7 +140,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")

# 7.5 is the minimum CC for cublaslt
if not cuda_specs.has_cublaslt:
if not gpu_specs.has_cublaslt:
print_dedented(
"""
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
Expand All @@ -154,10 +154,10 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
# (2) Multiple CUDA versions installed


def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}")
def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.cuda_version_string}")

binary_path = get_cuda_bnb_library_path(cuda_specs)
binary_path = get_gpu_bnb_library_path(gpu_specs)
if not binary_path.exists():
print_dedented(
f"""
Expand All @@ -168,7 +168,7 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
""",
)

hip_major, hip_minor = cuda_specs.cuda_version_tuple
hip_major, hip_minor = gpu_specs.cuda_version_tuple
if (hip_major, hip_minor) < (6, 1):
print_dedented(
"""
Expand All @@ -177,11 +177,11 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
)


def print_diagnostics(cuda_specs: CUDASpecs) -> None:
def print_diagnostics(gpu_specs: GPUSpecs) -> None:
if HIP_ENVIRONMENT:
_print_hip_diagnostics(cuda_specs)
_print_hip_diagnostics(gpu_specs)
else:
_print_cuda_diagnostics(cuda_specs)
_print_cuda_diagnostics(gpu_specs)


def _print_cuda_runtime_diagnostics() -> None:
Expand Down
241 changes: 241 additions & 0 deletions bitsandbytes/diagnostics/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import logging
import os
from pathlib import Path
from typing import Dict, Iterable, Iterator

import torch

from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path
from bitsandbytes.consts import NONPYTORCH_DOC_URL
from bitsandbytes.gpu_specs import GPUSpecs
from bitsandbytes.diagnostics.utils import print_dedented

GPU_RT_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH")

GPU_RT_PATH_IGNORED_ENVVARS = {
"DBUS_SESSION_BUS_ADDRESS", # hardware related
"GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks
"HOME", # Linux shell default
"LESSCLOSE",
"LESSOPEN", # related to the `less` command
"MAIL", # something related to emails
"OLDPWD",
"PATH", # this is for finding binaries, not libraries
"PWD", # PWD: this is how the shell keeps track of the current working dir
"SHELL", # binary for currently invoked shell
"SSH_AUTH_SOCK", # SSH stuff, therefore unrelated
"SSH_TTY",
"TMUX", # Terminal Multiplexer
"XDG_DATA_DIRS", # XDG: Desktop environment stuff
"XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff
"XDG_RUNTIME_DIR",
"_", # current Python interpreter
}

logger = logging.getLogger(__name__)


def get_runtime_lib_patterns() -> tuple:
if HIP_ENVIRONMENT:
return ("libamdhip64.so*",)
else:
return (
"cudart64*.dll", # Windows
"libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc.
"nvcuda*.dll", # Windows
)


def find_gpu_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]:
for dir_string in paths_list_candidate.split(os.pathsep):
if not dir_string:
continue
if os.sep not in dir_string:
continue
try:
dir = Path(dir_string)
try:
if not dir.exists():
logger.warning(f"The directory listed in your path is found to be non-existent: {dir}")
continue
except OSError: # Assume an esoteric error trying to poke at the directory
pass
for lib_pattern in get_runtime_lib_patterns():
for pth in dir.glob(lib_pattern):
if pth.is_file() and not pth.is_symlink():
yield pth
except (OSError, PermissionError):
pass


def is_relevant_candidate_env_var(env_var: str, value: str) -> bool:
return (
env_var in GPU_RT_PATH_PREFERRED_ENVVARS # is a preferred location
or (
os.sep in value # might contain a path
and env_var not in GPU_RT_PATH_IGNORED_ENVVARS # not ignored
and "CONDA" not in env_var # not another conda envvar
and "BASH_FUNC" not in env_var # not a bash function defined via envvar
and "\n" not in value # likely e.g. a script or something?
)
)


def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]:
return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)}


def find_gpu_rt_libraries() -> Iterator[Path]:
"""
Searches for a cuda installations, in the following order of priority:
1. active conda env
2. LD_LIBRARY_PATH
3. any other env vars, while ignoring those that
- are known to be unrelated
- don't contain the path separator `/`
If multiple libraries are found in part 3, we optimistically try one,
while giving a warning message.
"""
candidate_env_vars = get_potentially_lib_path_containing_env_vars()

for envvar in GPU_RT_PATH_PREFERRED_ENVVARS:
if envvar in candidate_env_vars:
directory = candidate_env_vars[envvar]
yield from find_gpu_libraries_in_path_list(directory)
candidate_env_vars.pop(envvar)

for env_var, value in candidate_env_vars.items():
yield from find_gpu_libraries_in_path_list(value)


def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None:
print(
f"PyTorch settings found: CUDA_VERSION={gpu_specs.backend_version_string}, "
f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.",
)

binary_path = get_gpu_bnb_library_path(gpu_specs)
if not binary_path.exists():
print_dedented(
f"""
Library not found: {binary_path}. Maybe you need to compile it from source?
If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`,
for example, `make CUDA_VERSION=113`.
The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`.
""",
)

cuda_major, cuda_minor = gpu_specs.backend_version_tuple
if cuda_major < 11:
print_dedented(
"""
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines!
""",
)

print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")

# 7.5 is the minimum CC for cublaslt
if not gpu_specs.enable_blaslt:
print_dedented(
"""
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
If you run into issues with 8-bit matmul, you can try 4-bit quantization:
https://huggingface.co/blog/4bit-transformers-bitsandbytes
""",
)

# TODO:
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed


def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.backend_version_string}")

binary_path = get_gpu_bnb_library_path(gpu_specs)
if not binary_path.exists():
print_dedented(
f"""
Library not found: {binary_path}.
Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION
in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version
and rebuild bitsandbytes.
""",
)

hip_major, hip_minor = gpu_specs.backend_version_tuple
if (hip_major, hip_minor) < (6, 1):
print_dedented(
"""
WARNING: bitsandbytes is fully supported only from ROCm 6.1.
""",
)


def print_diagnostics(gpu_specs: GPUSpecs) -> None:
if HIP_ENVIRONMENT:
_print_hip_diagnostics(gpu_specs)
else:
_print_cuda_diagnostics(gpu_specs)


def _print_cuda_runtime_diagnostics() -> None:
gpu_rt_paths = list(find_gpu_rt_libraries())
if not gpu_rt_paths:
print("WARNING! CUDA runtime files not found in any environmental path.")
elif len(gpu_rt_paths) > 1:
print_dedented(
f"""
Found duplicate CUDA runtime files (see below).
We select the PyTorch default CUDA runtime, which is {torch.version.cuda},
but this might mismatch with the CUDA version that is needed for bitsandbytes.
To override this behavior set the `BNB_CUDA_VERSION=<version string, e.g. 122>` environmental variable.
For example, if you want to use the CUDA version 122,
BNB_CUDA_VERSION=122 python ...
OR set the environmental variable in your .bashrc:
export BNB_CUDA_VERSION=122
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
""",
)
for pth in gpu_rt_paths:
print(f"* Found CUDA runtime at: {pth}")


def _print_hip_runtime_diagnostics() -> None:
gpu_rt_paths = list(find_gpu_rt_libraries())
if not gpu_rt_paths:
print("WARNING! ROCm runtime files not found in any environmental path.")
elif len(gpu_rt_paths) > 1:
print_dedented(
f"""
Found duplicate ROCm runtime files (see below).
We select the PyTorch default ROCm runtime, which is {torch.version.hip},
but this might mismatch with the ROCm version that is needed for bitsandbytes.
To resolve it, install PyTorch built for the ROCm version you want to use
and set LD_LIBRARY_PATH to your ROCm install path, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib,
""",
)

for pth in gpu_rt_paths:
print(f"* Found ROCm runtime at: {pth}")


def print_runtime_diagnostics() -> None:
if HIP_ENVIRONMENT:
_print_hip_runtime_diagnostics()
else:
_print_cuda_runtime_diagnostics()
14 changes: 7 additions & 7 deletions bitsandbytes/diagnostics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT
from bitsandbytes.consts import PACKAGE_GITHUB_URL
from bitsandbytes.cuda_specs import get_cuda_specs
from bitsandbytes.gpu_specs import get_gpu_specs
from bitsandbytes.diagnostics.cuda import (
print_diagnostics,
print_runtime_diagnostics,
Expand Down Expand Up @@ -50,20 +50,20 @@ def main():
print_header("")

print_header("OTHER")
cuda_specs = get_cuda_specs()
gpu_specs = get_gpu_specs()
if HIP_ENVIRONMENT:
rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}',"
rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}"
rocm_specs = f" rocm_version_string='{gpu_specs.cuda_version_string}',"
rocm_specs += f" rocm_version_tuple={gpu_specs.cuda_version_tuple}"
print(f"{BNB_BACKEND} specs:{rocm_specs}")
else:
print(f"{BNB_BACKEND} specs:{cuda_specs}")
print(f"{BNB_BACKEND} specs:{gpu_specs}")
if not torch.cuda.is_available():
print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:")
print(f"1. {BNB_BACKEND} driver not installed")
print(f"2. {BNB_BACKEND} not installed")
print(f"3. You have multiple conflicting {BNB_BACKEND} libraries")
if cuda_specs:
print_diagnostics(cuda_specs)
if gpu_specs:
print_diagnostics(gpu_specs)
print_runtime_diagnostics()
print_header("")
print_header("DEBUG INFO END")
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/gpu_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class GPUSpecs:
backend_version_tuple: Tuple[int, int]

@property
def has_blaslt(self) -> bool:
def enable_blaslt(self) -> bool:
if torch.version.hip:
return self.highest_compute_capability >= 601
else:
Expand Down
Loading

0 comments on commit 716e010

Please sign in to comment.