From c2b3782d750104131a24077e854f5dd647452be9 Mon Sep 17 00:00:00 2001 From: qihqi Date: Mon, 30 Mar 2026 17:25:18 -0700 Subject: [PATCH 1/2] Add jax_jit functional_call RNG test --- test/test_jittable_module.py | 71 ++++++++++++++++++++++++++++++++++++ torchax/interop.py | 31 ++++++++++++++-- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/test/test_jittable_module.py b/test/test_jittable_module.py index 304c103..eeaa0c5 100644 --- a/test/test_jittable_module.py +++ b/test/test_jittable_module.py @@ -13,9 +13,12 @@ # limitations under the License. import unittest +import functools +import jax import torch +import torchax from torchax import interop @@ -64,6 +67,74 @@ def outer_function(model, x): assert torch.equal(output, expected_output) + def test_take_rng_requires_rng_kwarg(self): + class PlusOne(torch.nn.Module): + def forward(self, x): + return x + 1 + + jittable_module = interop.JittableModule(PlusOne(), take_rng=True) + with self.assertRaisesRegex(TypeError, "requires a `rng` kwarg"): + jittable_module(torch.ones(2, 2)) + + def test_take_rng_removes_rng_before_module_call(self): + class PlusOne(torch.nn.Module): + def forward(self, x): + return x + 1 + + jittable_module = interop.JittableModule(PlusOne(), take_rng=True) + x = torch.randn(2, 2) + expected = x + 1 + output = jittable_module.functional_call( + "forward", jittable_module.params, jittable_module.buffers, x, rng=object() + ) + assert torch.equal(output, expected) + + def test_take_rng_controls_random_ops(self): + torchax.enable_globally() + + class RandomOut(torch.nn.Module): + def forward(self, x): + return torch.randn_like(x) + + model = RandomOut().to("jax") + jittable_module = interop.JittableModule(model, take_rng=True) + x = torch.ones(16, 16).to("jax") + + same_rng_1 = jittable_module(x, rng=jax.random.PRNGKey(0)) + same_rng_2 = jittable_module(x, rng=jax.random.PRNGKey(0)) + different_rng = jittable_module(x, rng=jax.random.PRNGKey(1)) + + self.assertTrue(torch.equal(same_rng_1, same_rng_2)) + self.assertFalse(torch.equal(same_rng_1, different_rng)) + + def test_take_rng_controls_random_ops_for_jitted_functional_call(self): + torchax.enable_globally() + + class RandomOut(torch.nn.Module): + def forward(self, x): + return torch.randn_like(x) + + model = RandomOut().to("jax") + jittable_module = interop.JittableModule(model, take_rng=True) + x = torch.ones(16, 16).to("jax") + + jitted_functional = interop.jax_jit( + functools.partial(jittable_module.functional_call, "forward") + ) + + same_rng_1 = jitted_functional( + jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(0) + ) + same_rng_2 = jitted_functional( + jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(0) + ) + different_rng = jitted_functional( + jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(1) + ) + + self.assertTrue(torch.equal(same_rng_1, same_rng_2)) + self.assertFalse(torch.equal(same_rng_1, different_rng)) + if __name__ == "__main__": unittest.main() diff --git a/torchax/interop.py b/torchax/interop.py index f4c77b0..ec9a64d 100644 --- a/torchax/interop.py +++ b/torchax/interop.py @@ -73,7 +73,13 @@ def set_one(module, prefix): class JittableModule(torch.nn.Module): - def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=True): + def __init__( + self, + m: torch.nn.Module, + extra_jit_args=None, + dedup_parameters=True, + take_rng: bool = False, + ): if extra_jit_args is None: extra_jit_args = {} super().__init__() @@ -82,6 +88,7 @@ def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=Tru self._jitted = {} self._extra_jit_args = extra_jit_args + self._take_rng = take_rng self._extra_dumped_weights = {} @@ -104,10 +111,20 @@ def __class__(self): return self._model.__class__ def __call__(self, *args, **kwargs): + if self._take_rng and "rng" not in kwargs: + raise TypeError("JittableModule(..., take_rng=True) requires a `rng` kwarg.") return self.forward(*args, **kwargs) def functional_call(self, method_or_name, params, buffers, *args, **kwargs): kwargs = kwargs or {} + rng = None + if self._take_rng: + if "rng" not in kwargs: + raise TypeError( + "JittableModule(..., take_rng=True) requires a `rng` kwarg in functional_call." + ) + kwargs = copy.copy(kwargs) + rng = _jax_view(kwargs.pop("rng")) params_copy = copy.copy(params) params_copy.update(buffers) # reinflate the state dict so there are not any missing keys @@ -124,8 +141,16 @@ def functional_call(self, method_or_name, params, buffers, *args, **kwargs): ) method = method_or_name args = (self._model,) + args - with torch_stateless._reparametrize_module(self._model, params_copy): - res = method(*args, **kwargs) + with torchax.default_env() as env: + if rng is None: + with torch_stateless._reparametrize_module(self._model, params_copy): + res = method(*args, **kwargs) + else: + with ( + env.override_property(prng=rng), + torch_stateless._reparametrize_module(self._model, params_copy), + ): + res = method(*args, **kwargs) return res def jittable_call(self, method_name: str, *args, **kwargs): From f73bac469a066368bbf6c8aa3e609e5020757c18 Mon Sep 17 00:00:00 2001 From: qihqi Date: Mon, 30 Mar 2026 17:29:55 -0700 Subject: [PATCH 2/2] Fix Ruff import ordering in jittable module tests --- test/test_jittable_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_jittable_module.py b/test/test_jittable_module.py index eeaa0c5..ca3674b 100644 --- a/test/test_jittable_module.py +++ b/test/test_jittable_module.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import functools +import unittest import jax import torch