Skip to content

Commit 2004a8b

Browse files
authored
[NVRTC] Add NVSHMEM support to NVRTC compilation path (#18681)
1 parent 74adf2d commit 2004a8b

File tree

4 files changed

+379
-30
lines changed

4 files changed

+379
-30
lines changed

python/tvm/contrib/nvcc.py

Lines changed: 220 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,17 @@ def compile_cuda(
7171
- NVRTC is a "runtime" compilation library and can be faster for JIT compilation.
7272
- NVRTC requires cuda-python: pip install cuda-python
7373
"""
74-
# TODO: if need NVSHMEM for compilation, fall back to NVCC because support for NVRTC
75-
# is not yet implemented
7674
use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in code
77-
if compiler == "nvcc" or use_nvshmem:
78-
return _compile_cuda_nvcc(code, target_format, arch, options, path_target, use_nvshmem)
75+
76+
if compiler == "nvcc":
77+
result = _compile_cuda_nvcc(code, target_format, arch, options, path_target, use_nvshmem)
7978
elif compiler == "nvrtc":
80-
return _compile_cuda_nvrtc(code, target_format, arch, options)
79+
result = _compile_cuda_nvrtc(code, target_format, arch, options, path_target, use_nvshmem)
8180
else:
8281
raise ValueError(f"cuda compiler must be 'nvcc' or 'nvrtc', got: {compiler}")
8382

83+
return result
84+
8485

8586
def _compile_cuda_nvcc(
8687
code,
@@ -235,7 +236,9 @@ def _compile_cuda_nvcc(
235236
return data
236237

237238

238-
def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
239+
def _compile_cuda_nvrtc(
240+
code, target_format=None, arch=None, options=None, path_target=None, use_nvshmem=False
241+
):
239242
"""Compile CUDA code using NVRTC (NVIDIA Runtime Compilation).
240243
241244
Parameters
@@ -248,6 +251,10 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
248251
Target architecture (e.g., "sm_80"). Auto-detected if None.
249252
options : str or list of str, optional
250253
Additional NVRTC options.
254+
path_target : str, optional
255+
Output file path. If provided, the compiled binary is written to this path.
256+
use_nvshmem : bool, optional
257+
Whether NVSHMEM is used. Default: False
251258
252259
Returns
253260
-------
@@ -264,8 +271,20 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
264271
"See: https://nvidia.github.io/cuda-python/"
265272
) from e
266273

267-
# Default target format
268-
if target_format is None:
274+
# For NVSHMEM, we also need the CUDA driver API to initialize the context for linking
275+
if use_nvshmem:
276+
import importlib.util # pylint: disable=import-outside-toplevel
277+
278+
if importlib.util.find_spec("cuda.bindings.driver") is None:
279+
raise RuntimeError(
280+
"Failed to compile CUDA with NVRTC+NVSHMEM because the `cuda-python` package "
281+
"is not available.\n"
282+
"Please install it with: pip install cuda-python\n"
283+
"See: https://nvidia.github.io/cuda-python/"
284+
)
285+
286+
# NVSHMEM requires linking with device library, which always produces cubin
287+
if use_nvshmem or target_format is None:
269288
target_format = "cubin"
270289

271290
# Validate target_format (NVRTC doesn't support fatbin)
@@ -287,6 +306,11 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
287306
compute_version = get_target_compute_version(Target.current(allow_none=True))
288307
arch = f"sm_{''.join(compute_version.split('.'))}"
289308

309+
# Get NVSHMEM paths if needed
310+
nvshmem_include_path, nvshmem_lib_path = None, None
311+
if use_nvshmem:
312+
nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths()
313+
290314
# Strip host-only headers for NVRTC. NVRTC compiles device code and does not
291315
# require the CUDA driver header or host C++ headers.
292316
headers_to_strip = {"#include <cuda.h>"}
@@ -304,6 +328,47 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
304328
"};\n\n" + code_filtered
305329
)
306330

