BackboneFinetuning¶
- class lightning.pytorch.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=10, lambda_func=<function multiplicative>, backbone_initial_ratio_lr=0.1, backbone_initial_lr=None, should_align=True, initial_denom_lr=10.0, train_bn=True, verbose=False, rounding=12)[source]¶
Bases:
BaseFinetuning
Finetune a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate and
should_align
is set to True, it will align with it for the rest of the training.- Parameters:
unfreeze_backbone_at_epoch¶ (
int
) – Epoch at which the backbone will be unfreezed.lambda_func¶ (
Callable
) – Scheduling function for increasing backbone learning rate.backbone_initial_ratio_lr¶ (
float
) – Used to scale down the backbone learning rate compared to rest of modelbackbone_initial_lr¶ (
Optional
[float
]) – Optional, Initial learning rate for the backbone. By default, we will usecurrent_learning / backbone_initial_ratio_lr
should_align¶ (
bool
) – Whether to align with current learning rate when backbone learning reaches it.initial_denom_lr¶ (
float
) – When unfreezing the backbone, the initial learning rate willcurrent_learning_rate / initial_denom_lr
.train_bn¶ (
bool
) – Whether to make Batch Normalization trainable.verbose¶ (
bool
) – Display current learning rate for model and backbone
Example:
>>> import torch >>> import torch.nn as nn >>> from lightning.pytorch import LightningModule, Trainer >>> from lightning.pytorch.callbacks import BackboneFinetuning >>> import torchvision.models as models >>> >>> class TransferLearningModel(LightningModule): ... def __init__(self, num_classes=10): ... super().__init__() ... # REQUIRED: Your model must have a 'backbone' attribute ... self.backbone = models.resnet50(weights="DEFAULT") ... # Remove the final classification layer from backbone ... self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) ... ... # Add your task-specific head ... self.head = nn.Sequential( ... nn.Flatten(), ... nn.Linear(2048, 512), ... nn.ReLU(), ... nn.Linear(512, num_classes) ... ) ... ... def forward(self, x): ... # Extract features with backbone ... features = self.backbone(x) ... # Classify with head ... return self.head(features) ... ... def configure_optimizers(self): ... # Initially only optimize the head - backbone will be added by callback ... return torch.optim.Adam(self.head.parameters(), lr=1e-3) ... >>> # Setup the callback >>> multiplicative = lambda epoch: 1.5 >>> backbone_finetuning = BackboneFinetuning( ... unfreeze_backbone_at_epoch=10, # Start unfreezing at epoch 10 ... lambda_func=multiplicative, # Gradually increase backbone LR ... backbone_initial_ratio_lr=0.1, # Start backbone at 10% of head LR ... ) >>> model = TransferLearningModel() >>> trainer = Trainer(callbacks=[backbone_finetuning])
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.