Skip to content

Commit 12dbcd4

Browse files
authored
feat: api support more filter params (#1510)
* tags and group * job_type, name, user.username
1 parent e5e12e5 commit 12dbcd4

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

swanlab/api/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,13 @@ def runs(self, path: str, filters: Dict[str, object] = None) -> Experiments:
140140
"""
141141
获取指定项目下的所有实验信息
142142
:param path: 项目路径,格式为 'username/project'
143+
:param filters: 筛选实验的条件,可选。支持以下特殊 key:
144+
- 'group': 按分组名称筛选,值为字符串
145+
- 'tags': 按标签筛选,值为字符串列表
146+
- 'name': 按实验名筛选,值为字符串
147+
- 'username': 按创建人筛选,值为字符串
148+
- 'job_type': 按任务类型筛选,值为字符串
143149
:return: Experiments 实例,可遍历获取实验信息
144-
:param filters: 筛选实验的条件,可选
145150
"""
146151
return Experiments(self._client, path=path, login_info=self._login_info, filters=filters)
147152

swanlab/core_python/api/experiment/__init__.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,22 +68,52 @@ def get_project_experiments(
6868
若有实验分组,则返回一个字典,使用时需递归展平实验数据
6969
:param client: 已登录的客户端实例
7070
:param path: 项目路径 username/project
71-
:param filters: 筛选实验的条件,可选
71+
:param filters: 筛选实验的条件,可选。支持以下特殊 key:
72+
- 'group': 按分组名称筛选,值为字符串
73+
- 'tags': 按标签筛选,值为字符串列表
74+
- 'name': 按实验名筛选,值为字符串
75+
- 'username': 按创建人筛选,值为字符串
76+
- 'job_type': 按任务类型筛选,值为字符串
7277
"""
73-
parsed_filters = (
74-
[
75-
{
76-
"key": to_camel_case(key) if parse_column_type(key) == 'STABLE' else key.split('.', 1)[-1],
77-
"active": True,
78-
"value": [value],
79-
"op": 'EQ',
80-
"type": parse_column_type(key),
81-
}
82-
for key, value in filters.items()
83-
]
84-
if filters
85-
else []
86-
)
78+
# 特殊筛选条件配置:用户侧 key -> 后端 key 和操作符
79+
SPECIAL_FILTER_CONFIG = {
80+
"group": {"key": "cluster", "op": "EQ"},
81+
"tags": {"key": "labels", "op": "IN"},
82+
"name": {"key": "name", "op": "EQ"},
83+
"username": {"key": "user.username", "op": "EQ"},
84+
"job_type": {"key": "job", "op": "EQ"},
85+
}
86+
87+
parsed_filters = []
88+
89+
if filters:
90+
for key, value in filters.items():
91+
if key in SPECIAL_FILTER_CONFIG:
92+
# 特殊字段处理
93+
config = SPECIAL_FILTER_CONFIG[key]
94+
# tags 需要转换为列表
95+
filter_value = list(value) if key == "tags" and isinstance(value, (list, tuple)) else [value]
96+
parsed_filters.append(
97+
{
98+
"key": config["key"],
99+
"active": True,
100+
"value": filter_value,
101+
"op": config["op"],
102+
"type": 'STABLE',
103+
}
104+
)
105+
else:
106+
# 常规字段处理
107+
parsed_filters.append(
108+
{
109+
"key": to_camel_case(key) if parse_column_type(key) == 'STABLE' else key.split('.', 1)[-1],
110+
"active": True,
111+
"value": [value],
112+
"op": 'EQ',
113+
"type": parse_column_type(key),
114+
}
115+
)
116+
87117
res = client.post(f"/project/{path}/runs/shows", data={'filters': parsed_filters})
88118
return res[0]
89119

0 commit comments

Comments
 (0)