File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -656,13 +656,14 @@ def check_cuda_runtime():
656656 driver_version = ctypes .c_int ()
657657 runtime_version = ctypes .c_int ()
658658
659- if cuda .cudaDriverGetVersion (ctypes .byref (driver_version )) == 0 and \
660- cuda .cudaRuntimeGetVersion (ctypes .byref (runtime_version )) == 0 :
661- driver_version = driver_version .value
662- runtime_version = runtime_version .value
663-
664- driver_v = parse (str (driver_version / 1000 ))
665- runtime_v = parse (str (runtime_version / 1000 ))
659+ # Check the get*Version call succeeds and is a non-zero value
660+ call_success = cuda .cudaDriverGetVersion (ctypes .byref (driver_version )) == 0
661+ call_success &= cuda .cudaRuntimeGetVersion (ctypes .byref (runtime_version )) == 0
662+ call_success &= bool (driver_version .value )
663+
664+ if call_success :
665+ driver_v = parse (str (driver_version .value / 1000 ))
666+ runtime_v = parse (str (runtime_version .value / 1000 ))
666667 # First check the "major" version, known to be incompatible
667668 if driver_v .major < runtime_v .major :
668669 raise RuntimeError (
You can’t perform that action at this time.
0 commit comments