44import json
55import os
66import time
7+ from contextlib import contextmanager
78from dataclasses import asdict
89from typing import Annotated , Optional
910
11+ from backend import KernelBackend
1012from consts import SubmissionMode
1113from fastapi import Depends , FastAPI , Header , HTTPException , UploadFile
1214from fastapi .responses import StreamingResponse
1517
1618from .utils import _handle_discord_oauth , _handle_github_oauth , _run_submission
1719
20+ # yes, we do want ... = Depends() in function signatures
21+ # ruff: noqa: B008
22+
1823app = FastAPI ()
1924
2025
@@ -25,7 +30,7 @@ def json_serializer(obj):
2530 raise TypeError (f"Type { type (obj )} not serializable" )
2631
2732
28- bot_instance = None
33+ backend_instance : KernelBackend = None
2934
3035_last_action = time .time ()
3136_submit_limiter = asyncio .Semaphore (3 )
@@ -50,13 +55,29 @@ async def simple_rate_limit():
5055 return
5156
5257
53- def init_api (_bot_instance ):
54- global bot_instance
55- bot_instance = _bot_instance
58+ def init_api (_backend_instance : KernelBackend ):
59+ global backend_instance
60+ backend_instance = _backend_instance
61+
62+
63+ @contextmanager
64+ def get_db ():
65+ """Database context manager with guaranteed error handling"""
66+ if not backend_instance :
67+ raise HTTPException (status_code = 500 , detail = "Bot instance not initialized" )
68+
69+ if not hasattr (backend_instance , "leaderboard_db" ):
70+ raise HTTPException (status_code = 500 , detail = "Database not initialized" )
71+
72+ with backend_instance .db as db :
73+ if db is None :
74+ raise HTTPException (status_code = 500 , detail = "Database connection failed" )
75+ yield db
5676
5777
5878async def validate_cli_header (
5979 x_popcorn_cli_id : Optional [str ] = Header (None , alias = "X-Popcorn-Cli-Id" ),
80+ db_context = Depends (get_db ),
6081) -> str :
6182 """
6283 FastAPI dependency to validate the X-Popcorn-Cli-Id header.
@@ -70,15 +91,8 @@ async def validate_cli_header(
7091 if not x_popcorn_cli_id :
7192 raise HTTPException (status_code = 400 , detail = "Missing X-Popcorn-Cli-Id header" )
7293
73- if not bot_instance or not hasattr (bot_instance , "leaderboard_db" ):
74- raise HTTPException (
75- status_code = 500 , detail = "Bot instance or database not initialized for validation"
76- )
77-
7894 try :
79- with bot_instance .leaderboard_db as db :
80- if db is None :
81- raise HTTPException (status_code = 500 , detail = "Database connection failed" )
95+ with db_context as db :
8296 user_info = db .validate_cli_id (x_popcorn_cli_id )
8397 except Exception as e :
8498 raise HTTPException (status_code = 500 , detail = f"Database error during validation: { e } " ) from e
@@ -90,7 +104,7 @@ async def validate_cli_header(
90104
91105
92106@app .get ("/auth/init" )
93- async def auth_init (provider : str ) -> dict :
107+ async def auth_init (provider : str , db_context = Depends ( get_db ) ) -> dict :
94108 if provider not in ["discord" , "github" ]:
95109 raise HTTPException (
96110 status_code = 400 , detail = "Invalid provider, must be 'discord' or 'github'"
@@ -110,14 +124,8 @@ async def auth_init(provider: str) -> dict:
110124
111125 state_uuid = str (uuid .uuid4 ())
112126
113- # Ensure bot_instance and leaderboard_db are available
114- if not bot_instance or not hasattr (bot_instance , "leaderboard_db" ):
115- raise HTTPException (status_code = 500 , detail = "Bot instance or database not initialized" )
116-
117127 try :
118- with bot_instance .leaderboard_db as db :
119- if db is None :
120- raise HTTPException (status_code = 500 , detail = "Database connection failed" )
128+ with db_context as db :
121129 # Assuming init_user_from_cli exists and handles DB interaction
122130 db .init_user_from_cli (state_uuid , provider )
123131 except AttributeError as e :
@@ -131,7 +139,7 @@ async def auth_init(provider: str) -> dict:
131139
132140
133141@app .get ("/auth/cli/{auth_provider}" )
134- async def cli_auth (auth_provider : str , code : str , state : str ): # noqa: C901
142+ async def cli_auth (auth_provider : str , code : str , state : str , db_context = Depends ( get_db ) ): # noqa: C901
135143 """
136144 Handle Discord/GitHub OAuth redirect. This endpoint receives the authorization code
137145 and state parameter from the OAuth flow.
@@ -194,13 +202,8 @@ async def cli_auth(auth_provider: str, code: str, state: str): # noqa: C901
194202 status_code = 500 , detail = "Failed to retrieve user ID or username from provider."
195203 )
196204
197- if not bot_instance or not hasattr (bot_instance , "leaderboard_db" ):
198- raise HTTPException (
199- status_code = 500 , detail = "Bot instance or database not initialized for update"
200- )
201-
202205 try :
203- with bot_instance . leaderboard_db as db :
206+ with db_context as db :
204207 if is_reset :
205208 db .reset_user_from_cli (user_id , cli_id , auth_provider )
206209 else :
@@ -223,7 +226,10 @@ async def cli_auth(auth_provider: str, code: str, state: str): # noqa: C901
223226
224227
225228async def _stream_submission_response (
226- submission_request , user_info , submission_mode_enum , bot_instance
229+ submission_request : SubmissionRequest ,
230+ user_info : dict ,
231+ submission_mode_enum : SubmissionMode ,
232+ backend : KernelBackend ,
227233):
228234 start_time = time .time ()
229235 task : asyncio .Task | None = None
@@ -233,15 +239,15 @@ async def _stream_submission_response(
233239 submission_request ,
234240 user_info ,
235241 submission_mode_enum ,
236- bot_instance ,
242+ backend ,
237243 )
238244 )
239245
240246 while not task .done ():
241247 elapsed_time = time .time () - start_time
242248 yield f"event: status\n data: { json .dumps ({'status' : 'processing' ,
243- 'elapsed_time' : round (elapsed_time , 2 )}
244- , default = json_serializer )} \n \n "
249+ 'elapsed_time' : round (elapsed_time , 2 )},
250+ default = json_serializer )} \n \n "
245251
246252 try :
247253 await asyncio .wait_for (asyncio .shield (task ), timeout = 15.0 )
@@ -292,6 +298,7 @@ async def run_submission( # noqa: C901
292298 submission_mode : str ,
293299 file : UploadFile ,
294300 user_info : Annotated [dict , Depends (validate_cli_header )],
301+ db_context = Depends (get_db ),
295302) -> StreamingResponse :
296303 """An endpoint that runs a submission on a given leaderboard, runner, and GPU type.
297304 Streams status updates and the final result via Server-Sent Events (SSE).
@@ -339,18 +346,8 @@ async def run_submission( # noqa: C901
339346 detail = f"Submission mode '{ submission_mode } ' is not supported for this endpoint" ,
340347 )
341348
342- if not bot_instance :
343- raise HTTPException (
344- status_code = 503 , detail = "Service temporarily unavailable: Bot not initialized"
345- )
346-
347349 try :
348- with bot_instance .leaderboard_db as db :
349- if db is None :
350- raise HTTPException (
351- status_code = 503 ,
352- detail = "Service temporarily unavailable: Database connection failed" ,
353- )
350+ with db_context as db :
354351 leaderboard_item = db .get_leaderboard (leaderboard_name )
355352 if not leaderboard_item :
356353 all_leaderboards = [lb ["name" ] for lb in db .get_leaderboards ()]
@@ -417,14 +414,14 @@ async def run_submission( # noqa: C901
417414 submission_request = submission_request ,
418415 user_info = {"user_id" : user_id , "user_name" : user_name },
419416 submission_mode_enum = submission_mode_enum ,
420- bot_instance = bot_instance ,
417+ backend = backend_instance ,
421418 )
422419
423420 return StreamingResponse (generator , media_type = "text/event-stream" )
424421
425422
426423@app .get ("/leaderboards" )
427- async def get_leaderboards ():
424+ async def get_leaderboards (db_context = Depends ( get_db ) ):
428425 """An endpoint that returns all leaderboards.
429426
430427 Returns:
@@ -433,19 +430,15 @@ async def get_leaderboards():
433430 and the GPU types that are available for submissions.
434431 """
435432 await simple_rate_limit ()
436- if not bot_instance or not hasattr (bot_instance , "leaderboard_db" ):
437- raise HTTPException (status_code = 500 , detail = "Bot instance or database not initialized" )
438433 try :
439- with bot_instance .leaderboard_db as db :
440- if db is None :
441- raise HTTPException (status_code = 500 , detail = "Database connection failed" )
434+ with db_context as db :
442435 return db .get_leaderboards ()
443436 except Exception as e :
444437 raise HTTPException (status_code = 500 , detail = f"Error fetching leaderboards: { e } " ) from e
445438
446439
447440@app .get ("/gpus/{leaderboard_name}" )
448- async def get_gpus (leaderboard_name : str ) -> list [str ]:
441+ async def get_gpus (leaderboard_name : str , db_context = Depends ( get_db ) ) -> list [str ]:
449442 """An endpoint that returns all GPU types that are available for a given leaderboard and runner.
450443
451444 Args:
@@ -456,14 +449,8 @@ async def get_gpus(leaderboard_name: str) -> list[str]:
456449 list[str]: A list of GPU types that are available for the given leaderboard and runner.
457450 """
458451 await simple_rate_limit ()
459- if not bot_instance or not hasattr (bot_instance , "leaderboard_db" ):
460- raise HTTPException (status_code = 500 , detail = "Bot instance or database not initialized" )
461-
462452 try :
463- with bot_instance .leaderboard_db as db :
464- if db is None :
465- raise HTTPException (status_code = 500 , detail = "Database connection failed" )
466-
453+ with db_context as db :
467454 # Validate leaderboard exists first
468455 leaderboard_names = [x ["name" ] for x in db .get_leaderboards ()]
469456 if leaderboard_name not in leaderboard_names :
@@ -482,15 +469,15 @@ async def get_gpus(leaderboard_name: str) -> list[str]:
482469
483470@app .get ("/submissions/{leaderboard_name}/{gpu_name}" )
484471async def get_submissions (
485- leaderboard_name : str , gpu_name : str , limit : int = None , offset : int = 0
472+ leaderboard_name : str ,
473+ gpu_name : str ,
474+ limit : int = None ,
475+ offset : int = 0 ,
476+ db_context = Depends (get_db ),
486477) -> list [LeaderboardRankedEntry ]:
487478 await simple_rate_limit ()
488- if not bot_instance or not hasattr (bot_instance , "leaderboard_db" ):
489- raise HTTPException (status_code = 500 , detail = "Bot instance or database not initialized" )
490479 try :
491- with bot_instance .leaderboard_db as db :
492- if db is None :
493- raise HTTPException (status_code = 500 , detail = "Database connection failed" )
480+ with db_context as db :
494481 # Add validation for leaderboard and GPU? Might be redundant if DB handles it.
495482 return db .get_leaderboard_submissions (
496483 leaderboard_name , gpu_name , limit = limit , offset = offset
@@ -500,15 +487,13 @@ async def get_submissions(
500487
501488
502489@app .get ("/submission_count/{leaderboard_name}/{gpu_name}" )
503- async def get_submission_count (leaderboard_name : str , gpu_name : str , user_id : str = None ) -> dict :
490+ async def get_submission_count (
491+ leaderboard_name : str , gpu_name : str , user_id : str = None , db_context = Depends (get_db )
492+ ) -> dict :
504493 """Get the total count of submissions for pagination"""
505494 await simple_rate_limit ()
506- if not bot_instance or not hasattr (bot_instance , "leaderboard_db" ):
507- raise HTTPException (status_code = 500 , detail = "Bot instance or database not initialized" )
508495 try :
509- with bot_instance .leaderboard_db as db :
510- if db is None :
511- raise HTTPException (status_code = 500 , detail = "Database connection failed" )
496+ with db_context as db :
512497 count = db .get_leaderboard_submission_count (leaderboard_name , gpu_name , user_id )
513498 return {"count" : count }
514499 except Exception as e :
0 commit comments