Conversation
|
Thank for the PR. I've noted the changes. |
|
Sorry for the late reply. I think the design/path is not very well for this library. We need a more solid design for easier and friendlier interface when addressing the randomness in jit-ed func/Module. Here I propose another design and a minimal working version, maybe we can start from this? For PyTorch, the state of RNG is a implicit global state. Manipulated by Although JAX honer pure-function design and require no implicit state. # we can call the model with only activation when not jit-ed
out = model(activation)
# It's not easy to understand, and breaks the signature of *forward*.
out = jittable_module(activation, rng=jax.random.PRNGKey(0))
# ideally, the jit-ed module should be used in the same way.
out = jittable_module(activation)I think the good/complete solution to address the randomness problem should
Here is a partial solution. # in interop.py
def _build_env() -> torchax.tensor.Environment:
return torchax.default_env()
class JittableModule(torch.nn.Module):
def __init__(
self,
m: torch.nn.Module,
extra_jit_args=None,
dedup_parameters=True,
env_builder: Callable[[], tensor.Environment] = _build_env,
):
if extra_jit_args is None:
extra_jit_args = {}
super().__init__()
self.params, self.buffers = extract_all_buffers(m)
self._model = m
self._jitted = {}
self._extra_jit_args = extra_jit_args
self._env = env_builder()
self._extra_dumped_weights = {}
if dedup_parameters:
temp = collections.defaultdict(list)
for k, v in self.params.items():
temp[id(v)].append(k)
for v in temp.values():
if len(v) > 1:
# duplicated weights with different name
self._extra_dumped_weights[v[0]] = v[1:]
for extra_keys in v[1:]:
del self.params[extra_keys]
@property
def __class__(self):
# Lie about the class type so that
# isinstance(jittable_module, self._model.__class__) works
return self._model.__class__
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def functional_call(self, method_or_name, params, buffers, rng, *args, **kwargs):
kwargs = kwargs or {}
params_copy = copy.copy(params)
params_copy.update(buffers)
# reinflate the state dict so there are not any missing keys
for k, v in self._extra_dumped_weights.items():
for new_key in v:
params_copy[new_key] = params_copy[k]
if isinstance(method_or_name, str):
method = getattr(self._model, method_or_name)
else:
if not callable(method_or_name):
raise TypeError(
f"method_or_name should be a callable or a string, got {type(method_or_name)}"
)
method = method_or_name
args = (self._model,) + args
with (
self._env.override_property(prng=_jax_view(rng)),
torch_stateless._reparametrize_module(self._model, params_copy)
):
res = method(*args, **kwargs)
return res
def jittable_call(self, method_name: str, *args, **kwargs):
if method_name not in self._jitted:
func = functools.partial(self.functional_call, method_name)
jitted = jax_jit(
functools.partial(self.functional_call, method_name),
kwargs_for_jax_jit=self._extra_jit_args,
)
def jitted_forward(*args, **kwargs):
rng = torch_view(self._env.param.prng)
return jitted(self.params, self.buffers, rng, *args, **kwargs)
self._jitted[method_name] = jitted_forward
return self._jitted[method_name](*args, **kwargs)
# in test/test_jittable_module.py
class TestClass:
def test_take_rng_controls_random_ops(self):
torchax.enable_globally()
env = torchax.default_env()
class RandomOut(torch.nn.Module):
def forward(self, x):
return torch.randn_like(x)
model = RandomOut().to("jax")
jittable_module = interop.JittableModule(model)
x = torch.ones(16, 16).to("jax")
with env.override_property(prng=jax.random.PRNGKey(0)):
same_rng_1 = jittable_module(x)
with env.override_property(prng=jax.random.PRNGKey(0)):
same_rng_2 = jittable_module(x)
with env.override_property(prng=jax.random.PRNGKey(2)):
different_rng = jittable_module(x)
self.assertTrue(torch.equal(same_rng_1, same_rng_2))
self.assertFalse(torch.equal(same_rng_1, different_rng))Note: here I use |
wdhongtw
left a comment
There was a problem hiding this comment.
Need some sort of change. See other comments.
72f5236 to
7b28014
Compare
|
Hi @wdhongtw The changes you propose makes sense. I made the changes accordingly. So now forward doesn't chagne signature but gets the rng from env and pass it down to functional_call; and functional_call is being jitted so it will not hardcde a rng. |
7b28014 to
88d24bc
Compare
wdhongtw
left a comment
There was a problem hiding this comment.
Still some questions. 🙂
| return self._model.__class__ | ||
|
|
||
| def __call__(self, *args, **kwargs): | ||
| return self.forward(*args, **kwargs) |
There was a problem hiding this comment.
Not quite sure about this removal of __call__, I think that resolve some issue around backward pass?
If so, I thinks this removal deserves a separated PR. 👍
| args = (self._model,) + args | ||
| with torch_stateless._reparametrize_module(self._model, params_copy): | ||
| res = method(*args, **kwargs) | ||
| with self._env as env: |
There was a problem hiding this comment.
Is entering the Environment required here? Shouldn't we enforce users to activate the context explicitly through torchax.enable_globally or with torchax.default_env as the usage guide in README?
|
|
||
| @contextlib.contextmanager | ||
| def with_rng(self, rng): | ||
| with self._env.override_property(prng=_jax_view(rng)): |
There was a problem hiding this comment.
What's the expected type of rng here? Can we add explicit type hint for this rng?
If it's a torch.Tensor, then the usage in test file seems incorrect. If it's a jax.Array, then the _jax_view is redundent.
Add an option for JittableModule to take in rng. This Pr addresses the issue in #17 where models with dropout want to use different rng behavior every run.