14. JAX#
In addition to what’s in Anaconda, this lecture will need the following libraries:
!pip install jax quantecon
This lecture provides a short introduction to Google JAX.
Here we are focused on using JAX on the CPU, rather than on accelerators such as GPUs or TPUs.
This means we will only see a small amount of the possible benefits from using JAX.
However, JAX seamlessly handles transitions across different hardware platforms.
As a result, if you run this code on a machine with a GPU and a GPU-aware version of JAX installed, your code will be automatically accelerated and you will receive the full benefits.
For a discussion of JAX on GPUs, see our JAX lecture series.
14.1. JAX as a NumPy Replacement#
One of the attractive features of JAX is that, whenever possible, it conforms to the NumPy API for array operations.
This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement.
Let’s look at the similarities and differences between JAX and NumPy.
14.1.1. Similarities#
We’ll use the following imports
import jax
import quantecon as qe
In addition, we replace import numpy as np with
import jax.numpy as jnp
Now we can use jnp in place of np for the usual array operations:
a = jnp.asarray((1.0, 3.2, -1.5))
print(a)
[ 1. 3.2 -1.5]
print(jnp.sum(a))
2.6999998
print(jnp.mean(a))
0.9
print(jnp.dot(a, a))
13.490001
However, the array object a is not a NumPy array:
a
Array([ 1. , 3.2, -1.5], dtype=float32)
type(a)
jaxlib._jax.ArrayImpl
Even scalar-valued maps on arrays return JAX arrays.
jnp.sum(a)
Array(2.6999998, dtype=float32)
Operations on higher dimensional arrays are also similar to NumPy:
A = jnp.ones((2, 2))
B = jnp.identity(2)
A @ B
Array([[1., 1.],
[1., 1.]], dtype=float32)
JAX’s array interface also provides the linalg subpackage:
jnp.linalg.inv(B) # Inverse of identity is identity
Array([[1., 0.],
[0., 1.]], dtype=float32)
jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors
EighResult(eigenvalues=Array([1., 1.], dtype=float32), eigenvectors=Array([[1., 0.],
[0., 1.]], dtype=float32))
14.1.2. Differences#
Let’s now look at some differences between JAX and NumPy array operations.
14.1.2.1. Precision#
One difference between NumPy and JAX is that JAX uses 32 bit floats by default.
This is because JAX is often used for GPU computing, and most GPU computations use 32 bit floats.
Using 32 bit floats can lead to significant speed gains with small loss of precision.
However, for some calculations precision matters.
In these cases 64 bit floats can be enforced via the command
jax.config.update("jax_enable_x64", True)
Let’s check this works:
jnp.ones(3)
Array([1., 1., 1.], dtype=float64)
14.1.2.2. Immutability#
As a NumPy replacement, a more significant difference is that arrays are treated as immutable.
For example, with NumPy we can write
import numpy as np
a = np.linspace(0, 1, 3)
a
array([0. , 0.5, 1. ])
and then mutate the data in memory:
a[0] = 1
a
array([1. , 0.5, 1. ])
In JAX this fails:
a = jnp.linspace(0, 1, 3)
a
Array([0. , 0.5, 1. ], dtype=float64)
a[0] = 1
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[20], line 1
----> 1 a[0] = 1
File ~/miniconda3/envs/quantecon/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py:621, in _unimplemented_setitem(self, i, x)
617 def _unimplemented_setitem(self, i, x):
618 msg = ("JAX arrays are immutable and do not support in-place item assignment."
619 " Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:"
620 " https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html")
--> 621 raise TypeError(msg.format(type(self)))
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html
In line with immutability, JAX does not support inplace operations:
a = np.array((2, 1))
a.sort() # Unlike NumPy, does not mutate a
a
array([1, 2])
a = jnp.array((2, 1))
a_new = a.sort() # Instead, the sort method returns a new sorted array
a, a_new
(Array([2, 1], dtype=int64), Array([1, 2], dtype=int64))
The designers of JAX chose to make arrays immutable because JAX uses a functional programming style.
This design choice has important implications, which we explore next!
14.1.2.3. A workaround#
We note that JAX does provide a version of in-place array modification
using the at method.
a = jnp.linspace(0, 1, 3)
Applying at[0].set(1) returns a new copy of a with the first element set to 1
a = a.at[0].set(1)
a
Array([1. , 0.5, 1. ], dtype=float64)
Obviously, there are downsides to using at:
The syntax is cumbersome and
we want to avoid creating fresh arrays in memory every time we change a single value!
Hence, for the most part, we try to avoid this syntax.
(Although it can in fact be efficient inside JIT-compiled functions – but let’s put this aside for now.)
14.2. Functional Programming#
From JAX’s documentation:
When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.
In other words, JAX assumes a functional programming style.
14.2.1. Pure functions#
The major implication is that JAX functions should be pure.
Pure functions have the following characteristics:
Deterministic
No side effects
Deterministic means
Same input \(\implies\) same output
Outputs do not depend on global state
In particular, pure functions will always return the same result if invoked with the same inputs.
No side effects means that the function
Won’t change global state
Won’t modify data passed to the function (immutable data)
14.2.2. Examples#
Here’s an example of a non-pure function
tax_rate = 0.1
prices = [10.0, 20.0]
def add_tax(prices):
for i, price in enumerate(prices):
prices[i] = price * (1 + tax_rate)
print('Post-tax prices: ', prices)
return prices
This function fails to be pure because
side effects — it modifies the global variable
pricesnon-deterministic — a change to the global variable
tax_ratewill modify function outputs, even with the same input arrayprices.
Here’s a pure version
tax_rate = 0.1
prices = (10.0, 20.0)
def add_tax_pure(prices, tax_rate):
new_prices = [price * (1 + tax_rate) for price in prices]
return new_prices
This pure version makes all dependencies explicit through function arguments, and doesn’t modify any external state.
Now that we understand what pure functions are, let’s explore how JAX’s approach to random numbers maintains this purity.
14.3. Random numbers#
Random numbers are rather different in JAX, compared to what you find in NumPy or Matlab.
At first you might find the syntax rather verbose.
But you will soon realize that the syntax and semantics are necessary in order to maintain the functional programming style we just discussed.
Moreover, full control of random state essential for parallel programming, such as when we want to run independent experiments along multiple threads.
14.3.1. Random number generation#
In JAX, the state of the random number generator is controlled explicitly.
First we produce a key, which seeds the random number generator.
seed = 1234
key = jax.random.PRNGKey(seed)
Now we can use the key to generate some random numbers:
x = jax.random.normal(key, (3, 3))
x
Array([[-0.54019824, 0.43957585, -0.01978102],
[ 0.90665474, -0.90831359, 1.32846635],
[ 0.20408174, 0.93096529, 3.30373914]], dtype=float64)
If we use the same key again, we initialize at the same seed, so the random numbers are the same:
jax.random.normal(key, (3, 3))
Array([[-0.54019824, 0.43957585, -0.01978102],
[ 0.90665474, -0.90831359, 1.32846635],
[ 0.20408174, 0.93096529, 3.30373914]], dtype=float64)
To produce a (quasi-) independent draw, one option is to “split” the existing key:
key, subkey = jax.random.split(key)
jax.random.normal(key, (3, 3))
Array([[ 1.24104247, 0.12018902, -2.23990047],
[ 0.70507261, -0.85702845, -1.24582014],
[ 0.38454486, 1.32117717, 0.56866901]], dtype=float64)
jax.random.normal(subkey, (3, 3))
Array([[ 0.07627173, -1.30349831, 0.86524323],
[-0.75550773, 0.63958052, 0.47052126],
[-1.72866044, -1.14696564, -1.23328892]], dtype=float64)
This syntax will seem unusual for a NumPy or Matlab user — but will make a lot of sense when we progress to parallel programming.
The function below produces k (quasi-) independent random n x n matrices using split.
def gen_random_matrices(key, n=2, k=3):
matrices = []
for _ in range(k):
key, subkey = jax.random.split(key)
A = jax.random.uniform(subkey, (n, n))
matrices.append(A)
print(A)
return matrices
seed = 42
key = jax.random.PRNGKey(seed)
matrices = gen_random_matrices(key)
[[0.74211901 0.54715578]
[0.05988742 0.32206803]]
[[0.65877976 0.57087415]
[0.97301903 0.10138266]]
[[0.68745522 0.25974132]
[0.06595873 0.83589118]]
We can also use fold_in when iterating in a loop:
def gen_random_matrices(key, n=2, k=3):
matrices = []
for i in range(k):
step_key = jax.random.fold_in(key, i)
A = jax.random.uniform(step_key, (n, n))
matrices.append(A)
print(A)
return matrices
key = jax.random.PRNGKey(seed)
matrices = gen_random_matrices(key)
[[0.23566993 0.39719189]
[0.95367373 0.42397776]]
[[0.74211901 0.54715578]
[0.05988742 0.32206803]]
[[0.37386727 0.66444882]
[0.80253222 0.42934555]]
14.3.2. Why explicit random state?#
Why does JAX require this somewhat verbose approach to random number generation?
One reason is to maintain pure functions.
Let’s see how random number generation relates to pure functions by comparing NumPy and JAX.
14.3.2.1. NumPy’s approach#
In NumPy, random number generation works by maintaining hidden global state.
Each time we call a random function, this state is updated:
np.random.seed(42)
print(np.random.randn()) # Updates state of random number generator
print(np.random.randn()) # Updates state of random number generator
0.4967141530112327
-0.13826430117118466
Each call returns a different value, even though we’re calling the same function with the same inputs (no arguments).
This function is not pure because:
It’s non-deterministic: same inputs (none, in this case) give different outputs
It has side effects: it modifies the global random number generator state
14.3.2.2. JAX’s approach#
As we saw above, JAX takes a different approach, making randomness explicit through keys.
For example,
def random_sum_jax(key):
key1, key2 = jax.random.split(key)
x = jax.random.normal(key1)
y = jax.random.normal(key2)
return x + y
With the same key, we always get the same result:
key = jax.random.PRNGKey(42)
random_sum_jax(key)
Array(-0.07040872, dtype=float64)
random_sum_jax(key)
Array(-0.07040872, dtype=float64)
To get new draws we need to supply a new key.
The function random_sum_jax is pure because:
It’s deterministic: same key always produces same output
No side effects: no hidden state is modified
The explicitness of JAX brings significant benefits:
Reproducibility: Easy to reproduce results by reusing keys
Parallelization: Each thread can have its own key without conflicts
Debugging: No hidden state makes code easier to reason about
JIT compatibility: The compiler can optimize pure functions more aggressively
The last point is expanded on in the next section.
14.4. JIT compilation#
The JAX just-in-time (JIT) compiler accelerates execution by generating efficient machine code that varies with both task size and hardware.
14.4.1. A simple example#
Let’s say we want to evaluate the cosine function at many points.
n = 50_000_000
x = np.linspace(0, 10, n)
14.4.1.1. With NumPy#
Let’s try with NumPy
with qe.Timer():
y = np.cos(x)
0.83 seconds elapsed
And one more time.
with qe.Timer():
y = np.cos(x)
0.88 seconds elapsed
Here NumPy uses a pre-built binary file, compiled from carefully written low-level code, for applying cosine to an array of floats.
This binary file ships with NumPy.
14.4.1.2. With JAX#
Now let’s try with JAX.
x = jnp.linspace(0, 10, n)
Let’s time the same procedure.
with qe.Timer():
y = jnp.cos(x)
jax.block_until_ready(y);
0.36 seconds elapsed
Note
Here, in order to measure actual speed, we use the block_until_ready method
to hold the interpreter until the results of the computation are returned.
This is necessary because JAX uses asynchronous dispatch, which allows the Python interpreter to run ahead of numerical computations.
For non-timed code, you can drop the line containing block_until_ready.
And let’s time it again.
with qe.Timer():
y = jnp.cos(x)
jax.block_until_ready(y);
0.35 seconds elapsed
If you are running this on a GPU the code will run much faster than its NumPy equivalent, which ran on the CPU.
Even if you are running on a machine with many CPUs, the second JAX run should be substantially faster with JAX.
Also, typically, the second run is faster than the first.
(This might not be noticable on the CPU but it should definitely be noticable on the GPU.)
This is because even built in functions like jnp.cos are JIT-compiled — and the
first run includes compile time.
Why would JAX want to JIT-compile built in functions like jnp.cos instead of
just providing pre-compiled versions, like NumPy?
The reason is that the JIT compiler wants to specialize on the size of the array being used (as well as the data type).
The size matters for generating optimized code because efficient parallelization requires matching the size of the task to the available hardware.
That’s why JAX waits to see the size of the array before compiling — which requires a JIT-compiled approach instead of supplying precompiled binaries.
14.4.1.3. Changing array sizes#
Here we change the input size and watch the runtimes.
x = jnp.linspace(0, 10, n + 1)
with qe.Timer():
y = jnp.cos(x)
jax.block_until_ready(y);
0.37 seconds elapsed
with qe.Timer():
y = jnp.cos(x)
jax.block_until_ready(y);
0.35 seconds elapsed
Typically, the run time increases and then falls again (this will be more obvious on the GPU).
This is because the JIT compiler specializes on array size to exploit parallelization — and hence generates fresh compiled code when the array size changes.
14.4.2. Evaluating a more complicated function#
Let’s try the same thing with a more complex function.
def f(x):
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2
return y
14.4.2.1. With NumPy#
We’ll try first with NumPy
n = 50_000_000
x = np.linspace(0, 10, n)
with qe.Timer():
y = f(x)
2.90 seconds elapsed
14.4.2.2. With JAX#
Now let’s try again with JAX.
As a first pass, we replace np with jnp throughout:
def f(x):
y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
return y
Now let’s time it.
x = jnp.linspace(0, 10, n)
with qe.Timer():
y = f(x)
jax.block_until_ready(y);
1.19 seconds elapsed
with qe.Timer():
y = f(x)
jax.block_until_ready(y);
1.07 seconds elapsed
The outcome is similar to the cos example — JAX is faster, especially if you
use a GPU and especially on the second run.
Moreover, with JAX, we have another trick up our sleeve:
14.4.3. Compiling the Whole Function#
The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing linear algebra operations into a single optimized kernel.
Let’s try this with the function f:
f_jax = jax.jit(f)
with qe.Timer():
y = f_jax(x)
jax.block_until_ready(y);
0.58 seconds elapsed
with qe.Timer():
y = f_jax(x)
jax.block_until_ready(y);
0.55 seconds elapsed
The runtime has improved again — now because we fused all the operations, allowing the compiler to optimize more aggressively.
For example, the compiler can eliminate multiple calls to the hardware accelerator and the creation of a number of intermediate arrays.
Incidentally, a more common syntax when targeting a function for the JIT compiler is
@jax.jit
def f(x):
pass # put function body here
14.4.4. Compiling non-pure functions#
Now that we’ve seen how powerful JIT compilation can be, it’s important to understand its relationship with pure functions.
While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable.
Here’s an illustration of this fact, using global variables:
a = 1 # global
@jax.jit
def f(x):
return a + x
x = jnp.ones(2)
f(x)
Array([2., 2.], dtype=float64)
In the code above, the global value a=1 is fused into the jitted function.
Even if we change a, the output of f will not be affected — as long as the same compiled version is called.
a = 42
f(x)
Array([2., 2.], dtype=float64)
Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of a takes effect:
x = jnp.ones(3)
f(x)
Array([43., 43., 43.], dtype=float64)
Moral of the story: write pure functions when using JAX!
14.4.5. Summary#
Now we can see why both developers and compilers benefit from pure functions.
We love pure functions because they
Help testing: each function can operate in isolation
Promote deterministic behavior and hence reproducibility
Prevent bugs that arise from mutating shared state
The compiler loves pure functions and functional programming because
Data dependencies are explicit, which helps with optimizing complex computations
Pure functions are easier to differentiate (autodiff)
Pure functions are easier to parallelize and optimize (don’t depend on shared mutable state)
14.5. Gradients#
JAX can use automatic differentiation to compute gradients.
This can be extremely useful for optimization and solving nonlinear systems.
We will see significant applications later in this lecture series.
For now, here’s a very simple illustration involving the function
def f(x):
return (x**2) / 2
Let’s take the derivative:
f_prime = jax.grad(f)
f_prime(10.0)
Array(10., dtype=float64, weak_type=True)
Let’s plot the function and derivative, noting that \(f'(x) = x\).
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
x_grid = jnp.linspace(-4, 4, 200)
ax.plot(x_grid, f(x_grid), label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend(loc='upper center')
plt.show()
We defer further exploration of automatic differentiation with JAX until Adventures with Autodiff.
14.6. Exercises#
Exercise 14.1
In the Exercise section of our lecture on Numba, we used Monte Carlo to price a European call option.
The code was accelerated by Numba-based multithreading.
Try writing a version of this operation for JAX, using all the same parameters.
Solution to Exercise 14.1
Here is one solution:
M = 10_000_000
n, β, K = 20, 0.99, 100
μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0
@jax.jit
def compute_call_price_jax(β=β,
μ=μ,
S0=S0,
h0=h0,
K=K,
n=n,
ρ=ρ,
ν=ν,
M=M,
key=jax.random.PRNGKey(1)):
s = jnp.full(M, np.log(S0))
h = jnp.full(M, h0)
for t in range(n):
key, subkey = jax.random.split(key)
Z = jax.random.normal(subkey, (2, M))
s = s + μ + jnp.exp(h) * Z[0, :]
h = ρ * h + ν * Z[1, :]
expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
return β**n * expectation
Let’s run it once to compile it:
with qe.Timer():
compute_call_price_jax().block_until_ready()
16.83 seconds elapsed
And now let’s time it:
with qe.Timer():
compute_call_price_jax().block_until_ready()
14.83 seconds elapsed