File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -44,18 +44,22 @@ def _import_get_cuda_path_or_home():
4444 except ModuleNotFoundError as exc :
4545 if exc .name != "cuda.pathfinder" :
4646 raise
47+ from importlib .metadata import PackageNotFoundError , distribution
48+ from pathlib import Path
49+
4750 import cuda
4851
49- for p in sys .path :
50- sp_cuda = os .path .join (p , "cuda" )
51- if os .path .isdir (os .path .join (sp_cuda , "pathfinder" )):
52- cuda .__path__ = list (cuda .__path__ ) + [sp_cuda ]
53- break
54- else :
52+ try :
53+ dist = distribution ("cuda-pathfinder" )
54+ except PackageNotFoundError :
5555 raise ModuleNotFoundError (
5656 "cuda-pathfinder is not installed in the build environment. "
5757 "Ensure 'cuda-pathfinder>=1.5' is in build-system.requires."
58- )
58+ ) from None
59+ site_cuda = str (dist .locate_file (Path ("cuda" )))
60+ cuda_paths = list (cuda .__path__ )
61+ if site_cuda not in cuda_paths :
62+ cuda .__path__ = cuda_paths + [site_cuda ]
5963 import cuda .pathfinder
6064
6165 return cuda .pathfinder .get_cuda_path_or_home
Original file line number Diff line number Diff line change @@ -39,18 +39,22 @@ def _import_get_cuda_path_or_home():
3939 except ModuleNotFoundError as exc :
4040 if exc .name != "cuda.pathfinder" :
4141 raise
42+ from importlib .metadata import PackageNotFoundError , distribution
43+ from pathlib import Path
44+
4245 import cuda
4346
44- for p in sys .path :
45- sp_cuda = os .path .join (p , "cuda" )
46- if os .path .isdir (os .path .join (sp_cuda , "pathfinder" )):
47- cuda .__path__ = list (cuda .__path__ ) + [sp_cuda ]
48- break
49- else :
47+ try :
48+ dist = distribution ("cuda-pathfinder" )
49+ except PackageNotFoundError :
5050 raise ModuleNotFoundError (
5151 "cuda-pathfinder is not installed in the build environment. "
5252 "Ensure 'cuda-pathfinder>=1.5' is in build-system.requires."
53- )
53+ ) from None
54+ site_cuda = str (dist .locate_file (Path ("cuda" )))
55+ cuda_paths = list (cuda .__path__ )
56+ if site_cuda not in cuda_paths :
57+ cuda .__path__ = cuda_paths + [site_cuda ]
5458 import cuda .pathfinder
5559
5660 return cuda .pathfinder .get_cuda_path_or_home
You can’t perform that action at this time.
0 commit comments