Does Torchax have an advantage over using Torch2Jax (https://github.com/rdyro/torch2jax) if we are wanting to call pytorch functionality from within a JITTed JAX program?
Will the way Torchax works cause issues if one wanted to try to pass the Pytorch tensors to Pykeops from within a JittableModule?
https://www.kernel-operations.io/keops/python/LazyTensor.html
Does Torchax have an advantage over using Torch2Jax (https://github.com/rdyro/torch2jax) if we are wanting to call pytorch functionality from within a JITTed JAX program?
Will the way Torchax works cause issues if one wanted to try to pass the Pytorch tensors to Pykeops from within a JittableModule?
https://www.kernel-operations.io/keops/python/LazyTensor.html