PyTorch Lightning is the high-level training framework that organizes raw PyTorch into two pieces: a LightningModule that says what to compute, and a Trainer that owns how the loop runs. You subclass L.LightningModule, build your layers in __init__, define a forward, and fill in step hooks like training_step; Lightning then handles the engineering around it, the loop, zero_grad, backward, step, device placement, mixed precision, and distribution. The convention for the series is import lightning as L (the modern unified package; import lightning.pytorch as L is equivalent and import pytorch_lightning as pl is the legacy alias), alongside the usual import torch and from torch import nn. Throughout, LitModel(L.LightningModule) is your subclass, model is an instance, and trainer is an L.Trainer(). This cheatsheet walks the eight things you reach for on a daily basis.
The LightningModule
A LightningModule is a regular torch.nn.Module with a fixed set of slots that say what to do rather than how to loop: you build your layers in __init__, define the forward pass, and Lightning fills in the engineering around it. Calling self.save_hyperparameters() records the constructor arguments so they travel inside every checkpoint and reload automatically.
import torch
from torch import nn
import lightning as L
class LitModel(L.LightningModule): # subclass the base module
def __init__(self): # build layers in the constructor
super().__init__()
self.save_hyperparameters() # save constructor args as hparams
self.net = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10))
def forward(self, x):
return self.net(x) # define the forward (inference) pass
model = LitModel(lr=1e-3) # instantiate the modelSee LightningModule.
Step Hooks
Instead of writing a training loop, you implement training_step, validation_step, test_step, and predict_step, each of which receives one batch and a batch_idx and describes the work for a single batch. The training step returns the loss (Lightning calls backward and step for you); the other steps just log metrics or return predictions, and epoch-level hooks like on_train_epoch_end let you act once per epoch.
import torch.nn.functional as F
def training_step(self, batch, batch_idx): # train on one batch
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss # Lightning calls backward + step
def validation_step(self, batch, batch_idx): # validate on one batch
self.log("val_loss", loss) # log instead of return
def test_step(self, batch, batch_idx): # test on one batch (run once)
...
def predict_step(self, batch, batch_idx): # predict on one batch
return self(batch)
def on_train_epoch_end(self): # hook into the epoch end
...See step hooks.
Optimizers & Schedulers
The single configure_optimizers method returns your optimizer (and optionally a learning-rate scheduler), and by default Lightning runs zero_grad, backward, step for you in the right order. When you need full control, such as GANs or multiple optimizers, set self.automatic_optimization = False and drive self.optimizers() yourself; gradient clipping and accumulation are configured on the Trainer, not in the module.
def configure_optimizers(self): # configure one optimizer
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return {"optimizer": opt, "lr_scheduler": sched} # add an LR scheduler
{"scheduler": sched, "interval": "epoch", "monitor": "val_loss"} # step interval + monitor
# loss.backward() and opt.step() are automatic by default
self.automatic_optimization = False # multiple optimizers (manual)
opt = self.optimizers() # you call opt.step() yourself
L.Trainer(gradient_clip_val=1.0) # clip gradients (set on the Trainer)See configure_optimizers.
The Trainer
The Trainer is the engine that owns the loop and the hardware: you set how long to train (max_epochs) and where (accelerator, devices), then call fit, validate, test, or predict. Because device placement, mixed precision, and distribution all live in the Trainer, the same LightningModule runs on CPU, one GPU, or many without code changes, and fast_dev_run=True runs a single batch through the whole pipeline to catch bugs in seconds.
trainer = L.Trainer(max_epochs=10) # build the trainer
L.Trainer(accelerator="auto", devices="auto") # pick hardware automatically
trainer.fit(model, train_dl, val_dl) # run the full fit loop
trainer.test(model, test_dl) # evaluate on a held-out set
preds = trainer.predict(model, dl) # generate predictions
L.Trainer(fast_dev_run=True) # smoke-test the whole pipelineSee Trainer.
Logging Metrics
Anywhere inside a step you call self.log("name", value), and Lightning handles aggregation across batches and devices, deciding whether to record per step or per epoch via on_step and on_epoch. Where those numbers go is set once on the Trainer with a logger (CSV, TensorBoard, Weights and Biases, MLflow), and logging a torchmetrics metric object lets the metric accumulate the correct value over an epoch instead of averaging per-batch numbers.
from lightning.pytorch.loggers import CSVLogger
import torchmetrics
self.log("train_loss", loss) # log a scalar
self.log("val_acc", acc, prog_bar=True) # show in the progress bar
self.log("loss", loss, on_step=True, on_epoch=True) # control step vs epoch reduce
self.log_dict({"acc": acc, "f1": f1}) # log several at once
L.Trainer(logger=CSVLogger("logs")) # pick a backend logger
self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
self.log("acc", self.acc) # track a real metric objectSee logging.
Callbacks
Callbacks inject behavior into the training loop without cluttering the model: ModelCheckpoint saves the best weights by a monitored metric, EarlyStopping halts training when it stops improving, and LearningRateMonitor logs the LR over time. You attach them as a list to Trainer(callbacks=[...]), resume an interrupted run by passing ckpt_path to fit (which restores optimizer state too), and write your own by subclassing L.Callback and filling in on_* hooks.
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1) # save the best checkpoint
EarlyStopping(monitor="val_loss", patience=3) # stop when it plateaus
LearningRateMonitor(logging_interval="step") # log the learning rate
L.Trainer(callbacks=[ckpt, es, lr]) # attach them to the trainer
trainer.fit(model, dl, ckpt_path="last.ckpt") # resume from a checkpoint
class MyCb(L.Callback): # write your own
def on_train_epoch_end(self, trainer, pl_module):
...See callbacks.
DataModule
A LightningDataModule packages all of your data logic, download, split, and dataloader construction, into one shareable, reproducible class. prepare_data runs once for one-time downloads, setup(stage) builds the train/val/test splits per process, and the *_dataloader methods return the loaders; you then pass datamodule=dm to the Trainer instead of juggling loose DataLoaders.
from torch.utils.data import DataLoader, random_split
class DataMod(L.LightningDataModule): # subclass the data base
def prepare_data(self): # download once, on rank 0
... # download here
def setup(self, stage): # build splits per process
self.train, self.val = random_split(...)
def train_dataloader(self): # expose the train loader
return DataLoader(self.train, batch_size=32, shuffle=True)
def val_dataloader(self): # add val / test / predict loaders
...
trainer.fit(model, datamodule=dm) # hand it to the trainerSee LightningDataModule.
Save, Load & Export
A Lightning checkpoint is more than weights: it stores the state_dict, the saved hyperparameters, optimizer and scheduler state, and the epoch, so LitModel.load_from_checkpoint(path) rebuilds a ready-to-use model without re-passing constructor arguments. For deployment you export the trained network with to_onnx (a portable graph for other runtimes) or torch.export.export (the current TorchScript-free graph export, saved as a .pt2 that runs without Python). Note that LightningModule.to_torchscript is deprecated in Lightning 2.7 and removed in 2.8 because TorchScript itself is deprecated in PyTorch, so prefer torch.export.export.
trainer.save_checkpoint("model.ckpt") # save a checkpoint by hand
LitModel.load_from_checkpoint("model.ckpt") # reload weights + hparams
LitModel.load_from_checkpoint("model.ckpt", lr=1e-4) # override hparams on load
torch.load("model.ckpt", weights_only=False)["state_dict"] # get just the weights dict
model.to_onnx("model.onnx", input_sample, export_params=True) # export to ONNX
torch.export.export(model, (input_sample,)) # export the graph (TorchScript-free)
torch.export.save(ep, "model.pt2") # self-contained .pt2 artifactSee checkpointing.
Quick Reference
| Method | When Lightning calls it | You return / do |
|---|---|---|
__init__ |
Once, when you build the model | Define layers, save_hyperparameters() |
forward(x) |
Inference / inside your steps | The model output (logits) |
training_step(batch, batch_idx) |
Every train batch | return loss (or a dict with "loss") |
validation_step(batch, batch_idx) |
Every val batch | Log metrics (no return needed) |
test_step(batch, batch_idx) |
Every test batch | Log metrics |
predict_step(batch, batch_idx) |
Every predict batch | return predictions |
configure_optimizers() |
Once, before fitting | Optimizer (and optional scheduler) |
| Call | What it runs | Needs |
|---|---|---|
trainer.fit(model, train_dl, val_dl) |
Full train + validation loop | Train loader (val optional) |
trainer.validate(model, val_dl) |
Validation only | Val loader |
trainer.test(model, test_dl) |
Test only (run once) | Test loader |
trainer.predict(model, dl) |
Inference, returns outputs | Any loader |
trainer.fit(..., ckpt_path="x.ckpt") |
Resume from a checkpoint | A saved .ckpt |
| Flag | Meaning | Typical value |
|---|---|---|
max_epochs |
How long to train | 10 |
accelerator |
Hardware kind | "auto", "gpu", "cpu" |
devices |
How many / which devices | "auto", 1, [0, 1] |
precision |
Numeric precision | "32-true", "16-mixed", "bf16-mixed" |
callbacks |
List of callbacks | [ckpt, early_stop] |
logger |
Where metrics go | CSVLogger(...), TensorBoardLogger(...) |
fast_dev_run |
Single-batch smoke test | True |
gradient_clip_val |
Clip gradient norm | 1.0 |
accumulate_grad_batches |
Simulate a larger batch | 4 |
| Argument | Effect | Default |
|---|---|---|
prog_bar=True |
Show in the progress bar | False |
on_step=True |
Record every batch | varies by step |
on_epoch=True |
Record an epoch aggregate | varies by step |
logger=True |
Send to the configured logger | True |
sync_dist=True |
Reduce across devices (DDP) | False |
Appendix: Sample Code
A complete, minimal LightningModule
The four slots that make a runnable model. save_hyperparameters() is what lets load_from_checkpoint work without re-passing lr.
import torch
from torch import nn
import torch.nn.functional as F
import lightning as L
import torchmetrics
class LitClassifier(L.LightningModule):
def __init__(self, in_dim=784, hidden=128, n_classes=10, lr=1e-3):
super().__init__()
self.save_hyperparameters() # stored in every checkpoint
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.ReLU(),
nn.Linear(hidden, n_classes),
)
self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_idx):
x, y = batch
loss = F.cross_entropy(self(x), y)
self.log("train_loss", loss, prog_bar=True)
return loss # Lightning calls backward + step
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
self.log("val_loss", F.cross_entropy(logits, y), prog_bar=True)
self.acc(logits, y)
self.log("val_acc", self.acc, prog_bar=True) # torchmetrics object
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)Train it with a Trainer and callbacks
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
L.seed_everything(42) # reproducible runs
ckpt = ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
early = EarlyStopping(monitor="val_loss", patience=3, mode="min")
trainer = L.Trainer(
max_epochs=10,
accelerator="auto", # CPU / GPU / MPS, picked for you
devices="auto",
callbacks=[ckpt, early],
)
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
trainer.test(model, dataloaders=test_dl)
print(ckpt.best_model_path) # path to the best checkpointA LightningDataModule
import lightning as L
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTData(L.LightningDataModule):
def __init__(self, data_dir="data", batch_size=64):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.tf = transforms.ToTensor()
def prepare_data(self): # runs once (download only)
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None): # runs on every process
full = MNIST(self.data_dir, train=True, transform=self.tf)
self.train, self.val = random_split(full, [55000, 5000])
self.test = MNIST(self.data_dir, train=False, transform=self.tf)
def train_dataloader(self):
return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test, batch_size=self.batch_size)
# Usage: trainer.fit(model, datamodule=MNISTData())Save, load, and export
# A checkpoint stores weights + hyperparameters + optimizer state + epoch.
trainer.save_checkpoint("model.ckpt")
# Reload a ready-to-use model; no need to re-pass constructor args.
best = LitClassifier.load_from_checkpoint("model.ckpt")
# Override a stored hyperparameter at load time:
tuned = LitClassifier.load_from_checkpoint("model.ckpt", lr=1e-4)
# Pull out just the raw weights (torch.load needs weights_only=False for a
# Lightning .ckpt, which holds non-tensor objects like hparams + optimizer state):
state_dict = torch.load("model.ckpt", weights_only=False)["state_dict"]
# Resume an interrupted run (restores optimizer state, not just weights):
# trainer.fit(model, datamodule=dm, ckpt_path="model.ckpt")
# Export for serving:
best.to_onnx("model.onnx", torch.randn(1, 784),
export_params=True) # portable graph (needs `onnx`)
# TorchScript-free graph export (to_torchscript is deprecated in 2.7, gone in 2.8):
exported = torch.export.export(best, (torch.randn(1, 784),))
torch.export.save(exported, "model.pt2") # self-contained .pt2The same loop in raw PyTorch (the “why Lightning” comparison)
Lightning replaces this hand-written boilerplate with the four slots above:
# Raw PyTorch: you own the loop, device moves, zero_grad, backward, step.
model = MyNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(10):
model.train()
for x, y in train_dl:
x, y = x.to(device), y.to(device) # Lightning does device moves
opt.zero_grad() # Lightning does zero_grad
loss = F.cross_entropy(model(x), y)
loss.backward() # Lightning does backward
opt.step() # Lightning does step
# ... plus your own val loop, metric averaging, checkpointing, logging ...References
PyTorch Lightning documentation
- Documentation home and Lightning in 15 minutes
LightningModule, step hooks, andconfigure_optimizers- Manual optimization, the
TrainerAPI, and accelerators and devices - Logging, loggers, and callbacks
ModelCheckpoint,EarlyStopping, and theLightningDataModule- Checkpointing, export to ONNX,
torch.export, and theLightningCLI
Related libraries and project