Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ._string_similarity_map import StringSimilarityMap
from .utils.page_logger import PageLogger
from .utils.restricted_pickle import BASE_ALLOWED_PICKLE_GLOBALS, restricted_pickle_load


@dataclass
Expand Down Expand Up @@ -79,7 +80,10 @@ def __init__(
if (not reset) and os.path.exists(self.path_to_dict):
self.logger.info("\nLOADING MEMOS FROM DISK at {}".format(self.path_to_dict))
with open(self.path_to_dict, "rb") as f:
self.uid_memo_dict = pickle.load(f)
allowed = BASE_ALLOWED_PICKLE_GLOBALS | {
("autogen_ext.experimental.task_centric_memory._memory_bank", "Memo"),
}
self.uid_memo_dict = restricted_pickle_load(f, allowed_globals=allowed)
self.last_memo_id = len(self.uid_memo_dict)
self.logger.info("\n{} MEMOS LOADED".format(len(self.uid_memo_dict)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from chromadb.config import Settings

from .utils.page_logger import PageLogger
from .utils.restricted_pickle import BASE_ALLOWED_PICKLE_GLOBALS, restricted_pickle_load


class StringSimilarityMap:
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self, reset: bool, path_to_db_dir: str, logger: PageLogger | None =
if (not reset) and os.path.exists(self.path_to_dict):
self.logger.debug("\nLOADING STRING SIMILARITY MAP FROM DISK at {}".format(self.path_to_dict))
with open(self.path_to_dict, "rb") as f:
self.uid_text_dict = pickle.load(f)
self.uid_text_dict = restricted_pickle_load(f, allowed_globals=BASE_ALLOWED_PICKLE_GLOBALS)
self.last_string_pair_id = len(self.uid_text_dict)
if len(self.uid_text_dict) > 0:
self.logger.debug("\n{} STRING PAIRS LOADED".format(len(self.uid_text_dict)))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import pickle
from typing import Any, BinaryIO

# NOTE: The task-centric memory feature persists local state to disk. These files
# can be moved between projects or restored from shared storage, so loads should
# not execute arbitrary pickle globals.

BASE_ALLOWED_PICKLE_GLOBALS: set[tuple[str, str]] = {
("builtins", "dict"),
("builtins", "list"),
("builtins", "set"),
("builtins", "tuple"),
("builtins", "str"),
("builtins", "bytes"),
("builtins", "bytearray"),
("builtins", "int"),
("builtins", "float"),
("builtins", "bool"),
}


class RestrictedUnpickler(pickle.Unpickler):
def __init__(self, file: BinaryIO, allowed_globals: set[tuple[str, str]]) -> None:
super().__init__(file)
self._allowed_globals = allowed_globals

def find_class(self, module: str, name: str): # noqa: ANN001
if (module, name) in self._allowed_globals:
return super().find_class(module, name)
raise pickle.UnpicklingError(
f"Blocked global during unpickle: {module}.{name}. "
"Delete the persisted memory files or re-run with reset=True."
)


def restricted_pickle_load(file: BinaryIO, *, allowed_globals: set[tuple[str, str]]) -> Any:
return RestrictedUnpickler(file, allowed_globals).load()

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import io
import os
import pickle

import pytest

pytest.importorskip("chromadb")


def test_restricted_pickle_load_blocks_unsafe_globals(monkeypatch: pytest.MonkeyPatch) -> None:
from autogen_ext.experimental.task_centric_memory.utils.restricted_pickle import (
BASE_ALLOWED_PICKLE_GLOBALS,
restricted_pickle_load,
)

monkeypatch.delenv("AUTOGEN_EXT_PICKLE_RCE_MARKER", raising=False)

class Evil:
def __reduce__(self): # noqa: ANN001
# Non-destructive: sets an env var if executed.
return (
exec,
("import os; os.environ['AUTOGEN_EXT_PICKLE_RCE_MARKER']='1'",),
)

payload = pickle.dumps(Evil())

with pytest.raises(pickle.UnpicklingError):
restricted_pickle_load(io.BytesIO(payload), allowed_globals=BASE_ALLOWED_PICKLE_GLOBALS)

assert os.environ.get("AUTOGEN_EXT_PICKLE_RCE_MARKER") is None


def test_restricted_pickle_load_allows_memo_dict_roundtrip() -> None:
from autogen_ext.experimental.task_centric_memory._memory_bank import Memo
from autogen_ext.experimental.task_centric_memory.utils.restricted_pickle import (
BASE_ALLOWED_PICKLE_GLOBALS,
restricted_pickle_load,
)

allowed = BASE_ALLOWED_PICKLE_GLOBALS | {
("autogen_ext.experimental.task_centric_memory._memory_bank", "Memo"),
}

original = {"1": Memo(task=None, insight="hi")}
payload = pickle.dumps(original)
loaded = restricted_pickle_load(io.BytesIO(payload), allowed_globals=allowed)

assert loaded == original