15. JAX

Note

This lecture is built using hardware that has access to a GPU. This means that

  1. the lecture might be significantly slower when running on your machine, and

  2. the code is well-suited to execution with Google colab

This lecture provides a short introduction to Google JAX.

15.1. Overview

Let’s start with an overview of JAX.

15.1.1. Capabilities

JAX is a Python library initially developed by Google to support in-house artificial intelligence and machine learning.

JAX provides data types, functions and a compiler for fast linear algebra operations and automatic differentiation.

Loosely speaking, JAX is like NumPy with the addition of

  • automatic differentiation

  • automated GPU/TPU support

  • a just-in-time compiler

One of the great benefits of JAX is that the same code can be run either on the CPU or on a hardware accelerator, such as a GPU or TPU.

For example, JAX automatically builds and deploys kernels on the GPU whenever an accessible device is detected.

15.1.2. History

In 2015, Google open-sourced part of its AI infrastructure called TensorFlow.

Around two years later, Facebook open-sourced PyTorch beta, an alternative AI framework which is regarded as developer-friendly and more Pythonic than TensorFlow.

By 2019, PyTorch was surging in popularity, adopted by Uber, Airbnb, Tesla and many other companies.

In 2020, Google launched JAX as an open-source framework, simultaneously beginning to shift away from TPUs to Nvidia GPUs.

In the last few years, uptake of Google JAX has accelerated rapidly, bringing attention back to Google-based machine learning architectures.

15.1.3. Installation

JAX can be installed with or without GPU support by following the install guide.

Note that JAX is pre-installed with GPU support on Google Colab.

If you do not have your own GPU, we recommend that you run this lecture on Colab.

15.2. JAX as a NumPy Replacement

One way to use JAX is as a plug-in NumPy replacement. Let’s look at the similarities and differences.

15.2.1. Similarities

The following import is standard, replacing import numpy as np:

import jax
import jax.numpy as jnp

Now we can use jnp in place of np for the usual array operations:

a = jnp.asarray((1.0, 3.2, -1.5))
print(a)
[ 1.   3.2 -1.5]
print(jnp.sum(a))
2.6999998
print(jnp.mean(a))
0.9
print(jnp.dot(a, a))
13.490001

However, the array object a is not a NumPy array:

a
DeviceArray([ 1. ,  3.2, -1.5], dtype=float32)
type(a)
jaxlib.xla_extension.DeviceArray

Even scalar-valued maps on arrays return objects of type DeviceArray:

jnp.sum(a)
DeviceArray(2.6999998, dtype=float32)

The term Device refers to the hardware accelerator (GPU or TPU), although JAX falls back to the CPU if no accelerator is detected.

(In the terminology of GPUs, the “host” is the machine that launches GPU operations, while the “device” is the GPU itself.)

Note

Note that DeviceArray is a future; it allows Python to continue execution when the results of computation are not available immediately.

This means that Python can dispatch more jobs without waiting for the computation results to be returned by the device.

This feature is called asynchronous dispatch, which hides Python overheads and reduces wait time.

Operations on higher dimensional arrays is also similar to NumPy:

A = jnp.ones((2, 2))
B = jnp.identity(2)
A @ B
DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32)
from jax.numpy import linalg
linalg.solve(B, A)
DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32)
linalg.eigh(B)  # Computes eigenvalues and eigenvectors
(DeviceArray([0.99999994, 0.99999994], dtype=float32),
 DeviceArray([[1., 0.],
              [0., 1.]], dtype=float32))

15.2.2. Differences

One difference between NumPy and JAX is that, when running on a GPU, JAX uses 32 bit floats by default.

This is standard for GPU computing and can lead to significant speed gains with small loss of precision.

However, for some calculations precision matters. In these cases 64 bit floats can be enforced via the command

jax.config.update("jax_enable_x64", True)

Let’s check this works:

jnp.ones(3)
DeviceArray([1., 1., 1.], dtype=float64)

As a NumPy replacement, a more significant difference is that arrays are treated as immutable.

For example, with NumPy we can write

import numpy as np
a = np.linspace(0, 1, 3)
a
array([0. , 0.5, 1. ])

