3434
3535import yaml
3636import swanlab
37+ from swanlab import echarts
3738from swanlab .log import swanlog as swl
3839from swanlab .data .porter import DataPorter
3940from swanlab .env import create_time
@@ -143,6 +144,28 @@ def _proto_items_to_dict(self, items) -> dict:
143144 swl .warning (f"Could not decode json for key '{ key } ': { item .value_json } " )
144145 return mapping
145146
147+ def _validate_path (self , base_dir : str , file_path : str ) -> Optional [str ]:
148+ """Validate that file_path stays within base_dir to prevent path traversal."""
149+ if not file_path :
150+ return None
151+ abs_base = os .path .abspath (base_dir )
152+ abs_path = os .path .abspath (os .path .join (base_dir , file_path ))
153+ if not abs_path .startswith (abs_base + os .sep ) and abs_path != abs_base :
154+ swl .warning (f"Path traversal attempt detected: { file_path } " )
155+ return None
156+ return abs_path
157+
158+ def _filter_text_columns (self , columns , data ):
159+ """Filter out non-text columns (images, media) from table data."""
160+ non_text_indices = {
161+ i for row in data for i , cell in enumerate (row ) if isinstance (cell , dict ) and "_type" in cell
162+ }
163+ text_indices = [i for i in range (len (columns )) if i not in non_text_indices ]
164+ return (
165+ [columns [i ] for i in text_indices ],
166+ [[row [i ] for i in text_indices if i < len (row )] for row in data ]
167+ )
168+
146169 def _parse_run (self , run_dir : str ):
147170 """Parses a single wandb run directory and converts it to a SwanLab run."""
148171 wandb_files = glob .glob (os .path .join (run_dir , "*.wandb" ))
@@ -166,38 +189,38 @@ def _parse_run(self, run_dir: str):
166189
167190 # Direct-write state (bypasses swanlab_run.log() overhead for scalar metrics)
168191 porter = None # set after swanlab.init()
169- _conv_column_kids = {} # key -> kid str
170- _conv_epoch_counters = {} # key -> cumulative epoch count
171- _upload_pre_count = [ 0 ] # items uploaded before upload progress bar appears
172- _total_scalars_logged = [ 0 ] # total scalar data points logged
192+ column_kids = {} # key -> kid str
193+ epoch_counters = {} # key -> cumulative epoch count
194+ upload_pre_count = 0 # items uploaded before upload progress bar appears
195+ total_scalars_logged = 0 # total scalar data points logged
173196
174197 def _log_scalars_direct (scalars , step ):
175198 """Write float scalars directly to the porter, bypassing log() overhead."""
199+ nonlocal total_scalars_logged
176200 if not scalars or porter is None :
177201 return
178202 for key in scalars :
179- if key not in _conv_column_kids :
180- kid = len (_conv_column_kids )
181- _conv_column_kids [key ] = str (kid )
203+ if key not in column_kids :
204+ kid = len (column_kids )
205+ column_kids [key ] = str (kid )
182206 split_key = key .split ("/" )
183207 sname = split_key [0 ] if len (split_key ) > 1 and split_key [0 ] else None
184208 porter .trace_column (ColumnInfo (
185209 key = key , kid = str (kid ), name = key , cls = 'CUSTOM' ,
186210 chart_type = ChartType .LINE , chart_reference = 'STEP' ,
187211 section_name = sname , section_type = "PUBLIC" ,
188212 ))
189- for key in scalars :
190- _conv_epoch_counters [key ] = _conv_epoch_counters .get (key , 0 ) + 1
191- _total_scalars_logged [0 ] += len (scalars )
192- porter .trace_scalars_step (step , scalars , dict (_conv_epoch_counters ), create_time ())
213+ epoch_counters [key ] = epoch_counters .get (key , 0 ) + 1
214+ total_scalars_logged += len (scalars )
215+ porter .trace_scalars_step (step , scalars , dict (epoch_counters ), create_time ())
193216
194217 def _finish_with_progress ():
195218 """Run swanlab_run.finish() while showing a Rich upload progress bar."""
196219 _pool = porter ._pool if porter is not None else None
197220 if _pool is None :
198221 swanlab_run .finish ()
199222 return
200- total = len (_conv_column_kids ) + _total_scalars_logged [ 0 ]
223+ total = len (column_kids ) + total_scalars_logged
201224 up = Progress (
202225 TextColumn ("[bold green]{task.description}" ),
203226 BarColumn (bar_width = 40 ),
@@ -206,32 +229,34 @@ def _finish_with_progress():
206229 TimeRemainingColumn (),
207230 )
208231 up .start ()
209- t = up .add_task ("Uploading to SwanLab" , total = total , completed = _upload_pre_count [ 0 ] )
232+ t = up .add_task ("Uploading to SwanLab" , total = total , completed = upload_pre_count )
210233
211- _last_completed = [ _upload_pre_count [ 0 ]]
212- _stall_check_time = [ time .time ()]
234+ last_completed = upload_pre_count
235+ stall_check_time = time .time ()
213236
214237 def _upload_cb (n ):
215- _last_completed [0 ] += n
216- _stall_check_time [0 ] = time .time ()
238+ nonlocal last_completed , stall_check_time
239+ last_completed += n
240+ stall_check_time = time .time ()
217241 up .update (t , advance = n )
218242
219243 _pool .collector .upload_callback = _upload_cb
220244
221245 # Monitor for stalls and update description
222246 import threading
223- _stop_monitor = [ False ]
247+ stop_monitor = False
224248 def _monitor_stalls ():
225- while not _stop_monitor [0 ]:
249+ nonlocal stop_monitor
250+ while not stop_monitor :
226251 time .sleep (1 )
227- if _stop_monitor [ 0 ] :
252+ if stop_monitor :
228253 break
229- elapsed = time .time () - _stall_check_time [ 0 ]
230- if elapsed > 5 and _last_completed [ 0 ] < total :
231- remaining = total - _last_completed [ 0 ]
254+ elapsed = time .time () - stall_check_time
255+ if elapsed > 5 and last_completed < total :
256+ remaining = total - last_completed
232257 batch_info = f"batch ~{ min (1000 , remaining )} items" if remaining > 0 else "final batch"
233258 up .update (t , description = f"[bold yellow]Uploading to SwanLab (processing { batch_info } , { int (elapsed )} s)" )
234- elif _last_completed [ 0 ] < total :
259+ elif last_completed < total :
235260 up .update (t , description = "[bold green]Uploading to SwanLab" )
236261
237262 monitor_thread = threading .Thread (target = _monitor_stalls , daemon = True )
@@ -244,7 +269,7 @@ def _monitor_stalls():
244269 swanlab_run .finish ()
245270 finally :
246271 swl .info = _orig_info
247- _stop_monitor [ 0 ] = True
272+ stop_monitor = True
248273 up .update (t , completed = total , description = "[bold green]Uploading to SwanLab" )
249274 up .stop ()
250275
@@ -283,7 +308,8 @@ def initialize_swanlab_run_if_needed():
283308 # Set pre-count callback: track items uploaded during parse before the upload bar appears
284309 if porter is not None and porter ._pool is not None :
285310 def _pre_upload_cb (n ):
286- _upload_pre_count [0 ] += n
311+ nonlocal upload_pre_count
312+ upload_pre_count += n
287313 porter ._pool .collector .upload_callback = _pre_upload_cb
288314
289315 # 恢复进度条
@@ -369,39 +395,79 @@ def _pre_upload_cb(n):
369395 initialize_swanlab_run_if_needed ()
370396 scalar_dict = {}
371397 media_dict = {}
398+ grouped_items = {}
372399 step = 0
373- # Single- pass: parse proto items directly, fast-path float() for scalars
400+ # First pass: group items by base key
374401 for item in record_pb .history .item :
375402 key = item .key or '/' .join (item .nested_key )
376403 if not key :
377404 continue
378- vj = item .value_json
405+ value_json = item .value_json
379406 if key == '_step' :
380407 try :
381- step = int (float (vj ))
408+ step = int (float (value_json ))
382409 except (ValueError , TypeError ):
383410 pass
384411 continue
385412 if key .startswith ('_' ):
386413 continue
387- # Fast path: direct float conversion (avoids full JSON parse for scalars)
388- try :
389- scalar_dict [key ] = float (vj )
390- continue
391- except (ValueError , TypeError ):
392- pass
393- # Slow path: full JSON parse for complex types (media, etc.)
394- try :
395- value = _json_loads (vj )
396- except (ValueError , Exception ):
397- continue
398- if isinstance (value , int ):
399- scalar_dict [key ] = float (value )
400- elif isinstance (value , dict ) and "_type" in value :
401- media_type = value ["_type" ]
402- path = os .path .join (files_root_dir , value .get ("path" , "" ))
403- if os .path .exists (path ) and media_type == "image-file" :
404- media_dict [key ] = swanlab .Image (path )
414+ # Check if key has nested structure (e.g., "table/_type")
415+ if '/' in key :
416+ base_key , sub_key = key .split ('/' , 1 )
417+ if base_key not in grouped_items :
418+ grouped_items [base_key ] = {}
419+ try :
420+ grouped_items [base_key ][sub_key ] = _json_loads (value_json )
421+ except (ValueError , Exception ):
422+ grouped_items [base_key ][sub_key ] = value_json
423+ else :
424+ # Fast path: direct float conversion for scalars
425+ try :
426+ scalar_dict [key ] = float (value_json )
427+ continue
428+ except (ValueError , TypeError ):
429+ pass
430+ # Slow path: full JSON parse
431+ try :
432+ value = _json_loads (value_json )
433+ if isinstance (value , int ):
434+ scalar_dict [key ] = float (value )
435+ elif isinstance (value , dict ) and "path" in value :
436+ validated_path = self ._validate_path (files_root_dir , value ["path" ])
437+ if validated_path and os .path .exists (validated_path ):
438+ if value .get ("_type" ) == "image-file" :
439+ media_dict [key ] = swanlab .Image (validated_path )
440+ elif value .get ("_type" ) == "audio-file" :
441+ media_dict [key ] = swanlab .Audio (validated_path )
442+ except (ValueError , Exception ):
443+ pass
444+
445+ # Second pass: process grouped items for tables and media
446+ for base_key , props in grouped_items .items ():
447+ if props .get ('_type' ) == 'table-file' and 'path' in props :
448+ validated_path = self ._validate_path (files_root_dir , props ['path' ])
449+ if validated_path and os .path .exists (validated_path ):
450+ try :
451+ with open (validated_path , 'r' , encoding = 'utf-8' ) as f :
452+ table_data = _json_loads (f .read ())
453+ columns = table_data .get ("columns" , [])
454+ data = table_data .get ("data" , [])
455+ # Filter text-only columns
456+ filtered_cols , filtered_data = self ._filter_text_columns (columns , data )
457+ if filtered_cols :
458+ table = echarts .Table ()
459+ table .add (filtered_cols , filtered_data )
460+ media_dict [base_key ] = table
461+ except Exception as e :
462+ swl .warning (f"Failed to parse table from { validated_path } : { e } " )
463+ elif props .get ('_type' ) == 'image-file' and 'path' in props :
464+ validated_path = self ._validate_path (files_root_dir , props ['path' ])
465+ if validated_path and os .path .exists (validated_path ):
466+ media_dict [base_key ] = swanlab .Image (validated_path )
467+ elif props .get ('_type' ) == 'audio-file' and 'path' in props :
468+ validated_path = self ._validate_path (files_root_dir , props ['path' ])
469+ if validated_path and os .path .exists (validated_path ):
470+ media_dict [base_key ] = swanlab .Audio (validated_path )
405471
406472 if scalar_dict or media_dict :
407473 if scalar_dict :
@@ -416,17 +482,13 @@ def _pre_upload_cb(n):
416482 # would result in N redundant log calls (N = number of steps).
417483 for item in record_pb .summary .update :
418484 key = item .key or '/' .join (item .nested_key )
419- if not key or key .startswith ('_' ):
485+ if not key or key .startswith ('_' ) or '/' in key :
420486 continue
421487 try :
422488 last_summary [key ] = float (item .value_json )
423489 except (ValueError , TypeError ):
424490 pass
425491
426- # 清理公共变量,释放内存
427- del record_pb
428- del record_bin
429-
430492 # GC every GC_INTERVAL records to reduce overhead
431493 record_count += 1
432494 if record_count % GC_INTERVAL == 0 :
@@ -436,12 +498,6 @@ def _pre_upload_cb(n):
436498 if progress is not None :
437499 progress .stop ()
438500
439- # Log 最终的 summary 数据(只 log 一次,避免对每条 summary record 都 log)
440- if last_summary :
441- initialize_swanlab_run_if_needed ()
442- if swanlab_run :
443- _log_scalars_direct (last_summary , last_summary_step )
444-
445501 if swanlab_run :
446502 swl .info (f"Finished Parsing run: { run_metadata ['name' ]} " )
447503 _finish_with_progress ()
0 commit comments