Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cuda_core/cuda/core/_linker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ from dataclasses import dataclass
from typing import Union
from warnings import warn

from cuda.pathfinder import optional_cuda_import
from cuda.pathfinder._optional_cuda_import import _optional_cuda_import
from cuda.core._device import Device
from cuda.core._module import ObjectCode
from cuda.core._utils.clear_error_support import assert_type
Expand Down Expand Up @@ -650,7 +650,7 @@ def _decide_nvjitlink_or_driver() -> bool:
" For best results, consider upgrading to a recent version of"
)

nvjitlink_module = optional_cuda_import(
nvjitlink_module = _optional_cuda_import(
"cuda.bindings.nvjitlink",
probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load
)
Expand Down
4 changes: 2 additions & 2 deletions cuda_core/cuda/core/_program.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import threading
from warnings import warn

from cuda.bindings import driver, nvrtc
from cuda.pathfinder import optional_cuda_import
from cuda.pathfinder._optional_cuda_import import _optional_cuda_import

from libcpp.vector cimport vector

Expand Down Expand Up @@ -485,7 +485,7 @@ def _get_nvvm_module():
"Please update cuda-bindings to use NVVM features."
)

nvvm = optional_cuda_import(
nvvm = _optional_cuda_import(
"cuda.bindings.nvvm",
probe_function=lambda module: module.version(), # probe triggers libnvvm load
)
Expand Down
20 changes: 10 additions & 10 deletions cuda_core/tests/test_optional_dependency_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def _patch_driver_version(monkeypatch, version=13000):
def test_get_nvvm_module_reraises_nested_module_not_found(monkeypatch):
monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9))

def fake_optional_cuda_import(modname, probe_function=None):
def fake__optional_cuda_import(modname, probe_function=None):
assert modname == "cuda.bindings.nvvm"
assert probe_function is not None
err = ModuleNotFoundError("No module named 'not_a_real_dependency'")
err.name = "not_a_real_dependency"
raise err

monkeypatch.setattr(_program, "optional_cuda_import", fake_optional_cuda_import)
monkeypatch.setattr(_program, "_optional_cuda_import", fake__optional_cuda_import)

with pytest.raises(ModuleNotFoundError, match="not_a_real_dependency") as excinfo:
_program._get_nvvm_module()
Expand All @@ -64,12 +64,12 @@ def fake_optional_cuda_import(modname, probe_function=None):
def test_get_nvvm_module_reports_missing_nvvm_module(monkeypatch):
monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9))

def fake_optional_cuda_import(modname, probe_function=None):
def fake__optional_cuda_import(modname, probe_function=None):
assert modname == "cuda.bindings.nvvm"
assert probe_function is not None
return None

monkeypatch.setattr(_program, "optional_cuda_import", fake_optional_cuda_import)
monkeypatch.setattr(_program, "_optional_cuda_import", fake__optional_cuda_import)

with pytest.raises(RuntimeError, match="cuda.bindings.nvvm"):
_program._get_nvvm_module()
Expand All @@ -78,12 +78,12 @@ def fake_optional_cuda_import(modname, probe_function=None):
def test_get_nvvm_module_handles_missing_libnvvm(monkeypatch):
monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9))

def fake_optional_cuda_import(modname, probe_function=None):
def fake__optional_cuda_import(modname, probe_function=None):
assert modname == "cuda.bindings.nvvm"
assert probe_function is not None
return None

monkeypatch.setattr(_program, "optional_cuda_import", fake_optional_cuda_import)
monkeypatch.setattr(_program, "_optional_cuda_import", fake__optional_cuda_import)

with pytest.raises(RuntimeError, match="libnvvm"):
_program._get_nvvm_module()
Expand All @@ -92,14 +92,14 @@ def fake_optional_cuda_import(modname, probe_function=None):
def test_decide_nvjitlink_or_driver_reraises_nested_module_not_found(monkeypatch):
_patch_driver_version(monkeypatch)

def fake_optional_cuda_import(modname, probe_function=None):
def fake__optional_cuda_import(modname, probe_function=None):
assert modname == "cuda.bindings.nvjitlink"
assert probe_function is not None
err = ModuleNotFoundError("No module named 'not_a_real_dependency'")
err.name = "not_a_real_dependency"
raise err

monkeypatch.setattr(_linker, "optional_cuda_import", fake_optional_cuda_import)
monkeypatch.setattr(_linker, "_optional_cuda_import", fake__optional_cuda_import)

with pytest.raises(ModuleNotFoundError, match="not_a_real_dependency") as excinfo:
_linker._decide_nvjitlink_or_driver()
Expand All @@ -109,12 +109,12 @@ def fake_optional_cuda_import(modname, probe_function=None):
def test_decide_nvjitlink_or_driver_falls_back_when_module_missing(monkeypatch):
_patch_driver_version(monkeypatch)

def fake_optional_cuda_import(modname, probe_function=None):
def fake__optional_cuda_import(modname, probe_function=None):
assert modname == "cuda.bindings.nvjitlink"
assert probe_function is not None
return None

monkeypatch.setattr(_linker, "optional_cuda_import", fake_optional_cuda_import)
monkeypatch.setattr(_linker, "_optional_cuda_import", fake__optional_cuda_import)

