55import math
66import pprint
77import tempfile
8+ import uuid
89import zipfile
910import zlib
1011from typing import Awaitable , Callable , Optional
@@ -46,7 +47,6 @@ def __init__(self, repo: str, token: str, branch: str):
4647 self .repo = repo
4748 self .token = token
4849 self .branch = branch
49- self .trigger_limit = asyncio .Semaphore (1 )
5050
5151 async def run_submission (
5252 self , config : dict , gpu_type : GPU , status : RunProgressReporter
@@ -87,11 +87,8 @@ async def run_submission(
8787 if gpu_vendor == "AMD" :
8888 inputs ["runner" ] = runner_name
8989
90- async with self .trigger_limit : # DO NOT REMOVE, PREVENTS A RACE CONDITION
91- if not await run .trigger (inputs ):
92- raise RuntimeError (
93- "Failed to trigger GitHub Action. Please check the configuration."
94- )
90+ if not await run .trigger (inputs ):
91+ raise RuntimeError ("Failed to trigger GitHub Action. Please check the configuration." )
9592
9693 await status .push ("⏳ Waiting for workflow to start..." )
9794 logger .info ("Waiting for workflow to start..." )
@@ -188,21 +185,39 @@ async def trigger(self, inputs: dict) -> bool:
188185
189186 Returns: Whether the run was successfully triggered,
190187 """
188+ run_id = str (uuid .uuid4 ())
189+
190+ inputs_with_run_id = {** inputs , "run_id" : run_id }
191+
192+ if self .workflow_file == "amd_workflow.yml" :
193+ expected_run_name = f"AMD Job - { run_id } "
194+ elif self .workflow_file == "nvidia_workflow.yml" :
195+ expected_run_name = f"NVIDIA Job - { run_id } "
196+ else :
197+ raise ValueError (f"Unknown workflow file: { self .workflow_file } " )
198+
191199 trigger_time = datetime .datetime .now (datetime .timezone .utc )
192200 try :
193201 workflow = await asyncio .to_thread (self .repo .get_workflow , self .workflow_file )
194202 except UnknownObjectException as e :
195203 logger .error (f"Could not find workflow { self .workflow_file } " , exc_info = e )
196204 raise ValueError (f"Could not find workflow { self .workflow_file } " ) from e
197205
198- logger .info ("Dispatching workflow %s on branch %s" , self .workflow_file , self .branch )
206+ logger .info (
207+ "Dispatching workflow %s on branch %s with run_id %s" ,
208+ self .workflow_file ,
209+ self .branch ,
210+ run_id ,
211+ )
199212 logger .debug (
200213 "Dispatching workflow %s on branch %s with inputs %s" ,
201214 self .workflow_file ,
202215 self .branch ,
203- pprint .pformat (inputs ),
216+ pprint .pformat (inputs_with_run_id ),
217+ )
218+ success = await asyncio .to_thread (
219+ workflow .create_dispatch , self .branch , inputs = inputs_with_run_id
204220 )
205- success = await asyncio .to_thread (workflow .create_dispatch , self .branch , inputs = inputs )
206221
207222 if success :
208223 wait_seconds = 5
@@ -214,28 +229,27 @@ async def trigger(self, inputs: dict) -> bool:
214229 workflow .get_runs , event = "workflow_dispatch"
215230 )
216231
217- logger .info (
218- f"Checking recent workflow_dispatch runs after { trigger_time .isoformat ()} ..."
219- )
232+ logger .info (f"Looking for workflow run with name: '{ expected_run_name } '" )
220233 found_run = None
221234 runs_checked = 0
222235 try :
223236 run_iterator = recent_runs_paginated .__iter__ ()
224- while runs_checked < 50 :
237+ while runs_checked < 100 :
225238 try :
226239 run = next (run_iterator )
227240 runs_checked += 1
228241 logger .debug (
229- f"Checking run { run .id } created at { run .created_at .isoformat ()} "
242+ f"Checking run { run .id } with name '{ run .name } '"
243+ f" created at { run .created_at .isoformat ()} "
230244 )
231- if run .created_at .replace (
245+ if run .name == expected_run_name and run . created_at .replace (
232246 tzinfo = datetime .timezone .utc
233- ) > trigger_time - datetime .timedelta (seconds = 2 ):
247+ ) > trigger_time - datetime .timedelta (seconds = 30 ):
234248 found_run = run
235- logger .info (f"Found matching workflow run: ID { found_run . id } " )
236- break
237- else :
238- logger . info ( f"Run { run . id } is older than trigger time, stopping check." )
249+ logger .info (
250+ f"Found matching workflow run: ID { found_run . id } "
251+ f"with name ' { found_run . name } '"
252+ )
239253 break
240254 except StopIteration :
241255 logger .debug ("Reached end of recent runs list." )
@@ -249,7 +263,8 @@ async def trigger(self, inputs: dict) -> bool:
249263 return True
250264 else :
251265 logger .warning (
252- f"Could not find a workflow run created after { trigger_time .isoformat ()} ."
266+ f"Could not find a workflow run with name '{ expected_run_name } ' "
267+ f"created after { trigger_time .isoformat ()} ."
253268 )
254269 return False
255270 else :
0 commit comments