Skip to content

Commit d9e86b4

Browse files
committed
move common code into generic RunProgressReporter
1 parent 4360e5e commit d9e86b4

2 files changed

Lines changed: 26 additions & 36 deletions

File tree

src/discord-cluster-manager/api/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def _run_submission(
186186
submission.code,
187187
submission.file_name,
188188
gpu,
189-
RunProgressReporterAPI(),
189+
RunProgressReporterAPI(f"{gpu.name} on {gpu.runner}"),
190190
req.task,
191191
mode,
192192
None,
@@ -200,7 +200,7 @@ async def _run_submission(
200200
submission.code,
201201
submission.file_name,
202202
gpu,
203-
RunProgressReporterAPI(),
203+
RunProgressReporterAPI(f"{gpu.name} on {gpu.runner} (secret)"),
204204
req.task,
205205
SubmissionMode.PRIVATE,
206206
req.secret_seed,

src/discord-cluster-manager/report.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -346,18 +346,35 @@ async def _update_message(self):
346346

347347

348348
class RunProgressReporter:
349+
def __init__(self, title: str):
350+
self.title = title
351+
self.lines = []
352+
349353
async def push(self, content: str | list[str]):
350-
raise NotImplementedError()
354+
if isinstance(content, str):
355+
self.lines.append(f"> {content}")
356+
else:
357+
for line in content:
358+
self.lines.append(f"> {line}")
359+
await self._update_message()
351360

352361
async def update(self, new_content: str):
353-
raise NotImplementedError()
362+
self.lines[-1] = f"> {new_content}"
363+
await self._update_message()
354364

355365
async def update_title(self, new_title):
356-
raise NotImplementedError()
366+
self.title = new_title
367+
await self._update_message()
368+
369+
def get_message(self):
370+
return str.join("\n", [f"**{self.title}**"] + self.lines)
357371

358372
async def generate_report(self, title: str, runs: dict[str, EvalResult]):
359373
raise NotImplementedError()
360374

375+
async def _update_message(self):
376+
raise NotImplementedError()
377+
361378

362379
class RunProgressReporterDiscord(RunProgressReporter):
363380
def __init__(
@@ -366,33 +383,13 @@ def __init__(
366383
interaction: discord.Interaction,
367384
title: str,
368385
):
369-
self.title = title
370-
self.lines = []
386+
super().__init__(title=title)
371387
self.root = root
372388
self.interaction = interaction
373389

374-
async def push(self, content: str | list[str]):
375-
if isinstance(content, str):
376-
self.lines.append(f"> {content}")
377-
else:
378-
for line in content:
379-
self.lines.append(f"> {line}")
380-
await self._update_message()
381-
382-
async def update(self, new_content: str):
383-
self.lines[-1] = f"> {new_content}"
384-
await self._update_message()
385-
386-
async def update_title(self, new_title):
387-
self.title = new_title
388-
await self._update_message()
389-
390390
async def _update_message(self):
391391
await self.root._update_message()
392392

393-
def get_message(self):
394-
return str.join("\n", [f"**{self.title}**"] + self.lines)
395-
396393
async def generate_report(self, title: str, runs: dict[str, EvalResult]):
397394
thread = await self.interaction.channel.create_thread(
398395
name=title,
@@ -405,17 +402,10 @@ async def generate_report(self, title: str, runs: dict[str, EvalResult]):
405402

406403

407404
class RunProgressReporterAPI(RunProgressReporter):
408-
def __init__(self):
409-
self.title = ""
410-
self.lines = []
405+
def __init__(self, title: str):
406+
super().__init__(title=title)
411407

412-
async def push(self, content: str | list[str]):
413-
pass
414-
415-
async def update(self, new_content: str):
416-
pass
417-
418-
async def update_title(self, new_title):
408+
async def _update_message(self):
419409
pass
420410

421411
async def generate_report(self, title: str, runs: dict[str, EvalResult]):

0 commit comments

Comments
 (0)