JAX Cheatsheet

A visual guide to JAX covering arrays and devices, the NumPy-style API, automatic differentiation, JIT compilation, auto-vectorization with vmap, splittable random keys, pure control flow, and pytrees.

python
jax
cheatsheet
Author

James Balamuta

Published

June 14, 2026

JAX is “NumPy you can transform.” You write ordinary, pure functions over jnp arrays, then wrap them in higher-order transforms, grad to differentiate, jit to compile, and vmap to vectorize, each of which returns a brand new function. The whole library teaches that one wrapping mental model, plus the few rules that make it work: arrays are immutable, randomness is explicit, and traced functions must avoid Python side effects. If you know NumPy and SciPy, the API will feel familiar; what is new is the layer of transformations on top. The convention throughout is import jax, import jax.numpy as jnp, and from jax import grad, jit, vmap.

Complete JAX cheatsheet (light mode): eight panels covering arrays and devices, the NumPy-style API, automatic differentiation, JIT compilation, vmap, random keys, control flow, and pytrees.

Complete JAX cheatsheet (dark mode): eight panels covering arrays and devices, the NumPy-style API, automatic differentiation, JIT compilation, vmap, random keys, control flow, and pytrees.

Download the full cheatsheet

All eight panels in a single, printable SVG.

Light SVG Dark SVG

Arrays & Devices

A JAX array looks like a NumPy array but lives on whatever accelerator JAX found (CPU, GPU, or TPU) and is immutable. Note that the default float dtype is float32, not float64. Computation is dispatched asynchronously, so you call block_until_ready() (or convert back to NumPy) when you actually need the values on the host.

jax arrays panel: jnp.array, arange/linspace, asarray, devices, device_put, block_until_ready.

Make arrays, see your accelerator, move data onto it.

jax arrays panel: jnp.array, arange/linspace, asarray, devices, device_put, block_until_ready.

Make arrays, see your accelerator, move data onto it.
import jax
import jax.numpy as jnp

jnp.array([1., 2., 3.])                  # array from a list -> float32 by default
jnp.arange(5)  jnp.linspace(0, 1, 3)     # range / evenly spaced values
jnp.asarray(np_array)                    # convert a NumPy array (host -> device)
jax.devices()  jax.default_backend()     # list devices -> [CpuDevice(id=0)], 'cpu'
jax.device_put(x, jax.devices()[0])      # place data on a chosen device
x.block_until_ready()                    # wait for async compute to finish

See Key concepts.

NumPy-style API & Immutable Updates

jax.numpy (imported as jnp) re-implements most of the NumPy API, math, reductions, indexing, and linear algebra, so existing array code ports over with minimal changes. Because arrays cannot be mutated, in-place writes are replaced by the functional update syntax x.at[idx].set(...) and .add(...), which returns a new array and leaves the original untouched.

jax numpy panel: element-wise math, matmul, reductions, indexing, at[].set, at[].add.

jnp mirrors numpy, but arrays never change in place.

jax numpy panel: element-wise math, matmul, reductions, indexing, at[].set, at[].add.

jnp mirrors numpy, but arrays never change in place.
jnp.exp(x)  x * 2 + 1            # element-wise math, vectorized
a @ b  jnp.dot(a, b)            # matrix multiply (sum of products)
x.sum(axis=0)  x.mean()        # reduce over an axis -> [5 7 9]
m[1, 2]  m[:, 0]               # indexing reads just like numpy
x.at[0].set(99)                 # functional "set" -> NEW array [99 1 2 3 4]
x.at[1].add(10)                 # functional "add" -> NEW array [0 11 2 3 4]

See Indexed update operators.

Automatic Differentiation

grad is a higher-order transform: give it a function that returns a scalar and it returns a new function computing the gradient, which you can nest (grad(grad(f))) or point at specific arguments with argnums. Use value_and_grad to get the loss and its gradient in one pass, and jax.jacobian or jax.hessian for full first- and second-order matrices.

jax grad panel: grad, value_and_grad, argnums, grad(grad), jacobian, hessian.

