1818from memos .templates .mem_reader_prompts import MEMORY_MERGE_PROMPT_EN , MEMORY_MERGE_PROMPT_ZH
1919from memos .templates .tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN , TOOL_TRAJECTORY_PROMPT_ZH
2020from memos .types import MessagesType
21- from memos .utils import timed
21+ from memos .utils import timed , timed_stage
2222
2323
2424if TYPE_CHECKING :
@@ -75,6 +75,30 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
7575 direct_markdown_hostnames = direct_markdown_hostnames ,
7676 )
7777
78+ def _embed_memory_items (self , items : list [TextualMemoryItem ]) -> None :
79+ """Compute embeddings for a list of memory items in-place.
80+
81+ Attempts a single batch call first; falls back to per-item calls if the
82+ batch fails. Errors are logged but never raised so callers always
83+ continue normally.
84+ """
85+ valid = [w for w in items if w and w .memory ]
86+ if not valid :
87+ return
88+ texts = [w .memory for w in valid ]
89+ try :
90+ embeddings = self .embedder .embed (texts )
91+ for w , emb in zip (valid , embeddings , strict = True ):
92+ w .metadata .embedding = emb
93+ except Exception as e :
94+ logger .error (f"[MultiModalStruct] Error batch computing embeddings: { e } " )
95+ logger .warning ("[EMBED_FALLBACK] batch_size=%d" , len (texts ))
96+ for w in valid :
97+ try :
98+ w .metadata .embedding = self .embedder .embed ([w .memory ])[0 ]
99+ except Exception as e2 :
100+ logger .error (f"[MultiModalStruct] Error computing embedding for item: { e2 } " )
101+
78102 def _split_large_memory_item (
79103 self , item : TextualMemoryItem , max_tokens : int
80104 ) -> list [TextualMemoryItem ]:
@@ -203,13 +227,8 @@ def _concat_multi_modal_memories(
203227 # If only one item after processing, compute embedding and return
204228 if len (processed_items ) == 1 :
205229 single_item = processed_items [0 ]
206- if single_item and single_item .memory :
207- try :
208- single_item .metadata .embedding = self .embedder .embed ([single_item .memory ])[0 ]
209- except Exception as e :
210- logger .error (
211- f"[MultiModalStruct] Error computing embedding for single item: { e } "
212- )
230+ with timed_stage ("add" , "embedding" , window_count = 1 ):
231+ self ._embed_memory_items ([single_item ])
213232 return processed_items
214233
215234 windows = []
@@ -260,31 +279,8 @@ def _concat_multi_modal_memories(
260279 windows .append (window )
261280
262281 # Batch compute embeddings for all windows
263- if windows :
264- # Collect all valid windows that need embedding
265- valid_windows = [w for w in windows if w and w .memory ]
266-
267- if valid_windows :
268- # Collect all texts that need embedding
269- texts_to_embed = [w .memory for w in valid_windows ]
270-
271- # Batch compute all embeddings at once
272- try :
273- embeddings = self .embedder .embed (texts_to_embed )
274- # Fill embeddings back into memory items
275- for window , embedding in zip (valid_windows , embeddings , strict = True ):
276- window .metadata .embedding = embedding
277- except Exception as e :
278- logger .error (f"[MultiModalStruct] Error batch computing embeddings: { e } " )
279- # Fallback: compute embeddings individually
280- for window in valid_windows :
281- if window .memory :
282- try :
283- window .metadata .embedding = self .embedder .embed ([window .memory ])[0 ]
284- except Exception as e2 :
285- logger .error (
286- f"[MultiModalStruct] Error computing embedding for item: { e2 } "
287- )
282+ with timed_stage ("add" , "embedding" , window_count = len (windows )):
283+ self ._embed_memory_items (windows )
288284
289285 return windows
290286
@@ -984,49 +980,49 @@ def _process_multi_modal_data(
984980 # must pop here, avoid add to info, only used in sync fine mode
985981 custom_tags = info .pop ("custom_tags" , None ) if isinstance (info , dict ) else None
986982
987- # Use MultiModalParser to parse the scene data
988- # If it's a list, parse each item; otherwise parse as single message
989- if isinstance (scene_data_info , list ):
990- # Pre-expand multimodal messages
991- expanded_messages = self ._expand_multimodal_messages (scene_data_info )
992-
993- # Parse each message in the list
994- all_memory_items = []
995- # Use thread pool to parse each message in parallel, but keep the original order
996- with ContextThreadPoolExecutor (max_workers = 30 ) as executor :
997- # submit tasks and keep the original order
998- futures = [
999- executor .submit (
1000- self .multi_modal_parser .parse ,
1001- msg ,
1002- info ,
1003- mode = "fast" ,
1004- need_emb = False ,
1005- ** kwargs ,
1006- )
1007- for msg in expanded_messages
1008- ]
1009- # collect results in original order
1010- for future in futures :
1011- try :
1012- items = future .result ()
1013- all_memory_items .extend (items )
1014- except Exception as e :
1015- logger .error (f"[MultiModalFine] Error in parallel parsing: { e } " )
1016- else :
1017- # Parse as single message
1018- all_memory_items = self .multi_modal_parser .parse (
1019- scene_data_info , info , mode = "fast" , need_emb = False , ** kwargs
1020- )
1021- fast_memory_items = self ._concat_multi_modal_memories (all_memory_items )
983+ # Stage: parse — parallel message parsing + sliding-window aggregation
984+ with timed_stage ("add" , "parse" ) as ts_parse :
985+ if isinstance (scene_data_info , list ):
986+ expanded_messages = self ._expand_multimodal_messages (scene_data_info )
987+ ts_parse .set (msg_count = len (expanded_messages ))
988+
989+ all_memory_items = []
990+ with ContextThreadPoolExecutor (max_workers = 30 ) as executor :
991+ futures = [
992+ executor .submit (
993+ self .multi_modal_parser .parse ,
994+ msg ,
995+ info ,
996+ mode = "fast" ,
997+ need_emb = False ,
998+ ** kwargs ,
999+ )
1000+ for msg in expanded_messages
1001+ ]
1002+ for future in futures :
1003+ try :
1004+ items = future .result ()
1005+ all_memory_items .extend (items )
1006+ except Exception as e :
1007+ logger .error (f"[MultiModalFine] Error in parallel parsing: { e } " )
1008+ else :
1009+ ts_parse .set (msg_count = 1 )
1010+ all_memory_items = self .multi_modal_parser .parse (
1011+ scene_data_info , info , mode = "fast" , need_emb = False , ** kwargs
1012+ )
1013+
1014+ fast_memory_items = self ._concat_multi_modal_memories (all_memory_items )
1015+ ts_parse .set (window_count = len (fast_memory_items ))
1016+
10221017 if mode == "fast" :
10231018 return fast_memory_items
1024- else :
1025- non_file_url_fast_items = [
1026- item for item in fast_memory_items if not self ._is_file_url_only_item (item )
1027- ]
10281019
1029- # Part A: call llm in parallel using thread pool
1020+ # Stage: llm_extract — fine mode 4-way parallel LLM + per-source serial
1021+ non_file_url_fast_items = [
1022+ item for item in fast_memory_items if not self ._is_file_url_only_item (item )
1023+ ]
1024+
1025+ with timed_stage ("add" , "llm_extract" ) as ts_llm :
10301026 fine_memory_items = []
10311027
10321028 with ContextThreadPoolExecutor (max_workers = 4 ) as executor :
@@ -1057,7 +1053,6 @@ def _process_multi_modal_data(
10571053 ** kwargs ,
10581054 )
10591055
1060- # Collect results
10611056 fine_memory_items_string_parser = future_string .result ()
10621057 fine_memory_items_tool_trajectory_parser = future_tool .result ()
10631058 fine_memory_items_skill_memory_parser = future_skill .result ()
@@ -1068,21 +1063,25 @@ def _process_multi_modal_data(
10681063 fine_memory_items .extend (fine_memory_items_skill_memory_parser )
10691064 fine_memory_items .extend (fine_memory_items_pref_parser )
10701065
1071- # Part B: get fine multimodal items
1072- for fast_item in fast_memory_items :
1073- sources = fast_item .metadata .sources
1074- for source in sources :
1075- lang = getattr (source , "lang" , "en" )
1076- items = self .multi_modal_parser .process_transfer (
1077- source ,
1078- context_items = [fast_item ],
1079- custom_tags = custom_tags ,
1080- info = info ,
1081- lang = lang ,
1082- user_context = kwargs .get ("user_context" ),
1083- )
1084- fine_memory_items .extend (items )
1085- return fine_memory_items
1066+ # Part B: per-source serial processing
1067+ with timed_stage ("add" , "per_source" ) as ts_ps :
1068+ for fast_item in fast_memory_items :
1069+ sources = fast_item .metadata .sources
1070+ for source in sources :
1071+ lang = getattr (source , "lang" , "en" )
1072+ items = self .multi_modal_parser .process_transfer (
1073+ source ,
1074+ context_items = [fast_item ],
1075+ custom_tags = custom_tags ,
1076+ info = info ,
1077+ lang = lang ,
1078+ user_context = kwargs .get ("user_context" ),
1079+ )
1080+ fine_memory_items .extend (items )
1081+
1082+ ts_llm .set (fine_memory_count = len (fine_memory_items ), per_source_ms = ts_ps .duration_ms )
1083+
1084+ return fine_memory_items
10861085
10871086 @timed
10881087 def _process_transfer_multi_modal_data (
0 commit comments