Skip to content

Commit f74040b

Browse files
authored
feat: add parallel run support and update run_id validation (#1491)
* Add parallel run support and update run_id validation Introduce a parallel run option and related env var, propagate slug usage, and relax run_id validation: - Add a `parallel` parameter to SwanLabInitializer (and SWANLAB_RUN_PARALLEL env) to enable shared parallel runs; when enabled it forces mode='cloud', resume='allow', and generates an id if missing. - Load `parallel` from config/env and validate it during initialization; minor warning/formatting tweaks. - Add ExperimentInfo.slug property and use it in Client.web_exp_url to prefer exp.slug over exp_id when available. - Update run_id validation: allow lengths 1–64 and disallow characters '/ \ # ? % :', with corresponding updates to tests. - Add missing import (random) required for id generation. Tests updated to reflect new run_id rules and additional valid/invalid cases. * Update sdk.py * Update sdk.py
1 parent b083cdb commit f74040b

File tree

6 files changed

+56
-22
lines changed

6 files changed

+56
-22
lines changed

swanlab/core_python/client/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def web_proj_url(self) -> str:
113113

114114
@property
115115
def web_exp_url(self) -> str:
116-
return f"{self.web_proj_url}/runs/{self.exp_id}"
116+
return f"{self.web_proj_url}/runs/{self.exp.slug or self.exp_id}"
117117

118118
# ---------------------------------- http方法 ----------------------------------
119119

swanlab/core_python/client/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def flag_id(self):
8383
def cuid(self):
8484
return self.__data["cuid"]
8585

86+
@property
87+
def slug(self):
88+
return self.__data["slug"]
89+
8690
@property
8791
def name(self):
8892
return self.__data["name"]

swanlab/data/sdk.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
在此处封装swanlab在日志记录模式下的各种接口
99
"""
1010
import os
11+
import random
12+
import secrets
1113
import time
1214
from datetime import datetime
1315
from typing import Union, Dict, Literal, List
@@ -108,6 +110,7 @@ def init(
108110
settings: Settings = None,
109111
id: str = None,
110112
resume: Union[Literal['must', 'allow', 'never'], bool] = None,
113+
parallel: Union[Literal['none', 'shared'], bool] = None,
111114
reinit: bool = None,
112115
color: str = None,
113116
**kwargs,
@@ -186,6 +189,10 @@ def init(
186189
- never: You cannot pass the `id` parameter, and a new run will be created.
187190
You can also pass a boolean value, where `True` is equivalent to 'allow' and `False` is equivalent to 'never'.
188191
[Notice that] This parameter is only valid when mode='cloud'
192+
parallel : Literal['none', 'shared'], optional
193+
Whether to run experiments in parallel or not:
194+
- none: Run experiments sequentially.
195+
- shared: Run experiments in parallel, equivalent to `mode='cloud' and resume='allow'`.
189196
id : str, optional
190197
The run ID of the previous run, which is used to resume the previous run.
191198
reinit : bool, optional
@@ -242,6 +249,7 @@ def init(
242249
public = _load_from_dict(load_data, "private", public)
243250
id = _load_from_dict(load_data, "id", id)
244251
resume = _load_from_dict(load_data, "resume", resume)
252+
parallel = _load_from_dict(load_data, "parallel", parallel)
245253
# 1.2 初始化confi参数
246254
config = _init_config(config)
247255
# 如果config是以下几个类别之一,则抛出异常
@@ -263,6 +271,7 @@ def init(
263271
id = _load_from_env(SwanLabEnv.RUN_ID.value, id)
264272
logdir = _load_from_env(SwanLabEnv.SWANLOG_FOLDER.value, logdir)
265273
color = _load_from_env(SwanLabEnv.EXP_COLOR.value, color)
274+
parallel = _load_from_env(SwanLabEnv.RUN_PARALLEL.value, parallel)
266275
# 2. 部分格式校验
267276
# 2.1 校验项目名称,默认实验名称为当前目录名
268277
project = project if project else os.path.basename(os.getcwd())
@@ -273,7 +282,9 @@ def init(
273282
# 2.2 校验实验名称
274283
# 处理空字符串或纯空格字符串情况
275284
if experiment_name is not None and not experiment_name.strip():
276-
swanlog.warning("The experiment name is an empty or whitespace-only string, automatically converted to None.")
285+
swanlog.warning(
286+
"The experiment name is an empty or whitespace-only string, automatically converted to None."
287+
)
277288
experiment_name = None
278289
if experiment_name:
279290
e = check_exp_name_format(experiment_name)
@@ -315,6 +326,11 @@ def init(
315326
# 3. 校验回调函数
316327
callbacks = check_callback_format(self.cbs + callbacks)
317328
self.cbs = []
329+
# 4. 校验并行模式
330+
if str(parallel).lower() in ["shared", "true", "yes"]:
331+
resume = "allow"
332+
mode = "cloud"
333+
id = id or secrets.token_hex(4)
318334
# 5. 校验mode参数并适配 backup 模式
319335
mode = "cloud" if mode == "online" else mode
320336
mode, login_info = _init_mode(mode, folder_settings.mode)
@@ -325,7 +341,9 @@ def init(
325341
resume = resume or 'never'
326342
# 非 cloud 模式下,resume 只支持 'never'
327343
if resume in ('must', 'allow') and mode != "cloud":
328-
swanlog.warning(f"resume='{resume}' is only supported in cloud mode, automatically switch to resume='never'.")
344+
swanlog.warning(
345+
f"resume='{resume}' is only supported in cloud mode, automatically switch to resume='never'."
346+
)
329347
resume = 'never'
330348
id = None
331349
# 根据 resume 的最终值进行校验

swanlab/env.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ class SwanLabEnv(enum.Enum):
141141
"""
142142
禁用Git功能,设置为true时不会采集Git信息
143143
"""
144+
RUN_PARALLEL = "SWANLAB_RUN_PARALLEL"
145+
"""
146+
是否开启并行模式,设置为true时会并行运行实验,等同于mode='cloud' and resume='allow'
147+
"""
144148

145149
@staticmethod
146150
def is_hostname(value: str) -> bool:

swanlab/formatter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,19 @@ def check_group_format(group: str, auto_cut: bool = True) -> str:
188188
def check_run_id_format(run_id: str = None) -> Optional[str]:
189189
"""
190190
检查运行ID格式,要求:
191-
1. 只能包含小写字母和数字
192-
2. 长度为21个字符
191+
1. 长度为1-64个字符
192+
2. 不能包含 /, \\, #, ?, %, : 字符
193193
:param run_id: 运行ID字符串
194194
:return: str 检查后的字符串
195195
:raises ValueError: 如果运行ID不符合要求
196196
"""
197197
if not run_id:
198198
return None
199199
run_id_str = str(run_id)
200-
if not re.match(r"^[a-z0-9]{21}$", run_id_str):
201-
raise ValueError(f"id `{run_id}` is invalid, it must be 21 characters of lowercase letters and digits")
200+
if not (1 <= len(run_id_str) <= 64):
201+
raise ValueError(f"id `{run_id}` is invalid, length must be between 1 and 64 characters")
202+
if re.search(r'[/\\#?%:]', run_id_str):
203+
raise ValueError(f"id `{run_id}` is invalid, it must not contain /, \\, #, ?, %, or :")
202204
return run_id_str
203205

204206

test/unit/test_formatter.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -255,30 +255,36 @@ def test_tag_no_cut(self, value: str):
255255

256256
class TestRunIdFormat:
257257
@staticmethod
258-
def test_run_id_format_valid_run_id():
258+
def test_run_id_format_valid_short():
259+
assert check_run_id_format("abc") == "abc"
260+
261+
@staticmethod
262+
def test_run_id_format_valid_21_chars():
259263
assert check_run_id_format("abcdefghijklmnopqrstu") == "abcdefghijklmnopqrstu"
260264

261265
@staticmethod
262-
def test_run_id_format_invalid_length():
263-
with pytest.raises(
264-
ValueError, match=r"id .* is invalid, it must be 21 characters of lowercase letters and digits"
265-
):
266-
check_run_id_format("shortid")
266+
def test_run_id_format_valid_64_chars():
267+
value = "a" * 64
268+
assert check_run_id_format(value) == value
267269

268270
@staticmethod
269-
def test_run_id_format_invalid_characters():
270-
with pytest.raises(
271-
ValueError, match=r"id .* is invalid, it must be 21 characters of lowercase letters and digits"
272-
):
273-
check_run_id_format("abc123!@#def456ghi789")
271+
def test_run_id_format_valid_mixed_chars():
272+
assert check_run_id_format("Hello_World-2024.run") == "Hello_World-2024.run"
274273

275274
@staticmethod
276-
def test_run_id_format_none_input():
277-
assert check_run_id_format(None) is None
275+
def test_run_id_format_invalid_too_long():
276+
with pytest.raises(ValueError, match=r"id .* is invalid, length must be between 1 and 64 characters"):
277+
check_run_id_format("a" * 65)
278278

279279
@staticmethod
280-
def test_run_id_format_numeric_input():
281-
assert check_run_id_format(123456789012345678901) == "123456789012345678901"
280+
@pytest.mark.parametrize("invalid_id", ["abc/def", "abc\\def", "abc#def", "abc?def", "abc%def", "abc:def"])
281+
def test_run_id_format_invalid_characters(invalid_id):
282+
with pytest.raises(ValueError, match=r"id .* is invalid, it must not contain"):
283+
check_run_id_format(invalid_id)
284+
285+
@staticmethod
286+
def test_run_id_format_none_input():
287+
assert check_run_id_format(None) is None
282288

283289
@staticmethod
284290
def test_run_id_format_empty_string():

0 commit comments

Comments
 (0)