diff --git a/test/test_jittable_module.py b/test/test_jittable_module.py index 304c103..ca3674b 100644 --- a/test/test_jittable_module.py +++ b/test/test_jittable_module.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import unittest +import jax import torch +import torchax from torchax import interop @@ -64,6 +67,74 @@ def outer_function(model, x): assert torch.equal(output, expected_output) + def test_take_rng_requires_rng_kwarg(self): + class PlusOne(torch.nn.Module): + def forward(self, x): + return x + 1 + + jittable_module = interop.JittableModule(PlusOne(), take_rng=True) + with self.assertRaisesRegex(TypeError, "requires a `rng` kwarg"): + jittable_module(torch.ones(2, 2)) + + def test_take_rng_removes_rng_before_module_call(self): + class PlusOne(torch.nn.Module): + def forward(self, x): + return x + 1 + + jittable_module = interop.JittableModule(PlusOne(), take_rng=True) + x = torch.randn(2, 2) + expected = x + 1 + output = jittable_module.functional_call( + "forward", jittable_module.params, jittable_module.buffers, x, rng=object() + ) + assert torch.equal(output, expected) + + def test_take_rng_controls_random_ops(self): + torchax.enable_globally() + + class RandomOut(torch.nn.Module): + def forward(self, x): + return torch.randn_like(x) + + model = RandomOut().to("jax") + jittable_module = interop.JittableModule(model, take_rng=True) + x = torch.ones(16, 16).to("jax") + + same_rng_1 = jittable_module(x, rng=jax.random.PRNGKey(0)) + same_rng_2 = jittable_module(x, rng=jax.random.PRNGKey(0)) + different_rng = jittable_module(x, rng=jax.random.PRNGKey(1)) + + self.assertTrue(torch.equal(same_rng_1, same_rng_2)) + self.assertFalse(torch.equal(same_rng_1, different_rng)) + + def test_take_rng_controls_random_ops_for_jitted_functional_call(self): + torchax.enable_globally() + + class RandomOut(torch.nn.Module): + def forward(self, x): + return torch.randn_like(x) + + model = RandomOut().to("jax") + jittable_module = interop.JittableModule(model, take_rng=True) + x = torch.ones(16, 16).to("jax") + + jitted_functional = interop.jax_jit( + functools.partial(jittable_module.functional_call, "forward") + ) + + same_rng_1 = jitted_functional( + jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(0) + ) + same_rng_2 = jitted_functional( + jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(0) + ) + different_rng = jitted_functional( + jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(1) + ) + + self.assertTrue(torch.equal(same_rng_1, same_rng_2)) + self.assertFalse(torch.equal(same_rng_1, different_rng)) + if __name__ == "__main__": unittest.main() diff --git a/torchax/interop.py b/torchax/interop.py index f4c77b0..ec9a64d 100644 --- a/torchax/interop.py +++ b/torchax/interop.py @@ -73,7 +73,13 @@ def set_one(module, prefix): class JittableModule(torch.nn.Module): - def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=True): + def __init__( + self, + m: torch.nn.Module, + extra_jit_args=None, + dedup_parameters=True, + take_rng: bool = False, + ): if extra_jit_args is None: extra_jit_args = {} super().__init__() @@ -82,6 +88,7 @@ def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=Tru self._jitted = {} self._extra_jit_args = extra_jit_args + self._take_rng = take_rng self._extra_dumped_weights = {} @@ -104,10 +111,20 @@ def __class__(self): return self._model.__class__ def __call__(self, *args, **kwargs): + if self._take_rng and "rng" not in kwargs: + raise TypeError("JittableModule(..., take_rng=True) requires a `rng` kwarg.") return self.forward(*args, **kwargs) def functional_call(self, method_or_name, params, buffers, *args, **kwargs): kwargs = kwargs or {} + rng = None + if self._take_rng: + if "rng" not in kwargs: + raise TypeError( + "JittableModule(..., take_rng=True) requires a `rng` kwarg in functional_call." + ) + kwargs = copy.copy(kwargs) + rng = _jax_view(kwargs.pop("rng")) params_copy = copy.copy(params) params_copy.update(buffers) # reinflate the state dict so there are not any missing keys @@ -124,8 +141,16 @@ def functional_call(self, method_or_name, params, buffers, *args, **kwargs): ) method = method_or_name args = (self._model,) + args - with torch_stateless._reparametrize_module(self._model, params_copy): - res = method(*args, **kwargs) + with torchax.default_env() as env: + if rng is None: + with torch_stateless._reparametrize_module(self._model, params_copy): + res = method(*args, **kwargs) + else: + with ( + env.override_property(prng=rng), + torch_stateless._reparametrize_module(self._model, params_copy), + ): + res = method(*args, **kwargs) return res def jittable_call(self, method_name: str, *args, **kwargs):