Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions fickling/context.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
import pickle

import fickling.hook as hook
import fickling.loader as loader
from fickling.analysis import Severity


class FicklingContextManager:
"""Context manager that activates fickling's safety hooks on enter and removes them on exit."""

def __init__(self, max_acceptable_severity=Severity.LIKELY_SAFE):
self.original_pickle_load = pickle.load
self.max_acceptable_severity = max_acceptable_severity

def __enter__(self):
# Modify the `hook_pickle_load` function to use the imported loader
wrapped_load = lambda file, *args, **kwargs: loader.load( # noqa
file, max_acceptable_severity=self.max_acceptable_severity
)
hook.run_hook()
hook.run_hook(max_acceptable_severity=self.max_acceptable_severity)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pickle.load = self.original_pickle_load
hook.remove_hook()


def check_safety():
Expand Down
81 changes: 71 additions & 10 deletions fickling/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle

import fickling.loader as loader
from fickling.analysis import Severity
from fickling.ml import FicklingMLUnpickler

_original_pickle_load = pickle.load
Expand All @@ -29,17 +30,55 @@ def load(self):
return loader.load(self._file, *self._args, **self._kwargs)


def run_hook():
"""Replace pickle.load() and pickle.Unpickler by fickling's safe versions"""
# Hook functions
pickle.load = loader.load
_pickle.load = loader.load
pickle.loads = loader.loads
_pickle.loads = loader.loads
def run_hook(max_acceptable_severity=Severity.LIKELY_SAFE):
"""Replace pickle.load() and pickle.Unpickler by fickling's safe versions

# Hook the Unpickler class
pickle.Unpickler = FicklingSafetyUnpickler
_pickle.Unpickler = FicklingSafetyUnpickler
Args:
max_acceptable_severity: Maximum severity level to allow through.
When non-default, wraps loader functions to pass the threshold.
"""
if max_acceptable_severity != Severity.LIKELY_SAFE:

def hooked_load(file, *args, **kwargs):
kwargs.pop("max_acceptable_severity", None)
return loader.load(
file, *args, max_acceptable_severity=max_acceptable_severity, **kwargs
)

def hooked_loads(data, *args, **kwargs):
kwargs.pop("max_acceptable_severity", None)
return loader.loads(
data, *args, max_acceptable_severity=max_acceptable_severity, **kwargs
)

pickle.load = hooked_load
_pickle.load = hooked_load
pickle.loads = hooked_loads
_pickle.loads = hooked_loads

# Create Unpickler subclass that passes severity through
class SafetyUnpicklerWithSeverity(FicklingSafetyUnpickler):
def __init__(self, file, *args, **kwargs):
kwargs.pop("max_acceptable_severity", None)
super().__init__(file, *args, **kwargs)

def load(self):
return loader.load(
self._file,
*self._args,
max_acceptable_severity=max_acceptable_severity,
**self._kwargs,
)

pickle.Unpickler = SafetyUnpicklerWithSeverity
_pickle.Unpickler = SafetyUnpicklerWithSeverity
else:
pickle.load = loader.load
_pickle.load = loader.load
pickle.loads = loader.loads
_pickle.loads = loader.loads
pickle.Unpickler = FicklingSafetyUnpickler
_pickle.Unpickler = FicklingSafetyUnpickler


def always_check_safety():
Expand Down Expand Up @@ -75,6 +114,28 @@ def __init__(self, file, *args, **kwargs):
_pickle.Unpickler = SafeMLUnpickler


def snapshot_hooks():
"""Capture the current state of all hooked pickle entry points."""
return (
pickle.load,
_pickle.load,
pickle.loads,
_pickle.loads,
pickle.Unpickler,
_pickle.Unpickler,
)


def restore_hooks(snapshot):
"""Restore pickle entry points from a previous snapshot."""
pickle.load = snapshot[0]
_pickle.load = snapshot[1]
pickle.loads = snapshot[2]
_pickle.loads = snapshot[3]
pickle.Unpickler = snapshot[4]
_pickle.Unpickler = snapshot[5]


def remove_hook():
"""Restore original pickle functions and classes"""
pickle.load = _original_pickle_load
Expand Down
Loading