Relation to NumPy

Author

Marie-Hélène Burle

NumPy is a popular Python scientific API at the core of many libraries. JAX uses a NumPy-inspired API. There are however important differences that we will explore in this section.

First, let’s start an interactive job:

salloc --time=2:0:0 --mem-per-cpu=3500M

Nowadays, IPython (Interactive Python) is known as the kernel used by Jupyter when running Python. Before the existence of Jupyter however, this kernel was created as a better command shell than the default Python shell. For interactive Python sessions in the command line, it is nicer and faster than plain Python with no downside. So we will use it for this course.

For this, we need to load the ipython-kernel module. To see what versions are available, you can run:

module spider ipython-kernel

Let’s load the latest module:

module load ipython-kernel/3.11

Now, let’s activate the virtual python environment:

source /project/60055/env/bin/activate

Finally, we can launch IPython:

ipython

A NumPy-inspired API

NumPy being so popular, JAX comes with a convenient high-level wrapper to NumPy: jax.numpy.

Being familiar with NumPy is thus an advantage to get started with JAX. The NumPy quickstart is a useful resource.

For a more efficient usage, JAX also comes with a lower-level API: jax.lax.

import numpy as np
print(np.array([(1, 2, 3), (4, 5, 6)]))
[[1 2 3]
 [4 5 6]]
print(np.zeros((2, 3)))
[[0. 0. 0.]
 [0. 0. 0.]]
print(np.ones((2, 3, 2)))
[[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]]
print(np.arange(24).reshape(2, 3, 4))
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
print(np.linspace(0, 2, 9))
[0.   0.25 0.5  0.75 1.   1.25 1.5  1.75 2.  ]
print(np.linspace(0, 2, 9)[::-1])
[2.   1.75 1.5  1.25 1.   0.75 0.5  0.25 0.  ]
import jax.numpy as jnp
print(jnp.array([(1, 2, 3), (4, 5, 6)]))
[[1 2 3]
 [4 5 6]]
print(jnp.zeros((2, 3)))
[[0. 0. 0.]
 [0. 0. 0.]]
print(jnp.ones((2, 3, 2)))
[[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]]
print(jnp.arange(24).reshape(2, 3, 4))
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
print(jnp.linspace(0, 2, 9))
[0.   0.25 0.5  0.75 1.   1.25 1.5  1.75 2.  ]
print(jnp.linspace(0, 2, 9)[::-1])
[2.   1.75 1.5  1.25 1.   0.75 0.5  0.25 0.  ]

Despite the similarities, there are important differences between JAX and NumPy.

Differences with NumPy

Different types

type(np.zeros((2, 3))) == type(jnp.zeros((2, 3)))
False
type(np.zeros((2, 3)))
numpy.ndarray
type(jnp.zeros((2, 3)))
jaxlib.xla_extension.ArrayImpl

Different default data types

np.zeros((2, 3)).dtype
dtype('float64')
jnp.zeros((2, 3)).dtype
dtype('float32')

Lower numerical precision improves speed and reduces memory usage at no cost while training neural networks and is thus a net benefit. Having been built with deep learning in mind, JAX defaults align with that of other DL libraries (e.g. PyTorch, TensorFlow).

Immutable arrays

In NumPy, you can modify ndarrays:

a = np.arange(5)
a[0] = 9
print(a)
[9 1 2 3 4]

JAX arrays are immutable:

a = jnp.arange(5)
a[0] = 9
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

Instead, you need to create a copy of the array with the mutation. This is done with:

b = a.at[0].set(9)
print(b)
[9 1 2 3 4]

Of course, you can overwrite a:

a = a.at[0].set(9)

Pseudorandom number generation

Programming languages usually come with automated pseudorandom number generator (PRNG) based on nondeterministic data from the operating system. They are extremely convenient, but slow, based on repeats, and problematic in parallel executions.

JAX relies on an explicitly set random state called a key.

from jax import random

key = random.PRNGKey(18)
print(key)
[ 0 18]

Each time you call a random function, you need a subkey split from your key. Keys should only ever be used once in your code. The key is what makes your code reproducible, but you don’t want to reuse it within your code as it would create spurious correlations.

Here is the workflow:

  • you split your key into a new key and one or multiple subkeys,
  • you discard the old key (because it was used to do the split—so its entropy budget, so to speak, has been used),
  • you use the subkey(s) to run your random function(s) and keep the new key for a future split.

Subkeys are of the same nature as keys. This is just a terminology.

To make sure not to reuse the old key, you can overwrite it by the new one:

key, subkey = random.split(key)
print(key)
[4197003906 1654466292]

That’s the value of our new key for future splits.

print(subkey)
[1685972163 1654824463]

This is the value of the subkey that we can use to call a random function.

Let’s use that subkey now:

print(random.normal(subkey))
1.1437175

To split your key into more subkeys, pass an argument to random.split:

key, subkey1, subkey2, subkey3 = random.split(key, 4)

Strict input control

NumPy’s fundamental object is the ndarray, but NumPy is very tolerant as to the type of input.

