3333_extensions = None
3434
3535
36- # Please keep in sync with the copy in cuda_core/build_hooks.py.
37- def _import_get_cuda_path_or_home ():
38- """Import get_cuda_path_or_home, working around PEP 517 namespace shadowing.
39-
40- See https://github.com/NVIDIA/cuda-python/issues/1824 for why this helper is needed.
41- """
42- try :
43- import cuda .pathfinder
44- except ModuleNotFoundError as exc :
45- if exc .name not in ("cuda" , "cuda.pathfinder" ):
46- raise
47- try :
48- import cuda
49- except ModuleNotFoundError :
50- cuda = None
51-
52- for p in sys .path :
53- sp_cuda = os .path .join (p , "cuda" )
54- if os .path .isdir (os .path .join (sp_cuda , "pathfinder" )):
55- cuda .__path__ = list (cuda .__path__ ) + [sp_cuda ]
56- break
57- else :
58- raise ModuleNotFoundError (
59- "cuda-pathfinder is not installed in the build environment. "
60- "Ensure 'cuda-pathfinder>=1.5' is in build-system.requires."
61- )
62- import cuda .pathfinder
63-
64- return cuda .pathfinder .get_cuda_path_or_home
65-
66-
6736@functools .cache
6837def _get_cuda_path () -> str :
69- get_cuda_path_or_home = _import_get_cuda_path_or_home ()
38+ from cuda .pathfinder import get_cuda_path_or_home
39+
7040 cuda_path = get_cuda_path_or_home ()
7141 if not cuda_path :
7242 raise RuntimeError ("Environment variable CUDA_PATH or CUDA_HOME is not set" )
@@ -266,7 +236,7 @@ def _generate_output(infile, template_vars):
266236
267237
268238def _rename_architecture_specific_files ():
269- path = os .path .join ("cuda" , "bindings" , "_internal" )
239+ path = os .path .join ("src" , " cuda" , "bindings" , "_internal" )
270240 if sys .platform == "linux" :
271241 src_files = glob .glob (os .path .join (path , "*_linux.pyx" ))
272242 elif sys .platform == "win32" :
@@ -290,7 +260,10 @@ def _prep_extensions(sources, libraries, include_dirs, library_dirs, extra_compi
290260 libraries = libraries if libraries else []
291261 exts = []
292262 for pyx in files :
293- mod_name = pyx .replace (".pyx" , "" ).replace (os .sep , "." ).replace ("/" , "." )
263+ mod_path = pyx .replace (".pyx" , "" ).replace (os .sep , "/" )
264+ if mod_path .startswith ("src/" ):
265+ mod_path = mod_path [len ("src/" ) :]
266+ mod_name = mod_path .replace ("/" , "." )
294267 exts .append (
295268 Extension (
296269 mod_name ,
@@ -346,12 +319,12 @@ def _build_cuda_bindings(strip=False):
346319
347320 # Generate code from .in templates
348321 path_list = [
349- os .path .join ("cuda" ),
350- os .path .join ("cuda" , "bindings" ),
351- os .path .join ("cuda" , "bindings" , "_bindings" ),
352- os .path .join ("cuda" , "bindings" , "_internal" ),
353- os .path .join ("cuda" , "bindings" , "_lib" ),
354- os .path .join ("cuda" , "bindings" , "utils" ),
322+ os .path .join ("src" , " cuda" ),
323+ os .path .join ("src" , " cuda" , "bindings" ),
324+ os .path .join ("src" , " cuda" , "bindings" , "_bindings" ),
325+ os .path .join ("src" , " cuda" , "bindings" , "_internal" ),
326+ os .path .join ("src" , " cuda" , "bindings" , "_lib" ),
327+ os .path .join ("src" , " cuda" , "bindings" , "utils" ),
355328 ]
356329 input_files = []
357330 for path in path_list :
@@ -375,6 +348,7 @@ def _build_cuda_bindings(strip=False):
375348 # Prepare compile/link arguments
376349 include_dirs = [
377350 os .path .dirname (sysconfig .get_path ("include" )),
351+ "src" ,
378352 ] + include_path_list
379353 library_dirs = [sysconfig .get_path ("platlib" ), os .path .join (os .sys .prefix , "lib" )]
380354 cudalib_subdirs = [r"lib\x64" ] if sys .platform == "win32" else ["lib64" , "lib" ]
@@ -415,21 +389,21 @@ def _cleanup_dst_files():
415389 # Build extension list
416390 extensions = []
417391 static_runtime_libraries = ["cudart_static" , "rt" ] if sys .platform == "linux" else ["cudart_static" ]
418- cuda_bindings_files = glob .glob ("cuda/bindings/*.pyx" )
392+ cuda_bindings_files = glob .glob ("src/ cuda/bindings/*.pyx" )
419393 if sys .platform == "win32" :
420394 cuda_bindings_files = [f for f in cuda_bindings_files if "cufile" not in f ]
421395 sources_list = [
422396 # private
423- (["cuda/bindings/_bindings/cydriver.pyx" , "cuda/bindings/_bindings/loader.cpp" ], None ),
424- (["cuda/bindings/_bindings/cynvrtc.pyx" ], None ),
425- (["cuda/bindings/_bindings/cyruntime.pyx" ], static_runtime_libraries ),
426- (["cuda/bindings/_bindings/cyruntime_ptds.pyx" ], static_runtime_libraries ),
397+ (["src/ cuda/bindings/_bindings/cydriver.pyx" , "src/ cuda/bindings/_bindings/loader.cpp" ], None ),
398+ (["src/ cuda/bindings/_bindings/cynvrtc.pyx" ], None ),
399+ (["src/ cuda/bindings/_bindings/cyruntime.pyx" ], static_runtime_libraries ),
400+ (["src/ cuda/bindings/_bindings/cyruntime_ptds.pyx" ], static_runtime_libraries ),
427401 # utils
428- (["cuda/bindings/utils/*.pyx" ], None ),
402+ (["src/ cuda/bindings/utils/*.pyx" ], None ),
429403 # public
430404 * (([f ], None ) for f in cuda_bindings_files ),
431405 # internal files used by generated bindings
432- (["cuda/bindings/_internal/utils.pyx" ], None ),
406+ (["src/ cuda/bindings/_internal/utils.pyx" ], None ),
433407 * (([f ], None ) for f in dst_files if f .endswith (".pyx" )),
434408 ]
435409
@@ -446,6 +420,7 @@ def _cleanup_dst_files():
446420 _extensions = cythonize (
447421 extensions ,
448422 nthreads = nthreads ,
423+ include_path = ["src" ],
449424 build_dir = "." if compile_for_coverage else "build/cython" ,
450425 compiler_directives = cython_directives ,
451426 ** extra_cythonize_kwargs ,
0 commit comments