Skip to content

Commit 930635c

Browse files
committed
rearrange reporting code
1 parent 3892756 commit 930635c

2 files changed

Lines changed: 48 additions & 33 deletions

File tree

src/discord-cluster-manager/cogs/submit_cog.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
from consts import GPU, GPU_TO_SM, RankCriterion, SubmissionMode, get_gpu_by_name
1212
from discord import app_commands
1313
from discord.ext import commands
14-
from report import MultiProgressReporter, RunProgressReporter, make_short_report
14+
from report import (
15+
MultiProgressReporter,
16+
RunProgressReporter,
17+
generate_report,
18+
make_short_report,
19+
)
1520
from run_eval import FullResult
1621
from task import LeaderboardTask
1722
from utils import (
@@ -224,8 +229,9 @@ async def _handle_submission(
224229
# does the last message of the short report start with ✅ or ❌?
225230
verdict = short_report[-1][0]
226231
id_str = f"{verdict}" if submission_id == -1 else f"{verdict} #{submission_id}"
227-
await reporter.generate_report(
228-
f"{id_str} {name} on {gpu_type.name} ({launcher.name})", result
232+
await reporter.display_report(
233+
f"{id_str} {name} on {gpu_type.name} ({launcher.name})",
234+
generate_report(result),
229235
)
230236
except Exception as E:
231237
logger.error("Error generating report. Result: %s", result, exc_info=E)

src/discord-cluster-manager/report.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -260,80 +260,86 @@ def generate_system_info(system: SystemInfo):
260260
"""
261261

262262

263-
def generate_report(reporter: "RunResultReport", runs: dict[str, EvalResult]): # noqa: C901
263+
def generate_report(result: FullResult) -> RunResultReport: # noqa: C901
264+
runs = result.runs
265+
report = RunResultReport()
266+
report.add_text(generate_system_info(result.system))
267+
264268
if "test" in runs:
265269
test_run = runs["test"]
266270

267271
if test_run.compilation is not None and not test_run.compilation.success:
268-
_generate_compile_report(reporter, test_run.compilation)
269-
return
272+
_generate_compile_report(report, test_run.compilation)
273+
return report
270274

271275
test_run = test_run.run
272276

273277
if not test_run.success:
274-
_generate_crash_report(reporter, test_run)
275-
return
278+
_generate_crash_report(report, test_run)
279+
return report
276280

277281
if not test_run.passed:
278-
_generate_test_report(reporter, test_run)
279-
return
282+
_generate_test_report(report, test_run)
283+
return report
280284
else:
281285
num_tests = int(test_run.result.get("test-count", 0))
282-
reporter.add_log(f"✅ Passed {num_tests}/{num_tests} tests", make_test_log(test_run))
286+
report.add_log(f"✅ Passed {num_tests}/{num_tests} tests", make_test_log(test_run))
283287

284288
if "benchmark" in runs:
285289
bench_run = runs["benchmark"]
286290
if bench_run.compilation is not None and not bench_run.compilation.success:
287-
_generate_compile_report(reporter, bench_run.compilation)
288-
return
291+
_generate_compile_report(report, bench_run.compilation)
292+
return report
289293

290294
bench_run = bench_run.run
291295
if not bench_run.success:
292-
_generate_crash_report(reporter, bench_run)
293-
return
296+
_generate_crash_report(report, bench_run)
297+
return report
294298

295-
reporter.add_log(
299+
report.add_log(
296300
"Benchmarks",
297301
make_benchmark_log(bench_run),
298302
)
299303

300304
if "leaderboard" in runs:
301305
bench_run = runs["leaderboard"]
302306
if bench_run.compilation is not None and not bench_run.compilation.success:
303-
_generate_compile_report(reporter, bench_run.compilation)
304-
return
307+
_generate_compile_report(report, bench_run.compilation)
308+
return report
305309

306310
bench_run = bench_run.run
307311
if not bench_run.success:
308-
_generate_crash_report(reporter, bench_run)
309-
return
312+
_generate_crash_report(report, bench_run)
313+
return report
310314

311-
reporter.add_log(
315+
report.add_log(
312316
"Ranked Benchmark",
313317
make_benchmark_log(bench_run),
314318
)
315319

316320
if "script" in runs:
317321
run = runs["script"]
318322
if run.compilation is not None and not run.compilation.success:
319-
_generate_compile_report(reporter, run.compilation)
320-
return
323+
_generate_compile_report(report, run.compilation)
324+
return report
321325

322326
run = run.run
323327
# OK, we were successful
324328
message = "# Success!\n"
325329
message += "Command "
326330
message += f"```bash\n{_limit_length(run.command, 1000)}```\n"
327331
message += f"ran successfully in {run.duration:.2} seconds.\n"
328-
reporter.add_text(message)
332+
report.add_text(message)
329333

330334
if len(runs) == 1:
331335
run = next(iter(runs.values()))
332336
if len(run.run.stderr.strip()) > 0:
333-
reporter.add_log("Program stderr", run.run.stderr.strip())
337+
report.add_log("Program stderr", run.run.stderr.strip())
334338

335339
if len(run.run.stdout.strip()) > 0:
336-
reporter.add_log("Program stdout", run.run.stdout.strip())
340+
report.add_log("Program stdout", run.run.stdout.strip())
341+
342+
return report
337343

338344

339345
class MultiProgressReporter:
@@ -389,7 +395,7 @@ async def update_title(self, new_title):
389395
def get_message(self):
390396
return str.join("\n", [f"**{self.title}**"] + self.lines)
391397

392-
async def generate_report(self, title: str, result: FullResult):
398+
async def display_report(self, title: str, report: RunResultReport):
393399
raise NotImplementedError()
394400

395401
async def _update_message(self):
@@ -410,16 +416,13 @@ def __init__(
410416
async def _update_message(self):
411417
await self.root._update_message()
412418

413-
async def generate_report(self, title: str, result: FullResult):
419+
async def display_report(self, title: str, report: RunResultReport):
414420
thread = await self.interaction.channel.create_thread(
415421
name=title,
416422
type=discord.ChannelType.private_thread,
417423
auto_archive_duration=1440,
418424
)
419425
await thread.add_user(self.interaction.user)
420-
report = RunResultReport()
421-
report.add_text(generate_system_info(result.system))
422-
generate_report(report, result.runs)
423426
message = ""
424427
for part in report.data:
425428
if isinstance(part, Text):
@@ -439,9 +442,15 @@ async def generate_report(self, title: str, result: FullResult):
439442
class RunProgressReporterAPI(RunProgressReporter):
440443
def __init__(self, title: str):
441444
super().__init__(title=title)
445+
self.long_report = ""
442446

443447
async def _update_message(self):
444448
pass
445449

446-
async def generate_report(self, title: str, runs: dict[str, EvalResult]):
447-
pass
450+
async def display_report(self, title: str, report: RunResultReport):
451+
for part in report.data:
452+
if isinstance(part, Text):
453+
self.long_report += part.text
454+
elif isinstance(part, Log):
455+
self.long_report += f"\n\n## {part.header}:\n"
456+
self.long_report += f"```\n{part.content}```"

0 commit comments

Comments
 (0)