diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 373a91875..4c705867d 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -30,28 +30,34 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: prefix = "rocm" if torch.version.hip else "cuda" library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" - override_value = os.environ.get("BNB_CUDA_VERSION") + cuda_override_value = os.environ.get("BNB_CUDA_VERSION") rocm_override_value = os.environ.get("BNB_ROCM_VERSION") - if rocm_override_value and torch.version.hip: + if rocm_override_value: library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1) + if torch.version.cuda: + raise RuntimeError( + f"BNB_ROCM_VERSION={rocm_override_value} detected for CUDA!\n" + "Use BNB_CUDA_VERSION instead: export BNB_CUDA_VERSION=\n" + "Clear the variable and retry: unset BNB_ROCM_VERSION\n" + ) logger.warning( f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n" - "If this was unintended set the BNB_ROCM_VERSION variable to an empty string: export BNB_ROCM_VERSION=\n" + "If this was unintended clear the variable and retry: unset BNB_ROCM_VERSION\n" ) - elif override_value: - library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) + elif cuda_override_value: + library_name = re.sub(r"cuda\d+", f"cuda{cuda_override_value}", library_name, count=1) if torch.version.hip: raise RuntimeError( - f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" + f"BNB_CUDA_VERSION={cuda_override_value} detected for ROCm!\n" f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=\n" - f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" + f"Clear the variable and retry: unset BNB_CUDA_VERSION\n" ) logger.warning( - f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" + f"WARNING: BNB_CUDA_VERSION={cuda_override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" - "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" + "If this was unintended clear the variable and retry: unset BNB_CUDA_VERSION\n" ) return PACKAGE_DIR / library_name diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index f74f05634..492cc0e77 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,11 +1,12 @@ import pytest -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path +from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path from bitsandbytes.cuda_specs import CUDASpecs @pytest.fixture def cuda120_spec() -> CUDASpecs: + """Simulates torch+cuda12.0 and a representative Ampere-class capability.""" return CUDASpecs( cuda_version_string="120", highest_compute_capability=(8, 6), @@ -13,30 +14,42 @@ def cuda120_spec() -> CUDASpecs: ) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): + """Without overrides, library path uses the detected CUDA 12.0 version.""" + monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): + """BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary.""" monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? -# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2 +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") +def test_get_cuda_bnb_library_path_rejects_rocm_override(monkeypatch, cuda120_spec): + """BNB_ROCM_VERSION should be rejected on CUDA with a helpful error.""" + monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) + monkeypatch.setenv("BNB_ROCM_VERSION", "72") + with pytest.raises(RuntimeError, match=r"BNB_ROCM_VERSION.*detected for CUDA"): + get_cuda_bnb_library_path(cuda120_spec) + + @pytest.fixture def rocm70_spec() -> CUDASpecs: + """Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer.""" return CUDASpecs( - cuda_version_string="70", # from torch.version.hip == "7.0.x" - highest_compute_capability=(0, 0), # unused for ROCm library path resolution + cuda_version_string="70", + highest_compute_capability=(0, 0), cuda_version_tuple=(7, 0), ) -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec): """Without override, library path uses PyTorch's ROCm 7.0 version.""" monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) @@ -44,29 +57,17 @@ def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec): assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70" -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog): """BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0.""" monkeypatch.setenv("BNB_ROCM_VERSION", "72") - monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72" assert "BNB_ROCM_VERSION" in caplog.text -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec): """BNB_CUDA_VERSION should be rejected on ROCm with a helpful error.""" monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) - monkeypatch.setenv("BNB_CUDA_VERSION", "72") + monkeypatch.setenv("BNB_CUDA_VERSION", "120") with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"): - get_cuda_bnb_library_path(rocm70_spec) - - -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") -def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog): - """When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True.""" - monkeypatch.setenv("BNB_ROCM_VERSION", "72") - monkeypatch.setenv("BNB_CUDA_VERSION", "72") - assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72" - assert "BNB_ROCM_VERSION" in caplog.text - assert "BNB_CUDA_VERSION" not in caplog.text + get_cuda_bnb_library_path(rocm70_spec) \ No newline at end of file