Source code for lr_range_test.helpers

from typing import Optional, List, Dict

import ignite
import torch

from lr_range_test.lr_range import AutomaticLRRangeTest, InteractiveLRRangeTest
from lr_range_test.type_aliases import OptimizerType, DataLoaderType, LossFnType


[docs]def lr_range_test(optimizer: OptimizerType, model: torch.nn.Module, train_loader: DataLoaderType, loss_fn: LossFnType, eval_metric: Optional[ignite.metrics.Metric] = None, test_loader: Optional[DataLoaderType] = None, lr_min: float = 1e-7, lr_max: float = 1e1, num_steps: int = 50, smooth_f: float = .05, diverge_th: float = 5., wd_values: Optional[List[float]] = None, pbar: bool = False, automatic: bool = False, descending: bool = True, device: str = 'cuda') -> Dict['str', float]: """ The function expects a ``model`` and ``optimizer`` for which to perform the test. This model is optimized with wrt. a loss function ``loss_fn``. The data is loaded from a given iterable (or a standard pytorch ``DataLoader``) called ``train_loader``. The loss will be calculated as the batch loss after each step if a ``test_loader`` is not specified. If ``test_loader`` is specified the loss is computed and averaged on the entirety of the test data. The learning rate of the model is varied from ``lr_min`` to ``lr_max`` exponentially over the course of ``num_steps`` iterations and smoothed with an exponential moving average with an alpha coefficient of ``smooth_f``. The training is stopped early if the loss diverges by a factor of more than ``diverge_th`` from the best recorded loss. A custom evaluation metric such as accuracy can be specified with an `ignite metric <https://pytorch.org/ignite/metrics.html>`_. If the metric is expected to increase during training, (eg. accuracy) the ``descending`` parameter should be set to ``False``. The test can be run in either the interactive our automatic way depending on the value of ``automatic``. The results will be returned as a dictionary :param automatic: whether to perform an automatic lr range test or an interactive one :param model: A torch module receiving inputs and outputting predictions :param eval_metric: An ignite metric to use when evaluating the test_loader. :param optimizer: The optimizer to use for the LR range test. :param train_loader: An iterable to load data from and feed to the trainer. :param test_loader: An iterable to load data from and feed to the evaluator, :param loss_fn: An objective function taking outputs and predictions and returning a metric. :param device: the device to do the training/evaluation on (default: cuda) :param descending: whether the metric/loss chosen should descend or not (ie. accuracy should not) :param pbar: whether to print a progress bar during training :param wd_values: the weight decay values to test for :param diverge_th: the coefficient by which the current metric must differ from the best recorded value to consider that the metric has diverged :param num_steps: the number of steps to increase LR over :param lr_max: the lr to end on :param lr_min: the lr to start from :param smooth_f: the alpha coefficient for the exponential moving average :return: a dictionary with the results """ if automatic: tester = AutomaticLRRangeTest(optimizer=optimizer, model=model, loss_fn=loss_fn, train_loader=train_loader, test_loader=test_loader, eval_metric=eval_metric, descending=descending, device=device) else: tester = InteractiveLRRangeTest(optimizer=optimizer, model=model, loss_fn=loss_fn, train_loader=train_loader, test_loader=test_loader, eval_metric=eval_metric, descending=descending, device=device) return tester.run(lr_min=lr_min, lr_max=lr_max, num_steps=num_steps, smooth_f=smooth_f, diverge_th=diverge_th, wd_values=wd_values, pbar=pbar)