and then mutate the data in memory:

a[0] = 1
a
array([1. , 0.5, 1. ])

In JAX this fails:

a = jnp.linspace(0, 1, 3)
a
DeviceArray([0. , 0.5, 1. ], dtype=float64)
a[0] = 1
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_2517/3686271957.py in <module>
----> 1 a[0] = 1

/__w/lecture-python-programming.myst/lecture-python-programming.myst/3/envs/quantecon/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _unimplemented_setitem(self, i, x)
   4946          "or another .at[] method: "
   4947          "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html")
-> 4948   raise TypeError(msg.format(type(self)))
   4949 
   4950 def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array:

TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In line with immutability, JAX does not support inplace operations:

a = np.array((2, 1))
a.sort()
a
array([1, 2])
a = jnp.array((2, 1))
a_new = a.sort()
a, a_new
(DeviceArray([2, 1], dtype=int64), DeviceArray([1, 2], dtype=int64))

The designers of JAX chose to make arrays immutable because JAX uses a functional programming style. More on this below.

Note that, while mutation is discouraged, it is in fact possible with at, as in

a = jnp.linspace(0, 1, 3)
id(a)
140290482633104
a
DeviceArray([0. , 0.5, 1. ], dtype=float64)
a.at[0].set(1)
DeviceArray([1. , 0.5, 1. ], dtype=float64)

We can check that the array is mutated by verifying its identity is unchanged:

id(a)
140290482633104

15.3. Random Numbers

Random numbers are also a bit different in JAX, relative to NumPy. Typically, in JAX, the state of the random number generator needs to be controlled explicitly.

import jax.random as random

First we produce a key, which seeds the random number generator.

key = random.PRNGKey(1)
type(key)
jaxlib.xla_extension.DeviceArray
print(key)
[0 1]

Now we can use the key to generate some random numbers:

x = random.normal(key, (3, 3))
x
DeviceArray([[-1.35247421, -0.2712502 , -0.02920518],
             [ 0.34706456,  0.5464053 , -1.52325812],
             [ 0.41677264, -0.59710138, -0.5678208 ]], dtype=float64)

If we use the same key again, we initialize at the same seed, so the random numbers are the same:

random.normal(key, (3, 3))
DeviceArray([[-1.35247421, -0.2712502 , -0.02920518],
             [ 0.34706456,  0.5464053 , -1.52325812],
             [ 0.41677264, -0.59710138, -0.5678208 ]], dtype=float64)

To produce a (quasi-) independent draw, best practice is to “split” the existing key:

key, subkey = random.split(key)
random.normal(key, (3, 3))
DeviceArray([[ 1.85374374, -0.37683949, -0.61276867],
             [-1.91829718,  0.27219409,  0.54922246],
             [ 0.40451442, -0.58726839, -0.63967753]], dtype=float64)
random.normal(subkey, (3, 3))
DeviceArray([[-0.4300635 ,  0.22778552,  0.57241269],
             [-0.15969178,  0.46719192,  0.21165091],
             [ 0.84118631,  1.18671326, -0.16607783]], dtype=float64)

The function below produces k (quasi-) independent random n x n matrices using this procedure.

def gen_random_matrices(key, n, k):
    matrices = []
    for _ in range(k):
        key, subkey = random.split(key)
        matrices.append(random.uniform(subkey, (n, n)))
    return matrices
matrices = gen_random_matrices(key, 2, 2)
for A in matrices:
    print(A)
[[0.97440813 0.3838544 ]
 [0.9790686  0.99981046]]
[[0.3473302  0.17157842]
 [0.89346686 0.01403153]]

One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays. Hence, to get a one-dimensional array of normal random draws we use (len, ) for the shape, as in

random.normal(key, (5, ))
DeviceArray([-0.64377279,  0.76961857, -0.29809604,  0.47858776,
             -2.00591299], dtype=float64)

15.4. JIT Compilation

The JAX JIT compiler accelerates logic within functions by fusing linear algebra operations into a single, highly optimized kernel that the host can launch on the GPU / TPU (or CPU if no accelerator is detected).

Consider the following pure Python function.