grad turns a function into its gradient function.

jax grad panel: grad, value_and_grad, argnums, grad(grad), jacobian, hessian.

grad turns a function into its gradient function.
from jax import grad, value_and_grad

grad(f)(2.0)                            # gradient of a scalar function -> 12.0
value_and_grad(loss)(w, x, y)           # value and gradient together -> (loss, grad)
grad(loss, argnums=(0, 1))(w, b, x)     # differentiate chosen args -> (gw, gb)
grad(grad(f))(2.0)                      # higher-order (2nd) derivative -> 12.0
jax.jacobian(f)(x)                      # full Jacobian matrix
jax.hessian(g)(x)                       # Hessian (curvature) matrix

See Automatic differentiation.

JIT Compilation

jit traces your Python function once on abstract “shaped” inputs, hands the resulting graph to the XLA compiler, and caches the compiled binary keyed by input shape and dtype, so the first call is slow and later calls are fast. Because tracing runs your Python only once, the function must be pure, with no data-dependent Python if/for over array values and no side effects, and values that should be compile-time constants go in static_argnums.

jax jit panel: jit, trace-then-reuse, static_argnums, make_jaxpr, block_until_ready, side effects.

jit traces a function and compiles it with XLA.

jax jit panel: jit, trace-then-reuse, static_argnums, make_jaxpr, block_until_ready, side effects.

jit traces a function and compiles it with XLA.
from jax import jit
from functools import partial

jit(f)  @jit                            # compile a function with XLA
f(x); f(x)                              # trace once, 2nd call fast (cached by shape+dtype)
partial(jit, static_argnums=(1,))       # mark compile-time constants (recompiles if changed)
jax.make_jaxpr(f)(x)                    # inspect the traced program (the IR)
f(x).block_until_ready()                # block before timing
# avoid print() / .item() inside a jit'd fn: traced, runs once

See Just-in-time compilation.

Auto-Vectorization (vmap)

vmap lets you write a function for a single example and then automatically adds a batch dimension, so you never hand-write a batch loop or fiddle with broadcasting. in_axes and out_axes say which axis of each argument and result is the batch axis (use None to hold an argument fixed), and vmap composes cleanly with grad and jit.

jax vmap panel: vmap, map over leading axis, in_axes, out_axes, vmap+grad, replace the for loop.

Write the single-example function; vmap adds the batch axis.

jax vmap panel: vmap, map over leading axis, in_axes, out_axes, vmap+grad, replace the for loop.

Write the single-example function; vmap adds the batch axis.
from jax import vmap

vmap(f)(batch)                          # batch a per-example function
vmap(dot)(A, B)                         # map over the leading axis -> [3 3 3 3]
vmap(f, in_axes=(0, None))              # map arg 0 per-row, hold arg 1 fixed
vmap(f, out_axes=1)                     # stack results into columns
vmap(grad(f))(xs)                       # compose with grad: one gradient per input
# replaces: for x in batch: ...         # no batch loop, same speed, less code

See Automatic vectorization.

Random Numbers (Explicit Keys)

JAX has no hidden global RNG state. Instead you create an explicit key from a seed and pass it into every sampling call, which makes randomness reproducible and parallel-safe. The golden rule is never reuse a key: split it into fresh independent keys (or fold_in a step index) before each new draw, because the same key always produces the same numbers.

jax random panel: key, split, normal, uniform/randint, fold_in, deterministic draws.

No global seed, you pass and split an explicit key.

jax random panel: key, split, normal, uniform/randint, fold_in, deterministic draws.

No global seed, you pass and split an explicit key.
key = jax.random.key(0)                       # create a key from a seed (reproducible)
k1, k2 = jax.random.split(key)                # split into independent keys (never reuse)
jax.random.normal(key, (3,))                  # standard normals -> [1.62 2.03 -0.43]
jax.random.uniform(key, (3,))                 # uniform [0,1) draws
jax.random.randint(key, (3,), 0, 10)          # integer draws -> [9 0 2]
jax.random.fold_in(key, step)                 # per-step key, stable per index

