Reference

Useful utilities for JAX.

This package provides a small set of useful utilities for working with JAX. For example, functions to limit the memory consumption of jax.vmap() (ppx.chunked_vmap) or to keep the outputs of only a subset of jax.lax.scan() iterations (ppx.sliced_scan).

Vectorization

powerpax.chunked_vmap(fun, chunk_size)[source]

Like jax.vmap() but limited to batches of chunk_size steps.

This function behaves like jax.vmap() in that it vectorizes a function to apply over batches of inputs. However, unlike vmap which carries out all calculations in parallel, this version will perform at most chunk_size steps at once and uses jax.lax.scan() to loop over chunks.

This is useful in cases where the calculations in fun involve large intermediate values which can exhaust available memory. With chunked_vmap it is possible to place an upper bound on peak memory use from the intermediate results while preserving some of the performance benefits of vmap, particularly on GPUs.

Parameters:
  • fun (function) – Function to be mapped over additional axes.

  • chunk_size (int) – Upper limit on the size of chunks to be vectorized over. Inputs larger than this will be processed with an outer scan loop.

Returns:

A wrapped version of fun which vectorizes over leading axes of each input.

Return type:

function

Note

Unlike jax.vmap() this function does not allow specifying in_axes or out_axes. It vectorizes all parameters over the first axis. Use jnp.moveaxis and functools.partial() to map over other axes or to leave a parameter un-vectorized.

Loops

powerpax.sliced_scan(f, init, xs, length=None, reverse=False, unroll=1, *, start=None, stop=None, step=None)[source]

Slice the output of jax.lax.scan() without first collecting all iterations.

Using this function is equivalent to:

carry, ys = jax.lax.scan(f, init, xs, length, reverse, unroll)
ys = jax.tree.map(lambda leaf: leaf[start:stop:step], ys)

except that it does not first produce a complete ys. The loop is split into separate scan phases (nested as needed) to collect only the required steps. See jax.tree.map() for information on its effect in the above example.

Most arguments are as in jax.lax.scan(). New parameters for this function are start, stop, and step.

Parameters:
  • f (function) – A function suitable for use with jax.lax.scan(). Namely, takes two arguments (a carry and x) and returns two values (an updated carry and y).

  • init (object) – A JAX pytree initializing the carry.

  • xs (object) – A JAX pytree over which to loop. If not None the loop scans over the leading dimension of each leaf array.

  • length (int, optional) – Integer specifying the number of iterations. Useful if xs is None. If both length and xs are provided, the implied loop iteration counts must match.

  • reverse (bool, optional) – If False (default) the loop proceeds in normal forward order. Otherwise the loop will start at the end of each input array in xs and fill ys from right to left.

  • unroll (int, optional) – An integer allowing greater loop unrolling. Note that this function applies the unrolling to each internal scan adjusting so that this provides an upper bound on the number of unrolled steps for the innermost loop in the case of nested scans.

  • start (int, optional) – The starting index at which the slice should start.

  • stop (int, optional) – The ending index at which the slice should stop.

  • step (int, optional) – The step size for the slice.

Returns:

A tuple (carry, ys) where ys has been sliced by start, stop and step.

Return type:

object, object

Note

The slicing applies only to ys, the carry value is still updated by all loop iterations even if their y outputs are skipped by the slice.

powerpax.checkpoint_chunked_scan(f, init, xs, length=None, reverse=False, unroll=1, *, chunk_size=None)[source]

Perform a scan inserting checkpoints every chunk_size steps.

This function performs a normal scan loop, but inserts checkpoints at regular intervals. This can reduce peak memory use (at the cost of recomputation) when computing gradients through the loop.

Most arguments are as in jax.lax.scan(). This function has one added parameter, chunk_size.

Parameters:
  • f (function) – A function suitable for use with jax.lax.scan(). Namely, takes two arguments (a carry and x) and returns two values (an updated carry and y).

  • init (object) – A JAX pytree initializing the carry.

  • xs (object) – A JAX pytree over which to loop. If not None the loop scans over the leading dimension of each leaf array.

  • length (int, optional) – Integer specifying the number of iterations. Useful if xs is None. If both length and xs are provided, the implied loop iteration counts must match.

  • reverse (bool, optional) – If False (default) the loop proceeds in normal forward order. Otherwise the loop will start at the end of each input array in xs and fill ys from right to left.

  • unroll (int, optional) – An integer allowing greater loop unrolling. Note that this function applies the unrolling to each internal scan adjusting so that this provides an upper bound on the number of unrolled steps for the innermost loop in the case of nested scans.

  • chunk_size (int, optional) – The interval at which to insert checkpoints. Every chunk_size steps a checkpoint will be inserted, starting with the first step. If this parameter is not specified, the entire scan is treated as one chunk with one checkpoint at the start.

Returns:

A tuple (carry, ys).

Return type:

object, object

Pytrees

class powerpax.Static(value)[source]

Treat a value as static when processing a pytree.

This class wraps value and will instruct JAX to treat it as a static value during pytree processing.

Parameters:

value (object) – The object to wrap and treat as static (must be hashable).

value

The wrapped, static value.

Type:

object