PowerpaxΒΆ
A small collection of utility functions for JAX.
See the reference section for complete documentation. These utilities build on functionality in JAX and include:
chunked_vmapLimit the number of vectorized steps evaluated in parallel which can reduce peak memory consumption with
vmap.sliced_scanKeep a subset of iterations from a
scanwithout first storing all intermediate steps.checkpoint_chunked_scanPerform a
scaninsertingcheckpointsat regular intervals.StaticWrap a value to treat it as a static member of a pytree without defining a custom class.
Contents