14. JAX#

This lecture provides a short introduction to Google JAX.

JAX is a high-performance scientific computing library that provides

  • a NumPy-like interface that can automatically parallelize across CPUs and GPUs,

  • a just-in-time compiler for accelerating a large range of numerical operations, and

  • automatic differentiation.

Increasingly, JAX also maintains and provides more specialized scientific computing routines, such as those originally found in SciPy.

In addition to what’s in Anaconda, this lecture will need the following libraries:

!pip install jax quantecon

Hide code cell output

Requirement already satisfied: jax in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (0.9.2)
Collecting quantecon
  Downloading quantecon-0.11.2-py3-none-any.whl.metadata (5.3 kB)
Requirement already satisfied: jaxlib<=0.9.2,>=0.9.2 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.9.2)
Requirement already satisfied: ml_dtypes>=0.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.5.4)
Requirement already satisfied: numpy>=2.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (2.3.5)
Requirement already satisfied: opt_einsum in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.13 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (1.16.3)
Requirement already satisfied: numba>=0.49.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (0.62.1)
Requirement already satisfied: requests in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.32.5)
Requirement already satisfied: sympy in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.14.0)
Requirement already satisfied: llvmlite<0.46,>=0.45.0dev0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from numba>=0.49.0->quantecon) (0.45.1)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2025.11.12)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from sympy->quantecon) (1.3.0)
Downloading quantecon-0.11.2-py3-none-any.whl (330 kB)
Installing collected packages: quantecon
Successfully installed quantecon-0.11.2

GPU

This lecture was built using a machine with access to a GPU — although it will also run without one.

Google Colab has a free tier with GPUs that you can access as follows:

  1. Click on the “play” icon top right

  2. Select Colab

  3. Set the runtime environment to include a GPU

14.1. JAX as a NumPy Replacement#

One of the attractive features of JAX is that, whenever possible, its array processing operations conform to the NumPy API.

This means that, in many cases, we can use JAX as a drop-in NumPy replacement.

Let’s look at the similarities and differences between JAX and NumPy.

14.1.1. Similarities#

We’ll use the following imports

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import quantecon as qe

Notice that we import jax.numpy as jnp, which provides a NumPy-like interface.

Here are some standard array operations using jnp:

a = jnp.asarray((1.0, 3.2, -1.5))
print(a)
[ 1.   3.2 -1.5]
print(jnp.sum(a))
2.6999998
print(jnp.mean(a))
0.9
print(jnp.dot(a, a))
13.490001

However, the array object a is not a NumPy array:

a
Array([ 1. ,  3.2, -1.5], dtype=float32)
type(a)
jaxlib._jax.ArrayImpl

Even scalar-valued maps on arrays return JAX arrays.

jnp.sum(a)
Array(2.6999998, dtype=float32)

Operations on higher dimensional arrays are also similar to NumPy:

A = jnp.ones((2, 2))
B = jnp.identity(2)
A @ B
Array([[1., 1.],
       [1., 1.]], dtype=float32)

JAX’s array interface also provides the linalg subpackage:

jnp.linalg.inv(B)   # Inverse of identity is identity
Array([[1., 0.],
       [0., 1.]], dtype=float32)
eigvals, eigvecs = jnp.linalg.eigh(B)  # Computes eigenvalues and eigenvectors
eigvals
Array([0.99999994, 0.99999994], dtype=float32)

14.1.2. Differences#

Let’s now look at some differences between JAX and NumPy array operations.

14.1.2.1. Speed!#

Let’s say we want to evaluate the cosine function at many points.

n = 50_000_000
x = np.linspace(0, 10, n)
14.1.2.1.1. With NumPy#

Let’s try with NumPy

with qe.Timer():
    y = np.cos(x)
0.6820 seconds elapsed

And one more time.

with qe.Timer():
    y = np.cos(x)
0.6836 seconds elapsed

Here

  • NumPy uses a pre-built binary for applying cosine to an array of floats

  • The binary runs on the local machine’s CPU

14.1.2.1.2. With JAX#

Now let’s try with JAX.

x = jnp.linspace(0, 10, n)

Let’s time the same procedure.

with qe.Timer():
    y = jnp.cos(x)
    jax.block_until_ready(y);
0.0848 seconds elapsed

Note

Here, in order to measure actual speed, we use the block_until_ready method to hold the interpreter until the results of the computation are returned.

This is necessary because JAX uses asynchronous dispatch, which allows the Python interpreter to run ahead of numerical computations.

For non-timed code, you can drop the line containing block_until_ready.

And let’s time it again.

with qe.Timer():
    y = jnp.cos(x)
    jax.block_until_ready(y);
0.0018 seconds elapsed

On a GPU, this code runs much faster than its NumPy equivalent.

Also, typically, the second run is faster than the first due to JIT compilation.

This is because even built in functions like jnp.cos are JIT-compiled — and the first run includes compile time.

Why would JAX want to JIT-compile built in functions like jnp.cos instead of just providing pre-compiled versions, like NumPy?

The reason is that the JIT compiler wants to specialize on the size of the array being used (as well as the data type).

The size matters for generating optimized code because efficient parallelization requires matching the size of the task to the available hardware.

We can verify the claim that JAX specializes on array size by changing the input size and watching the runtimes.

x = jnp.linspace(0, 10, n + 1)
with qe.Timer():
    y = jnp.cos(x)
    jax.block_until_ready(y);
0.0566 seconds elapsed
with qe.Timer():
    y = jnp.cos(x)
    jax.block_until_ready(y);
0.0021 seconds elapsed

The run time increases and then falls again (this will be more obvious on the GPU).

This is in line with the discussion above – the first run after changing array size shows compilation overhead.

Further discussion of JIT compilation is provided below.

14.1.2.2. Precision#

Another difference between NumPy and JAX is that JAX uses 32 bit floats by default.

This is because JAX is often used for GPU computing, and most GPU computations use 32 bit floats.

Using 32 bit floats can lead to significant speed gains with small loss of precision.

However, for some calculations precision matters.

In these cases 64 bit floats can be enforced via the command

jax.config.update("jax_enable_x64", True)

Let’s check this works:

jnp.ones(3)
Array([1., 1., 1.], dtype=float64)

14.1.2.3. Immutability#

As a NumPy replacement, a more significant difference is that arrays are treated as immutable.

For example, with NumPy we can write

a = np.linspace(0, 1, 3)
a
array([0. , 0.5, 1. ])

and then mutate the data in memory:

a[0] = 1
a
array([1. , 0.5, 1. ])

In JAX this fails!

a = jnp.linspace(0, 1, 3)
a
Array([0. , 0.5, 1. ], dtype=float64)
try:
    a[0] = 1
except Exception as e:
    print(e)
JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

The designers of JAX chose to make arrays immutable because JAX uses a functional programming style, which we discuss below.

14.1.2.4. A workaround#

We note that JAX does provide a version of in-place array modification using the at method.

a = jnp.linspace(0, 1, 3)

Applying at[0].set(1) returns a new copy of a with the first element set to 1

a = a.at[0].set(1)
a
Array([1. , 0.5, 1. ], dtype=float64)

Obviously, there are downsides to using at:

  • The syntax is cumbersome and

  • we want to avoid creating fresh arrays in memory every time we change a single value!

Hence, for the most part, we try to avoid this syntax.

(Although it can in fact be efficient inside JIT-compiled functions – but let’s put this aside for now.)

14.2. Functional Programming#

From JAX’s documentation:

When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.

In other words, JAX assumes a functional programming style.

14.2.1. Pure functions#

The major implication is that JAX functions should be pure.

Pure functions have the following characteristics:

  1. Deterministic

  2. No side effects

Deterministic means

  • Same input \(\implies\) same output

  • Outputs do not depend on global state

In particular, pure functions will always return the same result if invoked with the same inputs.

