Skip to content

Commit e98f257

Browse files
authored
Feat: streaming response (#249)
1 parent 365879e commit e98f257

1 file changed

Lines changed: 155 additions & 32 deletions

File tree

  • src/discord-cluster-manager/api

src/discord-cluster-manager/api/main.py

Lines changed: 155 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import base64
3+
import datetime
34
import json
45
import os
56
import time
@@ -8,6 +9,7 @@
89

910
from consts import SubmissionMode
1011
from fastapi import Depends, FastAPI, Header, HTTPException, UploadFile
12+
from fastapi.responses import StreamingResponse
1113
from submission import SubmissionRequest
1214
from utils import LeaderboardRankedEntry
1315

@@ -16,6 +18,13 @@
1618
app = FastAPI()
1719

1820

21+
def json_serializer(obj):
22+
"""JSON serializer for objects not serializable by default json code"""
23+
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
24+
return obj.isoformat()
25+
raise TypeError(f"Type {type(obj)} not serializable")
26+
27+
1928
bot_instance = None
2029

2130
_last_action = time.time()
@@ -213,15 +222,75 @@ async def cli_auth(auth_provider: str, code: str, state: str): # noqa: C901
213222
}
214223

215224

225+
async def _stream_submission_response(
226+
submission_request, user_info, submission_mode_enum, bot_instance
227+
):
228+
start_time = time.time()
229+
task: asyncio.Task | None = None
230+
try:
231+
task = asyncio.create_task(
232+
_run_submission(
233+
submission_request,
234+
user_info,
235+
submission_mode_enum,
236+
bot_instance,
237+
)
238+
)
239+
240+
while not task.done():
241+
elapsed_time = time.time() - start_time
242+
yield f"event: status\ndata: {json.dumps({'status': 'processing',
243+
'elapsed_time': round(elapsed_time, 2)}
244+
,default=json_serializer)}\n\n"
245+
246+
try:
247+
await asyncio.wait_for(asyncio.shield(task), timeout=15.0)
248+
except asyncio.TimeoutError:
249+
continue
250+
except asyncio.CancelledError:
251+
yield f"event: error\ndata: {json.dumps(
252+
{'status': 'error', 'detail': 'Submission cancelled'},
253+
default=json_serializer)}\n\n"
254+
return
255+
256+
result = await task
257+
result_data = {"status": "success", "results": [asdict(r) for r in result]}
258+
yield f"event: result\ndata: {json.dumps(result_data, default=json_serializer)}\n\n"
259+
260+
except HTTPException as http_exc:
261+
error_data = {
262+
"status": "error",
263+
"detail": http_exc.detail,
264+
"status_code": http_exc.status_code,
265+
}
266+
yield f"event: error\ndata: {json.dumps(error_data, default=json_serializer)}\n\n"
267+
except Exception as e:
268+
error_type = type(e).__name__
269+
error_data = {
270+
"status": "error",
271+
"detail": f"An unexpected error occurred: {error_type}",
272+
"raw_error": str(e),
273+
}
274+
yield f"event: error\ndata: {json.dumps(error_data, default=json_serializer)}\n\n"
275+
finally:
276+
if task and not task.done():
277+
task.cancel()
278+
try:
279+
await task
280+
except asyncio.CancelledError:
281+
pass
282+
283+
216284
@app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}")
217285
async def run_submission( # noqa: C901
218286
leaderboard_name: str,
219287
gpu_type: str,
220288
submission_mode: str,
221289
file: UploadFile,
222-
user_info: Annotated[dict, Depends(validate_cli_header)], # Apply dependency
223-
) -> dict:
290+
user_info: Annotated[dict, Depends(validate_cli_header)],
291+
) -> StreamingResponse:
224292
"""An endpoint that runs a submission on a given leaderboard, runner, and GPU type.
293+
Streams status updates and the final result via Server-Sent Events (SSE).
225294
226295
Requires a valid X-Popcorn-Cli-Id header.
227296
@@ -236,64 +305,118 @@ async def run_submission( # noqa: C901
236305
HTTPException: If the bot is not initialized, or header/input is invalid.
237306
238307
Returns:
239-
dict: A dictionary containing the status of the submission and the result.
240-
See class `FullResult` for more details.
308+
StreamingResponse: A streaming response containing the status and results of the submission.
241309
"""
242310
await simple_rate_limit()
243311
user_name = user_info["user_name"]
244312
user_id = user_info["user_id"]
245313

246-
submission_mode_enum: SubmissionMode = SubmissionMode(submission_mode.lower())
314+
try:
315+
submission_mode_enum: SubmissionMode = SubmissionMode(submission_mode.lower())
316+
except ValueError:
317+
raise HTTPException(
318+
status_code=400, detail=f"Invalid submission mode value: '{submission_mode}'"
319+
) from None
320+
247321
if submission_mode_enum in [SubmissionMode.PROFILE]:
248-
raise HTTPException(status_code=400, detail="Profile submissions are not supported yet")
322+
raise HTTPException(
323+
status_code=400, detail="Profile submissions are not currently supported via API"
324+
)
249325

250-
if submission_mode_enum not in [
326+
allowed_modes = [
251327
SubmissionMode.TEST,
252328
SubmissionMode.BENCHMARK,
253329
SubmissionMode.SCRIPT,
254330
SubmissionMode.LEADERBOARD,
255-
]:
256-
raise HTTPException(status_code=400, detail="Invalid submission mode")
331+
]
332+
if submission_mode_enum not in allowed_modes:
333+
raise HTTPException(
334+
status_code=400,
335+
detail=f"Submission mode '{submission_mode}' is not supported for this endpoint",
336+
)
257337

258338
if not bot_instance:
259-
raise HTTPException(status_code=500, detail="Bot not initialized")
339+
raise HTTPException(
340+
status_code=503, detail="Service temporarily unavailable: Bot not initialized"
341+
)
260342

261343
try:
262344
with bot_instance.leaderboard_db as db:
263345
if db is None:
264-
raise HTTPException(status_code=500, detail="Database connection failed")
265-
if not (leaderboard_item := db.get_leaderboard(leaderboard_name)):
266-
raise HTTPException(status_code=400, detail="Invalid leaderboard name")
267-
268-
gpus = leaderboard_item["gpu_types"]
346+
raise HTTPException(
347+
status_code=503,
348+
detail="Service temporarily unavailable: Database connection failed",
349+
)
350+
leaderboard_item = db.get_leaderboard(leaderboard_name)
351+
if not leaderboard_item:
352+
all_leaderboards = [lb["name"] for lb in db.get_leaderboards()]
353+
if leaderboard_name not in all_leaderboards:
354+
raise HTTPException(
355+
status_code=404, detail=f"Leaderboard '{leaderboard_name}' not found."
356+
)
357+
else:
358+
raise HTTPException(
359+
status_code=500,
360+
detail=f"Error retrieving details for leaderboard '{leaderboard_name}'.",
361+
)
362+
363+
gpus = leaderboard_item.get("gpu_types", [])
269364
if gpu_type not in gpus:
365+
supported_gpus = ", ".join(gpus) if gpus else "None"
270366
raise HTTPException(
271-
status_code=400, detail="This GPU is not supported for this leaderboard"
367+
status_code=400,
368+
detail=f"GPU type '{gpu_type}' is not supported for "
369+
f"leaderboard '{leaderboard_name}'. Supported GPUs: {supported_gpus}",
272370
)
371+
except HTTPException:
372+
raise
273373
except Exception as e:
274-
raise HTTPException(status_code=500, detail=f"Error fetching leaderboard data: {e}") from e
374+
raise HTTPException(
375+
status_code=500, detail=f"Internal server error while validating leaderboard/GPU: {e}"
376+
) from e
275377

276378
try:
277379
submission_content = await file.read()
380+
if not submission_content:
381+
raise HTTPException(
382+
status_code=400, detail="Empty file submitted. Please provide a file with code."
383+
)
384+
if len(submission_content) > 1_000_000:
385+
raise HTTPException(
386+
status_code=413, detail="Submission file is too large (limit: 1MB)."
387+
)
388+
389+
except HTTPException:
390+
raise
278391
except Exception as e:
279-
raise HTTPException(status_code=400, detail=f"Error building task config: {e}") from e
280-
281-
submission_request = SubmissionRequest(
282-
code=submission_content.decode("utf-8"),
283-
file_name=file.filename,
284-
user_id=user_id,
285-
gpus=[gpu_type],
286-
leaderboard=leaderboard_name,
287-
)
392+
raise HTTPException(status_code=400, detail=f"Error reading submission file: {e}") from e
393+
394+
try:
395+
submission_code = submission_content.decode("utf-8")
396+
submission_request = SubmissionRequest(
397+
code=submission_code,
398+
file_name=file.filename or "submission.py",
399+
user_id=user_id,
400+
gpus=[gpu_type],
401+
leaderboard=leaderboard_name,
402+
)
403+
except UnicodeDecodeError:
404+
raise HTTPException(
405+
status_code=400, detail="Failed to decode submission file content as UTF-8."
406+
) from None
407+
except Exception as e:
408+
raise HTTPException(
409+
status_code=500, detail=f"Internal server error creating submission request: {e}"
410+
) from e
288411

289-
result = await _run_submission(
290-
submission_request,
291-
{"user_id": user_id, "user_name": user_name},
292-
submission_mode_enum,
293-
bot_instance,
412+
generator = _stream_submission_response(
413+
submission_request=submission_request,
414+
user_info={"user_id": user_id, "user_name": user_name},
415+
submission_mode_enum=submission_mode_enum,
416+
bot_instance=bot_instance,
294417
)
295418

296-
return {"status": "success", "results": [asdict(r) for r in result]}
419+
return StreamingResponse(generator, media_type="text/event-stream")
297420

298421

299422
@app.get("/leaderboards")

0 commit comments

Comments
 (0)