Skip to content

Commit 02f413c

Browse files
committed
add a backend class that handles db and submission triggering
1 parent 4c1da07 commit 02f413c

8 files changed

Lines changed: 274 additions & 283 deletions

File tree

src/discord-cluster-manager/api/main.py

Lines changed: 54 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import json
55
import os
66
import time
7+
from contextlib import contextmanager
78
from dataclasses import asdict
89
from typing import Annotated, Optional
910

11+
from backend import KernelBackend
1012
from consts import SubmissionMode
1113
from fastapi import Depends, FastAPI, Header, HTTPException, UploadFile
1214
from fastapi.responses import StreamingResponse
@@ -15,6 +17,9 @@
1517

1618
from .utils import _handle_discord_oauth, _handle_github_oauth, _run_submission
1719

20+
# yes, we do want ... = Depends() in function signatures
21+
# ruff: noqa: B008
22+
1823
app = FastAPI()
1924

2025

@@ -25,7 +30,7 @@ def json_serializer(obj):
2530
raise TypeError(f"Type {type(obj)} not serializable")
2631

2732

28-
bot_instance = None
33+
backend_instance: KernelBackend = None
2934

3035
_last_action = time.time()
3136
_submit_limiter = asyncio.Semaphore(3)
@@ -50,13 +55,29 @@ async def simple_rate_limit():
5055
return
5156

5257

53-
def init_api(_bot_instance):
54-
global bot_instance
55-
bot_instance = _bot_instance
58+
def init_api(_backend_instance: KernelBackend):
59+
global backend_instance
60+
backend_instance = _backend_instance
61+
62+
63+
@contextmanager
64+
def get_db():
65+
"""Database context manager with guaranteed error handling"""
66+
if not backend_instance:
67+
raise HTTPException(status_code=500, detail="Bot instance not initialized")
68+
69+
if not hasattr(backend_instance, "leaderboard_db"):
70+
raise HTTPException(status_code=500, detail="Database not initialized")
71+
72+
with backend_instance.db as db:
73+
if db is None:
74+
raise HTTPException(status_code=500, detail="Database connection failed")
75+
yield db
5676

5777

5878
async def validate_cli_header(
5979
x_popcorn_cli_id: Optional[str] = Header(None, alias="X-Popcorn-Cli-Id"),
80+
db_context=Depends(get_db),
6081
) -> str:
6182
"""
6283
FastAPI dependency to validate the X-Popcorn-Cli-Id header.
@@ -70,15 +91,8 @@ async def validate_cli_header(
7091
if not x_popcorn_cli_id:
7192
raise HTTPException(status_code=400, detail="Missing X-Popcorn-Cli-Id header")
7293

73-
if not bot_instance or not hasattr(bot_instance, "leaderboard_db"):
74-
raise HTTPException(
75-
status_code=500, detail="Bot instance or database not initialized for validation"
76-
)
77-
7894
try:
79-
with bot_instance.leaderboard_db as db:
80-
if db is None:
81-
raise HTTPException(status_code=500, detail="Database connection failed")
95+
with db_context as db:
8296
user_info = db.validate_cli_id(x_popcorn_cli_id)
8397
except Exception as e:
8498
raise HTTPException(status_code=500, detail=f"Database error during validation: {e}") from e
@@ -90,7 +104,7 @@ async def validate_cli_header(
90104

91105

92106
@app.get("/auth/init")
93-
async def auth_init(provider: str) -> dict:
107+
async def auth_init(provider: str, db_context=Depends(get_db)) -> dict:
94108
if provider not in ["discord", "github"]:
95109
raise HTTPException(
96110
status_code=400, detail="Invalid provider, must be 'discord' or 'github'"
@@ -110,14 +124,8 @@ async def auth_init(provider: str) -> dict:
110124

111125
state_uuid = str(uuid.uuid4())
112126

113-
# Ensure bot_instance and leaderboard_db are available
114-
if not bot_instance or not hasattr(bot_instance, "leaderboard_db"):
115-
raise HTTPException(status_code=500, detail="Bot instance or database not initialized")
116-
117127
try:
118-
with bot_instance.leaderboard_db as db:
119-
if db is None:
120-
raise HTTPException(status_code=500, detail="Database connection failed")
128+
with db_context as db:
121129
# Assuming init_user_from_cli exists and handles DB interaction
122130
db.init_user_from_cli(state_uuid, provider)
123131
except AttributeError as e:
@@ -131,7 +139,7 @@ async def auth_init(provider: str) -> dict:
131139

132140

133141
@app.get("/auth/cli/{auth_provider}")
134-
async def cli_auth(auth_provider: str, code: str, state: str): # noqa: C901
142+
async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends(get_db)): # noqa: C901
135143
"""
136144
Handle Discord/GitHub OAuth redirect. This endpoint receives the authorization code
137145
and state parameter from the OAuth flow.
@@ -194,13 +202,8 @@ async def cli_auth(auth_provider: str, code: str, state: str): # noqa: C901
194202
status_code=500, detail="Failed to retrieve user ID or username from provider."
195203
)
196204

