# Pytrees

It is convenient to store data, model parameters, gradients, etc. in container structures such as lists or dicts. JAX has a container-like structure, the *pytree* that is flexible, can be nested, and is supported by many JAX functions, making for convenient workflows.

This section introduces pytrees and their functioning.

## A tree-like structure

The pytree container registry contains, by default, *lists*, *tuples*, and *dicts*. It can be extended to other containers.

Objects in the pytree container registry are pytrees. Other objects are leaf pytrees (so pytrees are recursive).

Pytrees are great for holding data and parameters, keeping everything organized, even for complex models. The leaves are usually made of arrays. Many JAX functions can be applied to pytrees.

Examples of pytrees:

```
1, 2, 3),
(1, 1., "string", True],
[2),
jnp.arange('key1': 3.4, 'key2': 6.},
{3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}],
[3, 2, (6, 0), 2, ()),
(3) jnp.zeros(
```

## Extracting leaves

Trees can be flattened and their leaves extracted into a list with `jax.tree.leaves`

:

`3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]) jax.tree.leaves([`

`[3.0, 1, 2, 'val1', 'val2', 'val3']`

Let’s create a list of pytrees and extract their leaves to look at more examples:

```
import jax
import jax.numpy as jnp
= [
list_trees 1, 2, 3),
(1, 1., "string", True],
[2),
jnp.arange('key1': 3.4, 'key2': 6.},
{3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}],
[3, 2, (6, 0, 9), 2, ()),
(3)
jnp.zeros(
]
for pytree in list_trees:
= jax.tree.leaves(pytree)
leaves print(f"{len(leaves)} leaves: {leaves}")
```

```
3 leaves: [1, 2, 3]
4 leaves: [1, 1.0, 'string', True]
1 leaves: [Array([0, 1], dtype=int32)]
2 leaves: [3.4, 6.0]
6 leaves: [3.0, 1, 2, 'val1', 'val2', 'val3']
5 leaves: [3, 2, 6, 0, 9, 2]
1 leaves: [Array([0., 0., 0.], dtype=float32)]
```

Be careful that leaves are not the same as container elements:

- while an array contains many elements, it is a single leaf,
- while a nested list or tuple represent a single element of the parent container, all the elements of nested tuples and lists are leaves,
- an empty tuple or list is a pytree without children and is not counted as a leaf.

Contrast this with the length (i.e. the number of elements of containers):

```
for pytree in list_trees:
print(f"{len(pytree)} elements")
```

```
3 elements
4 elements
2 elements
2 elements
3 elements
5 elements
3 elements
```

## Structure of pytrees

As we just saw, JAX can extract the leaves of pytrees. This is useful to run functions on them. But JAX also records their structure and is able to recreate them. The structure can be obtained with `jax.tree.structure`

:

`3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]) jax.tree.structure([`

`PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}])`

So each pytree can be turned into a tuple of the list of its leaves and its structure and that tuple can be turned back into the pytree.

`3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]) jax.tree.flatten([`

```
([3.0, 1, 2, 'val1', 'val2', 'val3'],
PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}]))
```

```
= jax.tree.flatten(
values, structure 3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]
[
) jax.tree.unflatten(structure, values)
```

`[3.0, (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]`

The path to each leaf can be obtained with `jax.tree_util.tree_flatten_with_path`

:

```
jax.tree_util.tree_flatten_with_path(3., (1, 2), {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'}]
[ )
```

```
([((SequenceKey(idx=0),), 3.0),
((SequenceKey(idx=1), SequenceKey(idx=0)), 1),
((SequenceKey(idx=1), SequenceKey(idx=1)), 2),
((SequenceKey(idx=2), DictKey(key='key1')), 'val1'),
((SequenceKey(idx=2), DictKey(key='key2')), 'val2'),
((SequenceKey(idx=2), DictKey(key='key3')), 'val3')],
PyTreeDef([*, (*, *), {'key1': *, 'key2': *, 'key3': *}]))
```

## Pytree operations

JAX can run operations on pytrees. Let’s create a few pytrees to play with:

```
= {'key1': 1., 'key2': 2., 'key3': 3.}
tree1 = {'key1': 4., 'key2': 5., 'key3': 6.}
tree2 = {'key1': 7., 'key2': 8., 'key3': 9.} tree3
```

`jax.tree.map`

allows to apply functions to each leaf of a tree:

`map(lambda x: 3 * x, tree1) jax.tree.`

`{'key1': 3.0, 'key2': 6.0, 'key3': 9.0}`

As long as pytrees share the same structure (including the same dicts keys), operations combining multiple pytrees also work:

`map(lambda x, y, z: x * y + z, tree1, tree2, tree3) jax.tree.`

`{'key1': 11.0, 'key2': 18.0, 'key3': 27.0}`

Here are a few more examples:

```
= [[1, 1, 1], (2, 2, 2, 2), 3]
tree4 = [[0, 5, 1], (2, 2, 2, 2), 3]
tree5 = [[0, 5, 1, 2], (2, 2, 2), 3] tree6
```

`map(lambda x, y: x + y, tree4, tree5) jax.tree.`

`[[1, 6, 2], (4, 4, 4, 4), 6]`

This won’t work though as the structures are different:

`map(lambda x, y: x + y, tree5, tree6) jax.tree.`

`ValueError: Tuple arity mismatch: 3 != 4; tuple: (2, 2, 2).`

## Pytree transposition

A list of pytrees can be transposed into a pytree of lists.

Let’s create a list with a few of our previous pytrees:

```
= [tree1, tree2, tree3]
trees print(trees)
```

`[{'key1': 1.0, 'key2': 2.0, 'key3': 3.0}, {'key1': 1.0, 'key2': 2.0, 'key3': 3.0}, {'key1': 1.0, 'key2': 2.0, 'key3': 3.0}]`

Here is how to transpose this list of pytrees:

`map(lambda *x: list(x), *trees) jax.tree.`

`{'key1': [1.0, 1.0, 1.0], 'key2': [2.0, 2.0, 2.0], 'key3': [3.0, 3.0, 3.0]}`

## Pytrees in NN

Pytrees are very useful when using JAX for deep learning. Our course on DL with Flax will show this, but below is a basic example modified from the JAX documentation.

```
import jax
import jax.numpy as jnp
from jax import random
```

The parameters of a multi-layer perceptron can be initialized with:

```
def init_params(layer_width):
= []
params = random.PRNGKey(11)
key = random.split(key)
key, subkey for n_in, n_out in zip(layer_width[:-1], layer_width[1:]):
params.append(dict(weights=random.normal(subkey, (n_in, n_out)) * jnp.sqrt(2/n_in),
=jnp.ones(n_out)
biases
)
)return params
= init_params([1, 128, 128, 1]) params
```

`params`

is a pytree:

`map(lambda x: x.shape, params) jax.tree.`

```
[{'biases': (128,), 'weights': (1, 128)},
{'biases': (128,), 'weights': (128, 128)},
{'biases': (1,), 'weights': (128, 1)}]
```

To train our MLP, we need to define a function for the forward pass:

```
@jax.jit
def forward(params, x):
*hidden, last = params
for layer in hidden:
= jax.nn.relu(x @ layer['weights'] + layer['biases'])
x return x @ last['weights'] + last['biases']
```

And a loss function:

```
@jax.jit
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
```

Then we choose a learning rate and define a function for the backpropagation:

```
= 0.0001
lr
@jax.jit
def update(params, x, y):
= jax.grad(loss_fn)(params, x, y)
grads return jax.tree.map(
lambda p, g: p - lr * g, params, grads
)
```

Because `jax.grad`

can accept pytrees, we can create a new pytree `grads`

by passing the `params`

pytree to it.

The gradient descent can be applied using both pytrees thanks to `jax.tree.map`

.

Then of course we could train our model:

```
= random.PRNGKey(3)
key = random.split(key)
key, subkey
= random.normal(subkey, (128, 1))
x = x ** 2
y
for _ in range(1000):
= update(params, x, y) params
```