331+
# Add standard type definitions and compatibility macros that NVRTC doesn't provide.
332+
nvrtc_preamble = """#include <cuda/std/cstdint>
333+
using cuda::std::uint8_t;
334+
using cuda::std::uint16_t;
335+
using cuda::std::uint32_t;
336+
using cuda::std::uint64_t;
337+
using cuda::std::int8_t;
338+
using cuda::std::int16_t;
339+
using cuda::std::int32_t;
340+
using cuda::std::int64_t;
341+
342+
// NVRTC uses asm/volatile instead of __asm__/__volatile__
343+
#ifndef __asm__
344+
#define __asm__ asm
345+
#endif
346+
#ifndef __volatile__
347+
#define __volatile__ volatile
348+
#endif
349+
350+
"""
351+
code_filtered = nvrtc_preamble + code_filtered
352+
353+
# For NVSHMEM, add preamble to map cuda::std type traits to std namespace.
354+
# NVSHMEM headers require std:: type traits but NVRTC uses cuda::std::.
355+
if use_nvshmem:
356+
nvshmem_preamble = """#include <cuda/std/type_traits>
357+
358+
// Map cuda::std type traits to std namespace for NVSHMEM headers
359+
namespace std {
360+
using cuda::std::is_integral;
361+
using cuda::std::is_signed;
362+
using cuda::std::is_unsigned;
363+
using cuda::std::is_floating_point;
364+
using cuda::std::is_same;
365+
using cuda::std::enable_if;
366+
using cuda::std::conditional;
367+
}
368+
369+
"""
370+
code_filtered = nvshmem_preamble + code_filtered
371+
307372
# Create NVRTC program
308373
# Use "tvm_kernels.cu" for consistency with nvcc path
309374
result, prog = nvrtc.nvrtcCreateProgram(
@@ -319,6 +384,9 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
319384
b"-default-device",
320385
]
321386

387+
if use_nvshmem:
388+
compile_opts.extend([b"-rdc", b"true"])
389+
322390
# Add CUDA include paths. NVRTC needs explicit include paths for CUDA headers.
323391
# Standard installations: cuda_path/include
324392
# Conda/architecture-specific installations: cuda_path/targets/<arch>/include
@@ -339,6 +407,12 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
339407
if os.path.isdir(arch_include):
340408
include_paths.append(arch_include)
341409