197-
if not bot_instance or not hasattr(bot_instance, "leaderboard_db"):
198-
raise HTTPException(
199-
status_code=500, detail="Bot instance or database not initialized for update"
200-
)
201-
202205
try:
203-
with bot_instance.leaderboard_db as db:
206+
with db_context as db:
204207
if is_reset:
205208
db.reset_user_from_cli(user_id, cli_id, auth_provider)
206209
else:
@@ -223,7 +226,10 @@ async def cli_auth(auth_provider: str, code: str, state: str): # noqa: C901
223226

224227

225228
async def _stream_submission_response(
226-
submission_request, user_info, submission_mode_enum, bot_instance
229+
submission_request: SubmissionRequest,
230+
user_info: dict,
231+
submission_mode_enum: SubmissionMode,
232+
backend: KernelBackend,
227233
):
228234
start_time = time.time()
229235
task: asyncio.Task | None = None
@@ -233,15 +239,15 @@ async def _stream_submission_response(
233239
submission_request,
234240
user_info,
235241
submission_mode_enum,
236-
bot_instance,
242+
backend,
237243
)
238244
)
239245

240246
while not task.done():
241247
elapsed_time = time.time() - start_time
242248
yield f"event: status\ndata: {json.dumps({'status': 'processing',
243-
'elapsed_time': round(elapsed_time, 2)}
244-
,default=json_serializer)}\n\n"
249+
'elapsed_time': round(elapsed_time, 2)},
250+
default=json_serializer)}\n\n"
245251

