Keras Cheatsheet

A visual guide to Keras 3 covering model building, the layer catalog, compile, fit/evaluate/predict, callbacks, save and load, transfer learning, and choosing a backend.

python
keras
deep-learning
cheatsheet
Author

James Balamuta

Published

June 12, 2026

Keras is the high-level deep-learning API for building, training, and shipping neural networks. A Keras model is a directed chain (or graph) of layers that maps input tensors to output tensors, and the daily workflow is short: build a model, compile it with an optimizer, loss, and metrics, then fit, evaluate, and predict. The headline feature of Keras 3 is that it is backend-agnostic: the exact same code runs on TensorFlow, JAX, or PyTorch, selected by the KERAS_BACKEND environment variable before you import keras. Throughout, data flows as NumPy arrays (or streamed datasets), and the backend-agnostic keras.ops namespace gives you a NumPy-like math API that works everywhere. This cheatsheet walks the eight steps you meet in order, from a blank model to a fine-tuned, saved network. For the classic deep-learning training loop written by hand, see the PyTorch and scikit-learn sheets.

Complete Keras cheatsheet (light mode): eight panels covering build a model, the layer catalog, compile, fit/evaluate/predict, callbacks, save and load, transfer learning, and backends.

Complete Keras cheatsheet (dark mode): eight panels covering build a model, the layer catalog, compile, fit/evaluate/predict, callbacks, save and load, transfer learning, and backends.

Download the full cheatsheet

All eight panels in a single, printable SVG.

Light SVG Dark SVG

Build a Model

A Keras model is a directed chain (or graph) of layers that maps input tensors to output tensors. keras.Sequential is the quick path for a simple linear stack, the functional API (keras.Model(inputs, outputs)) handles branches and multiple inputs or outputs, and subclassing keras.Model gives you a fully custom call for research code. Whichever you pick, model.summary() prints the layers, their output shapes, and the parameter counts so you can sanity-check the architecture before training.

keras build panel: Sequential, Input, Dense, functional API, subclassing, summary.

Three ways to assemble layers into a model.

keras build panel: Sequential, Input, Dense, functional API, subclassing, summary.

Three ways to assemble layers into a model.
import keras

model = keras.Sequential([...])                      # linear stack of layers
keras.Input(shape=(784,))                            # declare the input shape
keras.layers.Dense(128, activation="relu")           # one dense hidden layer
outputs = keras.layers.Dense(10)(x)                  # functional API: call a layer on a tensor
class Net(keras.Model): ...                          # subclass for a custom call()
model.summary()                                       # print layers, shapes, param counts

See the Sequential model guide and the Functional API guide.

The Layer Catalog

Layers are the reusable building blocks: each one holds its own weights and transforms a tensor, and you compose them like LEGO. Dense is the workhorse for tabular data and classifier heads, Conv2D and MaxPooling2D extract spatial features from images, Embedding and LSTM handle sequences and text, and utility layers like Flatten, Dropout, and Rescaling reshape, regularize, and normalize data inside the model graph.

keras layers panel: Dense, Conv2D, MaxPooling2D, Flatten, Dropout, Embedding, LSTM, Rescaling.

The building blocks you wire together.

keras layers panel: Dense, Conv2D, MaxPooling2D, Flatten, Dropout, Embedding, LSTM, Rescaling.

The building blocks you wire together.
import keras

keras.layers.Dense(64, activation="relu")     # fully connected layer
keras.layers.Conv2D(32, 3, activation="relu") # 2D convolution for images
keras.layers.MaxPooling2D(2)                   # downsample feature maps
keras.layers.Flatten()                         # unroll a grid into a 1D vector
keras.layers.Dropout(0.5)                      # regularize by dropping units
keras.layers.Embedding(10000, 64)             # token ids -> dense vectors
keras.layers.LSTM(64)                          # recurrent layer for sequences
keras.layers.Rescaling(1./255)                # scale inputs inside the graph

See the Layers API reference.

Compile the Model

compile attaches the three pieces training needs: an optimizer (how to update weights, usually adam), a loss (the single number to minimize, matched to your task), and zero or more metrics (human-readable scores reported each epoch). Match the loss to the labels: sparse_categorical_crossentropy for integer multiclass labels, binary_crossentropy for two-class problems, and mse for regression. Switch to the object form when you need options like from_logits=True.

