Skip to content

Commit 79f894f

Browse files
committed
tests: add coverage tests for cuda core
1 parent a8805e5 commit 79f894f

5 files changed

Lines changed: 422 additions & 1 deletion

File tree

cuda_core/tests/test_event.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

44

@@ -193,6 +193,54 @@ def test_event_type_safety(init_cuda):
193193
assert (event is None) is False
194194

195195

196+
def test_event_isub_not_implemented(init_cuda):
197+
"""Event.__isub__ returns NotImplemented for non-Event types."""
198+
device = Device()
199+
stream = device.create_stream()
200+
event = stream.record()
201+
result = event.__isub__(42)
202+
assert result is NotImplemented
203+
204+
205+
def test_event_rsub_not_implemented(init_cuda):
206+
"""Event.__rsub__ returns NotImplemented for non-Event types."""
207+
device = Device()
208+
stream = device.create_stream()
209+
event = stream.record()
210+
result = event.__rsub__(42)
211+
assert result is NotImplemented
212+
213+
214+
def test_event_get_ipc_descriptor_non_ipc(init_cuda):
215+
"""get_ipc_descriptor raises RuntimeError on a non-IPC event."""
216+
device = Device()
217+
stream = device.create_stream()
218+
event = stream.record()
219+
with pytest.raises(RuntimeError, match="not IPC-enabled"):
220+
event.get_ipc_descriptor()
221+
222+
223+
def test_event_is_done_false(init_cuda):
224+
"""Event.is_done returns False when captured work has not yet completed."""
225+
device = Device()
226+
latch = LatchKernel(device)
227+
stream = device.create_stream()
228+
latch.launch(stream)
229+
event = stream.record()
230+
# The latch holds the kernel; the event cannot be done yet.
231+
assert event.is_done is False
232+
latch.release()
233+
event.sync()
234+
235+
236+
def test_ipc_event_descriptor_direct_init():
237+
"""IPCEventDescriptor cannot be instantiated directly."""
238+
import cuda.core._event as _event_module
239+
240+
with pytest.raises(RuntimeError, match="cannot be instantiated directly"):
241+
_event_module.IPCEventDescriptor()
242+
243+
196244
# ============================================================================
197245
# Event Hash Tests
198246
# ============================================================================

cuda_core/tests/test_launcher.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,52 @@ def test_kernel_arg_unsupported_type():
387387

388388
with pytest.raises(TypeError, match="unsupported type"):
389389
ParamHolder(["not_a_valid_kernel_arg"])
390+
391+
392+
def test_kernel_arg_ctypes_subclass_isinstance_fallback():
393+
"""Subclassed ctypes types hit the isinstance fallback in prepare_ctypes_arg."""
394+
from cuda.core._kernel_arg_handler import ParamHolder
395+
396+
class MyInt32(ctypes.c_int32):
397+
pass
398+
399+
class MyFloat(ctypes.c_float):
400+
pass
401+
402+
class MyBool(ctypes.c_bool):
403+
pass
404+
405+
# These should NOT raise — they should be handled via isinstance fallback
406+
holder = ParamHolder([MyInt32(42), MyFloat(3.14), MyBool(True)])
407+
assert holder.ptr != 0
408+
409+
410+
def test_kernel_arg_numpy_subclass_isinstance_fallback():
411+
"""Subclassed numpy scalars hit the isinstance fallback in prepare_numpy_arg."""
412+
from cuda.core._kernel_arg_handler import ParamHolder
413+
414+
class MyInt32(np.int32):
415+
pass
416+
417+
class MyFloat32(np.float32):
418+
pass
419+
420+
holder = ParamHolder([MyInt32(7), MyFloat32(2.5)])
421+
assert holder.ptr != 0
422+
423+
424+
def test_kernel_arg_python_isinstance_fallbacks():
425+
"""Subclassed Python builtins hit the isinstance fallback in ParamHolder."""
426+
from cuda.core._kernel_arg_handler import ParamHolder
427+
428+
class MyBool(int):
429+
"""type(x) is not int, so fast path skips; isinstance(x, int) catches it."""
430+
431+
class MyFloat(float):
432+
pass
433+
434+
class MyComplex(complex):
435+
pass
436+
437+
holder = ParamHolder([MyBool(1), MyFloat(1.5), MyComplex(1 + 2j)])
438+
assert holder.ptr != 0

cuda_core/tests/test_linker.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,24 @@ def test_linker_logs_cached_after_link(compile_ptx_functions):
221221
# Calling again should return the same observable values.
222222
assert linker.get_error_log() == err_log
223223
assert linker.get_info_log() == info_log
224+
225+
226+
def test_linker_handle(compile_ptx_functions):
227+
"""Linker.handle returns a non-null handle object."""
228+
options = LinkerOptions(arch=ARCH)
229+
linker = Linker(*compile_ptx_functions, options=options)
230+
handle = linker.handle
231+
assert handle is not None
232+
assert int(handle) != 0
233+
234+
235+
@pytest.mark.skipif(is_culink_backend, reason="nvjitlink options only tested with nvjitlink backend")
236+
def test_linker_options_nvjitlink_options_as_str():
237+
"""_prepare_nvjitlink_options(as_bytes=False) returns plain strings."""
238+
opts = LinkerOptions(arch=ARCH, debug=True, lineinfo=True)
239+
options = opts._prepare_nvjitlink_options(as_bytes=False)
240+
assert isinstance(options, list)
241+
assert all(isinstance(o, str) for o in options)
242+
assert f"-arch={ARCH}" in options
243+
assert "-g" in options
244+
assert "-lineinfo" in options

