Skip to content

add rng option to jittablemodule#85

Open
qihqi wants to merge 3 commits intogoogle:mainfrom
qihqi:codex/add-rng-option-to-jittablemodule
Open

add rng option to jittablemodule#85
qihqi wants to merge 3 commits intogoogle:mainfrom
qihqi:codex/add-rng-option-to-jittablemodule

Conversation

@qihqi
Copy link
Copy Markdown
Collaborator

@qihqi qihqi commented Mar 31, 2026

Add an option for JittableModule to take in rng. This Pr addresses the issue in #17 where models with dropout want to use different rng behavior every run.

@qihqi qihqi requested review from QiliangCui and wdhongtw March 31, 2026 00:34
@wdhongtw
Copy link
Copy Markdown
Collaborator

Thank for the PR. I've noted the changes.
I need a few business days for a thorough review. Will get back to it soon.

@wdhongtw
Copy link
Copy Markdown
Collaborator

wdhongtw commented Apr 7, 2026

Sorry for the late reply.

I think the design/path is not very well for this library. We need a more solid design for easier and friendlier interface when addressing the randomness in jit-ed func/Module.

Here I propose another design and a minimal working version, maybe we can start from this?

@qihqi
cc @QiliangCui


For PyTorch, the state of RNG is a implicit global state. Manipulated by torch.manual_seed and used by various torch functions that requires a RNG source.

Although JAX honer pure-function design and require no implicit state. torchax should be a bridge align to PyTorch design, and fill the gap between PyTorch and JAX. Based on this, I don't think it's a good idea to introduce a new rng keyword arg to wrapped module when the module need to be jit-ed. It will reduce the usability of torchax bridge as it's harder to write device-agnostic applications. (Well... it's still good as currently the functionality before this PR is broken.)

# we can call the model with only activation when not jit-ed
out = model(activation)

# It's not easy to understand, and breaks the signature of *forward*.
out = jittable_module(activation, rng=jax.random.PRNGKey(0))

# ideally, the jit-ed module should be used in the same way.
out = jittable_module(activation)

I think the good/complete solution to address the randomness problem should

  • The free-function (jax_jit) and nn.Module (JittableModule) before and after JAX JIT should be use in the same way
    • After the func/module is constructed, the usage remains the same.
  • The seed is traced properly so the whether we applied JAX JIT is transparent to any computation that use rand-related API, and the computation are same before and after JAX JIT
    • no-seed result to different value each time, same seed result to determisted value
    • (Nice to have) and works properly when JAX JIT applied but also JAX_DISABLE_JIT=True

Here is a partial solution.
I try to capture the RNG state (in env) into JittableModule and include the RNG state as part of the jit-ed function.

# in interop.py

def _build_env() -> torchax.tensor.Environment:
  return torchax.default_env()