keras compile panel: compile, Adam, sparse_categorical_crossentropy, binary_crossentropy, metrics, from_logits.

Choose the optimizer, loss, and metrics before training.

keras compile panel: compile, Adam, sparse_categorical_crossentropy, binary_crossentropy, metrics, from_logits.

Choose the optimizer, loss, and metrics before training.
import keras

model.compile(optimizer="adam",                       # wire up training
              loss="sparse_categorical_crossentropy", # multiclass loss (integer labels)
              metrics=["accuracy"])                    # extra metric reported per epoch
keras.optimizers.Adam(learning_rate=1e-3)             # pick an optimizer + learning rate
loss = "binary_crossentropy"  # two-class             # binary classification loss
loss = "mse"                  # regression            # mean squared error loss
metrics = ["accuracy", keras.metrics.AUC()]           # track several metrics at once
keras.losses.SparseCategoricalCrossentropy(from_logits=True)  # object form, raw logits

See the model training APIs.

Train, Evaluate, Predict

fit runs the training loop for you, iterating over the data in mini-batches for a number of epochs and returning a History of per-epoch losses and metrics. Pass validation_split (or a separate validation set) to watch for overfitting. After training, evaluate reports the loss and metrics on held-out test data, and predict runs a forward pass on new inputs to produce raw outputs such as class probabilities. All three accept NumPy arrays or streamed tf.data datasets interchangeably.

keras fit panel: fit, validation_split, history, evaluate, predict, tf.data dataset.

The fit/evaluate/predict loop over your data.

keras fit panel: fit, validation_split, history, evaluate, predict, tf.data dataset.

The fit/evaluate/predict loop over your data.
import keras

model.fit(x, y, epochs=10, batch_size=32)     # train on NumPy arrays
model.fit(..., validation_split=0.2)          # hold out 20% for validation
history = model.fit(...)                        # capture per-epoch loss/metrics
model.evaluate(x_test, y_test)                 # score on held-out test data
preds = model.predict(x_new)                   # forward pass on new inputs
model.fit(train_ds, epochs=10)                 # feed a streamed tf.data pipeline

See training and evaluation with the built-in methods.

Callbacks

Callbacks are objects that Keras invokes at the edges of each batch and epoch so you can observe and steer training without rewriting the loop. The daily set is EarlyStopping (halt and rewind to the best epoch when validation stops improving), ModelCheckpoint (save the best model to disk), ReduceLROnPlateau (cut the learning rate when progress stalls), and loggers like TensorBoard and CSVLogger. You collect them in a list and hand it to fit(..., callbacks=[...]).

keras callbacks panel: EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard, CSVLogger, callbacks list.

Hooks that watch and steer training between epochs.

keras callbacks panel: EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard, CSVLogger, callbacks list.

Hooks that watch and steer training between epochs.
import keras

keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)   # stop when val stalls
keras.callbacks.ModelCheckpoint("best.keras", save_best_only=True)     # save the best model
keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2)              # drop LR on a plateau
keras.callbacks.TensorBoard(log_dir="logs")                            # log to TensorBoard
keras.callbacks.CSVLogger("training.csv")                             # append metrics to a CSV
model.fit(..., callbacks=[es, ckpt, rlrop])                            # pass the list to fit

See the Callbacks API reference.

Save & Load

Saving the whole model to the modern .keras file captures everything needed to resume: the architecture, the trained weights, and the compile configuration. keras.models.load_model brings it back ready to predict or keep training. When you only want the learned numbers, save_weights and load_weights write a .weights.h5 file, to_json serializes just the structure, and model.export(...) writes a framework-level SavedModel for production serving.

keras save and load panel: save, load_model, save_weights, load_weights, to_json, export.

Persist a model, its weights, or its config to disk.

keras save and load panel: save, load_model, save_weights, load_weights, to_json, export.

Persist a model, its weights, or its config to disk.
import keras

