Skip to content

Latest commit

 

History

History
15 lines (9 loc) · 215 Bytes

File metadata and controls

15 lines (9 loc) · 215 Bytes

BatchJAX

Description

BatchJAX is a library that allow JAX vmap to be used over lists and objax.ModuleList.

Installation

pip install batchjax

Example

See batchjax_example.ipynb.