Skip to content

Commit 5b6a756

Browse files
NexisatoSAKURA-CAT
andauthored
feat: better metric strategy(#1473)
* feat: support overwrite metric * fix: overwrite existing if explicitly set step * feat: support overwrite under any circumstance * fix: step auto increment * Support metric overwrite and preserve epochs Allow duplicated steps to overwrite existing metric entries while preserving the original epoch mapping. Key changes: add metric_overwrite flag to MetricInfo, maintain per-key _step_epochs and _step_summary_values in SwanLabKey, rebuild summaries on overwrite, and update in-memory collections to replace existing step entries. LocalRunCallback now rewrites log slice files when overwriting (helper _rewrite_metric_file), and several docstrings/log levels updated to reflect overwrite semantics. Tests updated to assert overwrite behavior and epoch preservation. * fix: memory reduce * chore: update warning * refactor: rebuild by step * feat: add parallel run support and update run_id validation (#1491) * Add parallel run support and update run_id validation Introduce a parallel run option and related env var, propagate slug usage, and relax run_id validation: - Add a `parallel` parameter to SwanLabInitializer (and SWANLAB_RUN_PARALLEL env) to enable shared parallel runs; when enabled it forces mode='cloud', resume='allow', and generates an id if missing. - Load `parallel` from config/env and validate it during initialization; minor warning/formatting tweaks. - Add ExperimentInfo.slug property and use it in Client.web_exp_url to prefer exp.slug over exp_id when available. - Update run_id validation: allow lengths 1–64 and disallow characters '/ \ # ? % :', with corresponding updates to tests. - Add missing import (random) required for id generation. Tests updated to reflect new run_id rules and additional valid/invalid cases. * Update sdk.py * Update sdk.py * feat: support overwrite metric * fix: overwrite existing if explicitly set step * feat: support overwrite under any circumstance * fix: step auto increment * Support metric overwrite and preserve epochs Allow duplicated steps to overwrite existing metric entries while preserving the original epoch mapping. Key changes: add metric_overwrite flag to MetricInfo, maintain per-key _step_epochs and _step_summary_values in SwanLabKey, rebuild summaries on overwrite, and update in-memory collections to replace existing step entries. LocalRunCallback now rewrites log slice files when overwriting (helper _rewrite_metric_file), and several docstrings/log levels updated to reflect overwrite semantics. Tests updated to assert overwrite behavior and epoch preservation. * fix: memory reduce * chore: update warning * refactor: rebuild by step --------- Co-authored-by: Kang Li <79990647+SAKURA-CAT@users.noreply.github.com>
1 parent 95f58eb commit 5b6a756

File tree

11 files changed

+382
-58
lines changed

11 files changed

+382
-58
lines changed

swanlab/data/callbacker/local.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import json
3737
import os
38+
import tempfile
3839
from datetime import datetime
3940
from typing import Tuple, Optional, TextIO
4041
from swanlab.toolkit import RuntimeInfo, MetricInfo
@@ -130,11 +131,50 @@ def on_metric_create(self, metric_info: MetricInfo, *args, **kwargs):
130131
os.makedirs(os.path.dirname(metric_info.summary_file_path), exist_ok=True)
131132
with open(metric_info.summary_file_path, "w+", encoding="utf-8") as f:
132133
f.write(json.dumps(metric_info.metric_summary, ensure_ascii=False))
133-
with open(metric_info.metric_file_path, "a", encoding="utf-8") as f:
134-
f.write(json.dumps(metric_info.metric, ensure_ascii=False) + "\n")
134+
if metric_info.metric_overwrite:
135+
self._rewrite_metric_file(metric_info)
136+
else:
137+
with open(metric_info.metric_file_path, "a", encoding="utf-8") as f:
138+
f.write(json.dumps(metric_info.metric, ensure_ascii=False) + "\n")
135139
# ---------------------------------- 保存媒体字节流数据 ----------------------------------
136140
self.porter.trace_metric(metric_info)
137141

142+
@staticmethod
143+
def _rewrite_metric_file(metric_info: MetricInfo) -> None:
144+
serialized = json.dumps(metric_info.metric, ensure_ascii=False) + "\n"
145+
metric_path = metric_info.metric_file_path
146+
147+
try:
148+
f_in = open(metric_path, "r", encoding="utf-8")
149+
except FileNotFoundError:
150+
with open(metric_path, "w", encoding="utf-8") as f:
151+
f.write(serialized)
152+
return
153+
154+
dir_path = os.path.dirname(metric_path)
155+
with f_in, tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", dir=dir_path, delete=False) as tmp:
156+
tmp_path = tmp.name
157+
replaced = False
158+
159+
for line in f_in:
160+
try:
161+
existing = json.loads(line)
162+
except json.JSONDecodeError:
163+
swanlog.warning(f"Failed to decode JSON from line in {metric_path}: {line.strip()}")
164+
tmp.write(line)
165+
continue
166+
167+
if existing.get("index") != metric_info.metric_step:
168+
tmp.write(line)
169+
elif not replaced:
170+
tmp.write(serialized)
171+
replaced = True
172+
173+
if not replaced:
174+
tmp.write(serialized)
175+
176+
os.replace(tmp_path, metric_path)
177+
138178
def on_stop(self, error: str = None, *args, **kwargs):
139179
"""
140180
训练结束,取消系统回调

swanlab/data/run/exp.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,29 @@ def _add(
124124
key_obj: SwanLabKey = self._keys.get(key_index, None)
125125

126126
# ---------------------------------- 包装器解析 ----------------------------------
127-
127+
explicit_step = step is not None
128128
if step is not None and not isinstance(step, int):
129129
swanlog.warning(f"Step {step} is not int, SwanLab will set it automatically.")
130130
step = None
131+
explicit_step = False
132+
131133
if key_obj is None:
132-
step = 0 if step is None or not isinstance(step, int) else step
134+
step = 0 if step is None else step
133135
else:
134-
step = len(key_obj.steps) if step is None else step
136+
if step is None:
137+
# 修复隐式步数的无限覆盖和乱序陷阱:
138+
# 若曾跨步长显式写入,len() 可能会落后于真实的 max step,由此引发相同 step 的持续覆盖
139+
current_len = len(key_obj.steps)
140+
max_step = max(key_obj.steps) if key_obj.steps else -1
141+
step = max(current_len, max_step + 1)
142+
135143
if step in key_obj.steps:
136-
swanlog.debug(f"Step {step} on key {key} already exists, ignored.")
137-
return MetricErrorInfo(column_info=key_obj.column_info, error=DataWrapper.create_duplicate_error())
144+
# 允许 overwrite,但区分显式指定和隐式的碰撞
145+
if explicit_step:
146+
swanlog.debug(f"Step {step} on key {key} already exists, overwriting.")
147+
else:
148+
swanlog.warning(f"Implicit step {step} on key {key} resolved as overwrite, but expected to append.")
138149
data.parse(step=step, key=key)
139-
140150
# ---------------------------------- 图表创建 ----------------------------------
141151

142152
if key_obj is None:

swanlab/data/run/key.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import json
99
import math
10-
from typing import Optional, Tuple
10+
from typing import Dict, Optional, Tuple
1111

1212
from swanlab.data.modules import DataWrapper, Line
1313
from swanlab.env import create_time
@@ -42,6 +42,10 @@ def __init__(
4242
self.key = key
4343
# 当前 key 包含的 step
4444
self.steps = set()
45+
# 当前 key 的 step 与首次写入 epoch 的映射,重复 step 覆盖时需要保持 epoch 不变
46+
self._step_epochs: Dict[int, int] = {}
47+
# 当前 key 的 step 与摘要值的映射,用于覆盖时重建 summary
48+
self._step_summary_values: Dict[int, Optional[object]] = {}
4549
self.column_info: Optional[ColumnInfo] = None
4650
self._media_dir = media_dir
4751
self._log_dir = log_dir
@@ -110,22 +114,24 @@ def add(self, data: DataWrapper) -> MetricInfo:
110114
# 4. 更新 summary 并添加数据
111115
# 如果为Line且为NaN或者INF,不更新summary
112116
r = result.strings or result.float
113-
if not data.type == Line or r not in [Line.nan, Line.inf]:
114-
if self._summary.get("max") is None or r > self._summary["max"]:
115-
self._summary["max"] = r
116-
self._summary["max_step"] = result.step
117-
if self._summary.get("min") is None or r < self._summary["min"]:
118-
self._summary["min"] = r
119-
self._summary["min_step"] = result.step
120-
self._summary["num"] = self._summary.get("num", 0) + 1
121-
self.steps.add(result.step)
122-
swanlog.debug(f"Add data, key: {self.key}, step: {result.step}, data: {r}")
123-
if len(self._collection["data"]) >= self.__slice_size:
124-
self._collection = self.__new_metric_collection()
125-
117+
overwrite = result.step in self._step_epochs
118+
if not overwrite:
119+
self.steps.add(result.step)
120+
self._step_epochs[result.step] = len(self.steps)
121+
epoch = self._step_epochs[result.step]
126122
new_data = self.__new_metric(result.step, r, more=result.more)
127-
self._collection["data"].append(new_data)
128-
epoch = len(self.steps)
123+
124+
# 覆盖写入时,只有当前 step 持有的 extremum 被削弱/移除时才需要全量重建。
125+
needs_rebuild = overwrite and self._should_rebuild_summary_on_overwrite(result.step, data.type, r)
126+
self._set_summary_value(result.step, data.type, r)
127+
if needs_rebuild:
128+
self._rebuild_summary()
129+
else:
130+
self._update_summary_incremental(result.step, data.type, r)
131+
self._update_collection(new_data, result.step, overwrite)
132+
swanlog.debug(
133+
f"{'Overwrite' if overwrite else 'Add'} data, key: {self.key}, step: {result.step}, data: {r}"
134+
)
129135
mu = math.ceil(epoch / self.__slice_size)
130136
return MetricInfo(
131137
column_info=self.column_info,
@@ -137,8 +143,67 @@ def add(self, data: DataWrapper) -> MetricInfo:
137143
metric_file_name=str(mu * self.__slice_size) + ".log",
138144
swanlab_logdir=self._log_dir,
139145
swanlab_media_dir=self._media_dir if result.buffers else None,
146+
metric_overwrite=overwrite,
140147
)
141148

149+
def _set_summary_value(self, step: int, data_type, value) -> None:
150+
if data_type == Line and value in [Line.nan, Line.inf]:
151+
self._step_summary_values[step] = None
152+
return
153+
self._step_summary_values[step] = value
154+
155+
def _should_rebuild_summary_on_overwrite(self, step: int, data_type, value) -> bool:
156+
max_step = self._summary.get("max_step")
157+
min_step = self._summary.get("min_step")
158+
if data_type == Line and value in [Line.nan, Line.inf]:
159+
return step == max_step or step == min_step
160+
161+
current_max = self._summary.get("max")
162+
if step == max_step and current_max is not None and value < current_max:
163+
return True
164+
165+
current_min = self._summary.get("min")
166+
if step == min_step and current_min is not None and value > current_min:
167+
return True
168+
169+
return False
170+
171+
def _update_summary_incremental(self, step: int, data_type, value) -> None:
172+
if data_type == Line and value in [Line.nan, Line.inf]:
173+
return
174+
if self._summary.get("max") is None or value > self._summary["max"]:
175+
self._summary["max"] = value
176+
self._summary["max_step"] = step
177+
if self._summary.get("min") is None or value < self._summary["min"]:
178+
self._summary["min"] = value
179+
self._summary["min_step"] = step
180+
self._summary["num"] = len(self.steps)
181+
182+
def _rebuild_summary(self) -> None:
183+
summary = {"num": len(self.steps)}
184+
for step, _epoch in sorted(self._step_epochs.items(), key=lambda item: item[1]):
185+
value = self._step_summary_values.get(step)
186+
if value is None:
187+
continue
188+
if summary.get("max") is None or value > summary["max"]:
189+
summary["max"] = value
190+
summary["max_step"] = step
191+
if summary.get("min") is None or value < summary["min"]:
192+
summary["min"] = value
193+
summary["min_step"] = step
194+
self._summary = summary
195+
196+
def _update_collection(self, new_data: dict, step: int, overwrite: bool) -> None:
197+
for idx, item in enumerate(self._collection["data"]):
198+
if item["index"] == step:
199+
self._collection["data"][idx] = new_data
200+
return
201+
if overwrite:
202+
return
203+
if len(self._collection["data"]) >= self.__slice_size:
204+
self._collection = self.__new_metric_collection()
205+
self._collection["data"].append(new_data)
206+
142207
def create_column(
143208
self,
144209
key: str,
@@ -301,8 +366,9 @@ def mock_from_remote(
301366
section_type=section_type,
302367
)
303368
key_obj.column_info = column_info
304-
# 5. 设置当前步数,resume 后不允许设置历史步数,所以需要覆盖
369+
# 5. 「同步云端最新 step」设置当前步数,resume 后不允许设置历史步数,所以需要覆盖
305370
if step is not None:
306371
for i in range(step + 1):
307372
key_obj.steps.add(i)
373+
key_obj._step_epochs[i] = i + 1
308374
return key_obj, column_info

swanlab/data/run/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def log(self, data: dict, step: int = None):
272272
For nested dicts, keys will be joined with dots (e.g., {'a': {'b': 1}} becomes {'a.b': 1}).
273273
step : int, optional
274274
The step number of the current data, if not provided, it will be automatically incremented.
275-
If step is duplicated, the data will be ignored.
275+
If step is duplicated, the latest data will overwrite the previous data on that step.
276276
277277
Raises
278278
----------

swanlab/data/sdk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def log(
460460
The value must be a `float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
461461
step : int, optional
462462
The step number of the current data, if not provided, it will be automatically incremented.
463-
If step is duplicated, the data will be ignored.
463+
If step is duplicated, the latest data will overwrite the previous data on that step.
464464
print_to_console : bool, optional
465465
Whether to print the data to the console, the default is False.
466466
"""

swanlab/integration/accelerate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs):
104104
The value must be a `float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
105105
step : int, optional
106106
The step number of the current data, if not provided, it will be automatically incremented.
107-
If step is duplicated, the data will be ignored.
107+
If step is duplicated, the latest data will overwrite the previous data on that step.
108108
kwargs:
109109
Additional key word arguments passed along to the `swanlab.log` method. Likes:
110110
print_to_console : bool, optional

swanlab/toolkit/models/metric.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __init__(
171171
metric_file_name: Optional[str],
172172
swanlab_logdir: Optional[str],
173173
swanlab_media_dir: Optional[str],
174+
metric_overwrite: bool = False,
174175
error: Optional[ParseErrorInfo] = None,
175176
):
176177
"""
@@ -184,6 +185,7 @@ def __init__(
184185
:param metric_file_name: 此指标的文件名
185186
:param swanlab_logdir: swanlab在本次实验的log文件夹路径
186187
:param swanlab_media_dir: swanlab在本次实验的media文件夹路径
188+
:param metric_overwrite: 当前指标是否覆盖了已有 step
187189
:param error: 创建此指标时的错误信息
188190
"""
189191
self.error = error
@@ -193,6 +195,7 @@ def __init__(
193195
self.metric_summary = metric_summary
194196
self.metric_step = metric_step
195197
self.metric_epoch = metric_epoch
198+
self.metric_overwrite = metric_overwrite
196199
_id = self.column_info.kid
197200
self.metric_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, metric_file_name)
198201
self.summary_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, self.__SUMMARY_NAME)
@@ -252,5 +255,6 @@ def __init__(self, column_info: ColumnInfo, error: ParseErrorInfo):
252255
None,
253256
None,
254257
None,
258+
False,
255259
error,
256260
)

test/metrics/echarts/calendar_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// @description: 本文件是对于echarts的 Calendar 图表测试 , 文件名不叫calendar是因为和库文件重名
66
"""
77
# ---------------------------------------------- Calendar - Calendar_heatmap ----------------------------------------------
8+
import tutils as T
89
import random
910
import datetime
1011

0 commit comments

Comments
 (0)