model.save("model.keras")                              # whole model (arch + weights + compile)
model = keras.models.load_model("model.keras")         # load it back, ready to use
model.save_weights("m.weights.h5")                     # save the learned weights only
model.load_weights("m.weights.h5")                     # restore weights into a model
json_str = model.to_json()                             # architecture as JSON (no weights)
model.export("served_model")                           # export a SavedModel for serving

See serialization and saving.

Transfer Learning

Instead of training from scratch, you load a network already trained on a huge dataset from keras.applications, drop its original classifier with include_top=False, and attach a small new head for your classes. The recipe is to freeze the backbone (base.trainable = False) and train only the head, then optionally unfreeze and fine-tune the whole thing at a very small learning rate. Always run inputs through the model’s matching preprocess_input so they are scaled the way the backbone expects.

keras transfer learning panel: MobileNetV2, freeze, head, preprocess_input, train head, unfreeze and fine-tune.

Reuse a pretrained network and fine-tune a small head.

keras transfer learning panel: MobileNetV2, freeze, head, preprocess_input, train head, unfreeze and fine-tune.

Reuse a pretrained network and fine-tune a small head.
import keras

base = keras.applications.MobileNetV2(weights="imagenet", include_top=False)  # pretrained backbone
base.trainable = False                                  # freeze the backbone
keras.layers.GlobalAveragePooling2D()                   # add your own head: pool
keras.layers.Dense(n, "softmax")                        # add your own head: classify
keras.applications.mobilenet_v2.preprocess_input(x)     # match the model's input range
model.compile(...); model.fit(train_ds, epochs=5)       # train just the new head
base.trainable = True                                   # unfreeze, then recompile with Adam(1e-5)

See the transfer learning and fine-tuning guide.

Backends & Config

Keras 3 is a single high-level API that runs on top of TensorFlow, JAX, or PyTorch, chosen by setting the KERAS_BACKEND environment variable before you import keras. Write your model once against keras.layers and the backend-agnostic keras.ops namespace (a NumPy-like API), confirm the engine with keras.backend.backend(), lock reproducibility with keras.utils.set_random_seed, and reach for keras.datasets (MNIST, CIFAR, IMDB) when you just need data to experiment.

keras backend panel: KERAS_BACKEND, backend(), keras.ops, set_random_seed, set_floatx, datasets.

One Keras 3 API, your choice of TensorFlow, JAX, or PyTorch.

keras backend panel: KERAS_BACKEND, backend(), keras.ops, set_random_seed, set_floatx, datasets.

One Keras 3 API, your choice of TensorFlow, JAX, or PyTorch.
import os
os.environ["KERAS_BACKEND"] = "jax"           # pick the backend (before importing keras)
import keras

keras.backend.backend()                        # confirm the active backend -> 'jax'
keras.ops.matmul(a, b)                         # backend-agnostic tensor math (NumPy-like)
keras.ops.relu(x)                              # same ops on every backend
keras.utils.set_random_seed(42)                # make a run reproducible
keras.config.set_floatx("float32")             # set default float precision
keras.datasets.mnist.load_data()               # built-in datasets to start

See getting started with Keras 3.

Quick Reference

Key Keras 3 operations.
Command What it does Area
keras.Sequential([...]) Linear stack of layers Build
keras.Input(shape=...) Declare the input tensor shape Build
keras.Model(inputs, outputs) Functional API model (a graph) Build
model.summary() Print layers, shapes, param counts Build
keras.layers.Dense(n, activation=...) Fully connected layer Layers
keras.layers.Conv2D(f, k) 2D convolution for images Layers
model.compile(optimizer, loss, metrics) Configure training Compile
model.fit(x, y, epochs=, batch_size=) Train the model Train
model.evaluate(x_test, y_test) Score on held-out data Train
model.predict(x_new) Forward pass on new inputs Predict
keras.callbacks.EarlyStopping(...) Stop when val stops improving Callbacks
keras.callbacks.ModelCheckpoint(...) Save the best model Callbacks
model.save("model.keras") Save whole model to disk Save
keras.models.load_model("model.keras") Load a saved model Load
keras.applications.MobileNetV2(...) Pretrained backbone Transfer
base.trainable = False Freeze a backbone Transfer
os.environ["KERAS_BACKEND"] = "jax" Choose TF / JAX / PyTorch Backend
keras.utils.set_random_seed(42) Reproducible runs Backend
Picking the loss and final layer for your problem.
Task Labels look like loss= Last layer activation
Binary classification 0 / 1 "binary_crossentropy" sigmoid (1 unit)
Multiclass (integer labels) 3, 7, 1 "sparse_categorical_crossentropy" softmax (n units)
Multiclass (one-hot labels) [0,0,1,0] "categorical_crossentropy" softmax (n units)
Regression continuous numbers "mse" (or "mae") none / linear
How the different save calls differ.
Call Writes Contains
model.save("m.keras") a single .keras zip architecture + weights + compile config
model.save_weights("m.weights.h5") a .weights.h5 file learned weights only
model.to_json() a JSON string architecture only (no weights)
model.export("dir") a SavedModel folder inference graph for serving