No side effects means that the function

  • Won’t change global state

  • Won’t modify data passed to the function (immutable data)

14.2.2. Examples#

Here’s an example of a non-pure function

tax_rate = 0.1
prices = [10.0, 20.0]

def add_tax(prices):
    for i, price in enumerate(prices):
        prices[i] = price * (1 + tax_rate)
    print('Post-tax prices: ', prices)
    return prices

This function fails to be pure because

  • side effects — it modifies the global variable prices

  • non-deterministic — a change to the global variable tax_rate will modify function outputs, even with the same input array prices.

Here’s a pure version

tax_rate = 0.1
prices = (10.0, 20.0)

def add_tax_pure(prices, tax_rate):
    new_prices = [price * (1 + tax_rate) for price in prices]
    return new_prices

This pure version makes all dependencies explicit through function arguments, and doesn’t modify any external state.

14.2.3. Why Functional Programming?#

JAX represents functions as computational graphs, which are then compiled or transformed (e.g., differentiated)

These computational graphs describe how a given set of inputs is transformed into an output.

They are pure by construction.

JAX uses a functional programming style so that user-built functions map directly into the graph-theoretic representations supported by JAX.

14.3. Random numbers#

Random number generation in JAX differs significantly from the patterns found in NumPy or MATLAB.

At first you might find the syntax rather verbose.

But the syntax and semantics are necessary to maintain the functional programming style we just discussed.

Moreover, full control of random state is essential for parallel programming, such as when we want to run independent experiments along multiple threads.

14.3.1. Random number generation#

In JAX, the state of the random number generator is controlled explicitly.

First we produce a key, which seeds the random number generator.

seed = 1234
key = jax.random.key(seed)

Now we can use the key to generate some random numbers:

x = jax.random.normal(key, (3, 3))
x
Array([[-0.54019824,  0.43957585, -0.01978102],
       [ 0.90665474, -0.90831359,  1.32846635],
       [ 0.20408174,  0.93096529,  3.30373914]], dtype=float64)

If we use the same key again, we initialize at the same seed, so the random numbers are the same:

jax.random.normal(key, (3, 3))
Array([[-0.54019824,  0.43957585, -0.01978102],
       [ 0.90665474, -0.90831359,  1.32846635],
       [ 0.20408174,  0.93096529,  3.30373914]], dtype=float64)

To produce a (quasi-) independent draw, one option is to “split” the existing key:

key, subkey = jax.random.split(key)
jax.random.normal(key, (3, 3))
Array([[ 1.24104247,  0.12018902, -2.23990047],
       [ 0.70507261, -0.85702845, -1.24582014],
       [ 0.38454486,  1.32117717,  0.56866901]], dtype=float64)
jax.random.normal(subkey, (3, 3))
Array([[ 0.07627173, -1.30349831,  0.86524323],
       [-0.75550773,  0.63958052,  0.47052126],
       [-1.72866044, -1.14696564, -1.23328892]], dtype=float64)

The following diagram illustrates how split produces a tree of keys from a single root, with each key generating independent random draws.

Hide code cell source

fig, ax = plt.subplots(figsize=(8, 4))
ax.set_xlim(-0.5, 6.5)
ax.set_ylim(-0.5, 3.5)
ax.set_aspect('equal')
ax.axis('off')

box_style = dict(boxstyle="round,pad=0.3", facecolor="white",
                 edgecolor="black", linewidth=1.5)
box_used = dict(boxstyle="round,pad=0.3", facecolor="#d4edda",
                edgecolor="black", linewidth=1.5)

# Root key
ax.text(3, 3, "key₀", ha='center', va='center', fontsize=11,
        bbox=box_style)

