From 1f63c45ccba04aa15045abd3796075a37195a867 Mon Sep 17 00:00:00 2001 From: Dan Guido Date: Thu, 22 Jan 2026 01:17:23 -0500 Subject: [PATCH 1/3] Add PyTorch and polyglot subcommands to CLI Adds new subcommands to expose PyTorchModelWrapper and polyglot module functionality via the CLI while maintaining full backward compatibility. New commands: - fickling pytorch identify FILE - Detect PyTorch format(s) - fickling pytorch show FILE - Decompile internal pickle - fickling pytorch check-safety FILE - Safety check internal pickle - fickling pytorch inject FILE ... - Inject payload into model - fickling polyglot identify FILE - Identify all possible formats - fickling polyglot properties FILE - File property analysis - fickling polyglot create F1 F2 -o O - Create polyglot file All commands support --json output and gracefully handle missing torch dependency with helpful installation instructions. Closes #101 Co-Authored-By: Claude Opus 4.5 --- fickling/cli.py | 183 +++++++++++++++++++++++--- fickling/cli_polyglot.py | 167 ++++++++++++++++++++++++ fickling/cli_pytorch.py | 200 ++++++++++++++++++++++++++++ test/test_cli.py | 273 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 807 insertions(+), 16 deletions(-) create mode 100644 fickling/cli_polyglot.py create mode 100644 fickling/cli_pytorch.py create mode 100644 test/test_cli.py diff --git a/fickling/cli.py b/fickling/cli.py index 24f87dca..8c09bf8e 100644 --- a/fickling/cli.py +++ b/fickling/cli.py @@ -10,13 +10,8 @@ DEFAULT_JSON_OUTPUT_FILE = "safety_results.json" -def main(argv: list[str] | None = None) -> int: - if argv is None: - argv = sys.argv - - parser = ArgumentParser( - description="fickling is a static analyzer and interpreter for Python pickle data" - ) +def _add_pickle_arguments(parser: ArgumentParser) -> None: + """Add the standard pickle-related arguments to a parser.""" parser.add_argument( "PICKLE_FILE", type=str, @@ -95,17 +90,10 @@ def main(argv: list[str] | None = None) -> int: action="store_true", help="print a runtime trace while interpreting the input pickle file", ) - parser.add_argument("--version", "-v", action="store_true", help="print the version and exit") - - args = parser.parse_args(argv[1:]) - if args.version: - if sys.stdout.isatty(): - print(f"fickling version {__version__}") - else: - print(__version__) - return 0 +def _handle_pickle_command(args) -> int: + """Handle the standard pickle decompilation/injection/safety commands.""" if args.create is None: if args.PICKLE_FILE == "-": if hasattr(sys.stdin, "buffer") and sys.stdin.buffer is not None: @@ -211,3 +199,166 @@ def main(argv: list[str] | None = None) -> int: file.close() return 0 + + +SUBCOMMANDS = {"pytorch", "polyglot"} + + +def _create_pickle_parser() -> ArgumentParser: + """Create parser for the original pickle commands (backward compatibility).""" + parser = ArgumentParser( + description="fickling is a static analyzer and interpreter for Python pickle data" + ) + parser.add_argument("--version", "-v", action="store_true", help="print the version and exit") + _add_pickle_arguments(parser) + return parser + + +def _create_subcommand_parser() -> ArgumentParser: + """Create parser with subcommands for PyTorch and polyglot operations.""" + parser = ArgumentParser( + description="fickling is a static analyzer and interpreter for Python pickle data" + ) + parser.add_argument("--version", "-v", action="store_true", help="print the version and exit") + + subparsers = parser.add_subparsers(dest="command", help="available commands") + + # PyTorch subcommand + pytorch_parser = subparsers.add_parser("pytorch", help="PyTorch model operations") + _setup_pytorch_subcommand(pytorch_parser) + + # Polyglot subcommand + polyglot_parser = subparsers.add_parser("polyglot", help="polyglot detection and creation") + _setup_polyglot_subcommand(polyglot_parser) + + return parser + + +def _get_first_positional(argv: list[str]) -> str | None: + """Get the first non-flag argument (potential subcommand or file).""" + for arg in argv[1:]: + if not arg.startswith("-"): + return arg + return None + + +def main(argv: list[str] | None = None) -> int: + if argv is None: + argv = sys.argv + + # Check for version flag first + if "--version" in argv or "-v" in argv: + if sys.stdout.isatty(): + print(f"fickling version {__version__}") + else: + print(__version__) + return 0 + + # Determine if we're using a subcommand or the original CLI + first_positional = _get_first_positional(argv) + + if first_positional in SUBCOMMANDS: + # Use subcommand parser + parser = _create_subcommand_parser() + args = parser.parse_args(argv[1:]) + + if args.command == "pytorch": + from .cli_pytorch import handle_pytorch_command + + return handle_pytorch_command(args) + if args.command == "polyglot": + from .cli_polyglot import handle_polyglot_command + + return handle_polyglot_command(args) + # Should not reach here + return 1 + # Use original pickle parser for backward compatibility + parser = _create_pickle_parser() + args = parser.parse_args(argv[1:]) + return _handle_pickle_command(args) + + +def _setup_pytorch_subcommand(parser: ArgumentParser) -> None: + """Set up the pytorch subcommand with its sub-subcommands.""" + subparsers = parser.add_subparsers(dest="pytorch_command", help="pytorch operations") + + # identify + identify_parser = subparsers.add_parser("identify", help="detect PyTorch file format(s)") + identify_parser.add_argument("file", type=str, help="path to the PyTorch model file") + identify_parser.add_argument("--json", action="store_true", help="output results as JSON") + + # show + show_parser = subparsers.add_parser("show", help="decompile internal pickle from PyTorch model") + show_parser.add_argument("file", type=str, help="path to the PyTorch model file") + show_parser.add_argument( + "--force", "-f", action="store_true", help="force processing unsupported formats" + ) + show_parser.add_argument("--trace", "-t", action="store_true", help="print a runtime trace") + + # check-safety + safety_parser = subparsers.add_parser( + "check-safety", help="run safety analysis on internal pickle" + ) + safety_parser.add_argument("file", type=str, help="path to the PyTorch model file") + safety_parser.add_argument( + "--force", "-f", action="store_true", help="force processing unsupported formats" + ) + safety_parser.add_argument( + "--json-output", + type=str, + default=None, + help="path to output JSON file for analysis results", + ) + safety_parser.add_argument( + "--print-results", "-p", action="store_true", help="print results to console" + ) + + # inject + inject_parser = subparsers.add_parser("inject", help="inject payload into PyTorch model") + inject_parser.add_argument("file", type=str, help="path to the PyTorch model file") + inject_parser.add_argument("-o", "--output", type=str, required=True, help="output file path") + inject_parser.add_argument( + "-c", "--code", type=str, required=True, help="Python code to inject" + ) + inject_parser.add_argument( + "--method", + type=str, + choices=["insertion", "combination"], + default="insertion", + help="injection method (default: insertion)", + ) + inject_parser.add_argument( + "--force", "-f", action="store_true", help="force processing unsupported formats" + ) + inject_parser.add_argument( + "--overwrite", action="store_true", help="overwrite original file with output" + ) + + +def _setup_polyglot_subcommand(parser: ArgumentParser) -> None: + """Set up the polyglot subcommand with its sub-subcommands.""" + subparsers = parser.add_subparsers(dest="polyglot_command", help="polyglot operations") + + # identify + identify_parser = subparsers.add_parser( + "identify", help="identify all possible PyTorch file formats" + ) + identify_parser.add_argument("file", type=str, help="path to the file to identify") + identify_parser.add_argument("--json", action="store_true", help="output results as JSON") + + # properties + properties_parser = subparsers.add_parser("properties", help="analyze file properties") + properties_parser.add_argument("file", type=str, help="path to the file to analyze") + properties_parser.add_argument( + "-r", "--recursive", action="store_true", help="analyze recursively into archives" + ) + properties_parser.add_argument("--json", action="store_true", help="output results as JSON") + + # create + create_parser = subparsers.add_parser("create", help="create a polyglot file") + create_parser.add_argument("file1", type=str, help="first input file") + create_parser.add_argument("file2", type=str, help="second input file") + create_parser.add_argument("-o", "--output", type=str, default=None, help="output file path") + create_parser.add_argument( + "--quiet", "-q", action="store_true", help="suppress output messages" + ) diff --git a/fickling/cli_polyglot.py b/fickling/cli_polyglot.py new file mode 100644 index 00000000..b6d54b02 --- /dev/null +++ b/fickling/cli_polyglot.py @@ -0,0 +1,167 @@ +"""CLI handlers for polyglot detection and creation operations.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + + +def _check_torch_available() -> bool: + """Check if PyTorch is available and provide helpful error message if not.""" + try: + import torch # noqa: F401 + + return True + except ImportError: + sys.stderr.write( + "Error: PyTorch is required for this command.\n" + "Please install it with: pip install fickling[torch]\n" + ) + return False + + +def handle_polyglot_command(args) -> int: + """Handle the polyglot subcommand and its sub-subcommands.""" + if args.polyglot_command is None: + sys.stderr.write("Error: polyglot subcommand required.\n") + sys.stderr.write("Available commands: identify, properties, create\n") + sys.stderr.write("Use 'fickling polyglot --help' for more information.\n") + return 1 + + if args.polyglot_command == "identify": + return _handle_polyglot_identify(args) + if args.polyglot_command == "properties": + return _handle_polyglot_properties(args) + if args.polyglot_command == "create": + return _handle_polyglot_create(args) + sys.stderr.write(f"Error: unknown polyglot command '{args.polyglot_command}'\n") + return 1 + + +def _handle_polyglot_identify(args) -> int: + """Handle 'fickling polyglot identify FILE'.""" + if not _check_torch_available(): + return 1 + + from .polyglot import identify_pytorch_file_format + + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + try: + formats = identify_pytorch_file_format(args.file, print_results=False) + + if hasattr(args, "json") and args.json: + result = { + "file": str(file_path), + "formats": formats, + "primary_format": formats[0] if formats else None, + "is_polyglot": len(formats) > 1, + } + print(json.dumps(result, indent=2)) + else: + if formats: + print(f"Identified format(s) for {args.file}:") + for i, fmt in enumerate(formats): + prefix = " [primary]" if i == 0 else " [also] " + print(f"{prefix} {fmt}") + if len(formats) > 1: + print("\n Note: Multiple formats detected - this may be a polyglot file.") + else: + print(f"No PyTorch formats detected for {args.file}") + print("This file may not be a PyTorch file, or it may be in an unsupported format.") + + return 0 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error identifying file: {e}\n") + return 1 + + +def _handle_polyglot_properties(args) -> int: + """Handle 'fickling polyglot properties FILE'.""" + if not _check_torch_available(): + return 1 + + from .polyglot import find_file_properties, find_file_properties_recursively + + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + recursive = getattr(args, "recursive", False) + + try: + if recursive: + properties = find_file_properties_recursively(args.file, print_properties=False) + else: + properties = find_file_properties(args.file, print_properties=False) + + if hasattr(args, "json") and args.json: + result = { + "file": str(file_path), + "properties": properties, + } + print(json.dumps(result, indent=2)) + else: + print(f"File properties for {args.file}:") + _print_properties(properties, indent=2) + + return 0 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error analyzing file properties: {e}\n") + return 1 + + +def _print_properties(properties: dict, indent: int = 0) -> None: + """Pretty-print file properties.""" + prefix = " " * indent + for key, value in properties.items(): + if key == "children" and isinstance(value, dict): + print(f"{prefix}{key}:") + for child_name, child_props in value.items(): + print(f"{prefix} {child_name}:") + if child_props is not None: + _print_properties(child_props, indent + 4) + else: + print(f"{prefix} (unable to read)") + else: + print(f"{prefix}{key}: {value}") + + +def _handle_polyglot_create(args) -> int: + """Handle 'fickling polyglot create FILE1 FILE2 -o OUTPUT'.""" + if not _check_torch_available(): + return 1 + + from .polyglot import create_polyglot + + file1_path = Path(args.file1) + file2_path = Path(args.file2) + + if not file1_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file1}\n") + return 1 + if not file2_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file2}\n") + return 1 + + output_path = getattr(args, "output", None) + quiet = getattr(args, "quiet", False) + + try: + success = create_polyglot( + args.file1, args.file2, polyglot_file_name=output_path, print_results=not quiet + ) + + if success: + return 0 + if not quiet: + sys.stderr.write("Failed to create polyglot. The file formats may not be compatible.\n") + return 1 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error creating polyglot: {e}\n") + return 1 diff --git a/fickling/cli_pytorch.py b/fickling/cli_pytorch.py new file mode 100644 index 00000000..494965ba --- /dev/null +++ b/fickling/cli_pytorch.py @@ -0,0 +1,200 @@ +"""CLI handlers for PyTorch model operations.""" + +from __future__ import annotations + +import json +import sys +from ast import unparse +from pathlib import Path + +from .analysis import Severity, check_safety +from .cli import DEFAULT_JSON_OUTPUT_FILE +from .fickle import Interpreter +from .tracing import Trace + + +def _check_torch_available() -> bool: + """Check if PyTorch is available and provide helpful error message if not.""" + try: + import torch # noqa: F401 + + return True + except ImportError: + sys.stderr.write( + "Error: PyTorch is required for this command.\n" + "Please install it with: pip install fickling[torch]\n" + ) + return False + + +def handle_pytorch_command(args) -> int: + """Handle the pytorch subcommand and its sub-subcommands.""" + if args.pytorch_command is None: + sys.stderr.write("Error: pytorch subcommand required.\n") + sys.stderr.write("Available commands: identify, show, check-safety, inject\n") + sys.stderr.write("Use 'fickling pytorch --help' for more information.\n") + return 1 + + if args.pytorch_command == "identify": + return _handle_pytorch_identify(args) + if args.pytorch_command == "show": + return _handle_pytorch_show(args) + if args.pytorch_command == "check-safety": + return _handle_pytorch_check_safety(args) + if args.pytorch_command == "inject": + return _handle_pytorch_inject(args) + sys.stderr.write(f"Error: unknown pytorch command '{args.pytorch_command}'\n") + return 1 + + +def _handle_pytorch_identify(args) -> int: + """Handle 'fickling pytorch identify FILE'.""" + if not _check_torch_available(): + return 1 + + from .pytorch import PyTorchModelWrapper + + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + try: + wrapper = PyTorchModelWrapper(file_path, force=True) + formats = wrapper.validate_file_format() + + formats_list = list(formats) + if hasattr(args, "json") and args.json: + result = { + "file": str(file_path), + "formats": formats_list, + "primary_format": formats_list[0] if formats_list else None, + } + print(json.dumps(result, indent=2)) + else: + if formats_list: + print(f"Detected format(s) for {args.file}:") + for i, fmt in enumerate(formats_list): + prefix = " [primary]" if i == 0 else " [also] " + print(f"{prefix} {fmt}") + else: + print(f"No PyTorch formats detected for {args.file}") + + return 0 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error identifying file: {e}\n") + return 1 + + +def _handle_pytorch_show(args) -> int: + """Handle 'fickling pytorch show FILE'.""" + if not _check_torch_available(): + return 1 + + from .pytorch import PyTorchModelWrapper + + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + force = getattr(args, "force", False) + + try: + wrapper = PyTorchModelWrapper(file_path, force=force) + pickled = wrapper.pickled + + interpreter = Interpreter(pickled) + if getattr(args, "trace", False): + trace = Trace(interpreter) + print(unparse(trace.run())) + else: + print(unparse(interpreter.to_ast())) + + return 0 + except ValueError as e: + sys.stderr.write(f"Error: {e}\n") + sys.stderr.write("Use --force to attempt processing anyway.\n") + return 1 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error reading PyTorch model: {e}\n") + return 1 + + +def _handle_pytorch_check_safety(args) -> int: + """Handle 'fickling pytorch check-safety FILE'.""" + if not _check_torch_available(): + return 1 + + from .pytorch import PyTorchModelWrapper + + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + force = getattr(args, "force", False) + json_output_path = getattr(args, "json_output", None) or DEFAULT_JSON_OUTPUT_FILE + print_results = getattr(args, "print_results", False) + + try: + wrapper = PyTorchModelWrapper(file_path, force=force) + pickled = wrapper.pickled + + safety_results = check_safety(pickled, json_output_path=json_output_path) + + if print_results: + print(safety_results.to_string()) + + if safety_results.severity > Severity.LIKELY_SAFE: + if print_results: + sys.stderr.write( + "Warning: Fickling detected that the PyTorch model may be unsafe.\n\n" + "Do not load this model if it is from an untrusted source!\n\n" + ) + return 1 + return 0 + except ValueError as e: + sys.stderr.write(f"Error: {e}\n") + sys.stderr.write("Use --force to attempt processing anyway.\n") + return 1 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error checking PyTorch model safety: {e}\n") + return 1 + + +def _handle_pytorch_inject(args) -> int: + """Handle 'fickling pytorch inject FILE -o OUTPUT -c CODE'.""" + if not _check_torch_available(): + return 1 + + from .pytorch import PyTorchModelWrapper + + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + output_path = Path(args.output) + if output_path.exists() and not getattr(args, "overwrite", False): + sys.stderr.write(f"Error: output file already exists: {args.output}\n") + sys.stderr.write("Use --overwrite to replace the original file.\n") + return 1 + + force = getattr(args, "force", False) + method = getattr(args, "method", "insertion") + overwrite = getattr(args, "overwrite", False) + code = args.code + + try: + wrapper = PyTorchModelWrapper(file_path, force=force) + wrapper.inject_payload(code, output_path, injection=method, overwrite=overwrite) + print(f"Payload injected successfully. Output written to: {output_path}") + return 0 + except ValueError as e: + sys.stderr.write(f"Error: {e}\n") + sys.stderr.write("Use --force to attempt processing anyway.\n") + return 1 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error injecting payload: {e}\n") + return 1 diff --git a/test/test_cli.py b/test/test_cli.py new file mode 100644 index 00000000..1d456030 --- /dev/null +++ b/test/test_cli.py @@ -0,0 +1,273 @@ +"""Tests for the fickling CLI.""" + +from __future__ import annotations + +import io +import tempfile +from contextlib import redirect_stderr, redirect_stdout +from pathlib import Path +from pickle import dumps +from unittest import TestCase + +import torch +import torchvision.models as models + +from fickling.cli import main + + +class TestCLIBackwardCompatibility(TestCase): + """Test that existing CLI behavior is preserved.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + # Create a simple pickle file + self.pickle_file = self.tmppath / "test.pkl" + with open(self.pickle_file, "wb") as f: + f.write(dumps({"test": "data"})) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_version_flag(self): + """Test --version flag.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "--version"]) + self.assertEqual(result, 0) + output = stdout.getvalue() + # Should contain version number + self.assertTrue(output.strip()) + + def test_decompile_pickle(self): + """Test basic pickle decompilation.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", str(self.pickle_file)]) + self.assertEqual(result, 0) + output = stdout.getvalue() + # Should contain decompiled code + self.assertIn("result", output) + + def test_check_safety_safe_file(self): + """Test --check-safety on a safe pickle file.""" + result = main(["fickling", "--check-safety", str(self.pickle_file)]) + self.assertEqual(result, 0) + + def test_help_flag(self): + """Test --help flag.""" + with self.assertRaises(SystemExit) as cm: + main(["fickling", "--help"]) + self.assertEqual(cm.exception.code, 0) + + +class TestCLISubcommandRouting(TestCase): + """Test that subcommand routing works correctly.""" + + def test_pytorch_subcommand_no_args(self): + """Test 'fickling pytorch' with no arguments.""" + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "pytorch"]) + self.assertEqual(result, 1) + self.assertIn("pytorch subcommand required", stderr.getvalue()) + + def test_polyglot_subcommand_no_args(self): + """Test 'fickling polyglot' with no arguments.""" + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "polyglot"]) + self.assertEqual(result, 1) + self.assertIn("polyglot subcommand required", stderr.getvalue()) + + +class TestPyTorchCLI(TestCase): + """Test PyTorch CLI subcommands.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + # Create a PyTorch model file + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_pytorch_identify(self): + """Test 'fickling pytorch identify'.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "pytorch", "identify", str(self.model_file)]) + self.assertEqual(result, 0) + output = stdout.getvalue() + self.assertIn("Detected format", output) + self.assertIn("PyTorch", output) + + def test_pytorch_identify_json(self): + """Test 'fickling pytorch identify --json'.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "pytorch", "identify", "--json", str(self.model_file)]) + self.assertEqual(result, 0) + output = stdout.getvalue() + # Should be valid JSON + import json + + data = json.loads(output) + self.assertIn("formats", data) + self.assertIn("primary_format", data) + + def test_pytorch_show(self): + """Test 'fickling pytorch show'.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "pytorch", "show", str(self.model_file)]) + self.assertEqual(result, 0) + output = stdout.getvalue() + # Should contain decompiled code + self.assertIn("result", output) + + def test_pytorch_check_safety(self): + """Test 'fickling pytorch check-safety'.""" + result = main(["fickling", "pytorch", "check-safety", str(self.model_file)]) + # Result can be 0 (safe) or 1 (potentially unsafe) - just verify it runs + self.assertIn(result, [0, 1]) + + def test_pytorch_inject(self): + """Test 'fickling pytorch inject'.""" + output_file = self.tmppath / "injected.pth" + result = main( + [ + "fickling", + "pytorch", + "inject", + str(self.model_file), + "-o", + str(output_file), + "-c", + "print('test')", + ] + ) + self.assertEqual(result, 0) + self.assertTrue(output_file.exists()) + + def test_pytorch_inject_combination_method(self): + """Test 'fickling pytorch inject --method combination'.""" + output_file = self.tmppath / "injected_combo.pth" + result = main( + [ + "fickling", + "pytorch", + "inject", + str(self.model_file), + "-o", + str(output_file), + "-c", + "print('test')", + "--method", + "combination", + ] + ) + self.assertEqual(result, 0) + self.assertTrue(output_file.exists()) + + def test_pytorch_file_not_found(self): + """Test PyTorch commands with non-existent file.""" + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "pytorch", "identify", "/nonexistent/file.pth"]) + self.assertEqual(result, 1) + self.assertIn("file not found", stderr.getvalue()) + + +class TestPolyglotCLI(TestCase): + """Test polyglot CLI subcommands.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + # Create a PyTorch model file + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_polyglot_identify(self): + """Test 'fickling polyglot identify'.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "polyglot", "identify", str(self.model_file)]) + self.assertEqual(result, 0) + output = stdout.getvalue() + self.assertIn("Identified format", output) + + def test_polyglot_identify_json(self): + """Test 'fickling polyglot identify --json'.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "polyglot", "identify", "--json", str(self.model_file)]) + self.assertEqual(result, 0) + output = stdout.getvalue() + # Should be valid JSON + import json + + data = json.loads(output) + self.assertIn("formats", data) + self.assertIn("is_polyglot", data) + + def test_polyglot_properties(self): + """Test 'fickling polyglot properties'.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "polyglot", "properties", str(self.model_file)]) + self.assertEqual(result, 0) + output = stdout.getvalue() + self.assertIn("File properties", output) + + def test_polyglot_properties_json(self): + """Test 'fickling polyglot properties --json'.""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "polyglot", "properties", "--json", str(self.model_file)]) + self.assertEqual(result, 0) + output = stdout.getvalue() + # Should be valid JSON + import json + + data = json.loads(output) + self.assertIn("properties", data) + + def test_polyglot_file_not_found(self): + """Test polyglot commands with non-existent file.""" + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "polyglot", "identify", "/nonexistent/file.pth"]) + self.assertEqual(result, 1) + self.assertIn("file not found", stderr.getvalue()) + + +class TestCLIErrorHandling(TestCase): + """Test CLI error handling.""" + + def test_nonexistent_pickle_file(self): + """Test decompiling non-existent file raises FileNotFoundError.""" + # The original CLI raises FileNotFoundError for non-existent files + with self.assertRaises(FileNotFoundError): + main(["fickling", "/nonexistent/file.pkl"]) + + def test_pytorch_inject_missing_output(self): + """Test pytorch inject without required --output flag.""" + with self.assertRaises(SystemExit): + main(["fickling", "pytorch", "inject", "file.pth", "-c", "code"]) + + def test_pytorch_inject_missing_code(self): + """Test pytorch inject without required --code flag.""" + with self.assertRaises(SystemExit): + main(["fickling", "pytorch", "inject", "file.pth", "-o", "out.pth"]) From 313981c63306e2fd732a780ab25d184915838038 Mon Sep 17 00:00:00 2001 From: Dan Guido Date: Thu, 22 Jan 2026 01:40:33 -0500 Subject: [PATCH 2/3] Redesign CLI with unified format-agnostic interface Replace nested subcommands (pytorch/polyglot) with flat, auto-detecting commands: - fickling check FILE: Safety check any pickle/model (auto-detects format) - fickling inject FILE -c CODE -o OUT: Inject payload into any format - fickling info FILE: Show format and properties - fickling create-polyglot F1 F2 -o OUT: Create polyglot files Key changes: - Add auto_load() to loader.py for automatic format detection - Rewrite cli.py with flat command structure - Delete cli_pytorch.py and cli_polyglot.py (absorbed into cli.py) - Maintain full backward compatibility with legacy flags - Update tests for new interface The tool now figures out if input is PyTorch, TorchScript, plain pickle, etc. without users needing to specify format explicitly. Co-Authored-By: Claude Opus 4.5 --- fickling/cli.py | 796 ++++++++++++++++++++++++++------------- fickling/cli_polyglot.py | 167 -------- fickling/cli_pytorch.py | 200 ---------- fickling/loader.py | 65 +++- test/test_cli.py | 285 ++++++++------ 5 files changed, 772 insertions(+), 741 deletions(-) delete mode 100644 fickling/cli_polyglot.py delete mode 100644 fickling/cli_pytorch.py diff --git a/fickling/cli.py b/fickling/cli.py index 8c09bf8e..7b75e876 100644 --- a/fickling/cli.py +++ b/fickling/cli.py @@ -1,364 +1,636 @@ +"""Fickling CLI - Pickle security analyzer with auto-detection.""" + from __future__ import annotations +import json import sys from argparse import ArgumentParser from ast import unparse +from pathlib import Path from . import __version__, fickle, tracing from .analysis import Severity, check_safety DEFAULT_JSON_OUTPUT_FILE = "safety_results.json" +# Commands that use the new subcommand interface +COMMANDS = {"check", "inject", "info", "create-polyglot"} + + +def _check_torch_available() -> bool: + """Check if PyTorch is available and provide helpful error message if not.""" + try: + import torch # noqa: F401 + + return True + except ImportError: + sys.stderr.write( + "Error: PyTorch is required for this command.\n" + "Please install it with: pip install fickling[torch]\n" + ) + return False -def _add_pickle_arguments(parser: ArgumentParser) -> None: - """Add the standard pickle-related arguments to a parser.""" + +def _create_legacy_parser() -> ArgumentParser: + """Create parser for legacy CLI behavior (backward compatibility).""" + parser = ArgumentParser( + prog="fickling", + description="Pickle security analyzer with auto-detection for PyTorch models", + ) + parser.add_argument("--version", "-v", action="store_true", help="print version and exit") parser.add_argument( - "PICKLE_FILE", + "file", type=str, nargs="?", default="-", - help="path to the pickle file to either " - "analyze or create (default is '-' for " - "STDIN/STDOUT)", + help="file to analyze (default: stdin)", + ) + parser.add_argument( + "--check-safety", + "-s", + action="store_true", + help="(legacy) run safety analysis - prefer 'fickling check FILE'", ) - options = parser.add_mutually_exclusive_group() - options.add_argument( + parser.add_argument( "--inject", "-i", type=str, default=None, - help="inject the specified Python code to be run at the end of unpickling, " - "and output the resulting pickle data", + help="(legacy) inject code - prefer 'fickling inject FILE -c CODE -o OUT'", ) parser.add_argument( "--inject-target", type=int, default=0, - help="some machine learning frameworks stack multiple pickles into the same model file; " - "this option specifies the index of the pickle file in which to inject the code from the " - "`--inject` command (default is 0)", + help="index of stacked pickle to inject into (default: 0)", + ) + parser.add_argument( + "--create", + "-c", + type=str, + default=None, + help="(legacy) create pickle from Python expression", ) - options.add_argument("--create", "-c", type=str, default=None) parser.add_argument( "--run-last", "-l", action="store_true", - help="used with --inject to have the injected code " - "run after the existing pickling code in " - "PICKLE_FILE (default is for the injected code " - "to be run before the existing code)", + help="run injected code after existing code", ) parser.add_argument( "--replace-result", "-r", action="store_true", - help=( - "used with --inject to replace the unpickling result of the code in PICKLE_FILE " - "with the return value of the injected code. Either way, the preexisting pickling " - "code is still executed." - ), + help="replace unpickle result with injected code return value", ) - options.add_argument( - "--check-safety", - "-s", - action="store_true", - help=( - "test if the given pickle file is known to be unsafe. If so, exit with non-zero " - "status. This test is not guaranteed correct; the pickle file may still be unsafe " - "even if this check exits with code zero." - ), - ) - parser.add_argument( "--json-output", type=str, default=None, - help="path to the output JSON file to store the analysis results from check-safety." - f"If not provided, a default file named {DEFAULT_JSON_OUTPUT_FILE} will be used.", + help=f"path to output JSON file (default: {DEFAULT_JSON_OUTPUT_FILE})", ) - parser.add_argument( "--print-results", "-p", action="store_true", - help="Print the analysis results to the console when checking safety.", + help="print analysis results to console", ) - parser.add_argument( "--trace", "-t", action="store_true", - help="print a runtime trace while interpreting the input pickle file", + help="print a runtime trace while interpreting", ) + return parser -def _handle_pickle_command(args) -> int: - """Handle the standard pickle decompilation/injection/safety commands.""" - if args.create is None: - if args.PICKLE_FILE == "-": - if hasattr(sys.stdin, "buffer") and sys.stdin.buffer is not None: - file = sys.stdin.buffer +def _create_command_parser() -> ArgumentParser: + """Create parser with subcommands for new flat command structure.""" + parser = ArgumentParser( + prog="fickling", + description="Pickle security analyzer with auto-detection for PyTorch models", + ) + parser.add_argument("--version", "-v", action="store_true", help="print version and exit") + + subparsers = parser.add_subparsers(dest="command", help="available commands") + + # === check: Safety analysis === + check_parser = subparsers.add_parser( + "check", + help="safety check any pickle/model file", + description="Run safety analysis on any pickle or PyTorch model file (auto-detects format)", + ) + check_parser.add_argument("file", type=str, help="file to check") + check_parser.add_argument("--json", action="store_true", help="output results as JSON") + check_parser.add_argument( + "--json-output", + type=str, + default=None, + help=f"path to output JSON file (default: {DEFAULT_JSON_OUTPUT_FILE})", + ) + check_parser.add_argument( + "--print-results", "-p", action="store_true", help="print detailed results to console" + ) + + # === inject: Payload injection === + inject_parser = subparsers.add_parser( + "inject", + help="inject payload into pickle/model file", + description="Inject Python code into a pickle or PyTorch model file (auto-detects format)", + ) + inject_parser.add_argument("file", type=str, help="file to inject into") + inject_parser.add_argument( + "-c", "--code", type=str, required=True, help="Python code to inject" + ) + inject_parser.add_argument("-o", "--output", type=str, required=True, help="output file path") + inject_parser.add_argument( + "--method", + type=str, + choices=["insertion", "combination"], + default="insertion", + help="injection method for PyTorch models (default: insertion)", + ) + inject_parser.add_argument( + "--run-last", + "-l", + action="store_true", + help="run injected code after existing code (default: before)", + ) + inject_parser.add_argument( + "--replace-result", + "-r", + action="store_true", + help="replace unpickle result with injected code return value", + ) + inject_parser.add_argument( + "--overwrite", action="store_true", help="overwrite output file if exists" + ) + + # === info: Format identification === + info_parser = subparsers.add_parser( + "info", + help="show format and properties of a file", + description="Identify file format and show properties (requires PyTorch for full detection)", + ) + info_parser.add_argument("file", type=str, help="file to analyze") + info_parser.add_argument("--json", action="store_true", help="output results as JSON") + info_parser.add_argument( + "-r", "--recursive", action="store_true", help="analyze recursively into archives" + ) + + # === create-polyglot: Polyglot creation === + polyglot_parser = subparsers.add_parser( + "create-polyglot", + help="create a polyglot file from two inputs", + description="Create a polyglot file by combining two PyTorch/pickle files", + ) + polyglot_parser.add_argument("file1", type=str, help="first input file") + polyglot_parser.add_argument("file2", type=str, help="second input file") + polyglot_parser.add_argument("-o", "--output", type=str, default=None, help="output file path") + polyglot_parser.add_argument( + "--quiet", "-q", action="store_true", help="suppress output messages" + ) + + return parser + + +def _get_first_positional(argv: list[str]) -> str | None: + """Get the first non-flag argument (potential command or file).""" + for arg in argv[1:]: + if not arg.startswith("-"): + return arg + return None + + +def main(argv: list[str] | None = None) -> int: + """Main CLI entry point.""" + if argv is None: + argv = sys.argv + + # Check for version flag first + if "--version" in argv or "-v" in argv: + if len(argv) == 2: # Only version flag present + if sys.stdout.isatty(): + print(f"fickling version {__version__}") else: - file = sys.stdin + print(__version__) + return 0 + + # Determine if we're using a new command or legacy CLI + first_positional = _get_first_positional(argv) + + if first_positional in COMMANDS: + # Use new command parser + parser = _create_command_parser() + args = parser.parse_args(argv[1:]) + + if args.command == "check": + return _handle_check(args) + if args.command == "inject": + return _handle_inject(args) + if args.command == "info": + return _handle_info(args) + if args.command == "create-polyglot": + return _handle_create_polyglot(args) + return 1 + + # Use legacy parser for backward compatibility + parser = _create_legacy_parser() + args = parser.parse_args(argv[1:]) + return _handle_legacy(args) + + +def _handle_check(args) -> int: + """Handle 'fickling check FILE' - safety analysis with auto-detection.""" + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + json_output_path = args.json_output or DEFAULT_JSON_OUTPUT_FILE + print_results = getattr(args, "print_results", False) + + try: + from .loader import auto_load + + format_name, stacked_pickled = auto_load(file_path) + + if not getattr(args, "json", False): + print(f"Detected format: {format_name}") + + was_safe = True + all_results = [] + + for pickled in stacked_pickled: + safety_results = check_safety(pickled, json_output_path=json_output_path) + all_results.append(safety_results) + + if safety_results.severity > Severity.LIKELY_SAFE: + was_safe = False + + if getattr(args, "json", False): + result = { + "file": str(file_path), + "format": format_name, + "safe": was_safe, + "severity": max(r.severity.value for r in all_results), + "results": [r.to_dict() for r in all_results], + } + print(json.dumps(result, indent=2)) else: - file = open(args.PICKLE_FILE, "rb") - try: - stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) - except fickle.PickleDecodeError as e: - sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") - if args.check_safety: - sys.stderr.write( - "Parsing errors might be indicative of a maliciously crafted pickle file. DO NOT TRUST this file without performing further analysis!\n" - ) - sys.stderr.write( - "\n(If this is a valid pickle file, please report the error at https://github.com/trailofbits/fickling)\n" - ) - return 1 - finally: - file.close() + if print_results: + for i, safety_results in enumerate(all_results): + if len(all_results) > 1: + print(f"\n--- Pickle {i} ---") + print(safety_results.to_string()) - if args.inject is not None: - if args.inject_target >= len(stacked_pickled): + if was_safe: + print("No unsafe operations detected.") + else: sys.stderr.write( - f"Error: --inject-target {args.inject_target} is too high; there are only " - f"{len(stacked_pickled)} stacked pickle files in the input\n" + "\nWarning: Potentially unsafe operations detected.\n" + "Do not unpickle this file if it is from an untrusted source!\n" ) + + return 0 if was_safe else 1 + + except FileNotFoundError as e: + sys.stderr.write(f"Error: {e}\n") + return 1 + except ValueError as e: + sys.stderr.write(f"Error loading file: {e}\n") + return 1 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error: {e}\n") + return 1 + + +def _handle_inject(args) -> int: + """Handle 'fickling inject FILE -c CODE -o OUT' - payload injection with auto-detection.""" + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + output_path = Path(args.output) + if output_path.exists() and not getattr(args, "overwrite", False): + sys.stderr.write(f"Error: output file already exists: {args.output}\n") + sys.stderr.write("Use --overwrite to replace it.\n") + return 1 + + try: + from .loader import auto_load + + format_name, stacked_pickled = auto_load(file_path) + print(f"Detected format: {format_name}") + + # For PyTorch ZIP formats, use PyTorchModelWrapper for proper injection + if format_name in ("PyTorch v1.3", "TorchScript v1.4", "TorchScript v1.3"): + if not _check_torch_available(): return 1 - if hasattr(sys.stdout, "buffer") and sys.stdout.buffer is not None: - buffer = sys.stdout.buffer - else: - buffer = sys.stdout - for pickled in stacked_pickled[: args.inject_target]: + + from .pytorch import PyTorchModelWrapper + + method = getattr(args, "method", "insertion") + overwrite = getattr(args, "overwrite", False) + + wrapper = PyTorchModelWrapper(file_path, force=True) + wrapper.inject_payload(args.code, output_path, injection=method, overwrite=overwrite) + print(f"Payload injected successfully. Output: {output_path}") + return 0 + + # For plain pickle, use direct injection + if args.output == "-": + buffer = ( + sys.stdout.buffer + if hasattr(sys.stdout, "buffer") and sys.stdout.buffer + else sys.stdout + ) + should_close = False + else: + buffer = open(output_path, "wb") + should_close = True + + try: + inject_target = getattr(args, "inject_target", 0) + if inject_target >= len(stacked_pickled): + inject_target = 0 + + for pickled in stacked_pickled[:inject_target]: pickled.dump(buffer) - pickled = stacked_pickled[args.inject_target] - if not isinstance(pickled[-1], fickle.Stop): - sys.stderr.write( - "Warning: The last opcode of the input file was expected to be STOP, but was " - f"in fact {pickled[-1].info.name}" - ) + + pickled = stacked_pickled[inject_target] pickled.insert_python_eval( - args.inject, - run_first=not args.run_last, - use_output_as_unpickle_result=args.replace_result, + args.code, + run_first=not getattr(args, "run_last", False), + use_output_as_unpickle_result=getattr(args, "replace_result", False), ) pickled.dump(buffer) - for pickled in stacked_pickled[args.inject_target + 1 :]: + + for pickled in stacked_pickled[inject_target + 1 :]: pickled.dump(buffer) - elif args.check_safety: - was_safe = True - json_output_path = args.json_output or DEFAULT_JSON_OUTPUT_FILE - for pickled in stacked_pickled: - safety_results = check_safety(pickled, json_output_path=json_output_path) - - # Print results if requested - if args.print_results: - print(safety_results.to_string()) - if safety_results.severity > Severity.LIKELY_SAFE: - was_safe = False - if args.print_results: - sys.stderr.write( - "Warning: Fickling detected that the pickle file may be unsafe.\n\n" - "Do not unpickle this file if it is from an untrusted source!\n\n" - ) + print(f"Payload injected successfully. Output: {output_path}") + return 0 + finally: + if should_close: + buffer.close() - return [1, 0][was_safe] + except FileNotFoundError as e: + sys.stderr.write(f"Error: {e}\n") + return 1 + except ValueError as e: + sys.stderr.write(f"Error: {e}\n") + return 1 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error injecting payload: {e}\n") + return 1 - else: - var_id = 0 - for i, pickled in enumerate(stacked_pickled): - interpreter = fickle.Interpreter( - pickled, first_variable_id=var_id, result_variable=f"result{i}" - ) - if args.trace: - trace = tracing.Trace(interpreter) - print(unparse(trace.run())) - else: - print(unparse(interpreter.to_ast())) - var_id = interpreter.next_variable_id - else: - pickled = fickle.Pickled( - [ - fickle.Global.create("__builtin__", "eval"), - fickle.Mark(), - fickle.Unicode(args.create.encode("utf-8")), - fickle.Tuple(), - fickle.Reduce(), - fickle.Stop(), - ] + +def _handle_info(args) -> int: + """Handle 'fickling info FILE' - format identification and properties.""" + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return 1 + + # Try to use polyglot module for detailed analysis (requires torch) + try: + from .polyglot import ( + find_file_properties, + find_file_properties_recursively, + identify_pytorch_file_format, ) - if args.PICKLE_FILE == "-": - file = sys.stdout - if hasattr(file, "buffer") and file.buffer is not None: - file = file.buffer + + formats = identify_pytorch_file_format(args.file, print_results=False) + recursive = getattr(args, "recursive", False) + + if recursive: + properties = find_file_properties_recursively(args.file, print_properties=False) else: - file = open(args.PICKLE_FILE, "wb") - try: - pickled.dump(file) - finally: - file.close() + properties = find_file_properties(args.file, print_properties=False) + + if getattr(args, "json", False): + result = { + "file": str(file_path), + "formats": formats, + "primary_format": formats[0] if formats else None, + "is_polyglot": len(formats) > 1, + "properties": properties, + } + print(json.dumps(result, indent=2)) + else: + if formats: + print(f"Format: {formats[0]}") + if len(formats) > 1: + print(f"Also valid as: {', '.join(formats[1:])}") + print("(This file may be a polyglot)") + else: + print("Format: pickle (no specific PyTorch format detected)") - return 0 + print("\nProperties:") + _print_properties(properties, indent=2) + return 0 -SUBCOMMANDS = {"pytorch", "polyglot"} + except ImportError: + # torch not installed - provide basic info + try: + with open(file_path, "rb") as f: + stacked = fickle.StackedPickle.load(f, fail_on_decode_error=False) + + if getattr(args, "json", False): + result = { + "file": str(file_path), + "formats": ["pickle"], + "primary_format": "pickle", + "is_polyglot": False, + "pickle_count": len(stacked), + } + print(json.dumps(result, indent=2)) + else: + print("Format: pickle") + print(f"Stacked pickles: {len(stacked)}") + print("\nNote: Install PyTorch for detailed format detection:") + print(" pip install fickling[torch]") + + return 0 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error reading file: {e}\n") + return 1 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error: {e}\n") + return 1 -def _create_pickle_parser() -> ArgumentParser: - """Create parser for the original pickle commands (backward compatibility).""" - parser = ArgumentParser( - description="fickling is a static analyzer and interpreter for Python pickle data" - ) - parser.add_argument("--version", "-v", action="store_true", help="print the version and exit") - _add_pickle_arguments(parser) - return parser +def _print_properties(properties: dict, indent: int = 0) -> None: + """Pretty-print file properties.""" + prefix = " " * indent + for key, value in properties.items(): + if key == "children" and isinstance(value, dict): + print(f"{prefix}{key}:") + for child_name, child_props in value.items(): + print(f"{prefix} {child_name}:") + if child_props is not None: + _print_properties(child_props, indent + 4) + else: + print(f"{prefix} (unable to read)") + else: + print(f"{prefix}{key}: {value}") -def _create_subcommand_parser() -> ArgumentParser: - """Create parser with subcommands for PyTorch and polyglot operations.""" - parser = ArgumentParser( - description="fickling is a static analyzer and interpreter for Python pickle data" - ) - parser.add_argument("--version", "-v", action="store_true", help="print the version and exit") - subparsers = parser.add_subparsers(dest="command", help="available commands") +def _handle_create_polyglot(args) -> int: + """Handle 'fickling create-polyglot FILE1 FILE2 -o OUT'.""" + if not _check_torch_available(): + return 1 - # PyTorch subcommand - pytorch_parser = subparsers.add_parser("pytorch", help="PyTorch model operations") - _setup_pytorch_subcommand(pytorch_parser) + from .polyglot import create_polyglot - # Polyglot subcommand - polyglot_parser = subparsers.add_parser("polyglot", help="polyglot detection and creation") - _setup_polyglot_subcommand(polyglot_parser) + file1_path = Path(args.file1) + file2_path = Path(args.file2) - return parser + if not file1_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file1}\n") + return 1 + if not file2_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file2}\n") + return 1 + output_path = getattr(args, "output", None) + quiet = getattr(args, "quiet", False) -def _get_first_positional(argv: list[str]) -> str | None: - """Get the first non-flag argument (potential subcommand or file).""" - for arg in argv[1:]: - if not arg.startswith("-"): - return arg - return None + try: + success = create_polyglot( + args.file1, args.file2, polyglot_file_name=output_path, print_results=not quiet + ) + if success: + return 0 + if not quiet: + sys.stderr.write("Failed to create polyglot. The file formats may not be compatible.\n") + return 1 + except Exception as e: # noqa: BLE001 + sys.stderr.write(f"Error creating polyglot: {e}\n") + return 1 -def main(argv: list[str] | None = None) -> int: - if argv is None: - argv = sys.argv - # Check for version flag first - if "--version" in argv or "-v" in argv: - if sys.stdout.isatty(): - print(f"fickling version {__version__}") +def _handle_legacy(args) -> int: + """Handle legacy CLI behavior (backward compatibility).""" + # Handle --check-safety flag + if args.check_safety: + if args.file and args.file != "-": + # Create a fake args object for _handle_check + class CheckArgs: + pass + + check_args = CheckArgs() + check_args.file = args.file + check_args.json = False + check_args.json_output = args.json_output + check_args.print_results = args.print_results + return _handle_check(check_args) + sys.stderr.write("Error: file path required with --check-safety\n") + return 1 + + # Handle --inject flag + if args.inject: + # For legacy inject, output goes to stdout + if args.file == "-": + file = sys.stdin.buffer if hasattr(sys.stdin, "buffer") else sys.stdin else: - print(__version__) - return 0 + file = open(args.file, "rb") - # Determine if we're using a subcommand or the original CLI - first_positional = _get_first_positional(argv) + try: + stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + return 1 + finally: + if file not in (sys.stdin, sys.stdin.buffer): + file.close() - if first_positional in SUBCOMMANDS: - # Use subcommand parser - parser = _create_subcommand_parser() - args = parser.parse_args(argv[1:]) + if args.inject_target >= len(stacked_pickled): + sys.stderr.write( + f"Error: --inject-target {args.inject_target} is too high; there are only " + f"{len(stacked_pickled)} stacked pickle files in the input\n" + ) + return 1 - if args.command == "pytorch": - from .cli_pytorch import handle_pytorch_command + buffer = sys.stdout.buffer if hasattr(sys.stdout, "buffer") else sys.stdout - return handle_pytorch_command(args) - if args.command == "polyglot": - from .cli_polyglot import handle_polyglot_command + for pickled in stacked_pickled[: args.inject_target]: + pickled.dump(buffer) - return handle_polyglot_command(args) - # Should not reach here - return 1 - # Use original pickle parser for backward compatibility - parser = _create_pickle_parser() - args = parser.parse_args(argv[1:]) - return _handle_pickle_command(args) + pickled = stacked_pickled[args.inject_target] + if not isinstance(pickled[-1], fickle.Stop): + sys.stderr.write( + "Warning: The last opcode of the input file was expected to be STOP, but was " + f"in fact {pickled[-1].info.name}" + ) + pickled.insert_python_eval( + args.inject, + run_first=not args.run_last, + use_output_as_unpickle_result=args.replace_result, + ) + pickled.dump(buffer) -def _setup_pytorch_subcommand(parser: ArgumentParser) -> None: - """Set up the pytorch subcommand with its sub-subcommands.""" - subparsers = parser.add_subparsers(dest="pytorch_command", help="pytorch operations") + for pickled in stacked_pickled[args.inject_target + 1 :]: + pickled.dump(buffer) - # identify - identify_parser = subparsers.add_parser("identify", help="detect PyTorch file format(s)") - identify_parser.add_argument("file", type=str, help="path to the PyTorch model file") - identify_parser.add_argument("--json", action="store_true", help="output results as JSON") + return 0 - # show - show_parser = subparsers.add_parser("show", help="decompile internal pickle from PyTorch model") - show_parser.add_argument("file", type=str, help="path to the PyTorch model file") - show_parser.add_argument( - "--force", "-f", action="store_true", help="force processing unsupported formats" - ) - show_parser.add_argument("--trace", "-t", action="store_true", help="print a runtime trace") + # Handle --create flag + if args.create: + pickled = fickle.Pickled( + [ + fickle.Global.create("__builtin__", "eval"), + fickle.Mark(), + fickle.Unicode(args.create.encode("utf-8")), + fickle.Tuple(), + fickle.Reduce(), + fickle.Stop(), + ] + ) + if args.file == "-": + file = sys.stdout.buffer if hasattr(sys.stdout, "buffer") else sys.stdout + else: + file = open(args.file, "wb") - # check-safety - safety_parser = subparsers.add_parser( - "check-safety", help="run safety analysis on internal pickle" - ) - safety_parser.add_argument("file", type=str, help="path to the PyTorch model file") - safety_parser.add_argument( - "--force", "-f", action="store_true", help="force processing unsupported formats" - ) - safety_parser.add_argument( - "--json-output", - type=str, - default=None, - help="path to output JSON file for analysis results", - ) - safety_parser.add_argument( - "--print-results", "-p", action="store_true", help="print results to console" - ) + try: + pickled.dump(file) + finally: + if file not in (sys.stdout, sys.stdout.buffer): + file.close() - # inject - inject_parser = subparsers.add_parser("inject", help="inject payload into PyTorch model") - inject_parser.add_argument("file", type=str, help="path to the PyTorch model file") - inject_parser.add_argument("-o", "--output", type=str, required=True, help="output file path") - inject_parser.add_argument( - "-c", "--code", type=str, required=True, help="Python code to inject" - ) - inject_parser.add_argument( - "--method", - type=str, - choices=["insertion", "combination"], - default="insertion", - help="injection method (default: insertion)", - ) - inject_parser.add_argument( - "--force", "-f", action="store_true", help="force processing unsupported formats" - ) - inject_parser.add_argument( - "--overwrite", action="store_true", help="overwrite original file with output" - ) + return 0 + # Default: decompile the file + if args.file == "-": + file = sys.stdin.buffer if hasattr(sys.stdin, "buffer") else sys.stdin + else: + file = open(args.file, "rb") -def _setup_polyglot_subcommand(parser: ArgumentParser) -> None: - """Set up the polyglot subcommand with its sub-subcommands.""" - subparsers = parser.add_subparsers(dest="polyglot_command", help="polyglot operations") + try: + stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + return 1 + finally: + if file not in (sys.stdin, sys.stdin.buffer): + file.close() - # identify - identify_parser = subparsers.add_parser( - "identify", help="identify all possible PyTorch file formats" - ) - identify_parser.add_argument("file", type=str, help="path to the file to identify") - identify_parser.add_argument("--json", action="store_true", help="output results as JSON") + var_id = 0 + for i, pickled in enumerate(stacked_pickled): + interpreter = fickle.Interpreter( + pickled, first_variable_id=var_id, result_variable=f"result{i}" + ) + if args.trace: + trace = tracing.Trace(interpreter) + print(unparse(trace.run())) + else: + print(unparse(interpreter.to_ast())) + var_id = interpreter.next_variable_id - # properties - properties_parser = subparsers.add_parser("properties", help="analyze file properties") - properties_parser.add_argument("file", type=str, help="path to the file to analyze") - properties_parser.add_argument( - "-r", "--recursive", action="store_true", help="analyze recursively into archives" - ) - properties_parser.add_argument("--json", action="store_true", help="output results as JSON") - - # create - create_parser = subparsers.add_parser("create", help="create a polyglot file") - create_parser.add_argument("file1", type=str, help="first input file") - create_parser.add_argument("file2", type=str, help="second input file") - create_parser.add_argument("-o", "--output", type=str, default=None, help="output file path") - create_parser.add_argument( - "--quiet", "-q", action="store_true", help="suppress output messages" - ) + return 0 diff --git a/fickling/cli_polyglot.py b/fickling/cli_polyglot.py deleted file mode 100644 index b6d54b02..00000000 --- a/fickling/cli_polyglot.py +++ /dev/null @@ -1,167 +0,0 @@ -"""CLI handlers for polyglot detection and creation operations.""" - -from __future__ import annotations - -import json -import sys -from pathlib import Path - - -def _check_torch_available() -> bool: - """Check if PyTorch is available and provide helpful error message if not.""" - try: - import torch # noqa: F401 - - return True - except ImportError: - sys.stderr.write( - "Error: PyTorch is required for this command.\n" - "Please install it with: pip install fickling[torch]\n" - ) - return False - - -def handle_polyglot_command(args) -> int: - """Handle the polyglot subcommand and its sub-subcommands.""" - if args.polyglot_command is None: - sys.stderr.write("Error: polyglot subcommand required.\n") - sys.stderr.write("Available commands: identify, properties, create\n") - sys.stderr.write("Use 'fickling polyglot --help' for more information.\n") - return 1 - - if args.polyglot_command == "identify": - return _handle_polyglot_identify(args) - if args.polyglot_command == "properties": - return _handle_polyglot_properties(args) - if args.polyglot_command == "create": - return _handle_polyglot_create(args) - sys.stderr.write(f"Error: unknown polyglot command '{args.polyglot_command}'\n") - return 1 - - -def _handle_polyglot_identify(args) -> int: - """Handle 'fickling polyglot identify FILE'.""" - if not _check_torch_available(): - return 1 - - from .polyglot import identify_pytorch_file_format - - file_path = Path(args.file) - if not file_path.exists(): - sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 - - try: - formats = identify_pytorch_file_format(args.file, print_results=False) - - if hasattr(args, "json") and args.json: - result = { - "file": str(file_path), - "formats": formats, - "primary_format": formats[0] if formats else None, - "is_polyglot": len(formats) > 1, - } - print(json.dumps(result, indent=2)) - else: - if formats: - print(f"Identified format(s) for {args.file}:") - for i, fmt in enumerate(formats): - prefix = " [primary]" if i == 0 else " [also] " - print(f"{prefix} {fmt}") - if len(formats) > 1: - print("\n Note: Multiple formats detected - this may be a polyglot file.") - else: - print(f"No PyTorch formats detected for {args.file}") - print("This file may not be a PyTorch file, or it may be in an unsupported format.") - - return 0 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error identifying file: {e}\n") - return 1 - - -def _handle_polyglot_properties(args) -> int: - """Handle 'fickling polyglot properties FILE'.""" - if not _check_torch_available(): - return 1 - - from .polyglot import find_file_properties, find_file_properties_recursively - - file_path = Path(args.file) - if not file_path.exists(): - sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 - - recursive = getattr(args, "recursive", False) - - try: - if recursive: - properties = find_file_properties_recursively(args.file, print_properties=False) - else: - properties = find_file_properties(args.file, print_properties=False) - - if hasattr(args, "json") and args.json: - result = { - "file": str(file_path), - "properties": properties, - } - print(json.dumps(result, indent=2)) - else: - print(f"File properties for {args.file}:") - _print_properties(properties, indent=2) - - return 0 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error analyzing file properties: {e}\n") - return 1 - - -def _print_properties(properties: dict, indent: int = 0) -> None: - """Pretty-print file properties.""" - prefix = " " * indent - for key, value in properties.items(): - if key == "children" and isinstance(value, dict): - print(f"{prefix}{key}:") - for child_name, child_props in value.items(): - print(f"{prefix} {child_name}:") - if child_props is not None: - _print_properties(child_props, indent + 4) - else: - print(f"{prefix} (unable to read)") - else: - print(f"{prefix}{key}: {value}") - - -def _handle_polyglot_create(args) -> int: - """Handle 'fickling polyglot create FILE1 FILE2 -o OUTPUT'.""" - if not _check_torch_available(): - return 1 - - from .polyglot import create_polyglot - - file1_path = Path(args.file1) - file2_path = Path(args.file2) - - if not file1_path.exists(): - sys.stderr.write(f"Error: file not found: {args.file1}\n") - return 1 - if not file2_path.exists(): - sys.stderr.write(f"Error: file not found: {args.file2}\n") - return 1 - - output_path = getattr(args, "output", None) - quiet = getattr(args, "quiet", False) - - try: - success = create_polyglot( - args.file1, args.file2, polyglot_file_name=output_path, print_results=not quiet - ) - - if success: - return 0 - if not quiet: - sys.stderr.write("Failed to create polyglot. The file formats may not be compatible.\n") - return 1 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error creating polyglot: {e}\n") - return 1 diff --git a/fickling/cli_pytorch.py b/fickling/cli_pytorch.py deleted file mode 100644 index 494965ba..00000000 --- a/fickling/cli_pytorch.py +++ /dev/null @@ -1,200 +0,0 @@ -"""CLI handlers for PyTorch model operations.""" - -from __future__ import annotations - -import json -import sys -from ast import unparse -from pathlib import Path - -from .analysis import Severity, check_safety -from .cli import DEFAULT_JSON_OUTPUT_FILE -from .fickle import Interpreter -from .tracing import Trace - - -def _check_torch_available() -> bool: - """Check if PyTorch is available and provide helpful error message if not.""" - try: - import torch # noqa: F401 - - return True - except ImportError: - sys.stderr.write( - "Error: PyTorch is required for this command.\n" - "Please install it with: pip install fickling[torch]\n" - ) - return False - - -def handle_pytorch_command(args) -> int: - """Handle the pytorch subcommand and its sub-subcommands.""" - if args.pytorch_command is None: - sys.stderr.write("Error: pytorch subcommand required.\n") - sys.stderr.write("Available commands: identify, show, check-safety, inject\n") - sys.stderr.write("Use 'fickling pytorch --help' for more information.\n") - return 1 - - if args.pytorch_command == "identify": - return _handle_pytorch_identify(args) - if args.pytorch_command == "show": - return _handle_pytorch_show(args) - if args.pytorch_command == "check-safety": - return _handle_pytorch_check_safety(args) - if args.pytorch_command == "inject": - return _handle_pytorch_inject(args) - sys.stderr.write(f"Error: unknown pytorch command '{args.pytorch_command}'\n") - return 1 - - -def _handle_pytorch_identify(args) -> int: - """Handle 'fickling pytorch identify FILE'.""" - if not _check_torch_available(): - return 1 - - from .pytorch import PyTorchModelWrapper - - file_path = Path(args.file) - if not file_path.exists(): - sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 - - try: - wrapper = PyTorchModelWrapper(file_path, force=True) - formats = wrapper.validate_file_format() - - formats_list = list(formats) - if hasattr(args, "json") and args.json: - result = { - "file": str(file_path), - "formats": formats_list, - "primary_format": formats_list[0] if formats_list else None, - } - print(json.dumps(result, indent=2)) - else: - if formats_list: - print(f"Detected format(s) for {args.file}:") - for i, fmt in enumerate(formats_list): - prefix = " [primary]" if i == 0 else " [also] " - print(f"{prefix} {fmt}") - else: - print(f"No PyTorch formats detected for {args.file}") - - return 0 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error identifying file: {e}\n") - return 1 - - -def _handle_pytorch_show(args) -> int: - """Handle 'fickling pytorch show FILE'.""" - if not _check_torch_available(): - return 1 - - from .pytorch import PyTorchModelWrapper - - file_path = Path(args.file) - if not file_path.exists(): - sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 - - force = getattr(args, "force", False) - - try: - wrapper = PyTorchModelWrapper(file_path, force=force) - pickled = wrapper.pickled - - interpreter = Interpreter(pickled) - if getattr(args, "trace", False): - trace = Trace(interpreter) - print(unparse(trace.run())) - else: - print(unparse(interpreter.to_ast())) - - return 0 - except ValueError as e: - sys.stderr.write(f"Error: {e}\n") - sys.stderr.write("Use --force to attempt processing anyway.\n") - return 1 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error reading PyTorch model: {e}\n") - return 1 - - -def _handle_pytorch_check_safety(args) -> int: - """Handle 'fickling pytorch check-safety FILE'.""" - if not _check_torch_available(): - return 1 - - from .pytorch import PyTorchModelWrapper - - file_path = Path(args.file) - if not file_path.exists(): - sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 - - force = getattr(args, "force", False) - json_output_path = getattr(args, "json_output", None) or DEFAULT_JSON_OUTPUT_FILE - print_results = getattr(args, "print_results", False) - - try: - wrapper = PyTorchModelWrapper(file_path, force=force) - pickled = wrapper.pickled - - safety_results = check_safety(pickled, json_output_path=json_output_path) - - if print_results: - print(safety_results.to_string()) - - if safety_results.severity > Severity.LIKELY_SAFE: - if print_results: - sys.stderr.write( - "Warning: Fickling detected that the PyTorch model may be unsafe.\n\n" - "Do not load this model if it is from an untrusted source!\n\n" - ) - return 1 - return 0 - except ValueError as e: - sys.stderr.write(f"Error: {e}\n") - sys.stderr.write("Use --force to attempt processing anyway.\n") - return 1 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error checking PyTorch model safety: {e}\n") - return 1 - - -def _handle_pytorch_inject(args) -> int: - """Handle 'fickling pytorch inject FILE -o OUTPUT -c CODE'.""" - if not _check_torch_available(): - return 1 - - from .pytorch import PyTorchModelWrapper - - file_path = Path(args.file) - if not file_path.exists(): - sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 - - output_path = Path(args.output) - if output_path.exists() and not getattr(args, "overwrite", False): - sys.stderr.write(f"Error: output file already exists: {args.output}\n") - sys.stderr.write("Use --overwrite to replace the original file.\n") - return 1 - - force = getattr(args, "force", False) - method = getattr(args, "method", "insertion") - overwrite = getattr(args, "overwrite", False) - code = args.code - - try: - wrapper = PyTorchModelWrapper(file_path, force=force) - wrapper.inject_payload(code, output_path, injection=method, overwrite=overwrite) - print(f"Payload injected successfully. Output written to: {output_path}") - return 0 - except ValueError as e: - sys.stderr.write(f"Error: {e}\n") - sys.stderr.write("Use --force to attempt processing anyway.\n") - return 1 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error injecting payload: {e}\n") - return 1 diff --git a/fickling/loader.py b/fickling/loader.py index 5863f4e1..7fae3761 100644 --- a/fickling/loader.py +++ b/fickling/loader.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import pickle from io import BytesIO +from pathlib import Path from fickling.analysis import Severity, check_safety from fickling.exception import UnsafeFileError -from fickling.fickle import Pickled +from fickling.fickle import Pickled, StackedPickle def load( @@ -70,3 +73,63 @@ def loads( json_output_path=json_output_path, **kwargs, ) + + +def auto_load(path: Path | str) -> tuple[str, StackedPickle]: + """ + Auto-detect file format and load the pickle content. + + This function automatically detects whether the file is a PyTorch model (ZIP format), + a plain pickle, or other supported formats, and returns the appropriate Pickled data. + + Args: + path: Path to the file to load + + Returns: + A tuple of (format_name, pickled_data) where: + - format_name: A string describing the detected format (e.g., "PyTorch v1.3", "pickle") + - pickled_data: A StackedPickle containing one or more Pickled objects + + Raises: + ValueError: If the file format cannot be determined or is unsupported + """ + if isinstance(path, str): + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # Try PyTorch ZIP formats first (most common for ML models) + try: + from fickling.polyglot import identify_pytorch_file_format + + formats = identify_pytorch_file_format(path, print_results=False) + + if formats: + primary_format = formats[0] + + # Handle PyTorch v1.3 and TorchScript v1.4 (ZIP with data.pkl) + if primary_format in ("PyTorch v1.3", "TorchScript v1.4", "TorchScript v1.3"): + from fickling.pytorch import PyTorchModelWrapper + + wrapper = PyTorchModelWrapper(path, force=True) + # Return as StackedPickle for consistency + return primary_format, StackedPickle([wrapper.pickled]) + + # Handle legacy formats as plain pickle + if primary_format == "PyTorch v0.1.10": + with open(path, "rb") as f: + stacked = StackedPickle.load(f, fail_on_decode_error=False) + return primary_format, stacked + + except ImportError: + # torch not installed, fall through to plain pickle handling + pass + + # Fall back to plain pickle + try: + with open(path, "rb") as f: + stacked = StackedPickle.load(f, fail_on_decode_error=False) + return "pickle", stacked + except Exception as e: + raise ValueError(f"Unable to load file as pickle: {e}") from e diff --git a/test/test_cli.py b/test/test_cli.py index 1d456030..7a1bdc01 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -3,6 +3,7 @@ from __future__ import annotations import io +import json import tempfile from contextlib import redirect_stderr, redirect_stdout from pathlib import Path @@ -50,9 +51,11 @@ def test_decompile_pickle(self): # Should contain decompiled code self.assertIn("result", output) - def test_check_safety_safe_file(self): - """Test --check-safety on a safe pickle file.""" - result = main(["fickling", "--check-safety", str(self.pickle_file)]) + def test_check_safety_legacy_flag(self): + """Test --check-safety on a safe pickle file (legacy syntax).""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "--check-safety", str(self.pickle_file)]) self.assertEqual(result, 0) def test_help_flag(self): @@ -62,33 +65,18 @@ def test_help_flag(self): self.assertEqual(cm.exception.code, 0) -class TestCLISubcommandRouting(TestCase): - """Test that subcommand routing works correctly.""" - - def test_pytorch_subcommand_no_args(self): - """Test 'fickling pytorch' with no arguments.""" - stderr = io.StringIO() - with redirect_stderr(stderr): - result = main(["fickling", "pytorch"]) - self.assertEqual(result, 1) - self.assertIn("pytorch subcommand required", stderr.getvalue()) - - def test_polyglot_subcommand_no_args(self): - """Test 'fickling polyglot' with no arguments.""" - stderr = io.StringIO() - with redirect_stderr(stderr): - result = main(["fickling", "polyglot"]) - self.assertEqual(result, 1) - self.assertIn("polyglot subcommand required", stderr.getvalue()) - - -class TestPyTorchCLI(TestCase): - """Test PyTorch CLI subcommands.""" +class TestCheckCommand(TestCase): + """Test the 'fickling check' command.""" def setUp(self): self.tmpdir = tempfile.TemporaryDirectory() self.tmppath = Path(self.tmpdir.name) + # Create a simple pickle file + self.pickle_file = self.tmppath / "test.pkl" + with open(self.pickle_file, "wb") as f: + f.write(dumps({"test": "data"})) + # Create a PyTorch model file model = models.mobilenet_v2(weights=None) self.model_file = self.tmppath / "model.pth" @@ -97,77 +85,119 @@ def setUp(self): def tearDown(self): self.tmpdir.cleanup() - def test_pytorch_identify(self): - """Test 'fickling pytorch identify'.""" + def test_check_pickle(self): + """Test 'fickling check' on a pickle file.""" stdout = io.StringIO() with redirect_stdout(stdout): - result = main(["fickling", "pytorch", "identify", str(self.model_file)]) + result = main(["fickling", "check", str(self.pickle_file)]) self.assertEqual(result, 0) output = stdout.getvalue() self.assertIn("Detected format", output) + self.assertIn("pickle", output) + + def test_check_pytorch_model(self): + """Test 'fickling check' on a PyTorch model (auto-detection).""" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", str(self.model_file)]) + # Result can be 0 (safe) or 1 (potentially unsafe) - just verify it runs + self.assertIn(result, [0, 1]) + output = stdout.getvalue() + self.assertIn("Detected format", output) self.assertIn("PyTorch", output) - def test_pytorch_identify_json(self): - """Test 'fickling pytorch identify --json'.""" + def test_check_json_output(self): + """Test 'fickling check --json'.""" stdout = io.StringIO() with redirect_stdout(stdout): - result = main(["fickling", "pytorch", "identify", "--json", str(self.model_file)]) + result = main(["fickling", "check", "--json", str(self.pickle_file)]) self.assertEqual(result, 0) output = stdout.getvalue() # Should be valid JSON - import json - data = json.loads(output) - self.assertIn("formats", data) - self.assertIn("primary_format", data) + self.assertIn("format", data) + self.assertIn("safe", data) + self.assertIn("severity", data) + + def test_check_file_not_found(self): + """Test 'fickling check' with non-existent file.""" + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "check", "/nonexistent/file.pkl"]) + self.assertEqual(result, 1) + self.assertIn("file not found", stderr.getvalue()) + + +class TestInjectCommand(TestCase): + """Test the 'fickling inject' command.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + # Create a simple pickle file + self.pickle_file = self.tmppath / "test.pkl" + with open(self.pickle_file, "wb") as f: + f.write(dumps({"test": "data"})) + + # Create a PyTorch model file + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() - def test_pytorch_show(self): - """Test 'fickling pytorch show'.""" + def test_inject_pickle(self): + """Test 'fickling inject' on a pickle file.""" + output_file = self.tmppath / "injected.pkl" stdout = io.StringIO() with redirect_stdout(stdout): - result = main(["fickling", "pytorch", "show", str(self.model_file)]) + result = main( + [ + "fickling", + "inject", + str(self.pickle_file), + "-c", + "print('test')", + "-o", + str(output_file), + ] + ) self.assertEqual(result, 0) - output = stdout.getvalue() - # Should contain decompiled code - self.assertIn("result", output) - - def test_pytorch_check_safety(self): - """Test 'fickling pytorch check-safety'.""" - result = main(["fickling", "pytorch", "check-safety", str(self.model_file)]) - # Result can be 0 (safe) or 1 (potentially unsafe) - just verify it runs - self.assertIn(result, [0, 1]) + self.assertTrue(output_file.exists()) - def test_pytorch_inject(self): - """Test 'fickling pytorch inject'.""" + def test_inject_pytorch_model(self): + """Test 'fickling inject' on a PyTorch model.""" output_file = self.tmppath / "injected.pth" - result = main( - [ - "fickling", - "pytorch", - "inject", - str(self.model_file), - "-o", - str(output_file), - "-c", - "print('test')", - ] - ) + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main( + [ + "fickling", + "inject", + str(self.model_file), + "-c", + "print('test')", + "-o", + str(output_file), + ] + ) self.assertEqual(result, 0) self.assertTrue(output_file.exists()) - def test_pytorch_inject_combination_method(self): - """Test 'fickling pytorch inject --method combination'.""" + def test_inject_pytorch_combination_method(self): + """Test 'fickling inject --method combination' on a PyTorch model.""" output_file = self.tmppath / "injected_combo.pth" result = main( [ "fickling", - "pytorch", "inject", str(self.model_file), - "-o", - str(output_file), "-c", "print('test')", + "-o", + str(output_file), "--method", "combination", ] @@ -175,22 +205,47 @@ def test_pytorch_inject_combination_method(self): self.assertEqual(result, 0) self.assertTrue(output_file.exists()) - def test_pytorch_file_not_found(self): - """Test PyTorch commands with non-existent file.""" + def test_inject_missing_output(self): + """Test inject without required --output flag.""" + with self.assertRaises(SystemExit): + main(["fickling", "inject", str(self.pickle_file), "-c", "code"]) + + def test_inject_missing_code(self): + """Test inject without required --code flag.""" + with self.assertRaises(SystemExit): + main(["fickling", "inject", str(self.pickle_file), "-o", "out.pkl"]) + + def test_inject_file_not_found(self): + """Test inject with non-existent file.""" stderr = io.StringIO() with redirect_stderr(stderr): - result = main(["fickling", "pytorch", "identify", "/nonexistent/file.pth"]) + result = main( + [ + "fickling", + "inject", + "/nonexistent/file.pkl", + "-c", + "code", + "-o", + "out.pkl", + ] + ) self.assertEqual(result, 1) self.assertIn("file not found", stderr.getvalue()) -class TestPolyglotCLI(TestCase): - """Test polyglot CLI subcommands.""" +class TestInfoCommand(TestCase): + """Test the 'fickling info' command.""" def setUp(self): self.tmpdir = tempfile.TemporaryDirectory() self.tmppath = Path(self.tmpdir.name) + # Create a simple pickle file + self.pickle_file = self.tmppath / "test.pkl" + with open(self.pickle_file, "wb") as f: + f.write(dumps({"test": "data"})) + # Create a PyTorch model file model = models.mobilenet_v2(weights=None) self.model_file = self.tmppath / "model.pth" @@ -199,56 +254,74 @@ def setUp(self): def tearDown(self): self.tmpdir.cleanup() - def test_polyglot_identify(self): - """Test 'fickling polyglot identify'.""" + def test_info_pickle(self): + """Test 'fickling info' on a pickle file.""" stdout = io.StringIO() with redirect_stdout(stdout): - result = main(["fickling", "polyglot", "identify", str(self.model_file)]) + result = main(["fickling", "info", str(self.pickle_file)]) self.assertEqual(result, 0) output = stdout.getvalue() - self.assertIn("Identified format", output) - - def test_polyglot_identify_json(self): - """Test 'fickling polyglot identify --json'.""" - stdout = io.StringIO() - with redirect_stdout(stdout): - result = main(["fickling", "polyglot", "identify", "--json", str(self.model_file)]) - self.assertEqual(result, 0) - output = stdout.getvalue() - # Should be valid JSON - import json + self.assertIn("Format:", output) - data = json.loads(output) - self.assertIn("formats", data) - self.assertIn("is_polyglot", data) - - def test_polyglot_properties(self): - """Test 'fickling polyglot properties'.""" + def test_info_pytorch_model(self): + """Test 'fickling info' on a PyTorch model.""" stdout = io.StringIO() with redirect_stdout(stdout): - result = main(["fickling", "polyglot", "properties", str(self.model_file)]) + result = main(["fickling", "info", str(self.model_file)]) self.assertEqual(result, 0) output = stdout.getvalue() - self.assertIn("File properties", output) + self.assertIn("Format:", output) + self.assertIn("PyTorch", output) - def test_polyglot_properties_json(self): - """Test 'fickling polyglot properties --json'.""" + def test_info_json_output(self): + """Test 'fickling info --json'.""" stdout = io.StringIO() with redirect_stdout(stdout): - result = main(["fickling", "polyglot", "properties", "--json", str(self.model_file)]) + result = main(["fickling", "info", "--json", str(self.model_file)]) self.assertEqual(result, 0) output = stdout.getvalue() # Should be valid JSON - import json - data = json.loads(output) + self.assertIn("formats", data) + self.assertIn("primary_format", data) self.assertIn("properties", data) - def test_polyglot_file_not_found(self): - """Test polyglot commands with non-existent file.""" + def test_info_file_not_found(self): + """Test info with non-existent file.""" + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "info", "/nonexistent/file.pth"]) + self.assertEqual(result, 1) + self.assertIn("file not found", stderr.getvalue()) + + +class TestCreatePolyglotCommand(TestCase): + """Test the 'fickling create-polyglot' command.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + # Create a PyTorch model file + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_create_polyglot_file_not_found(self): + """Test create-polyglot with non-existent file.""" stderr = io.StringIO() with redirect_stderr(stderr): - result = main(["fickling", "polyglot", "identify", "/nonexistent/file.pth"]) + result = main( + [ + "fickling", + "create-polyglot", + "/nonexistent/file1.pth", + str(self.model_file), + ] + ) self.assertEqual(result, 1) self.assertIn("file not found", stderr.getvalue()) @@ -261,13 +334,3 @@ def test_nonexistent_pickle_file(self): # The original CLI raises FileNotFoundError for non-existent files with self.assertRaises(FileNotFoundError): main(["fickling", "/nonexistent/file.pkl"]) - - def test_pytorch_inject_missing_output(self): - """Test pytorch inject without required --output flag.""" - with self.assertRaises(SystemExit): - main(["fickling", "pytorch", "inject", "file.pth", "-c", "code"]) - - def test_pytorch_inject_missing_code(self): - """Test pytorch inject without required --code flag.""" - with self.assertRaises(SystemExit): - main(["fickling", "pytorch", "inject", "file.pth", "-o", "out.pth"]) From df1ed6634d4ee36bb32d5bd227964d4b9ef399f9 Mon Sep 17 00:00:00 2001 From: Dan Guido Date: Fri, 20 Feb 2026 12:20:01 -0700 Subject: [PATCH 3/3] Fix review issues in CLI redesign - Restore ClamAV-compatible exit codes (EXIT_CLEAN=0, EXIT_UNSAFE=1, EXIT_ERROR=2) throughout all CLI handlers - Fix _get_first_positional() to skip flag values, preventing misrouting when flag arguments match command names (e.g. --inject "check") - Restore stdin support for legacy --check-safety (was rejecting '-') - Add mutually exclusive group to legacy parser for --inject/--check-safety/--create - Error on invalid --inject-target instead of silently resetting to 0 - Narrow broad except Exception catches to specific types (PickleDecodeError, FileNotFoundError, ValueError, OSError) - Restore PickleDecodeError security warning in check handler - Narrow auto_load() ImportError catch to only the import statement - Guard torch imports with pytest.importorskip() for test portability - Add auto_load() tests, _get_first_positional() tests, exit code tests - Strengthen test assertions (verify output content, not just return codes) Co-Authored-By: Claude Opus 4.6 --- fickling/cli.py | 288 ++++++++++++++++++---------- fickling/constants.py | 4 + fickling/loader.py | 11 +- test/test_cli.py | 437 +++++++++++++++++++++++++++++++----------- 4 files changed, 517 insertions(+), 223 deletions(-) create mode 100644 fickling/constants.py diff --git a/fickling/cli.py b/fickling/cli.py index 7b75e876..bb7c1dd3 100644 --- a/fickling/cli.py +++ b/fickling/cli.py @@ -10,6 +10,7 @@ from . import __version__, fickle, tracing from .analysis import Severity, check_safety +from .constants import EXIT_CLEAN, EXIT_ERROR, EXIT_UNSAFE DEFAULT_JSON_OUTPUT_FILE = "safety_results.json" @@ -45,13 +46,14 @@ def _create_legacy_parser() -> ArgumentParser: default="-", help="file to analyze (default: stdin)", ) - parser.add_argument( + options = parser.add_mutually_exclusive_group() + options.add_argument( "--check-safety", "-s", action="store_true", help="(legacy) run safety analysis - prefer 'fickling check FILE'", ) - parser.add_argument( + options.add_argument( "--inject", "-i", type=str, @@ -64,7 +66,7 @@ def _create_legacy_parser() -> ArgumentParser: default=0, help="index of stacked pickle to inject into (default: 0)", ) - parser.add_argument( + options.add_argument( "--create", "-c", type=str, @@ -194,11 +196,38 @@ def _create_command_parser() -> ArgumentParser: return parser +# Flags that consume the next argument as a value +_FLAGS_WITH_VALUES = { + "--inject", + "-i", + "--inject-target", + "--create", + "-c", + "--json-output", + "--code", + "--output", + "-o", + "--method", +} + + def _get_first_positional(argv: list[str]) -> str | None: - """Get the first non-flag argument (potential command or file).""" + """Get the first non-flag argument (potential command or file). + + Skips values consumed by flags like --inject to avoid + misrouting when a flag value matches a command name. + """ + skip_next = False for arg in argv[1:]: - if not arg.startswith("-"): - return arg + if skip_next: + skip_next = False + continue + if arg in _FLAGS_WITH_VALUES: + skip_next = True + continue + if arg.startswith("-"): + continue + return arg return None @@ -214,7 +243,7 @@ def main(argv: list[str] | None = None) -> int: print(f"fickling version {__version__}") else: print(__version__) - return 0 + return EXIT_CLEAN # Determine if we're using a new command or legacy CLI first_positional = _get_first_positional(argv) @@ -232,7 +261,7 @@ def main(argv: list[str] | None = None) -> int: return _handle_info(args) if args.command == "create-polyglot": return _handle_create_polyglot(args) - return 1 + return EXIT_ERROR # Use legacy parser for backward compatibility parser = _create_legacy_parser() @@ -245,7 +274,7 @@ def _handle_check(args) -> int: file_path = Path(args.file) if not file_path.exists(): sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 + return EXIT_ERROR json_output_path = args.json_output or DEFAULT_JSON_OUTPUT_FILE print_results = getattr(args, "print_results", False) @@ -292,17 +321,24 @@ def _handle_check(args) -> int: "Do not unpickle this file if it is from an untrusted source!\n" ) - return 0 if was_safe else 1 + return EXIT_CLEAN if was_safe else EXIT_UNSAFE + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + sys.stderr.write( + "Parsing errors might be indicative of a maliciously crafted pickle file. " + "DO NOT TRUST this file without performing further analysis!\n" + ) + return EXIT_ERROR except FileNotFoundError as e: sys.stderr.write(f"Error: {e}\n") - return 1 + return EXIT_ERROR except ValueError as e: sys.stderr.write(f"Error loading file: {e}\n") - return 1 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error: {e}\n") - return 1 + return EXIT_ERROR + except OSError as e: + sys.stderr.write(f"Error reading file: {e}\n") + return EXIT_ERROR def _handle_inject(args) -> int: @@ -310,13 +346,13 @@ def _handle_inject(args) -> int: file_path = Path(args.file) if not file_path.exists(): sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 + return EXIT_ERROR output_path = Path(args.output) if output_path.exists() and not getattr(args, "overwrite", False): sys.stderr.write(f"Error: output file already exists: {args.output}\n") sys.stderr.write("Use --overwrite to replace it.\n") - return 1 + return EXIT_ERROR try: from .loader import auto_load @@ -327,7 +363,7 @@ def _handle_inject(args) -> int: # For PyTorch ZIP formats, use PyTorchModelWrapper for proper injection if format_name in ("PyTorch v1.3", "TorchScript v1.4", "TorchScript v1.3"): if not _check_torch_available(): - return 1 + return EXIT_ERROR from .pytorch import PyTorchModelWrapper @@ -337,9 +373,17 @@ def _handle_inject(args) -> int: wrapper = PyTorchModelWrapper(file_path, force=True) wrapper.inject_payload(args.code, output_path, injection=method, overwrite=overwrite) print(f"Payload injected successfully. Output: {output_path}") - return 0 + return EXIT_CLEAN # For plain pickle, use direct injection + inject_target = getattr(args, "inject_target", 0) + if inject_target >= len(stacked_pickled): + sys.stderr.write( + f"Error: --inject-target {inject_target} is too high; there are only " + f"{len(stacked_pickled)} stacked pickle files in the input\n" + ) + return EXIT_ERROR + if args.output == "-": buffer = ( sys.stdout.buffer @@ -352,10 +396,6 @@ def _handle_inject(args) -> int: should_close = True try: - inject_target = getattr(args, "inject_target", 0) - if inject_target >= len(stacked_pickled): - inject_target = 0 - for pickled in stacked_pickled[:inject_target]: pickled.dump(buffer) @@ -371,20 +411,23 @@ def _handle_inject(args) -> int: pickled.dump(buffer) print(f"Payload injected successfully. Output: {output_path}") - return 0 + return EXIT_CLEAN finally: if should_close: buffer.close() + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + return EXIT_ERROR except FileNotFoundError as e: sys.stderr.write(f"Error: {e}\n") - return 1 + return EXIT_ERROR except ValueError as e: sys.stderr.write(f"Error: {e}\n") - return 1 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error injecting payload: {e}\n") - return 1 + return EXIT_ERROR + except OSError as e: + sys.stderr.write(f"Error: {e}\n") + return EXIT_ERROR def _handle_info(args) -> int: @@ -392,7 +435,7 @@ def _handle_info(args) -> int: file_path = Path(args.file) if not file_path.exists(): sys.stderr.write(f"Error: file not found: {args.file}\n") - return 1 + return EXIT_ERROR # Try to use polyglot module for detailed analysis (requires torch) try: @@ -401,7 +444,10 @@ def _handle_info(args) -> int: find_file_properties_recursively, identify_pytorch_file_format, ) + except ImportError: + return _handle_info_basic(args, file_path) + try: formats = identify_pytorch_file_format(args.file, print_results=False) recursive = getattr(args, "recursive", False) @@ -431,37 +477,38 @@ def _handle_info(args) -> int: print("\nProperties:") _print_properties(properties, indent=2) - return 0 + return EXIT_CLEAN - except ImportError: - # torch not installed - provide basic info - try: - with open(file_path, "rb") as f: - stacked = fickle.StackedPickle.load(f, fail_on_decode_error=False) - - if getattr(args, "json", False): - result = { - "file": str(file_path), - "formats": ["pickle"], - "primary_format": "pickle", - "is_polyglot": False, - "pickle_count": len(stacked), - } - print(json.dumps(result, indent=2)) - else: - print("Format: pickle") - print(f"Stacked pickles: {len(stacked)}") - print("\nNote: Install PyTorch for detailed format detection:") - print(" pip install fickling[torch]") + except (fickle.PickleDecodeError, ValueError, OSError) as e: + sys.stderr.write(f"Error: {e}\n") + return EXIT_ERROR - return 0 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error reading file: {e}\n") - return 1 - except Exception as e: # noqa: BLE001 - sys.stderr.write(f"Error: {e}\n") - return 1 +def _handle_info_basic(args, file_path: Path) -> int: + """Handle 'fickling info' without PyTorch (basic pickle info only).""" + try: + with open(file_path, "rb") as f: + stacked = fickle.StackedPickle.load(f, fail_on_decode_error=False) + + if getattr(args, "json", False): + result = { + "file": str(file_path), + "formats": ["pickle"], + "primary_format": "pickle", + "is_polyglot": False, + "pickle_count": len(stacked), + } + print(json.dumps(result, indent=2)) + else: + print("Format: pickle") + print(f"Stacked pickles: {len(stacked)}") + print("\nNote: Install PyTorch for detailed format detection:") + print(" pip install fickling[torch]") + + return EXIT_CLEAN + except (fickle.PickleDecodeError, ValueError, OSError) as e: + sys.stderr.write(f"Error reading file: {e}\n") + return EXIT_ERROR def _print_properties(properties: dict, indent: int = 0) -> None: @@ -483,7 +530,7 @@ def _print_properties(properties: dict, indent: int = 0) -> None: def _handle_create_polyglot(args) -> int: """Handle 'fickling create-polyglot FILE1 FILE2 -o OUT'.""" if not _check_torch_available(): - return 1 + return EXIT_ERROR from .polyglot import create_polyglot @@ -492,10 +539,10 @@ def _handle_create_polyglot(args) -> int: if not file1_path.exists(): sys.stderr.write(f"Error: file not found: {args.file1}\n") - return 1 + return EXIT_ERROR if not file2_path.exists(): sys.stderr.write(f"Error: file not found: {args.file2}\n") - return 1 + return EXIT_ERROR output_path = getattr(args, "output", None) quiet = getattr(args, "quiet", False) @@ -506,58 +553,88 @@ def _handle_create_polyglot(args) -> int: ) if success: - return 0 + return EXIT_CLEAN if not quiet: sys.stderr.write("Failed to create polyglot. The file formats may not be compatible.\n") - return 1 - except Exception as e: # noqa: BLE001 + return EXIT_ERROR + except (ValueError, OSError) as e: sys.stderr.write(f"Error creating polyglot: {e}\n") - return 1 + return EXIT_ERROR + + +def _open_input(file_arg: str): + """Open input file or return stdin buffer.""" + if file_arg == "-": + if hasattr(sys.stdin, "buffer") and sys.stdin.buffer is not None: + return sys.stdin.buffer + return sys.stdin + return open(file_arg, "rb") def _handle_legacy(args) -> int: """Handle legacy CLI behavior (backward compatibility).""" - # Handle --check-safety flag + # Handle --check-safety flag (supports stdin via '-' default) if args.check_safety: - if args.file and args.file != "-": - # Create a fake args object for _handle_check - class CheckArgs: - pass - - check_args = CheckArgs() - check_args.file = args.file - check_args.json = False - check_args.json_output = args.json_output - check_args.print_results = args.print_results - return _handle_check(check_args) - sys.stderr.write("Error: file path required with --check-safety\n") - return 1 + file = _open_input(args.file) + try: + stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + sys.stderr.write( + "Parsing errors might be indicative of a maliciously crafted " + "pickle file. DO NOT TRUST this file without performing " + "further analysis!\n" + ) + return EXIT_ERROR + finally: + if file not in (sys.stdin, getattr(sys.stdin, "buffer", None)): + file.close() + + was_safe = True + json_output_path = args.json_output or DEFAULT_JSON_OUTPUT_FILE + for pickled in stacked_pickled: + safety_results = check_safety(pickled, json_output_path=json_output_path) + + if args.print_results: + print(safety_results.to_string()) + + if safety_results.severity > Severity.LIKELY_SAFE: + was_safe = False + if args.print_results: + sys.stderr.write( + "Warning: Fickling detected that the pickle file " + "may be unsafe.\n\n" + "Do not unpickle this file if it is from an " + "untrusted source!\n\n" + ) + + return EXIT_CLEAN if was_safe else EXIT_UNSAFE # Handle --inject flag if args.inject: - # For legacy inject, output goes to stdout - if args.file == "-": - file = sys.stdin.buffer if hasattr(sys.stdin, "buffer") else sys.stdin - else: - file = open(args.file, "rb") + file = _open_input(args.file) try: stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) except fickle.PickleDecodeError as e: sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") - return 1 + return EXIT_ERROR finally: - if file not in (sys.stdin, sys.stdin.buffer): + if file not in (sys.stdin, getattr(sys.stdin, "buffer", None)): file.close() if args.inject_target >= len(stacked_pickled): sys.stderr.write( - f"Error: --inject-target {args.inject_target} is too high; there are only " - f"{len(stacked_pickled)} stacked pickle files in the input\n" + f"Error: --inject-target {args.inject_target} is too high; " + f"there are only {len(stacked_pickled)} stacked pickle " + f"files in the input\n" ) - return 1 + return EXIT_ERROR - buffer = sys.stdout.buffer if hasattr(sys.stdout, "buffer") else sys.stdout + if hasattr(sys.stdout, "buffer") and sys.stdout.buffer is not None: + buffer = sys.stdout.buffer + else: + buffer = sys.stdout for pickled in stacked_pickled[: args.inject_target]: pickled.dump(buffer) @@ -565,8 +642,9 @@ class CheckArgs: pickled = stacked_pickled[args.inject_target] if not isinstance(pickled[-1], fickle.Stop): sys.stderr.write( - "Warning: The last opcode of the input file was expected to be STOP, but was " - f"in fact {pickled[-1].info.name}" + "Warning: The last opcode of the input file was expected " + "to be STOP, but was in fact " + f"{pickled[-1].info.name}" ) pickled.insert_python_eval( @@ -579,7 +657,7 @@ class CheckArgs: for pickled in stacked_pickled[args.inject_target + 1 :]: pickled.dump(buffer) - return 0 + return EXIT_CLEAN # Handle --create flag if args.create: @@ -594,37 +672,39 @@ class CheckArgs: ] ) if args.file == "-": - file = sys.stdout.buffer if hasattr(sys.stdout, "buffer") else sys.stdout + if hasattr(sys.stdout, "buffer") and sys.stdout.buffer is not None: + file = sys.stdout.buffer + else: + file = sys.stdout else: file = open(args.file, "wb") try: pickled.dump(file) finally: - if file not in (sys.stdout, sys.stdout.buffer): + if file not in (sys.stdout, getattr(sys.stdout, "buffer", None)): file.close() - return 0 + return EXIT_CLEAN # Default: decompile the file - if args.file == "-": - file = sys.stdin.buffer if hasattr(sys.stdin, "buffer") else sys.stdin - else: - file = open(args.file, "rb") + file = _open_input(args.file) try: stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) except fickle.PickleDecodeError as e: sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") - return 1 + return EXIT_ERROR finally: - if file not in (sys.stdin, sys.stdin.buffer): + if file not in (sys.stdin, getattr(sys.stdin, "buffer", None)): file.close() var_id = 0 for i, pickled in enumerate(stacked_pickled): interpreter = fickle.Interpreter( - pickled, first_variable_id=var_id, result_variable=f"result{i}" + pickled, + first_variable_id=var_id, + result_variable=f"result{i}", ) if args.trace: trace = tracing.Trace(interpreter) @@ -633,4 +713,4 @@ class CheckArgs: print(unparse(interpreter.to_ast())) var_id = interpreter.next_variable_id - return 0 + return EXIT_CLEAN diff --git a/fickling/constants.py b/fickling/constants.py new file mode 100644 index 00000000..f0bdc1b4 --- /dev/null +++ b/fickling/constants.py @@ -0,0 +1,4 @@ +# ClamAV-compatible exit codes for CI/CD integration +EXIT_CLEAN = 0 # No issues found +EXIT_UNSAFE = 1 # Potentially malicious content detected +EXIT_ERROR = 2 # Scan error (parse failure, file not found, etc.) diff --git a/fickling/loader.py b/fickling/loader.py index 7fae3761..24c517d4 100644 --- a/fickling/loader.py +++ b/fickling/loader.py @@ -102,7 +102,11 @@ def auto_load(path: Path | str) -> tuple[str, StackedPickle]: # Try PyTorch ZIP formats first (most common for ML models) try: from fickling.polyglot import identify_pytorch_file_format + except ImportError: + # torch not installed, fall through to plain pickle handling + identify_pytorch_file_format = None + if identify_pytorch_file_format is not None: formats = identify_pytorch_file_format(path, print_results=False) if formats: @@ -113,7 +117,6 @@ def auto_load(path: Path | str) -> tuple[str, StackedPickle]: from fickling.pytorch import PyTorchModelWrapper wrapper = PyTorchModelWrapper(path, force=True) - # Return as StackedPickle for consistency return primary_format, StackedPickle([wrapper.pickled]) # Handle legacy formats as plain pickle @@ -122,14 +125,10 @@ def auto_load(path: Path | str) -> tuple[str, StackedPickle]: stacked = StackedPickle.load(f, fail_on_decode_error=False) return primary_format, stacked - except ImportError: - # torch not installed, fall through to plain pickle handling - pass - # Fall back to plain pickle try: with open(path, "rb") as f: stacked = StackedPickle.load(f, fail_on_decode_error=False) return "pickle", stacked - except Exception as e: + except (OSError, ValueError) as e: raise ValueError(f"Unable to load file as pickle: {e}") from e diff --git a/test/test_cli.py b/test/test_cli.py index 7a1bdc01..631ebf9f 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -10,10 +10,10 @@ from pickle import dumps from unittest import TestCase -import torch -import torchvision.models as models +import pytest -from fickling.cli import main +from fickling.cli import _get_first_positional, main +from fickling.constants import EXIT_CLEAN, EXIT_ERROR, EXIT_UNSAFE class TestCLIBackwardCompatibility(TestCase): @@ -23,7 +23,6 @@ def setUp(self): self.tmpdir = tempfile.TemporaryDirectory() self.tmppath = Path(self.tmpdir.name) - # Create a simple pickle file self.pickle_file = self.tmppath / "test.pkl" with open(self.pickle_file, "wb") as f: f.write(dumps({"test": "data"})) @@ -32,39 +31,108 @@ def tearDown(self): self.tmpdir.cleanup() def test_version_flag(self): - """Test --version flag.""" stdout = io.StringIO() with redirect_stdout(stdout): result = main(["fickling", "--version"]) - self.assertEqual(result, 0) - output = stdout.getvalue() - # Should contain version number - self.assertTrue(output.strip()) + self.assertEqual(result, EXIT_CLEAN) + # isatty() is False for StringIO, so output is just the version number + output = stdout.getvalue().strip() + self.assertTrue(output) + self.assertRegex(output, r"^\d+\.\d+\.\d+") + + def test_version_flag_short(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "-v"]) + self.assertEqual(result, EXIT_CLEAN) + output = stdout.getvalue().strip() + self.assertTrue(output) def test_decompile_pickle(self): - """Test basic pickle decompilation.""" stdout = io.StringIO() with redirect_stdout(stdout): result = main(["fickling", str(self.pickle_file)]) - self.assertEqual(result, 0) + self.assertEqual(result, EXIT_CLEAN) output = stdout.getvalue() - # Should contain decompiled code self.assertIn("result", output) + self.assertIn("test", output) def test_check_safety_legacy_flag(self): - """Test --check-safety on a safe pickle file (legacy syntax).""" stdout = io.StringIO() with redirect_stdout(stdout): result = main(["fickling", "--check-safety", str(self.pickle_file)]) - self.assertEqual(result, 0) + self.assertEqual(result, EXIT_CLEAN) + + def test_check_safety_legacy_stdin(self): + """Legacy --check-safety supports stdin via default '-' file arg.""" + # Verify the parser accepts --check-safety without a file argument + # (file defaults to "-" for stdin) + from fickling.cli import _create_legacy_parser + + parser = _create_legacy_parser() + args = parser.parse_args(["--check-safety"]) + self.assertTrue(args.check_safety) + self.assertEqual(args.file, "-") + + def test_legacy_mutually_exclusive_flags(self): + """--inject, --check-safety, --create are mutually exclusive.""" + with self.assertRaises(SystemExit): + main( + [ + "fickling", + "--check-safety", + "--inject", + "code", + str(self.pickle_file), + ] + ) def test_help_flag(self): - """Test --help flag.""" with self.assertRaises(SystemExit) as cm: main(["fickling", "--help"]) self.assertEqual(cm.exception.code, 0) +class TestGetFirstPositional(TestCase): + """Test the _get_first_positional routing function.""" + + def test_simple_command(self): + self.assertEqual( + _get_first_positional(["fickling", "check", "file.pkl"]), + "check", + ) + + def test_flag_value_not_misrouted(self): + """Flag values matching command names must not be treated as commands.""" + self.assertEqual( + _get_first_positional(["fickling", "--inject", "check", "file.pkl"]), + "file.pkl", + ) + + def test_short_flag_value_not_misrouted(self): + self.assertEqual( + _get_first_positional(["fickling", "-i", "check", "file.pkl"]), + "file.pkl", + ) + + def test_no_positional(self): + self.assertIsNone( + _get_first_positional(["fickling", "--version"]), + ) + + def test_file_path_as_first_positional(self): + self.assertEqual( + _get_first_positional(["fickling", "file.pkl"]), + "file.pkl", + ) + + def test_create_flag_value_skipped(self): + self.assertEqual( + _get_first_positional(["fickling", "--create", "expr", "out.pkl"]), + "out.pkl", + ) + + class TestCheckCommand(TestCase): """Test the 'fickling check' command.""" @@ -72,61 +140,79 @@ def setUp(self): self.tmpdir = tempfile.TemporaryDirectory() self.tmppath = Path(self.tmpdir.name) - # Create a simple pickle file self.pickle_file = self.tmppath / "test.pkl" with open(self.pickle_file, "wb") as f: f.write(dumps({"test": "data"})) - # Create a PyTorch model file - model = models.mobilenet_v2(weights=None) - self.model_file = self.tmppath / "model.pth" - torch.save(model, self.model_file) - def tearDown(self): self.tmpdir.cleanup() def test_check_pickle(self): - """Test 'fickling check' on a pickle file.""" stdout = io.StringIO() with redirect_stdout(stdout): result = main(["fickling", "check", str(self.pickle_file)]) - self.assertEqual(result, 0) + self.assertEqual(result, EXIT_CLEAN) output = stdout.getvalue() self.assertIn("Detected format", output) self.assertIn("pickle", output) - - def test_check_pytorch_model(self): - """Test 'fickling check' on a PyTorch model (auto-detection).""" - stdout = io.StringIO() - with redirect_stdout(stdout): - result = main(["fickling", "check", str(self.model_file)]) - # Result can be 0 (safe) or 1 (potentially unsafe) - just verify it runs - self.assertIn(result, [0, 1]) - output = stdout.getvalue() - self.assertIn("Detected format", output) - self.assertIn("PyTorch", output) + self.assertIn("No unsafe operations detected", output) def test_check_json_output(self): - """Test 'fickling check --json'.""" stdout = io.StringIO() with redirect_stdout(stdout): result = main(["fickling", "check", "--json", str(self.pickle_file)]) - self.assertEqual(result, 0) - output = stdout.getvalue() - # Should be valid JSON - data = json.loads(output) - self.assertIn("format", data) - self.assertIn("safe", data) + self.assertEqual(result, EXIT_CLEAN) + data = json.loads(stdout.getvalue()) + self.assertEqual(data["format"], "pickle") + self.assertTrue(data["safe"]) self.assertIn("severity", data) + self.assertIn("results", data) def test_check_file_not_found(self): - """Test 'fickling check' with non-existent file.""" stderr = io.StringIO() with redirect_stderr(stderr): result = main(["fickling", "check", "/nonexistent/file.pkl"]) - self.assertEqual(result, 1) + self.assertEqual(result, EXIT_ERROR) self.assertIn("file not found", stderr.getvalue()) + def test_check_print_results(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", "--print-results", str(self.pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + + +class TestCheckCommandPyTorch(TestCase): + """Test 'fickling check' on PyTorch models (requires torch).""" + + @classmethod + def setUpClass(cls): + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + def setUp(self): + import torch + import torchvision.models as models + + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_check_pytorch_model(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", str(self.model_file)]) + self.assertIn(result, [EXIT_CLEAN, EXIT_UNSAFE]) + output = stdout.getvalue() + self.assertIn("Detected format", output) + self.assertIn("PyTorch", output) + class TestInjectCommand(TestCase): """Test the 'fickling inject' command.""" @@ -135,21 +221,14 @@ def setUp(self): self.tmpdir = tempfile.TemporaryDirectory() self.tmppath = Path(self.tmpdir.name) - # Create a simple pickle file self.pickle_file = self.tmppath / "test.pkl" with open(self.pickle_file, "wb") as f: f.write(dumps({"test": "data"})) - # Create a PyTorch model file - model = models.mobilenet_v2(weights=None) - self.model_file = self.tmppath / "model.pth" - torch.save(model, self.model_file) - def tearDown(self): self.tmpdir.cleanup() def test_inject_pickle(self): - """Test 'fickling inject' on a pickle file.""" output_file = self.tmppath / "injected.pkl" stdout = io.StringIO() with redirect_stdout(stdout): @@ -164,11 +243,78 @@ def test_inject_pickle(self): str(output_file), ] ) - self.assertEqual(result, 0) + self.assertEqual(result, EXIT_CLEAN) self.assertTrue(output_file.exists()) + self.assertGreater(output_file.stat().st_size, 0) + + def test_inject_missing_output(self): + with self.assertRaises(SystemExit): + main(["fickling", "inject", str(self.pickle_file), "-c", "code"]) + + def test_inject_missing_code(self): + with self.assertRaises(SystemExit): + main(["fickling", "inject", str(self.pickle_file), "-o", "out.pkl"]) + + def test_inject_file_not_found(self): + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main( + [ + "fickling", + "inject", + "/nonexistent/file.pkl", + "-c", + "code", + "-o", + "out.pkl", + ] + ) + self.assertEqual(result, EXIT_ERROR) + self.assertIn("file not found", stderr.getvalue()) + + def test_inject_output_exists_no_overwrite(self): + output_file = self.tmppath / "existing.pkl" + output_file.write_bytes(b"existing") + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main( + [ + "fickling", + "inject", + str(self.pickle_file), + "-c", + "print('test')", + "-o", + str(output_file), + ] + ) + self.assertEqual(result, EXIT_ERROR) + self.assertIn("already exists", stderr.getvalue()) + + +class TestInjectCommandPyTorch(TestCase): + """Test 'fickling inject' on PyTorch models (requires torch).""" + + @classmethod + def setUpClass(cls): + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + def setUp(self): + import torch + import torchvision.models as models + + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() def test_inject_pytorch_model(self): - """Test 'fickling inject' on a PyTorch model.""" output_file = self.tmppath / "injected.pth" stdout = io.StringIO() with redirect_stdout(stdout): @@ -183,11 +329,10 @@ def test_inject_pytorch_model(self): str(output_file), ] ) - self.assertEqual(result, 0) + self.assertEqual(result, EXIT_CLEAN) self.assertTrue(output_file.exists()) def test_inject_pytorch_combination_method(self): - """Test 'fickling inject --method combination' on a PyTorch model.""" output_file = self.tmppath / "injected_combo.pth" result = main( [ @@ -202,107 +347,97 @@ def test_inject_pytorch_combination_method(self): "combination", ] ) - self.assertEqual(result, 0) + self.assertEqual(result, EXIT_CLEAN) self.assertTrue(output_file.exists()) - def test_inject_missing_output(self): - """Test inject without required --output flag.""" - with self.assertRaises(SystemExit): - main(["fickling", "inject", str(self.pickle_file), "-c", "code"]) - - def test_inject_missing_code(self): - """Test inject without required --code flag.""" - with self.assertRaises(SystemExit): - main(["fickling", "inject", str(self.pickle_file), "-o", "out.pkl"]) - - def test_inject_file_not_found(self): - """Test inject with non-existent file.""" - stderr = io.StringIO() - with redirect_stderr(stderr): - result = main( - [ - "fickling", - "inject", - "/nonexistent/file.pkl", - "-c", - "code", - "-o", - "out.pkl", - ] - ) - self.assertEqual(result, 1) - self.assertIn("file not found", stderr.getvalue()) - class TestInfoCommand(TestCase): - """Test the 'fickling info' command.""" + """Test 'fickling info' on plain pickle (no torch required).""" def setUp(self): self.tmpdir = tempfile.TemporaryDirectory() self.tmppath = Path(self.tmpdir.name) - # Create a simple pickle file self.pickle_file = self.tmppath / "test.pkl" with open(self.pickle_file, "wb") as f: f.write(dumps({"test": "data"})) - # Create a PyTorch model file - model = models.mobilenet_v2(weights=None) - self.model_file = self.tmppath / "model.pth" - torch.save(model, self.model_file) - def tearDown(self): self.tmpdir.cleanup() def test_info_pickle(self): - """Test 'fickling info' on a pickle file.""" stdout = io.StringIO() with redirect_stdout(stdout): result = main(["fickling", "info", str(self.pickle_file)]) - self.assertEqual(result, 0) + self.assertEqual(result, EXIT_CLEAN) output = stdout.getvalue() self.assertIn("Format:", output) + def test_info_file_not_found(self): + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "info", "/nonexistent/file.pth"]) + self.assertEqual(result, EXIT_ERROR) + self.assertIn("file not found", stderr.getvalue()) + + +class TestInfoCommandPyTorch(TestCase): + """Test 'fickling info' on PyTorch models (requires torch).""" + + @classmethod + def setUpClass(cls): + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + def setUp(self): + import torch + import torchvision.models as models + + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + def test_info_pytorch_model(self): - """Test 'fickling info' on a PyTorch model.""" stdout = io.StringIO() with redirect_stdout(stdout): result = main(["fickling", "info", str(self.model_file)]) - self.assertEqual(result, 0) + self.assertEqual(result, EXIT_CLEAN) output = stdout.getvalue() self.assertIn("Format:", output) self.assertIn("PyTorch", output) def test_info_json_output(self): - """Test 'fickling info --json'.""" stdout = io.StringIO() with redirect_stdout(stdout): result = main(["fickling", "info", "--json", str(self.model_file)]) - self.assertEqual(result, 0) - output = stdout.getvalue() - # Should be valid JSON - data = json.loads(output) + self.assertEqual(result, EXIT_CLEAN) + data = json.loads(stdout.getvalue()) self.assertIn("formats", data) self.assertIn("primary_format", data) self.assertIn("properties", data) - def test_info_file_not_found(self): - """Test info with non-existent file.""" - stderr = io.StringIO() - with redirect_stderr(stderr): - result = main(["fickling", "info", "/nonexistent/file.pth"]) - self.assertEqual(result, 1) - self.assertIn("file not found", stderr.getvalue()) - class TestCreatePolyglotCommand(TestCase): - """Test the 'fickling create-polyglot' command.""" + """Test the 'fickling create-polyglot' command (requires torch).""" + + @classmethod + def setUpClass(cls): + pytest.importorskip("torch") + pytest.importorskip("torchvision") def setUp(self): + import torch + import torchvision.models as models + self.tmpdir = tempfile.TemporaryDirectory() self.tmppath = Path(self.tmpdir.name) - # Create a PyTorch model file model = models.mobilenet_v2(weights=None) self.model_file = self.tmppath / "model.pth" torch.save(model, self.model_file) @@ -311,7 +446,6 @@ def tearDown(self): self.tmpdir.cleanup() def test_create_polyglot_file_not_found(self): - """Test create-polyglot with non-existent file.""" stderr = io.StringIO() with redirect_stderr(stderr): result = main( @@ -322,15 +456,92 @@ def test_create_polyglot_file_not_found(self): str(self.model_file), ] ) - self.assertEqual(result, 1) + self.assertEqual(result, EXIT_ERROR) self.assertIn("file not found", stderr.getvalue()) +class TestAutoLoad(TestCase): + """Test the auto_load() format detection function.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_auto_load_pickle(self): + from fickling.loader import auto_load + + pickle_file = self.tmppath / "test.pkl" + with open(pickle_file, "wb") as f: + f.write(dumps({"key": "value"})) + + format_name, stacked = auto_load(pickle_file) + self.assertEqual(format_name, "pickle") + self.assertGreater(len(stacked), 0) + + def test_auto_load_file_not_found(self): + from fickling.loader import auto_load + + with self.assertRaises(FileNotFoundError): + auto_load(Path("/nonexistent/file.pkl")) + + def test_auto_load_invalid_file(self): + from fickling.loader import auto_load + + bad_file = self.tmppath / "bad.pkl" + bad_file.write_bytes(b"not a pickle at all") + with self.assertRaises(ValueError): + auto_load(bad_file) + + def test_auto_load_string_path(self): + from fickling.loader import auto_load + + pickle_file = self.tmppath / "test.pkl" + with open(pickle_file, "wb") as f: + f.write(dumps([1, 2, 3])) + + format_name, stacked = auto_load(str(pickle_file)) + self.assertEqual(format_name, "pickle") + + +class TestCLIExitCodes(TestCase): + """Test that exit codes follow ClamAV convention.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_safe_pickle_returns_exit_clean(self): + pickle_file = self.tmppath / "safe.pkl" + with open(pickle_file, "wb") as f: + f.write(dumps({"safe": True})) + + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", str(pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + + def test_file_not_found_returns_exit_error(self): + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "check", "/nonexistent/file.pkl"]) + self.assertEqual(result, EXIT_ERROR) + + def test_version_returns_exit_clean(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "--version"]) + self.assertEqual(result, EXIT_CLEAN) + + class TestCLIErrorHandling(TestCase): """Test CLI error handling.""" def test_nonexistent_pickle_file(self): - """Test decompiling non-existent file raises FileNotFoundError.""" - # The original CLI raises FileNotFoundError for non-existent files with self.assertRaises(FileNotFoundError): main(["fickling", "/nonexistent/file.pkl"])