BackboneFinetuning

class lightning.pytorch.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=10, lambda_func=<function multiplicative>, backbone_initial_ratio_lr=0.1, backbone_initial_lr=None, should_align=True, initial_denom_lr=10.0, train_bn=True, verbose=False, rounding=12)[source]

Bases: BaseFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling.

When the backbone learning rate reaches the current model learning rate and should_align is set to True, it will align with it for the rest of the training.

Parameters:
  • unfreeze_backbone_at_epoch (int) – Epoch at which the backbone will be unfreezed.

  • lambda_func (Callable) – Scheduling function for increasing backbone learning rate.

  • backbone_initial_ratio_lr (float) – Used to scale down the backbone learning rate compared to rest of model

  • backbone_initial_lr (Optional[float]) – Optional, Initial learning rate for the backbone. By default, we will use current_learning /  backbone_initial_ratio_lr

  • should_align (bool) – Whether to align with current learning rate when backbone learning reaches it.

  • initial_denom_lr (float) – When unfreezing the backbone, the initial learning rate will current_learning_rate /  initial_denom_lr.

  • train_bn (bool) – Whether to make Batch Normalization trainable.

  • verbose (bool) – Display current learning rate for model and backbone

  • rounding (int) – Precision for displaying learning rate

Example:

>>> import torch
>>> import torch.nn as nn
>>> from lightning.pytorch import LightningModule, Trainer
>>> from lightning.pytorch.callbacks import BackboneFinetuning
>>> import torchvision.models as models
>>>
>>> class TransferLearningModel(LightningModule):
...     def __init__(self, num_classes=10):
...         super().__init__()
...         # REQUIRED: Your model must have a 'backbone' attribute
...         self.backbone = models.resnet50(weights="DEFAULT")
...         # Remove the final classification layer from backbone
...         self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
...
...         # Add your task-specific head
...         self.head = nn.Sequential(
...             nn.Flatten(),
...             nn.Linear(2048, 512),
...             nn.ReLU(),
...             nn.Linear(512, num_classes)
...         )
...
...     def forward(self, x):
...         # Extract features with backbone
...         features = self.backbone(x)
...         # Classify with head
...         return self.head(features)
...
...     def configure_optimizers(self):
...         # Initially only optimize the head - backbone will be added by callback
...         return torch.optim.Adam(self.head.parameters(), lr=1e-3)
...
>>> # Setup the callback
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(
...     unfreeze_backbone_at_epoch=10,  # Start unfreezing at epoch 10
...     lambda_func=multiplicative,     # Gradually increase backbone LR
...     backbone_initial_ratio_lr=0.1,  # Start backbone at 10% of head LR
... )
>>> model = TransferLearningModel()
>>> trainer = Trainer(callbacks=[backbone_finetuning])
finetune_function(pl_module, epoch, optimizer)[source]

Called when the epoch begins.

Return type:

None

freeze_before_training(pl_module)[source]

Override to add your freeze logic.

Return type:

None

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_fit_start(trainer, pl_module)[source]
Raises:

MisconfigurationException – If LightningModule has no nn.Module backbone attribute.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.