diff --git a/fickling/cli.py b/fickling/cli.py index 24f87dc..bb7c1dd 100644 --- a/fickling/cli.py +++ b/fickling/cli.py @@ -1,194 +1,666 @@ +"""Fickling CLI - Pickle security analyzer with auto-detection.""" + from __future__ import annotations +import json import sys from argparse import ArgumentParser from ast import unparse +from pathlib import Path from . import __version__, fickle, tracing from .analysis import Severity, check_safety +from .constants import EXIT_CLEAN, EXIT_ERROR, EXIT_UNSAFE DEFAULT_JSON_OUTPUT_FILE = "safety_results.json" +# Commands that use the new subcommand interface +COMMANDS = {"check", "inject", "info", "create-polyglot"} -def main(argv: list[str] | None = None) -> int: - if argv is None: - argv = sys.argv +def _check_torch_available() -> bool: + """Check if PyTorch is available and provide helpful error message if not.""" + try: + import torch # noqa: F401 + + return True + except ImportError: + sys.stderr.write( + "Error: PyTorch is required for this command.\n" + "Please install it with: pip install fickling[torch]\n" + ) + return False + + +def _create_legacy_parser() -> ArgumentParser: + """Create parser for legacy CLI behavior (backward compatibility).""" parser = ArgumentParser( - description="fickling is a static analyzer and interpreter for Python pickle data" + prog="fickling", + description="Pickle security analyzer with auto-detection for PyTorch models", ) + parser.add_argument("--version", "-v", action="store_true", help="print version and exit") parser.add_argument( - "PICKLE_FILE", + "file", type=str, nargs="?", default="-", - help="path to the pickle file to either " - "analyze or create (default is '-' for " - "STDIN/STDOUT)", + help="file to analyze (default: stdin)", ) options = parser.add_mutually_exclusive_group() + options.add_argument( + "--check-safety", + "-s", + action="store_true", + help="(legacy) run safety analysis - prefer 'fickling check FILE'", + ) options.add_argument( "--inject", "-i", type=str, default=None, - help="inject the specified Python code to be run at the end of unpickling, " - "and output the resulting pickle data", + help="(legacy) inject code - prefer 'fickling inject FILE -c CODE -o OUT'", ) parser.add_argument( "--inject-target", type=int, default=0, - help="some machine learning frameworks stack multiple pickles into the same model file; " - "this option specifies the index of the pickle file in which to inject the code from the " - "`--inject` command (default is 0)", + help="index of stacked pickle to inject into (default: 0)", + ) + options.add_argument( + "--create", + "-c", + type=str, + default=None, + help="(legacy) create pickle from Python expression", ) - options.add_argument("--create", "-c", type=str, default=None) parser.add_argument( "--run-last", "-l", action="store_true", - help="used with --inject to have the injected code " - "run after the existing pickling code in " - "PICKLE_FILE (default is for the injected code " - "to be run before the existing code)", + help="run injected code after existing code", ) parser.add_argument( "--replace-result", "-r", action="store_true", - help=( - "used with --inject to replace the unpickling result of the code in PICKLE_FILE " - "with the return value of the injected code. Either way, the preexisting pickling " - "code is still executed." - ), - ) - options.add_argument( - "--check-safety", - "-s", - action="store_true", - help=( - "test if the given pickle file is known to be unsafe. If so, exit with non-zero " - "status. This test is not guaranteed correct; the pickle file may still be unsafe " - "even if this check exits with code zero." - ), + help="replace unpickle result with injected code return value", ) - parser.add_argument( "--json-output", type=str, default=None, - help="path to the output JSON file to store the analysis results from check-safety." - f"If not provided, a default file named {DEFAULT_JSON_OUTPUT_FILE} will be used.", + help=f"path to output JSON file (default: {DEFAULT_JSON_OUTPUT_FILE})", ) - parser.add_argument( "--print-results", "-p", action="store_true", - help="Print the analysis results to the console when checking safety.", + help="print analysis results to console", ) - parser.add_argument( "--trace", "-t", action="store_true", - help="print a runtime trace while interpreting the input pickle file", + help="print a runtime trace while interpreting", ) - parser.add_argument("--version", "-v", action="store_true", help="print the version and exit") + return parser + +def _create_command_parser() -> ArgumentParser: + """Create parser with subcommands for new flat command structure.""" + parser = ArgumentParser( + prog="fickling", + description="Pickle security analyzer with auto-detection for PyTorch models", + ) + parser.add_argument("--version", "-v", action="store_true", help="print version and exit") + + subparsers = parser.add_subparsers(dest="command", help="available commands") + + # === check: Safety analysis === + check_parser = subparsers.add_parser( + "check", + help="safety check any pickle/model file", + description="Run safety analysis on any pickle or PyTorch model file (auto-detects format)", + ) + check_parser.add_argument("file", type=str, help="file to check") + check_parser.add_argument("--json", action="store_true", help="output results as JSON") + check_parser.add_argument( + "--json-output", + type=str, + default=None, + help=f"path to output JSON file (default: {DEFAULT_JSON_OUTPUT_FILE})", + ) + check_parser.add_argument( + "--print-results", "-p", action="store_true", help="print detailed results to console" + ) + + # === inject: Payload injection === + inject_parser = subparsers.add_parser( + "inject", + help="inject payload into pickle/model file", + description="Inject Python code into a pickle or PyTorch model file (auto-detects format)", + ) + inject_parser.add_argument("file", type=str, help="file to inject into") + inject_parser.add_argument( + "-c", "--code", type=str, required=True, help="Python code to inject" + ) + inject_parser.add_argument("-o", "--output", type=str, required=True, help="output file path") + inject_parser.add_argument( + "--method", + type=str, + choices=["insertion", "combination"], + default="insertion", + help="injection method for PyTorch models (default: insertion)", + ) + inject_parser.add_argument( + "--run-last", + "-l", + action="store_true", + help="run injected code after existing code (default: before)", + ) + inject_parser.add_argument( + "--replace-result", + "-r", + action="store_true", + help="replace unpickle result with injected code return value", + ) + inject_parser.add_argument( + "--overwrite", action="store_true", help="overwrite output file if exists" + ) + + # === info: Format identification === + info_parser = subparsers.add_parser( + "info", + help="show format and properties of a file", + description="Identify file format and show properties (requires PyTorch for full detection)", + ) + info_parser.add_argument("file", type=str, help="file to analyze") + info_parser.add_argument("--json", action="store_true", help="output results as JSON") + info_parser.add_argument( + "-r", "--recursive", action="store_true", help="analyze recursively into archives" + ) + + # === create-polyglot: Polyglot creation === + polyglot_parser = subparsers.add_parser( + "create-polyglot", + help="create a polyglot file from two inputs", + description="Create a polyglot file by combining two PyTorch/pickle files", + ) + polyglot_parser.add_argument("file1", type=str, help="first input file") + polyglot_parser.add_argument("file2", type=str, help="second input file") + polyglot_parser.add_argument("-o", "--output", type=str, default=None, help="output file path") + polyglot_parser.add_argument( + "--quiet", "-q", action="store_true", help="suppress output messages" + ) + + return parser + + +# Flags that consume the next argument as a value +_FLAGS_WITH_VALUES = { + "--inject", + "-i", + "--inject-target", + "--create", + "-c", + "--json-output", + "--code", + "--output", + "-o", + "--method", +} + + +def _get_first_positional(argv: list[str]) -> str | None: + """Get the first non-flag argument (potential command or file). + + Skips values consumed by flags like --inject to avoid + misrouting when a flag value matches a command name. + """ + skip_next = False + for arg in argv[1:]: + if skip_next: + skip_next = False + continue + if arg in _FLAGS_WITH_VALUES: + skip_next = True + continue + if arg.startswith("-"): + continue + return arg + return None + + +def main(argv: list[str] | None = None) -> int: + """Main CLI entry point.""" + if argv is None: + argv = sys.argv + + # Check for version flag first + if "--version" in argv or "-v" in argv: + if len(argv) == 2: # Only version flag present + if sys.stdout.isatty(): + print(f"fickling version {__version__}") + else: + print(__version__) + return EXIT_CLEAN + + # Determine if we're using a new command or legacy CLI + first_positional = _get_first_positional(argv) + + if first_positional in COMMANDS: + # Use new command parser + parser = _create_command_parser() + args = parser.parse_args(argv[1:]) + + if args.command == "check": + return _handle_check(args) + if args.command == "inject": + return _handle_inject(args) + if args.command == "info": + return _handle_info(args) + if args.command == "create-polyglot": + return _handle_create_polyglot(args) + return EXIT_ERROR + + # Use legacy parser for backward compatibility + parser = _create_legacy_parser() args = parser.parse_args(argv[1:]) + return _handle_legacy(args) - if args.version: - if sys.stdout.isatty(): - print(f"fickling version {__version__}") + +def _handle_check(args) -> int: + """Handle 'fickling check FILE' - safety analysis with auto-detection.""" + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return EXIT_ERROR + + json_output_path = args.json_output or DEFAULT_JSON_OUTPUT_FILE + print_results = getattr(args, "print_results", False) + + try: + from .loader import auto_load + + format_name, stacked_pickled = auto_load(file_path) + + if not getattr(args, "json", False): + print(f"Detected format: {format_name}") + + was_safe = True + all_results = [] + + for pickled in stacked_pickled: + safety_results = check_safety(pickled, json_output_path=json_output_path) + all_results.append(safety_results) + + if safety_results.severity > Severity.LIKELY_SAFE: + was_safe = False + + if getattr(args, "json", False): + result = { + "file": str(file_path), + "format": format_name, + "safe": was_safe, + "severity": max(r.severity.value for r in all_results), + "results": [r.to_dict() for r in all_results], + } + print(json.dumps(result, indent=2)) else: - print(__version__) - return 0 + if print_results: + for i, safety_results in enumerate(all_results): + if len(all_results) > 1: + print(f"\n--- Pickle {i} ---") + print(safety_results.to_string()) - if args.create is None: - if args.PICKLE_FILE == "-": - if hasattr(sys.stdin, "buffer") and sys.stdin.buffer is not None: - file = sys.stdin.buffer + if was_safe: + print("No unsafe operations detected.") else: - file = sys.stdin - else: - file = open(args.PICKLE_FILE, "rb") - try: - stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) - except fickle.PickleDecodeError as e: - sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") - if args.check_safety: sys.stderr.write( - "Parsing errors might be indicative of a maliciously crafted pickle file. DO NOT TRUST this file without performing further analysis!\n" + "\nWarning: Potentially unsafe operations detected.\n" + "Do not unpickle this file if it is from an untrusted source!\n" ) - sys.stderr.write( - "\n(If this is a valid pickle file, please report the error at https://github.com/trailofbits/fickling)\n" - ) - return 1 - finally: - file.close() - if args.inject is not None: - if args.inject_target >= len(stacked_pickled): - sys.stderr.write( - f"Error: --inject-target {args.inject_target} is too high; there are only " - f"{len(stacked_pickled)} stacked pickle files in the input\n" - ) - return 1 - if hasattr(sys.stdout, "buffer") and sys.stdout.buffer is not None: - buffer = sys.stdout.buffer - else: - buffer = sys.stdout - for pickled in stacked_pickled[: args.inject_target]: + return EXIT_CLEAN if was_safe else EXIT_UNSAFE + + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + sys.stderr.write( + "Parsing errors might be indicative of a maliciously crafted pickle file. " + "DO NOT TRUST this file without performing further analysis!\n" + ) + return EXIT_ERROR + except FileNotFoundError as e: + sys.stderr.write(f"Error: {e}\n") + return EXIT_ERROR + except ValueError as e: + sys.stderr.write(f"Error loading file: {e}\n") + return EXIT_ERROR + except OSError as e: + sys.stderr.write(f"Error reading file: {e}\n") + return EXIT_ERROR + + +def _handle_inject(args) -> int: + """Handle 'fickling inject FILE -c CODE -o OUT' - payload injection with auto-detection.""" + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return EXIT_ERROR + + output_path = Path(args.output) + if output_path.exists() and not getattr(args, "overwrite", False): + sys.stderr.write(f"Error: output file already exists: {args.output}\n") + sys.stderr.write("Use --overwrite to replace it.\n") + return EXIT_ERROR + + try: + from .loader import auto_load + + format_name, stacked_pickled = auto_load(file_path) + print(f"Detected format: {format_name}") + + # For PyTorch ZIP formats, use PyTorchModelWrapper for proper injection + if format_name in ("PyTorch v1.3", "TorchScript v1.4", "TorchScript v1.3"): + if not _check_torch_available(): + return EXIT_ERROR + + from .pytorch import PyTorchModelWrapper + + method = getattr(args, "method", "insertion") + overwrite = getattr(args, "overwrite", False) + + wrapper = PyTorchModelWrapper(file_path, force=True) + wrapper.inject_payload(args.code, output_path, injection=method, overwrite=overwrite) + print(f"Payload injected successfully. Output: {output_path}") + return EXIT_CLEAN + + # For plain pickle, use direct injection + inject_target = getattr(args, "inject_target", 0) + if inject_target >= len(stacked_pickled): + sys.stderr.write( + f"Error: --inject-target {inject_target} is too high; there are only " + f"{len(stacked_pickled)} stacked pickle files in the input\n" + ) + return EXIT_ERROR + + if args.output == "-": + buffer = ( + sys.stdout.buffer + if hasattr(sys.stdout, "buffer") and sys.stdout.buffer + else sys.stdout + ) + should_close = False + else: + buffer = open(output_path, "wb") + should_close = True + + try: + for pickled in stacked_pickled[:inject_target]: pickled.dump(buffer) - pickled = stacked_pickled[args.inject_target] - if not isinstance(pickled[-1], fickle.Stop): - sys.stderr.write( - "Warning: The last opcode of the input file was expected to be STOP, but was " - f"in fact {pickled[-1].info.name}" - ) + + pickled = stacked_pickled[inject_target] pickled.insert_python_eval( - args.inject, - run_first=not args.run_last, - use_output_as_unpickle_result=args.replace_result, + args.code, + run_first=not getattr(args, "run_last", False), + use_output_as_unpickle_result=getattr(args, "replace_result", False), ) pickled.dump(buffer) - for pickled in stacked_pickled[args.inject_target + 1 :]: + + for pickled in stacked_pickled[inject_target + 1 :]: pickled.dump(buffer) - elif args.check_safety: - was_safe = True - json_output_path = args.json_output or DEFAULT_JSON_OUTPUT_FILE - for pickled in stacked_pickled: - safety_results = check_safety(pickled, json_output_path=json_output_path) - # Print results if requested - if args.print_results: - print(safety_results.to_string()) + print(f"Payload injected successfully. Output: {output_path}") + return EXIT_CLEAN + finally: + if should_close: + buffer.close() + + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + return EXIT_ERROR + except FileNotFoundError as e: + sys.stderr.write(f"Error: {e}\n") + return EXIT_ERROR + except ValueError as e: + sys.stderr.write(f"Error: {e}\n") + return EXIT_ERROR + except OSError as e: + sys.stderr.write(f"Error: {e}\n") + return EXIT_ERROR - if safety_results.severity > Severity.LIKELY_SAFE: - was_safe = False - if args.print_results: - sys.stderr.write( - "Warning: Fickling detected that the pickle file may be unsafe.\n\n" - "Do not unpickle this file if it is from an untrusted source!\n\n" - ) - return [1, 0][was_safe] +def _handle_info(args) -> int: + """Handle 'fickling info FILE' - format identification and properties.""" + file_path = Path(args.file) + if not file_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file}\n") + return EXIT_ERROR + # Try to use polyglot module for detailed analysis (requires torch) + try: + from .polyglot import ( + find_file_properties, + find_file_properties_recursively, + identify_pytorch_file_format, + ) + except ImportError: + return _handle_info_basic(args, file_path) + + try: + formats = identify_pytorch_file_format(args.file, print_results=False) + recursive = getattr(args, "recursive", False) + + if recursive: + properties = find_file_properties_recursively(args.file, print_properties=False) else: - var_id = 0 - for i, pickled in enumerate(stacked_pickled): - interpreter = fickle.Interpreter( - pickled, first_variable_id=var_id, result_variable=f"result{i}" - ) - if args.trace: - trace = tracing.Trace(interpreter) - print(unparse(trace.run())) + properties = find_file_properties(args.file, print_properties=False) + + if getattr(args, "json", False): + result = { + "file": str(file_path), + "formats": formats, + "primary_format": formats[0] if formats else None, + "is_polyglot": len(formats) > 1, + "properties": properties, + } + print(json.dumps(result, indent=2)) + else: + if formats: + print(f"Format: {formats[0]}") + if len(formats) > 1: + print(f"Also valid as: {', '.join(formats[1:])}") + print("(This file may be a polyglot)") + else: + print("Format: pickle (no specific PyTorch format detected)") + + print("\nProperties:") + _print_properties(properties, indent=2) + + return EXIT_CLEAN + + except (fickle.PickleDecodeError, ValueError, OSError) as e: + sys.stderr.write(f"Error: {e}\n") + return EXIT_ERROR + + +def _handle_info_basic(args, file_path: Path) -> int: + """Handle 'fickling info' without PyTorch (basic pickle info only).""" + try: + with open(file_path, "rb") as f: + stacked = fickle.StackedPickle.load(f, fail_on_decode_error=False) + + if getattr(args, "json", False): + result = { + "file": str(file_path), + "formats": ["pickle"], + "primary_format": "pickle", + "is_polyglot": False, + "pickle_count": len(stacked), + } + print(json.dumps(result, indent=2)) + else: + print("Format: pickle") + print(f"Stacked pickles: {len(stacked)}") + print("\nNote: Install PyTorch for detailed format detection:") + print(" pip install fickling[torch]") + + return EXIT_CLEAN + except (fickle.PickleDecodeError, ValueError, OSError) as e: + sys.stderr.write(f"Error reading file: {e}\n") + return EXIT_ERROR + + +def _print_properties(properties: dict, indent: int = 0) -> None: + """Pretty-print file properties.""" + prefix = " " * indent + for key, value in properties.items(): + if key == "children" and isinstance(value, dict): + print(f"{prefix}{key}:") + for child_name, child_props in value.items(): + print(f"{prefix} {child_name}:") + if child_props is not None: + _print_properties(child_props, indent + 4) else: - print(unparse(interpreter.to_ast())) - var_id = interpreter.next_variable_id - else: + print(f"{prefix} (unable to read)") + else: + print(f"{prefix}{key}: {value}") + + +def _handle_create_polyglot(args) -> int: + """Handle 'fickling create-polyglot FILE1 FILE2 -o OUT'.""" + if not _check_torch_available(): + return EXIT_ERROR + + from .polyglot import create_polyglot + + file1_path = Path(args.file1) + file2_path = Path(args.file2) + + if not file1_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file1}\n") + return EXIT_ERROR + if not file2_path.exists(): + sys.stderr.write(f"Error: file not found: {args.file2}\n") + return EXIT_ERROR + + output_path = getattr(args, "output", None) + quiet = getattr(args, "quiet", False) + + try: + success = create_polyglot( + args.file1, args.file2, polyglot_file_name=output_path, print_results=not quiet + ) + + if success: + return EXIT_CLEAN + if not quiet: + sys.stderr.write("Failed to create polyglot. The file formats may not be compatible.\n") + return EXIT_ERROR + except (ValueError, OSError) as e: + sys.stderr.write(f"Error creating polyglot: {e}\n") + return EXIT_ERROR + + +def _open_input(file_arg: str): + """Open input file or return stdin buffer.""" + if file_arg == "-": + if hasattr(sys.stdin, "buffer") and sys.stdin.buffer is not None: + return sys.stdin.buffer + return sys.stdin + return open(file_arg, "rb") + + +def _handle_legacy(args) -> int: + """Handle legacy CLI behavior (backward compatibility).""" + # Handle --check-safety flag (supports stdin via '-' default) + if args.check_safety: + file = _open_input(args.file) + try: + stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + sys.stderr.write( + "Parsing errors might be indicative of a maliciously crafted " + "pickle file. DO NOT TRUST this file without performing " + "further analysis!\n" + ) + return EXIT_ERROR + finally: + if file not in (sys.stdin, getattr(sys.stdin, "buffer", None)): + file.close() + + was_safe = True + json_output_path = args.json_output or DEFAULT_JSON_OUTPUT_FILE + for pickled in stacked_pickled: + safety_results = check_safety(pickled, json_output_path=json_output_path) + + if args.print_results: + print(safety_results.to_string()) + + if safety_results.severity > Severity.LIKELY_SAFE: + was_safe = False + if args.print_results: + sys.stderr.write( + "Warning: Fickling detected that the pickle file " + "may be unsafe.\n\n" + "Do not unpickle this file if it is from an " + "untrusted source!\n\n" + ) + + return EXIT_CLEAN if was_safe else EXIT_UNSAFE + + # Handle --inject flag + if args.inject: + file = _open_input(args.file) + + try: + stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + return EXIT_ERROR + finally: + if file not in (sys.stdin, getattr(sys.stdin, "buffer", None)): + file.close() + + if args.inject_target >= len(stacked_pickled): + sys.stderr.write( + f"Error: --inject-target {args.inject_target} is too high; " + f"there are only {len(stacked_pickled)} stacked pickle " + f"files in the input\n" + ) + return EXIT_ERROR + + if hasattr(sys.stdout, "buffer") and sys.stdout.buffer is not None: + buffer = sys.stdout.buffer + else: + buffer = sys.stdout + + for pickled in stacked_pickled[: args.inject_target]: + pickled.dump(buffer) + + pickled = stacked_pickled[args.inject_target] + if not isinstance(pickled[-1], fickle.Stop): + sys.stderr.write( + "Warning: The last opcode of the input file was expected " + "to be STOP, but was in fact " + f"{pickled[-1].info.name}" + ) + + pickled.insert_python_eval( + args.inject, + run_first=not args.run_last, + use_output_as_unpickle_result=args.replace_result, + ) + pickled.dump(buffer) + + for pickled in stacked_pickled[args.inject_target + 1 :]: + pickled.dump(buffer) + + return EXIT_CLEAN + + # Handle --create flag + if args.create: pickled = fickle.Pickled( [ fickle.Global.create("__builtin__", "eval"), @@ -199,15 +671,46 @@ def main(argv: list[str] | None = None) -> int: fickle.Stop(), ] ) - if args.PICKLE_FILE == "-": - file = sys.stdout - if hasattr(file, "buffer") and file.buffer is not None: - file = file.buffer + if args.file == "-": + if hasattr(sys.stdout, "buffer") and sys.stdout.buffer is not None: + file = sys.stdout.buffer + else: + file = sys.stdout else: - file = open(args.PICKLE_FILE, "wb") + file = open(args.file, "wb") + try: pickled.dump(file) finally: + if file not in (sys.stdout, getattr(sys.stdout, "buffer", None)): + file.close() + + return EXIT_CLEAN + + # Default: decompile the file + file = _open_input(args.file) + + try: + stacked_pickled = fickle.StackedPickle.load(file, fail_on_decode_error=False) + except fickle.PickleDecodeError as e: + sys.stderr.write(f"Fickling failed to parse this pickle file. Error: {e!s}\n") + return EXIT_ERROR + finally: + if file not in (sys.stdin, getattr(sys.stdin, "buffer", None)): file.close() - return 0 + var_id = 0 + for i, pickled in enumerate(stacked_pickled): + interpreter = fickle.Interpreter( + pickled, + first_variable_id=var_id, + result_variable=f"result{i}", + ) + if args.trace: + trace = tracing.Trace(interpreter) + print(unparse(trace.run())) + else: + print(unparse(interpreter.to_ast())) + var_id = interpreter.next_variable_id + + return EXIT_CLEAN diff --git a/fickling/constants.py b/fickling/constants.py new file mode 100644 index 0000000..f0bdc1b --- /dev/null +++ b/fickling/constants.py @@ -0,0 +1,4 @@ +# ClamAV-compatible exit codes for CI/CD integration +EXIT_CLEAN = 0 # No issues found +EXIT_UNSAFE = 1 # Potentially malicious content detected +EXIT_ERROR = 2 # Scan error (parse failure, file not found, etc.) diff --git a/fickling/loader.py b/fickling/loader.py index 5863f4e..24c517d 100644 --- a/fickling/loader.py +++ b/fickling/loader.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import pickle from io import BytesIO +from pathlib import Path from fickling.analysis import Severity, check_safety from fickling.exception import UnsafeFileError -from fickling.fickle import Pickled +from fickling.fickle import Pickled, StackedPickle def load( @@ -70,3 +73,62 @@ def loads( json_output_path=json_output_path, **kwargs, ) + + +def auto_load(path: Path | str) -> tuple[str, StackedPickle]: + """ + Auto-detect file format and load the pickle content. + + This function automatically detects whether the file is a PyTorch model (ZIP format), + a plain pickle, or other supported formats, and returns the appropriate Pickled data. + + Args: + path: Path to the file to load + + Returns: + A tuple of (format_name, pickled_data) where: + - format_name: A string describing the detected format (e.g., "PyTorch v1.3", "pickle") + - pickled_data: A StackedPickle containing one or more Pickled objects + + Raises: + ValueError: If the file format cannot be determined or is unsupported + """ + if isinstance(path, str): + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # Try PyTorch ZIP formats first (most common for ML models) + try: + from fickling.polyglot import identify_pytorch_file_format + except ImportError: + # torch not installed, fall through to plain pickle handling + identify_pytorch_file_format = None + + if identify_pytorch_file_format is not None: + formats = identify_pytorch_file_format(path, print_results=False) + + if formats: + primary_format = formats[0] + + # Handle PyTorch v1.3 and TorchScript v1.4 (ZIP with data.pkl) + if primary_format in ("PyTorch v1.3", "TorchScript v1.4", "TorchScript v1.3"): + from fickling.pytorch import PyTorchModelWrapper + + wrapper = PyTorchModelWrapper(path, force=True) + return primary_format, StackedPickle([wrapper.pickled]) + + # Handle legacy formats as plain pickle + if primary_format == "PyTorch v0.1.10": + with open(path, "rb") as f: + stacked = StackedPickle.load(f, fail_on_decode_error=False) + return primary_format, stacked + + # Fall back to plain pickle + try: + with open(path, "rb") as f: + stacked = StackedPickle.load(f, fail_on_decode_error=False) + return "pickle", stacked + except (OSError, ValueError) as e: + raise ValueError(f"Unable to load file as pickle: {e}") from e diff --git a/test/test_cli.py b/test/test_cli.py new file mode 100644 index 0000000..631ebf9 --- /dev/null +++ b/test/test_cli.py @@ -0,0 +1,547 @@ +"""Tests for the fickling CLI.""" + +from __future__ import annotations + +import io +import json +import tempfile +from contextlib import redirect_stderr, redirect_stdout +from pathlib import Path +from pickle import dumps +from unittest import TestCase + +import pytest + +from fickling.cli import _get_first_positional, main +from fickling.constants import EXIT_CLEAN, EXIT_ERROR, EXIT_UNSAFE + + +class TestCLIBackwardCompatibility(TestCase): + """Test that existing CLI behavior is preserved.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + self.pickle_file = self.tmppath / "test.pkl" + with open(self.pickle_file, "wb") as f: + f.write(dumps({"test": "data"})) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_version_flag(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "--version"]) + self.assertEqual(result, EXIT_CLEAN) + # isatty() is False for StringIO, so output is just the version number + output = stdout.getvalue().strip() + self.assertTrue(output) + self.assertRegex(output, r"^\d+\.\d+\.\d+") + + def test_version_flag_short(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "-v"]) + self.assertEqual(result, EXIT_CLEAN) + output = stdout.getvalue().strip() + self.assertTrue(output) + + def test_decompile_pickle(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", str(self.pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + output = stdout.getvalue() + self.assertIn("result", output) + self.assertIn("test", output) + + def test_check_safety_legacy_flag(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "--check-safety", str(self.pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + + def test_check_safety_legacy_stdin(self): + """Legacy --check-safety supports stdin via default '-' file arg.""" + # Verify the parser accepts --check-safety without a file argument + # (file defaults to "-" for stdin) + from fickling.cli import _create_legacy_parser + + parser = _create_legacy_parser() + args = parser.parse_args(["--check-safety"]) + self.assertTrue(args.check_safety) + self.assertEqual(args.file, "-") + + def test_legacy_mutually_exclusive_flags(self): + """--inject, --check-safety, --create are mutually exclusive.""" + with self.assertRaises(SystemExit): + main( + [ + "fickling", + "--check-safety", + "--inject", + "code", + str(self.pickle_file), + ] + ) + + def test_help_flag(self): + with self.assertRaises(SystemExit) as cm: + main(["fickling", "--help"]) + self.assertEqual(cm.exception.code, 0) + + +class TestGetFirstPositional(TestCase): + """Test the _get_first_positional routing function.""" + + def test_simple_command(self): + self.assertEqual( + _get_first_positional(["fickling", "check", "file.pkl"]), + "check", + ) + + def test_flag_value_not_misrouted(self): + """Flag values matching command names must not be treated as commands.""" + self.assertEqual( + _get_first_positional(["fickling", "--inject", "check", "file.pkl"]), + "file.pkl", + ) + + def test_short_flag_value_not_misrouted(self): + self.assertEqual( + _get_first_positional(["fickling", "-i", "check", "file.pkl"]), + "file.pkl", + ) + + def test_no_positional(self): + self.assertIsNone( + _get_first_positional(["fickling", "--version"]), + ) + + def test_file_path_as_first_positional(self): + self.assertEqual( + _get_first_positional(["fickling", "file.pkl"]), + "file.pkl", + ) + + def test_create_flag_value_skipped(self): + self.assertEqual( + _get_first_positional(["fickling", "--create", "expr", "out.pkl"]), + "out.pkl", + ) + + +class TestCheckCommand(TestCase): + """Test the 'fickling check' command.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + self.pickle_file = self.tmppath / "test.pkl" + with open(self.pickle_file, "wb") as f: + f.write(dumps({"test": "data"})) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_check_pickle(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", str(self.pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + output = stdout.getvalue() + self.assertIn("Detected format", output) + self.assertIn("pickle", output) + self.assertIn("No unsafe operations detected", output) + + def test_check_json_output(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", "--json", str(self.pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + data = json.loads(stdout.getvalue()) + self.assertEqual(data["format"], "pickle") + self.assertTrue(data["safe"]) + self.assertIn("severity", data) + self.assertIn("results", data) + + def test_check_file_not_found(self): + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "check", "/nonexistent/file.pkl"]) + self.assertEqual(result, EXIT_ERROR) + self.assertIn("file not found", stderr.getvalue()) + + def test_check_print_results(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", "--print-results", str(self.pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + + +class TestCheckCommandPyTorch(TestCase): + """Test 'fickling check' on PyTorch models (requires torch).""" + + @classmethod + def setUpClass(cls): + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + def setUp(self): + import torch + import torchvision.models as models + + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_check_pytorch_model(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", str(self.model_file)]) + self.assertIn(result, [EXIT_CLEAN, EXIT_UNSAFE]) + output = stdout.getvalue() + self.assertIn("Detected format", output) + self.assertIn("PyTorch", output) + + +class TestInjectCommand(TestCase): + """Test the 'fickling inject' command.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + self.pickle_file = self.tmppath / "test.pkl" + with open(self.pickle_file, "wb") as f: + f.write(dumps({"test": "data"})) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_inject_pickle(self): + output_file = self.tmppath / "injected.pkl" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main( + [ + "fickling", + "inject", + str(self.pickle_file), + "-c", + "print('test')", + "-o", + str(output_file), + ] + ) + self.assertEqual(result, EXIT_CLEAN) + self.assertTrue(output_file.exists()) + self.assertGreater(output_file.stat().st_size, 0) + + def test_inject_missing_output(self): + with self.assertRaises(SystemExit): + main(["fickling", "inject", str(self.pickle_file), "-c", "code"]) + + def test_inject_missing_code(self): + with self.assertRaises(SystemExit): + main(["fickling", "inject", str(self.pickle_file), "-o", "out.pkl"]) + + def test_inject_file_not_found(self): + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main( + [ + "fickling", + "inject", + "/nonexistent/file.pkl", + "-c", + "code", + "-o", + "out.pkl", + ] + ) + self.assertEqual(result, EXIT_ERROR) + self.assertIn("file not found", stderr.getvalue()) + + def test_inject_output_exists_no_overwrite(self): + output_file = self.tmppath / "existing.pkl" + output_file.write_bytes(b"existing") + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main( + [ + "fickling", + "inject", + str(self.pickle_file), + "-c", + "print('test')", + "-o", + str(output_file), + ] + ) + self.assertEqual(result, EXIT_ERROR) + self.assertIn("already exists", stderr.getvalue()) + + +class TestInjectCommandPyTorch(TestCase): + """Test 'fickling inject' on PyTorch models (requires torch).""" + + @classmethod + def setUpClass(cls): + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + def setUp(self): + import torch + import torchvision.models as models + + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_inject_pytorch_model(self): + output_file = self.tmppath / "injected.pth" + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main( + [ + "fickling", + "inject", + str(self.model_file), + "-c", + "print('test')", + "-o", + str(output_file), + ] + ) + self.assertEqual(result, EXIT_CLEAN) + self.assertTrue(output_file.exists()) + + def test_inject_pytorch_combination_method(self): + output_file = self.tmppath / "injected_combo.pth" + result = main( + [ + "fickling", + "inject", + str(self.model_file), + "-c", + "print('test')", + "-o", + str(output_file), + "--method", + "combination", + ] + ) + self.assertEqual(result, EXIT_CLEAN) + self.assertTrue(output_file.exists()) + + +class TestInfoCommand(TestCase): + """Test 'fickling info' on plain pickle (no torch required).""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + self.pickle_file = self.tmppath / "test.pkl" + with open(self.pickle_file, "wb") as f: + f.write(dumps({"test": "data"})) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_info_pickle(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "info", str(self.pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + output = stdout.getvalue() + self.assertIn("Format:", output) + + def test_info_file_not_found(self): + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "info", "/nonexistent/file.pth"]) + self.assertEqual(result, EXIT_ERROR) + self.assertIn("file not found", stderr.getvalue()) + + +class TestInfoCommandPyTorch(TestCase): + """Test 'fickling info' on PyTorch models (requires torch).""" + + @classmethod + def setUpClass(cls): + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + def setUp(self): + import torch + import torchvision.models as models + + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_info_pytorch_model(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "info", str(self.model_file)]) + self.assertEqual(result, EXIT_CLEAN) + output = stdout.getvalue() + self.assertIn("Format:", output) + self.assertIn("PyTorch", output) + + def test_info_json_output(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "info", "--json", str(self.model_file)]) + self.assertEqual(result, EXIT_CLEAN) + data = json.loads(stdout.getvalue()) + self.assertIn("formats", data) + self.assertIn("primary_format", data) + self.assertIn("properties", data) + + +class TestCreatePolyglotCommand(TestCase): + """Test the 'fickling create-polyglot' command (requires torch).""" + + @classmethod + def setUpClass(cls): + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + def setUp(self): + import torch + import torchvision.models as models + + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + model = models.mobilenet_v2(weights=None) + self.model_file = self.tmppath / "model.pth" + torch.save(model, self.model_file) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_create_polyglot_file_not_found(self): + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main( + [ + "fickling", + "create-polyglot", + "/nonexistent/file1.pth", + str(self.model_file), + ] + ) + self.assertEqual(result, EXIT_ERROR) + self.assertIn("file not found", stderr.getvalue()) + + +class TestAutoLoad(TestCase): + """Test the auto_load() format detection function.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_auto_load_pickle(self): + from fickling.loader import auto_load + + pickle_file = self.tmppath / "test.pkl" + with open(pickle_file, "wb") as f: + f.write(dumps({"key": "value"})) + + format_name, stacked = auto_load(pickle_file) + self.assertEqual(format_name, "pickle") + self.assertGreater(len(stacked), 0) + + def test_auto_load_file_not_found(self): + from fickling.loader import auto_load + + with self.assertRaises(FileNotFoundError): + auto_load(Path("/nonexistent/file.pkl")) + + def test_auto_load_invalid_file(self): + from fickling.loader import auto_load + + bad_file = self.tmppath / "bad.pkl" + bad_file.write_bytes(b"not a pickle at all") + with self.assertRaises(ValueError): + auto_load(bad_file) + + def test_auto_load_string_path(self): + from fickling.loader import auto_load + + pickle_file = self.tmppath / "test.pkl" + with open(pickle_file, "wb") as f: + f.write(dumps([1, 2, 3])) + + format_name, stacked = auto_load(str(pickle_file)) + self.assertEqual(format_name, "pickle") + + +class TestCLIExitCodes(TestCase): + """Test that exit codes follow ClamAV convention.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.tmppath = Path(self.tmpdir.name) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_safe_pickle_returns_exit_clean(self): + pickle_file = self.tmppath / "safe.pkl" + with open(pickle_file, "wb") as f: + f.write(dumps({"safe": True})) + + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "check", str(pickle_file)]) + self.assertEqual(result, EXIT_CLEAN) + + def test_file_not_found_returns_exit_error(self): + stderr = io.StringIO() + with redirect_stderr(stderr): + result = main(["fickling", "check", "/nonexistent/file.pkl"]) + self.assertEqual(result, EXIT_ERROR) + + def test_version_returns_exit_clean(self): + stdout = io.StringIO() + with redirect_stdout(stdout): + result = main(["fickling", "--version"]) + self.assertEqual(result, EXIT_CLEAN) + + +class TestCLIErrorHandling(TestCase): + """Test CLI error handling.""" + + def test_nonexistent_pickle_file(self): + with self.assertRaises(FileNotFoundError): + main(["fickling", "/nonexistent/file.pkl"])