Skip to content

Commit 93f3880

Browse files
leofangclaudepre-commit-ci[bot]
authored
Add tests for ObjectCode.from_fatbin() using nvfatbin bindings (#1875)
* Add tests for ObjectCode.from_fatbin() using nvfatbin bindings Add availability detection for nvfatbin bindings and tests for loading fatbin code both from memory (bytes) and from file (str path). The fatbin fixture creates a multi-arch fatbin containing a cubin for the current device arch and PTX for a second arch, exercising the nvfatbin API (create, add_cubin, add_ptx, size, get, destroy). Partially addresses #663. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix fatbin fixture: compile PTX targeting second_arch The PTX was being compiled for the current device's arch but labeled as a different arch in nvfatbin, which could produce an invalid fatbin. Now compile PTX with ProgramOptions(arch=f"sm_{second_arch}") so the PTX actually targets the intended architecture. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Make cubin compilation arch explicit in fatbin fixture Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * [pre-commit.ci] auto code formatting --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent da79d63 commit 93f3880

1 file changed

Lines changed: 74 additions & 0 deletions

File tree

cuda_core/tests/test_module.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@
3333
"""
3434

3535

36+
def _is_nvfatbin_available():
37+
"""Check if nvfatbin bindings are available."""
38+
try:
39+
from cuda.bindings import nvfatbin
40+
41+
nvfatbin.version()
42+
return True
43+
except Exception:
44+
return False
45+
46+
47+
nvfatbin_available = pytest.mark.skipif(not _is_nvfatbin_available(), reason="nvfatbin bindings not available")
48+
49+
3650
@pytest.fixture(scope="module")
3751
def cuda12_4_prerequisite_check():
3852
return binding_version() >= (12, 0, 0) and driver_version() >= (12, 4, 0)
@@ -90,6 +104,44 @@ def get_saxpy_kernel_ltoir(init_cuda):
90104
return mod
91105

92106

107+
@pytest.fixture
108+
def get_saxpy_fatbin(init_cuda):
109+
from cuda.bindings import nvfatbin
110+
111+
dev = Device()
112+
arch = dev.arch
113+
114+
# Pick a second arch different from the current device
115+
second_arch = "75" if arch != "75" else "80"
116+
117+
# Compile to cubin for current device arch
118+
prog = Program(SAXPY_KERNEL, code_type="c++", options=ProgramOptions(arch=f"sm_{arch}"))
119+
mod = prog.compile(
120+
"cubin",
121+
name_expressions=("saxpy<float>", "saxpy<double>"),
122+
)
123+
cubin = mod.code
124+
sym_map = mod.symbol_mapping
125+
126+
# Compile to PTX targeting the second arch
127+
ptx_mod = Program(SAXPY_KERNEL, code_type="c++", options=ProgramOptions(arch=f"sm_{second_arch}")).compile(
128+
"ptx",
129+
name_expressions=("saxpy<float>", "saxpy<double>"),
130+
)
131+
ptx = ptx_mod.code
132+
133+
# Create fatbin with both cubin + PTX
134+
handle = nvfatbin.create([], 0)
135+
nvfatbin.add_cubin(handle, cubin, len(cubin), arch, "saxpy")
136+
nvfatbin.add_ptx(handle, ptx, len(ptx), second_arch, "saxpy", f"-arch=sm_{second_arch}")
137+
fatbin_size = nvfatbin.size(handle)
138+
fatbin = bytearray(fatbin_size)
139+
nvfatbin.get(handle, fatbin)
140+
nvfatbin.destroy(handle)
141+
142+
return bytes(fatbin), sym_map
143+
144+
93145
def test_get_kernel(init_cuda):
94146
kernel = """extern "C" __global__ void ABC() { }"""
95147

@@ -220,6 +272,28 @@ def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path):
220272
# ltoir doesn't support kernel retrieval directly as it's used for linking
221273

222274

275+
@nvfatbin_available
276+
def test_object_code_load_fatbin(get_saxpy_fatbin):
277+
fatbin, sym_map = get_saxpy_fatbin
278+
assert isinstance(fatbin, bytes)
279+
mod_obj = ObjectCode.from_fatbin(fatbin, symbol_mapping=sym_map)
280+
assert mod_obj.code == fatbin
281+
assert mod_obj.code_type == "fatbin"
282+
mod_obj.get_kernel("saxpy<double>") # force loading
283+
284+
285+
@nvfatbin_available
286+
def test_object_code_load_fatbin_from_file(get_saxpy_fatbin, tmp_path):
287+
fatbin, sym_map = get_saxpy_fatbin
288+
assert isinstance(fatbin, bytes)
289+
fatbin_file = tmp_path / "test.fatbin"
290+
fatbin_file.write_bytes(fatbin)
291+
mod_obj = ObjectCode.from_fatbin(str(fatbin_file), symbol_mapping=sym_map)
292+
assert mod_obj.code == str(fatbin_file)
293+
assert mod_obj.code_type == "fatbin"
294+
mod_obj.get_kernel("saxpy<double>") # force loading
295+
296+
223297
def test_saxpy_arguments(get_saxpy_kernel_cubin, cuda12_4_prerequisite_check):
224298
krn, _ = get_saxpy_kernel_cubin
225299

0 commit comments

Comments
 (0)