Powerpax#

A small collection of utility functions for JAX.

See the reference section for complete documentation. These utilities build on functionality in JAX and include:

chunked_vmap

Limit the number of vectorized steps evaluated in parallel which can reduce peak memory consumption with vmap.

sliced_scan

Keep a subset of iterations from a scan without first storing all intermediate steps.

checkpoint_chunked_scan

Perform a scan inserting checkpoints at regular intervals.

Static

Wrap a value to treat it as a static member of a pytree without defining a custom class.