14. JAX#

In addition to what’s in Anaconda, this lecture will need the following libraries:

!pip install jax quantecon

Hide code cell output

Collecting jax
  Downloading jax-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting quantecon
  Downloading quantecon-0.10.1-py3-none-any.whl.metadata (5.3 kB)
Collecting jaxlib<=0.8.1,>=0.8.1 (from jax)
  Downloading jaxlib-0.8.1-cp313-cp313-manylinux_2_27_x86_64.whl.metadata (1.3 kB)
Collecting ml_dtypes>=0.5.0 (from jax)
  Downloading ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.9 kB)
Requirement already satisfied: numpy>=2.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (2.1.3)
Collecting opt_einsum (from jax)
  Downloading opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Requirement already satisfied: scipy>=1.13 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (1.15.3)
Requirement already satisfied: numba>=0.49.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (0.61.0)
Requirement already satisfied: requests in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.32.3)
Requirement already satisfied: sympy in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.13.3)
Requirement already satisfied: llvmlite<0.45,>=0.44.0dev0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from numba>=0.49.0->quantecon) (0.44.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2.3.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2025.4.26)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from sympy->quantecon) (1.3.0)
Downloading jax-0.8.1-py3-none-any.whl (2.9 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/2.9 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.9/2.9 MB 137.4 MB/s eta 0:00:00
?25hDownloading jaxlib-0.8.1-cp313-cp313-manylinux_2_27_x86_64.whl (80.3 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/80.3 MB ? eta -:--:--
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.6/80.3 MB 67.8 MB/s eta 0:00:01
   ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━ 32.5/80.3 MB 81.4 MB/s eta 0:00:01
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━ 67.9/80.3 MB 112.7 MB/s eta 0:00:01
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 80.3/80.3 MB 104.3 MB/s eta 0:00:00
?25hDownloading quantecon-0.10.1-py3-none-any.whl (325 kB)
Downloading ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (5.0 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/5.0 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.0/5.0 MB 135.8 MB/s eta 0:00:00
?25hDownloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
Installing collected packages: opt_einsum, ml_dtypes, quantecon, jaxlib, jax
?25l
   ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━ 2/5 [quantecon]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 3/5 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 4/5 [jax]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5/5 [jax]
?25h
Successfully installed jax-0.8.1 jaxlib-0.8.1 ml_dtypes-0.5.4 opt_einsum-3.4.0 quantecon-0.10.1

This lecture provides a short introduction to Google JAX.

Here we are focused on using JAX on the CPU, rather than on accelerators such as GPUs or TPUs.

This means we will only see a small amount of the possible benefits from using JAX.

However, JAX seamlessly handles transitions across different hardware platforms.

As a result, if you run this code on a machine with a GPU and a GPU-aware version of JAX installed, your code will be automatically accelerated and you will receive the full benefits.

For a discussion of JAX on GPUs, see our JAX lecture series.

14.1. JAX as a NumPy Replacement#

One of the attractive features of JAX is that, whenever possible, it conforms to the NumPy API for array operations.

This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement.

Let’s look at the similarities and differences between JAX and NumPy.

14.1.1. Similarities#

We’ll use the following imports

import jax
import quantecon as qe

In addition, we replace import numpy as np with

import jax.numpy as jnp

Now we can use jnp in place of np for the usual array operations:

a = jnp.asarray((1.0, 3.2, -1.5))
print(a)
[ 1.   3.2 -1.5]
print(jnp.sum(a))
2.6999998
print(jnp.mean(a))
0.9
print(jnp.dot(a, a))
13.490001

However, the array object a is not a NumPy array:

a
Array([ 1. ,  3.2, -1.5], dtype=float32)
type(a)
jaxlib._jax.ArrayImpl

Even scalar-valued maps on arrays return JAX arrays.

jnp.sum(a)
Array(2.6999998, dtype=float32)

Operations on higher dimensional arrays are also similar to NumPy:

A = jnp.ones((2, 2))
B = jnp.identity(2)
A @ B
Array([[1., 1.],
       [1., 1.]], dtype=float32)

JAX’s array interface also provides the linalg subpackage:

jnp.linalg.inv(B)   # Inverse of identity is identity
Array([[1., 0.],
       [0., 1.]], dtype=float32)
jnp.linalg.eigh(B)  # Computes eigenvalues and eigenvectors
EighResult(eigenvalues=Array([1., 1.], dtype=float32), eigenvectors=Array([[1., 0.],
       [0., 1.]], dtype=float32))

14.1.2. Differences#

Let’s now look at some differences between JAX and NumPy array operations.

14.1.2.1. Precision#

One difference between NumPy and JAX is that JAX uses 32 bit floats by default.

This is because JAX is often used for GPU computing, and most GPU computations use 32 bit floats.

Using 32 bit floats can lead to significant speed gains with small loss of precision.

However, for some calculations precision matters.

In these cases 64 bit floats can be enforced via the command

jax.config.update("jax_enable_x64", True)

Let’s check this works:

jnp.ones(3)
Array([1., 1., 1.], dtype=float64)

14.1.2.2. Immutability#

As a NumPy replacement, a more significant difference is that arrays are treated as immutable.

For example, with NumPy we can write

import numpy as np
a = np.linspace(0, 1, 3)
a
array([0. , 0.5, 1. ])

and then mutate the data in memory:

a[0] = 1
a
array([1. , 0.5, 1. ])

In JAX this fails:

a = jnp.linspace(0, 1, 3)
a
Array([0. , 0.5, 1. ], dtype=float64)
a[0] = 1
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[20], line 1
----> 1 a[0] = 1

File ~/miniconda3/envs/quantecon/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py:621, in _unimplemented_setitem(self, i, x)
    617 def _unimplemented_setitem(self, i, x):
    618   msg = ("JAX arrays are immutable and do not support in-place item assignment."
    619          " Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:"
    620          " https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html")
--> 621   raise TypeError(msg.format(type(self)))

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

In line with immutability, JAX does not support inplace operations:

a = np.array((2, 1))
a.sort()    # Unlike NumPy, does not mutate a
a
array([1, 2])
a = jnp.array((2, 1))
a_new = a.sort()   # Instead, the sort method returns a new sorted array
a, a_new
(Array([2, 1], dtype=int64), Array([1, 2], dtype=int64))

The designers of JAX chose to make arrays immutable because JAX uses a functional programming style.

This design choice has important implications, which we explore next!

14.1.2.3. A workaround#

We note that JAX does provide a version of in-place array modification using the at method.

a = jnp.linspace(0, 1, 3)

Applying at[0].set(1) returns a new copy of a with the first element set to 1

a = a.at[0].set(1)
a
Array([1. , 0.5, 1. ], dtype=float64)

Obviously, there are downsides to using at:

  • The syntax is cumbersome and

  • we want to avoid creating fresh arrays in memory every time we change a single value!

Hence, for the most part, we try to avoid this syntax.

(Although it can in fact be efficient inside JIT-compiled functions – but let’s put this aside for now.)

14.2. Functional Programming#

From JAX’s documentation:

When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.

In other words, JAX assumes a functional programming style.

14.2.1. Pure functions#

The major implication is that JAX functions should be pure.

Pure functions have the following characteristics:

  1. Deterministic

  2. No side effects

Deterministic means

  • Same input \(\implies\) same output

  • Outputs do not depend on global state

In particular, pure functions will always return the same result if invoked with the same inputs.

No side effects means that the function

  • Won’t change global state

  • Won’t modify data passed to the function (immutable data)

14.2.2. Examples#

Here’s an example of a non-pure function

tax_rate = 0.1
prices = [10.0, 20.0]

def add_tax(prices):
    for i, price in enumerate(prices):
        prices[i] = price * (1 + tax_rate)
    print('Post-tax prices: ', prices)
    return prices

This function fails to be pure because

  • side effects — it modifies the global variable prices

  • non-deterministic — a change to the global variable tax_rate will modify function outputs, even with the same input array prices.

Here’s a pure version

tax_rate = 0.1
prices = (10.0, 20.0)

def add_tax_pure(prices, tax_rate):
    new_prices = [price * (1 + tax_rate) for price in prices]
    return new_prices

This pure version makes all dependencies explicit through function arguments, and doesn’t modify any external state.

Now that we understand what pure functions are, let’s explore how JAX’s approach to random numbers maintains this purity.

14.3. Random numbers#

Random numbers are rather different in JAX, compared to what you find in NumPy or Matlab.

At first you might find the syntax rather verbose.

But you will soon realize that the syntax and semantics are necessary in order to maintain the functional programming style we just discussed.

Moreover, full control of random state essential for parallel programming, such as when we want to run independent experiments along multiple threads.

14.3.1. Random number generation#

In JAX, the state of the random number generator is controlled explicitly.

First we produce a key, which seeds the random number generator.

seed = 1234
key = jax.random.PRNGKey(seed)

Now we can use the key to generate some random numbers:

x = jax.random.normal(key, (3, 3))
x
Array([[-0.54019824,  0.43957585, -0.01978102],
       [ 0.90665474, -0.90831359,  1.32846635],
       [ 0.20408174,  0.93096529,  3.30373914]], dtype=float64)

If we use the same key again, we initialize at the same seed, so the random numbers are the same:

jax.random.normal(key, (3, 3))
Array([[-0.54019824,  0.43957585, -0.01978102],
       [ 0.90665474, -0.90831359,  1.32846635],
       [ 0.20408174,  0.93096529,  3.30373914]], dtype=float64)

To produce a (quasi-) independent draw, one option is to “split” the existing key:

key, subkey = jax.random.split(key)
jax.random.normal(key, (3, 3))
Array([[ 1.24104247,  0.12018902, -2.23990047],
       [ 0.70507261, -0.85702845, -1.24582014],
       [ 0.38454486,  1.32117717,  0.56866901]], dtype=float64)
jax.random.normal(subkey, (3, 3))
Array([[ 0.07627173, -1.30349831,  0.86524323],
       [-0.75550773,  0.63958052,  0.47052126],
       [-1.72866044, -1.14696564, -1.23328892]], dtype=float64)

This syntax will seem unusual for a NumPy or Matlab user — but will make a lot of sense when we progress to parallel programming.

The function below produces k (quasi-) independent random n x n matrices using split.

def gen_random_matrices(key, n=2, k=3):
    matrices = []
    for _ in range(k):
        key, subkey = jax.random.split(key)
        A = jax.random.uniform(subkey, (n, n))
        matrices.append(A)
        print(A)
    return matrices
seed = 42
key = jax.random.PRNGKey(seed)
matrices = gen_random_matrices(key)
[[0.74211901 0.54715578]
 [0.05988742 0.32206803]]
[[0.65877976 0.57087415]
 [0.97301903 0.10138266]]
[[0.68745522 0.25974132]
 [0.06595873 0.83589118]]

We can also use fold_in when iterating in a loop:

def gen_random_matrices(key, n=2, k=3):
    matrices = []
    for i in range(k):
        step_key = jax.random.fold_in(key, i)
        A = jax.random.uniform(step_key, (n, n))
        matrices.append(A)
        print(A)
    return matrices
key = jax.random.PRNGKey(seed)
matrices = gen_random_matrices(key)
[[0.23566993 0.39719189]
 [0.95367373 0.42397776]]
[[0.74211901 0.54715578]
 [0.05988742 0.32206803]]
[[0.37386727 0.66444882]
 [0.80253222 0.42934555]]

14.3.2. Why explicit random state?#

Why does JAX require this somewhat verbose approach to random number generation?

One reason is to maintain pure functions.

Let’s see how random number generation relates to pure functions by comparing NumPy and JAX.

14.3.2.1. NumPy’s approach#

In NumPy, random number generation works by maintaining hidden global state.

Each time we call a random function, this state is updated:

np.random.seed(42)
print(np.random.randn())   # Updates state of random number generator
print(np.random.randn())   # Updates state of random number generator
0.4967141530112327
-0.13826430117118466

Each call returns a different value, even though we’re calling the same function with the same inputs (no arguments).

This function is not pure because:

  • It’s non-deterministic: same inputs (none, in this case) give different outputs

  • It has side effects: it modifies the global random number generator state

14.3.2.2. JAX’s approach#

As we saw above, JAX takes a different approach, making randomness explicit through keys.

For example,

def random_sum_jax(key):
    key1, key2 = jax.random.split(key)
    x = jax.random.normal(key1)
    y = jax.random.normal(key2)
    return x + y

With the same key, we always get the same result:

key = jax.random.PRNGKey(42)
random_sum_jax(key)
Array(-0.07040872, dtype=float64)
random_sum_jax(key)
Array(-0.07040872, dtype=float64)

To get new draws we need to supply a new key.

The function random_sum_jax is pure because:

  • It’s deterministic: same key always produces same output

  • No side effects: no hidden state is modified

The explicitness of JAX brings significant benefits:

  • Reproducibility: Easy to reproduce results by reusing keys

  • Parallelization: Each thread can have its own key without conflicts

  • Debugging: No hidden state makes code easier to reason about

  • JIT compatibility: The compiler can optimize pure functions more aggressively

The last point is expanded on in the next section.

14.4. JIT compilation#

The JAX just-in-time (JIT) compiler accelerates execution by generating efficient machine code that varies with both task size and hardware.

14.4.1. A simple example#

Let’s say we want to evaluate the cosine function at many points.

n = 50_000_000
x = np.linspace(0, 10, n)

14.4.1.1. With NumPy#

Let’s try with NumPy

with qe.Timer():
    y = np.cos(x)
0.83 seconds elapsed

And one more time.

with qe.Timer():
    y = np.cos(x)
0.88 seconds elapsed

Here NumPy uses a pre-built binary file, compiled from carefully written low-level code, for applying cosine to an array of floats.

This binary file ships with NumPy.

14.4.1.2. With JAX#

Now let’s try with JAX.

x = jnp.linspace(0, 10, n)

Let’s time the same procedure.

with qe.Timer():
    y = jnp.cos(x)
    jax.block_until_ready(y);
0.36 seconds elapsed

Note

Here, in order to measure actual speed, we use the block_until_ready method to hold the interpreter until the results of the computation are returned.

This is necessary because JAX uses asynchronous dispatch, which allows the Python interpreter to run ahead of numerical computations.

For non-timed code, you can drop the line containing block_until_ready.

And let’s time it again.

with qe.Timer():
    y = jnp.cos(x)
    jax.block_until_ready(y);
0.35 seconds elapsed

If you are running this on a GPU the code will run much faster than its NumPy equivalent, which ran on the CPU.

Even if you are running on a machine with many CPUs, the second JAX run should be substantially faster with JAX.

Also, typically, the second run is faster than the first.

(This might not be noticable on the CPU but it should definitely be noticable on the GPU.)

This is because even built in functions like jnp.cos are JIT-compiled — and the first run includes compile time.

Why would JAX want to JIT-compile built in functions like jnp.cos instead of just providing pre-compiled versions, like NumPy?

The reason is that the JIT compiler wants to specialize on the size of the array being used (as well as the data type).

The size matters for generating optimized code because efficient parallelization requires matching the size of the task to the available hardware.

That’s why JAX waits to see the size of the array before compiling — which requires a JIT-compiled approach instead of supplying precompiled binaries.

14.4.1.3. Changing array sizes#

Here we change the input size and watch the runtimes.

x = jnp.linspace(0, 10, n + 1)
with qe.Timer():
    y = jnp.cos(x)
    jax.block_until_ready(y);
0.37 seconds elapsed
with qe.Timer():
    y = jnp.cos(x)
    jax.block_until_ready(y);
0.35 seconds elapsed

Typically, the run time increases and then falls again (this will be more obvious on the GPU).

This is because the JIT compiler specializes on array size to exploit parallelization — and hence generates fresh compiled code when the array size changes.

14.4.2. Evaluating a more complicated function#

Let’s try the same thing with a more complex function.

def f(x):
    y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2
    return y

14.4.2.1. With NumPy#

We’ll try first with NumPy

n = 50_000_000
x = np.linspace(0, 10, n)
with qe.Timer():
    y = f(x)
2.90 seconds elapsed

14.4.2.2. With JAX#

Now let’s try again with JAX.

As a first pass, we replace np with jnp throughout:

def f(x):
    y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
    return y

Now let’s time it.

x = jnp.linspace(0, 10, n)
with qe.Timer():
    y = f(x)
    jax.block_until_ready(y);
1.19 seconds elapsed
with qe.Timer():
    y = f(x)
    jax.block_until_ready(y);
1.07 seconds elapsed

The outcome is similar to the cos example — JAX is faster, especially if you use a GPU and especially on the second run.

Moreover, with JAX, we have another trick up our sleeve:

14.4.3. Compiling the Whole Function#

The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing linear algebra operations into a single optimized kernel.

Let’s try this with the function f:

f_jax = jax.jit(f)
with qe.Timer():
    y = f_jax(x)
    jax.block_until_ready(y);
0.58 seconds elapsed
with qe.Timer():
    y = f_jax(x)
    jax.block_until_ready(y);
0.55 seconds elapsed

The runtime has improved again — now because we fused all the operations, allowing the compiler to optimize more aggressively.

For example, the compiler can eliminate multiple calls to the hardware accelerator and the creation of a number of intermediate arrays.

Incidentally, a more common syntax when targeting a function for the JIT compiler is

@jax.jit
def f(x):
    pass # put function body here

14.4.4. Compiling non-pure functions#

Now that we’ve seen how powerful JIT compilation can be, it’s important to understand its relationship with pure functions.

While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable.

Here’s an illustration of this fact, using global variables:

a = 1  # global

@jax.jit
def f(x):
    return a + x
x = jnp.ones(2)
f(x)
Array([2., 2.], dtype=float64)

In the code above, the global value a=1 is fused into the jitted function.

Even if we change a, the output of f will not be affected — as long as the same compiled version is called.

a = 42
f(x)
Array([2., 2.], dtype=float64)

Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of a takes effect:

x = jnp.ones(3)
f(x)
Array([43., 43., 43.], dtype=float64)

Moral of the story: write pure functions when using JAX!

14.4.5. Summary#

Now we can see why both developers and compilers benefit from pure functions.

We love pure functions because they

  • Help testing: each function can operate in isolation

  • Promote deterministic behavior and hence reproducibility

  • Prevent bugs that arise from mutating shared state

The compiler loves pure functions and functional programming because

  • Data dependencies are explicit, which helps with optimizing complex computations

  • Pure functions are easier to differentiate (autodiff)

  • Pure functions are easier to parallelize and optimize (don’t depend on shared mutable state)

14.5. Gradients#

JAX can use automatic differentiation to compute gradients.

This can be extremely useful for optimization and solving nonlinear systems.

We will see significant applications later in this lecture series.

For now, here’s a very simple illustration involving the function

def f(x):
    return (x**2) / 2

Let’s take the derivative:

f_prime = jax.grad(f)
f_prime(10.0)
Array(10., dtype=float64, weak_type=True)

Let’s plot the function and derivative, noting that \(f'(x) = x\).

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
x_grid = jnp.linspace(-4, 4, 200)
ax.plot(x_grid, f(x_grid), label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend(loc='upper center')
plt.show()
_images/cce2615e7d89337351ba1ee141377cf88c739cc0199e653bb3db43a1071c2cac.png

We defer further exploration of automatic differentiation with JAX until Adventures with Autodiff.

14.6. Exercises#

Exercise 14.1

In the Exercise section of our lecture on Numba, we used Monte Carlo to price a European call option.

The code was accelerated by Numba-based multithreading.

Try writing a version of this operation for JAX, using all the same parameters.