Appendix: Sample Code

The full workflow in one block (MNIST)

The canonical “hello world” of Keras: build, compile, fit, evaluate, predict.

import keras
import numpy as np

# 1. Data: 28x28 grayscale digits, 10 classes
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test  = x_test.reshape(-1, 784).astype("float32") / 255.0

# 2. Build a simple classifier
model = keras.Sequential([
    keras.Input(shape=(784,)),
    keras.layers.Dense(128, activation="relu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation="softmax"),
])
model.summary()

# 3. Compile: optimizer + loss + metric
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",   # integer labels
    metrics=["accuracy"],
)

# 4. Train, with a validation split and early stopping
history = model.fit(
    x_train, y_train,
    epochs=10, batch_size=128, validation_split=0.1,
    callbacks=[keras.callbacks.EarlyStopping(patience=2,
                                             restore_best_weights=True)],
)

# 5. Evaluate and predict
loss, acc = model.evaluate(x_test, y_test)
probs = model.predict(x_test[:5])
print(probs.argmax(axis=1))   # predicted classes for 5 digits

The three ways to build a model

import keras

# (a) Sequential: a simple linear stack
seq = keras.Sequential([
    keras.Input(shape=(32,)),
    keras.layers.Dense(64, activation="relu"),
    keras.layers.Dense(1),
])

# (b) Functional API: graphs, branches, multiple inputs/outputs
inputs = keras.Input(shape=(32,))
x = keras.layers.Dense(64, activation="relu")(inputs)
outputs = keras.layers.Dense(1)(x)
func = keras.Model(inputs, outputs)

# (c) Subclassing: a custom forward pass for research code
class MLP(keras.Model):
    def __init__(self):
        super().__init__()
        self.h = keras.layers.Dense(64, activation="relu")
        self.out = keras.layers.Dense(1)
    def call(self, x):
        return self.out(self.h(x))

sub = MLP()

Save and reload round-trip

import keras

# Save EVERYTHING (architecture + weights + compile state)
model.save("model.keras")
restored = keras.models.load_model("model.keras")   # ready to predict or resume

# Save only the learned weights
model.save_weights("model.weights.h5")
restored.load_weights("model.weights.h5")

# Export a SavedModel for serving (TensorFlow Serving, etc.)
model.export("served_model")

Transfer learning skeleton

import keras

# 1. Pretrained backbone, classifier head removed
base = keras.applications.MobileNetV2(
    input_shape=(160, 160, 3),
    include_top=False,
    weights="imagenet",
)
base.trainable = False           # freeze it

# 2. Attach a small head for YOUR classes
model = keras.Sequential([
    base,
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(5, activation="softmax"),
])
model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
# model.fit(train_ds, epochs=5)          # train the head only

# 3. Optional: unfreeze and fine-tune at a tiny learning rate
base.trainable = True
model.compile(optimizer=keras.optimizers.Adam(1e-5),
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
# model.fit(train_ds, epochs=5)          # fine-tune end to end

Choosing a backend (set before importing keras)

import os
os.environ["KERAS_BACKEND"] = "jax"   # or "tensorflow" or "torch"

import keras
print(keras.backend.backend())        # -> 'jax'

# The same model code now runs on JAX. Backend-agnostic math:
import keras.ops as ops
ops.matmul(ops.ones((2, 3)), ops.ones((3, 4))).shape   # (2, 4)

keras.utils.set_random_seed(42)       # reproducible across the stack

References

Keras documentation

Project