PyTorch Lightning Cheatsheet

A visual guide to PyTorch Lightning covering the LightningModule, the training, validation, and test steps, optimizers, the Trainer, logging metrics, callbacks like checkpointing and early stopping, DataModules, and saving, loading, and exporting models.

python
pytorch
lightning
deep-learning
cheatsheet
Author

James Balamuta

Published

June 10, 2026

PyTorch Lightning is the high-level training framework that organizes raw PyTorch into two pieces: a LightningModule that says what to compute, and a Trainer that owns how the loop runs. You subclass L.LightningModule, build your layers in __init__, define a forward, and fill in step hooks like training_step; Lightning then handles the engineering around it, the loop, zero_grad, backward, step, device placement, mixed precision, and distribution. The convention for the series is import lightning as L (the modern unified package; import lightning.pytorch as L is equivalent and import pytorch_lightning as pl is the legacy alias), alongside the usual import torch and from torch import nn. Throughout, LitModel(L.LightningModule) is your subclass, model is an instance, and trainer is an L.Trainer(). This cheatsheet walks the eight things you reach for on a daily basis.

Complete PyTorch Lightning cheatsheet (light mode): eight panels covering the LightningModule, step hooks, optimizers, the Trainer, logging, callbacks, DataModules, and save/load/export.

Complete PyTorch Lightning cheatsheet (dark mode): eight panels covering the LightningModule, step hooks, optimizers, the Trainer, logging, callbacks, DataModules, and save/load/export.

Download the full cheatsheet

All eight panels in a single, printable SVG.

Light SVG Dark SVG

The LightningModule

A LightningModule is a regular torch.nn.Module with a fixed set of slots that say what to do rather than how to loop: you build your layers in __init__, define the forward pass, and Lightning fills in the engineering around it. Calling self.save_hyperparameters() records the constructor arguments so they travel inside every checkpoint and reload automatically.

LightningModule panel: subclass, __init__, save_hyperparameters, forward, instantiate.

One class that bundles model, data flow, and training logic.

LightningModule panel: subclass, __init__, save_hyperparameters, forward, instantiate.

One class that bundles model, data flow, and training logic.
import torch
from torch import nn
import lightning as L

class LitModel(L.LightningModule):           # subclass the base module
    def __init__(self):                       # build layers in the constructor
        super().__init__()
        self.save_hyperparameters()           # save constructor args as hparams
        self.net = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10))

    def forward(self, x):
        return self.net(x)                    # define the forward (inference) pass

model = LitModel(lr=1e-3)                      # instantiate the model

See LightningModule.

Step Hooks

Instead of writing a training loop, you implement training_step, validation_step, test_step, and predict_step, each of which receives one batch and a batch_idx and describes the work for a single batch. The training step returns the loss (Lightning calls backward and step for you); the other steps just log metrics or return predictions, and epoch-level hooks like on_train_epoch_end let you act once per epoch.

Step hooks panel: training_step, validation_step, test_step, predict_step, on_train_epoch_end, signature contract.

Lightning calls these per batch; you just return the loss.

Step hooks panel: training_step, validation_step, test_step, predict_step, on_train_epoch_end, signature contract.

Lightning calls these per batch; you just return the loss.
import torch.nn.functional as F

def training_step(self, batch, batch_idx):    # train on one batch
    x, y = batch
    loss = F.cross_entropy(self(x), y)
    return loss                               # Lightning calls backward + step

def validation_step(self, batch, batch_idx):  # validate on one batch
    self.log("val_loss", loss)                # log instead of return

def test_step(self, batch, batch_idx):        # test on one batch (run once)
    ...

def predict_step(self, batch, batch_idx):     # predict on one batch
    return self(batch)

def on_train_epoch_end(self):                 # hook into the epoch end
    ...

See step hooks.

Optimizers & Schedulers

The single configure_optimizers method returns your optimizer (and optionally a learning-rate scheduler), and by default Lightning runs zero_grad, backward, step for you in the right order. When you need full control, such as GANs or multiple optimizers, set self.automatic_optimization = False and drive self.optimizers() yourself; gradient clipping and accumulation are configured on the Trainer, not in the module.

Optimizers panel: configure_optimizers, scheduler dict, interval/monitor, automatic loop, manual optimization, gradient clipping.

