Reference¶
Useful utilities for JAX.
This package provides a small set of useful utilities for working with
JAX. For example, functions to limit the memory consumption of
jax.vmap() (ppx.chunked_vmap) or
to keep the outputs of only a subset of jax.lax.scan()
iterations (ppx.sliced_scan).
Vectorization¶
- powerpax.chunked_vmap(fun, chunk_size)[source]¶
Like
jax.vmap()but limited to batches of chunk_size steps.This function behaves like
jax.vmap()in that it vectorizes a function to apply over batches of inputs. However, unlike vmap which carries out all calculations in parallel, this version will perform at most chunk_size steps at once and usesjax.lax.scan()to loop over chunks.This is useful in cases where the calculations in fun involve large intermediate values which can exhaust available memory. With chunked_vmap it is possible to place an upper bound on peak memory use from the intermediate results while preserving some of the performance benefits of vmap, particularly on GPUs.
- Parameters:
- Returns:
A wrapped version of fun which vectorizes over leading axes of each input.
- Return type:
function
Note
Unlike
jax.vmap()this function does not allow specifying in_axes or out_axes. It vectorizes all parameters over the first axis. Usejnp.moveaxisandfunctools.partial()to map over other axes or to leave a parameter un-vectorized.
Loops¶
- powerpax.sliced_scan(f, init, xs, length=None, reverse=False, unroll=1, *, start=None, stop=None, step=None)[source]¶
Slice the output of
jax.lax.scan()without first collecting all iterations.Using this function is equivalent to:
carry, ys = jax.lax.scan(f, init, xs, length, reverse, unroll) ys = jax.tree.map(lambda leaf: leaf[start:stop:step], ys)
except that it does not first produce a complete ys. The loop is split into separate scan phases (nested as needed) to collect only the required steps. See
jax.tree.map()for information on its effect in the above example.Most arguments are as in
jax.lax.scan(). New parameters for this function are start, stop, and step.- Parameters:
f (function) – A function suitable for use with
jax.lax.scan(). Namely, takes two arguments (a carry and x) and returns two values (an updated carry and y).init (object) – A JAX pytree initializing the carry.
xs (object) – A JAX pytree over which to loop. If not
Nonethe loop scans over the leading dimension of each leafarray.length (int, optional) – Integer specifying the number of iterations. Useful if xs is
None. If both length and xs are provided, the implied loop iteration counts must match.reverse (bool, optional) – If
False(default) the loop proceeds in normal forward order. Otherwise the loop will start at the end of each input array in xs and fill ys from right to left.unroll (int, optional) – An integer allowing greater loop unrolling. Note that this function applies the unrolling to each internal
scanadjusting so that this provides an upper bound on the number of unrolled steps for the innermost loop in the case of nested scans.start (int, optional) – The starting index at which the slice should start.
stop (int, optional) – The ending index at which the slice should stop.
step (int, optional) – The step size for the slice.
- Returns:
A tuple
(carry, ys)where ys has been sliced by start, stop and step.- Return type:
Note
The slicing applies only to ys, the carry value is still updated by all loop iterations even if their y outputs are skipped by the slice.
- powerpax.checkpoint_chunked_scan(f, init, xs, length=None, reverse=False, unroll=1, *, chunk_size=None)[source]¶
Perform a
scaninsertingcheckpointsevery chunk_size steps.This function performs a normal scan loop, but inserts checkpoints at regular intervals. This can reduce peak memory use (at the cost of recomputation) when computing gradients through the loop.
Most arguments are as in
jax.lax.scan(). This function has one added parameter, chunk_size.- Parameters:
f (function) – A function suitable for use with
jax.lax.scan(). Namely, takes two arguments (a carry and x) and returns two values (an updated carry and y).init (object) – A JAX pytree initializing the carry.
xs (object) – A JAX pytree over which to loop. If not
Nonethe loop scans over the leading dimension of each leafarray.length (int, optional) – Integer specifying the number of iterations. Useful if xs is
None. If both length and xs are provided, the implied loop iteration counts must match.reverse (bool, optional) – If
False(default) the loop proceeds in normal forward order. Otherwise the loop will start at the end of each input array in xs and fill ys from right to left.unroll (int, optional) – An integer allowing greater loop unrolling. Note that this function applies the unrolling to each internal
scanadjusting so that this provides an upper bound on the number of unrolled steps for the innermost loop in the case of nested scans.chunk_size (int, optional) – The interval at which to insert checkpoints. Every chunk_size steps a checkpoint will be inserted, starting with the first step. If this parameter is not specified, the entire scan is treated as one chunk with one checkpoint at the start.
- Returns:
A tuple
(carry, ys).- Return type: