Keras is the high-level deep-learning API for building, training, and shipping neural networks. A Keras model is a directed chain (or graph) of layers that maps input tensors to output tensors, and the daily workflow is short: build a model, compile it with an optimizer, loss, and metrics, then fit, evaluate, and predict. The headline feature of Keras 3 is that it is backend-agnostic: the exact same code runs on TensorFlow, JAX, or PyTorch, selected by the KERAS_BACKEND environment variable before you import keras. Throughout, data flows as NumPy arrays (or streamed datasets), and the backend-agnostic keras.ops namespace gives you a NumPy-like math API that works everywhere. This cheatsheet walks the eight steps you meet in order, from a blank model to a fine-tuned, saved network. For the classic deep-learning training loop written by hand, see the PyTorch and scikit-learn sheets.
Build a Model
A Keras model is a directed chain (or graph) of layers that maps input tensors to output tensors. keras.Sequential is the quick path for a simple linear stack, the functional API (keras.Model(inputs, outputs)) handles branches and multiple inputs or outputs, and subclassing keras.Model gives you a fully custom call for research code. Whichever you pick, model.summary() prints the layers, their output shapes, and the parameter counts so you can sanity-check the architecture before training.
import keras
model = keras.Sequential([...]) # linear stack of layers
keras.Input(shape=(784,)) # declare the input shape
keras.layers.Dense(128, activation="relu") # one dense hidden layer
outputs = keras.layers.Dense(10)(x) # functional API: call a layer on a tensor
class Net(keras.Model): ... # subclass for a custom call()
model.summary() # print layers, shapes, param countsSee the Sequential model guide and the Functional API guide.
The Layer Catalog
Layers are the reusable building blocks: each one holds its own weights and transforms a tensor, and you compose them like LEGO. Dense is the workhorse for tabular data and classifier heads, Conv2D and MaxPooling2D extract spatial features from images, Embedding and LSTM handle sequences and text, and utility layers like Flatten, Dropout, and Rescaling reshape, regularize, and normalize data inside the model graph.
import keras
keras.layers.Dense(64, activation="relu") # fully connected layer
keras.layers.Conv2D(32, 3, activation="relu") # 2D convolution for images
keras.layers.MaxPooling2D(2) # downsample feature maps
keras.layers.Flatten() # unroll a grid into a 1D vector
keras.layers.Dropout(0.5) # regularize by dropping units
keras.layers.Embedding(10000, 64) # token ids -> dense vectors
keras.layers.LSTM(64) # recurrent layer for sequences
keras.layers.Rescaling(1./255) # scale inputs inside the graphCompile the Model
compile attaches the three pieces training needs: an optimizer (how to update weights, usually adam), a loss (the single number to minimize, matched to your task), and zero or more metrics (human-readable scores reported each epoch). Match the loss to the labels: sparse_categorical_crossentropy for integer multiclass labels, binary_crossentropy for two-class problems, and mse for regression. Switch to the object form when you need options like from_logits=True.
import keras
model.compile(optimizer="adam", # wire up training
loss="sparse_categorical_crossentropy", # multiclass loss (integer labels)
metrics=["accuracy"]) # extra metric reported per epoch
keras.optimizers.Adam(learning_rate=1e-3) # pick an optimizer + learning rate
loss = "binary_crossentropy" # two-class # binary classification loss
loss = "mse" # regression # mean squared error loss
metrics = ["accuracy", keras.metrics.AUC()] # track several metrics at once
keras.losses.SparseCategoricalCrossentropy(from_logits=True) # object form, raw logitsTrain, Evaluate, Predict
fit runs the training loop for you, iterating over the data in mini-batches for a number of epochs and returning a History of per-epoch losses and metrics. Pass validation_split (or a separate validation set) to watch for overfitting. After training, evaluate reports the loss and metrics on held-out test data, and predict runs a forward pass on new inputs to produce raw outputs such as class probabilities. All three accept NumPy arrays or streamed tf.data datasets interchangeably.
import keras
model.fit(x, y, epochs=10, batch_size=32) # train on NumPy arrays
model.fit(..., validation_split=0.2) # hold out 20% for validation
history = model.fit(...) # capture per-epoch loss/metrics
model.evaluate(x_test, y_test) # score on held-out test data
preds = model.predict(x_new) # forward pass on new inputs
model.fit(train_ds, epochs=10) # feed a streamed tf.data pipelineCallbacks
Callbacks are objects that Keras invokes at the edges of each batch and epoch so you can observe and steer training without rewriting the loop. The daily set is EarlyStopping (halt and rewind to the best epoch when validation stops improving), ModelCheckpoint (save the best model to disk), ReduceLROnPlateau (cut the learning rate when progress stalls), and loggers like TensorBoard and CSVLogger. You collect them in a list and hand it to fit(..., callbacks=[...]).
import keras
keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True) # stop when val stalls
keras.callbacks.ModelCheckpoint("best.keras", save_best_only=True) # save the best model
keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2) # drop LR on a plateau
keras.callbacks.TensorBoard(log_dir="logs") # log to TensorBoard
keras.callbacks.CSVLogger("training.csv") # append metrics to a CSV
model.fit(..., callbacks=[es, ckpt, rlrop]) # pass the list to fitSave & Load
Saving the whole model to the modern .keras file captures everything needed to resume: the architecture, the trained weights, and the compile configuration. keras.models.load_model brings it back ready to predict or keep training. When you only want the learned numbers, save_weights and load_weights write a .weights.h5 file, to_json serializes just the structure, and model.export(...) writes a framework-level SavedModel for production serving.
import keras
model.save("model.keras") # whole model (arch + weights + compile)
model = keras.models.load_model("model.keras") # load it back, ready to use
model.save_weights("m.weights.h5") # save the learned weights only
model.load_weights("m.weights.h5") # restore weights into a model
json_str = model.to_json() # architecture as JSON (no weights)
model.export("served_model") # export a SavedModel for servingTransfer Learning
Instead of training from scratch, you load a network already trained on a huge dataset from keras.applications, drop its original classifier with include_top=False, and attach a small new head for your classes. The recipe is to freeze the backbone (base.trainable = False) and train only the head, then optionally unfreeze and fine-tune the whole thing at a very small learning rate. Always run inputs through the model’s matching preprocess_input so they are scaled the way the backbone expects.
import keras
base = keras.applications.MobileNetV2(weights="imagenet", include_top=False) # pretrained backbone
base.trainable = False # freeze the backbone
keras.layers.GlobalAveragePooling2D() # add your own head: pool
keras.layers.Dense(n, "softmax") # add your own head: classify
keras.applications.mobilenet_v2.preprocess_input(x) # match the model's input range
model.compile(...); model.fit(train_ds, epochs=5) # train just the new head
base.trainable = True # unfreeze, then recompile with Adam(1e-5)Backends & Config
Keras 3 is a single high-level API that runs on top of TensorFlow, JAX, or PyTorch, chosen by setting the KERAS_BACKEND environment variable before you import keras. Write your model once against keras.layers and the backend-agnostic keras.ops namespace (a NumPy-like API), confirm the engine with keras.backend.backend(), lock reproducibility with keras.utils.set_random_seed, and reach for keras.datasets (MNIST, CIFAR, IMDB) when you just need data to experiment.
import os
os.environ["KERAS_BACKEND"] = "jax" # pick the backend (before importing keras)
import keras
keras.backend.backend() # confirm the active backend -> 'jax'
keras.ops.matmul(a, b) # backend-agnostic tensor math (NumPy-like)
keras.ops.relu(x) # same ops on every backend
keras.utils.set_random_seed(42) # make a run reproducible
keras.config.set_floatx("float32") # set default float precision
keras.datasets.mnist.load_data() # built-in datasets to startQuick Reference
| Command | What it does | Area |
|---|---|---|
keras.Sequential([...]) |
Linear stack of layers | Build |
keras.Input(shape=...) |
Declare the input tensor shape | Build |
keras.Model(inputs, outputs) |
Functional API model (a graph) | Build |
model.summary() |
Print layers, shapes, param counts | Build |
keras.layers.Dense(n, activation=...) |
Fully connected layer | Layers |
keras.layers.Conv2D(f, k) |
2D convolution for images | Layers |
model.compile(optimizer, loss, metrics) |
Configure training | Compile |
model.fit(x, y, epochs=, batch_size=) |
Train the model | Train |
model.evaluate(x_test, y_test) |
Score on held-out data | Train |
model.predict(x_new) |
Forward pass on new inputs | Predict |
keras.callbacks.EarlyStopping(...) |
Stop when val stops improving | Callbacks |
keras.callbacks.ModelCheckpoint(...) |
Save the best model | Callbacks |
model.save("model.keras") |
Save whole model to disk | Save |
keras.models.load_model("model.keras") |
Load a saved model | Load |
keras.applications.MobileNetV2(...) |
Pretrained backbone | Transfer |
base.trainable = False |
Freeze a backbone | Transfer |
os.environ["KERAS_BACKEND"] = "jax" |
Choose TF / JAX / PyTorch | Backend |
keras.utils.set_random_seed(42) |
Reproducible runs | Backend |
| Task | Labels look like | loss= |
Last layer activation |
|---|---|---|---|
| Binary classification | 0 / 1 |
"binary_crossentropy" |
sigmoid (1 unit) |
| Multiclass (integer labels) | 3, 7, 1 |
"sparse_categorical_crossentropy" |
softmax (n units) |
| Multiclass (one-hot labels) | [0,0,1,0] |
"categorical_crossentropy" |
softmax (n units) |
| Regression | continuous numbers | "mse" (or "mae") |
none / linear |
| Call | Writes | Contains |
|---|---|---|
model.save("m.keras") |
a single .keras zip |
architecture + weights + compile config |
model.save_weights("m.weights.h5") |
a .weights.h5 file |
learned weights only |
model.to_json() |
a JSON string | architecture only (no weights) |
model.export("dir") |
a SavedModel folder | inference graph for serving |
Appendix: Sample Code
The full workflow in one block (MNIST)
The canonical “hello world” of Keras: build, compile, fit, evaluate, predict.
import keras
import numpy as np
# 1. Data: 28x28 grayscale digits, 10 classes
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
# 2. Build a simple classifier
model = keras.Sequential([
keras.Input(shape=(784,)),
keras.layers.Dense(128, activation="relu"),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation="softmax"),
])
model.summary()
# 3. Compile: optimizer + loss + metric
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy", # integer labels
metrics=["accuracy"],
)
# 4. Train, with a validation split and early stopping
history = model.fit(
x_train, y_train,
epochs=10, batch_size=128, validation_split=0.1,
callbacks=[keras.callbacks.EarlyStopping(patience=2,
restore_best_weights=True)],
)
# 5. Evaluate and predict
loss, acc = model.evaluate(x_test, y_test)
probs = model.predict(x_test[:5])
print(probs.argmax(axis=1)) # predicted classes for 5 digitsThe three ways to build a model
import keras
# (a) Sequential: a simple linear stack
seq = keras.Sequential([
keras.Input(shape=(32,)),
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(1),
])
# (b) Functional API: graphs, branches, multiple inputs/outputs
inputs = keras.Input(shape=(32,))
x = keras.layers.Dense(64, activation="relu")(inputs)
outputs = keras.layers.Dense(1)(x)
func = keras.Model(inputs, outputs)
# (c) Subclassing: a custom forward pass for research code
class MLP(keras.Model):
def __init__(self):
super().__init__()
self.h = keras.layers.Dense(64, activation="relu")
self.out = keras.layers.Dense(1)
def call(self, x):
return self.out(self.h(x))
sub = MLP()Save and reload round-trip
import keras
# Save EVERYTHING (architecture + weights + compile state)
model.save("model.keras")
restored = keras.models.load_model("model.keras") # ready to predict or resume
# Save only the learned weights
model.save_weights("model.weights.h5")
restored.load_weights("model.weights.h5")
# Export a SavedModel for serving (TensorFlow Serving, etc.)
model.export("served_model")Transfer learning skeleton
import keras
# 1. Pretrained backbone, classifier head removed
base = keras.applications.MobileNetV2(
input_shape=(160, 160, 3),
include_top=False,
weights="imagenet",
)
base.trainable = False # freeze it
# 2. Attach a small head for YOUR classes
model = keras.Sequential([
base,
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dropout(0.2),
keras.layers.Dense(5, activation="softmax"),
])
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
# model.fit(train_ds, epochs=5) # train the head only
# 3. Optional: unfreeze and fine-tune at a tiny learning rate
base.trainable = True
model.compile(optimizer=keras.optimizers.Adam(1e-5),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
# model.fit(train_ds, epochs=5) # fine-tune end to endChoosing a backend (set before importing keras)
import os
os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
import keras
print(keras.backend.backend()) # -> 'jax'
# The same model code now runs on JAX. Backend-agnostic math:
import keras.ops as ops
ops.matmul(ops.ones((2, 3)), ops.ones((3, 4))).shape # (2, 4)
keras.utils.set_random_seed(42) # reproducible across the stackReferences
Keras documentation
- Keras documentation home, getting started, and getting started with Keras 3
- The Sequential model guide, the Functional API guide, and subclassing layers and models
- Layers API, model training APIs, and training with the built-in methods
- Optimizers, losses, metrics, and callbacks
- Serialization and saving, model saving APIs, transfer learning, and Keras Applications
- Keras ops API, backend and config utilities, Python and seeding utilities, and built-in datasets
Project