diff --git a/test/fixtures/generate_fixtures.py b/test/fixtures/generate_fixtures.py new file mode 100644 index 0000000..b51f110 --- /dev/null +++ b/test/fixtures/generate_fixtures.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10,<3.14" +# dependencies = ["torch>=2.10,<3", "torchvision>=0.25,<1"] +# /// +"""Generate TorchScript fixture files for tests. + +These fixtures are checked into the repository so that tests don't need +to call torch.jit.script/torch.jit.save at runtime (which emits +deprecation warnings on Python 3.14+ and may be removed in future +PyTorch versions). + +Usage: + uv run test/fixtures/generate_fixtures.py +""" + +from pathlib import Path + +import torch +import torchvision.models as models + +FIXTURES_DIR = Path(__file__).parent + + +def main(): + model = models.squeezenet1_0() + scripted = torch.jit.script(model) + + out = FIXTURES_DIR / "squeezenet1_0_torchscript_v1_4.pt" + torch.jit.save(scripted, out) + print(f"Generated {out} ({out.stat().st_size} bytes)") + + +if __name__ == "__main__": + main() diff --git a/test/fixtures/squeezenet1_0_torchscript_v1_4.pt b/test/fixtures/squeezenet1_0_torchscript_v1_4.pt new file mode 100644 index 0000000..7a5eb6e Binary files /dev/null and b/test/fixtures/squeezenet1_0_torchscript_v1_4.pt differ diff --git a/test/test_polyglot.py b/test/test_polyglot.py index 242d79d..6aa1e44 100644 --- a/test/test_polyglot.py +++ b/test/test_polyglot.py @@ -1,6 +1,5 @@ import random import string -import sys import tarfile import tempfile import unittest @@ -15,7 +14,7 @@ import fickling.polyglot as polyglot from fickling.polyglot import FileProperties -_lacks_torch_jit_support = sys.version_info >= (3, 14) +FIXTURES_DIR = Path(__file__).parent / "fixtures" def _make_properties(**overrides): @@ -79,17 +78,9 @@ def setUp(self): self.filename_legacy_pickle = tmppath / "model_legacy_pickle.pth" torch.save(model, self.filename_legacy_pickle, _use_new_zipfile_serialization=False) - if not _lacks_torch_jit_support: - # TorchScript v1.4 - m = torch.jit.script(model) - self.filename_torchscript = tmppath / "model_torchscript.pt" - torch.jit.save(m, self.filename_torchscript) - - # TorchScript v1.4 Dup - self.filename_torchscript_dup = tmppath / "model_torchscript_dup.pt" - torch.jit.save(m, self.filename_torchscript_dup) - - self.standard_torchscript_polyglot_name = tmppath / "test_polyglot.pt" + # TorchScript v1.4 (pre-generated fixtures to avoid torch.jit deprecation warnings) + self.filename_torchscript = FIXTURES_DIR / "squeezenet1_0_torchscript_v1_4.pt" + self.standard_torchscript_polyglot_name = tmppath / "test_polyglot.pt" # PyTorch v0.1.1 self.filename_legacy_tar = tmppath / "model_legacy_tar.pth" @@ -132,7 +123,6 @@ def test_legacy_pickle(self): formats = polyglot.identify_pytorch_file_format(self.filename_legacy_pickle) self.assertEqual(formats, ["PyTorch v0.1.10"]) - @unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+") def test_torchscript(self): formats = polyglot.identify_pytorch_file_format(self.filename_torchscript) self.assertEqual(formats, ["TorchScript v1.4", "TorchScript v1.3", "PyTorch v1.3"]) @@ -200,7 +190,6 @@ def test_legacy_pickle_properties(self): proper_result = _make_properties(is_valid_pickle=True) self.assertEqual(properties, proper_result) - @unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+") def test_torchscript_properties(self): properties = polyglot.find_file_properties(self.filename_torchscript) proper_result = _make_properties( @@ -219,11 +208,10 @@ def test_zip_properties(self): proper_result = _make_properties(is_standard_zip=True, is_standard_not_torch=True) self.assertEqual(properties, proper_result) - @unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+") def test_create_standard_torchscript_polyglot(self): polyglot.create_polyglot( self.filename_v1_3_dup, - self.filename_torchscript_dup, + self.filename_torchscript, self.standard_torchscript_polyglot_name, print_results=False, ) diff --git a/test/test_pytorch.py b/test/test_pytorch.py index ef8c76f..fa73d74 100644 --- a/test/test_pytorch.py +++ b/test/test_pytorch.py @@ -1,4 +1,3 @@ -import sys import tempfile import unittest from pathlib import Path @@ -9,7 +8,7 @@ from fickling.fickle import Pickled from fickling.pytorch import PyTorchModelWrapper -_lacks_torch_jit_support = sys.version_info >= (3, 14) +FIXTURES_DIR = Path(__file__).parent / "fixtures" class TestPyTorchModule(unittest.TestCase): @@ -21,10 +20,8 @@ def setUp(self): self.filename_v1_3 = tmppath / "test_model.pth" torch.save(model, self.filename_v1_3) - if not _lacks_torch_jit_support: - m = torch.jit.script(model) - self.torchscript_filename = tmppath / "test_model_torchscript.pth" - torch.jit.save(m, self.torchscript_filename) + # Pre-generated fixture to avoid torch.jit deprecation warnings + self.torchscript_filename = FIXTURES_DIR / "squeezenet1_0_torchscript_v1_4.pt" def tearDown(self): self.tmpdir.cleanup() @@ -35,7 +32,6 @@ def test_wrapper(self): except Exception as e: # noqa self.fail(f"PyTorchModelWrapper was not able to load a PyTorch v1.3 file: {e}") - @unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+") def test_torchscript_wrapper(self): try: PyTorchModelWrapper(self.torchscript_filename) @@ -47,7 +43,6 @@ def test_pickled(self): pickled_portion = result.pickled self.assertIsInstance(pickled_portion, Pickled) - @unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+") def test_torchscript_pickled(self): result = PyTorchModelWrapper(self.torchscript_filename) pickled_portion = result.pickled