Skip to content

Commit b083cdb

Browse files
Zeyi-Linclaude
andauthored
feat: add color parameter to swanlab.init for experiment color control (#1481)
* feat: add color parameter to swanlab.init for experiment color control - Add color parameter supporting preset colors, RGB, and hex formats - Add SWANLAB_EXP_COLOR environment variable - Support color formats: preset names (green, blue, etc), rgb(r,g,b), #hex - Update callbackers to respect user-defined colors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: improve color format validation code quality - Merge RGB regex patterns using optional group (?:rgb\s*)? - Use re.fullmatch for hex color validation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * improve --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent cf1301d commit b083cdb

File tree

6 files changed

+75
-6
lines changed

6 files changed

+75
-6
lines changed

swanlab/data/callbacker/disabled.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
class DisabledCallback(SwanLabRunCallback):
1414
def on_init(self, proj_name: str, workspace: str, public: bool = None, logdir: str = None, *args, **kwargs):
1515
self.run_store.run_name = "run-disabled"
16-
self.run_store.run_colors = N.generate_colors(0)
16+
if self.run_store.run_colors is None:
17+
self.run_store.run_colors = N.generate_colors(0)
1718
self.run_store.run_id = N.generate_run_id()
1819
self.run_store.new = True
1920

swanlab/data/callbacker/local.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def on_init(self, proj_name: str, workspace: str, public: bool = None, logdir: s
7878
run_store.tags = [] if run_store.tags is None else run_store.tags
7979
exp_count = random.randint(0, 20)
8080
run_store.run_name = N.generate_name(exp_count) if run_store.run_name is None else run_store.run_name
81-
run_store.run_colors = generate_colors(random.randint(0, 20))
81+
if run_store.run_colors is None:
82+
run_store.run_colors = generate_colors(exp_count)
8283
run_store.run_id = N.generate_run_id()
8384
run_store.new = True
8485

swanlab/data/callbacker/offline.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@ def on_init(self, proj_name: str, workspace: str, public: bool = None, logdir: s
3030
run_store.workspace = workspace
3131
run_store.visibility = public
3232
run_store.tags = [] if run_store.tags is None else run_store.tags
33-
# 设置颜色,随机生成一个
33+
# 设置颜色和名称
3434
exp_count = random.randint(0, 20)
35-
run_store.run_colors = N.generate_colors(exp_count)
36-
# 设置名称,随机生成
35+
if run_store.run_colors is None:
36+
run_store.run_colors = N.generate_colors(exp_count)
3737
run_store.run_name = N.generate_name(exp_count) if run_store.run_name is None else run_store.run_name
38-
run_store.run_colors = N.generate_colors(exp_count)
3938
run_store.run_id = N.generate_run_id()
4039
run_store.new = True
4140

swanlab/data/sdk.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
check_desc_format,
2424
check_tags_format,
2525
check_run_id_format,
26+
check_color_format,
2627
)
2728
from swanlab.log import swanlog
2829
from swanlab.swanlab_settings import Settings, get_settings, set_settings, read_folder_settings
@@ -108,6 +109,7 @@ def init(
108109
id: str = None,
109110
resume: Union[Literal['must', 'allow', 'never'], bool] = None,
110111
reinit: bool = None,
112+
color: str = None,
111113
**kwargs,
112114
) -> SwanLabRun:
113115
"""
@@ -188,6 +190,10 @@ def init(
188190
The run ID of the previous run, which is used to resume the previous run.
189191
reinit : bool, optional
190192
Whether to reinitialize SwanLabRun, the default is False.
193+
color : str, optional
194+
The experiment color displayed in the web interface.
195+
Supports preset colors (e.g., "green"), RGB format (e.g., "rgb(82,141,89)"),
196+
or hex format (e.g., "#528d59", "528d59").
191197
"""
192198
# 一个进程同时只能有一个实验在运行
193199
if SwanLabRun.is_started():
@@ -256,6 +262,7 @@ def init(
256262
resume = _load_from_env(SwanLabEnv.RESUME.value, resume)
257263
id = _load_from_env(SwanLabEnv.RUN_ID.value, id)
258264
logdir = _load_from_env(SwanLabEnv.SWANLOG_FOLDER.value, logdir)
265+
color = _load_from_env(SwanLabEnv.EXP_COLOR.value, color)
259266
# 2. 部分格式校验
260267
# 2.1 校验项目名称,默认实验名称为当前目录名
261268
project = project if project else os.path.basename(os.getcwd())
@@ -298,6 +305,13 @@ def init(
298305
if tags[i] != new_tags[i]:
299306
swanlog.warning("The tag has been truncated automatically.")
300307
tags[i] = new_tags[i]
308+
# 2.7 校验颜色格式
309+
if color:
310+
try:
311+
color = check_color_format(color)
312+
except ValueError as e:
313+
swanlog.warning(f"Invalid color format: {e}, will use random color")
314+
color = None
301315
# 3. 校验回调函数
302316
callbacks = check_callback_format(self.cbs + callbacks)
303317
self.cbs = []
@@ -352,6 +366,7 @@ def init(
352366
run_store.tags = tags
353367
run_store.description = description
354368
run_store.run_name = experiment_name
369+
run_store.run_colors = (color, color) if color else None
355370
run_store.swanlog_dir = logdir
356371
# 2. 启动操作员,注册运行实例
357372
operator = _create_operator(mode, login_info, callbacks)

swanlab/env.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class SwanLabEnv(enum.Enum):
133133
"""
134134
实验标签,用于标注当前实验,多个标签用逗号分隔
135135
"""
136+
EXP_COLOR = "SWANLAB_EXP_COLOR"
137+
"""
138+
实验颜色,用于控制实验在网页端的显示颜色
139+
"""
136140
DISABLE_GIT = "SWANLAB_DISABLE_GIT"
137141
"""
138142
禁用Git功能,设置为true时不会采集Git信息

swanlab/formatter.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,52 @@ def check_callback_format(callback: Optional[Union[SwanKitCallback, List[SwanKit
256256
if not isinstance(callback, list):
257257
raise TypeError(f"Only support SwanKitCallback or List[SwanKitCallback], but got {type(callback)}")
258258
return callback
259+
260+
261+
PRESET_COLORS = {
262+
"green": "#528d59",
263+
"blue": "#587ad2",
264+
"red": "#c24d46",
265+
"cyan": "#6ebad3",
266+
"orange": "#dfb142",
267+
"purple": "#6d4ba4",
268+
"pink": "#d47694",
269+
"brown": "#905f4a",
270+
"gray": "#989fa3",
271+
}
272+
273+
274+
def check_color_format(color: str) -> Optional[str]:
275+
"""
276+
验证并转换颜色格式为十六进制
277+
支持:预置颜色名、RGB格式、十六进制
278+
:param color: 颜色字符串
279+
:return: 十六进制颜色字符串
280+
:raises ValueError: 颜色格式不正确
281+
"""
282+
if not color:
283+
return None
284+
285+
color = color.strip().lower()
286+
287+
# 预置颜色
288+
if color in PRESET_COLORS:
289+
return PRESET_COLORS[color]
290+
291+
# RGB格式
292+
rgb_match = re.match(r'(?:rgb\s*)?\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)', color)
293+
if rgb_match:
294+
r, g, b = map(int, rgb_match.groups())
295+
if all(0 <= v <= 255 for v in [r, g, b]):
296+
return f"#{r:02x}{g:02x}{b:02x}"
297+
raise ValueError(f"RGB values must be 0-255: ({r}, {g}, {b})")
298+
299+
# 十六进制格式
300+
hex_color = color.lstrip("#")
301+
if len(hex_color) == 3:
302+
hex_color = "".join(c * 2 for c in hex_color)
303+
304+
if re.fullmatch(r'[0-9a-f]{6}', hex_color):
305+
return f"#{hex_color}"
306+
307+
raise ValueError(f"Invalid color format: {color}")

0 commit comments

Comments
 (0)