Pytrees

Author

Marie-Hélène Burle

It is convenient to store data, model parameters, gradients, etc. in container structures such as lists or dicts. JAX has a container-like structure, the pytree that is flexible, can be nested, and is supported by many JAX functions, making for convenient workflows.

This section introduces pytrees and their functioning.

A tree-like structure

The pytree container registry contains, by default, lists, tuples, and dicts. It can be extended to other containers.

Objects in the pytree container registry are pytrees. Other objects are leaf pytrees (so pytrees are recursive).

Pytrees are great for holding data and parameters, keeping everything organized, even for complex models. The leaves are usually made of arrays. Many JAX functions can be applied to pytrees.

Examples of pytrees:

(1, 2, 3),
[1, 1., "string", True],
jnp.arange(2),
{'key1': 3.4, 'key2': 6.},
[3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}],
(3, 2, (6, 0), 2, ()),
jnp.zeros(3)

Let’s kill our previous interactive job with a GPU:

exit

Then start an interactive job with a CPU:

salloc --time=2:0:0 --mem-per-cpu=5500M

Load the ipython module:

module load ipython-kernel/3.11

Activate the virtual python environment:

source /project/60055/env/bin/activate

Launch IPython:

ipython

Extracting leaves

Trees can be flattened and their leaves extracted into a list with jax.tree.leaves:

jax.tree.leaves([3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}])
[3.0, 1, 2, 'val1', 'val2', 'val3']

Let’s create a list of pytrees and extract their leaves to look at more examples:

import jax
import jax.numpy as jnp

list_trees = [
    (1, 2, 3),
    [1, 1., "string", True],
    jnp.arange(2),
    {'key1': 3.4, 'key2': 6.},
    [3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}],
    (3, 2, (6, 0, 9), 2, ()),
    jnp.zeros(3)
    ]

for pytree in list_trees:
  leaves = jax.tree.leaves(pytree)
  print(f"{len(leaves)} leaves: {leaves}")
3 leaves: [1, 2, 3]
4 leaves: [1, 1.0, 'string', True]
1 leaves: [Array([0, 1], dtype=int32)]
2 leaves: [3.4, 6.0]
6 leaves: [3.0, 1, 2, 'val1', 'val2', 'val3']
5 leaves: [3, 2, 6, 0, 9, 2]
1 leaves: [Array([0., 0., 0.], dtype=float32)]

Be careful that leaves are not the same as container elements:

  • while an array contains many elements, it is a single leaf,
  • while a nested list or tuple represent a single element of the parent container, all the elements of nested tuples and lists are leaves,
  • an empty tuple or list is a pytree without children and is not counted as a leaf.

Contrast this with the length (i.e. the number of elements of containers):

for pytree in list_trees:
  print(f"{len(pytree)} elements")
3 elements
4 elements
2 elements
2 elements
3 elements
5 elements
3 elements

Structure of pytrees

As we just saw, JAX can extract the leaves of pytrees. This is useful to run functions on them. But JAX also records their structure and is able to recreate them. The structure can be obtained with jax.tree.structure:

jax.tree.structure([3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}])
PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}])

So each pytree can be turned into a tuple of the list of its leaves and its structure and that tuple can be turned back into the pytree.

jax.tree.flatten([3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}])
([3.0, 1, 2, 'val1', 'val2', 'val3'],
 PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}]))
values, structure = jax.tree.flatten(
    [3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]
)
jax.tree.unflatten(structure, values)
[3.0, (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]

The path to each leaf can be obtained with jax.tree_util.tree_flatten_with_path:

jax.tree_util.tree_flatten_with_path(
    [3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]
)
([((SequenceKey(idx=0),), 3.0),
  ((SequenceKey(idx=1), SequenceKey(idx=0)), 1),
  ((SequenceKey(idx=1), SequenceKey(idx=1)), 2),
  ((SequenceKey(idx=2), DictKey(key='key1')), 'val1'),
  ((SequenceKey(idx=2), DictKey(key='key2')), 'val2'),
  ((SequenceKey(idx=2), DictKey(key='key3')), 'val3')],
 PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}]))

Pytree operations

JAX can run operations on pytrees. Let’s create a few pytrees to play with:

tree1 = {'key1': 1., 'key2': 2., 'key3': 3.}
tree2 = {'key1': 4., 'key2': 5., 'key3': 6.}
tree3 = {'key1': 7., 'key2': 8., 'key3': 9.}

jax.tree.map allows to apply functions to each leaf of a tree:

jax.tree.map(lambda x: 3 * x, tree1)
{'key1': 3.0, 'key2': 6.0, 'key3': 9.0}

As long as pytrees share the same structure (including the same dicts keys), operations combining multiple pytrees also work:

jax.tree.map(lambda x, y, z: x * y + z, tree1, tree2, tree3)
{'key1': 11.0, 'key2': 18.0, 'key3': 27.0}

Here are a few more examples:

tree4 = [[1, 1, 1], (2, 2, 2, 2), 3]
tree5 = [[0, 5, 1], (2, 2, 2, 2), 3]
tree6 = [[0, 5, 1, 2], (2, 2, 2), 3]
jax.tree.map(lambda x, y: x + y, tree4, tree5)
[[1, 6, 2], (4, 4, 4, 4), 6]

This won’t work though as the structures are different:

jax.tree.map(lambda x, y: x + y, tree5, tree6)
ValueError: Tuple arity mismatch: 3 != 4; tuple: (2, 2, 2).

Pytree transposition

A list of pytrees can be transposed into a pytree of lists.

Let’s create a list with a few of our previous pytrees:

trees = [tree1, tree2, tree3]
print(trees)
[{'key1': 1.0, 'key2': 2.0, 'key3': 3.0}, {'key1': 1.0, 'key2': 2.0, 'key3': 3.0}, {'key1': 1.0, 'key2': 2.0, 'key3': 3.0}]

Here is how to transpose this list of pytrees:

jax.tree.map(lambda *x: list(x), *trees)
{'key1': [1.0, 1.0, 1.0], 'key2': [2.0, 2.0, 2.0], 'key3': [3.0, 3.0, 3.0]}

Pytrees in NN

Pytrees are very useful when using JAX for deep learning. Our course on DL with Flax will show this, but below is a basic example modified from the JAX documentation.

import jax
import jax.numpy as jnp
from jax import random

The parameters of a multi-layer perceptron can be initialized with:

def init_params(layer_width):
  params = []
  key = random.PRNGKey(11)
  key, subkey = random.split(key)
  for n_in, n_out in zip(layer_width[:-1], layer_width[1:]):
    params.append(
        dict(weights=random.normal(subkey, (n_in, n_out)) * jnp.sqrt(2/n_in),
             biases=jnp.ones(n_out)
            )
    )
  return params

params = init_params([1, 128, 128, 1])

params is a pytree:

jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

To train our MLP, we need to define a function for the forward pass:

@jax.jit
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

And a loss function:

@jax.jit
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

Then we choose a learning rate and define a function for the backpropagation:

lr = 0.0001

@jax.jit
def update(params, x, y):
  grads = jax.grad(loss_fn)(params, x, y)
  return jax.tree.map(
      lambda p, g: p - lr * g, params, grads
  )

Because jax.grad can accept pytrees, we can create a new pytree grads by passing the params pytree to it.

The gradient descent can be applied using both pytrees thanks to jax.tree.map.

Then of course we could train our model:

key = random.PRNGKey(3)
key, subkey = random.split(key)

x = random.normal(subkey, (128, 1))
y = x ** 2

for _ in range(1000):
  params = update(params, x, y)