Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions test/test_jittable_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import unittest

import jax
import torch

import torchax
from torchax import interop


Expand Down Expand Up @@ -64,6 +67,74 @@ def outer_function(model, x):

assert torch.equal(output, expected_output)

def test_take_rng_requires_rng_kwarg(self):
class PlusOne(torch.nn.Module):
def forward(self, x):
return x + 1

jittable_module = interop.JittableModule(PlusOne(), take_rng=True)
with self.assertRaisesRegex(TypeError, "requires a `rng` kwarg"):
jittable_module(torch.ones(2, 2))

def test_take_rng_removes_rng_before_module_call(self):
class PlusOne(torch.nn.Module):
def forward(self, x):
return x + 1

jittable_module = interop.JittableModule(PlusOne(), take_rng=True)
x = torch.randn(2, 2)
expected = x + 1
output = jittable_module.functional_call(
"forward", jittable_module.params, jittable_module.buffers, x, rng=object()
)
assert torch.equal(output, expected)

def test_take_rng_controls_random_ops(self):
torchax.enable_globally()

class RandomOut(torch.nn.Module):
def forward(self, x):
return torch.randn_like(x)

model = RandomOut().to("jax")
jittable_module = interop.JittableModule(model, take_rng=True)
x = torch.ones(16, 16).to("jax")

same_rng_1 = jittable_module(x, rng=jax.random.PRNGKey(0))
same_rng_2 = jittable_module(x, rng=jax.random.PRNGKey(0))
different_rng = jittable_module(x, rng=jax.random.PRNGKey(1))

self.assertTrue(torch.equal(same_rng_1, same_rng_2))
self.assertFalse(torch.equal(same_rng_1, different_rng))

def test_take_rng_controls_random_ops_for_jitted_functional_call(self):
torchax.enable_globally()

class RandomOut(torch.nn.Module):
def forward(self, x):
return torch.randn_like(x)

model = RandomOut().to("jax")
jittable_module = interop.JittableModule(model, take_rng=True)
x = torch.ones(16, 16).to("jax")

jitted_functional = interop.jax_jit(
functools.partial(jittable_module.functional_call, "forward")
)

same_rng_1 = jitted_functional(
jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(0)
)
same_rng_2 = jitted_functional(
jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(0)
)
different_rng = jitted_functional(
jittable_module.params, jittable_module.buffers, x, rng=jax.random.PRNGKey(1)
)

self.assertTrue(torch.equal(same_rng_1, same_rng_2))
self.assertFalse(torch.equal(same_rng_1, different_rng))


if __name__ == "__main__":
unittest.main()
31 changes: 28 additions & 3 deletions torchax/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ def set_one(module, prefix):


class JittableModule(torch.nn.Module):
def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=True):
def __init__(
self,
m: torch.nn.Module,
extra_jit_args=None,
dedup_parameters=True,
take_rng: bool = False,
):
if extra_jit_args is None:
extra_jit_args = {}
super().__init__()
Expand All @@ -82,6 +88,7 @@ def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=Tru
self._jitted = {}

self._extra_jit_args = extra_jit_args
self._take_rng = take_rng

self._extra_dumped_weights = {}

Expand All @@ -104,10 +111,20 @@ def __class__(self):
return self._model.__class__

def __call__(self, *args, **kwargs):
if self._take_rng and "rng" not in kwargs:
raise TypeError("JittableModule(..., take_rng=True) requires a `rng` kwarg.")
return self.forward(*args, **kwargs)

def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
kwargs = kwargs or {}
rng = None
if self._take_rng:
if "rng" not in kwargs:
raise TypeError(
"JittableModule(..., take_rng=True) requires a `rng` kwarg in functional_call."
)
kwargs = copy.copy(kwargs)
rng = _jax_view(kwargs.pop("rng"))
params_copy = copy.copy(params)
params_copy.update(buffers)
# reinflate the state dict so there are not any missing keys
Expand All @@ -124,8 +141,16 @@ def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
)
method = method_or_name
args = (self._model,) + args
with torch_stateless._reparametrize_module(self._model, params_copy):
res = method(*args, **kwargs)
with torchax.default_env() as env:
if rng is None:
with torch_stateless._reparametrize_module(self._model, params_copy):
res = method(*args, **kwargs)
else:
with (
env.override_property(prng=rng),
torch_stateless._reparametrize_module(self._model, params_copy),
):
res = method(*args, **kwargs)
return res

def jittable_call(self, method_name: str, *args, **kwargs):
Expand Down