np.sum([1.0, 2.0])  # here we are using a list
np.float64(3.0)
np.sum((1.0, 2.0))  # here is a tuple
np.float64(3.0)

To avoid inefficiencies, JAX will only accept arrays.

jnp.sum([1.0, 2.0])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
jnp.sum((1.0, 2.0))
TypeError: sum requires ndarray or scalar arguments, got <class 'tuple'> at position 0.

Out of bounds indexing

NumPy will warn you with an error message if you try to index out of bounds:

print(np.arange(5)[10])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[13], line 1
----> 1 print(np.arange(5)[10])

IndexError: index 10 is out of bounds for axis 0 with size 5

Be aware that JAX will not raise an error. Instead, it will silently return the closest boundary:

print(jnp.arange(5)[10])
4

Functionally pure functions

More importantly, only functionally pure functions—that is, functions for which the outputs are only based on the inputs and which have no side effects—can be used with JAX.

Outputs only based on inputs

Consider the function:

def f(x):
    return a + x

which uses the variable a from the global environment.

This function is not functionally pure because the outputs (the results of the function) do not solely depend on the arguments (the values given to x) passed to it. They also depend on the value of a.

Remember how tracing works: new inputs with the same shape and dtype use the cached compiled program directly. If the value of a changes in the global environment, a new tracing is not triggered and the cached compiled program uses the old value of a (the one that was used during tracing).

It is only if the code is run on an input x with a different shape and/or dtype that tracing happens again and that the new value for a takes effect.

To demo this, we need to use JIT compilation that we will explain in a later section.

from jax import jit

a = jnp.ones(3)
print(a)
[1. 1. 1.]
def f(x):
    return a + x

print(jit(f)(jnp.ones(3)))
[2. 2. 2.]

All good here because this is the first run (tracing).

Now, let’s change the value of a to an array of zeros:

a = jnp.zeros(3)
print(a)
[0. 0. 0.]

And rerun the same code:

print(jit(f)(jnp.ones(3)))
[2. 2. 2.]

We should have an array of ones, but we get the same result we got earlier. Why? because we are running a cached program with the value that a had during tracing.

The new value for a will only take effect if we re-trigger tracing by changing the shape and/or dtype of x:

a = jnp.zeros(4)
print(a)
[0. 0. 0. 0.]
print(jit(f)(jnp.ones(4)))
[1. 1. 1. 1.]

Passing an argument of a different shape to f forced recompilation. Using a different data type (e.g. with jnp.arange(3)) would have done the same.

No side effects

A function is said to have a side effect if it changes something outside of its local environment (if it does anything beside returning an output).

Examples of side effects include:

  • printing to standard output/shell,
  • reading from file/writing to file,
  • modifying a global variable.

In JAX, the side effects will happen during the first run (tracing), but will not happen on subsequent runs. You thus cannot rely on side effects in your code.

def f(a, b):
    print("Calculating sum")
    return a + b

print(jit(f)(jnp.arange(3), jnp.arange(3)))
Calculating sum
[0 2 4]

Printing (the side effect) happened here because this is the first run.

Let’s rerun the function:

print(jit(f)(jnp.arange(3), jnp.arange(3)))
[0 2 4]

This time, no printing.

Understanding jaxprs

Jaxprs are created by tracers wrapping the Python code during compilation (the first run). They contain information on the shape and data type of arrays as well as the operations performed on these arrays. Jaxprs do not however contain information on values: this allows the compiled program to be general enough to be rerun with any new arrays of the same shape and data type without having to rerun the slow Python code and recompile.

Jaxprs also do not contain any information on elements that are not part of the inputs such as external variables, nor do they contain information on side effects.

Jaxprs can be visualized with the jax.make_jaxpr function:

import jax

x = jnp.array([1., 4., 3.])
y = jnp.array([8., 1., 2.])

def f(x, y):
    return 2 * x**2 + y

jax.make_jaxpr(f)(x, y) 
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = integer_pow[y=2] a
    d:f32[3] = mul 2.0 c
    e:f32[3] = add d b
  in (e,) }

Let’s add a print function to f:

def f(x, y):
    print("This is a function with side-effect")
    return 2 * x**2 + y

jax.make_jaxpr(f)(x, y)
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = integer_pow[y=2] a
    d:f32[3] = mul 2.0 c
    e:f32[3] = add d b
  in (e,) }

The jaxpr is exactly the same. This is why printing will happen during tracing (when the Python code is run), but not afterwards (when the compiled code using the jaxpr is run).

Why the constraints?

The more constraints you add to a programming language, the more optimization you can get from the compiler. Speed comes at the cost of convenience.

For instance, consider a Python list. It is an extremely convenient and flexible object: heterogeneous, mutable… You can do anything with it. But computations on lists are extremely slow.

NumPy’s ndarrays are more constrained (homogeneous), but the type constraint permits the creation of a much faster language (NumPy is written in C and Fortran as well as Python) with vectorization, optimizations, and a greatly improved performance.

JAX takes it further: by using an intermediate representation and very strict constraints on type, pure functional programming, etc., yet more optimizations can be achieved and you can optimize your own functions with JIT compilation and the XLA. Ultimately, this is what makes JAX so fast.