410+
# Check for CCCL include directory (required for cuda/std/cstdint and type_traits)
411+
# CCCL provides standard library functionality for device code
412+
cccl_include = os.path.join(arch_include, "cccl") if os.path.isdir(arch_include) else None
413+
if cccl_include and os.path.isdir(cccl_include):
414+
include_paths.append(cccl_include)
415+
342416
# Verify we can find essential CUDA headers
343417
if not any(os.path.isfile(os.path.join(p, "cuda_runtime.h")) for p in include_paths):
344418
raise RuntimeError(
@@ -351,6 +425,26 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
351425
for include_path in include_paths:
352426
compile_opts.append(f"-I{include_path}".encode())
353427

428+
# Add NVSHMEM include path
429+
if use_nvshmem and nvshmem_include_path:
430+
compile_opts.append(f"-I{nvshmem_include_path}".encode())
431+
432+
# For NVSHMEM, add deprecation and type conversion macros
433+
if use_nvshmem:
434+
compile_opts.extend(
435+
[
436+
# Define deprecation macros as empty (not properly defined in NVRTC context)
437+
b"-D__NV_SILENCE_DEPRECATION_BEGIN=",
438+
b"-D__NV_SILENCE_DEPRECATION_END=",
439+
b"-D__NV_SILENCE_HOST_DEPRECATION_BEGIN=",
440+
b"-D__NV_SILENCE_HOST_DEPRECATION_END=",
441+
# Disable FP8/FP6/FP4 extended types that cause issues with NVRTC
442+
b"-D__CUDA_NO_FP8_CONVERSIONS__",
443+
b"-D__CUDA_NO_FP6_CONVERSIONS__",
444+
b"-D__CUDA_NO_FP4_CONVERSIONS__",
445+
]
446+
)
447+
354448
compile_opts.extend(
355449
[
356450
b"-U__CUDA_NO_HALF_OPERATORS__",
@@ -363,12 +457,40 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
363457
]
364458
)
365459

366-
# Add user-provided options
460+
# Add user-provided options, filtering out nvcc-specific flags that nvrtc doesn't support
367461
if options:
462+
nvcc_only_prefixes = (
463+
"-c",
464+
"-O",
465+
"-std",
466+
"--std",
467+
"-Xcompiler",
468+
"-Xlinker",
469+
"-Xarchive",
470+
"-Xcudafe",
471+
"-Xptxas",
472+
"--compile",
473+
"--compiler-options",
474+
"--linker-options",
475+
"-fPIC",
476+
"-shared",
477+
"-o",
478+
)
368479
if isinstance(options, str):
369-
compile_opts.append(options.encode())
370-
else:
371-
compile_opts.extend([opt.encode() if isinstance(opt, str) else opt for opt in options])
480+
options = [options]
481+
for opt in options:
482+
if isinstance(opt, str):
483+
opt_str = opt
484+
elif isinstance(opt, bytes):
485+
opt_str = opt.decode()
486+
else:
487+
opt_str = str(opt)
488+
skip = any(
489+
opt_str.startswith(prefix) or opt_str == prefix for prefix in nvcc_only_prefixes
490+
)
491+
if skip:
492+
continue
493+
compile_opts.append(opt.encode() if isinstance(opt, str) else opt)
372494

373495
# Compile
374496
(result,) = nvrtc.nvrtcCompileProgram(prog, len(compile_opts), compile_opts)
@@ -410,10 +532,94 @@ def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
410532
nvrtc.nvrtcDestroyProgram(prog)
411533
raise RuntimeError(f"Failed to get PTX: {nvrtc.nvrtcGetErrorString(result)}")
412534

413-
# Clean up
535+
# Clean up NVRTC program
414536
nvrtc.nvrtcDestroyProgram(prog)
415537

416-
return bytearray(binary_buf)
538+
# Link stage for NVSHMEM
539+
if use_nvshmem:
540+
binary_buf = _link_nvshmem_nvrtc(binary_buf, nvshmem_lib_path)
541+
542+
if path_target:
543+
with open(path_target, "wb") as f:
544+
f.write(binary_buf)
545+
return binary_buf
546+
547+
548+
def _link_nvshmem_nvrtc(binary_buf, nvshmem_lib_path):
549+
"""Link compiled CUBIN with NVSHMEM device library using CUDA driver API."""
550+
import ctypes # pylint: disable=import-outside-toplevel
551+
552+
from cuda.bindings import driver as cu # pylint: disable=import-outside-toplevel
553+
554+
# cuLinkCreate requires a valid CUDA context.
555+
# Always create a fresh context for linking to avoid issues with stale contexts
556+
# in multi-process environments like Disco workers.
557+
(result,) = cu.cuInit(0)
558+
if result != cu.CUresult.CUDA_SUCCESS:
559+
raise RuntimeError(f"Failed to initialize CUDA: {result}")
560+
561+
result, device = cu.cuDeviceGet(0)
562+
if result != cu.CUresult.CUDA_SUCCESS:
563+
raise RuntimeError(f"Failed to get CUDA device: {result}")
564+
565+
result, context = cu.cuCtxCreate(None, 0, device)
566+
if result != cu.CUresult.CUDA_SUCCESS:
567+
raise RuntimeError(f"Failed to create CUDA context: {result}")
568+
569+
try:
570+
# Create linker
571+
result, link_state = cu.cuLinkCreate(0, [], [])
572+
if result != cu.CUresult.CUDA_SUCCESS:
573+
raise RuntimeError(f"Failed to create CUDA linker: {result}")
574+
575+
try:
576+
# Add our compiled CUBIN
577+
(result,) = cu.cuLinkAddData(
578+
link_state,
579+
cu.CUjitInputType.CU_JIT_INPUT_CUBIN,
580+
binary_buf,
581+
len(binary_buf),
582+
b"tvm_kernels.cubin",
583+
0,
584+
[],
585+
[],
586+
)
587+
if result != cu.CUresult.CUDA_SUCCESS:
588+
raise RuntimeError(f"Failed to add CUBIN to linker: {result}")
589+
590+
# Add NVSHMEM device library
591+
nvshmem_device_lib = os.path.join(nvshmem_lib_path, "libnvshmem_device.a")
592+
if not os.path.exists(nvshmem_device_lib):
593+
raise RuntimeError(f"NVSHMEM device library not found: {nvshmem_device_lib}")
594+
595+
(result,) = cu.cuLinkAddFile(
596+
link_state,
597+
cu.CUjitInputType.CU_JIT_INPUT_LIBRARY,
598+
nvshmem_device_lib.encode(),
599+
0,
600+
[],
601+
[],
602+
)
603+
if result != cu.CUresult.CUDA_SUCCESS:
604+
raise RuntimeError(f"Failed to add NVSHMEM device library: {result}")
605+
606+
# Complete linking
607+
result, linked_cubin, linked_size = cu.cuLinkComplete(link_state)
608+
if result != cu.CUresult.CUDA_SUCCESS:
609+
raise RuntimeError(f"Failed to complete NVSHMEM linking: {result}")
610+
611+
# Copy linked binary before destroying linker
612+
binary_buf = bytearray(ctypes.string_at(linked_cubin, linked_size))
613+
if not binary_buf:
614+
raise RuntimeError("Compilation error: empty result is generated")
615+
finally:
616+
# Clean up linker
617+
cu.cuLinkDestroy(link_state)
618+
finally:
619+
# Clean up context
620+
cu.cuCtxDestroy(context)
621+
622+
return binary_buf
417623

418624

419625
def find_cuda_path():

0 commit comments

Comments
 (0)