See Pseudorandom numbers.

Control Flow & Loops

Inside a jit or grad trace, a regular Python if or for runs on the trace (it gets unrolled or fails on array-valued conditions), so data-dependent control flow uses the lax primitives cond, fori_loop, while_loop, and scan. scan is the workhorse for sequential models: it threads a carry through a sequence while collecting per-step outputs, and lax.stop_gradient blocks a value from contributing to the backward pass.

jax control flow panel: cond, fori_loop, while_loop, scan, avoid Python for, stop_gradient.

Inside jit, use lax primitives instead of Python branches.

jax control flow panel: cond, fori_loop, while_loop, scan, avoid Python for, stop_gradient.

Inside jit, use lax primitives instead of Python branches.
from jax import lax

lax.cond(p, true_fn, false_fn, x)       # data-dependent branch -> 11
lax.fori_loop(0, 5, body, init)         # fixed-count loop (i: 0->4) -> 10
lax.while_loop(cond, body, init)        # loop while a condition holds -> 12
lax.scan(body, init, xs)                # carry + collect -> (carry 6, ys [0 1 3 6])
# avoid a Python `for` over a traced array: it unrolls or errors -> use scan
jax.lax.stop_gradient(x)                # block a value from the backward pass

See Control flow.

Pytrees & Parameters

A pytree is any nested structure of dicts, lists, and tuples whose leaves are arrays, exactly how you store model parameters and optimizer state. jax.tree.map applies a function to every leaf (across one or several matching trees) while preserving the structure, so a whole gradient-descent update or a grad result is just a tree the same shape as your parameters.

jax pytrees panel: parameter pytree, tree.map, tree.leaves, flatten, SGD step, grad returns a tree.

Nested dicts and lists of arrays are first-class; map over them.

jax pytrees panel: parameter pytree, tree.map, tree.leaves, flatten, SGD step, grad returns a tree.

Nested dicts and lists of arrays are first-class; map over them.
params = {"w": ..., "b": ...}                       # a parameter pytree (dict of arrays)
jax.tree.map(lambda x: x*2, params)                 # map a function over every leaf
jax.tree.leaves(params)                             # list just the leaves [b, w]
jax.tree.flatten(params)                            # -> (leaves list, treedef) round-trip
jax.tree.map(lambda p, g: p - lr*g, params, grads)  # one SGD step over a tree
grad(loss)(params, batch)                           # grad returns a tree shaped like params

See Working with pytrees.

Quick Reference

JAX’s composable transforms.
Transform Turns f into Daily use
jax.grad(f) a function returning ∇f gradients for optimization
jax.value_and_grad(f) a function returning (f, ∇f) loss + gradient in one pass
jax.jit(f) an XLA-compiled f speed; fuse and run on accelerator
jax.vmap(f) a batched f add a batch axis, drop the loop
jax.jacobian(f) / jax.hessian(f) full Jacobian / Hessian sensitivity, second-order methods
Everyday JAX operations.
Command What it does Area
jnp.array(...) / jnp.arange Build a device array Arrays
jax.devices() List CPU/GPU/TPU Arrays
jax.device_put(x, dev) Place data on a device Arrays
x.block_until_ready() Wait for async compute Arrays
x.at[i].set(v) / .add(v) Immutable indexed update NumPy API
a @ b / x.sum(axis=0) Matmul / axis reduction NumPy API
grad(f)(x) Gradient function Autodiff
value_and_grad(loss)(p, b) Loss and its gradient Autodiff
jit(f) / @jit Trace + compile with XLA JIT
partial(jit, static_argnums=...) Pin compile-time constants JIT
make_jaxpr(f)(x) Inspect the traced IR JIT
vmap(f, in_axes=...) Auto-vectorize over a batch vmap
jax.random.key(0) Make an RNG key Random
jax.random.split(key) Fork independent keys Random
jax.random.normal(key, shape) Sample from a distribution Random
lax.scan(body, init, xs) Carry + collect over a sequence Control flow
lax.cond / fori_loop / while_loop Traceable branch / loops Control flow
lax.stop_gradient(x) Block the backward pass Control flow
jax.tree.map(fn, tree) Apply over pytree leaves Pytrees
jax.tree.leaves / flatten Inspect / round-trip a pytree Pytrees
The mental-model gaps to mind when porting from NumPy.
Topic NumPy JAX
Mutation x[0] = 9 (in place) x = x.at[0].set(9) (new array)
Default float float64 float32 (set jax_enable_x64 for 64-bit)
Randomness global seed explicit key, must be split
Speed primitive vectorize in C jit to XLA, run on GPU/TPU
Gradients none built in grad, value_and_grad
Loops Python for lax.scan / fori_loop inside jit

