Pytrees
It is convenient to store data, model parameters, gradients, etc. in container structures such as lists or dicts. JAX has a container-like structure, the pytree that is flexible, can be nested, and is supported by many JAX functions, making for convenient workflows.
This section introduces pytrees and their functioning.
A tree-like structure
The pytree container registry contains, by default, lists, tuples, and dicts. It can be extended to other containers.
Objects in the pytree container registry are pytrees. Other objects are leaf pytrees (so pytrees are recursive).
Pytrees are great for holding data and parameters, keeping everything organized, even for complex models. The leaves are usually made of arrays. Many JAX functions can be applied to pytrees.
Examples of pytrees:
1, 2, 3),
(1, 1., "string", True],
[2),
jnp.arange('key1': 3.4, 'key2': 6.},
{3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}],
[3, 2, (6, 0), 2, ()),
(3) jnp.zeros(
Extracting leaves
Trees can be flattened and their leaves extracted into a list with jax.tree.leaves
:
3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]) jax.tree.leaves([
[3.0, 1, 2, 'val1', 'val2', 'val3']
Let’s create a list of pytrees and extract their leaves to look at more examples:
import jax
import jax.numpy as jnp
= [
list_trees 1, 2, 3),
(1, 1., "string", True],
[2),
jnp.arange('key1': 3.4, 'key2': 6.},
{3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}],
[3, 2, (6, 0, 9), 2, ()),
(3)
jnp.zeros(
]
for pytree in list_trees:
= jax.tree.leaves(pytree)
leaves print(f"{len(leaves)} leaves: {leaves}")
3 leaves: [1, 2, 3]
4 leaves: [1, 1.0, 'string', True]
1 leaves: [Array([0, 1], dtype=int32)]
2 leaves: [3.4, 6.0]
6 leaves: [3.0, 1, 2, 'val1', 'val2', 'val3']
5 leaves: [3, 2, 6, 0, 9, 2]
1 leaves: [Array([0., 0., 0.], dtype=float32)]
Be careful that leaves are not the same as container elements:
- while an array contains many elements, it is a single leaf,
- while a nested list or tuple represent a single element of the parent container, all the elements of nested tuples and lists are leaves,
- an empty tuple or list is a pytree without children and is not counted as a leaf.
Contrast this with the length (i.e. the number of elements of containers):
for pytree in list_trees:
print(f"{len(pytree)} elements")
3 elements
4 elements
2 elements
2 elements
3 elements
5 elements
3 elements
Structure of pytrees
As we just saw, JAX can extract the leaves of pytrees. This is useful to run functions on them. But JAX also records their structure and is able to recreate them. The structure can be obtained with jax.tree.structure
:
3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]) jax.tree.structure([
PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}])
So each pytree can be turned into a tuple of the list of its leaves and its structure and that tuple can be turned back into the pytree.
3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]) jax.tree.flatten([
([3.0, 1, 2, 'val1', 'val2', 'val3'],
PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}]))
= jax.tree.flatten(
values, structure 3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]
[
) jax.tree.unflatten(structure, values)
[3.0, (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]
The path to each leaf can be obtained with jax.tree_util.tree_flatten_with_path
:
jax.tree_util.tree_flatten_with_path(3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]
[ )
([((SequenceKey(idx=0),), 3.0),
((SequenceKey(idx=1), SequenceKey(idx=0)), 1),
((SequenceKey(idx=1), SequenceKey(idx=1)), 2),
((SequenceKey(idx=2), DictKey(key='key1')), 'val1'),
((SequenceKey(idx=2), DictKey(key='key2')), 'val2'),
((SequenceKey(idx=2), DictKey(key='key3')), 'val3')],
PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}]))
Pytree operations
JAX can run operations on pytrees. Let’s create a few pytrees to play with:
= {'key1': 1., 'key2': 2., 'key3': 3.}
tree1 = {'key1': 4., 'key2': 5., 'key3': 6.}
tree2 = {'key1': 7., 'key2': 8., 'key3': 9.} tree3
jax.tree.map
allows to apply functions to each leaf of a tree:
map(lambda x: 3 * x, tree1) jax.tree.
{'key1': 3.0, 'key2': 6.0, 'key3': 9.0}
As long as pytrees share the same structure (including the same dicts keys), operations combining multiple pytrees also work:
map(lambda x, y, z: x * y + z, tree1, tree2, tree3) jax.tree.
{'key1': 11.0, 'key2': 18.0, 'key3': 27.0}
Here are a few more examples:
= [[1, 1, 1], (2, 2, 2, 2), 3]
tree4 = [[0, 5, 1], (2, 2, 2, 2), 3]
tree5 = [[0, 5, 1, 2], (2, 2, 2), 3] tree6
map(lambda x, y: x + y, tree4, tree5) jax.tree.
[[1, 6, 2], (4, 4, 4, 4), 6]
This won’t work though as the structures are different:
map(lambda x, y: x + y, tree5, tree6) jax.tree.
ValueError: Tuple arity mismatch: 3 != 4; tuple: (2, 2, 2).
Pytree transposition
A list of pytrees can be transposed into a pytree of lists.
Let’s create a list with a few of our previous pytrees:
= [tree1, tree2, tree3]
trees print(trees)
[{'key1': 1.0, 'key2': 2.0, 'key3': 3.0}, {'key1': 1.0, 'key2': 2.0, 'key3': 3.0}, {'key1': 1.0, 'key2': 2.0, 'key3': 3.0}]
Here is how to transpose this list of pytrees:
map(lambda *x: list(x), *trees) jax.tree.
{'key1': [1.0, 1.0, 1.0], 'key2': [2.0, 2.0, 2.0], 'key3': [3.0, 3.0, 3.0]}
Pytrees in NN
Pytrees are very useful when using JAX for deep learning. Our course on DL with Flax will show this, but below is a basic example modified from the JAX documentation.
import jax
import jax.numpy as jnp
from jax import random
The parameters of a multi-layer perceptron can be initialized with:
def init_params(layer_width):
= []
params = random.PRNGKey(11)
key = random.split(key)
key, subkey for n_in, n_out in zip(layer_width[:-1], layer_width[1:]):
params.append(dict(weights=random.normal(subkey, (n_in, n_out)) * jnp.sqrt(2/n_in),
=jnp.ones(n_out)
biases
)
)return params
= init_params([1, 128, 128, 1]) params
params
is a pytree:
map(lambda x: x.shape, params) jax.tree.
[{'biases': (128,), 'weights': (1, 128)},
{'biases': (128,), 'weights': (128, 128)},
{'biases': (1,), 'weights': (128, 1)}]
To train our MLP, we need to define a function for the forward pass:
@jax.jit
def forward(params, x):
*hidden, last = params
for layer in hidden:
= jax.nn.relu(x @ layer['weights'] + layer['biases'])
x return x @ last['weights'] + last['biases']
And a loss function:
@jax.jit
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
Then we choose a learning rate and define a function for the backpropagation:
= 0.0001
lr
@jax.jit
def update(params, x, y):
= jax.grad(loss_fn)(params, x, y)
grads return jax.tree.map(
lambda p, g: p - lr * g, params, grads
)
Because jax.grad
can accept pytrees, we can create a new pytree grads
by passing the params
pytree to it.
The gradient descent can be applied using both pytrees thanks to jax.tree.map
.
Then of course we could train our model:
= random.PRNGKey(3)
key = random.split(key)
key, subkey
= random.normal(subkey, (128, 1))
x = x ** 2
y
for _ in range(1000):
= update(params, x, y) params