Skip to content

Commit e6656f3

Browse files
author
harvey_xiang
committed
feat: optimized embedding item
1 parent deb9d1a commit e6656f3

1 file changed

Lines changed: 26 additions & 43 deletions

File tree

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,30 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
7575
direct_markdown_hostnames=direct_markdown_hostnames,
7676
)
7777

78+
def _embed_memory_items(self, items: list[TextualMemoryItem]) -> None:
79+
"""Compute embeddings for a list of memory items in-place.
80+
81+
Attempts a single batch call first; falls back to per-item calls if the
82+
batch fails. Errors are logged but never raised so callers always
83+
continue normally.
84+
"""
85+
valid = [w for w in items if w and w.memory]
86+
if not valid:
87+
return
88+
texts = [w.memory for w in valid]
89+
try:
90+
embeddings = self.embedder.embed(texts)
91+
for w, emb in zip(valid, embeddings, strict=True):
92+
w.metadata.embedding = emb
93+
except Exception as e:
94+
logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}")
95+
logger.warning("[EMBED_FALLBACK] batch_size=%d", len(texts))
96+
for w in valid:
97+
try:
98+
w.metadata.embedding = self.embedder.embed([w.memory])[0]
99+
except Exception as e2:
100+
logger.error(f"[MultiModalStruct] Error computing embedding for item: {e2}")
101+
78102
def _split_large_memory_item(
79103
self, item: TextualMemoryItem, max_tokens: int
80104
) -> list[TextualMemoryItem]:
@@ -204,15 +228,7 @@ def _concat_multi_modal_memories(
204228
if len(processed_items) == 1:
205229
single_item = processed_items[0]
206230
with timed_stage("add", "embedding", window_count=1):
207-
if single_item and single_item.memory:
208-
try:
209-
single_item.metadata.embedding = self.embedder.embed([single_item.memory])[
210-
0
211-
]
212-
except Exception as e:
213-
logger.error(
214-
f"[MultiModalStruct] Error computing embedding for single item: {e}"
215-
)
231+
self._embed_memory_items([single_item])
216232
return processed_items
217233

218234
windows = []
@@ -264,40 +280,7 @@ def _concat_multi_modal_memories(
264280

265281
# Batch compute embeddings for all windows
266282
with timed_stage("add", "embedding", window_count=len(windows)):
267-
if windows:
268-
valid_windows = [w for w in windows if w and w.memory]
269-
270-
if valid_windows:
271-
texts_to_embed = [w.memory for w in valid_windows]
272-
273-
try:
274-
embeddings = self.embedder.embed(texts_to_embed)
275-
for window, embedding in zip(valid_windows, embeddings, strict=True):
276-
window.metadata.embedding = embedding
277-
except Exception as e:
278-
logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}")
279-
logger.warning("[EMBED_FALLBACK] batch_size=%d", len(texts_to_embed))
280-
for window in valid_windows:
281-
if window.memory:
282-
try:
283-
window.metadata.embedding = self.embedder.embed(
284-
[window.memory]
285-
)[0]
286-
except Exception as e2:
287-
logger.error(
288-
"[MultiModalStruct] Error computing embedding"
289-
f" for item: {e2}"
290-
)
291-
292-
# [EMBED_MISSING] alert if any window has no embedding
293-
null_count = sum(1 for w in windows if w and w.metadata.embedding is None)
294-
if null_count > 0:
295-
logger.warning(
296-
"[EMBED_MISSING] window_count=%d null_count=%d ratio=%.2f",
297-
len(windows),
298-
null_count,
299-
null_count / max(len(windows), 1),
300-
)
283+
self._embed_memory_items(windows)
301284

302285
return windows
303286

0 commit comments

Comments
 (0)