15. NumPy vs Numba vs JAX#

In the preceding lectures, we’ve discussed three core libraries for scientific and numerical computing:

Which one should we use in any given situation?

This lecture addresses that question, at least partially, by discussing some use cases.

Before getting started, we note that the first two are a natural pair: NumPy and Numba play well together.

JAX, on the other hand, stands alone.

When considering each approach, we will consider not just efficiency and memory footprint but also clarity and ease of use.

In addition to what’s in Anaconda, this lecture will need the following libraries:

!pip install quantecon jax

Hide code cell output

Requirement already satisfied: quantecon in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (0.10.1)
Requirement already satisfied: jax in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (0.8.1)
Requirement already satisfied: numba>=0.49.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (0.61.0)
Requirement already satisfied: numpy>=1.17.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.1.3)
Requirement already satisfied: requests in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.32.3)
Requirement already satisfied: scipy>=1.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.15.3)
Requirement already satisfied: sympy in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.13.3)
Requirement already satisfied: jaxlib<=0.8.1,>=0.8.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.8.1)
Requirement already satisfied: ml_dtypes>=0.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.5.4)
Requirement already satisfied: opt_einsum in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (3.4.0)
Requirement already satisfied: llvmlite<0.45,>=0.44.0dev0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from numba>=0.49.0->quantecon) (0.44.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2.3.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2025.4.26)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from sympy->quantecon) (1.3.0)

We will use the following imports.

import random
import numpy as np
import quantecon as qe
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib import cm
import jax
import jax.numpy as jnp

15.1. Vectorized operations#

Some operations can be perfectly vectorized — all loops are easily eliminated and numerical operations are reduced to calculations on arrays.

In this case, which approach is best?

15.1.1. Problem Statement#

Consider the problem of maximizing a function \(f\) of two variables \((x,y)\) over the square \([-a, a] \times [-a, a]\).

For \(f\) and \(a\) let’s choose

\[ f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2} \quad \text{and} \quad a = 3 \]

Here’s a plot of \(f\)

def f(x, y):
    return np.cos(x**2 + y**2) / (1 + x**2 + y**2)

xgrid = np.linspace(-3, 3, 50)
ygrid = xgrid
x, y = np.meshgrid(xgrid, ygrid)

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(x,
                y,
                f(x, y),
                rstride=2, cstride=2,
                cmap=cm.jet,
                alpha=0.7,
                linewidth=0.25)
ax.set_zlim(-0.5, 1.0)
ax.set_xlabel('$x$', fontsize=14)
ax.set_ylabel('$y$', fontsize=14)
plt.show()
_images/1705404365675bd611ef1c31d537c2b740954064ee32e1a3dcb9c165f9c749f1.png

For the sake of this exercise, we’re going to use brute force for the maximization.

  1. Evaluate \(f\) for all \((x,y)\) in a grid on the square.

  2. Return the maximum of observed values.

Just to illustrate the idea, here’s a non-vectorized version that uses Python loops.

grid = np.linspace(-3, 3, 50)
m = -np.inf
for x in grid:
    for y in grid:
        z = f(x, y)
        if z > m:
            m = z

15.1.2. NumPy vectorization#

If we switch to NumPy-style vectorization we can use a much larger grid and the code executes relatively quickly.

Here we use np.meshgrid to create two-dimensional input grids x and y such that f(x, y) generates all evaluations on the product grid.

(This strategy dates back to Matlab.)

grid = np.linspace(-3, 3, 3_000)
x, y = np.meshgrid(grid, grid)

with qe.Timer(precision=8):
    np.max(f(x, y))
0.16827178 seconds elapsed

In the vectorized version, all the looping takes place in compiled code.

Moreover, NumPy uses implicit multithreading, so that at least some parallelization occurs.

Note

If you have a system monitor such as htop (Linux/Mac) or perfmon (Windows), then try running this and then observing the load on your CPUs.

(You will probably need to bump up the grid size to see large effects.)

The output typically shows that the operation is successfully distributed across multiple threads.

(The parallelization cannot be highly efficient because the binary is compiled before it sees the size of the arrays x and y.)

15.1.3. A Comparison with Numba#