def f(x, p=1000):
    return sum((k*x for k in range(p)))

Let’s build an array to call the function on.

n = 50_000_000
x = jnp.ones(n)

How long does the function take to execute?

%time f(x).block_until_ready()
CPU times: user 455 ms, sys: 180 ms, total: 635 ms
Wall time: 3.37 s
DeviceArray([499500., 499500., 499500., ..., 499500., 499500., 499500.],            dtype=float64)

Note

With asynchronous dispatch, the %time magic is only evaluating the time to dispatch by the Python interpreter, without taking into account the computation time on the device.

Here, to measure the actual speed, the block_until_ready() method prevents asynchronous dispatch by asking Python to wait until the computation results are ready.

This code is not particularly fast.

While it is run on the GPU (since x is a DeviceArray), each vector k * x has to be instantiated before the final sum is computed.

If we JIT-compile the function with JAX, then the operations are fused and no intermediate arrays are created.

f_jit = jax.jit(f)   # target for JIT compilation

Let’s run once to compile it:

f_jit(x)
DeviceArray([499500., 499500., 499500., ..., 499500., 499500., 499500.],            dtype=float64)

And now let’s time it.

%time f_jit(x).block_until_ready()
CPU times: user 1.15 ms, sys: 293 µs, total: 1.44 ms
Wall time: 36 ms
DeviceArray([499500., 499500., 499500., ..., 499500., 499500., 499500.],            dtype=float64)

15.5. Functional Programming

From JAX’s documentation:

When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.

In other words, JAX assumes a functional programming style.

The major implication is that JAX functions should be pure:

  • no dependence on global variables

  • no side effects

“A pure function will always return the same result if invoked with the same inputs.”

JAX will not usually throw errors when compiling impure functions but execution becomes unpredictable.

Here’s an illustration of this fact, using global variables:

a = 1  # global

@jax.jit
def f(x):
    return a + x
x = jnp.ones(2)
f(x)
DeviceArray([2., 2.], dtype=float64)

In the code above, the global value a=1 is fused into the jitted function.

Even if we change a, the output of f will not be affected — as long as the same compiled version is called.

a = 42
f(x)
DeviceArray([2., 2.], dtype=float64)

Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of a takes effect:

x = np.ones(3)
f(x)
DeviceArray([43., 43., 43.], dtype=float64)

Moral of the story: write pure functions when using JAX!

15.6. Gradients

JAX can use automatic differentiation to compute gradients.

This can be extremely useful in optimization, root finding and other applications.

Here’s a very simple illustration, involving the function

def f(x):
    return (x**2) / 2

Let’s take the derivative:

f_prime = jax.grad(f)
f_prime(10.0)
DeviceArray(10., dtype=float64, weak_type=True)

Let’s plot the function and derivative, noting that \(f'(x) = x\).

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
x_grid = jnp.linspace(-4, 4, 200)
ax.plot(x_grid, f(x_grid), label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend(loc='upper center')
plt.show()
_images/jax_intro_86_0.png

15.7. Exercises

Exercise 15.1

Recall that Newton’s method for solving for the root of \(f\) involves iterating on

\[ q(x) = x - \frac{f(x)}{f'(x)} \]

Write a function called newton that takes a function \(f\) plus a guess \(x_0\) and returns an approximate fixed point.

Your newton implementation should use automatic differentiation to calculate \(f'\).

Test your newton method on the function shown below.

f = lambda x: jnp.sin(4 * (x - 1/4)) + x + x**20 - 1
x = jnp.linspace(0, 1, 100)

fig, ax = plt.subplots()
ax.plot(x, f(x), label='$f(x)$')
ax.axhline(ls='--', c='k')
ax.set_xlabel('$x$', fontsize=12)
ax.set_ylabel('$f(x)$', fontsize=12)
ax.legend(fontsize=12)
plt.show()
_images/jax_intro_89_0.png

Exercise 15.2

In an earlier exercise on parallelization, we used Monte Carlo to price a European call option.

The code was accelerated by Numba-based multithreading.

Try writing a version of this operation for JAX, using all the same parameters.

If you are running your code on a GPU, you should be able to achieve significantly faster exection.