JAX is “NumPy you can transform.” You write ordinary, pure functions over jnp arrays, then wrap them in higher-order transforms, grad to differentiate, jit to compile, and vmap to vectorize, each of which returns a brand new function. The whole library teaches that one wrapping mental model, plus the few rules that make it work: arrays are immutable, randomness is explicit, and traced functions must avoid Python side effects. If you know NumPy and SciPy, the API will feel familiar; what is new is the layer of transformations on top. The convention throughout is import jax, import jax.numpy as jnp, and from jax import grad, jit, vmap.
Arrays & Devices
A JAX array looks like a NumPy array but lives on whatever accelerator JAX found (CPU, GPU, or TPU) and is immutable. Note that the default float dtype is float32, not float64. Computation is dispatched asynchronously, so you call block_until_ready() (or convert back to NumPy) when you actually need the values on the host.
import jax
import jax.numpy as jnp
jnp.array([1., 2., 3.]) # array from a list -> float32 by default
jnp.arange(5) jnp.linspace(0, 1, 3) # range / evenly spaced values
jnp.asarray(np_array) # convert a NumPy array (host -> device)
jax.devices() jax.default_backend() # list devices -> [CpuDevice(id=0)], 'cpu'
jax.device_put(x, jax.devices()[0]) # place data on a chosen device
x.block_until_ready() # wait for async compute to finishSee Key concepts.
NumPy-style API & Immutable Updates
jax.numpy (imported as jnp) re-implements most of the NumPy API, math, reductions, indexing, and linear algebra, so existing array code ports over with minimal changes. Because arrays cannot be mutated, in-place writes are replaced by the functional update syntax x.at[idx].set(...) and .add(...), which returns a new array and leaves the original untouched.
jnp.exp(x) x * 2 + 1 # element-wise math, vectorized
a @ b jnp.dot(a, b) # matrix multiply (sum of products)
x.sum(axis=0) x.mean() # reduce over an axis -> [5 7 9]
m[1, 2] m[:, 0] # indexing reads just like numpy
x.at[0].set(99) # functional "set" -> NEW array [99 1 2 3 4]
x.at[1].add(10) # functional "add" -> NEW array [0 11 2 3 4]Automatic Differentiation
grad is a higher-order transform: give it a function that returns a scalar and it returns a new function computing the gradient, which you can nest (grad(grad(f))) or point at specific arguments with argnums. Use value_and_grad to get the loss and its gradient in one pass, and jax.jacobian or jax.hessian for full first- and second-order matrices.
from jax import grad, value_and_grad
grad(f)(2.0) # gradient of a scalar function -> 12.0
value_and_grad(loss)(w, x, y) # value and gradient together -> (loss, grad)
grad(loss, argnums=(0, 1))(w, b, x) # differentiate chosen args -> (gw, gb)
grad(grad(f))(2.0) # higher-order (2nd) derivative -> 12.0
jax.jacobian(f)(x) # full Jacobian matrix
jax.hessian(g)(x) # Hessian (curvature) matrixJIT Compilation
jit traces your Python function once on abstract “shaped” inputs, hands the resulting graph to the XLA compiler, and caches the compiled binary keyed by input shape and dtype, so the first call is slow and later calls are fast. Because tracing runs your Python only once, the function must be pure, with no data-dependent Python if/for over array values and no side effects, and values that should be compile-time constants go in static_argnums.
from jax import jit
from functools import partial
jit(f) @jit # compile a function with XLA
f(x); f(x) # trace once, 2nd call fast (cached by shape+dtype)
partial(jit, static_argnums=(1,)) # mark compile-time constants (recompiles if changed)
jax.make_jaxpr(f)(x) # inspect the traced program (the IR)
f(x).block_until_ready() # block before timing
# avoid print() / .item() inside a jit'd fn: traced, runs onceAuto-Vectorization (vmap)
vmap lets you write a function for a single example and then automatically adds a batch dimension, so you never hand-write a batch loop or fiddle with broadcasting. in_axes and out_axes say which axis of each argument and result is the batch axis (use None to hold an argument fixed), and vmap composes cleanly with grad and jit.
from jax import vmap
vmap(f)(batch) # batch a per-example function
vmap(dot)(A, B) # map over the leading axis -> [3 3 3 3]
vmap(f, in_axes=(0, None)) # map arg 0 per-row, hold arg 1 fixed
vmap(f, out_axes=1) # stack results into columns
vmap(grad(f))(xs) # compose with grad: one gradient per input
# replaces: for x in batch: ... # no batch loop, same speed, less codeRandom Numbers (Explicit Keys)
JAX has no hidden global RNG state. Instead you create an explicit key from a seed and pass it into every sampling call, which makes randomness reproducible and parallel-safe. The golden rule is never reuse a key: split it into fresh independent keys (or fold_in a step index) before each new draw, because the same key always produces the same numbers.
key = jax.random.key(0) # create a key from a seed (reproducible)
k1, k2 = jax.random.split(key) # split into independent keys (never reuse)
jax.random.normal(key, (3,)) # standard normals -> [1.62 2.03 -0.43]
jax.random.uniform(key, (3,)) # uniform [0,1) draws
jax.random.randint(key, (3,), 0, 10) # integer draws -> [9 0 2]
jax.random.fold_in(key, step) # per-step key, stable per indexSee Pseudorandom numbers.
Control Flow & Loops
Inside a jit or grad trace, a regular Python if or for runs on the trace (it gets unrolled or fails on array-valued conditions), so data-dependent control flow uses the lax primitives cond, fori_loop, while_loop, and scan. scan is the workhorse for sequential models: it threads a carry through a sequence while collecting per-step outputs, and lax.stop_gradient blocks a value from contributing to the backward pass.
from jax import lax
lax.cond(p, true_fn, false_fn, x) # data-dependent branch -> 11
lax.fori_loop(0, 5, body, init) # fixed-count loop (i: 0->4) -> 10
lax.while_loop(cond, body, init) # loop while a condition holds -> 12
lax.scan(body, init, xs) # carry + collect -> (carry 6, ys [0 1 3 6])
# avoid a Python `for` over a traced array: it unrolls or errors -> use scan
jax.lax.stop_gradient(x) # block a value from the backward passSee Control flow.
Pytrees & Parameters
A pytree is any nested structure of dicts, lists, and tuples whose leaves are arrays, exactly how you store model parameters and optimizer state. jax.tree.map applies a function to every leaf (across one or several matching trees) while preserving the structure, so a whole gradient-descent update or a grad result is just a tree the same shape as your parameters.
params = {"w": ..., "b": ...} # a parameter pytree (dict of arrays)
jax.tree.map(lambda x: x*2, params) # map a function over every leaf
jax.tree.leaves(params) # list just the leaves [b, w]
jax.tree.flatten(params) # -> (leaves list, treedef) round-trip
jax.tree.map(lambda p, g: p - lr*g, params, grads) # one SGD step over a tree
grad(loss)(params, batch) # grad returns a tree shaped like paramsSee Working with pytrees.
Quick Reference
| Transform | Turns f into |
Daily use |
|---|---|---|
jax.grad(f) |
a function returning ∇f |
gradients for optimization |
jax.value_and_grad(f) |
a function returning (f, ∇f) |
loss + gradient in one pass |
jax.jit(f) |
an XLA-compiled f |
speed; fuse and run on accelerator |
jax.vmap(f) |
a batched f |
add a batch axis, drop the loop |
jax.jacobian(f) / jax.hessian(f) |
full Jacobian / Hessian | sensitivity, second-order methods |
| Command | What it does | Area |
|---|---|---|
jnp.array(...) / jnp.arange |
Build a device array | Arrays |
jax.devices() |
List CPU/GPU/TPU | Arrays |
jax.device_put(x, dev) |
Place data on a device | Arrays |
x.block_until_ready() |
Wait for async compute | Arrays |
x.at[i].set(v) / .add(v) |
Immutable indexed update | NumPy API |
a @ b / x.sum(axis=0) |
Matmul / axis reduction | NumPy API |
grad(f)(x) |
Gradient function | Autodiff |
value_and_grad(loss)(p, b) |
Loss and its gradient | Autodiff |
jit(f) / @jit |
Trace + compile with XLA | JIT |
partial(jit, static_argnums=...) |
Pin compile-time constants | JIT |
make_jaxpr(f)(x) |
Inspect the traced IR | JIT |
vmap(f, in_axes=...) |
Auto-vectorize over a batch | vmap |
jax.random.key(0) |
Make an RNG key | Random |
jax.random.split(key) |
Fork independent keys | Random |
jax.random.normal(key, shape) |
Sample from a distribution | Random |
lax.scan(body, init, xs) |
Carry + collect over a sequence | Control flow |
lax.cond / fori_loop / while_loop |
Traceable branch / loops | Control flow |
lax.stop_gradient(x) |
Block the backward pass | Control flow |
jax.tree.map(fn, tree) |
Apply over pytree leaves | Pytrees |
jax.tree.leaves / flatten |
Inspect / round-trip a pytree | Pytrees |
| Topic | NumPy | JAX |
|---|---|---|
| Mutation | x[0] = 9 (in place) |
x = x.at[0].set(9) (new array) |
| Default float | float64 |
float32 (set jax_enable_x64 for 64-bit) |
| Randomness | global seed | explicit key, must be split |
| Speed primitive | vectorize in C | jit to XLA, run on GPU/TPU |
| Gradients | none built in | grad, value_and_grad |
| Loops | Python for |
lax.scan / fori_loop inside jit |
Appendix: Sample Code
The transform mental model (wrap a pure function)
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
def predict(w, b, x):
return w * x + b
def loss(w, b, xs, ys):
preds = predict(w, b, xs)
return jnp.mean((preds - ys) ** 2)
xs = jnp.array([1., 2., 3.])
ys = jnp.array([2., 4., 6.])
# grad: a new function computing d(loss)/d(w, b)
gw, gb = grad(loss, argnums=(0, 1))(0.0, 0.0, xs, ys)
# jit: compile the loss for fast repeated calls
fast_loss = jit(loss)
fast_loss(0.0, 0.0, xs, ys).block_until_ready()Immutability and functional updates
import jax.numpy as jnp
x = jnp.arange(5) # [0 1 2 3 4]
y = x.at[0].set(99) # NEW array [99 1 2 3 4]
z = x.at[1].add(10) # NEW array [ 0 11 2 3 4]
# x is unchanged: [0 1 2 3 4]A tiny gradient-descent loop (grad + jit + pytree)
import jax
import jax.numpy as jnp
from jax import grad, jit
# parameters as a pytree (dict of arrays)
params = {"w": jnp.array(0.0), "b": jnp.array(0.0)}
xs = jnp.array([1., 2., 3.])
ys = jnp.array([2., 4., 6.])
def loss(params, xs, ys):
preds = params["w"] * xs + params["b"]
return jnp.mean((preds - ys) ** 2)
@jit
def step(params, xs, ys, lr=0.1):
grads = grad(loss)(params, xs, ys) # gradient pytree, same shape as params
return jax.tree.map(lambda p, g: p - lr * g, params, grads)
for _ in range(200):
params = step(params, xs, ys)
# params -> approximately {"w": 2.0, "b": 0.0}Explicit randomness: split before you draw
import jax
key = jax.random.key(0) # seed -> reproducible
k1, k2 = jax.random.split(key) # two independent keys
a = jax.random.normal(k1, (3,)) # never reuse a key
b = jax.random.uniform(k2, (3,)) # different stream
# Same key -> same numbers (deterministic):
jax.random.normal(key, (3,)) # [ 1.6226 2.0253 -0.4336]
jax.random.normal(key, (3,)) # identical to the line abovevmap: write for one, run for a batch
import jax.numpy as jnp
from jax import vmap
def dot(a, b):
return jnp.dot(a, b) # written for a single pair of vectors
A = jnp.ones((4, 3))
B = jnp.ones((4, 3))
vmap(dot)(A, B) # [3. 3. 3. 3.], batched over the leading axis
# hold the second argument fixed:
vmap(lambda x, y: x + y, in_axes=(0, None))(jnp.arange(3), 10) # [10 11 12]lax.scan: the carry-and-collect loop
import jax.numpy as jnp
from jax import lax
def body(carry, x):
carry = carry + x
return carry, carry # (new carry, per-step output)
total, prefix_sums = lax.scan(body, 0, jnp.arange(4))
# total -> 6, prefix_sums -> [0 1 3 6]Enabling 64-bit precision (when float32 is not enough)
import jax
jax.config.update("jax_enable_x64", True) # must run before array creation
import jax.numpy as jnp
jnp.array([1.0, 2.0]).dtype # now float64 (default is float32)References
JAX documentation
- JAX documentation home, Quickstart / how to think in JAX, and Installation
- Key concepts, Indexed update operators, and JAX, The Sharp Bits
jax.numpyAPI reference, Automatic differentiation, and the autodiff cookbook- Just-in-time compilation, Automatic vectorization, and Pseudorandom numbers
- Control flow, the
jax.laxreference, Working with pytrees, and thejax.tree_utilreference
Project and ecosystem