Skip to content

Commit

Permalink
adding arch detect torocm6.2_internal_testing
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Apr 26, 2024
1 parent a36bd1d commit 9bcbf40
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
9 changes: 9 additions & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ctypes as ct
import os
import torch
import subprocess
import re

from pathlib import Path
from warnings import warn
Expand Down Expand Up @@ -30,11 +32,18 @@
lib.get_context.restype = ct.c_void_p

HIP_ENVIRONMENT = False
ROCM_GPU_ARCH = "unknown"
if torch.version.cuda:
lib.get_cusparse.restype = ct.c_void_p
elif torch.version.hip:
HIP_ENVIRONMENT = True
lib.get_hipsparse.restype = ct.c_void_p
result = subprocess.run(['rocminfo'], capture_output=True, text=True)
match = re.search(r'Name:\s+gfx(\d+)', result.stdout)
if match:
ROCM_GPU_ARCH = "gfx" + match.group(1)
else:
ROCM_GPU_ARCH = "unknown"

lib.cget_managed_ptr.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
Expand Down
3 changes: 2 additions & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
from scipy.stats import norm

torch.set_printoptions(
Expand Down Expand Up @@ -2543,6 +2543,7 @@ def test_managed():
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
@pytest.mark.skipif(HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", reason="this test is not supported on ROCm with gfx90a architecture yet")
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
Expand Down

0 comments on commit 9bcbf40

Please sign in to comment.