Return optimizers from one method; Lightning steps them.

Optimizers panel: configure_optimizers, scheduler dict, interval/monitor, automatic loop, manual optimization, gradient clipping.

Return optimizers from one method; Lightning steps them.
def configure_optimizers(self):               # configure one optimizer
    return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

return {"optimizer": opt, "lr_scheduler": sched}                      # add an LR scheduler
{"scheduler": sched, "interval": "epoch", "monitor": "val_loss"}      # step interval + monitor

# loss.backward() and opt.step() are automatic by default

self.automatic_optimization = False           # multiple optimizers (manual)
opt = self.optimizers()                        # you call opt.step() yourself

L.Trainer(gradient_clip_val=1.0)               # clip gradients (set on the Trainer)

See configure_optimizers.

The Trainer

The Trainer is the engine that owns the loop and the hardware: you set how long to train (max_epochs) and where (accelerator, devices), then call fit, validate, test, or predict. Because device placement, mixed precision, and distribution all live in the Trainer, the same LightningModule runs on CPU, one GPU, or many without code changes, and fast_dev_run=True runs a single batch through the whole pipeline to catch bugs in seconds.

Trainer panel: build trainer, pick hardware, fit, test, predict, fast_dev_run.

The engine that owns the loop and the hardware.

Trainer panel: build trainer, pick hardware, fit, test, predict, fast_dev_run.

The engine that owns the loop and the hardware.
trainer = L.Trainer(max_epochs=10)                     # build the trainer
L.Trainer(accelerator="auto", devices="auto")          # pick hardware automatically
trainer.fit(model, train_dl, val_dl)                   # run the full fit loop
trainer.test(model, test_dl)                           # evaluate on a held-out set
preds = trainer.predict(model, dl)                     # generate predictions
L.Trainer(fast_dev_run=True)                           # smoke-test the whole pipeline

See Trainer.

Logging Metrics

Anywhere inside a step you call self.log("name", value), and Lightning handles aggregation across batches and devices, deciding whether to record per step or per epoch via on_step and on_epoch. Where those numbers go is set once on the Trainer with a logger (CSV, TensorBoard, Weights and Biases, MLflow), and logging a torchmetrics metric object lets the metric accumulate the correct value over an epoch instead of averaging per-batch numbers.

Logging panel: log a scalar, prog_bar, on_step/on_epoch, log_dict, pick a logger, torchmetrics object.

Call self.log; Lightning aggregates and routes it.

Logging panel: log a scalar, prog_bar, on_step/on_epoch, log_dict, pick a logger, torchmetrics object.

Call self.log; Lightning aggregates and routes it.
from lightning.pytorch.loggers import CSVLogger
import torchmetrics

self.log("train_loss", loss)                                  # log a scalar
self.log("val_acc", acc, prog_bar=True)                       # show in the progress bar
self.log("loss", loss, on_step=True, on_epoch=True)           # control step vs epoch reduce
self.log_dict({"acc": acc, "f1": f1})                         # log several at once

L.Trainer(logger=CSVLogger("logs"))                           # pick a backend logger

self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
self.log("acc", self.acc)                                     # track a real metric object

See logging.

Callbacks

Callbacks inject behavior into the training loop without cluttering the model: ModelCheckpoint saves the best weights by a monitored metric, EarlyStopping halts training when it stops improving, and LearningRateMonitor logs the LR over time. You attach them as a list to Trainer(callbacks=[...]), resume an interrupted run by passing ckpt_path to fit (which restores optimizer state too), and write your own by subclassing L.Callback and filling in on_* hooks.

Callbacks panel: ModelCheckpoint, EarlyStopping, LearningRateMonitor, attach to trainer, resume from checkpoint, write your own.

Plug behavior into the loop without touching the model.

Callbacks panel: ModelCheckpoint, EarlyStopping, LearningRateMonitor, attach to trainer, resume from checkpoint, write your own.

Plug behavior into the loop without touching the model.
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)    # save the best checkpoint
EarlyStopping(monitor="val_loss", patience=3)                    # stop when it plateaus
LearningRateMonitor(logging_interval="step")                     # log the learning rate