# Level 1
ax.annotate("", xy=(1.5, 2), xytext=(3, 2.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.annotate("", xy=(4.5, 2), xytext=(3, 2.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.text(1.5, 2, "key₁", ha='center', va='center', fontsize=11,
        bbox=box_style)
ax.text(4.5, 2, "subkey₁", ha='center', va='center', fontsize=11,
        bbox=box_used)
ax.text(5.7, 2, "→ draw", ha='left', va='center', fontsize=10,
        color='green')

# Label the split
ax.text(2, 2.65, "split", ha='center', va='center', fontsize=9,
        fontstyle='italic', color='gray')

# Level 2
ax.annotate("", xy=(0.5, 1), xytext=(1.5, 1.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.annotate("", xy=(2.5, 1), xytext=(1.5, 1.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.text(0.5, 1, "key₂", ha='center', va='center', fontsize=11,
        bbox=box_style)
ax.text(2.5, 1, "subkey₂", ha='center', va='center', fontsize=11,
        bbox=box_used)
ax.text(3.7, 1, "→ draw", ha='left', va='center', fontsize=10,
        color='green')

ax.text(0.7, 1.65, "split", ha='center', va='center', fontsize=9,
        fontstyle='italic', color='gray')

# Level 3
ax.annotate("", xy=(0, 0), xytext=(0.5, 0.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.annotate("", xy=(1.5, 0), xytext=(0.5, 0.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.text(0, 0, "key₃", ha='center', va='center', fontsize=11,
        bbox=box_style)
ax.text(1.5, 0, "subkey₃", ha='center', va='center', fontsize=11,
        bbox=box_used)
ax.text(2.7, 0, "→ draw", ha='left', va='center', fontsize=10,
        color='green')
ax.text(0, 0.65, "split", ha='center', va='center', fontsize=9,
        fontstyle='italic', color='gray')

ax.text(3, -0.5, "⋮", ha='center', va='center', fontsize=14)

ax.set_title("PRNG Key Splitting Tree", fontsize=13, pad=10)
plt.tight_layout()
plt.show()
_images/e72a3ecb3e9782169063d498d0dd7cfcaab6eff9fa6bf12b7f8549264b246ec2.png

This syntax will seem unusual for a NumPy or Matlab user — but will make a lot of sense when we progress to parallel programming.

The function below produces k (quasi-) independent random n x n matrices using split.

def gen_random_matrices(key, n=2, k=3):
    matrices = []
    for _ in range(k):
        key, subkey = jax.random.split(key)
        A = jax.random.uniform(subkey, (n, n))
        matrices.append(A)
        print(A)
    return matrices
seed = 42
key = jax.random.key(seed)
matrices = gen_random_matrices(key)
[[0.74211901 0.54715578]
 [0.05988742 0.32206803]]
[[0.65877976 0.57087415]
 [0.97301903 0.10138266]]
[[0.68745522 0.25974132]
 [0.06595873 0.83589118]]

We can also use fold_in when iterating in a loop:

def gen_random_matrices(key, n=2, k=3):
    matrices = []
    for i in range(k):
        step_key = jax.random.fold_in(key, i)
        A = jax.random.uniform(step_key, (n, n))
        matrices.append(A)
        print(A)
    return matrices
key = jax.random.key(seed)
matrices = gen_random_matrices(key)
[[0.23566993 0.39719189]
 [0.95367373 0.42397776]]
[[0.74211901 0.54715578]
 [0.05988742 0.32206803]]
[[0.37386727 0.66444882]
 [0.80253222 0.42934555]]

14.3.2. Why explicit random state?#

Why does JAX require this somewhat verbose approach to random number generation?

One reason is to maintain pure functions.

Let’s see how random number generation relates to pure functions by comparing NumPy and JAX.

14.3.2.1. NumPy’s approach#

In NumPy’s legacy random number generation API (which mimics MATLAB), generation works by maintaining hidden global state.

Each time we call a random function, this state is updated:

np.random.seed(42)
print(np.random.randn())   # Updates state of random number generator
print(np.random.randn())   # Updates state of random number generator
0.4967141530112327
-0.13826430117118466

Each call returns a different value, even though we’re calling the same function with the same inputs (no arguments).

This function is not pure because:

  • It’s non-deterministic: same inputs (none, in this case) give different outputs

  • It has side effects: it modifies the global random number generator state

14.3.2.2. JAX’s approach#

As we saw above, JAX takes a different approach, making randomness explicit through keys.

For example,

def random_sum_jax(key):
    key1, key2 = jax.random.split(key)
    x = jax.random.normal(key1)
    y = jax.random.normal(key2)
    return x + y

With the same key, we always get the same result:

key = jax.random.key(42)
random_sum_jax(key)
Array(-0.07040872, dtype=float64)
random_sum_jax(key)
Array(-0.07040872, dtype=float64)

To get new draws we need to supply a new key.

The function random_sum_jax is pure because:

  • It’s deterministic: same key always produces same output

  • No side effects: no hidden state is modified

The explicitness of JAX brings significant benefits:

  • Reproducibility: Easy to reproduce results by reusing keys

  • Parallelization: Each thread can have its own key without conflicts

  • Debugging: No hidden state makes code easier to reason about

  • JIT compatibility: The compiler can optimize pure functions more aggressively

The last point is expanded on in the next section.

14.4. JIT Compilation#

The JAX just-in-time (JIT) compiler accelerates execution by generating efficient machine code that varies with both task size and hardware.

We saw the power of JAX’s JIT compiler combined with parallel hardware when we above, when we applied cos to a large array.

Let’s try the same thing with a more complex function.

14.4.1. Evaluating a more complicated function#

Consider the function

def f(x):
    y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
    return y

14.4.1.1. With NumPy#

We’ll try first with NumPy

n = 50_000_000
x = np.linspace(0, 10, n)
with qe.Timer():
    y = f(x)
2.4429 seconds elapsed

14.4.1.2. With JAX#

Now let’s try again with JAX.

As a first pass, we replace np with jnp throughout:

def f(x):
    y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
    return y

Now let’s time it.

x = jnp.linspace(0, 10, n)
with qe.Timer():
    y = f(x)
    jax.block_until_ready(y);
0.4820 seconds elapsed
with qe.Timer():
    y = f(x)
    jax.block_until_ready(y);
0.1024 seconds elapsed

The outcome is similar to the cos example — JAX is faster, especially on the second run after JIT compilation.

However, with JAX, we have another trick up our sleeve — we can JIT-compile the entire function, not just individual operations.

14.4.2. Compiling the whole function#

The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array operations into a single optimized kernel.

Let’s try this with the function f:

f_jax = jax.jit(f)
with qe.Timer():
    y = f_jax(x)
    jax.block_until_ready(y);
0.1565 seconds elapsed
with qe.Timer():
    y = f_jax(x)
    jax.block_until_ready(y);
0.0367 seconds elapsed

The runtime has improved again — now because we fused all the operations, allowing the compiler to optimize more aggressively.

For example, the compiler can eliminate multiple calls to the hardware accelerator and the creation of a number of intermediate arrays.

Incidentally, a more common syntax when targeting a function for the JIT compiler is

@jax.jit
def f(x):
    pass # put function body here

14.4.3. How JIT compilation works#

When we apply jax.jit to a function, JAX traces it: instead of executing the operations immediately, it records the sequence of operations as a computational graph and hands that graph to the XLA compiler.

XLA then fuses and optimizes the operations into a single compiled kernel tailored to the available hardware (CPU, GPU, or TPU).

The first call to a JIT-compiled function incurs compilation overhead, but subsequent calls with the same input shapes and types reuse the cached compiled code and run at full speed.

14.4.4. Compiling non-pure functions#

Now that we’ve seen how powerful JIT compilation can be, it’s important to understand its relationship with pure functions.

While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable.

Here’s an illustration of this fact, using global variables:

a = 1  # global

@jax.jit
def f(x):
    return a + x
x = jnp.ones(2)
f(x)
Array([2., 2.], dtype=float64)

In the code above, the global value a=1 is fused into the jitted function.

Even if we change a, the output of f will not be affected — as long as the same compiled version is called.

a = 42
f(x)
Array([2., 2.], dtype=float64)

Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of a takes effect:

x = jnp.ones(3)
f(x)
Array([43., 43., 43.], dtype=float64)

Moral of the story: write pure functions when using JAX!

14.4.5. Summary#

Now we can see why both developers and compilers benefit from pure functions.

We love pure functions because they

  • Help testing: each function can operate in isolation

  • Promote deterministic behavior and hence reproducibility

  • Prevent bugs that arise from mutating shared state

The compiler loves pure functions and functional programming because

  • Data dependencies are explicit, which helps with optimizing complex computations

  • Pure functions are easier to differentiate (autodiff)

  • Pure functions are easier to parallelize and optimize (don’t depend on shared mutable state)

14.5. Vectorization with vmap#

Another powerful JAX transformation is jax.vmap, which automatically vectorizes a function written for a single input so that it operates over batches.

This avoids the need to manually write vectorized code or use explicit loops.

14.5.1. A simple example#

Suppose we have a function that computes summary statistics for a single array:

def summary(x):
    return jnp.mean(x), jnp.median(x)

We can apply it to a single vector:

x = jnp.array([1.0, 2.0, 5.0])
summary(x)
(Array(2.66666667, dtype=float64), Array(2., dtype=float64))

Now suppose we have a matrix and want to compute these statistics for each row.

Without vmap, we’d need an explicit loop:

X = jnp.array([[1.0, 2.0, 5.0],
               [4.0, 5.0, 6.0],
               [1.0, 8.0, 9.0]])

for row in X:
    print(summary(row))
(Array(2.66666667, dtype=float64), Array(2., dtype=float64))
(Array(5., dtype=float64), Array(5., dtype=float64))
(Array(6., dtype=float64), Array(8., dtype=float64))

However, Python loops are slow and cannot be efficiently compiled or parallelized by JAX.

Using vmap keeps the computation on the accelerator and composes with other JAX transformations like jit and grad:

batch_summary = jax.vmap(summary)
batch_summary(X)
(Array([2.66666667, 5.        , 6.        ], dtype=float64),
 Array([2., 5., 8.], dtype=float64))

The function summary was written for a single array, and vmap automatically lifted it to operate row-wise over a matrix — no loops, no reshaping.

14.5.2. Combining transformations#

One of JAX’s strengths is that transformations compose naturally.

For example, we can JIT-compile a vectorized function:

fast_batch_summary = jax.jit(jax.vmap(summary))
fast_batch_summary(X)
(Array([2.66666667, 5.        , 6.        ], dtype=float64),
 Array([2., 5., 8.], dtype=float64))

This composition of jit, vmap, and (as we’ll see next) grad is central to JAX’s design and makes it especially powerful for scientific computing and machine learning.

14.6. Automatic differentiation: a preview#

JAX can use automatic differentiation to compute gradients.

This can be extremely useful for optimization and solving nonlinear systems.

Here’s a simple illustration involving the function \(f(x) = x^2 / 2\):

def f(x):
    return (x**2) / 2

f_prime = jax.grad(f)
f_prime(10.0)
Array(10., dtype=float64, weak_type=True)

Let’s plot the function and derivative, noting that \(f'(x) = x\).

fig, ax = plt.subplots()
x_grid = jnp.linspace(-4, 4, 200)
ax.plot(x_grid, f(x_grid), label="$f$")
ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$")
ax.legend(loc='upper center')
plt.show()
_images/33a5668fe6eb8c8b34c592510a458f9f56faae1878d363a2234b36425d4f5b0a.png

Automatic differentiation is a deep topic with many applications in economics and finance. We provide a more thorough treatment in our lecture on autodiff.

14.7. Exercises#

Exercise 14.1

In the Exercise section of our lecture on Numba, we used Monte Carlo to price a European call option.

The code was accelerated by Numba-based multithreading.

Try writing a version of this operation for JAX, using all the same parameters.