@@ -71,16 +71,17 @@ def compile_cuda(
7171 - NVRTC is a "runtime" compilation library and can be faster for JIT compilation.
7272 - NVRTC requires cuda-python: pip install cuda-python
7373 """
74- # TODO: if need NVSHMEM for compilation, fall back to NVCC because support for NVRTC
75- # is not yet implemented
7674 use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in code
77- if compiler == "nvcc" or use_nvshmem :
78- return _compile_cuda_nvcc (code , target_format , arch , options , path_target , use_nvshmem )
75+
76+ if compiler == "nvcc" :
77+ result = _compile_cuda_nvcc (code , target_format , arch , options , path_target , use_nvshmem )
7978 elif compiler == "nvrtc" :
80- return _compile_cuda_nvrtc (code , target_format , arch , options )
79+ result = _compile_cuda_nvrtc (code , target_format , arch , options , path_target , use_nvshmem )
8180 else :
8281 raise ValueError (f"cuda compiler must be 'nvcc' or 'nvrtc', got: { compiler } " )
8382
83+ return result
84+
8485
8586def _compile_cuda_nvcc (
8687 code ,
@@ -235,7 +236,9 @@ def _compile_cuda_nvcc(
235236 return data
236237
237238
238- def _compile_cuda_nvrtc (code , target_format = None , arch = None , options = None ):
239+ def _compile_cuda_nvrtc (
240+ code , target_format = None , arch = None , options = None , path_target = None , use_nvshmem = False
241+ ):
239242 """Compile CUDA code using NVRTC (NVIDIA Runtime Compilation).
240243
241244 Parameters
@@ -248,6 +251,10 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
248251 Target architecture (e.g., "sm_80"). Auto-detected if None.
249252 options : str or list of str, optional
250253 Additional NVRTC options.
254+ path_target : str, optional
255+ Output file path. If provided, the compiled binary is written to this path.
256+ use_nvshmem : bool, optional
257+ Whether NVSHMEM is used. Default: False
251258
252259 Returns
253260 -------
@@ -264,8 +271,20 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
264271 "See: https://nvidia.github.io/cuda-python/"
265272 ) from e
266273
267- # Default target format
268- if target_format is None :
274+ # For NVSHMEM, we also need the CUDA driver API to initialize the context for linking
275+ if use_nvshmem :
276+ import importlib .util # pylint: disable=import-outside-toplevel
277+
278+ if importlib .util .find_spec ("cuda.bindings.driver" ) is None :
279+ raise RuntimeError (
280+ "Failed to compile CUDA with NVRTC+NVSHMEM because the `cuda-python` package "
281+ "is not available.\n "
282+ "Please install it with: pip install cuda-python\n "
283+ "See: https://nvidia.github.io/cuda-python/"
284+ )
285+
286+ # NVSHMEM requires linking with device library, which always produces cubin
287+ if use_nvshmem or target_format is None :
269288 target_format = "cubin"
270289
271290 # Validate target_format (NVRTC doesn't support fatbin)
@@ -287,6 +306,11 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
287306 compute_version = get_target_compute_version (Target .current (allow_none = True ))
288307 arch = f"sm_{ '' .join (compute_version .split ('.' ))} "
289308
309+ # Get NVSHMEM paths if needed
310+ nvshmem_include_path , nvshmem_lib_path = None , None
311+ if use_nvshmem :
312+ nvshmem_include_path , nvshmem_lib_path = find_nvshmem_paths ()
313+
290314 # Strip host-only headers for NVRTC. NVRTC compiles device code and does not
291315 # require the CUDA driver header or host C++ headers.
292316 headers_to_strip = {"#include <cuda.h>" }
@@ -304,6 +328,47 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
304328 "};\n \n " + code_filtered
305329 )
306330
331+ # Add standard type definitions and compatibility macros that NVRTC doesn't provide.
332+ nvrtc_preamble = """#include <cuda/std/cstdint>
333+ using cuda::std::uint8_t;
334+ using cuda::std::uint16_t;
335+ using cuda::std::uint32_t;
336+ using cuda::std::uint64_t;
337+ using cuda::std::int8_t;
338+ using cuda::std::int16_t;
339+ using cuda::std::int32_t;
340+ using cuda::std::int64_t;
341+
342+ // NVRTC uses asm/volatile instead of __asm__/__volatile__
343+ #ifndef __asm__
344+ #define __asm__ asm
345+ #endif
346+ #ifndef __volatile__
347+ #define __volatile__ volatile
348+ #endif
349+
350+ """
351+ code_filtered = nvrtc_preamble + code_filtered
352+
353+ # For NVSHMEM, add preamble to map cuda::std type traits to std namespace.
354+ # NVSHMEM headers require std:: type traits but NVRTC uses cuda::std::.
355+ if use_nvshmem :
356+ nvshmem_preamble = """#include <cuda/std/type_traits>
357+
358+ // Map cuda::std type traits to std namespace for NVSHMEM headers
359+ namespace std {
360+ using cuda::std::is_integral;
361+ using cuda::std::is_signed;
362+ using cuda::std::is_unsigned;
363+ using cuda::std::is_floating_point;
364+ using cuda::std::is_same;
365+ using cuda::std::enable_if;
366+ using cuda::std::conditional;
367+ }
368+
369+ """
370+ code_filtered = nvshmem_preamble + code_filtered
371+
307372 # Create NVRTC program
308373 # Use "tvm_kernels.cu" for consistency with nvcc path
309374 result , prog = nvrtc .nvrtcCreateProgram (
@@ -319,6 +384,9 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
319384 b"-default-device" ,
320385 ]
321386
387+ if use_nvshmem :
388+ compile_opts .extend ([b"-rdc" , b"true" ])
389+
322390 # Add CUDA include paths. NVRTC needs explicit include paths for CUDA headers.
323391 # Standard installations: cuda_path/include
324392 # Conda/architecture-specific installations: cuda_path/targets/<arch>/include
@@ -339,6 +407,12 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
339407 if os .path .isdir (arch_include ):
340408 include_paths .append (arch_include )
341409
410+ # Check for CCCL include directory (required for cuda/std/cstdint and type_traits)
411+ # CCCL provides standard library functionality for device code
412+ cccl_include = os .path .join (arch_include , "cccl" ) if os .path .isdir (arch_include ) else None
413+ if cccl_include and os .path .isdir (cccl_include ):
414+ include_paths .append (cccl_include )
415+
342416 # Verify we can find essential CUDA headers
343417 if not any (os .path .isfile (os .path .join (p , "cuda_runtime.h" )) for p in include_paths ):
344418 raise RuntimeError (
@@ -351,6 +425,26 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
351425 for include_path in include_paths :
352426 compile_opts .append (f"-I{ include_path } " .encode ())
353427
428+ # Add NVSHMEM include path
429+ if use_nvshmem and nvshmem_include_path :
430+ compile_opts .append (f"-I{ nvshmem_include_path } " .encode ())
431+
432+ # For NVSHMEM, add deprecation and type conversion macros
433+ if use_nvshmem :
434+ compile_opts .extend (
435+ [
436+ # Define deprecation macros as empty (not properly defined in NVRTC context)
437+ b"-D__NV_SILENCE_DEPRECATION_BEGIN=" ,
438+ b"-D__NV_SILENCE_DEPRECATION_END=" ,
439+ b"-D__NV_SILENCE_HOST_DEPRECATION_BEGIN=" ,
440+ b"-D__NV_SILENCE_HOST_DEPRECATION_END=" ,
441+ # Disable FP8/FP6/FP4 extended types that cause issues with NVRTC
442+ b"-D__CUDA_NO_FP8_CONVERSIONS__" ,
443+ b"-D__CUDA_NO_FP6_CONVERSIONS__" ,
444+ b"-D__CUDA_NO_FP4_CONVERSIONS__" ,
445+ ]
446+ )
447+
354448 compile_opts .extend (
355449 [
356450 b"-U__CUDA_NO_HALF_OPERATORS__" ,
@@ -363,12 +457,40 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
363457 ]
364458 )
365459
366- # Add user-provided options
460+ # Add user-provided options, filtering out nvcc-specific flags that nvrtc doesn't support
367461 if options :
462+ nvcc_only_prefixes = (
463+ "-c" ,
464+ "-O" ,
465+ "-std" ,
466+ "--std" ,
467+ "-Xcompiler" ,
468+ "-Xlinker" ,
469+ "-Xarchive" ,
470+ "-Xcudafe" ,
471+ "-Xptxas" ,
472+ "--compile" ,
473+ "--compiler-options" ,
474+ "--linker-options" ,
475+ "-fPIC" ,
476+ "-shared" ,
477+ "-o" ,
478+ )
368479 if isinstance (options , str ):
369- compile_opts .append (options .encode ())
370- else :
371- compile_opts .extend ([opt .encode () if isinstance (opt , str ) else opt for opt in options ])
480+ options = [options ]
481+ for opt in options :
482+ if isinstance (opt , str ):
483+ opt_str = opt
484+ elif isinstance (opt , bytes ):
485+ opt_str = opt .decode ()
486+ else :
487+ opt_str = str (opt )
488+ skip = any (
489+ opt_str .startswith (prefix ) or opt_str == prefix for prefix in nvcc_only_prefixes
490+ )
491+ if skip :
492+ continue
493+ compile_opts .append (opt .encode () if isinstance (opt , str ) else opt )
372494
373495 # Compile
374496 (result ,) = nvrtc .nvrtcCompileProgram (prog , len (compile_opts ), compile_opts )
@@ -410,10 +532,94 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
410532 nvrtc .nvrtcDestroyProgram (prog )
411533 raise RuntimeError (f"Failed to get PTX: { nvrtc .nvrtcGetErrorString (result )} " )
412534
413- # Clean up
535+ # Clean up NVRTC program
414536 nvrtc .nvrtcDestroyProgram (prog )
415537
416- return bytearray (binary_buf )
538+ # Link stage for NVSHMEM
539+ if use_nvshmem :
540+ binary_buf = _link_nvshmem_nvrtc (binary_buf , nvshmem_lib_path )
541+
542+ if path_target :
543+ with open (path_target , "wb" ) as f :
544+ f .write (binary_buf )
545+ return binary_buf
546+
547+
548+ def _link_nvshmem_nvrtc (binary_buf , nvshmem_lib_path ):
549+ """Link compiled CUBIN with NVSHMEM device library using CUDA driver API."""
550+ import ctypes # pylint: disable=import-outside-toplevel
551+
552+ from cuda .bindings import driver as cu # pylint: disable=import-outside-toplevel
553+
554+ # cuLinkCreate requires a valid CUDA context.
555+ # Always create a fresh context for linking to avoid issues with stale contexts
556+ # in multi-process environments like Disco workers.
557+ (result ,) = cu .cuInit (0 )
558+ if result != cu .CUresult .CUDA_SUCCESS :
559+ raise RuntimeError (f"Failed to initialize CUDA: { result } " )
560+
561+ result , device = cu .cuDeviceGet (0 )
562+ if result != cu .CUresult .CUDA_SUCCESS :
563+ raise RuntimeError (f"Failed to get CUDA device: { result } " )
564+
565+ result , context = cu .cuCtxCreate (None , 0 , device )
566+ if result != cu .CUresult .CUDA_SUCCESS :
567+ raise RuntimeError (f"Failed to create CUDA context: { result } " )
568+
569+ try :
570+ # Create linker
571+ result , link_state = cu .cuLinkCreate (0 , [], [])
572+ if result != cu .CUresult .CUDA_SUCCESS :
573+ raise RuntimeError (f"Failed to create CUDA linker: { result } " )
574+
575+ try :
576+ # Add our compiled CUBIN
577+ (result ,) = cu .cuLinkAddData (
578+ link_state ,
579+ cu .CUjitInputType .CU_JIT_INPUT_CUBIN ,
580+ binary_buf ,
581+ len (binary_buf ),
582+ b"tvm_kernels.cubin" ,
583+ 0 ,
584+ [],
585+ [],
586+ )
587+ if result != cu .CUresult .CUDA_SUCCESS :
588+ raise RuntimeError (f"Failed to add CUBIN to linker: { result } " )
589+
590+ # Add NVSHMEM device library
591+ nvshmem_device_lib = os .path .join (nvshmem_lib_path , "libnvshmem_device.a" )
592+ if not os .path .exists (nvshmem_device_lib ):
593+ raise RuntimeError (f"NVSHMEM device library not found: { nvshmem_device_lib } " )
594+
595+ (result ,) = cu .cuLinkAddFile (
596+ link_state ,
597+ cu .CUjitInputType .CU_JIT_INPUT_LIBRARY ,
598+ nvshmem_device_lib .encode (),
599+ 0 ,
600+ [],
601+ [],
602+ )
603+ if result != cu .CUresult .CUDA_SUCCESS :
604+ raise RuntimeError (f"Failed to add NVSHMEM device library: { result } " )
605+
606+ # Complete linking
607+ result , linked_cubin , linked_size = cu .cuLinkComplete (link_state )
608+ if result != cu .CUresult .CUDA_SUCCESS :
609+ raise RuntimeError (f"Failed to complete NVSHMEM linking: { result } " )
610+
611+ # Copy linked binary before destroying linker
612+ binary_buf = bytearray (ctypes .string_at (linked_cubin , linked_size ))
613+ if not binary_buf :
614+ raise RuntimeError ("Compilation error: empty result is generated" )
615+ finally :
616+ # Clean up linker
617+ cu .cuLinkDestroy (link_state )
618+ finally :
619+ # Clean up context
620+ cu .cuCtxDestroy (context )
621+
622+ return binary_buf
417623
418624
419625def find_cuda_path ():
0 commit comments