L.Trainer(callbacks=[ckpt, es, lr])                              # attach them to the trainer
trainer.fit(model, dl, ckpt_path="last.ckpt")                    # resume from a checkpoint

class MyCb(L.Callback):                                          # write your own
    def on_train_epoch_end(self, trainer, pl_module):
        ...

See callbacks.

DataModule

A LightningDataModule packages all of your data logic, download, split, and dataloader construction, into one shareable, reproducible class. prepare_data runs once for one-time downloads, setup(stage) builds the train/val/test splits per process, and the *_dataloader methods return the loaders; you then pass datamodule=dm to the Trainer instead of juggling loose DataLoaders.

DataModule panel: subclass, prepare_data, setup, train_dataloader, val/test/predict loaders, hand to the trainer.

Package download, split, and dataloaders in one class.

DataModule panel: subclass, prepare_data, setup, train_dataloader, val/test/predict loaders, hand to the trainer.

Package download, split, and dataloaders in one class.
from torch.utils.data import DataLoader, random_split

class DataMod(L.LightningDataModule):                  # subclass the data base
    def prepare_data(self):                            # download once, on rank 0
        ...                                            # download here

    def setup(self, stage):                            # build splits per process
        self.train, self.val = random_split(...)

    def train_dataloader(self):                        # expose the train loader
        return DataLoader(self.train, batch_size=32, shuffle=True)

    def val_dataloader(self):                          # add val / test / predict loaders
        ...

trainer.fit(model, datamodule=dm)                      # hand it to the trainer

See LightningDataModule.

Save, Load & Export

A Lightning checkpoint is more than weights: it stores the state_dict, the saved hyperparameters, optimizer and scheduler state, and the epoch, so LitModel.load_from_checkpoint(path) rebuilds a ready-to-use model without re-passing constructor arguments. For deployment you export the trained network with to_onnx (a portable graph for other runtimes) or torch.export.export (the current TorchScript-free graph export, saved as a .pt2 that runs without Python). Note that LightningModule.to_torchscript is deprecated in Lightning 2.7 and removed in 2.8 because TorchScript itself is deprecated in PyTorch, so prefer torch.export.export.

Save, load, export panel: save_checkpoint, load_from_checkpoint, override hparams, raw state_dict, to_onnx, torch.export.

Checkpoints capture everything; export for serving.

Save, load, export panel: save_checkpoint, load_from_checkpoint, override hparams, raw state_dict, to_onnx, torch.export.

Checkpoints capture everything; export for serving.
trainer.save_checkpoint("model.ckpt")                            # save a checkpoint by hand
LitModel.load_from_checkpoint("model.ckpt")                      # reload weights + hparams
LitModel.load_from_checkpoint("model.ckpt", lr=1e-4)             # override hparams on load
torch.load("model.ckpt", weights_only=False)["state_dict"]       # get just the weights dict

model.to_onnx("model.onnx", input_sample, export_params=True)    # export to ONNX

torch.export.export(model, (input_sample,))                      # export the graph (TorchScript-free)
torch.export.save(ep, "model.pt2")                               # self-contained .pt2 artifact

See checkpointing.

Quick Reference

What goes in a LightningModule.
Method When Lightning calls it You return / do
__init__ Once, when you build the model Define layers, save_hyperparameters()
forward(x) Inference / inside your steps The model output (logits)
training_step(batch, batch_idx) Every train batch return loss (or a dict with "loss")
validation_step(batch, batch_idx) Every val batch Log metrics (no return needed)
test_step(batch, batch_idx) Every test batch Log metrics
predict_step(batch, batch_idx) Every predict batch return predictions
configure_optimizers() Once, before fitting Optimizer (and optional scheduler)
How you drive the Trainer.
Call What it runs Needs
trainer.fit(model, train_dl, val_dl) Full train + validation loop Train loader (val optional)
trainer.validate(model, val_dl) Validation only Val loader
trainer.test(model, test_dl) Test only (run once) Test loader
trainer.predict(model, dl) Inference, returns outputs Any loader
trainer.fit(..., ckpt_path="x.ckpt") Resume from a checkpoint A saved .ckpt
Frequently used Trainer arguments.
Flag Meaning Typical value
max_epochs How long to train 10
accelerator Hardware kind "auto", "gpu", "cpu"
devices How many / which devices "auto", 1, [0, 1]
precision Numeric precision "32-true", "16-mixed", "bf16-mixed"
callbacks List of callbacks [ckpt, early_stop]
logger Where metrics go CSVLogger(...), TensorBoardLogger(...)
fast_dev_run Single-batch smoke test True
gradient_clip_val Clip gradient norm 1.0
accumulate_grad_batches Simulate a larger batch 4
Controlling what self.log records.
Argument Effect Default
prog_bar=True Show in the progress bar False
on_step=True Record every batch varies by step
on_epoch=True Record an epoch aggregate varies by step
logger=True Send to the configured logger True
sync_dist=True Reduce across devices (DDP) False

