Skip to content

Commit 72dc4dc

Browse files
authored
feat: swanlab.save() file uploader (#1515)
* feat: add file uploader skeleton * fix: update uploader * refactor: adjust save path * refactor: simplify save_file model definition * refactor: simplify model definition * fix: slot args * fix: alignment with interface params * fix: default mime_type * chore: add default mime_type * fix: mime_type args * refactor: poll file_sig with timer * chore: add save hints * chore: modify resp field * fix: rename save_utils ut * chore: add matrix test on multiplatform * Revert "chore: add matrix test on multiplatform" This reverts commit 05c35b2. * fix: symlink unused * chore: simplify & add warning * feat: add upload progress * chore: update doc string of save_progress * feat: add file max_size constraint * refactor: save logic Normalize and harden save/upload logic: _iter_files now returns a list, FileUploadManager short-circuits in disabled mode earlier, multipart upload was refactored into _upload_multipart with proper part buffering and failure handling, and a _mark_failed helper was added to reliably report FAILED state. DirWatcher now guards _target_modes with a dedicated lock and exposes _get/_set/_clear helpers to ensure thread-safe reads/writes and avoid races when resolving symlink vs copy targets. Progress start logic was simplified to call Status.start() without touching private rich internals. SDK save path resolution gains _infer_default_save_base_path to better infer base_path for absolute globs; save() now returns List[str] and short-circuits when mode is disabled. Unit tests were added/updated to cover these behaviors. * fix: progress upload sync * fix: future error * feat: constraint file size and count
1 parent 2b75ff8 commit 72dc4dc

File tree

18 files changed

+2030
-41
lines changed

18 files changed

+2030
-41
lines changed

swanlab/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"merge_settings",
3434
"init",
3535
"log",
36+
"save",
3637
"register_callbacks",
3738
"finish",
3839
"Audio",

swanlab/core_python/api/experiment/__init__.py

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
@description: 定义实验相关的后端API接口
66
"""
77

8-
from typing import Literal, Dict, TYPE_CHECKING, List, Union
8+
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Union
99

1010
from swanlab.core_python.api.type import RunType
11-
from .utils import to_camel_case, parse_column_type
11+
12+
from .utils import (
13+
parse_column_type,
14+
to_camel_case,
15+
unwrap_api_payload,
16+
)
1217

1318
if TYPE_CHECKING:
1419
from swanlab.core_python.client import Client
@@ -35,7 +40,7 @@ def update_experiment_state(
3540
username: str,
3641
projname: str,
3742
cuid: str,
38-
state: Literal['FINISHED', 'CRASHED', 'ABORTED'],
43+
state: Literal["FINISHED", "CRASHED", "ABORTED"],
3944
finished_at: str = None,
4045
):
4146
"""
@@ -76,7 +81,7 @@ def get_project_experiments(
7681
- 'job_type': 按任务类型筛选,值为字符串
7782
"""
7883
# 特殊筛选条件配置:用户侧 key -> 后端 key 和操作符
79-
SPECIAL_FILTER_CONFIG = {
84+
special_filter_config = {
8085
"group": {"key": "cluster", "op": "EQ"},
8186
"tags": {"key": "labels", "op": "IN"},
8287
"name": {"key": "name", "op": "EQ"},
@@ -88,33 +93,39 @@ def get_project_experiments(
8893

8994
if filters:
9095
for key, value in filters.items():
91-
if key in SPECIAL_FILTER_CONFIG:
96+
if key in special_filter_config:
9297
# 特殊字段处理
93-
config = SPECIAL_FILTER_CONFIG[key]
98+
config = special_filter_config[key]
9499
# tags 需要转换为列表
95-
filter_value = list(value) if key == "tags" and isinstance(value, (list, tuple)) else [value]
100+
filter_value = (
101+
list(value)
102+
if key == "tags" and isinstance(value, (list, tuple))
103+
else [value]
104+
)
96105
parsed_filters.append(
97106
{
98107
"key": config["key"],
99108
"active": True,
100109
"value": filter_value,
101110
"op": config["op"],
102-
"type": 'STABLE',
111+
"type": "STABLE",
103112
}
104113
)
105114
else:
106115
# 常规字段处理
107116
parsed_filters.append(
108117
{
109-
"key": to_camel_case(key) if parse_column_type(key) == 'STABLE' else key.split('.', 1)[-1],
118+
"key": to_camel_case(key)
119+
if parse_column_type(key) == "STABLE"
120+
else key.split(".", 1)[-1],
110121
"active": True,
111122
"value": [value],
112-
"op": 'EQ',
123+
"op": "EQ",
113124
"type": parse_column_type(key),
114125
}
115126
)
116127

117-
res = client.post(f"/project/{path}/runs/shows", data={'filters': parsed_filters})
128+
res = client.post(f"/project/{path}/runs/shows", data={"filters": parsed_filters})
118129
return res[0]
119130

120131

@@ -125,7 +136,7 @@ def get_single_experiment(client: "Client", *, path: str) -> RunType:
125136
:param client: 已登录的客户端实例
126137
:param path: 实验路径 username/project/expid
127138
"""
128-
proj_path, expid = path.rsplit('/', 1)
139+
proj_path, expid = path.rsplit("/", 1)
129140
res = client.get(f"/project/{proj_path}/runs/{expid}")
130141
return res[0]
131142

@@ -137,7 +148,7 @@ def get_experiment_metrics(client: "Client", *, expid: str, key: str) -> Dict[st
137148
:param expid: 实验cuid
138149
:param key: 指定字段列表
139150
"""
140-
res = client.get(f"/experiment/{expid}/column/csv", params={'key': key})
151+
res = client.get(f"/experiment/{expid}/column/csv", params={"key": key})
141152
return res[0]
142153

143154

@@ -147,15 +158,78 @@ def delete_experiment(client: "Client", *, path: str):
147158
:param client: 已登录的客户端实例
148159
:param path: 实验路径 'username/project/expid'
149160
"""
150-
proj_path, expid = path.rsplit('/', 1)
161+
proj_path, expid = path.rsplit("/", 1)
151162
client.delete(f"/project/{proj_path}/runs/{expid}")
152163

153164

165+
def prepare_upload(
166+
client: "Client", exp_id: str, files: Iterable[Dict[str, object]]
167+
) -> List[str]:
168+
"""
169+
创建普通文件上传任务,返回预签名上传地址列表。
170+
"""
171+
payload_files = list(files)
172+
if not payload_files:
173+
return []
174+
data, _ = client.post(
175+
f"/experiment/{exp_id}/files/prepare", {"files": payload_files}
176+
)
177+
result = unwrap_api_payload(data)
178+
if isinstance(result, dict):
179+
urls = result.get("urls", [])
180+
return urls if isinstance(urls, list) else []
181+
return []
182+
183+
184+
def complete_upload(
185+
client: "Client", exp_id: str, files: Iterable[Dict[str, object]]
186+
) -> None:
187+
"""
188+
标记普通文件上传完成。
189+
"""
190+
payload_files = list(files)
191+
if not payload_files:
192+
return
193+
client.post(f"/experiment/{exp_id}/files/complete", {"files": payload_files})
194+
195+
196+
def prepare_multipart(
197+
client: "Client", exp_id: str, file: Dict[str, object]
198+
) -> Dict[str, object]:
199+
"""
200+
创建分片上传任务,返回 uploadId 和分片上传地址列表。
201+
"""
202+
data, _ = client.post(
203+
f"/experiment/{exp_id}/files/prepare-multipart",
204+
{"files": [file]},
205+
)
206+
result = unwrap_api_payload(data)
207+
if isinstance(result, dict):
208+
files = result.get("files", [])
209+
if files and isinstance(files, list):
210+
return files[0]
211+
raise ValueError("Multipart prepare API returned empty file payloads.")
212+
213+
214+
def complete_multipart(client: "Client", exp_id: str, file: Dict[str, object]) -> None:
215+
"""
216+
标记分片上传完成。
217+
"""
218+
client.post(
219+
f"/experiment/{exp_id}/files/complete-multipart",
220+
{"files": [file]},
221+
)
222+
223+
154224
__all__ = [
155225
"send_experiment_heartbeat",
156226
"update_experiment_state",
157227
"get_project_experiments",
158228
"get_single_experiment",
159229
"get_experiment_metrics",
160230
"delete_experiment",
231+
"prepare_upload",
232+
"complete_upload",
233+
"prepare_multipart",
234+
"complete_multipart",
161235
]

swanlab/core_python/api/experiment/utils.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,69 @@
55
@description: 实验相关的后端API接口中的工具函数
66
"""
77

8+
from typing import Dict, List, Optional, Tuple
9+
810
from swanlab.core_python.api.type import ColumnType
911

1012

1113
# 从前缀中获取指标类型
1214
def parse_column_type(column: str) -> ColumnType:
13-
column_type = column.split('.', 1)[0]
14-
if column_type == 'summary':
15-
return 'SCALAR'
16-
elif column_type == 'config':
17-
return 'CONFIG'
15+
column_type = column.split(".", 1)[0]
16+
if column_type == "summary":
17+
return "SCALAR"
18+
elif column_type == "config":
19+
return "CONFIG"
1820
else:
19-
return 'STABLE'
21+
return "STABLE"
2022

2123

2224
# 将下划线命名转化为驼峰命名
2325
def to_camel_case(name: str) -> str:
24-
return ''.join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split('_'))])
26+
return "".join(
27+
[w.capitalize() if i > 0 else w for i, w in enumerate(name.split("_"))]
28+
)
29+
30+
31+
def unwrap_api_payload(data):
32+
if (
33+
isinstance(data, dict)
34+
and "data" in data
35+
and isinstance(data["data"], (dict, list))
36+
):
37+
return data["data"]
38+
return data
39+
40+
41+
def extract_upload_id(payload: Dict[str, object]) -> Optional[str]:
42+
upload_id = payload.get("uploadId")
43+
if isinstance(upload_id, str) and upload_id:
44+
return upload_id
45+
return None
46+
47+
48+
49+
def extract_part_urls(payload: Dict[str, object]) -> List[Tuple[int, str]]:
50+
parts = payload.get("parts")
51+
if not isinstance(parts, list):
52+
raise ValueError("Multipart upload URLs are missing in prepare response.")
53+
54+
resolved = []
55+
for part in parts:
56+
if not isinstance(part, dict):
57+
raise ValueError("Multipart prepare response contains invalid part data.")
58+
part_number = part.get("partNumber")
59+
url = part.get("url")
60+
if not isinstance(part_number, int) or not isinstance(url, str) or not url:
61+
raise ValueError("Invalid partNumber or url in multipart response.")
62+
resolved.append((part_number, url))
63+
64+
return sorted(resolved, key=lambda item: item[0])
65+
66+
67+
__all__ = [
68+
"parse_column_type",
69+
"to_camel_case",
70+
"unwrap_api_payload",
71+
"extract_upload_id",
72+
"extract_part_urls",
73+
]

swanlab/core_python/api/service.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99
from concurrent.futures import ThreadPoolExecutor
1010
from io import BytesIO
11-
from typing import List, Tuple
11+
from typing import List, Optional, Tuple
1212

1313
import requests
1414
from requests.exceptions import RequestException
@@ -18,7 +18,9 @@
1818
from ...toolkit.models.data import MediaBuffer
1919

2020

21-
def upload_file(*, url: str, buffer: BytesIO, max_retries=3):
21+
MIME_TYPE_DEFAULT: str = "application/octet-stream"
22+
23+
def upload_file(*, url: str, buffer: BytesIO, max_retries=3, mime_type: str=MIME_TYPE_DEFAULT) -> Optional[str]:
2224
"""
2325
上传文件到COS
2426
:param url: COS上传URL
@@ -33,13 +35,16 @@ def upload_file(*, url: str, buffer: BytesIO, max_retries=3):
3335
response = session.put(
3436
url,
3537
data=buffer,
36-
headers={'Content-Type': 'application/octet-stream'},
38+
headers={"Content-Type": mime_type},
3739
timeout=30,
3840
)
3941
response.raise_for_status()
40-
return
42+
etag = response.headers.get("ETag")
43+
return etag if etag else None
4144
except RequestException:
42-
swanlog.warning("Upload attempt {} failed for URL: {}".format(attempt, url))
45+
swanlog.warning(
46+
"Upload attempt {} failed for URL: {}".format(attempt, url)
47+
)
4348
# 如果是最后一次尝试,抛出异常
4449
if attempt == max_retries:
4550
raise
@@ -57,10 +62,10 @@ def upload_to_cos(client: Client, *, cuid: str, buffers: List[MediaBuffer]):
5762
failed_buffers: List[Tuple[str, MediaBuffer]] = []
5863
# 1. 后端签名
5964
data, _ = client.post(
60-
'/resources/presigned/put',
65+
"/resources/presigned/put",
6166
{"experimentId": cuid, "paths": [buffer.file_name for buffer in buffers]},
6267
)
63-
urls: List[str] = data['urls']
68+
urls: List[str] = data["urls"]
6469
# 2. 并发上传
6570
# executor.submit可能会失败,因为线程数有限或者线程池已经关闭
6671
# 来自此issue: https://github.com/SwanHubX/SwanLab/issues/889,此时需要一个个发送
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from .manager import DirWatcher, FileUploadManager
2+
from .model import SaveFileState, SaveFileModel
3+
from .utils import (
4+
collect_save_files,
5+
compute_md5,
6+
file_signature,
7+
guess_mime_type,
8+
validate_glob_path,
9+
)
10+
11+
__all__ = [
12+
"SaveFileState",
13+
"SaveFileModel",
14+
"collect_save_files",
15+
"validate_glob_path",
16+
"compute_md5",
17+
"guess_mime_type",
18+
"file_signature",
19+
"FileUploadManager",
20+
"DirWatcher",
21+
]

0 commit comments

Comments
 (0)