cuda_core/tests/test_program.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,107 @@ def test_program_options_as_bytes_nvvm_unsupported_option():
773773
options = ProgramOptions(arch="sm_80", lineinfo=True)
774774
with pytest.raises(CUDAError, match="not supported by NVVM backend"):
775775
options.as_bytes("nvvm")
776+
777+
778+
def test_program_options_repr():
779+
"""ProgramOptions.__repr__ returns a human-readable string."""
780+
opts = ProgramOptions(name="mykernel", arch="sm_80")
781+
r = repr(opts)
782+
assert "ProgramOptions" in r
783+
assert "mykernel" in r
784+
assert "sm_80" in r
785+
786+
787+
def test_program_options_bad_define_macro_short_tuple():
788+
"""define_macro with a 1-element tuple raises RuntimeError."""
789+
opts = ProgramOptions(name="test", arch="sm_80", define_macro=("ONLY_NAME",))
790+
with pytest.raises(RuntimeError, match="Expected define_macro tuple"):
791+
opts.as_bytes("nvrtc")
792+
793+
794+
def test_program_options_bad_define_macro_non_str_value():
795+
"""define_macro tuple with a non-string value raises RuntimeError."""
796+
opts = ProgramOptions(name="test", arch="sm_80", define_macro=("MY_MACRO", 99))
797+
with pytest.raises(RuntimeError, match="Expected define_macro tuple"):
798+
opts.as_bytes("nvrtc")
799+
800+
801+
def test_program_options_bad_define_macro_list_non_str():
802+
"""define_macro list containing a non-str/non-tuple item raises RuntimeError."""
803+
opts = ProgramOptions(name="test", arch="sm_80", define_macro=[42])
804+
with pytest.raises(RuntimeError, match="Expected define_macro"):
805+
opts.as_bytes("nvrtc")
806+
807+
808+
def test_program_options_bad_define_macro_list_bad_tuple():
809+
"""define_macro list with a malformed tuple inside raises RuntimeError."""
810+
opts = ProgramOptions(name="test", arch="sm_80", define_macro=[("ONLY_NAME",)])
811+
with pytest.raises(RuntimeError, match="Expected define_macro"):
812+
opts.as_bytes("nvrtc")
813+
814+
815+
def test_ptx_program_extra_sources_unsupported(ptx_code_object):
816+
"""PTX backend raises ValueError when extra_sources is specified."""
817+
options = ProgramOptions(extra_sources=[("module1", b"data")])
818+
with pytest.raises(ValueError, match="extra_sources is not supported by the PTX backend"):
819+
Program(ptx_code_object.code.decode(), "ptx", options)
820+
821+
822+
def test_ptx_program_handle_is_linker_handle(init_cuda, ptx_code_object):
823+
"""Program.handle for the PTX backend delegates to the linker handle."""
824+
program = Program(ptx_code_object.code.decode(), "ptx")
825+
handle = program.handle
826+
assert handle is not None
827+
assert int(handle) != 0
828+
program.close()
829+
830+
831+
@nvvm_available
832+
def test_nvvm_program_wrong_code_type():
833+
"""NVVM backend raises TypeError when code is not str/bytes/bytearray."""
834+
with pytest.raises(TypeError, match="NVVM IR code must be provided as str, bytes, or bytearray"):
835+
Program(42, "nvvm")
836+
837+
838+
def test_extra_sources_not_sequence():
839+
"""extra_sources must be a sequence; non-sequence raises TypeError."""
840+
with pytest.raises(TypeError, match="extra_sources must be a sequence of 2-tuples"):
841+
ProgramOptions(name="test", arch="sm_80", extra_sources=42)
842+
843+
844+
def test_extra_sources_bad_module_not_tuple():
845+
"""extra_sources items must be 2-tuples; non-tuple item raises TypeError."""
846+
with pytest.raises(TypeError, match="Each extra module must be a 2-tuple"):
847+
ProgramOptions(name="test", arch="sm_80", extra_sources=["not_a_tuple"])
848+
849+
850+
def test_extra_sources_bad_module_name_not_str():
851+
"""extra_sources module name must be a string; non-str raises TypeError."""
852+
with pytest.raises(TypeError, match="Module name at index 0 must be a string"):
853+
ProgramOptions(name="test", arch="sm_80", extra_sources=[(42, b"source")])
854+
855+
856+
def test_extra_sources_bad_module_source_wrong_type():
857+
"""extra_sources module source must be str/bytes/bytearray."""
858+
with pytest.raises(TypeError, match="Module source at index 0 must be str"):
859+
ProgramOptions(name="test", arch="sm_80", extra_sources=[("mod", 42)])
860+
861+
862+
def test_extra_sources_empty_source():
863+
"""extra_sources module source cannot be empty bytes."""
864+
with pytest.raises(ValueError, match="Module source for 'mod'.*cannot be empty"):
865+
ProgramOptions(name="test", arch="sm_80", extra_sources=[("mod", b"")])
866+
867+
868+
def test_nvrtc_compile_with_logs_capture(init_cuda):
869+
"""Program.compile with logs= exercises the NVRTC program-log reading path."""
870+
import io
871+
872+
# #warning generates a non-empty NVRTC program log, ensuring logsize > 1.
873+
code = '#warning "test log capture"\nextern "C" __global__ void my_kernel() {}'
874+
program = Program(code, "c++")
875+
logs = io.StringIO()
876+
result = program.compile("ptx", logs=logs)
877+
assert isinstance(result, ObjectCode)
878+
assert logs.getvalue(), "Expected non-empty compilation log from #warning directive"
879+
program.close()

0 commit comments

Comments
 (0)