\n",
" \n",
" \n",
" \n",
"

"
]
},
{
"cell_type": "markdown",
"id": "a641b22e",
"metadata": {},
"source": [
"# Numba"
]
},
{
"cell_type": "markdown",
"id": "d3360516",
"metadata": {},
"source": [
"## Contents\n",
"\n",
"- [Numba](#Numba) \n",
" - [Overview](#Overview) \n",
" - [Compiling Functions](#Compiling-Functions) \n",
" - [Decorator Notation](#Decorator-Notation) \n",
" - [Type Inference](#Type-Inference) \n",
" - [Compiling Classes](#Compiling-Classes) \n",
" - [Alternatives to Numba](#Alternatives-to-Numba) \n",
" - [Summary and Comments](#Summary-and-Comments) \n",
" - [Exercises](#Exercises) "
]
},
{
"cell_type": "markdown",
"id": "d1fa7871",
"metadata": {},
"source": [
"In addition to what’s in Anaconda, this lecture will need the following libraries:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eebc84b4",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"!pip install quantecon"
]
},
{
"cell_type": "markdown",
"id": "1f6235ef",
"metadata": {},
"source": [
"Please also make sure that you have the latest version of Anaconda, since old\n",
"versions are a [common source of errors](https://python-programming.quantecon.org/troubleshooting.html).\n",
"\n",
"Let’s start with some imports:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72b13785",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import quantecon as qe\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['figure.figsize'] = (10,6)"
]
},
{
"cell_type": "markdown",
"id": "eb7e867e",
"metadata": {},
"source": [
"## Overview\n",
"\n",
"In an [earlier lecture](https://python-programming.quantecon.org/need_for_speed.html) we learned about vectorization, which is one method to improve speed and efficiency in numerical work.\n",
"\n",
"Vectorization involves sending array processing\n",
"operations in batch to efficient low-level code.\n",
"\n",
"However, as [discussed previously](https://python-programming.quantecon.org/need_for_speed.html#numba-p-c-vectorization), vectorization has several weaknesses.\n",
"\n",
"One is that it is highly memory-intensive when working with large amounts of data.\n",
"\n",
"Another is that the set of algorithms that can be entirely vectorized is not universal.\n",
"\n",
"In fact, for some algorithms, vectorization is ineffective.\n",
"\n",
"Fortunately, a new Python library called [Numba](http://numba.pydata.org/)\n",
"solves many of these problems.\n",
"\n",
"It does so through something called **just in time (JIT) compilation**.\n",
"\n",
"The key idea is to compile functions to native machine code instructions on the fly.\n",
"\n",
"When it succeeds, the compiled code is extremely fast.\n",
"\n",
"Numba is specifically designed for numerical work and can also do other tricks such as [multithreading](https://en.wikipedia.org/wiki/Multithreading_%28computer_architecture%29).\n",
"\n",
"Numba will be a key part of our lectures — especially those lectures involving dynamic programming.\n",
"\n",
"This lecture introduces the main ideas.\n",
"\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"id": "e7792290",
"metadata": {},
"source": [
"## Compiling Functions\n",
"\n",
"\n",
"\n",
"As stated above, Numba’s primary use is compiling functions to fast native\n",
"machine code during runtime.\n",
"\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"id": "b8ba3136",
"metadata": {},
"source": [
"### An Example\n",
"\n",
"Let’s consider a problem that is difficult to vectorize: generating the trajectory of a difference equation given an initial condition.\n",
"\n",
"We will take the difference equation to be the quadratic map\n",
"\n",
"$$\n",
"x_{t+1} = \\alpha x_t (1 - x_t)\n",
"$$\n",
"\n",
"In what follows we set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd91b402",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"α = 4.0"
]
},
{
"cell_type": "markdown",
"id": "b56fddc4",
"metadata": {},
"source": [
"Here’s the plot of a typical trajectory, starting from $ x_0 = 0.1 $, with $ t $ on the x-axis"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4daadaf",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"def qm(x0, n):\n",
" x = np.empty(n+1)\n",
" x[0] = x0\n",
" for t in range(n):\n",
" x[t+1] = α * x[t] * (1 - x[t])\n",
" return x\n",
"\n",
"x = qm(0.1, 250)\n",
"fig, ax = plt.subplots()\n",
"ax.plot(x, 'b-', lw=2, alpha=0.8)\n",
"ax.set_xlabel('$t$', fontsize=12)\n",
"ax.set_ylabel('$x_{t}$', fontsize = 12)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "1e1b8e43",
"metadata": {},
"source": [
"To speed the function `qm` up using Numba, our first step is"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4ca74cb",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"from numba import njit\n",
"\n",
"qm_numba = njit(qm)"
]
},
{
"cell_type": "markdown",
"id": "7c032ee5",
"metadata": {},
"source": [
"The function `qm_numba` is a version of `qm` that is “targeted” for\n",
"JIT-compilation.\n",
"\n",
"We will explain what this means momentarily.\n",
"\n",
"Let’s time and compare identical function calls across these two versions, starting with the original function `qm`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b862077",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"n = 10_000_000\n",
"\n",
"qe.tic()\n",
"qm(0.1, int(n))\n",
"time1 = qe.toc()"
]
},
{
"cell_type": "markdown",
"id": "52e94973",
"metadata": {},
"source": [
"Now let’s try qm_numba"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb826b17",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"qe.tic()\n",
"qm_numba(0.1, int(n))\n",
"time2 = qe.toc()"
]
},
{
"cell_type": "markdown",
"id": "14b17be0",
"metadata": {},
"source": [
"This is already a massive speed gain.\n",
"\n",
"In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory:\n",
"\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60ed57c9",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"qe.tic()\n",
"qm_numba(0.1, int(n))\n",
"time3 = qe.toc()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e32b28d",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"time1 / time3 # Calculate speed gain"
]
},
{
"cell_type": "markdown",
"id": "2548ad9b",
"metadata": {},
"source": [
"This kind of speed gain is huge relative to how simple and clear the implementation is."
]
},
{
"cell_type": "markdown",
"id": "5b494040",
"metadata": {},
"source": [
"### How and When it Works\n",
"\n",
"Numba attempts to generate fast machine code using the infrastructure provided by the [LLVM Project](http://llvm.org/).\n",
"\n",
"It does this by inferring type information on the fly.\n",
"\n",
"(See our [earlier lecture](https://python-programming.quantecon.org/need_for_speed.html) on scientific computing for a discussion of types.)\n",
"\n",
"The basic idea is this:\n",
"\n",
"- Python is very flexible and hence we could call the function qm with many\n",
" types. \n",
" - e.g., `x0` could be a NumPy array or a list, `n` could be an integer or a float, etc. \n",
"- This makes it hard to *pre*-compile the function. \n",
"- However, when we do actually call the function, say by executing `qm(0.5, 10)`,\n",
" the types of `x0` and `n` become clear. \n",
"- Moreover, the types of other variables in `qm` can be inferred once the input is known. \n",
"- So the strategy of Numba and other JIT compilers is to wait until this\n",
" moment, and *then* compile the function. \n",
"\n",
"\n",
"That’s why it is called “just-in-time” compilation.\n",
"\n",
"Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 20)`, compilation only takes place on the first call.\n",
"\n",
"The compiled code is then cached and recycled as required."
]
},
{
"cell_type": "markdown",
"id": "18669f8f",
"metadata": {},
"source": [
"## Decorator Notation\n",
"\n",
"In the code above we created a JIT compiled version of `qm` via the call"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "431938f4",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"qm_numba = njit(qm)"
]
},
{
"cell_type": "markdown",
"id": "27a85580",
"metadata": {},
"source": [
"In practice this would typically be done using an alternative *decorator* syntax.\n",
"\n",
"(We will explain all about decorators in a [later lecture](https://python-programming.quantecon.org/python_advanced_features.html) but you can skip the details at this stage.)\n",
"\n",
"Let’s see how this is done.\n",
"\n",
"To target a function for JIT compilation we can put `@njit` before the function definition.\n",
"\n",
"Here’s what this looks like for `qm`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5544a36",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"@njit\n",
"def qm(x0, n):\n",
" x = np.empty(n+1)\n",
" x[0] = x0\n",
" for t in range(n):\n",
" x[t+1] = α * x[t] * (1 - x[t])\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "d8b1deb4",
"metadata": {},
"source": [
"This is equivalent to `qm = njit(qm)`.\n",
"\n",
"The following now uses the jitted version:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70dd906e",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"%%time \n",
"\n",
"qm(0.1, 100_000)"
]
},
{
"cell_type": "markdown",
"id": "fff38865",
"metadata": {},
"source": [
"Numba provides several arguments for decorators to accelerate computation and cache functions [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html).\n",
"\n",
"In the [following lecture on parallelization](https://python-programming.quantecon.org/parallelization.html#parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization."
]
},
{
"cell_type": "markdown",
"id": "aa3b48b3",
"metadata": {},
"source": [
"## Type Inference\n",
"\n",
"Clearly type inference is a key part of JIT compilation.\n",
"\n",
"As you can imagine, inferring types is easier for simple Python objects (e.g., simple scalar data types such as floats and integers).\n",
"\n",
"Numba also plays well with NumPy arrays.\n",
"\n",
"In an ideal setting, Numba can infer all necessary type information.\n",
"\n",
"This allows it to generate native machine code, without having to call the Python runtime environment.\n",
"\n",
"In such a setting, Numba will be on par with machine code from low-level languages.\n",
"\n",
"When Numba cannot infer all type information, it will raise an error.\n",
"\n",
"For example, in the case below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b12ea4d5",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"@njit\n",
"def bootstrap(data, statistics, n):\n",
" bootstrap_stat = np.empty(n)\n",
" n = len(data)\n",
" for i in range(n_resamples):\n",
" resample = np.random.choice(data, size=n, replace=True)\n",
" bootstrap_stat[i] = statistics(resample)\n",
" return bootstrap_stat\n",
"\n",
"def mean(data):\n",
" return np.mean(data)\n",
"\n",
"data = np.array([2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2])\n",
"n_resamples = 10\n",
"\n",
"print('Type of function:', type(mean))\n",
"\n",
"#Error\n",
"try:\n",
" bootstrap(data, mean, n_resamples)\n",
"except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"id": "8e87bd5f",
"metadata": {},
"source": [
"But Numba recognizes JIT-compiled functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9001cac7",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"@njit\n",
"def mean(data):\n",
" return np.mean(data)\n",
"\n",
"print('Type of function:', type(mean))\n",
"\n",
"%time bootstrap(data, mean, n_resamples)"
]
},
{
"cell_type": "markdown",
"id": "5c5ff89c",
"metadata": {},
"source": [
"We can check the signature of the JIT-compiled function"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6084a924",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"bootstrap.signatures"
]
},
{
"cell_type": "markdown",
"id": "07587452",
"metadata": {},
"source": [
"The function `bootstrap` takes one `float64` floating point array, one function called `mean` and an `int64` integer.\n",
"\n",
"Now let’s see what happens when we change the inputs.\n",
"\n",
"Running it again with a larger integer for `n` and a different set of data does not change the signature of the function."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68b5e0a3",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"data = np.array([4.1, 1.1, 2.3, 1.9, 0.1, 2.8, 1.2])\n",
"%time bootstrap(data, mean, 100)\n",
"bootstrap.signatures"
]
},
{
"cell_type": "markdown",
"id": "d2acd213",
"metadata": {},
"source": [
"As expected, the second run is much faster.\n",
"\n",
"Let’s try to change the data again and use an integer array as data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc39600f",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"data = np.array([1, 2, 3, 4, 5], dtype=np.int64)\n",
"%time bootstrap(data, mean, 100)\n",
"bootstrap.signatures"
]
},
{
"cell_type": "markdown",
"id": "f3458b0c",
"metadata": {},
"source": [
"Note that a second signature is added.\n",
"\n",
"It also takes longer to run, suggesting that Numba recompiles this function as the type changes.\n",
"\n",
"Overall, type inference helps Numba to achieve its performance, but it also limits what Numba supports and sometimes requires careful type checks.\n",
"\n",
"You can refer to the list of supported Python and Numpy features [here](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html)."
]
},
{
"cell_type": "markdown",
"id": "8fdd4ab1",
"metadata": {},
"source": [
"## Compiling Classes\n",
"\n",
"As mentioned above, at present Numba can only compile a subset of Python.\n",
"\n",
"However, that subset is ever expanding.\n",
"\n",
"For example, Numba is now quite effective at compiling classes.\n",
"\n",
"If a class is successfully compiled, then its methods act as JIT-compiled\n",
"functions.\n",
"\n",
"To give one example, let’s consider the class for analyzing the Solow growth model we\n",
"created in [this lecture](https://python-programming.quantecon.org/python_oop.html).\n",
"\n",
"To compile this class we use the `@jitclass` decorator:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f43f69ff",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"from numba import float64\n",
"from numba.experimental import jitclass"
]
},
{
"cell_type": "markdown",
"id": "aae70dd9",
"metadata": {},
"source": [
"Notice that we also imported something called `float64`.\n",
"\n",
"This is a data type representing standard floating point numbers.\n",
"\n",
"We are importing it here because Numba needs a bit of extra help with types when it tries to deal with classes.\n",
"\n",
"Here’s our code:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae895253",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"solow_data = [\n",
" ('n', float64),\n",
" ('s', float64),\n",
" ('δ', float64),\n",
" ('α', float64),\n",
" ('z', float64),\n",
" ('k', float64)\n",
"]\n",
"\n",
"@jitclass(solow_data)\n",
"class Solow:\n",
" r\"\"\"\n",
" Implements the Solow growth model with the update rule\n",
"\n",
" k_{t+1} = [(s z k^α_t) + (1 - δ)k_t] /(1 + n)\n",
"\n",
" \"\"\"\n",
" def __init__(self, n=0.05, # population growth rate\n",
" s=0.25, # savings rate\n",
" δ=0.1, # depreciation rate\n",
" α=0.3, # share of labor\n",
" z=2.0, # productivity\n",
" k=1.0): # current capital stock\n",
"\n",
" self.n, self.s, self.δ, self.α, self.z = n, s, δ, α, z\n",
" self.k = k\n",
"\n",
" def h(self):\n",
" \"Evaluate the h function\"\n",
" # Unpack parameters (get rid of self to simplify notation)\n",
" n, s, δ, α, z = self.n, self.s, self.δ, self.α, self.z\n",
" # Apply the update rule\n",
" return (s * z * self.k**α + (1 - δ) * self.k) / (1 + n)\n",
"\n",
" def update(self):\n",
" \"Update the current state (i.e., the capital stock).\"\n",
" self.k = self.h()\n",
"\n",
" def steady_state(self):\n",
" \"Compute the steady state value of capital.\"\n",
" # Unpack parameters (get rid of self to simplify notation)\n",
" n, s, δ, α, z = self.n, self.s, self.δ, self.α, self.z\n",
" # Compute and return steady state\n",
" return ((s * z) / (n + δ))**(1 / (1 - α))\n",
"\n",
" def generate_sequence(self, t):\n",
" \"Generate and return a time series of length t\"\n",
" path = []\n",
" for i in range(t):\n",
" path.append(self.k)\n",
" self.update()\n",
" return path"
]
},
{
"cell_type": "markdown",
"id": "0afe64b8",
"metadata": {},
"source": [
"First we specified the types of the instance data for the class in\n",
"`solow_data`.\n",
"\n",
"After that, targeting the class for JIT compilation only requires adding\n",
"`@jitclass(solow_data)` before the class definition.\n",
"\n",
"When we call the methods in the class, the methods are compiled just like functions."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67407496",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"s1 = Solow()\n",
"s2 = Solow(k=8.0)\n",
"\n",
"T = 60\n",
"fig, ax = plt.subplots()\n",
"\n",
"# Plot the common steady state value of capital\n",
"ax.plot([s1.steady_state()]*T, 'k-', label='steady state')\n",
"\n",
"# Plot time series for each economy\n",
"for s in s1, s2:\n",
" lb = f'capital series from initial state {s.k}'\n",
" ax.plot(s.generate_sequence(T), 'o-', lw=2, alpha=0.6, label=lb)\n",
"ax.set_ylabel('$k_{t}$', fontsize=12)\n",
"ax.set_xlabel('$t$', fontsize=12)\n",
"ax.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "d98b3a78",
"metadata": {},
"source": [
"## Alternatives to Numba\n",
"\n",
"\n",
"\n",
"There are additional options for accelerating Python loops.\n",
"\n",
"Here we quickly review them.\n",
"\n",
"However, we do so only for interest and completeness.\n",
"\n",
"If you prefer, you can safely skip this section."
]
},
{
"cell_type": "markdown",
"id": "9d7b92bb",
"metadata": {},
"source": [
"### Cython\n",
"\n",
"Like [Numba](https://python-programming.quantecon.org/.html), [Cython](http://cython.org/) provides an approach to generating fast compiled code that can be used from Python.\n",
"\n",
"As was the case with Numba, a key problem is the fact that Python is dynamically typed.\n",
"\n",
"As you’ll recall, Numba solves this problem (where possible) by inferring type.\n",
"\n",
"Cython’s approach is different — programmers add type definitions directly to their “Python” code.\n",
"\n",
"As such, the Cython language can be thought of as Python with type definitions.\n",
"\n",
"In addition to a language specification, Cython is also a language translator, transforming Cython code into optimized C and C++ code.\n",
"\n",
"Cython also takes care of building language extensions — the wrapper code that interfaces between the resulting compiled code and Python.\n",
"\n",
"While Cython has certain advantages, we generally find it both slower and more\n",
"cumbersome than Numba."
]
},
{
"cell_type": "markdown",
"id": "d1702476",
"metadata": {},
"source": [
"### Interfacing with Fortran via F2Py\n",
"\n",
"\n",
"\n",
"If you are comfortable writing Fortran you will find it very easy to create\n",
"extension modules from Fortran code using [F2Py](https://docs.scipy.org/doc/numpy/f2py/).\n",
"\n",
"F2Py is a Fortran-to-Python interface generator that is particularly simple to\n",
"use.\n",
"\n",
"Robert Johansson provides a [nice introduction](http://nbviewer.jupyter.org/github/jrjohansson/scientific-python-lectures/blob/master/Lecture-6A-Fortran-and-C.ipynb)\n",
"to F2Py, among other things.\n",
"\n",
"Recently, [a Jupyter cell magic for Fortran](http://nbviewer.jupyter.org/github/mgaitan/fortran_magic/blob/master/documentation.ipynb) has been developed — you might want to give it a try."
]
},
{
"cell_type": "markdown",
"id": "1897ae0e",
"metadata": {},
"source": [
"## Summary and Comments\n",
"\n",
"Let’s review the above and add some cautionary notes."
]
},
{
"cell_type": "markdown",
"id": "191a5000",
"metadata": {},
"source": [
"### Limitations\n",
"\n",
"As we’ve seen, Numba needs to infer type information on\n",
"all variables to generate fast machine-level instructions.\n",
"\n",
"For simple routines, Numba infers types very well.\n",
"\n",
"For larger ones, or for routines using external libraries, it can easily fail.\n",
"\n",
"Hence, it’s prudent when using Numba to focus on speeding up small, time-critical snippets of code.\n",
"\n",
"This will give you much better performance than blanketing your Python programs with `@njit` statements."
]
},
{
"cell_type": "markdown",
"id": "c860661b",
"metadata": {},
"source": [
"### A Gotcha: Global Variables\n",
"\n",
"Here’s another thing to be careful about when using Numba.\n",
"\n",
"Consider the following example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6906053",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"a = 1\n",
"\n",
"@njit\n",
"def add_a(x):\n",
" return a + x\n",
"\n",
"print(add_a(10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cf26d67",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"a = 2\n",
"\n",
"print(add_a(10))"
]
},
{
"cell_type": "markdown",
"id": "4c534245",
"metadata": {},
"source": [
"Notice that changing the global had no effect on the value returned by the\n",
"function.\n",
"\n",
"When Numba compiles machine code for functions, it treats global variables as constants to ensure type stability."
]
},
{
"cell_type": "markdown",
"id": "843a6675",
"metadata": {},
"source": [
"## Exercises"
]
},
{
"cell_type": "markdown",
"id": "1619550c",
"metadata": {},
"source": [
"## Exercise 16.1\n",
"\n",
"[Previously](https://python-programming.quantecon.org/python_by_example.html#pbe_ex5) we considered how to approximate $ \\pi $ by\n",
"Monte Carlo.\n",
"\n",
"Use the same idea here, but make the code efficient using Numba.\n",
"\n",
"Compare speed with and without Numba when the sample size is large."
]
},
{
"cell_type": "markdown",
"id": "9a2f28a0",
"metadata": {},
"source": [
"## Solution to[ Exercise 16.1](https://python-programming.quantecon.org/#speed_ex1)\n",
"\n",
"Here is one solution:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eab81db3",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"from random import uniform\n",
"\n",
"@njit\n",
"def calculate_pi(n=1_000_000):\n",
" count = 0\n",
" for i in range(n):\n",
" u, v = uniform(0, 1), uniform(0, 1)\n",
" d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2)\n",
" if d < 0.5:\n",
" count += 1\n",
"\n",
" area_estimate = count / n\n",
" return area_estimate * 4 # dividing by radius**2"
]
},
{
"cell_type": "markdown",
"id": "fcdb98d1",
"metadata": {},
"source": [
"Now let’s see how fast it runs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "113025aa",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"%time calculate_pi()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9fe89a8",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"%time calculate_pi()"
]
},
{
"cell_type": "markdown",
"id": "a6d0b1d5",
"metadata": {},
"source": [
"If we switch off JIT compilation by removing `@njit`, the code takes around\n",
"150 times as long on our machine.\n",
"\n",
"So we get a speed gain of 2 orders of magnitude–which is huge–by adding four\n",
"characters."
]
},
{
"cell_type": "markdown",
"id": "a679c038",
"metadata": {},
"source": [
"## Exercise 16.2\n",
"\n",
"In the [Introduction to Quantitative Economics with Python](https://python-intro.quantecon.org) lecture series you can\n",
"learn all about finite-state Markov chains.\n",
"\n",
"For now, let’s just concentrate on simulating a very simple example of such a chain.\n",
"\n",
"Suppose that the volatility of returns on an asset can be in one of two regimes — high or low.\n",
"\n",
"The transition probabilities across states are as follows\n",
"\n",
"![https://python-programming.quantecon.org/_static/lecture_specific/sci_libs/nfs_ex1.png](https://python-programming.quantecon.org/_static/lecture_specific/sci_libs/nfs_ex1.png)\n",
"\n",
" \n",
"For example, let the period length be one day, and suppose the current state is high.\n",
"\n",
"We see from the graph that the state tomorrow will be\n",
"\n",
"- high with probability 0.8 \n",
"- low with probability 0.2 \n",
"\n",
"\n",
"Your task is to simulate a sequence of daily volatility states according to this rule.\n",
"\n",
"Set the length of the sequence to `n = 1_000_000` and start in the high state.\n",
"\n",
"Implement a pure Python version and a Numba version, and compare speeds.\n",
"\n",
"To test your code, evaluate the fraction of time that the chain spends in the low state.\n",
"\n",
"If your code is correct, it should be about 2/3.\n",
"\n",
"- Represent the low state as 0 and the high state as 1. \n",
"- If you want to store integers in a NumPy array and then apply JIT compilation, use `x = np.empty(n, dtype=np.int_)`. "
]
},
{
"cell_type": "markdown",
"id": "d7311b82",
"metadata": {},
"source": [
"## Solution to[ Exercise 16.2](https://python-programming.quantecon.org/#speed_ex2)\n",
"\n",
"We let\n",
"\n",
"- 0 represent “low” \n",
"- 1 represent “high” "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ab28186",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"p, q = 0.1, 0.2 # Prob of leaving low and high state respectively"
]
},
{
"cell_type": "markdown",
"id": "010e1071",
"metadata": {},
"source": [
"Here’s a pure Python version of the function"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f241185",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"def compute_series(n):\n",
" x = np.empty(n, dtype=np.int_)\n",
" x[0] = 1 # Start in state 1\n",
" U = np.random.uniform(0, 1, size=n)\n",
" for t in range(1, n):\n",
" current_x = x[t-1]\n",
" if current_x == 0:\n",
" x[t] = U[t] < p\n",
" else:\n",
" x[t] = U[t] > q\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "79099cdd",
"metadata": {},
"source": [
"Let’s run this code and check that the fraction of time spent in the low\n",
"state is about 0.666"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee6472fc",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"n = 1_000_000\n",
"x = compute_series(n)\n",
"print(np.mean(x == 0)) # Fraction of time x is in state 0"
]
},
{
"cell_type": "markdown",
"id": "c07f9719",
"metadata": {},
"source": [
"This is (approximately) the right output.\n",
"\n",
"Now let’s time it:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3effcde",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"qe.tic()\n",
"compute_series(n)\n",
"qe.toc()"
]
},
{
"cell_type": "markdown",
"id": "d75e6ae1",
"metadata": {},
"source": [
"Next let’s implement a Numba version, which is easy"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee702d8e",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"compute_series_numba = njit(compute_series)"
]
},
{
"cell_type": "markdown",
"id": "7db55023",
"metadata": {},
"source": [
"Let’s check we still get the right numbers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb8f26fe",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"x = compute_series_numba(n)\n",
"print(np.mean(x == 0))"
]
},
{
"cell_type": "markdown",
"id": "82342b59",
"metadata": {},
"source": [
"Let’s see the time"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88363344",
"metadata": {
"hide-output": false
},
"outputs": [],
"source": [
"qe.tic()\n",
"compute_series_numba(n)\n",
"qe.toc()"
]
},
{
"cell_type": "markdown",
"id": "d56168b7",
"metadata": {},
"source": [
"This is a nice speed improvement for one line of code!"
]
}
],
"metadata": {
"date": 1710455932.216583,
"filename": "numba.md",
"kernelspec": {
"display_name": "Python",
"language": "python3",
"name": "python3"
},
"title": "Numba"
},
"nbformat": 4,
"nbformat_minor": 5
}