11import asyncio
2+ import base64
23import datetime
34import json
45import math
56import pprint
67import tempfile
78import zipfile
9+ import zlib
810from typing import Awaitable , Callable , Optional
911
1012import requests
2022from github import Github , UnknownObjectException , WorkflowRun
2123from report import RunProgressReporter
2224from run_eval import CompileResult , EvalResult , FullResult , RunResult , SystemInfo
23- from utils import get_github_branch_name , setup_logging
25+ from utils import setup_logging
2426
2527from .launcher import Launcher
2628
@@ -39,10 +41,11 @@ def get_timeout(config: dict) -> int:
3941
4042
4143class GitHubLauncher (Launcher ):
42- def __init__ (self , repo : str , token : str ):
44+ def __init__ (self , repo : str , token : str , branch : str ):
4345 super ().__init__ (name = "GitHub" , gpus = GitHubGPU )
4446 self .repo = repo
4547 self .token = token
48+ self .branch = branch
4649 self .trigger_limit = asyncio .Semaphore (1 )
4750
4851 async def run_submission (
@@ -71,10 +74,12 @@ async def run_submission(
7174 lang_name = {"py" : "Python" , "cu" : "CUDA" }[lang ]
7275
7376 logger .info (f"Attempting to trigger GitHub action for { lang_name } on { selected_workflow } " )
74- run = GitHubRun (self .repo , self .token , selected_workflow )
77+ run = GitHubRun (self .repo , self .token , self . branch , selected_workflow )
7578 logger .info (f"Successfully created GitHub run: { run .run_id } " )
7679
77- payload = json .dumps (config )
80+ payload = base64 .b64encode (zlib .compress (json .dumps (config ).encode ("utf-8" ))).decode (
81+ "utf-8"
82+ )
7883
7984 inputs = {"payload" : payload }
8085 if lang == "py" :
@@ -143,10 +148,11 @@ async def wait_callback(self, run: "GitHubRun", status: RunProgressReporter):
143148
144149
145150class GitHubRun :
146- def __init__ (self , repo : str , token : str , workflow_file : str ):
151+ def __init__ (self , repo : str , token : str , branch : str , workflow_file : str ):
147152 gh = Github (token )
148153 self .repo = gh .get_repo (repo )
149154 self .token = token
155+ self .branch = branch
150156 self .workflow_file = workflow_file
151157 self .run : Optional [WorkflowRun .WorkflowRun ] = None
152158 self .start_time = None
@@ -189,14 +195,14 @@ async def trigger(self, inputs: dict) -> bool:
189195 logger .error (f"Could not find workflow { self .workflow_file } " , exc_info = e )
190196 raise ValueError (f"Could not find workflow { self .workflow_file } " ) from e
191197
192- branch_name = get_github_branch_name ( )
198+ logger . info ( "Dispatching workflow %s on branch %s" , self . workflow_file , self . branch )
193199 logger .debug (
194200 "Dispatching workflow %s on branch %s with inputs %s" ,
195201 self .workflow_file ,
196- branch_name ,
202+ self . branch ,
197203 pprint .pformat (inputs ),
198204 )
199- success = await asyncio .to_thread (workflow .create_dispatch , branch_name , inputs = inputs )
205+ success = await asyncio .to_thread (workflow .create_dispatch , self . branch , inputs = inputs )
200206
201207 if success :
202208 wait_seconds = 5
@@ -248,7 +254,7 @@ async def trigger(self, inputs: dict) -> bool:
248254 return False
249255 else :
250256 logger .error (
251- f"Failed to dispatch workflow { self .workflow_file } on branch { branch_name } ."
257+ f"Failed to dispatch workflow { self .workflow_file } on branch { self . branch } ."
252258 )
253259 return False
254260
0 commit comments