from functools import partial
import matplotlib.pyplot as plt
import torch
from torch import nn
Optimizers and Schedulers: Homework
Implementing Cosine Annealing and the OneCycle scheduler from scratch
"ggplot") plt.style.use(
Recall, we want something to situate within this interface
class BaseSchedulerCB(Callback):
"""Base callback class for schedulers"""
def __init__(self, scheduler_f, **kwargs):
self.scheduler_f = scheduler_f
self.sched_kwargs = kwargs
self.sched = None
def before_fit(self, learn):
self.sched = self.scheduler_f(learn.opt, **self.sched_kwargs) # 👈
def _step(self, learn):
if learn.training:
self.sched.step() # 👈
That is:
- It takes the optimizer as the first argument of its constructor
- It steps once per batch or once per epoch, if overwritten. This is implemented in the
LRScheduler
super class.
We also need to implement the LRScheduler
interface:
.get_lr()
returns the optimizer learner rate as a function of the internal property,_step_count
torch.optim.lr_scheduler.LRScheduler??
Init signature: torch.optim.lr_scheduler.LRScheduler( optimizer, last_epoch=-1, verbose=False, ) Docstring: <no docstring> Source: class LRScheduler: def __init__(self, optimizer, last_epoch=-1, verbose=False): # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f'{type(optimizer).__name__} is not an Optimizer') self.optimizer = optimizer # Initialize epoch and base learning rates if last_epoch == -1: for group in optimizer.param_groups: group.setdefault('initial_lr', group['lr']) else: for i, group in enumerate(optimizer.param_groups): if 'initial_lr' not in group: raise KeyError("param 'initial_lr' is not specified " f"in param_groups[{i}] when resuming an optimizer") self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] self.last_epoch = last_epoch # Following https://github.com/pytorch/pytorch/issues/20124 # We would like to ensure that `lr_scheduler.step()` is called after # `optimizer.step()` def with_counter(method): if getattr(method, '_with_counter', False): # `optimizer.step()` has already been replaced, return. return method # Keep a weak reference to the optimizer instance to prevent # cyclic references. instance_ref = weakref.ref(method.__self__) # Get the unbound method for the same purpose. func = method.__func__ cls = instance_ref().__class__ del method @wraps(func) def wrapper(*args, **kwargs): instance = instance_ref() instance._step_count += 1 wrapped = func.__get__(instance, cls) return wrapped(*args, **kwargs) # Note that the returned function here is no longer a bound method, # so attributes like `__func__` and `__self__` no longer exist. wrapper._with_counter = True return wrapper self.optimizer.step = with_counter(self.optimizer.step) self.verbose = verbose self._initial_step() def _initial_step(self): """Initialize step counts and performs a step""" self.optimizer._step_count = 0 self._step_count = 0 self.step() def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} def load_state_dict(self, state_dict): """Loads the schedulers state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ return self._last_lr def get_lr(self): # Compute learning rate using chainable form of the scheduler raise NotImplementedError def print_lr(self, is_verbose, group, lr, epoch=None): """Display the current learning rate. """ if is_verbose: if epoch is None: print(f'Adjusting learning rate of group {group} to {lr:.4e}.') else: epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.') def step(self, epoch=None): # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 if self._step_count == 1: if not hasattr(self.optimizer.step, "_with_counter"): warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " "initialization. Please, make sure to call `optimizer.step()` before " "`lr_scheduler.step()`. See more details at " "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) # Just check if there were two first lr_scheduler.step() calls before optimizer.step() elif self.optimizer._step_count < 1: warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " "In PyTorch 1.1.0 and later, you should call them in the opposite order: " "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " "will result in PyTorch skipping the first value of the learning rate schedule. " "See more details at " "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) self._step_count += 1 with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 values = self.get_lr() else: warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) self.last_epoch = epoch if hasattr(self, "_get_closed_form_lr"): values = self._get_closed_form_lr() else: values = self.get_lr() for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data param_group['lr'] = lr self.print_lr(self.verbose, i, lr, epoch) self._last_lr = [group['lr'] for group in self.optimizer.param_groups] File: ~/miniforge3/envs/slowai/lib/python3.10/site-packages/torch/optim/lr_scheduler.py Type: type Subclasses: _LRScheduler, LambdaLR, MultiplicativeLR, StepLR, MultiStepLR, ConstantLR, LinearLR, ExponentialLR, SequentialLR, PolynomialLR, ...
First, let’s write helpers.
def plot_scheduler(sched, n_batches):
= plt.subplots(figsize=(4, 4))
fig, ax = []
lrs
lrs.append(sched.get_last_lr())for _ in range(n_batches):
sched.optimizer.step()
sched.step()
lrs.append(sched.get_last_lr())
ax.plot(lrs)set(xlabel="Time", ylabel="LR") ax.
Now, we can do some dummy training to help with plotting and ensure it actually works in a training loop
= torch.randn(100, 1) # 100 samples with 1 feature
X = 2 * X + 1 + torch.randn(100, 1) # Add some noise
y
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
def train(scheduler_f, nbatches=100, lr=0.01):
= LinearRegressionModel(1, 1)
model = nn.MSELoss()
criterion = torch.optim.SGD(model.parameters(), lr=lr)
optimizer = scheduler_f(optimizer)
scheduler = []
lrs for epoch in range(nbatches):
lrs.append(scheduler)= model(X)
outputs = criterion(outputs, y)
loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
plot_scheduler(scheduler, n_batches)
= 100
n_batches = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=n_batches)
scheduler_f train(scheduler_f, n_batches)
Part I: Cosine Annealing
= 100
t_max = 2.5
lr_max = 0.5
lr_min = torch.arange(t_max)
x = (1 + torch.cos(x * 3.141 / t_max)) / 2 * (lr_max - lr_min) + lr_min
y = plt.subplots(figsize=(4, 4))
fig, ax
ax.plot(x, y)0, 5); ax.set_ylim(
class CosineAnnealingLRScheduler(torch.optim.lr_scheduler.LRScheduler):
def __init__(
self,
optimizer,float,
lr_max: float,
lr_min: int,
t_max: =-1,
last_epoch=False,
verbose
):# That that the superclass constructor calls .step() on the instance,
# such that we need one additional learning rate beyond the number of
# steps associated with each batch
= torch.arange(t_max + 1).float()
xs self.lrs = (1 + torch.cos(xs * 3.141 / t_max)) / 2 * (lr_max - lr_min) + lr_min
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
return [self.lrs[self._step_count - 1] for _ in self.optimizer.param_groups]
= partial(
scheduler_f
CosineAnnealingLRScheduler,=1.0,
lr_max=0.5,
lr_min=n_batches,
t_max
) train(scheduler_f, n_batches)
Part II: 1 cycle
= 100
t_max = 0.1, 0.8, 0.01 lr_start, lr_max, lr_end
= torch.linspace(lr_start, lr_max, t_max // 2)
lra = torch.linspace(lr_max, lr_end, t_max // 2 + t_max % 2)
lrb = torch.cat((ya, yb)) lrs
= plt.subplots(figsize=(4, 4))
fig, ax
ax.plot(x, lrs)0, 1); ax.set_ylim(
class OneCycleLRScheduler(torch.optim.lr_scheduler.LRScheduler):
def __init__(
self,
optimizer,
lrs,int,
t_max: =-1,
last_epoch=False,
verbose
):# That that the superclass constructor calls .step() on the instance,
# such that we need one additional learning rate beyond the number of
# steps associated with each batch
= lrs
lr_start, lr_max, lr_end = torch.linspace(lr_start, lr_max, t_max // 2)
lra = torch.linspace(lr_max, lr_end, t_max // 2 + (t_max % 2) + 1)
lrb self.lrs = torch.cat((lra, lrb))
print(len(self.lrs))
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
return [self.lrs[self._step_count - 1] for _ in self.optimizer.param_groups]
= partial(
scheduler_f
OneCycleLRScheduler,=(0.2, 0.8, 0.02),
lrs=n_batches,
t_max
) train(scheduler_f, n_batches)
101