Now let’s see if we can achieve better performance using Numba with a simple loop.

import numba

@numba.jit
def compute_max_numba(grid):
    m = -np.inf
    for x in grid:
        for y in grid:
            z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
            if z > m:
                m = z
    return m

grid = np.linspace(-3, 3, 3_000)

with qe.Timer(precision=8):
    compute_max_numba(grid)
0.23649049 seconds elapsed
with qe.Timer(precision=8):
    compute_max_numba(grid)
0.10795808 seconds elapsed

Depending on your machine, the Numba version can be a bit slower or a bit faster than NumPy.

On one hand, NumPy combines efficient arithmetic (like Numba) with some multithreading (unlike this Numba code), which provides an advantage.

On the other hand, the Numba routine uses much less memory, since we are only working with a single one-dimensional grid.

15.1.4. Parallelized Numba#

Now let’s try parallelization with Numba using prange:

First we parallelize just the outer loop.

@numba.jit(parallel=True)
def compute_max_numba_parallel(grid):
    n = len(grid)
    m = -np.inf
    for i in numba.prange(n):
        for j in range(n):
            x = grid[i]
            y = grid[j]
            z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
            if z > m:
                m = z
    return m

with qe.Timer(precision=8):
    compute_max_numba_parallel(grid)
0.26112270 seconds elapsed
with qe.Timer(precision=8):
    compute_max_numba_parallel(grid)
0.00013709 seconds elapsed

Next we parallelize both loops.

@numba.jit(parallel=True)
def compute_max_numba_parallel_nested(grid):
    n = len(grid)
    m = -np.inf
    for i in numba.prange(n):
        for j in numba.prange(n):
            x = grid[i]
            y = grid[j]
            z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
            if z > m:
                m = z
    return m

with qe.Timer(precision=8):
    compute_max_numba_parallel_nested(grid)
0.26212025 seconds elapsed
with qe.Timer(precision=8):
    compute_max_numba_parallel_nested(grid)
0.00011683 seconds elapsed

Depending on your machine, you might or might not see large benefits from parallelization here.

If you have a small number of cores, the overhead of thread management and synchronization can overwhelm the benefits of parallel execution.

For more powerful machines and larger grid sizes, parallelization can generate large speed gains.

15.1.5. Vectorized code with JAX#

In most ways, vectorization is the same in JAX as it is in NumPy.

But there are also some differences, which we highlight here.

Let’s start with the function.

@jax.jit
def f(x, y):
    return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)

As with NumPy, to get the right shape and the correct nested for loop calculation, we can use a meshgrid operation designed for this purpose:

grid = jnp.linspace(-3, 3, 3_000)
x_mesh, y_mesh = np.meshgrid(grid, grid)

with qe.Timer(precision=8):
    z_mesh = f(x_mesh, y_mesh).block_until_ready()
0.07069302 seconds elapsed

Let’s run again to eliminate compile time.

with qe.Timer(precision=8):
    z_mesh = f(x_mesh, y_mesh).block_until_ready()
0.03460550 seconds elapsed

Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU.

The compilation overhead is a one-time cost that pays off when the function is called repeatedly.

15.1.6. JAX plus vmap#

There is one problem with both the NumPy code and the JAX code:

While the flat arrays are low-memory

grid.nbytes 
12000

the mesh grids are memory intensive

x_mesh.nbytes + y_mesh.nbytes
72000000

This extra memory usage can be a big problem in actual research calculations.

Fortunately, JAX admits a different approach using jax.vmap.

15.1.6.1. Version 1#

Here’s one way we can apply vmap.

# Set up f to compute f(x, y) at every x for any given y
f_vec_x = lambda y: f(grid, y)
# Vectorize this operation over all y
f_vec = jax.vmap(f_vec_x)
# Compute result at all y
z_vmap = f_vec(grid)

Let’s see the timing:

with qe.Timer(precision=8):
    z_vmap = f_vec(grid)
    z_vmap.block_until_ready()
0.03802609 seconds elapsed

Let’s check we got the right result:

jnp.allclose(z_mesh, z_vmap)
Array(True, dtype=bool)

The execution time is similar to as the mesh operation but we are using much less memory.