with pytest.warns(RuntimeWarning, match="cuda.bindings.nvjitlink is not available"):
use_driver_backend = _linker._decide_nvjitlink_or_driver()
Expand Down
1 change: 0 additions & 1 deletion cuda_pathfinder/cuda/pathfinder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
locate_nvidia_header_directory as locate_nvidia_header_directory,
)
from cuda.pathfinder._headers.supported_nvidia_headers import SUPPORTED_HEADERS_CTK as _SUPPORTED_HEADERS_CTK
from cuda.pathfinder._optional_cuda_import import optional_cuda_import as optional_cuda_import
from cuda.pathfinder._static_libs.find_bitcode_lib import (
SUPPORTED_BITCODE_LIBS as _SUPPORTED_BITCODE_LIBS,
)
Expand Down
2 changes: 1 addition & 1 deletion cuda_pathfinder/cuda/pathfinder/_optional_cuda_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cuda.pathfinder._dynamic_libs.load_dl_common import DynamicLibNotFoundError


def optional_cuda_import(
def _optional_cuda_import(
fully_qualified_modname: str,
*,
probe_function: Callable[[ModuleType], object] | None = None,
Expand Down
1 change: 0 additions & 1 deletion cuda_pathfinder/docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ locating NVIDIA C/C++ header directories, and finding CUDA binary utilities.

SUPPORTED_NVIDIA_LIBNAMES
load_nvidia_dynamic_lib
optional_cuda_import
LoadedDL
DynamicLibNotFoundError
DynamicLibUnknownError
Expand Down
6 changes: 2 additions & 4 deletions cuda_pathfinder/docs/source/release/1.4.2-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,5 @@
Highlights
----------

* Add ``optional_cuda_import()`` to support robust optional imports of CUDA
Python modules. It returns ``None`` when the requested module is absent or a
probe hits ``DynamicLibNotFoundError``, while still re-raising unrelated
``ModuleNotFoundError`` exceptions (for missing transitive dependencies).
* Privatize ``optional_cuda_import()`` (renamed to ``_optional_cuda_import()``)
to remove it from the public API surface.
23 changes: 12 additions & 11 deletions cuda_pathfinder/tests/test_optional_cuda_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,33 @@
import pytest

import cuda.pathfinder._optional_cuda_import as optional_import_mod
from cuda.pathfinder import DynamicLibNotFoundError, optional_cuda_import
from cuda.pathfinder import DynamicLibNotFoundError
from cuda.pathfinder._optional_cuda_import import _optional_cuda_import


def test_optional_cuda_import_returns_module_when_available(monkeypatch):
def test__optional_cuda_import_returns_module_when_available(monkeypatch):
fake_module = types.SimpleNamespace(__name__="cuda.bindings.nvvm")
monkeypatch.setattr(optional_import_mod.importlib, "import_module", lambda _name: fake_module)

result = optional_cuda_import("cuda.bindings.nvvm")
result = _optional_cuda_import("cuda.bindings.nvvm")

assert result is fake_module


def test_optional_cuda_import_returns_none_when_module_missing(monkeypatch):
def test__optional_cuda_import_returns_none_when_module_missing(monkeypatch):
def fake_import_module(name):
err = ModuleNotFoundError("No module named 'cuda.bindings.nvvm'")
err.name = name
raise err

monkeypatch.setattr(optional_import_mod.importlib, "import_module", fake_import_module)

result = optional_cuda_import("cuda.bindings.nvvm")
result = _optional_cuda_import("cuda.bindings.nvvm")

assert result is None


def test_optional_cuda_import_reraises_nested_module_not_found(monkeypatch):
def test__optional_cuda_import_reraises_nested_module_not_found(monkeypatch):
def fake_import_module(_name):
err = ModuleNotFoundError("No module named 'not_a_real_dependency'")
err.name = "not_a_real_dependency"
Expand All @@ -40,28 +41,28 @@ def fake_import_module(_name):
monkeypatch.setattr(optional_import_mod.importlib, "import_module", fake_import_module)

with pytest.raises(ModuleNotFoundError, match="not_a_real_dependency") as excinfo:
optional_cuda_import("cuda.bindings.nvvm")
_optional_cuda_import("cuda.bindings.nvvm")
assert excinfo.value.name == "not_a_real_dependency"


def test_optional_cuda_import_returns_none_when_probe_finds_missing_dynamic_lib(monkeypatch):
def test__optional_cuda_import_returns_none_when_probe_finds_missing_dynamic_lib(monkeypatch):
fake_module = types.SimpleNamespace(__name__="cuda.bindings.nvvm")
monkeypatch.setattr(optional_import_mod.importlib, "import_module", lambda _name: fake_module)

def probe(_module):
raise DynamicLibNotFoundError("libnvvm missing")

result = optional_cuda_import("cuda.bindings.nvvm", probe_function=probe)
result = _optional_cuda_import("cuda.bindings.nvvm", probe_function=probe)

assert result is None


def test_optional_cuda_import_reraises_non_pathfinder_probe_error(monkeypatch):
def test__optional_cuda_import_reraises_non_pathfinder_probe_error(monkeypatch):
fake_module = types.SimpleNamespace(__name__="cuda.bindings.nvvm")
monkeypatch.setattr(optional_import_mod.importlib, "import_module", lambda _name: fake_module)

def probe(_module):
raise RuntimeError("unexpected probe failure")

with pytest.raises(RuntimeError, match="unexpected probe failure"):
optional_cuda_import("cuda.bindings.nvvm", probe_function=probe)
_optional_cuda_import("cuda.bindings.nvvm", probe_function=probe)
Loading