Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,34 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
prefix = "rocm" if torch.version.hip else "cuda"
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"

override_value = os.environ.get("BNB_CUDA_VERSION")
cuda_override_value = os.environ.get("BNB_CUDA_VERSION")
rocm_override_value = os.environ.get("BNB_ROCM_VERSION")

if rocm_override_value and torch.version.hip:
if rocm_override_value:
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
if torch.version.cuda:
raise RuntimeError(
f"BNB_ROCM_VERSION={rocm_override_value} detected for CUDA!\n"
"Use BNB_CUDA_VERSION instead: export BNB_CUDA_VERSION=<version>\n"
"Clear the variable and retry: unset BNB_ROCM_VERSION\n"
)
logger.warning(
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
"If this was unintended set the BNB_ROCM_VERSION variable to an empty string: export BNB_ROCM_VERSION=\n"
"If this was unintended clear the variable and retry: unset BNB_ROCM_VERSION\n"
)
elif override_value:
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
elif cuda_override_value:
library_name = re.sub(r"cuda\d+", f"cuda{cuda_override_value}", library_name, count=1)
if torch.version.hip:
raise RuntimeError(
f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n"
f"BNB_CUDA_VERSION={cuda_override_value} detected for ROCm!\n"
f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
f"Clear the variable and retry: export BNB_CUDA_VERSION=\n"
f"Clear the variable and retry: unset BNB_CUDA_VERSION\n"
)
logger.warning(
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
"If this was unintended clear the variable and retry: unset BNB_CUDA_VERSION\n"
)

return PACKAGE_DIR / library_name
Expand Down
47 changes: 24 additions & 23 deletions tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,73 @@
import pytest

from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path
from bitsandbytes.cuda_specs import CUDASpecs


@pytest.fixture
def cuda120_spec() -> CUDASpecs:
"""Simulates torch+cuda12.0 and a representative Ampere-class capability."""
return CUDASpecs(
cuda_version_string="120",
highest_compute_capability=(8, 6),
cuda_version_tuple=(12, 0),
)


@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
"""Without overrides, library path uses the detected CUDA 12.0 version."""
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"


@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
"""BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary."""
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?


# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
def test_get_cuda_bnb_library_path_rejects_rocm_override(monkeypatch, cuda120_spec):
"""BNB_ROCM_VERSION should be rejected on CUDA with a helpful error."""
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
with pytest.raises(RuntimeError, match=r"BNB_ROCM_VERSION.*detected for CUDA"):
get_cuda_bnb_library_path(cuda120_spec)


@pytest.fixture
def rocm70_spec() -> CUDASpecs:
"""Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer."""
return CUDASpecs(
cuda_version_string="70", # from torch.version.hip == "7.0.x"
highest_compute_capability=(0, 0), # unused for ROCm library path resolution
cuda_version_string="70",
highest_compute_capability=(0, 0),
cuda_version_tuple=(7, 0),
)


@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec):
"""Without override, library path uses PyTorch's ROCm 7.0 version."""
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70"


@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog):
"""BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0."""
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
assert "BNB_ROCM_VERSION" in caplog.text


@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec):
"""BNB_CUDA_VERSION should be rejected on ROCm with a helpful error."""
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
monkeypatch.setenv("BNB_CUDA_VERSION", "120")
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"):
get_cuda_bnb_library_path(rocm70_spec)


@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog):
"""When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True."""
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
assert "BNB_ROCM_VERSION" in caplog.text
assert "BNB_CUDA_VERSION" not in caplog.text
get_cuda_bnb_library_path(rocm70_spec)