Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions test/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import concurrent.futures
import threading
import unittest

import torchax


class TestThreading(unittest.TestCase):
def test_access_config_thread(reraise):
def test_access_config_thread(self):
torchax.default_env()

def task():
with reraise:
print(torchax.default_env().param)
print(torchax.default_env().param)

threads = []
for _ in range(5):
Expand All @@ -35,6 +35,24 @@ def task():
for thread in threads:
thread.join()

def test_thread_safe_init(self):
# Force a reset to simulate pristine state
torchax._env = None

def task():
return torchax.default_env()

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(task) for _ in range(32)]
results = [f.result() for f in futures]

# All threads should return the same environment object
assert len(results) > 0
lead = results[0]
for r in results:
self.assertIsNotNone(r)
self.assertIs(r, lead)


if __name__ == "__main__":
unittest.main()
42 changes: 31 additions & 11 deletions torchax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import dataclasses
import os
import threading
from contextlib import contextmanager
from typing import Any

Expand All @@ -40,6 +41,7 @@
"default_env",
"extract_jax",
"enable_globally",
"disable_globally",
"save_checkpoint",
"load_checkpoint",
]
Expand All @@ -55,15 +57,31 @@
)
# torchax:oss-end

env = None
_env: tensor.Environment | None = None
_env_lock = threading.Lock()


def default_env():
global env
def default_env() -> tensor.Environment:
"""Returns the default environment.

if env is None:
env = tensor.Environment()
return env
The (global) environment is constructed lazily on the first call,
with default configuration. Construct it manually for advanced
configuration.
"""
global _env

if _env is None:
# The first thread that enters this block will create the environment.
# Other threads will wait for the lock to be released and then return
# the environment.
with _env_lock:
if _env is not None:
return _env

_env = tensor.Environment()

assert _env is not None
return _env


def extract_jax(mod: torch.nn.Module, env=None, *, dedup_parameters=True):
Expand Down Expand Up @@ -94,13 +112,15 @@ def jax_func(states, args, kwargs=None):
return states, jax_func


def enable_globally():
env = default_env().enable_torch_modes()
return env
def enable_globally() -> None:
"""Enables torchax globally."""

default_env().enable_torch_modes()


def disable_globally() -> None:
"""Disables torchax globally."""

def disable_globally():
global env
default_env().disable_torch_modes()


Expand Down