|
23 | 23 | check_desc_format, |
24 | 24 | check_tags_format, |
25 | 25 | check_run_id_format, |
| 26 | + check_color_format, |
26 | 27 | ) |
27 | 28 | from swanlab.log import swanlog |
28 | 29 | from swanlab.swanlab_settings import Settings, get_settings, set_settings, read_folder_settings |
@@ -108,6 +109,7 @@ def init( |
108 | 109 | id: str = None, |
109 | 110 | resume: Union[Literal['must', 'allow', 'never'], bool] = None, |
110 | 111 | reinit: bool = None, |
| 112 | + color: str = None, |
111 | 113 | **kwargs, |
112 | 114 | ) -> SwanLabRun: |
113 | 115 | """ |
@@ -188,6 +190,10 @@ def init( |
188 | 190 | The run ID of the previous run, which is used to resume the previous run. |
189 | 191 | reinit : bool, optional |
190 | 192 | Whether to reinitialize SwanLabRun, the default is False. |
| 193 | + color : str, optional |
| 194 | + The experiment color displayed in the web interface. |
| 195 | + Supports preset colors (e.g., "green"), RGB format (e.g., "rgb(82,141,89)"), |
| 196 | + or hex format (e.g., "#528d59", "528d59"). |
191 | 197 | """ |
192 | 198 | # 一个进程同时只能有一个实验在运行 |
193 | 199 | if SwanLabRun.is_started(): |
@@ -256,6 +262,7 @@ def init( |
256 | 262 | resume = _load_from_env(SwanLabEnv.RESUME.value, resume) |
257 | 263 | id = _load_from_env(SwanLabEnv.RUN_ID.value, id) |
258 | 264 | logdir = _load_from_env(SwanLabEnv.SWANLOG_FOLDER.value, logdir) |
| 265 | + color = _load_from_env(SwanLabEnv.EXP_COLOR.value, color) |
259 | 266 | # 2. 部分格式校验 |
260 | 267 | # 2.1 校验项目名称,默认实验名称为当前目录名 |
261 | 268 | project = project if project else os.path.basename(os.getcwd()) |
@@ -298,6 +305,13 @@ def init( |
298 | 305 | if tags[i] != new_tags[i]: |
299 | 306 | swanlog.warning("The tag has been truncated automatically.") |
300 | 307 | tags[i] = new_tags[i] |
| 308 | + # 2.7 校验颜色格式 |
| 309 | + if color: |
| 310 | + try: |
| 311 | + color = check_color_format(color) |
| 312 | + except ValueError as e: |
| 313 | + swanlog.warning(f"Invalid color format: {e}, will use random color") |
| 314 | + color = None |
301 | 315 | # 3. 校验回调函数 |
302 | 316 | callbacks = check_callback_format(self.cbs + callbacks) |
303 | 317 | self.cbs = [] |
@@ -352,6 +366,7 @@ def init( |
352 | 366 | run_store.tags = tags |
353 | 367 | run_store.description = description |
354 | 368 | run_store.run_name = experiment_name |
| 369 | + run_store.run_colors = (color, color) if color else None |
355 | 370 | run_store.swanlog_dir = logdir |
356 | 371 | # 2. 启动操作员,注册运行实例 |
357 | 372 | operator = _create_operator(mode, login_info, callbacks) |
|
0 commit comments