Optimizers and Schedulers: Homework

Implementing Cosine Annealing and the OneCycle scheduler from scratch
from functools import partial

import matplotlib.pyplot as plt
import torch
from torch import nn
plt.style.use("ggplot")

Recall, we want something to situate within this interface

class BaseSchedulerCB(Callback):
    """Base callback class for schedulers"""

    def __init__(self, scheduler_f, **kwargs):
        self.scheduler_f = scheduler_f
        self.sched_kwargs = kwargs
        self.sched = None

    def before_fit(self, learn):
        self.sched = self.scheduler_f(learn.opt, **self.sched_kwargs) # 👈 

    def _step(self, learn):
        if learn.training:
            self.sched.step() # 👈 

That is:

We also need to implement the LRScheduler interface:

torch.optim.lr_scheduler.LRScheduler??
Init signature:
torch.optim.lr_scheduler.LRScheduler(
    optimizer,
    last_epoch=-1,
    verbose=False,
)
Docstring:      <no docstring>
Source:        
class LRScheduler:
    def __init__(self, optimizer, last_epoch=-1, verbose=False):
        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
        self.optimizer = optimizer
        # Initialize epoch and base learning rates
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   f"in param_groups[{i}] when resuming an optimizer")
        self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
        self.last_epoch = last_epoch
        # Following https://github.com/pytorch/pytorch/issues/20124
        # We would like to ensure that `lr_scheduler.step()` is called after
        # `optimizer.step()`
        def with_counter(method):
            if getattr(method, '_with_counter', False):
                # `optimizer.step()` has already been replaced, return.
                return method
            # Keep a weak reference to the optimizer instance to prevent
            # cyclic references.
            instance_ref = weakref.ref(method.__self__)
            # Get the unbound method for the same purpose.
            func = method.__func__
            cls = instance_ref().__class__
            del method
            @wraps(func)
            def wrapper(*args, **kwargs):
                instance = instance_ref()
                instance._step_count += 1
                wrapped = func.__get__(instance, cls)
                return wrapped(*args, **kwargs)
            # Note that the returned function here is no longer a bound method,
            # so attributes like `__func__` and `__self__` no longer exist.
            wrapper._with_counter = True
            return wrapper
        self.optimizer.step = with_counter(self.optimizer.step)
        self.verbose = verbose
        self._initial_step()
    def _initial_step(self):
        """Initialize step counts and performs a step"""
        self.optimizer._step_count = 0
        self._step_count = 0
        self.step()
    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
    def load_state_dict(self, state_dict):
        """Loads the schedulers state.
        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)
    def get_last_lr(self):
        """ Return last computed learning rate by current scheduler.
        """
        return self._last_lr
    def get_lr(self):
        # Compute learning rate using chainable form of the scheduler
        raise NotImplementedError
    def print_lr(self, is_verbose, group, lr, epoch=None):
        """Display the current learning rate.
        """
        if is_verbose:
            if epoch is None:
                print(f'Adjusting learning rate of group {group} to {lr:.4e}.')
            else:
                epoch_str = ("%.2f" if isinstance(epoch, float) else
                             "%.5d") % epoch
                print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.')
    def step(self, epoch=None):
        # Raise a warning if old pattern is detected
        # https://github.com/pytorch/pytorch/issues/20124
        if self._step_count == 1:
            if not hasattr(self.optimizer.step, "_with_counter"):
                warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                              "initialization. Please, make sure to call `optimizer.step()` before "
                              "`lr_scheduler.step()`. See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
            # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
            elif self.optimizer._step_count < 1:
                warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                              "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                              "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                              "will result in PyTorch skipping the first value of the learning rate schedule. "
                              "See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
        self._step_count += 1
        with _enable_get_lr_call(self):
            if epoch is None:
                self.last_epoch += 1
                values = self.get_lr()
            else:
                warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
                self.last_epoch = epoch
                if hasattr(self, "_get_closed_form_lr"):
                    values = self._get_closed_form_lr()
                else:
                    values = self.get_lr()
        for i, data in enumerate(zip(self.optimizer.param_groups, values)):
            param_group, lr = data
            param_group['lr'] = lr
            self.print_lr(self.verbose, i, lr, epoch)
        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
File:           ~/miniforge3/envs/slowai/lib/python3.10/site-packages/torch/optim/lr_scheduler.py
Type:           type
Subclasses:     _LRScheduler, LambdaLR, MultiplicativeLR, StepLR, MultiStepLR, ConstantLR, LinearLR, ExponentialLR, SequentialLR, PolynomialLR, ...

First, let’s write helpers.

def plot_scheduler(sched, n_batches):
    fig, ax = plt.subplots(figsize=(4, 4))
    lrs = []
    lrs.append(sched.get_last_lr())
    for _ in range(n_batches):
        sched.optimizer.step()
        sched.step()
        lrs.append(sched.get_last_lr())
    ax.plot(lrs)
    ax.set(xlabel="Time", ylabel="LR")

Now, we can do some dummy training to help with plotting and ensure it actually works in a training loop

X = torch.randn(100, 1)  # 100 samples with 1 feature
y = 2 * X + 1 + torch.randn(100, 1)  # Add some noise


class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)


def train(scheduler_f, nbatches=100, lr=0.01):
    model = LinearRegressionModel(1, 1)
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    scheduler = scheduler_f(optimizer)
    lrs = []
    for epoch in range(nbatches):
        lrs.append(scheduler)
        outputs = model(X)
        loss = criterion(outputs, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    plot_scheduler(scheduler, n_batches)
n_batches = 100
scheduler_f = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=n_batches)
train(scheduler_f, n_batches)

Part I: Cosine Annealing

t_max = 100
lr_max = 2.5
lr_min = 0.5
x = torch.arange(t_max)
y = (1 + torch.cos(x * 3.141 / t_max)) / 2 * (lr_max - lr_min) + lr_min
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(x, y)
ax.set_ylim(0, 5);

class CosineAnnealingLRScheduler(torch.optim.lr_scheduler.LRScheduler):
    def __init__(
        self,
        optimizer,
        lr_max: float,
        lr_min: float,
        t_max: int,
        last_epoch=-1,
        verbose=False,
    ):
        # That that the superclass constructor calls .step() on the instance,
        # such that we need one additional learning rate beyond the number of
        # steps associated with each batch
        xs = torch.arange(t_max + 1).float()
        self.lrs = (1 + torch.cos(xs * 3.141 / t_max)) / 2 * (lr_max - lr_min) + lr_min
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        return [self.lrs[self._step_count - 1] for _ in self.optimizer.param_groups]
scheduler_f = partial(
    CosineAnnealingLRScheduler,
    lr_max=1.0,
    lr_min=0.5,
    t_max=n_batches,
)
train(scheduler_f, n_batches)

Part II: 1 cycle

t_max = 100
lr_start, lr_max, lr_end = 0.1, 0.8, 0.01
lra = torch.linspace(lr_start, lr_max, t_max // 2)
lrb = torch.linspace(lr_max, lr_end, t_max // 2 + t_max % 2)
lrs = torch.cat((ya, yb))
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(x, lrs)
ax.set_ylim(0, 1);

class OneCycleLRScheduler(torch.optim.lr_scheduler.LRScheduler):
    def __init__(
        self,
        optimizer,
        lrs,
        t_max: int,
        last_epoch=-1,
        verbose=False,
    ):
        # That that the superclass constructor calls .step() on the instance,
        # such that we need one additional learning rate beyond the number of
        # steps associated with each batch
        lr_start, lr_max, lr_end = lrs
        lra = torch.linspace(lr_start, lr_max, t_max // 2)
        lrb = torch.linspace(lr_max, lr_end, t_max // 2 + (t_max % 2) + 1)
        self.lrs = torch.cat((lra, lrb))
        print(len(self.lrs))
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        return [self.lrs[self._step_count - 1] for _ in self.optimizer.param_groups]
scheduler_f = partial(
    OneCycleLRScheduler,
    lrs=(0.2, 0.8, 0.02),
    t_max=n_batches,
)
train(scheduler_f, n_batches)
101