Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions test/fixtures/generate_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10,<3.14"
# dependencies = ["torch>=2.10,<3", "torchvision>=0.25,<1"]
# ///
"""Generate TorchScript fixture files for tests.

These fixtures are checked into the repository so that tests don't need
to call torch.jit.script/torch.jit.save at runtime (which emits
deprecation warnings on Python 3.14+ and may be removed in future
PyTorch versions).

Usage:
uv run test/fixtures/generate_fixtures.py
"""

from pathlib import Path

import torch
import torchvision.models as models

FIXTURES_DIR = Path(__file__).parent


def main():
model = models.squeezenet1_0()
scripted = torch.jit.script(model)

out = FIXTURES_DIR / "squeezenet1_0_torchscript_v1_4.pt"
torch.jit.save(scripted, out)
print(f"Generated {out} ({out.stat().st_size} bytes)")


if __name__ == "__main__":
main()
Binary file added test/fixtures/squeezenet1_0_torchscript_v1_4.pt
Binary file not shown.
22 changes: 5 additions & 17 deletions test/test_polyglot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
import string
import sys
import tarfile
import tempfile
import unittest
Expand All @@ -15,7 +14,7 @@
import fickling.polyglot as polyglot
from fickling.polyglot import FileProperties

_lacks_torch_jit_support = sys.version_info >= (3, 14)
FIXTURES_DIR = Path(__file__).parent / "fixtures"


def _make_properties(**overrides):
Expand Down Expand Up @@ -79,17 +78,9 @@ def setUp(self):
self.filename_legacy_pickle = tmppath / "model_legacy_pickle.pth"
torch.save(model, self.filename_legacy_pickle, _use_new_zipfile_serialization=False)

if not _lacks_torch_jit_support:
# TorchScript v1.4
m = torch.jit.script(model)
self.filename_torchscript = tmppath / "model_torchscript.pt"
torch.jit.save(m, self.filename_torchscript)

# TorchScript v1.4 Dup
self.filename_torchscript_dup = tmppath / "model_torchscript_dup.pt"
torch.jit.save(m, self.filename_torchscript_dup)

self.standard_torchscript_polyglot_name = tmppath / "test_polyglot.pt"
# TorchScript v1.4 (pre-generated fixtures to avoid torch.jit deprecation warnings)
self.filename_torchscript = FIXTURES_DIR / "squeezenet1_0_torchscript_v1_4.pt"
self.standard_torchscript_polyglot_name = tmppath / "test_polyglot.pt"

# PyTorch v0.1.1
self.filename_legacy_tar = tmppath / "model_legacy_tar.pth"
Expand Down Expand Up @@ -132,7 +123,6 @@ def test_legacy_pickle(self):
formats = polyglot.identify_pytorch_file_format(self.filename_legacy_pickle)
self.assertEqual(formats, ["PyTorch v0.1.10"])

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_torchscript(self):
formats = polyglot.identify_pytorch_file_format(self.filename_torchscript)
self.assertEqual(formats, ["TorchScript v1.4", "TorchScript v1.3", "PyTorch v1.3"])
Expand Down Expand Up @@ -200,7 +190,6 @@ def test_legacy_pickle_properties(self):
proper_result = _make_properties(is_valid_pickle=True)
self.assertEqual(properties, proper_result)

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_torchscript_properties(self):
properties = polyglot.find_file_properties(self.filename_torchscript)
proper_result = _make_properties(
Expand All @@ -219,11 +208,10 @@ def test_zip_properties(self):
proper_result = _make_properties(is_standard_zip=True, is_standard_not_torch=True)
self.assertEqual(properties, proper_result)

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_create_standard_torchscript_polyglot(self):
polyglot.create_polyglot(
self.filename_v1_3_dup,
self.filename_torchscript_dup,
self.filename_torchscript,
self.standard_torchscript_polyglot_name,
print_results=False,
)
Expand Down
11 changes: 3 additions & 8 deletions test/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
import tempfile
import unittest
from pathlib import Path
Expand All @@ -9,7 +8,7 @@
from fickling.fickle import Pickled
from fickling.pytorch import PyTorchModelWrapper

_lacks_torch_jit_support = sys.version_info >= (3, 14)
FIXTURES_DIR = Path(__file__).parent / "fixtures"


class TestPyTorchModule(unittest.TestCase):
Expand All @@ -21,10 +20,8 @@ def setUp(self):
self.filename_v1_3 = tmppath / "test_model.pth"
torch.save(model, self.filename_v1_3)

if not _lacks_torch_jit_support:
m = torch.jit.script(model)
self.torchscript_filename = tmppath / "test_model_torchscript.pth"
torch.jit.save(m, self.torchscript_filename)
# Pre-generated fixture to avoid torch.jit deprecation warnings
self.torchscript_filename = FIXTURES_DIR / "squeezenet1_0_torchscript_v1_4.pt"

def tearDown(self):
self.tmpdir.cleanup()
Expand All @@ -35,7 +32,6 @@ def test_wrapper(self):
except Exception as e: # noqa
self.fail(f"PyTorchModelWrapper was not able to load a PyTorch v1.3 file: {e}")

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_torchscript_wrapper(self):
try:
PyTorchModelWrapper(self.torchscript_filename)
Expand All @@ -47,7 +43,6 @@ def test_pickled(self):
pickled_portion = result.pickled
self.assertIsInstance(pickled_portion, Pickled)

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_torchscript_pickled(self):
result = PyTorchModelWrapper(self.torchscript_filename)
pickled_portion = result.pickled
Expand Down
Loading