Appendix: Sample Code

A complete, minimal LightningModule

The four slots that make a runnable model. save_hyperparameters() is what lets load_from_checkpoint work without re-passing lr.

import torch
from torch import nn
import torch.nn.functional as F
import lightning as L
import torchmetrics

class LitClassifier(L.LightningModule):
    def __init__(self, in_dim=784, hidden=128, n_classes=10, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()          # stored in every checkpoint
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, n_classes),
        )
        self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.log("train_loss", loss, prog_bar=True)
        return loss                          # Lightning calls backward + step

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        self.log("val_loss", F.cross_entropy(logits, y), prog_bar=True)
        self.acc(logits, y)
        self.log("val_acc", self.acc, prog_bar=True)   # torchmetrics object

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

Train it with a Trainer and callbacks

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

L.seed_everything(42)                        # reproducible runs

ckpt = ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
early = EarlyStopping(monitor="val_loss", patience=3, mode="min")

trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",                      # CPU / GPU / MPS, picked for you
    devices="auto",
    callbacks=[ckpt, early],
)
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
trainer.test(model, dataloaders=test_dl)

print(ckpt.best_model_path)                  # path to the best checkpoint

A LightningDataModule

import lightning as L
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

class MNISTData(L.LightningDataModule):
    def __init__(self, data_dir="data", batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.tf = transforms.ToTensor()

    def prepare_data(self):                   # runs once (download only)
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):              # runs on every process
        full = MNIST(self.data_dir, train=True, transform=self.tf)
        self.train, self.val = random_split(full, [55000, 5000])
        self.test = MNIST(self.data_dir, train=False, transform=self.tf)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

# Usage: trainer.fit(model, datamodule=MNISTData())

Save, load, and export

# A checkpoint stores weights + hyperparameters + optimizer state + epoch.
trainer.save_checkpoint("model.ckpt")

# Reload a ready-to-use model; no need to re-pass constructor args.
best = LitClassifier.load_from_checkpoint("model.ckpt")

# Override a stored hyperparameter at load time:
tuned = LitClassifier.load_from_checkpoint("model.ckpt", lr=1e-4)

# Pull out just the raw weights (torch.load needs weights_only=False for a
# Lightning .ckpt, which holds non-tensor objects like hparams + optimizer state):
state_dict = torch.load("model.ckpt", weights_only=False)["state_dict"]

# Resume an interrupted run (restores optimizer state, not just weights):
# trainer.fit(model, datamodule=dm, ckpt_path="model.ckpt")

# Export for serving:
best.to_onnx("model.onnx", torch.randn(1, 784),
             export_params=True)                      # portable graph (needs `onnx`)

# TorchScript-free graph export (to_torchscript is deprecated in 2.7, gone in 2.8):
exported = torch.export.export(best, (torch.randn(1, 784),))
torch.export.save(exported, "model.pt2")              # self-contained .pt2

The same loop in raw PyTorch (the “why Lightning” comparison)

Lightning replaces this hand-written boilerplate with the four slots above:

# Raw PyTorch: you own the loop, device moves, zero_grad, backward, step.
model = MyNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(10):
    model.train()
    for x, y in train_dl:
        x, y = x.to(device), y.to(device)     # Lightning does device moves
        opt.zero_grad()                        # Lightning does zero_grad
        loss = F.cross_entropy(model(x), y)
        loss.backward()                        # Lightning does backward
        opt.step()                             # Lightning does step
    # ... plus your own val loop, metric averaging, checkpointing, logging ...

References

PyTorch Lightning documentation

Related libraries and project