@@ -53,6 +53,17 @@ def __init__(self, verbose=False):
5353 self .mpi_exec = "mpiexec"
5454 else :
5555 self .mpi_exec = "mpirun"
56+ self .platform = platform .system ()
57+
58+ # Detect MPI implementation to choose compatible flags
59+ self .mpi_env_mode = "unknown" # one of: openmpi, mpich, unknown
60+ self .mpi_np_flag = "-np"
61+ if self .platform == "Windows" :
62+ # MSMPI uses -env and -n
63+ self .mpi_env_mode = "mpich"
64+ self .mpi_np_flag = "-n"
65+ else :
66+ self .mpi_env_mode , self .mpi_np_flag = self .__detect_mpi_impl ()
5667
5768 @staticmethod
5869 def __get_project_path ():
@@ -88,6 +99,81 @@ def __run_exec(self, command):
8899 if result .returncode != 0 :
89100 raise Exception (f"Subprocess return { result .returncode } ." )
90101
102+ def __detect_mpi_impl (self ):
103+ """Detect MPI implementation and return (env_mode, np_flag).
104+ env_mode: 'openmpi' -> use '-x VAR', 'mpich' -> use '-genvlist VAR1,VAR2', 'unknown' -> pass no env flags.
105+ np_flag: '-np' for OpenMPI/unknown, '-n' for MPICH-family.
106+ """
107+ probes = (["--version" ], ["-V" ], ["-v" ], ["--help" ], ["-help" ])
108+ out = ""
109+ for args in probes :
110+ try :
111+ proc = subprocess .run (
112+ [self .mpi_exec ] + list (args ),
113+ stdout = subprocess .PIPE ,
114+ stderr = subprocess .STDOUT ,
115+ text = True ,
116+ )
117+ out = (proc .stdout or "" ).lower ()
118+ if out :
119+ break
120+ except Exception :
121+ continue
122+
123+ if "open mpi" in out or "ompi" in out :
124+ return "openmpi" , "-np"
125+ if (
126+ "hydra" in out
127+ or "mpich" in out
128+ or "intel(r) mpi" in out
129+ or "intel mpi" in out
130+ ):
131+ return "mpich" , "-n"
132+ return "unknown" , "-np"
133+
134+ def __build_mpi_cmd (self , ppc_num_proc , additional_mpi_args ):
135+ base = [self .mpi_exec ] + shlex .split (additional_mpi_args )
136+
137+ if self .platform == "Windows" :
138+ # MS-MPI style
139+ env_args = [
140+ "-env" ,
141+ "PPC_NUM_THREADS" ,
142+ self .__ppc_env ["PPC_NUM_THREADS" ],
143+ "-env" ,
144+ "OMP_NUM_THREADS" ,
145+ self .__ppc_env ["OMP_NUM_THREADS" ],
146+ ]
147+ np_args = ["-n" , ppc_num_proc ]
148+ return base + env_args + np_args
149+
150+ # Non-Windows
151+ if self .mpi_env_mode == "openmpi" :
152+ env_args = [
153+ "-x" ,
154+ "PPC_NUM_THREADS" ,
155+ "-x" ,
156+ "OMP_NUM_THREADS" ,
157+ ]
158+ np_flag = "-np"
159+ elif self .mpi_env_mode == "mpich" :
160+ # Explicitly set env variables for all ranks
161+ env_args = [
162+ "-env" ,
163+ "PPC_NUM_THREADS" ,
164+ self .__ppc_env ["PPC_NUM_THREADS" ],
165+ "-env" ,
166+ "OMP_NUM_THREADS" ,
167+ self .__ppc_env ["OMP_NUM_THREADS" ],
168+ ]
169+ np_flag = "-n"
170+ else :
171+ # Unknown MPI flavor: rely on environment inheritance and default to -np
172+ env_args = []
173+ np_flag = "-np"
174+
175+ return base + env_args + [np_flag , ppc_num_proc ]
176+
91177 @staticmethod
92178 def __get_gtest_settings (repeats_count , type_task ):
93179 command = [
@@ -133,10 +219,7 @@ def run_processes(self, additional_mpi_args):
133219 raise EnvironmentError (
134220 "Required environment variable 'PPC_NUM_PROC' is not set."
135221 )
136-
137- mpi_running = (
138- [self .mpi_exec ] + shlex .split (additional_mpi_args ) + ["-np" , ppc_num_proc ]
139- )
222+ mpi_running = self .__build_mpi_cmd (ppc_num_proc , additional_mpi_args )
140223 if not self .__ppc_env .get ("PPC_ASAN_RUN" ):
141224 for task_type in ["all" , "mpi" ]:
142225 self .__run_exec (
@@ -147,7 +230,7 @@ def run_processes(self, additional_mpi_args):
147230
148231 def run_performance (self ):
149232 if not self .__ppc_env .get ("PPC_ASAN_RUN" ):
150- mpi_running = [ self .mpi_exec , "-np" , self . __ppc_num_proc ]
233+ mpi_running = self .__build_mpi_cmd ( self . __ppc_num_proc , "" )
151234 for task_type in ["all" , "mpi" ]:
152235 self .__run_exec (
153236 mpi_running
0 commit comments