Relation to NumPy

Author

Marie-Hélène Burle

NumPy is a popular Python scientific API at the core of many libraries. JAX uses a NumPy-inspired API. There are however important differences that we will explore in this section.

A NumPy-inspired API

NumPy being so popular, JAX comes with a convenient high-level wrapper to NumPy: jax.numpy.

Being familiar with NumPy is thus an advantage to get started with JAX. The NumPy quickstart is a useful resource.

For a more efficient usage, JAX also comes with a lower-level API: jax.lax.

NumPy

import numpy as np
print(np.array([(1, 2, 3), (4, 5, 6)]))
[[1 2 3]
 [4 5 6]]
print(np.zeros((2, 3)))
[[0. 0. 0.]
 [0. 0. 0.]]
print(np.ones((2, 3, 2)))
[[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]]
print(np.arange(24).reshape(2, 3, 4))
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
print(np.linspace(0, 2, 9))
[0.   0.25 0.5  0.75 1.   1.25 1.5  1.75 2.  ]
print(np.linspace(0, 2, 9)[::-1])
[2.   1.75 1.5  1.25 1.   0.75 0.5  0.25 0.  ]

JAX NumPy

import jax.numpy as jnp
print(jnp.array([(1, 2, 3), (4, 5, 6)]))
[[1 2 3]
 [4 5 6]]
print(jnp.zeros((2, 3)))
[[0. 0. 0.]
 [0. 0. 0.]]
print(jnp.ones((2, 3, 2)))
[[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]]
print(jnp.arange(24).reshape(2, 3, 4))
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
print(jnp.linspace(0, 2, 9))
[0.   0.25 0.5  0.75 1.   1.25 1.5  1.75 2.  ]
print(jnp.linspace(0, 2, 9)[::-1])
[2.   1.75 1.5  1.25 1.   0.75 0.5  0.25 0.  ]

Despite the similarities, there are important differences between JAX and NumPy.

Differences with NumPy

Different types

type(np.zeros((2, 3))) == type(jnp.zeros((2, 3)))
False
type(np.zeros((2, 3)))
<class 'numpy.ndarray'>
type(jnp.zeros((2, 3)))
<class 'jaxlib.xla_extension.ArrayImpl'>

Different default data types

np.zeros((2, 3)).dtype
dtype('float64')
jnp.zeros((2, 3)).dtype
dtype('float32')

Functionally pure functions

More importantly, only functionally pure functions—that is, functions for which the outputs are only based on the inputs and which have no side effects—can be used with JAX.

Outputs only based on inputs

Consider the function:

def impure_func(x):
    return a + x

which uses the variable a from the global environment.

This function is not functionally pure because the outputs (the results of the function) do not solely depend on the arguments (the values given to x) passed to it. They also depend on the value of a.

Remember how tracing works: new inputs with the same shape and dtype use the cached compiled program directly. If the value of a changes in the global environment, a new tracing is not triggered and the cached compiled program uses the old value of a (the one that was used during tracing).

It is only if the code is run on an input x with a different shape and/or dtype that tracing happens again and that the new value for a takes effect.

To demo this, we need to use JIT compilation that we will explain in a later section.

from jax import jit

a = jnp.ones(3)
print(a)
[1. 1. 1.]
def impure_func(x):
    return a + x

print(jit(impure_func)(jnp.ones(3)))
[2. 2. 2.]

All good here because this is the first run (tracing).

Now, let’s change the value of a to an array of zeros:

a = jnp.zeros(3)
print(a)
[0. 0. 0.]

And rerun the same code:

print(jit(impure_func)(jnp.ones(3)))
[2. 2. 2.]

We should have an array of ones, but we get the same result we got earlier. Why? because we are running a cached program with the value that a had during tracing.

The new value for a will only take effect if we re-trigger tracing by changing the shape and/or dtype of x:

a = jnp.zeros(4)
print(a)
[0. 0. 0. 0.]
print(jit(impure_func)(jnp.ones(4)))
[1. 1. 1. 1.]