Appendix: Sample Code

The transform mental model (wrap a pure function)

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(w, b, x):
    return w * x + b

def loss(w, b, xs, ys):
    preds = predict(w, b, xs)
    return jnp.mean((preds - ys) ** 2)

xs = jnp.array([1., 2., 3.])
ys = jnp.array([2., 4., 6.])

# grad: a new function computing d(loss)/d(w, b)
gw, gb = grad(loss, argnums=(0, 1))(0.0, 0.0, xs, ys)

# jit: compile the loss for fast repeated calls
fast_loss = jit(loss)
fast_loss(0.0, 0.0, xs, ys).block_until_ready()

Immutability and functional updates

import jax.numpy as jnp

x = jnp.arange(5)        # [0 1 2 3 4]
y = x.at[0].set(99)      # NEW array [99 1 2 3 4]
z = x.at[1].add(10)      # NEW array [ 0 11 2 3 4]
# x is unchanged: [0 1 2 3 4]

A tiny gradient-descent loop (grad + jit + pytree)

import jax
import jax.numpy as jnp
from jax import grad, jit

# parameters as a pytree (dict of arrays)
params = {"w": jnp.array(0.0), "b": jnp.array(0.0)}
xs = jnp.array([1., 2., 3.])
ys = jnp.array([2., 4., 6.])

def loss(params, xs, ys):
    preds = params["w"] * xs + params["b"]
    return jnp.mean((preds - ys) ** 2)

@jit
def step(params, xs, ys, lr=0.1):
    grads = grad(loss)(params, xs, ys)          # gradient pytree, same shape as params
    return jax.tree.map(lambda p, g: p - lr * g, params, grads)

for _ in range(200):
    params = step(params, xs, ys)
# params -> approximately {"w": 2.0, "b": 0.0}

Explicit randomness: split before you draw

import jax

key = jax.random.key(0)          # seed -> reproducible
k1, k2 = jax.random.split(key)   # two independent keys

a = jax.random.normal(k1, (3,))  # never reuse a key
b = jax.random.uniform(k2, (3,)) # different stream

# Same key -> same numbers (deterministic):
jax.random.normal(key, (3,))     # [ 1.6226  2.0253 -0.4336]
jax.random.normal(key, (3,))     # identical to the line above

vmap: write for one, run for a batch

import jax.numpy as jnp
from jax import vmap

def dot(a, b):
    return jnp.dot(a, b)         # written for a single pair of vectors

A = jnp.ones((4, 3))
B = jnp.ones((4, 3))
vmap(dot)(A, B)                  # [3. 3. 3. 3.], batched over the leading axis

# hold the second argument fixed:
vmap(lambda x, y: x + y, in_axes=(0, None))(jnp.arange(3), 10)  # [10 11 12]

lax.scan: the carry-and-collect loop

import jax.numpy as jnp
from jax import lax

def body(carry, x):
    carry = carry + x
    return carry, carry          # (new carry, per-step output)

total, prefix_sums = lax.scan(body, 0, jnp.arange(4))
# total -> 6, prefix_sums -> [0 1 3 6]

Enabling 64-bit precision (when float32 is not enough)

import jax
jax.config.update("jax_enable_x64", True)   # must run before array creation
import jax.numpy as jnp
jnp.array([1.0, 2.0]).dtype                  # now float64 (default is float32)

References

JAX documentation

Project and ecosystem