Skip to content

Commit cf1301d

Browse files
Zeyi-Linclaude
andauthored
feat: code quality improve and fix key section error (#1476)
* refactor: optimize wandb sync code quality - Replace fragile class name string matching with isinstance checks - Simplify redundant if-else branches - Move numpy import to function top to avoid repeated imports - Replace print() with swanlab logging system Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: optimize tensorboard sync code quality - Add proper None checks for global_step parameter - Fix numpy import to avoid repeated imports - Fix NCHW format handling to correctly extract first image - Add error logging for image conversion failures - Improve error handling with try-except blocks Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * del wandb summary * fix section split * section split unit test --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 807c4dc commit cf1301d

5 files changed

Lines changed: 101 additions & 62 deletions

File tree

swanlab/converter/wb/wb_local_converter.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,6 @@ def _pre_upload_cb(n):
351351
progress = None
352352
task_id = None
353353
record_count = 0
354-
last_summary = {}
355-
last_summary_step = 0
356354

357355
if Progress is not None:
358356
progress = Progress(
@@ -474,20 +472,7 @@ def _pre_upload_cb(n):
474472
_log_scalars_direct(scalar_dict, step)
475473
if media_dict:
476474
swanlab_run.log(media_dict, step=step)
477-
last_summary_step = step
478475
del scalar_dict, media_dict
479-
elif record_type == "summary":
480-
# Accumulate into last_summary; only log once after the loop ends.
481-
# wandb writes a summary record after every step, so calling log() here
482-
# would result in N redundant log calls (N = number of steps).
483-
for item in record_pb.summary.update:
484-
key = item.key or '/'.join(item.nested_key)
485-
if not key or key.startswith('_') or '/' in key:
486-
continue
487-
try:
488-
last_summary[key] = float(item.value_json)
489-
except (ValueError, TypeError):
490-
pass
491476

492477
# GC every GC_INTERVAL records to reduce overhead
493478
record_count += 1

swanlab/data/run/key.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,9 @@ def create_column(
165165
# PUBLIC 可选是否传递名称,如果 key 包含斜杠,则使用斜杠前的部分作为section的名称
166166
# CUSTOM 时如果 key 包含斜杠,则使用斜杠前的部分作为section的名称,并且将 section_type 设置为 PUBLIC
167167
if section_type in ["PUBLIC", "CUSTOM"]:
168-
split_key = key.split("/")
169-
if len(split_key) > 1 and split_key[0]:
170-
# 如果key包含斜杠,则使用斜杠前的部分作为section的名称
171-
result.section = split_key[0]
168+
if "/" in key:
169+
# 如果key包含斜杠,则使用最后一个斜杠前的部分作为section的名称
170+
result.section = key.rsplit("/", 1)[0]
172171
section_type: SectionType = "PUBLIC"
173172
else:
174173
result.section = None

swanlab/sync/tensorboard.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import swanlab
3+
from swanlab.log import swanlog as swl
34

45

56
def _extract_args(args, kwargs, param_names):
@@ -69,7 +70,7 @@ def patched_add_scalar(self, *args, **kwargs):
6970
)
7071

7172
data = {tag: scalar_value}
72-
swanlab.log(data=data, step=int(global_step))
73+
swanlab.log(data=data, step=int(global_step) if global_step is not None else None)
7374

7475
return original_add_scalar(self, *args, **kwargs)
7576

@@ -83,42 +84,53 @@ def patched_add_scalars(self, *args, **kwargs):
8384
)
8485
for dict_tag, value in scalar_value_dict.items():
8586
data = {f"{tag}/{dict_tag}": value}
86-
swanlab.log(data=data, step=int(global_step))
87+
swanlab.log(data=data, step=int(global_step) if global_step is not None else None)
8788
return original_add_scalars(self, *args, **kwargs)
8889

8990
@functools.wraps(original_add_image)
9091
def patched_add_image(self, *args, **kwargs):
9192
if types_set is not None and 'image' not in types_set:
9293
return original_add_image(self, *args, **kwargs)
93-
import numpy as np
94+
95+
try:
96+
import numpy as np
97+
except ImportError:
98+
np = None
99+
100+
if np is None:
101+
swl.warning("numpy not available, skipping image conversion")
102+
return original_add_image(self, *args, **kwargs)
94103

95104
tag, img_tensor, global_step, dataformats = _extract_args(
96105
args, kwargs, ['tag', 'img_tensor', 'global_step', 'dataformats']
97106
)
98-
dataformats = dataformats or 'CHW' # 设置默认值
99-
100-
# Convert to numpy array if it's a tensor
101-
if hasattr(img_tensor, 'cpu'):
102-
img_tensor = img_tensor.cpu()
103-
if hasattr(img_tensor, 'numpy'):
104-
img_tensor = img_tensor.numpy()
105-
106-
# Handle different input formats
107-
if dataformats == 'CHW':
108-
# Convert CHW to HWC for swanlab
109-
img_tensor = np.transpose(img_tensor, (1, 2, 0))
110-
elif dataformats == 'NCHW':
111-
# Take first image if batch dimension exists and convert to HWC
112-
img_tensor = np.transpose(img_tensor, (1, 2, 0))
113-
elif dataformats == 'HW':
114-
# Add channel dimension for grayscale
115-
img_tensor = np.expand_dims(img_tensor, axis=-1)
116-
elif dataformats == 'HWC':
117-
# Already in correct format
118-
pass
119-
120-
data = {tag: swanlab.Image(img_tensor)}
121-
swanlab.log(data=data, step=int(global_step))
107+
dataformats = dataformats or 'CHW'
108+
109+
try:
110+
# Convert to numpy array if it's a tensor
111+
if hasattr(img_tensor, 'cpu'):
112+
img_tensor = img_tensor.cpu()
113+
if hasattr(img_tensor, 'numpy'):
114+
img_tensor = img_tensor.numpy()
115+
116+
# Handle different input formats
117+
if dataformats == 'CHW':
118+
# Convert CHW to HWC for swanlab
119+
img_tensor = np.transpose(img_tensor, (1, 2, 0))
120+
elif dataformats == 'NCHW':
121+
# Take first image if batch dimension exists and convert to HWC
122+
img_tensor = np.transpose(img_tensor[0], (1, 2, 0))
123+
elif dataformats == 'HW':
124+
# Add channel dimension for grayscale
125+
img_tensor = np.expand_dims(img_tensor, axis=-1)
126+
elif dataformats == 'HWC':
127+
# Already in correct format
128+
pass
129+
130+
data = {tag: swanlab.Image(img_tensor)}
131+
swanlab.log(data=data, step=int(global_step) if global_step is not None else None)
132+
except Exception as e:
133+
swl.warning(f"Failed to convert image for tag '{tag}': {e}")
122134

123135
return original_add_image(self, *args, **kwargs)
124136

@@ -130,7 +142,7 @@ def patched_add_text(self, *args, **kwargs):
130142
args, kwargs, ['tag', 'text_string', 'global_step']
131143
)
132144
data = {tag: swanlab.Text(text_string)}
133-
swanlab.log(data=data, step=int(global_step))
145+
swanlab.log(data=data, step=int(global_step) if global_step is not None else None)
134146
return original_add_text(self, *args, **kwargs)
135147

