Skip to content

Commit 807c4dc

Browse files
authored
feat: wandb local converter table (#1469)
* try * refactor * fix
1 parent 208f929 commit 807c4dc

1 file changed

Lines changed: 114 additions & 58 deletions

File tree

swanlab/converter/wb/wb_local_converter.py

Lines changed: 114 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import yaml
3636
import swanlab
37+
from swanlab import echarts
3738
from swanlab.log import swanlog as swl
3839
from swanlab.data.porter import DataPorter
3940
from swanlab.env import create_time
@@ -143,6 +144,28 @@ def _proto_items_to_dict(self, items) -> dict:
143144
swl.warning(f"Could not decode json for key '{key}': {item.value_json}")
144145
return mapping
145146

147+
def _validate_path(self, base_dir: str, file_path: str) -> Optional[str]:
148+
"""Validate that file_path stays within base_dir to prevent path traversal."""
149+
if not file_path:
150+
return None
151+
abs_base = os.path.abspath(base_dir)
152+
abs_path = os.path.abspath(os.path.join(base_dir, file_path))
153+
if not abs_path.startswith(abs_base + os.sep) and abs_path != abs_base:
154+
swl.warning(f"Path traversal attempt detected: {file_path}")
155+
return None
156+
return abs_path
157+
158+
def _filter_text_columns(self, columns, data):
159+
"""Filter out non-text columns (images, media) from table data."""
160+
non_text_indices = {
161+
i for row in data for i, cell in enumerate(row) if isinstance(cell, dict) and "_type" in cell
162+
}
163+
text_indices = [i for i in range(len(columns)) if i not in non_text_indices]
164+
return (
165+
[columns[i] for i in text_indices],
166+
[[row[i] for i in text_indices if i < len(row)] for row in data]
167+
)
168+
146169
def _parse_run(self, run_dir: str):
147170
"""Parses a single wandb run directory and converts it to a SwanLab run."""
148171
wandb_files = glob.glob(os.path.join(run_dir, "*.wandb"))
@@ -166,38 +189,38 @@ def _parse_run(self, run_dir: str):
166189

167190
# Direct-write state (bypasses swanlab_run.log() overhead for scalar metrics)
168191
porter = None # set after swanlab.init()
169-
_conv_column_kids = {} # key -> kid str
170-
_conv_epoch_counters = {} # key -> cumulative epoch count
171-
_upload_pre_count = [0] # items uploaded before upload progress bar appears
172-
_total_scalars_logged = [0] # total scalar data points logged
192+
column_kids = {} # key -> kid str
193+
epoch_counters = {} # key -> cumulative epoch count
194+
upload_pre_count = 0 # items uploaded before upload progress bar appears
195+
total_scalars_logged = 0 # total scalar data points logged
173196

174197
def _log_scalars_direct(scalars, step):
175198
"""Write float scalars directly to the porter, bypassing log() overhead."""
199+
nonlocal total_scalars_logged
176200
if not scalars or porter is None:
177201
return
178202
for key in scalars:
179-
if key not in _conv_column_kids:
180-
kid = len(_conv_column_kids)
181-
_conv_column_kids[key] = str(kid)
203+
if key not in column_kids:
204+
kid = len(column_kids)
205+
column_kids[key] = str(kid)
182206
split_key = key.split("/")
183207
sname = split_key[0] if len(split_key) > 1 and split_key[0] else None
184208
porter.trace_column(ColumnInfo(
185209
key=key, kid=str(kid), name=key, cls='CUSTOM',
186210
chart_type=ChartType.LINE, chart_reference='STEP',
187211
section_name=sname, section_type="PUBLIC",
188212
))
189-
for key in scalars:
190-
_conv_epoch_counters[key] = _conv_epoch_counters.get(key, 0) + 1
191-
_total_scalars_logged[0] += len(scalars)
192-
porter.trace_scalars_step(step, scalars, dict(_conv_epoch_counters), create_time())
213+
epoch_counters[key] = epoch_counters.get(key, 0) + 1
214+
total_scalars_logged += len(scalars)
215+
porter.trace_scalars_step(step, scalars, dict(epoch_counters), create_time())
193216

194217
def _finish_with_progress():
195218
"""Run swanlab_run.finish() while showing a Rich upload progress bar."""
196219
_pool = porter._pool if porter is not None else None
197220
if _pool is None:
198221
swanlab_run.finish()
199222
return
200-
total = len(_conv_column_kids) + _total_scalars_logged[0]
223+
total = len(column_kids) + total_scalars_logged
201224
up = Progress(
202225
TextColumn("[bold green]{task.description}"),
203226
BarColumn(bar_width=40),
@@ -206,32 +229,34 @@ def _finish_with_progress():
206229
TimeRemainingColumn(),
207230
)
208231
up.start()
209-
t = up.add_task("Uploading to SwanLab", total=total, completed=_upload_pre_count[0])
232+
t = up.add_task("Uploading to SwanLab", total=total, completed=upload_pre_count)
210233

211-
_last_completed = [_upload_pre_count[0]]
212-
_stall_check_time = [time.time()]
234+
last_completed = upload_pre_count
235+
stall_check_time = time.time()
213236

214237
def _upload_cb(n):
215-
_last_completed[0] += n
216-
_stall_check_time[0] = time.time()
238+
nonlocal last_completed, stall_check_time
239+
last_completed += n
240+
stall_check_time = time.time()
217241
up.update(t, advance=n)
218242

219243
_pool.collector.upload_callback = _upload_cb
220244

221245
# Monitor for stalls and update description
222246
import threading
223-
_stop_monitor = [False]
247+
stop_monitor = False
224248
def _monitor_stalls():
225-
while not _stop_monitor[0]:
249+
nonlocal stop_monitor
250+
while not stop_monitor:
226251
time.sleep(1)
227-
if _stop_monitor[0]:
252+
if stop_monitor:
228253
break
229-
elapsed = time.time() - _stall_check_time[0]
230-
if elapsed > 5 and _last_completed[0] < total:
231-
remaining = total - _last_completed[0]
254+
elapsed = time.time() - stall_check_time
255+
if elapsed > 5 and last_completed < total:
256+
remaining = total - last_completed
232257
batch_info = f"batch ~{min(1000, remaining)} items" if remaining > 0 else "final batch"
233258
up.update(t, description=f"[bold yellow]Uploading to SwanLab (processing {batch_info}, {int(elapsed)}s)")
234-
elif _last_completed[0] < total:
259+
elif last_completed < total:
235260
up.update(t, description="[bold green]Uploading to SwanLab")
236261

237262
monitor_thread = threading.Thread(target=_monitor_stalls, daemon=True)
@@ -244,7 +269,7 @@ def _monitor_stalls():
244269
swanlab_run.finish()
245270
finally:
246271
swl.info = _orig_info
247-
_stop_monitor[0] = True
272+
stop_monitor = True
248273
up.update(t, completed=total, description="[bold green]Uploading to SwanLab")
249274
up.stop()
250275

@@ -283,7 +308,8 @@ def initialize_swanlab_run_if_needed():
283308
# Set pre-count callback: track items uploaded during parse before the upload bar appears
284309
if porter is not None and porter._pool is not None:
285310
def _pre_upload_cb(n):
286-
_upload_pre_count[0] += n
311+
nonlocal upload_pre_count
312+
upload_pre_count += n
287313
porter._pool.collector.upload_callback = _pre_upload_cb
288314

289315
# 恢复进度条
@@ -369,39 +395,79 @@ def _pre_upload_cb(n):
369395
initialize_swanlab_run_if_needed()
370396
scalar_dict = {}
371397
media_dict = {}
398+
grouped_items = {}
372399
step = 0
373-
# Single-pass: parse proto items directly, fast-path float() for scalars
400+
# First pass: group items by base key
374401
for item in record_pb.history.item:
375402
key = item.key or '/'.join(item.nested_key)
376403
if not key:
377404
continue
378-
vj = item.value_json
405+
value_json = item.value_json
379406
if key == '_step':
380407
try:
381-
step = int(float(vj))
408+
step = int(float(value_json))
382409
except (ValueError, TypeError):
383410
pass
384411
continue
385412
if key.startswith('_'):
386413
continue
387-
# Fast path: direct float conversion (avoids full JSON parse for scalars)
388-
try:
389-
scalar_dict[key] = float(vj)
390-
continue
391-
except (ValueError, TypeError):
392-
pass
393-
# Slow path: full JSON parse for complex types (media, etc.)
394-
try:
395-
value = _json_loads(vj)
396-
except (ValueError, Exception):
397-
continue
398-
if isinstance(value, int):
399-
scalar_dict[key] = float(value)
400-
elif isinstance(value, dict) and "_type" in value:
401-
media_type = value["_type"]
402-
path = os.path.join(files_root_dir, value.get("path", ""))
403-
if os.path.exists(path) and media_type == "image-file":
404-
media_dict[key] = swanlab.Image(path)
414+
# Check if key has nested structure (e.g., "table/_type")
415+
if '/' in key:
416+
base_key, sub_key = key.split('/', 1)
417+
if base_key not in grouped_items:
418+
grouped_items[base_key] = {}
419+
try:
420+
grouped_items[base_key][sub_key] = _json_loads(value_json)
421+
except (ValueError, Exception):
422+
grouped_items[base_key][sub_key] = value_json
423+
else:
424+
# Fast path: direct float conversion for scalars
425+
try:
426+
scalar_dict[key] = float(value_json)
427+
continue
428+
except (ValueError, TypeError):
429+
pass
430+
# Slow path: full JSON parse
431+
try:
432+
value = _json_loads(value_json)
433+
if isinstance(value, int):
434+
scalar_dict[key] = float(value)
435+
elif isinstance(value, dict) and "path" in value:
436+
validated_path = self._validate_path(files_root_dir, value["path"])
437+
if validated_path and os.path.exists(validated_path):
438+
if value.get("_type") == "image-file":
439+
media_dict[key] = swanlab.Image(validated_path)
440+
elif value.get("_type") == "audio-file":
441+
media_dict[key] = swanlab.Audio(validated_path)
442+
except (ValueError, Exception):
443+
pass
444+
445+
# Second pass: process grouped items for tables and media
446+
for base_key, props in grouped_items.items():
447+
if props.get('_type') == 'table-file' and 'path' in props:
448+
validated_path = self._validate_path(files_root_dir, props['path'])
449+
if validated_path and os.path.exists(validated_path):
450+
try:
451+
with open(validated_path, 'r', encoding='utf-8') as f:
452+
table_data = _json_loads(f.read())
453+
columns = table_data.get("columns", [])
454+
data = table_data.get("data", [])
455+
# Filter text-only columns
456+
filtered_cols, filtered_data = self._filter_text_columns(columns, data)
457+
if filtered_cols:
458+
table = echarts.Table()
459+
table.add(filtered_cols, filtered_data)
460+
media_dict[base_key] = table
461+
except Exception as e:
462+
swl.warning(f"Failed to parse table from {validated_path}: {e}")
463+
elif props.get('_type') == 'image-file' and 'path' in props:
464+
validated_path = self._validate_path(files_root_dir, props['path'])
465+
if validated_path and os.path.exists(validated_path):
466+
media_dict[base_key] = swanlab.Image(validated_path)
467+
elif props.get('_type') == 'audio-file' and 'path' in props:
468+
validated_path = self._validate_path(files_root_dir, props['path'])
469+
if validated_path and os.path.exists(validated_path):
470+
media_dict[base_key] = swanlab.Audio(validated_path)
405471

406472
if scalar_dict or media_dict:
407473
if scalar_dict:
@@ -416,17 +482,13 @@ def _pre_upload_cb(n):
416482
# would result in N redundant log calls (N = number of steps).
417483
for item in record_pb.summary.update:
418484
key = item.key or '/'.join(item.nested_key)
419-
if not key or key.startswith('_'):
485+
if not key or key.startswith('_') or '/' in key:
420486
continue
421487
try:
422488
last_summary[key] = float(item.value_json)
423489
except (ValueError, TypeError):
424490
pass
425491

426-
# 清理公共变量,释放内存
427-
del record_pb
428-
del record_bin
429-
430492
# GC every GC_INTERVAL records to reduce overhead
431493
record_count += 1
432494
if record_count % GC_INTERVAL == 0:
@@ -436,12 +498,6 @@ def _pre_upload_cb(n):
436498
if progress is not None:
437499
progress.stop()
438500

439-
# Log 最终的 summary 数据(只 log 一次,避免对每条 summary record 都 log)
440-
if last_summary:
441-
initialize_swanlab_run_if_needed()
442-
if swanlab_run:
443-
_log_scalars_direct(last_summary, last_summary_step)
444-
445501
if swanlab_run:
446502
swl.info(f"Finished Parsing run: {run_metadata['name']}")
447503
_finish_with_progress()

0 commit comments

Comments
 (0)