In addition, vmap allows us to break vectorization up into stages, which is often easier to comprehend than the traditional approach.

This will become more obvious when we tackle larger problems.

15.1.6.2. Version 2#

Here’s a more generic approach to using vmap that we often use in the lectures.

First we vectorize in y.

f_vec_y = jax.vmap(f, in_axes=(None, 0))

In the line above, (None, 0) indicates that we are vectorizing in the second argument, which is y.

Next, we vectorize in the first argument, which is x.

f_vec = jax.vmap(f_vec_y, in_axes=(0, None))

With this construction, we can now call \(f\) directly on flat (low memory) arrays.

x, y = grid, grid
with qe.Timer(precision=8):
    z_vmap = f_vec(x, y).block_until_ready()
0.05407190 seconds elapsed

Let’s run it again to eliminate compilation time:

with qe.Timer(precision=8):
    z_vmap = f_vec(x, y).block_until_ready()
0.02737856 seconds elapsed

Let’s check we got the right result:

jnp.allclose(z_mesh, z_vmap)
Array(True, dtype=bool)

15.1.7. Summary#

In our view, JAX is the winner for vectorized operations.

It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap).

Moreover, the vmap approach can sometimes lead to significantly clearer code.

While Numba is impressive, the beauty of JAX is that, with fully vectorized operations, we can run exactly the same code on machines with hardware accelerators and reap all the benefits without paying extra cost.

15.2. Sequential operations#

Some operations are inherently sequential – and hence difficult or impossible to vectorize.

In this case NumPy is a poor option and we are left with the choice of Numba or JAX.

To compare these choices, we will revisit the problem of iterating on the quadratic map that we saw in our Numba lecture.

15.2.1. Numba Version#

Here’s the Numba version.

@numba.jit
def qm(x0, n, α=4.0):
    x = np.empty(n+1)
    x[0] = x0
    for t in range(n):
      x[t+1] = α * x[t] * (1 - x[t])
    return x

Let’s generate a time series of length 10,000,000 and time the execution:

n = 10_000_000

with qe.Timer(precision=8):
    x = qm(0.1, n)
0.14098239 seconds elapsed

Let’s run it again to eliminate compilation time:

with qe.Timer(precision=8):
    x = qm(0.1, n)
0.02677727 seconds elapsed

Numba handles this sequential operation very efficiently.

Notice that the second run is significantly faster after JIT compilation completes.

Numba’s compilation is typically quite fast, and the resulting code performance is excellent for sequential operations like this one.

15.2.2. JAX Version#

Now let’s create a JAX version using lax.scan:

(We’ll hold n static because it affects array size and hence JAX wants to specialize on its value in the compiled code.)

from jax import lax
from functools import partial

@partial(jax.jit, static_argnums=(1,))
def qm_jax(x0, n, α=4.0):
    def update(x, t):
        x_new = α * x * (1 - x)
        return x_new, x_new

    _, x = lax.scan(update, x0, jnp.arange(n))
    return jnp.concatenate([jnp.array([x0]), x])

This code is not easy to read but, in essence, lax.scan repeatedly calls qm_jax and accumulates the returns x_new into an array.

Let’s time it with the same parameters:

with qe.Timer(precision=8):
    x_jax = qm_jax(0.1, n).block_until_ready()
0.10437250 seconds elapsed

Let’s run it again to eliminate compilation overhead:

with qe.Timer(precision=8):
    x_jax = qm_jax(0.1, n).block_until_ready()
0.05768394 seconds elapsed

JAX is also efficient for this sequential operation.

Both JAX and Numba deliver strong performance after compilation, with Numba typically (but not always) offering slightly better speeds on purely sequential operations.

15.2.3. Summary#

While both Numba and JAX deliver strong performance for sequential operations, there are significant differences in code readability and ease of use.

The Numba version is straightforward and natural to read: we simply allocate an array and fill it element by element using a standard Python loop.

This is exactly how most programmers think about the algorithm.

The JAX version, on the other hand, requires using lax.scan, which is significantly less intuitive.

Additionally, JAX’s immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba.

For this type of sequential operation, Numba is the clear winner in terms of code clarity and ease of implementation, as well as high performance.