Transfer Learning¶
Audience: Users looking to use pretrained models with Lightning.
Use any PyTorch nn.Module¶
Any model that is a PyTorch nn.Module can be used with Lightning (because LightningModules are nn.Modules also).
Use a pretrained LightningModule¶
Let’s use the AutoEncoder as a feature extractor in a separate model.
class Encoder(torch.nn.Module):
...
class AutoEncoder(LightningModule):
def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
class CIFAR10Classifier(LightningModule):
def __init__(self):
# init the pretrained LightningModule
self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH).encoder
self.feature_extractor.freeze()
# the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
self.classifier = nn.Linear(100, 10)
def forward(self, x):
representations = self.feature_extractor(x)
x = self.classifier(representations)
...
We used our pretrained Autoencoder (a LightningModule) for transfer learning!
Example: Imagenet (Computer Vision)¶
import torchvision.models as models
class ImagenetTransferLearning(LightningModule):
def __init__(self):
super().__init__()
# init a pretrained resnet
backbone = models.resnet50(weights="DEFAULT")
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = nn.Sequential(*layers)
self.feature_extractor.eval()
# use the pretrained model to classify cifar-10 (10 image classes)
num_target_classes = 10
self.classifier = nn.Linear(num_filters, num_target_classes)
def forward(self, x):
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
x = self.classifier(representations)
...
Finetune
model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)
And use it to predict your data of interest
model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()
x = some_images_from_cifar10()
predictions = model(x)
We used a pretrained model on imagenet, finetuned on CIFAR-10 to predict on CIFAR-10. In the non-academic world we would finetune on a tiny dataset you have and predict on your dataset.
Example: BERT (NLP)¶
Lightning is completely agnostic to what’s used for transfer learning so long as it is a torch.nn.Module subclass.
Here’s a model that uses Huggingface transformers.
class BertMNLIFinetuner(LightningModule):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
self.bert.train()
self.W = nn.Linear(bert.config.hidden_size, 3)
self.num_classes = 3
def forward(self, input_ids, attention_mask, token_type_ids):
h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
h_cls = h[:, 0]
logits = self.W(h_cls)
return logits, attn
Automated Finetuning with Callbacks¶
PyTorch Lightning provides the BackboneFinetuning
callback to automate
the finetuning process. This callback gradually unfreezes your model’s backbone during training. This is particularly
useful when working with large pretrained models, as it allows you to start training with a frozen backbone and
then progressively unfreeze layers to fine-tune the model.
The BackboneFinetuning
callback expects your model to have a specific structure:
class MyModel(LightningModule):
def __init__(self):
super().__init__()
# REQUIRED: Your model must have a 'backbone' attribute
# This should be the pretrained part you want to finetune
self.backbone = some_pretrained_model
# Your task-specific layers (head, classifier, etc.)
self.head = nn.Linear(backbone_features, num_classes)
def configure_optimizers(self):
# Only optimize the head initially - backbone will be added automatically
return torch.optim.Adam(self.head.parameters(), lr=1e-3)
Example: Computer Vision with ResNet¶
Here’s a complete example showing how to use BackboneFinetuning
for computer vision:
import torch
import torch.nn as nn
import torchvision.models as models
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BackboneFinetuning
class ResNetClassifier(LightningModule):
def __init__(self, num_classes=10, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
# Create backbone from pretrained ResNet
resnet = models.resnet50(weights="DEFAULT")
# Remove the final classification layer
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
# Add custom classification head
self.head = nn.Sequential(
nn.Flatten(),
nn.Linear(resnet.fc.in_features, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, num_classes)
)
def forward(self, x):
# Extract features with backbone
features = self.backbone(x)
# Classify with head
return self.head(features)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
# Initially only train the head - backbone will be added by callback
return torch.optim.Adam(self.head.parameters(), lr=self.hparams.learning_rate)
# Setup the finetuning callback
backbone_finetuning = BackboneFinetuning(
unfreeze_backbone_at_epoch=10, # Start unfreezing backbone at epoch 10
lambda_func=lambda epoch: 1.5, # Gradually increase backbone learning rate
backbone_initial_ratio_lr=0.1, # Backbone starts at 10% of head learning rate
should_align=True, # Align rates when backbone rate reaches head rate
verbose=True # Print learning rates during training
)
model = ResNetClassifier()
trainer = Trainer(callbacks=[backbone_finetuning], max_epochs=20)
Custom Finetuning Strategies¶
For more control, you can create custom finetuning strategies by subclassing
BaseFinetuning
:
from lightning.pytorch.callbacks.finetuning import BaseFinetuning
class CustomFinetuning(BaseFinetuning):
def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
super().__init__()
self.unfreeze_at_epoch = unfreeze_at_epoch
self.layers_per_epoch = layers_per_epoch
def freeze_before_training(self, pl_module):
# Freeze the entire backbone initially
self.freeze(pl_module.backbone)
def finetune_function(self, pl_module, epoch, optimizer):
# Gradually unfreeze layers
if epoch >= self.unfreeze_at_epoch:
layers_to_unfreeze = min(
self.layers_per_epoch,
len(list(pl_module.backbone.children()))
)
# Unfreeze from the top layers down
backbone_children = list(pl_module.backbone.children())
for layer in backbone_children[-layers_to_unfreeze:]:
self.unfreeze_and_add_param_group(
layer, optimizer, lr=1e-4
)