Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions torchax/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import functools
from functools import wraps
from inspect import signature
from typing import TypeVar, overload

import jax
import jax.numpy as jnp
Expand All @@ -29,6 +30,9 @@
from torchax.ops import mappings
from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue

_T = TypeVar("_T")


try:
from jax import shard_map as shard_map # for jax since v0.8.0
except ImportError:
Expand Down Expand Up @@ -177,6 +181,40 @@ def _torch_view(t: JaxValue) -> TorchValue:


torch_view = functools.partial(pytree.tree_map, _torch_view)
"""Tree-map "JAX to PyTorch" transformation.

Use torch_view_elem for better typing support if we don't need tree-map.
"""


@overload
def torch_view_elem(v: jax.Array) -> torch.Tensor: ...


@overload
def torch_view_elem(v: jnp.dtype) -> torch.dtype: ...


@overload
def torch_view_elem(v: JaxCallable) -> TorchCallable: ...


@overload
def torch_view_elem(v: _T) -> _T: ...


def torch_view_elem(v: JaxValue) -> TorchValue:
"""Apply "JAX to PyTorch" transformation.

* jax.Array -> torch.Tensor: tensor transform
* jnp.dtype -> torch.dtype: dtype transform
* JaxCallable -> TorchCallable: function transform
* T -> T: other regular types remain unchanged

Function transform is not very friendly for typing system.
Write a dedicate function if needed.
"""
return _torch_view(v)


def _jax_view(t: TorchValue) -> JaxValue:
Expand All @@ -196,6 +234,40 @@ def _jax_view(t: TorchValue) -> JaxValue:


jax_view = functools.partial(pytree.tree_map, _jax_view)
"""Tree-map "PyTorch to JAX" transformation.

Use jax_view_elem for better typing support if we don't need tree-map.
"""


@overload
def jax_view_elem(v: torch.Tensor) -> jax.Array: ...


@overload
def jax_view_elem(v: torch.dtype) -> jnp.dtype: ...


@overload
def jax_view_elem(v: TorchCallable) -> JaxCallable: ...


@overload
def jax_view_elem(v: _T) -> _T: ...


def jax_view_elem(v: TorchValue) -> JaxValue:
"""Apply "PyTorch to JAX" transformation.

* torch.Tensor -> jax.Array: tensor transform
* torch.dtype -> jnp.dtype: dtype transform
* TorchCallable -> JaxCallable: function transform
* T -> T: other regular types remain unchanged

Function transform is not very friendly for typing system.
Write a dedicate function if needed.
"""
return _jax_view(v)


def call_jax(
Expand Down