Passing to impure_func() an argument of a different shape forced retracing.

No side effects

A function is said to have a side effect if it changes something outside of its local environment (if it does anything beside returning an output).

Examples of side effects include:

  • printing to standard output/shell,
  • reading from file/writing to file,
  • modifying a global variable.

In JAX, the side effects will happen during the first run (tracing), but will not happen on subsequent runs. You thus cannot rely on side effects in your code.

def impure_func(a, b):
    print("Calculating sum...")
    return a + b

print(jit(impure_func)(jnp.arange(3), jnp.arange(3)))
Calculating sum...
[0 2 4]

Printing (the side effect) happened here because this is the first run.

Let’s rerun the function:

print(jit(impure_func)(jnp.arange(3), jnp.arange(3)))
[0 2 4]

This time, no printing…

Pseudorandom number generation (PRNG)

Programming languages usually come with automated pseudorandom number generator based on nondeterministic data from the operating system. They are extremely convenient, but slow, based on repeats, and problematic in parallel executions.

JAX relies on an explicitly set random state called a key.

from jax import random

key = random.PRNGKey(18)
print(key)
[ 0 18]

Each time you call a random function, you need a subkey split from your key. Keys should only ever be used once in your code. The key is what makes your code reproducible, but you don’t want to reuse it within your code as it would create spurious correlations.

Here is the workflow:

  • you split your key into a new key and a subkey,
  • you discard the old key (because it was used to do the split—so its entropy budget, so to speak, has been used),
  • you use the subkey to run your random function and keep the new key for a future split.

To make sure not to reuse the old key, you can overwrite it by the new one:

key, subkey = random.split(key)
print(key)
[4197003906 1654466292]

That’s the value of our new key for future splits.

print(subkey)
[1685972163 1654824463]

This is the value of the subkey that we can use to call a random function:

print(random.normal(subkey))
1.1437175

Immutable arrays

JAX arrays are immutable.

In NumPy, you can modify ndarrays:

a = np.arange(5)
a[0] = 9
print(a)
[9 1 2 3 4]

This is impossible in JAX:

a = jnp.arange(5)
a[0] = 9
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

Instead, you need to create a copy of the array with the mutation. This is done with:

b = a.at[0].set(9)
print(b)
[9 1 2 3 4]

Of course, if you want to modify the array in place, you can overwrite a:

a = a.at[0].set(9)

Input control

NumPy’s fundamental object is the ndarray, but NumPy is lax as to the type of input:

type([1.0, 2.0])
<class 'list'>
np.sum([1.0, 2.0])
3.0
type((1.0, 2.0))
<class 'tuple'>
np.sum((1.0, 2.0))
3.0

To avoid inefficiencies, JAX will only accept arrays:

jnp.sum([1.0, 2.0])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
jnp.sum((1.0, 2.0))
TypeError: sum requires ndarray or scalar arguments, got <class 'tuple'> at position 0.

Out of bounds indexing

NumPy will warn you with an error message if you try to index out of bounds:

print(jnp.arange(5))
[0 1 2 3 4]
print(np.arange(5)[10])
IndexError: index 10 is out of bounds for axis 0 with size 5

Be aware that JAX will not raise an error. Instead, it will silently return the closest boundary:

print(jnp.arange(5)[10])
4

Why all these constraints?

The more constraints you add to a programming language, the more optimization you can get from the compiler. Speed comes at the cost of convenience.

For instance, consider a Python list. It is an extremely convenient and flexible object: heterogeneous, mutable… You can do anything with it. But computations on lists are extremely slow.

NumPy’s ndarrays are more constrained (homogeneous), but the type constraint permits the creation of a much faster language (NumPy is written in C and Fortran as well as Python) with vectorization, optimizations, and a greatly improved performance.

JAX takes it further: by using an intermediate representation and very strict constraints on type, pure functional programming, etc., yet more optimizations can be achieved and you can optimize your own functions with JIT compilation and the XLA. Ultimately, this is what makes JAX so fast.