Transfer Learning

Audience: Users looking to use pretrained models with Lightning.


Use any PyTorch nn.Module

Any model that is a PyTorch nn.Module can be used with Lightning (because LightningModules are nn.Modules also).


Use a pretrained LightningModule

Let’s use the AutoEncoder as a feature extractor in a separate model.

class Encoder(torch.nn.Module):
    ...


class AutoEncoder(LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()


class CIFAR10Classifier(LightningModule):
    def __init__(self):
        # init the pretrained LightningModule
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH).encoder
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

We used our pretrained Autoencoder (a LightningModule) for transfer learning!


Example: Imagenet (Computer Vision)

import torchvision.models as models


class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)
        self.feature_extractor.eval()

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        ...

Finetune

model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)

And use it to predict your data of interest

model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()

x = some_images_from_cifar10()
predictions = model(x)

We used a pretrained model on imagenet, finetuned on CIFAR-10 to predict on CIFAR-10. In the non-academic world we would finetune on a tiny dataset you have and predict on your dataset.


Example: BERT (NLP)

Lightning is completely agnostic to what’s used for transfer learning so long as it is a torch.nn.Module subclass.

Here’s a model that uses Huggingface transformers.

class BertMNLIFinetuner(LightningModule):
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        self.bert.train()
        self.W = nn.Linear(bert.config.hidden_size, 3)
        self.num_classes = 3

    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        h_cls = h[:, 0]
        logits = self.W(h_cls)
        return logits, attn

Automated Finetuning with Callbacks

PyTorch Lightning provides the BackboneFinetuning callback to automate the finetuning process. This callback gradually unfreezes your model’s backbone during training. This is particularly useful when working with large pretrained models, as it allows you to start training with a frozen backbone and then progressively unfreeze layers to fine-tune the model.

The BackboneFinetuning callback expects your model to have a specific structure:

class MyModel(LightningModule):
    def __init__(self):
        super().__init__()

        # REQUIRED: Your model must have a 'backbone' attribute
        # This should be the pretrained part you want to finetune
        self.backbone = some_pretrained_model

        # Your task-specific layers (head, classifier, etc.)
        self.head = nn.Linear(backbone_features, num_classes)

    def configure_optimizers(self):
        # Only optimize the head initially - backbone will be added automatically
        return torch.optim.Adam(self.head.parameters(), lr=1e-3)

Example: Computer Vision with ResNet

Here’s a complete example showing how to use BackboneFinetuning for computer vision:

import torch
import torch.nn as nn
import torchvision.models as models
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BackboneFinetuning


class ResNetClassifier(LightningModule):
    def __init__(self, num_classes=10, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        # Create backbone from pretrained ResNet
        resnet = models.resnet50(weights="DEFAULT")
        # Remove the final classification layer
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # Add custom classification head
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(resnet.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # Extract features with backbone
        features = self.backbone(x)
        # Classify with head
        return self.head(features)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        # Initially only train the head - backbone will be added by callback
        return torch.optim.Adam(self.head.parameters(), lr=self.hparams.learning_rate)


# Setup the finetuning callback
backbone_finetuning = BackboneFinetuning(
    unfreeze_backbone_at_epoch=10,  # Start unfreezing backbone at epoch 10
    lambda_func=lambda epoch: 1.5,  # Gradually increase backbone learning rate
    backbone_initial_ratio_lr=0.1,  # Backbone starts at 10% of head learning rate
    should_align=True,  # Align rates when backbone rate reaches head rate
    verbose=True  # Print learning rates during training
)

model = ResNetClassifier()
trainer = Trainer(callbacks=[backbone_finetuning], max_epochs=20)

Custom Finetuning Strategies

For more control, you can create custom finetuning strategies by subclassing BaseFinetuning:

from lightning.pytorch.callbacks.finetuning import BaseFinetuning


class CustomFinetuning(BaseFinetuning):
    def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
        super().__init__()
        self.unfreeze_at_epoch = unfreeze_at_epoch
        self.layers_per_epoch = layers_per_epoch

    def freeze_before_training(self, pl_module):
        # Freeze the entire backbone initially
        self.freeze(pl_module.backbone)

    def finetune_function(self, pl_module, epoch, optimizer):
        # Gradually unfreeze layers
        if epoch >= self.unfreeze_at_epoch:
            layers_to_unfreeze = min(
                self.layers_per_epoch,
                len(list(pl_module.backbone.children()))
            )

            # Unfreeze from the top layers down
            backbone_children = list(pl_module.backbone.children())
            for layer in backbone_children[-layers_to_unfreeze:]:
                self.unfreeze_and_add_param_group(
                    layer, optimizer, lr=1e-4
                )