Skip to content

Commit f8ab962

Browse files
CaralHsiharvey_xiang
andauthored
feat: add stage log (#1486)
* fix: lint * feat: add addMessage Stage log * feat: add addMessage Stage log * feat: optimized embedding item --------- Co-authored-by: harvey_xiang <harvey_xiang22@163.com>
1 parent 96a1dd6 commit f8ab962

7 files changed

Lines changed: 1142 additions & 173 deletions

File tree

src/memos/chunkers/sentence_chunker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,5 @@ def chunk(self, text: str) -> list[str] | list[Chunk]:
5353
chunks.append(chunk)
5454

5555
logger.debug(f"Generated {len(chunks)} chunks from input text")
56+
5657
return chunks

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 88 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH
1919
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
2020
from memos.types import MessagesType
21-
from memos.utils import timed
21+
from memos.utils import timed, timed_stage
2222

2323

2424
if TYPE_CHECKING:
@@ -75,6 +75,30 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
7575
direct_markdown_hostnames=direct_markdown_hostnames,
7676
)
7777

78+
def _embed_memory_items(self, items: list[TextualMemoryItem]) -> None:
79+
"""Compute embeddings for a list of memory items in-place.
80+
81+
Attempts a single batch call first; falls back to per-item calls if the
82+
batch fails. Errors are logged but never raised so callers always
83+
continue normally.
84+
"""
85+
valid = [w for w in items if w and w.memory]
86+
if not valid:
87+
return
88+
texts = [w.memory for w in valid]
89+
try:
90+
embeddings = self.embedder.embed(texts)
91+
for w, emb in zip(valid, embeddings, strict=True):
92+
w.metadata.embedding = emb
93+
except Exception as e:
94+
logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}")
95+
logger.warning("[EMBED_FALLBACK] batch_size=%d", len(texts))
96+
for w in valid:
97+
try:
98+
w.metadata.embedding = self.embedder.embed([w.memory])[0]
99+
except Exception as e2:
100+
logger.error(f"[MultiModalStruct] Error computing embedding for item: {e2}")
101+
78102
def _split_large_memory_item(
79103
self, item: TextualMemoryItem, max_tokens: int
80104
) -> list[TextualMemoryItem]:
@@ -203,13 +227,8 @@ def _concat_multi_modal_memories(
203227
# If only one item after processing, compute embedding and return
204228
if len(processed_items) == 1:
205229
single_item = processed_items[0]
206-
if single_item and single_item.memory:
207-
try:
208-
single_item.metadata.embedding = self.embedder.embed([single_item.memory])[0]
209-
except Exception as e:
210-
logger.error(
211-
f"[MultiModalStruct] Error computing embedding for single item: {e}"
212-
)
230+
with timed_stage("add", "embedding", window_count=1):
231+
self._embed_memory_items([single_item])
213232
return processed_items
214233

215234
windows = []
@@ -260,31 +279,8 @@ def _concat_multi_modal_memories(
260279
windows.append(window)
261280

262281
# Batch compute embeddings for all windows
263-
if windows:
264-
# Collect all valid windows that need embedding
265-
valid_windows = [w for w in windows if w and w.memory]
266-
267-
if valid_windows:
268-
# Collect all texts that need embedding
269-
texts_to_embed = [w.memory for w in valid_windows]
270-
271-
# Batch compute all embeddings at once
272-
try:
273-
embeddings = self.embedder.embed(texts_to_embed)
274-
# Fill embeddings back into memory items
275-
for window, embedding in zip(valid_windows, embeddings, strict=True):
276-
window.metadata.embedding = embedding
277-
except Exception as e:
278-
logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}")
279-
# Fallback: compute embeddings individually
280-
for window in valid_windows:
281-
if window.memory:
282-
try:
283-
window.metadata.embedding = self.embedder.embed([window.memory])[0]
284-
except Exception as e2:
285-
logger.error(
286-
f"[MultiModalStruct] Error computing embedding for item: {e2}"
287-
)
282+
with timed_stage("add", "embedding", window_count=len(windows)):
283+
self._embed_memory_items(windows)
288284

289285
return windows
290286

@@ -984,49 +980,49 @@ def _process_multi_modal_data(
984980
# must pop here, avoid add to info, only used in sync fine mode
985981
custom_tags = info.pop("custom_tags", None) if isinstance(info, dict) else None
986982

987-
# Use MultiModalParser to parse the scene data
988-
# If it's a list, parse each item; otherwise parse as single message
989-
if isinstance(scene_data_info, list):
990-
# Pre-expand multimodal messages
991-
expanded_messages = self._expand_multimodal_messages(scene_data_info)
992-
993-
# Parse each message in the list
994-
all_memory_items = []
995-
# Use thread pool to parse each message in parallel, but keep the original order
996-
with ContextThreadPoolExecutor(max_workers=30) as executor:
997-
# submit tasks and keep the original order
998-
futures = [
999-
executor.submit(
1000-
self.multi_modal_parser.parse,
1001-
msg,
1002-
info,
1003-
mode="fast",
1004-
need_emb=False,
1005-
**kwargs,
1006-
)
1007-
for msg in expanded_messages
1008-
]
1009-
# collect results in original order
1010-
for future in futures:
1011-
try:
1012-
items = future.result()
1013-
all_memory_items.extend(items)
1014-
except Exception as e:
1015-
logger.error(f"[MultiModalFine] Error in parallel parsing: {e}")
1016-
else:
1017-
# Parse as single message
1018-
all_memory_items = self.multi_modal_parser.parse(
1019-
scene_data_info, info, mode="fast", need_emb=False, **kwargs
1020-
)
1021-
fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
983+
# Stage: parse — parallel message parsing + sliding-window aggregation
984+
with timed_stage("add", "parse") as ts_parse:
985+
if isinstance(scene_data_info, list):
986+
expanded_messages = self._expand_multimodal_messages(scene_data_info)
987+
ts_parse.set(msg_count=len(expanded_messages))
988+
989+
all_memory_items = []
990+
with ContextThreadPoolExecutor(max_workers=30) as executor:
991+
futures = [
992+
executor.submit(
993+
self.multi_modal_parser.parse,
994+
msg,
995+
info,
996+
mode="fast",
997+
need_emb=False,
998+
**kwargs,
999+
)
1000+
for msg in expanded_messages
1001+
]
1002+
for future in futures:
1003+
try:
1004+
items = future.result()
1005+
all_memory_items.extend(items)
1006+
except Exception as e:
1007+
logger.error(f"[MultiModalFine] Error in parallel parsing: {e}")
1008+
else:
1009+
ts_parse.set(msg_count=1)
1010+
all_memory_items = self.multi_modal_parser.parse(
1011+
scene_data_info, info, mode="fast", need_emb=False, **kwargs
1012+
)
1013+
1014+
fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
1015+
ts_parse.set(window_count=len(fast_memory_items))
1016+
10221017
if mode == "fast":
10231018
return fast_memory_items
1024-
else:
1025-
non_file_url_fast_items = [
1026-
item for item in fast_memory_items if not self._is_file_url_only_item(item)
1027-
]
10281019

1029-
# Part A: call llm in parallel using thread pool
1020+
# Stage: llm_extract — fine mode 4-way parallel LLM + per-source serial
1021+
non_file_url_fast_items = [
1022+
item for item in fast_memory_items if not self._is_file_url_only_item(item)
1023+
]
1024+
1025+
with timed_stage("add", "llm_extract") as ts_llm:
10301026
fine_memory_items = []
10311027

10321028
with ContextThreadPoolExecutor(max_workers=4) as executor:
@@ -1057,7 +1053,6 @@ def _process_multi_modal_data(
10571053
**kwargs,
10581054
)
10591055

1060-
# Collect results
10611056
fine_memory_items_string_parser = future_string.result()
10621057
fine_memory_items_tool_trajectory_parser = future_tool.result()
10631058
fine_memory_items_skill_memory_parser = future_skill.result()
@@ -1068,21 +1063,25 @@ def _process_multi_modal_data(
10681063
fine_memory_items.extend(fine_memory_items_skill_memory_parser)
10691064
fine_memory_items.extend(fine_memory_items_pref_parser)
10701065

1071-
# Part B: get fine multimodal items
1072-
for fast_item in fast_memory_items:
1073-
sources = fast_item.metadata.sources
1074-
for source in sources:
1075-
lang = getattr(source, "lang", "en")
1076-
items = self.multi_modal_parser.process_transfer(
1077-
source,
1078-
context_items=[fast_item],
1079-
custom_tags=custom_tags,
1080-
info=info,
1081-
lang=lang,
1082-
user_context=kwargs.get("user_context"),
1083-
)
1084-
fine_memory_items.extend(items)
1085-
return fine_memory_items
1066+
# Part B: per-source serial processing
1067+
with timed_stage("add", "per_source") as ts_ps:
1068+
for fast_item in fast_memory_items:
1069+
sources = fast_item.metadata.sources
1070+
for source in sources:
1071+
lang = getattr(source, "lang", "en")
1072+
items = self.multi_modal_parser.process_transfer(
1073+
source,
1074+
context_items=[fast_item],
1075+
custom_tags=custom_tags,
1076+
info=info,
1077+
lang=lang,
1078+
user_context=kwargs.get("user_context"),
1079+
)
1080+
fine_memory_items.extend(items)
1081+
1082+
ts_llm.set(fine_memory_count=len(fine_memory_items), per_source_ms=ts_ps.duration_ms)
1083+
1084+
return fine_memory_items
10861085

10871086
@timed
10881087
def _process_transfer_multi_modal_data(

src/memos/multi_mem_cube/composite_cube.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from memos.context.context import ContextThreadPoolExecutor
88
from memos.multi_mem_cube.views import MemCubeView
9+
from memos.utils import timed_stage
910

1011

1112
if TYPE_CHECKING:
@@ -27,13 +28,18 @@ class CompositeCubeView(MemCubeView):
2728

2829
def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]:
2930
all_results: list[dict[str, Any]] = []
30-
31-
# fast mode: for each cube view, add memories
32-
# maybe add more strategies in add_req.async_mode
33-
for view in self.cube_views:
34-
self.logger.info(f"[CompositeCubeView] fan-out add to cube={view.cube_id}")
35-
results = view.add_memories(add_req)
36-
all_results.extend(results)
31+
cube_count = len(self.cube_views)
32+
33+
with timed_stage("add", "multi_cube", cube_count=cube_count):
34+
for idx, view in enumerate(self.cube_views):
35+
self.logger.info(
36+
"[CompositeCubeView] fan-out add to cube=%s (%d/%d)",
37+
view.cube_id,
38+
idx + 1,
39+
cube_count,
40+
)
41+
results = view.add_memories(add_req)
42+
all_results.extend(results)
3743

3844
return all_results
3945

0 commit comments

Comments
 (0)