246252
try:
247253
await asyncio.wait_for(asyncio.shield(task), timeout=15.0)
@@ -292,6 +298,7 @@ async def run_submission( # noqa: C901
292298
submission_mode: str,
293299
file: UploadFile,
294300
user_info: Annotated[dict, Depends(validate_cli_header)],
301+
db_context=Depends(get_db),
295302
) -> StreamingResponse:
296303
"""An endpoint that runs a submission on a given leaderboard, runner, and GPU type.
297304
Streams status updates and the final result via Server-Sent Events (SSE).
@@ -339,18 +346,8 @@ async def run_submission( # noqa: C901
339346
detail=f"Submission mode '{submission_mode}' is not supported for this endpoint",
340347
)
341348

342-
if not bot_instance:
343-
raise HTTPException(
344-
status_code=503, detail="Service temporarily unavailable: Bot not initialized"
345-
)
346-
347349
try:
348-
with bot_instance.leaderboard_db as db:
349-
if db is None:
350-
raise HTTPException(
351-
status_code=503,
352-
detail="Service temporarily unavailable: Database connection failed",
353-
)
350+
with db_context as db:
354351
leaderboard_item = db.get_leaderboard(leaderboard_name)
355352
if not leaderboard_item:
356353
all_leaderboards = [lb["name"] for lb in db.get_leaderboards()]
@@ -417,14 +414,14 @@ async def run_submission( # noqa: C901
417414
submission_request=submission_request,
418415
user_info={"user_id": user_id, "user_name": user_name},
419416
submission_mode_enum=submission_mode_enum,
420-
bot_instance=bot_instance,
417+
backend=backend_instance,
421418
)
422419

423420
return StreamingResponse(generator, media_type="text/event-stream")
424421

425422

426423
@app.get("/leaderboards")
427-
async def get_leaderboards():
424+
async def get_leaderboards(db_context=Depends(get_db)):
428425
"""An endpoint that returns all leaderboards.
429426
430427
Returns:
@@ -433,19 +430,15 @@ async def get_leaderboards():
433430
and the GPU types that are available for submissions.
434431
"""
435432
await simple_rate_limit()
436-
if not bot_instance or not hasattr(bot_instance, "leaderboard_db"):
437-
raise HTTPException(status_code=500, detail="Bot instance or database not initialized")
438433
try:
439-
with bot_instance.leaderboard_db as db:
440-
if db is None:
441-
raise HTTPException(status_code=500, detail="Database connection failed")
434+
with db_context as db:
442435
return db.get_leaderboards()
443436
except Exception as e:
444437
raise HTTPException(status_code=500, detail=f"Error fetching leaderboards: {e}") from e
445438

446439

447440
@app.get("/gpus/{leaderboard_name}")
448-
async def get_gpus(leaderboard_name: str) -> list[str]:
441+
async def get_gpus(leaderboard_name: str, db_context=Depends(get_db)) -> list[str]:
449442
"""An endpoint that returns all GPU types that are available for a given leaderboard and runner.
450443
451444
Args:
@@ -456,14 +449,8 @@ async def get_gpus(leaderboard_name: str) -> list[str]:
456449
list[str]: A list of GPU types that are available for the given leaderboard and runner.
457450
"""
458451
await simple_rate_limit()
459-
if not bot_instance or not hasattr(bot_instance, "leaderboard_db"):
460-
raise HTTPException(status_code=500, detail="Bot instance or database not initialized")
461-
462452
try:
463-
with bot_instance.leaderboard_db as db:
464-
if db is None:
465-
raise HTTPException(status_code=500, detail="Database connection failed")
466-
453+
with db_context as db:
467454
# Validate leaderboard exists first
468455
leaderboard_names = [x["name"] for x in db.get_leaderboards()]
469456
if leaderboard_name not in leaderboard_names:
@@ -482,15 +469,15 @@ async def get_gpus(leaderboard_name: str) -> list[str]:
482469

483470
@app.get("/submissions/{leaderboard_name}/{gpu_name}")
484471
async def get_submissions(
485-
leaderboard_name: str, gpu_name: str, limit: int = None, offset: int = 0
472+
leaderboard_name: str,
473+
gpu_name: str,
474+
limit: int = None,
475+
offset: int = 0,
476+
db_context=Depends(get_db),
486477
) -> list[LeaderboardRankedEntry]:
487478
await simple_rate_limit()
488-
if not bot_instance or not hasattr(bot_instance, "leaderboard_db"):
489-
raise HTTPException(status_code=500, detail="Bot instance or database not initialized")
490479
try:
491-
with bot_instance.leaderboard_db as db:
492-
if db is None:
493-
raise HTTPException(status_code=500, detail="Database connection failed")
480+
with db_context as db:
494481
# Add validation for leaderboard and GPU? Might be redundant if DB handles it.
495482
return db.get_leaderboard_submissions(
496483
leaderboard_name, gpu_name, limit=limit, offset=offset
@@ -500,15 +487,13 @@ async def get_submissions(
500487

501488

502489
@app.get("/submission_count/{leaderboard_name}/{gpu_name}")
503-
async def get_submission_count(leaderboard_name: str, gpu_name: str, user_id: str = None) -> dict:
490+
async def get_submission_count(
491+
leaderboard_name: str, gpu_name: str, user_id: str = None, db_context=Depends(get_db)
492+
) -> dict:
504493
"""Get the total count of submissions for pagination"""
505494
await simple_rate_limit()
506-
if not bot_instance or not hasattr(bot_instance, "leaderboard_db"):
507-
raise HTTPException(status_code=500, detail="Bot instance or database not initialized")
508495
try:
509-
with bot_instance.leaderboard_db as db:
510-
if db is None:
511-
raise HTTPException(status_code=500, detail="Database connection failed")
496+
with db_context as db:
512497
count = db.get_leaderboard_submission_count(leaderboard_name, gpu_name, user_id)
513498
return {"count": count}
514499
except Exception as e:

src/discord-cluster-manager/api/utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List
44

55
import requests
6+
from backend import KernelBackend
67
from consts import SubmissionMode, get_gpu_by_name
78
from env import (
89
CLI_DISCORD_CLIENT_ID,
@@ -139,23 +140,21 @@ async def _handle_github_oauth(code: str, redirect_uri: str) -> tuple[str, str]:
139140

140141

141142
async def _run_submission(
142-
submission: SubmissionRequest, user_info: dict, mode: SubmissionMode, bot
143+
submission: SubmissionRequest, user_info: dict, mode: SubmissionMode, backend: KernelBackend
143144
):
144145
try:
145-
req = prepare_submission(submission, bot.leaderboard_db)
146+
req = prepare_submission(submission, backend.db)
146147
except Exception as e:
147148
raise HTTPException(status_code=400, detail=str(e)) from e
148149

149150
selected_gpus = [get_gpu_by_name(gpu) for gpu in req.gpus]
150151
if len(selected_gpus) > 1 or selected_gpus[0] is None:
151152
raise HTTPException(status_code=400, detail="Invalid GPU type")
152153

153-
command = bot.get_cog("SubmitCog").submit_leaderboard
154-
155154
user_name = user_info["user_name"]
156155
user_id = user_info["user_id"]
157156

158-
with bot.leaderboard_db as db:
157+
with backend.db as db:
159158
sub_id = db.create_submission(
160159
leaderboard=req.leaderboard,
161160
file_name=submission.file_name,
@@ -176,7 +175,7 @@ def add_reporter(title: str):
176175

177176
try:
178177
tasks = [
179-
command(
178+
backend.submit_leaderboard(
180179
sub_id,
181180
submission.code,
182181
submission.file_name,
@@ -190,7 +189,7 @@ def add_reporter(title: str):
190189

191190
if mode == SubmissionMode.LEADERBOARD:
192191
tasks += [
193-
command(
192+
backend.submit_leaderboard(
194193
sub_id,
195194
submission.code,
196195
submission.file_name,
@@ -204,7 +203,7 @@ def add_reporter(title: str):
204203

205204
results = await asyncio.gather(*tasks)
206205
finally:
207-
with bot.leaderboard_db as db:
206+
with backend.db as db:
208207
db.mark_submission_done(sub_id)
209208

210209
return results, [rep.get_message() + "\n" + rep.long_report for rep in reporters]

0 commit comments

Comments
 (0)