diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 2edede76d..26287ff4f 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -2,6 +2,7 @@ import json import re import traceback +import uuid from typing import TYPE_CHECKING, Any @@ -15,6 +16,8 @@ from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.plugins.hook_defs import H +from memos.plugins.hooks import trigger_single_hook from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType @@ -58,6 +61,8 @@ def __init__(self, config: MultiModalStructMemReaderConfig): simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) + self.memory_version_switch = getattr(config, "memory_version_switch", "off") + # Image parser LLM (requires vision model) # Falls back to general_llm if not configured (general_llm itself falls back to main llm) self.image_parser_llm = ( @@ -124,21 +129,32 @@ def _split_large_memory_item( try: chunks = self.chunker.chunk(item_text) split_items = [] + source_info = dict(item.metadata.info or {}) + source_internal_info = dict(item.metadata.internal_info or {}) + ingest_batch_id = str(source_internal_info.get("ingest_batch_id") or uuid.uuid4()) + chunk_total = len(chunks) - def _create_chunk_item(chunk): + def _create_chunk_item(chunk_idx: int, chunk): # Different chunkers are not fully consistent: # some return Chunk-like objects with `.text`, while others return raw strings. chunk_text = chunk.text if hasattr(chunk, "text") else chunk if not chunk_text or not chunk_text.strip(): return None + chunk_info = { + "user_id": item.metadata.user_id, + "session_id": item.metadata.session_id, + **source_info, + } + chunk_internal_info = { + **source_internal_info, + "ingest_batch_id": ingest_batch_id, + "chunk_index": chunk_idx, + "chunk_total": chunk_total, + } # Create a new memory item for each chunk, preserving original metadata split_item = self._make_memory_item( value=chunk_text, - info={ - "user_id": item.metadata.user_id, - "session_id": item.metadata.session_id, - **(item.metadata.info or {}), - }, + info=chunk_info, memory_type=item.metadata.memory_type, tags=item.metadata.tags or [], key=item.metadata.key, @@ -146,11 +162,15 @@ def _create_chunk_item(chunk): background=item.metadata.background or "", need_embed=False, ) + split_item.metadata.internal_info = chunk_internal_info return split_item # Use thread pool to parallel process chunks, but keep the original order with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks] + futures = [ + executor.submit(_create_chunk_item, chunk_idx, chunk) + for chunk_idx, chunk in enumerate(chunks) + ] for future in futures: split_item = future.result() if split_item is not None: @@ -306,6 +326,7 @@ def _build_window_from_items( all_sources = [] roles = set() aggregated_file_ids: list[str] = [] + ingest_batch_ids: set[str] = set() for item in items: if item.memory: @@ -334,6 +355,11 @@ def _build_window_from_items( for fid in item_file_ids: if fid and fid not in aggregated_file_ids: aggregated_file_ids.append(fid) + item_internal_info = getattr(metadata, "internal_info", None) + if isinstance(item_internal_info, dict): + batch_id = item_internal_info.get("ingest_batch_id") + if batch_id: + ingest_batch_ids.add(str(batch_id)) # Determine memory_type based on roles (same logic as simple_struct) # UserMemory if only user role, else LongTermMemory @@ -368,7 +394,6 @@ def _build_window_from_items( info_ = info.copy() user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") - # Create memory item without embedding (set to None, will be filled in batch) aggregated_item = TextualMemoryItem( memory=merged_text, @@ -389,6 +414,10 @@ def _build_window_from_items( **extra_kwargs, ), ) + if len(ingest_batch_ids) == 1: + aggregated_item.metadata.internal_info = { + "ingest_batch_id": next(iter(ingest_batch_ids)) + } return aggregated_item @@ -458,6 +487,9 @@ def _get_llm_response( if self.config.remove_prompt_example and examples: prompt = prompt.replace(examples, "") + + logger.info(f"[MultiModalParser] Process String Fine Prompt: {prompt}") + messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) @@ -506,6 +538,7 @@ def _get_maybe_merged_memory( sources: list, **kwargs, ) -> dict: + # TODO: delete this function """ Check if extracted memory should be merged with similar existing memories. If merge is needed, return merged memory dict with merged_from field. @@ -520,102 +553,7 @@ def _get_maybe_merged_memory( Returns: Memory dict (possibly merged) with merged_from field if merged """ - # If no graph_db or user_name, return original - if not self.graph_db or "user_name" not in kwargs: - return extracted_memory_dict - user_name = kwargs.get("user_name") - - # Detect language - lang = "en" - if sources: - for source in sources: - if hasattr(source, "lang") and source.lang: - lang = source.lang - break - elif isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - break - if lang is None: - lang = detect_lang(mem_text) - - # Search for similar memories - merge_threshold = kwargs.get("merge_similarity_threshold", 0.3) - - try: - search_results = self.graph_db.search_by_embedding( - vector=self.embedder.embed(mem_text)[0], - top_k=20, - status="activated", - threshold=merge_threshold, - user_name=user_name, - ) - - if not search_results: - return extracted_memory_dict - - # Get full memory details - similar_memory_ids = [r["id"] for r in search_results if r.get("id")] - similar_memories_list = [ - self.graph_db.get_node(mem_id, include_embedding=False, user_name=user_name) - for mem_id in similar_memory_ids - ] - - # Filter out None and mode:fast memories - filtered_similar = [] - for mem in similar_memories_list: - if not mem: - continue - mem_metadata = mem.get("metadata", {}) - tags = mem_metadata.get("tags", []) - if isinstance(tags, list) and "mode:fast" in tags: - continue - filtered_similar.append( - { - "id": mem.get("id"), - "memory": mem.get("memory", ""), - } - ) - logger.info( - f"Valid similar memories for {mem_text} is " - f"{len(filtered_similar)}: {filtered_similar}" - ) - - if not filtered_similar: - return extracted_memory_dict - - # Create a temporary TextualMemoryItem for merge check - temp_memory_item = TextualMemoryItem( - memory=mem_text, - metadata=TreeNodeTextualMemoryMetadata( - user_id="", - session_id="", - memory_type=extracted_memory_dict.get("memory_type", "LongTermMemory"), - status="activated", - tags=extracted_memory_dict.get("tags", []), - key=extracted_memory_dict.get("key", ""), - ), - ) - - # Try to merge with LLM - merge_result = self._merge_memories_with_llm( - temp_memory_item, filtered_similar, lang=lang - ) - - if merge_result: - # Return merged memory dict - merged_dict = extracted_memory_dict.copy() - merged_content = merge_result.get("value", mem_text) - merged_dict["value"] = merged_content - merged_from_ids = merge_result.get("merged_from", []) - merged_dict["merged_from"] = merged_from_ids - return merged_dict - else: - return extracted_memory_dict - - except Exception as e: - logger.error(f"[MultiModalFine] Error in get_maybe_merged_memory: {e}") - # On error, return original - return extracted_memory_dict + return extracted_memory_dict def _merge_memories_with_llm( self, @@ -717,6 +655,35 @@ def _process_one_item( # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) + # ========== Stage 0: Memory version async extraction/update pipeline ========== + if getattr(self, "memory_version_switch", "off") == "on": + try: + user_name = kwargs.get("user_name") + should_use_version_pipeline = trigger_single_hook( + H.MEMORY_VERSION_PREPARE_UPDATES, + item=fast_item, + user_name=user_name, + judge_llm=self.general_llm, + ) + if should_use_version_pipeline: + lang = detect_lang(kwargs.get("chat_history") or mem_str) + custom_tags_prompt_template = PROMPT_DICT["custom_tags"][lang] + new_items = trigger_single_hook( + H.MEMORY_VERSION_APPLY_UPDATES, + item=fast_item, + user_name=user_name, + version_llm=self.qwen_llm, + merge_llm=self.general_llm, + custom_tags=custom_tags, + custom_tags_prompt_template=custom_tags_prompt_template, + timeout_sec=30, + ) + return new_items + except RuntimeError as ex: + logger.warning(f"[MultiModalFine] Memory version hook unavailable: {ex}") + except Exception as ex: + logger.warning(f"[MultiModalFine] Fine memory version pipeline failed: {ex}") + # ========== Stage 1: Normal extraction (without reference) ========== try: resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) @@ -727,14 +694,15 @@ def _process_one_item( if resp.get("memory list", []): for m in resp.get("memory list", []): try: - # Check and merge with similar memories if needed - m_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=m, - mem_text=m.get("value", ""), - sources=sources, - original_query=mem_str, - **kwargs, - ) + m_maybe_merged = m + if getattr(self, "memory_version_switch", "off") != "on": + m_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=m, + mem_text=m.get("value", ""), + sources=sources, + original_query=mem_str, + **kwargs, + ) # Normalize memory_type (same as simple_struct) memory_type = ( m_maybe_merged.get("memory_type", "LongTermMemory") @@ -752,8 +720,10 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in m_maybe_merged: + if ( + getattr(self, "memory_version_switch", "off") != "on" + and "merged_from" in m_maybe_merged + ): node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = m_maybe_merged["merged_from"] fine_items.append(node) @@ -762,13 +732,15 @@ def _process_one_item( elif resp.get("value") and resp.get("key"): try: # Check and merge with similar memories if needed - resp_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=resp, - mem_text=resp.get("value", "").strip(), - sources=sources, - original_query=mem_str, - **kwargs, - ) + resp_maybe_merged = resp + if getattr(self, "memory_version_switch", "off") != "on": + resp_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=resp, + mem_text=resp.get("value", "").strip(), + sources=sources, + original_query=mem_str, + **kwargs, + ) node = self._make_memory_item( value=resp_maybe_merged.get("value", "").strip(), info=info_per_item, @@ -779,8 +751,10 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in resp_maybe_merged: + if ( + getattr(self, "memory_version_switch", "off") != "on" + and "merged_from" in resp_maybe_merged + ): node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"] fine_items.append(node) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index a9b2c43a4..f34cf1efd 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -69,9 +69,9 @@ class ArchivedTextualMemory(BaseModel): memory: str | None = Field( default_factory=lambda: "", description="The content of the archived version of the memory." ) - update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field( + update_type: Literal["conflict", "duplicate", "extract", "unrelated", "feedback"] = Field( default="unrelated", - description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).", + description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`, `feedback`).", ) archived_memory_id: str | None = Field( default=None, @@ -106,15 +106,15 @@ class TextualMemoryMetadata(BaseModel): default=None, description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.", ) - evolve_to: list[str] | None = Field( + evolve_to: list[str] = Field( default_factory=list, - description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", + description="Recording which new memory nodes it 'evolves' to after llm extraction.", ) - version: int | None = Field( - default=None, + version: int = Field( + default=1, description="The version of the memory. Will be incremented when the memory is updated.", ) - history: list[ArchivedTextualMemory] | None = Field( + history: list[ArchivedTextualMemory] = Field( default_factory=list, description="Storing the archived versions of the memory. Only preserving core information of each version.", ) @@ -146,6 +146,10 @@ class TextualMemoryMetadata(BaseModel): default=None, description="Arbitrary key-value pairs for additional metadata.", ) + internal_info: dict | None = Field( + default=None, + description="Internal algorithm metadata reserved for system use.", + ) model_config = ConfigDict(extra="allow") diff --git a/tests/mem_reader/test_project_id_propagation.py b/tests/mem_reader/test_project_id_propagation.py index 5a17910ca..bf55aca46 100644 --- a/tests/mem_reader/test_project_id_propagation.py +++ b/tests/mem_reader/test_project_id_propagation.py @@ -53,6 +53,7 @@ def _make_fast_item( manager_user_id: str | None = MANAGER_USER_ID, project_id: str | None = PROJECT_ID, role: str = "user", + internal_info: dict | None = None, ) -> TextualMemoryItem: return TextualMemoryItem( memory=memory, @@ -63,6 +64,7 @@ def _make_fast_item( sources=[SourceMessage(type="chat", role=role, content=memory)], manager_user_id=manager_user_id, project_id=project_id, + internal_info=internal_info, ), ) @@ -216,6 +218,8 @@ def setUp(self): self.reader.graph_db = MagicMock() self.reader.oss_config = None self.reader.skills_dir_config = None + self.reader.memory_version_switch = "off" + self.reader.qwen_llm = MagicMock() # -- _build_window_from_items -------------------------------------------- def test_build_window_propagates_project_id(self): @@ -255,6 +259,50 @@ def test_build_window_picks_first_nonempty(self): self.assertIsNotNone(result) _assert_fields(self, result) + def test_split_large_memory_item_assigns_shared_ingest_batch_id(self): + self.reader._count_tokens = MagicMock(return_value=999) + self.reader.chunker.chunk.return_value = ["chunk one", "chunk two"] + + def fake_make_memory_item( + *, + value, + info, + memory_type, + tags, + key, + sources, + background, + need_embed, + ): + return TextualMemoryItem( + memory=value, + metadata=TreeNodeTextualMemoryMetadata( + user_id=info["user_id"], + session_id=info["session_id"], + memory_type=memory_type, + tags=tags, + key=key, + sources=sources, + background=background, + ), + ) + + self.reader._make_memory_item = MagicMock(side_effect=fake_make_memory_item) + source_item = _make_fast_item("very long source", internal_info={"origin": "doc"}) + + result = self.reader._split_large_memory_item(source_item, max_tokens=10) + + self.assertEqual(len(result), 2) + batch_ids = { + item.metadata.internal_info["ingest_batch_id"] + for item in result + if item.metadata.internal_info and item.metadata.internal_info.get("ingest_batch_id") + } + self.assertEqual(len(batch_ids), 1) + self.assertEqual({item.metadata.internal_info["chunk_index"] for item in result}, {0, 1}) + self.assertEqual({item.metadata.internal_info["chunk_total"] for item in result}, {2}) + self.assertEqual({item.metadata.internal_info["origin"] for item in result}, {"doc"}) + # -- _process_string_fine ------------------------------------------------ def test_process_string_fine_propagates_fields(self): """Fine string extraction must carry project_id/manager_user_id