@@ -346,18 +346,35 @@ async def _update_message(self):
346346
347347
348348class RunProgressReporter :
349+ def __init__ (self , title : str ):
350+ self .title = title
351+ self .lines = []
352+
349353 async def push (self , content : str | list [str ]):
350- raise NotImplementedError ()
354+ if isinstance (content , str ):
355+ self .lines .append (f"> { content } " )
356+ else :
357+ for line in content :
358+ self .lines .append (f"> { line } " )
359+ await self ._update_message ()
351360
352361 async def update (self , new_content : str ):
353- raise NotImplementedError ()
362+ self .lines [- 1 ] = f"> { new_content } "
363+ await self ._update_message ()
354364
355365 async def update_title (self , new_title ):
356- raise NotImplementedError ()
366+ self .title = new_title
367+ await self ._update_message ()
368+
369+ def get_message (self ):
370+ return str .join ("\n " , [f"**{ self .title } **" ] + self .lines )
357371
358372 async def generate_report (self , title : str , runs : dict [str , EvalResult ]):
359373 raise NotImplementedError ()
360374
375+ async def _update_message (self ):
376+ raise NotImplementedError ()
377+
361378
362379class RunProgressReporterDiscord (RunProgressReporter ):
363380 def __init__ (
@@ -366,33 +383,13 @@ def __init__(
366383 interaction : discord .Interaction ,
367384 title : str ,
368385 ):
369- self .title = title
370- self .lines = []
386+ super ().__init__ (title = title )
371387 self .root = root
372388 self .interaction = interaction
373389
374- async def push (self , content : str | list [str ]):
375- if isinstance (content , str ):
376- self .lines .append (f"> { content } " )
377- else :
378- for line in content :
379- self .lines .append (f"> { line } " )
380- await self ._update_message ()
381-
382- async def update (self , new_content : str ):
383- self .lines [- 1 ] = f"> { new_content } "
384- await self ._update_message ()
385-
386- async def update_title (self , new_title ):
387- self .title = new_title
388- await self ._update_message ()
389-
390390 async def _update_message (self ):
391391 await self .root ._update_message ()
392392
393- def get_message (self ):
394- return str .join ("\n " , [f"**{ self .title } **" ] + self .lines )
395-
396393 async def generate_report (self , title : str , runs : dict [str , EvalResult ]):
397394 thread = await self .interaction .channel .create_thread (
398395 name = title ,
@@ -405,17 +402,10 @@ async def generate_report(self, title: str, runs: dict[str, EvalResult]):
405402
406403
407404class RunProgressReporterAPI (RunProgressReporter ):
408- def __init__ (self ):
409- self .title = ""
410- self .lines = []
405+ def __init__ (self , title : str ):
406+ super ().__init__ (title = title )
411407
412- async def push (self , content : str | list [str ]):
413- pass
414-
415- async def update (self , new_content : str ):
416- pass
417-
418- async def update_title (self , new_title ):
408+ async def _update_message (self ):
419409 pass
420410
421411 async def generate_report (self , title : str , runs : dict [str , EvalResult ]):
0 commit comments