136148
def patched_close(self):

swanlab/sync/wandb.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import swanlab
2+
from swanlab.log import swanlog as swl
23

34
def _extract_args(args, kwargs, param_names):
45
"""
@@ -62,6 +63,8 @@ def sync_wandb(
6263
try:
6364
import wandb
6465
from wandb import sdk as wandb_sdk
66+
from wandb import Image as WandbImage
67+
WANDB_IMAGE_AVAILABLE = True
6568
except ImportError:
6669
raise ImportError("please install wandb first, command: `pip install wandb`")
6770

@@ -94,9 +97,7 @@ def patched_init(*args, **kwargs):
9497

9598
if wandb_run is False:
9699
kwargs["mode"] = "offline"
97-
return original_init(*args, **kwargs)
98-
else:
99-
return original_init(*args, **kwargs)
100+
return original_init(*args, **kwargs)
100101

101102
def patched_config_update(self, *args, **kwargs):
102103
d, _ = _extract_args(args, kwargs, ['d', 'allow_val_change'])
@@ -107,45 +108,51 @@ def patched_config_update(self, *args, **kwargs):
107108

108109
def patched_log(self, *args, **kwargs):
109110
data, step, commit, sync = _extract_args(args, kwargs, ['data', 'step', 'commit', 'sync'])
110-
111+
111112
if data is None:
112113
return original_log(self, *args, **kwargs)
113-
114+
115+
# Import numpy once
116+
try:
117+
import numpy as np
118+
except ImportError:
119+
np = None
120+
114121
# 处理数据,支持 wandb.Image
115122
processed_data = {}
116123
for key, value in data.items():
117124
if isinstance(value, (int, float, bool, str)):
118125
# 标量类型直接保留
119126
processed_data[key] = value
120-
elif hasattr(value, '__class__') and value.__class__.__name__ == 'Image' and hasattr(value, 'image'):
127+
elif WANDB_IMAGE_AVAILABLE and isinstance(value, WandbImage):
121128
# 检测是否为 wandb.Image
122129
try:
130+
if np is None:
131+
swl.warning(f"numpy not available, skipping wandb.Image conversion for key '{key}'")
132+
continue
123133
# 获取 wandb.Image 的图像数据
124134
if value.image is not None:
125-
# 将 PIL Image 转换为 numpy 数组
126-
import numpy as np
127135
img_array = np.array(value.image)
128-
129-
# 创建 swanlab.Image
130136
caption = getattr(value, '_caption', None)
131137
swanlab_image = swanlab.Image(img_array, caption=caption)
132138
processed_data[key] = swanlab_image
133139
else:
134140
# 如果 image 为 None,尝试使用 _image
135141
if hasattr(value, '_image') and value._image is not None:
136-
import numpy as np
137142
img_array = np.array(value._image)
138143
caption = getattr(value, '_caption', None)
139144
swanlab_image = swanlab.Image(img_array, caption=caption)
140145
processed_data[key] = swanlab_image
141146
except Exception as e:
142147
# 如果转换失败,记录错误但继续处理其他数据
143-
print(f"Warning: Failed to convert wandb.Image for key '{key}': {e}")
148+
swl.warning(f"Failed to convert wandb.Image for key '{key}': {e}")
144149
continue
145-
elif isinstance(value, list) and value and hasattr(value[0], '__class__') and value[0].__class__.__name__ == 'Image':
150+
elif isinstance(value, list) and value and WANDB_IMAGE_AVAILABLE and isinstance(value[0], WandbImage):
146151
# 检测是否为 wandb.Image 列表
147152
try:
148-
import numpy as np
153+
if np is None:
154+
swl.warning(f"numpy not available, skipping wandb.Image list conversion for key '{key}'")
155+
continue
149156
swanlab_images = []
150157
for v in value:
151158
if hasattr(v, 'image') and v.image is not None:
@@ -160,7 +167,7 @@ def patched_log(self, *args, **kwargs):
160167
processed_data[key] = swanlab_images
161168
except Exception as e:
162169
# 如果转换失败,记录错误但继续处理其他数据
163-
print(f"Warning: Failed to convert wandb.Image list for key '{key}': {e}")
170+
swl.warning(f"Failed to convert wandb.Image list for key '{key}': {e}")
164171
continue
165172

166173
if processed_data:

test/unit/data/run/test_key.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from swanlab.data.run.key import SwanLabKey
11+
from swanlab.data.modules import Line, DataWrapper
1112
from swanlab.toolkit import ChartType
1213
from tutils.setup import UseMockRunState
1314

@@ -196,3 +197,38 @@ def test_step_not_none(self):
196197
assert len(key_obj.steps) == 101, "Steps should contain one entry when step is provided"
197198
assert 1 in key_obj.steps, "Step 1 should be present in the steps"
198199
assert 101 not in key_obj.steps, "Step 10 should be present in the steps"
200+
201+
@pytest.mark.parametrize(
202+
"key, expected_section",
203+
[
204+
("loss/metrics/label1", "loss/metrics"), # 多个斜杠,取最后一个斜杠前的部分
205+
("train/loss", "train"), # 单个斜杠
206+
("accuracy", "default"), # 无斜杠,保持默认
207+
("a/b/c/d", "a/b/c"), # 更多斜杠
208+
],
209+
)
210+
def test_section_split(self, key, expected_section):
211+
"""
212+
测试 section 按最后一个斜杠分割的逻辑
213+
"""
214+
with UseMockRunState() as run_state:
215+
# 创建 DataWrapper 并解析
216+
data = DataWrapper(key=key, data=[Line(0.5)])
217+
data.parse(step=0, key=key)
218+
219+
# 创建 SwanLabKey 对象
220+
swanlab_key = SwanLabKey(key=key, media_dir=run_state.store.media_dir, log_dir=run_state.store.log_dir)
221+
222+
# 创建列信息
223+
column_info = swanlab_key.create_column(
224+
key=key,
225+
name=None,
226+
column_class="CUSTOM",
227+
column_config=None,
228+
section_type="PUBLIC",
229+
data=data,
230+
num=0
231+
)
232+
233+
# 验证 section_name
234+
assert column_info.section_name == expected_section

0 commit comments

Comments
 (0)