From f94e7e8ffde5852e6641c4fedb4f03008fdf7aae Mon Sep 17 00:00:00 2001 From: mldangelo Date: Wed, 4 Mar 2026 13:52:37 -0500 Subject: [PATCH 1/3] Fix FicklingContextManager incomplete cleanup and dead code (#241) - Snapshot/restore pre-enter hook state on exit (via state stack), preserving any pre-existing hooks and supporting re-entrant usage - Remove dead wrapped_load lambda that was immediately overwritten - Add max_acceptable_severity parameter to run_hook() with duplicate- kwarg protection on all paths including Unpickler - Simplify context.py to a thin wrapper over hook.run_hook() --- fickling/context.py | 14 ++++-------- fickling/hook.py | 55 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/fickling/context.py b/fickling/context.py index 1fc98d78..9a30e6f2 100644 --- a/fickling/context.py +++ b/fickling/context.py @@ -1,25 +1,19 @@ -import pickle - import fickling.hook as hook -import fickling.loader as loader from fickling.analysis import Severity class FicklingContextManager: + """Context manager that activates fickling's safety hooks on enter and removes them on exit.""" + def __init__(self, max_acceptable_severity=Severity.LIKELY_SAFE): - self.original_pickle_load = pickle.load self.max_acceptable_severity = max_acceptable_severity def __enter__(self): - # Modify the `hook_pickle_load` function to use the imported loader - wrapped_load = lambda file, *args, **kwargs: loader.load( # noqa - file, max_acceptable_severity=self.max_acceptable_severity - ) - hook.run_hook() + hook.run_hook(max_acceptable_severity=self.max_acceptable_severity) return self def __exit__(self, exc_type, exc_val, exc_tb): - pickle.load = self.original_pickle_load + hook.remove_hook() def check_safety(): diff --git a/fickling/hook.py b/fickling/hook.py index 59242170..7324ac96 100644 --- a/fickling/hook.py +++ b/fickling/hook.py @@ -3,6 +3,7 @@ import pickle import fickling.loader as loader +from fickling.analysis import Severity from fickling.ml import FicklingMLUnpickler _original_pickle_load = pickle.load @@ -29,17 +30,49 @@ def load(self): return loader.load(self._file, *self._args, **self._kwargs) -def run_hook(): - """Replace pickle.load() and pickle.Unpickler by fickling's safe versions""" - # Hook functions - pickle.load = loader.load - _pickle.load = loader.load - pickle.loads = loader.loads - _pickle.loads = loader.loads - - # Hook the Unpickler class - pickle.Unpickler = FicklingSafetyUnpickler - _pickle.Unpickler = FicklingSafetyUnpickler +def run_hook(max_acceptable_severity=Severity.LIKELY_SAFE): + """Replace pickle.load() and pickle.Unpickler by fickling's safe versions + + Args: + max_acceptable_severity: Maximum severity level to allow through. + When non-default, wraps loader functions to pass the threshold. + """ + if max_acceptable_severity != Severity.LIKELY_SAFE: + + def hooked_load(file, *args, **kwargs): + return loader.load( + file, *args, max_acceptable_severity=max_acceptable_severity, **kwargs + ) + + def hooked_loads(data, *args, **kwargs): + return loader.loads( + data, *args, max_acceptable_severity=max_acceptable_severity, **kwargs + ) + + pickle.load = hooked_load + _pickle.load = hooked_load + pickle.loads = hooked_loads + _pickle.loads = hooked_loads + + # Create Unpickler subclass that passes severity through + class SafetyUnpicklerWithSeverity(FicklingSafetyUnpickler): + def load(self): + return loader.load( + self._file, + *self._args, + max_acceptable_severity=max_acceptable_severity, + **self._kwargs, + ) + + pickle.Unpickler = SafetyUnpicklerWithSeverity + _pickle.Unpickler = SafetyUnpicklerWithSeverity + else: + pickle.load = loader.load + _pickle.load = loader.load + pickle.loads = loader.loads + _pickle.loads = loader.loads + pickle.Unpickler = FicklingSafetyUnpickler + _pickle.Unpickler = FicklingSafetyUnpickler def always_check_safety(): From bc9c293c7792adbae4065cdcd7eb23a11c04dfa4 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Wed, 4 Mar 2026 13:52:43 -0500 Subject: [PATCH 2/3] Add comprehensive tests for FicklingContextManager (#241) Cover hook lifecycle, exception safety, max_acceptable_severity enforcement, pre-existing hook preservation, duplicate-keyword safety on all paths, re-entrant/nested usage, and check_safety(). --- test/test_context.py | 304 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 test/test_context.py diff --git a/test/test_context.py b/test/test_context.py new file mode 100644 index 00000000..b0820647 --- /dev/null +++ b/test/test_context.py @@ -0,0 +1,304 @@ +import _pickle +import io +import pickle +import unittest + +import fickling +import fickling.hook as hook +from fickling.analysis import Severity +from fickling.context import FicklingContextManager +from fickling.exception import UnsafeFileError + +SAFE_DATA = pickle.dumps([1, 2, 3]) + + +class UnsafePayload: + def __reduce__(self): + import os + + return (os.system, ("echo pwned",)) + + +UNSAFE_DATA = pickle.dumps(UnsafePayload()) + + +class TestContextManagerHookLifecycle(unittest.TestCase): + """Test that __enter__ activates hooks and __exit__ fully removes them.""" + + def test_hooks_active_inside_context(self): + with FicklingContextManager(): + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + def test_all_six_entry_points_hooked(self): + """All 6 entry points must be intercepted inside the context.""" + cases = { + "pickle.load": lambda: pickle.load(io.BytesIO(UNSAFE_DATA)), + "pickle.loads": lambda: pickle.loads(UNSAFE_DATA), + "pickle.Unpickler": lambda: pickle.Unpickler(io.BytesIO(UNSAFE_DATA)).load(), + "_pickle.load": lambda: _pickle.load(io.BytesIO(UNSAFE_DATA)), + "_pickle.loads": lambda: _pickle.loads(UNSAFE_DATA), + "_pickle.Unpickler": lambda: _pickle.Unpickler(io.BytesIO(UNSAFE_DATA)).load(), + } + with FicklingContextManager(): + for name, call in cases.items(): + with self.subTest(entry_point=name): + with self.assertRaises(UnsafeFileError, msg=f"{name} not intercepted"): + call() + + def test_all_six_entry_points_restored_after_exit(self): + """All 6 entry points must be restored to originals after exiting context.""" + originals = { + "pickle.load": pickle.load, + "pickle.loads": pickle.loads, + "pickle.Unpickler": pickle.Unpickler, + "_pickle.load": _pickle.load, + "_pickle.loads": _pickle.loads, + "_pickle.Unpickler": _pickle.Unpickler, + } + + with FicklingContextManager(): + pass + + for name, original in originals.items(): + with self.subTest(entry_point=name): + current = eval(name) + self.assertIs(current, original, f"{name} not restored after __exit__") + + def test_safe_pickle_passes_through(self): + with FicklingContextManager(): + result = pickle.loads(SAFE_DATA) + self.assertEqual(result, [1, 2, 3]) + + def test_safe_pickle_via_load(self): + with FicklingContextManager(): + result = pickle.load(io.BytesIO(SAFE_DATA)) + self.assertEqual(result, [1, 2, 3]) + + +class TestContextManagerExceptionSafety(unittest.TestCase): + """Test that hooks are cleaned up even when exceptions occur.""" + + def test_hooks_restored_on_exception(self): + originals = { + "pickle.load": pickle.load, + "pickle.loads": pickle.loads, + "pickle.Unpickler": pickle.Unpickler, + "_pickle.load": _pickle.load, + "_pickle.loads": _pickle.loads, + "_pickle.Unpickler": _pickle.Unpickler, + } + + try: + with FicklingContextManager(): + raise ValueError("test error") + except ValueError: + pass + + for name, original in originals.items(): + with self.subTest(entry_point=name): + current = eval(name) + self.assertIs(current, original, f"{name} not restored after exception") + + def test_hooks_restored_on_unsafe_file_error(self): + originals = { + "pickle.load": pickle.load, + "pickle.loads": pickle.loads, + } + + try: + with FicklingContextManager(): + pickle.loads(UNSAFE_DATA) + except UnsafeFileError: + pass + + for name, original in originals.items(): + with self.subTest(entry_point=name): + current = eval(name) + self.assertIs(current, original, f"{name} not restored after UnsafeFileError") + + +class TestMaxAcceptableSeverity(unittest.TestCase): + """Test that max_acceptable_severity is properly wired through.""" + + def test_default_severity_is_likely_safe(self): + cm = FicklingContextManager() + self.assertEqual(cm.max_acceptable_severity, Severity.LIKELY_SAFE) + + def test_explicit_default_severity_passes_safe_pickle(self): + """Explicitly passing LIKELY_SAFE (the default) still allows safe pickles.""" + with FicklingContextManager(max_acceptable_severity=Severity.LIKELY_SAFE): + result = pickle.loads(SAFE_DATA) + self.assertEqual(result, [1, 2, 3]) + + def test_custom_severity_accepts_higher_threshold(self): + """With a higher severity threshold, unsafe pickles may pass through.""" + with FicklingContextManager(max_acceptable_severity=Severity.OVERTLY_MALICIOUS): + # Even overtly malicious payloads should pass when threshold is max + result = pickle.loads(UNSAFE_DATA) + # os.system("echo pwned") returns 0 + self.assertEqual(result, 0) + + def test_custom_severity_hooks_restored(self): + """Custom severity hooks are cleaned up by remove_hook().""" + originals = { + "pickle.load": pickle.load, + "pickle.loads": pickle.loads, + "_pickle.load": _pickle.load, + "_pickle.loads": _pickle.loads, + } + + with FicklingContextManager(max_acceptable_severity=Severity.SUSPICIOUS): + pass + + for name, original in originals.items(): + with self.subTest(entry_point=name): + current = eval(name) + self.assertIs(current, original, f"{name} not restored after custom severity") + + +class TestCheckSafetyConvenience(unittest.TestCase): + """Test the check_safety() convenience function.""" + + def test_returns_context_manager(self): + cm = fickling.check_safety() + self.assertIsInstance(cm, FicklingContextManager) + + def test_usable_as_context_manager(self): + with fickling.check_safety(): + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + def test_hooks_restored_after_convenience(self): + original_loads = pickle.loads + with fickling.check_safety(): + pass + self.assertIs(pickle.loads, original_loads) + + +class TestContextManagerIdempotency(unittest.TestCase): + """Test nesting and repeated use of the context manager.""" + + def test_sequential_usage(self): + for _ in range(3): + with FicklingContextManager(): + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + # After all contexts, originals are restored + result = pickle.loads(SAFE_DATA) + self.assertEqual(result, [1, 2, 3]) + + def test_no_interaction_with_manual_hook(self): + """Context manager should work independently of manual hook.run_hook().""" + # Context manager on its own + with FicklingContextManager(): + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + # Manual hook still works after + hook.run_hook() + try: + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + finally: + hook.remove_hook() + + def test_preexisting_hook_survives_context_exit(self): + """If run_hook() is active before context entry, it must remain active after exit.""" + hook.run_hook() + try: + # Hooks are active + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + # Enter and exit context manager + with FicklingContextManager(): + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + # Pre-existing hooks must still be active after context exit + with self.assertRaises(UnsafeFileError, msg="pre-existing hook was clobbered"): + pickle.loads(UNSAFE_DATA) + finally: + hook.remove_hook() + + def test_preexisting_ml_hook_survives_context_exit(self): + """If activate_safe_ml_environment() is active, context exit must preserve it.""" + hook.activate_safe_ml_environment() + try: + hooked_load = pickle.load + hooked_loads = pickle.loads + hooked_unpickler = pickle.Unpickler + + with FicklingContextManager(): + pass + + # ML hooks must still be in place + self.assertIs(pickle.load, hooked_load, "ML hook on pickle.load was clobbered") + self.assertIs(pickle.loads, hooked_loads, "ML hook on pickle.loads was clobbered") + self.assertIs( + pickle.Unpickler, + hooked_unpickler, + "ML hook on pickle.Unpickler was clobbered", + ) + finally: + hook.remove_hook() + + +class TestDuplicateKeywordSafety(unittest.TestCase): + """Test that hooked functions don't raise TypeError from duplicate kwargs.""" + + def test_explicit_severity_kwarg_does_not_raise(self): + """Caller passing max_acceptable_severity should not cause TypeError.""" + with FicklingContextManager(max_acceptable_severity=Severity.SUSPICIOUS): + # This would TypeError if the closure doesn't pop max_acceptable_severity + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA, max_acceptable_severity=Severity.LIKELY_SAFE) + + def test_unpickler_severity_kwarg_does_not_raise(self): + """Unpickler path must also handle duplicate max_acceptable_severity.""" + with FicklingContextManager(max_acceptable_severity=Severity.SUSPICIOUS): + with self.assertRaises(UnsafeFileError): + pickle.Unpickler( + io.BytesIO(UNSAFE_DATA), max_acceptable_severity=Severity.LIKELY_SAFE + ).load() + + +class TestReentrantUsage(unittest.TestCase): + """Test that the same context manager instance can be nested safely.""" + + def test_nested_same_instance_restores_correctly(self): + """Nested reuse of the same CM instance must not leak hooks.""" + original_loads = pickle.loads + cm = FicklingContextManager() + + with cm: + with cm: + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + # Inner exit: hooks still active from outer enter + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + # Outer exit: fully restored + self.assertIs(pickle.loads, original_loads) + + def test_nested_different_severities(self): + """Nested CMs with different severities restore correctly.""" + original_loads = pickle.loads + + with FicklingContextManager(max_acceptable_severity=Severity.LIKELY_SAFE): + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + with FicklingContextManager(max_acceptable_severity=Severity.OVERTLY_MALICIOUS): + # Inner context allows everything + result = pickle.loads(UNSAFE_DATA) + self.assertEqual(result, 0) + + # Outer context blocks again + with self.assertRaises(UnsafeFileError): + pickle.loads(UNSAFE_DATA) + + self.assertIs(pickle.loads, original_loads) From 00e681928cd7e46bb247c6ac2db288e4f32120d5 Mon Sep 17 00:00:00 2001 From: mldangelo Date: Wed, 4 Mar 2026 14:34:24 -0500 Subject: [PATCH 3/3] Clean up eval() in tests and Unpickler kwargs handling - Replace eval() with getattr-based _get_entry_point() helper - Move kwargs.pop to SafetyUnpicklerWithSeverity.__init__ to avoid per-load() dict copy --- fickling/hook.py | 28 ++++++++++++++++++++++++++++ test/test_context.py | 15 +++++++++++---- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/fickling/hook.py b/fickling/hook.py index 7324ac96..531d07f0 100644 --- a/fickling/hook.py +++ b/fickling/hook.py @@ -40,11 +40,13 @@ def run_hook(max_acceptable_severity=Severity.LIKELY_SAFE): if max_acceptable_severity != Severity.LIKELY_SAFE: def hooked_load(file, *args, **kwargs): + kwargs.pop("max_acceptable_severity", None) return loader.load( file, *args, max_acceptable_severity=max_acceptable_severity, **kwargs ) def hooked_loads(data, *args, **kwargs): + kwargs.pop("max_acceptable_severity", None) return loader.loads( data, *args, max_acceptable_severity=max_acceptable_severity, **kwargs ) @@ -56,6 +58,10 @@ def hooked_loads(data, *args, **kwargs): # Create Unpickler subclass that passes severity through class SafetyUnpicklerWithSeverity(FicklingSafetyUnpickler): + def __init__(self, file, *args, **kwargs): + kwargs.pop("max_acceptable_severity", None) + super().__init__(file, *args, **kwargs) + def load(self): return loader.load( self._file, @@ -108,6 +114,28 @@ def __init__(self, file, *args, **kwargs): _pickle.Unpickler = SafeMLUnpickler +def snapshot_hooks(): + """Capture the current state of all hooked pickle entry points.""" + return ( + pickle.load, + _pickle.load, + pickle.loads, + _pickle.loads, + pickle.Unpickler, + _pickle.Unpickler, + ) + + +def restore_hooks(snapshot): + """Restore pickle entry points from a previous snapshot.""" + pickle.load = snapshot[0] + _pickle.load = snapshot[1] + pickle.loads = snapshot[2] + _pickle.loads = snapshot[3] + pickle.Unpickler = snapshot[4] + _pickle.Unpickler = snapshot[5] + + def remove_hook(): """Restore original pickle functions and classes""" pickle.load = _original_pickle_load diff --git a/test/test_context.py b/test/test_context.py index b0820647..7445f72d 100644 --- a/test/test_context.py +++ b/test/test_context.py @@ -11,6 +11,13 @@ SAFE_DATA = pickle.dumps([1, 2, 3]) +_MODULES = {"pickle": pickle, "_pickle": _pickle} + + +def _get_entry_point(name): + mod_name, attr = name.split(".", 1) + return getattr(_MODULES[mod_name], attr) + class UnsafePayload: def __reduce__(self): @@ -62,7 +69,7 @@ def test_all_six_entry_points_restored_after_exit(self): for name, original in originals.items(): with self.subTest(entry_point=name): - current = eval(name) + current = _get_entry_point(name) self.assertIs(current, original, f"{name} not restored after __exit__") def test_safe_pickle_passes_through(self): @@ -97,7 +104,7 @@ def test_hooks_restored_on_exception(self): for name, original in originals.items(): with self.subTest(entry_point=name): - current = eval(name) + current = _get_entry_point(name) self.assertIs(current, original, f"{name} not restored after exception") def test_hooks_restored_on_unsafe_file_error(self): @@ -114,7 +121,7 @@ def test_hooks_restored_on_unsafe_file_error(self): for name, original in originals.items(): with self.subTest(entry_point=name): - current = eval(name) + current = _get_entry_point(name) self.assertIs(current, original, f"{name} not restored after UnsafeFileError") @@ -153,7 +160,7 @@ def test_custom_severity_hooks_restored(self): for name, original in originals.items(): with self.subTest(entry_point=name): - current = eval(name) + current = _get_entry_point(name) self.assertIs(current, original, f"{name} not restored after custom severity")