class JittableModule(torch.nn.Module):
  def __init__(
    self,
    m: torch.nn.Module,
    extra_jit_args=None,
    dedup_parameters=True,
    env_builder: Callable[[], tensor.Environment] = _build_env,
  ):
    if extra_jit_args is None:
      extra_jit_args = {}
    super().__init__()
    self.params, self.buffers = extract_all_buffers(m)
    self._model = m
    self._jitted = {}

    self._extra_jit_args = extra_jit_args
    self._env = env_builder()

    self._extra_dumped_weights = {}

    if dedup_parameters:
      temp = collections.defaultdict(list)
      for k, v in self.params.items():
        temp[id(v)].append(k)

      for v in temp.values():
        if len(v) > 1:
          # duplicated weights with different name
          self._extra_dumped_weights[v[0]] = v[1:]
          for extra_keys in v[1:]:
            del self.params[extra_keys]

  @property
  def __class__(self):
    # Lie about the class type so that
    # isinstance(jittable_module, self._model.__class__) works
    return self._model.__class__

  def __call__(self, *args, **kwargs):
    return self.forward(*args, **kwargs)

  def functional_call(self, method_or_name, params, buffers, rng, *args, **kwargs):
    kwargs = kwargs or {}
    params_copy = copy.copy(params)
    params_copy.update(buffers)
    # reinflate the state dict so there are not any missing keys
    for k, v in self._extra_dumped_weights.items():
      for new_key in v:
        params_copy[new_key] = params_copy[k]

    if isinstance(method_or_name, str):
      method = getattr(self._model, method_or_name)
    else:
      if not callable(method_or_name):
        raise TypeError(
          f"method_or_name should be a callable or a string, got {type(method_or_name)}"
        )
      method = method_or_name
      args = (self._model,) + args
    with (
      self._env.override_property(prng=_jax_view(rng)),
      torch_stateless._reparametrize_module(self._model, params_copy)
    ):
      res = method(*args, **kwargs)
    return res

  def jittable_call(self, method_name: str, *args, **kwargs):
    if method_name not in self._jitted:
      func = functools.partial(self.functional_call, method_name)
      jitted = jax_jit(
        functools.partial(self.functional_call, method_name),
        kwargs_for_jax_jit=self._extra_jit_args,
      )

      def jitted_forward(*args, **kwargs):
        rng = torch_view(self._env.param.prng)
        return jitted(self.params, self.buffers, rng, *args, **kwargs)

      self._jitted[method_name] = jitted_forward
    return self._jitted[method_name](*args, **kwargs)


# in test/test_jittable_module.py
class TestClass:

  def test_take_rng_controls_random_ops(self):
    torchax.enable_globally()
    env = torchax.default_env()

    class RandomOut(torch.nn.Module):
      def forward(self, x):
        return torch.randn_like(x)

    model = RandomOut().to("jax")
    jittable_module = interop.JittableModule(model)
    x = torch.ones(16, 16).to("jax")

    with env.override_property(prng=jax.random.PRNGKey(0)):
      same_rng_1 = jittable_module(x)
    with env.override_property(prng=jax.random.PRNGKey(0)):
      same_rng_2 = jittable_module(x)
    with env.override_property(prng=jax.random.PRNGKey(2)):
      different_rng = jittable_module(x)

    self.assertTrue(torch.equal(same_rng_1, same_rng_2))
    self.assertFalse(torch.equal(same_rng_1, different_rng))

Note: here I use env.override_property(prng=...) to manually seed as there seems no way to register a handler for torch.manual_seed. Please correct me if I'm wrong.

Copy link
Copy Markdown
Collaborator

@wdhongtw wdhongtw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need some sort of change. See other comments.

@qihqi qihqi force-pushed the codex/add-rng-option-to-jittablemodule branch 2 times, most recently from 72f5236 to 7b28014 Compare April 11, 2026 00:28
@qihqi
Copy link
Copy Markdown
Collaborator Author

qihqi commented Apr 11, 2026

Hi @wdhongtw

The changes you propose makes sense. I made the changes accordingly. So now forward doesn't chagne signature but gets the rng from env and pass it down to functional_call; and functional_call is being jitted so it will not hardcde a rng.

@qihqi qihqi force-pushed the codex/add-rng-option-to-jittablemodule branch from 7b28014 to 88d24bc Compare April 11, 2026 00:33
Copy link
Copy Markdown
Collaborator

@wdhongtw wdhongtw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still some questions. 🙂‍↕️

return self._model.__class__

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure about this removal of __call__, I think that resolve some issue around backward pass?

If so, I thinks this removal deserves a separated PR. 👍

args = (self._model,) + args
with torch_stateless._reparametrize_module(self._model, params_copy):
res = method(*args, **kwargs)
with self._env as env:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is entering the Environment required here? Shouldn't we enforce users to activate the context explicitly through torchax.enable_globally or with torchax.default_env as the usage guide in README?


@contextlib.contextmanager
def with_rng(self, rng):
with self._env.override_property(prng=_jax_view(rng)):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the expected type of rng here? Can we add explicit type hint for this rng?

If it's a torch.Tensor, then the usage in test file seems incorrect. If it's a jax.Array, then the _jax_view is redundent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants