Powerpax#
A small collection of utility functions for JAX.
See the reference section for complete documentation. These utilities build on functionality in JAX and include:
chunked_vmap
Limit the number of vectorized steps evaluated in parallel which can reduce peak memory consumption with
vmap
.sliced_scan
Keep a subset of iterations from a
scan
without first storing all intermediate steps.checkpoint_chunked_scan
Perform a
scan
insertingcheckpoints
at regular intervals.Static
Wrap a value to treat it as a static member of a pytree without defining a custom class.