Source code for libai.engine.hooks

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
import math
import operator
import time
from collections import Counter

import oneflow as flow

from libai.evaluation import flatten_results_dict
from libai.utils import distributed as dist
from libai.utils.checkpoint import Checkpointer
from libai.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
from libai.utils.events import EventWriter
from libai.utils.timer import Timer

from .trainer import HookBase

# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/hooks.py
# --------------------------------------------------------

"""
Implement some common hooks.
"""
logger = logging.getLogger(__name__)


[docs]class CallbackHook(HookBase): """ Create a hook using callback functions provided by the user. """ def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None): """ Each argument is a function that takes one argument: the trainer. """ self._before_train = before_train self._before_step = before_step self._after_step = after_step self._after_train = after_train
[docs] def before_train(self): if self._before_train: self._before_train(self.trainer)
[docs] def after_train(self): if self._after_train: self._after_train(self.trainer) # The functions may be closures that hold reference to the trainer # Therefore, delete them to avoid circular reference. del self._before_train, self._after_train del self._before_step, self._after_step
[docs] def before_step(self): if self._before_step: self._before_step(self.trainer)
[docs] def after_step(self): if self._after_step: self._after_step(self.trainer)
[docs]class IterationTimer(HookBase): """ Track the time spent for each iteration (each run_step call in the trainer). Print a summary in the end of training. This hook uses the time between the call to its :meth:`before_step` and :meth:`after_step` methods. Under the convention that :meth:`before_step` of all hooks should only take negligible amount of time, the :class:`IterationTimer` hook should be placed at the beginning of the list of hooks to obtain accurate timing. """ def __init__(self, warmup_iter=3): """ Args: warmup_iter (int): the number of iterations at the beginning to exclude from timing. """ self._warmup_iter = warmup_iter self._step_timer = Timer()
[docs] def before_train(self): self._start_time = time.perf_counter() self._total_timer = Timer() self._total_timer.pause()
[docs] def after_train(self): total_time = time.perf_counter() - self._start_time total_time_minus_hooks = self._total_timer.seconds() hook_time = total_time - total_time_minus_hooks num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter if num_iter > 0 and total_time_minus_hooks > 0: # Speed is meaningful only after warmup # NOTE this format is parsed by grep in some scripts logger.info( "Overall training speed: {} iterations in {} ({:.4f} s / it)".format( num_iter, str(datetime.timedelta(seconds=int(total_time_minus_hooks))), total_time_minus_hooks / num_iter, ) ) logger.info( "Total training time: {} ({} on hooks)".format( str(datetime.timedelta(seconds=int(total_time))), str(datetime.timedelta(seconds=int(hook_time))), ) )
[docs] def before_step(self): self._step_timer.reset() self._total_timer.resume()
[docs] def after_step(self): # +1 because we're in after_step iter_done = self.trainer.iter - self.trainer.start_iter + 1 if iter_done >= self._warmup_iter: sec = self._step_timer.seconds() self.trainer.storage.put_scalars(time=sec) else: self._start_time = time.perf_counter() self._total_timer.reset() self._total_timer.pause()
[docs]class PeriodicWriter(HookBase): """ Write events to EventStorage periodically. It is executed every ``period`` iterations and after the last iteration. """ def __init__(self, writers, period=20): """ Args: writers (list[EventWriter]): a list of EventWriter objects period (int): """ self._writers = writers for w in writers: assert isinstance(w, EventWriter), w self._period = period
[docs] def after_step(self): if (self.trainer.iter + 1) % self._period == 0 or ( self.trainer.iter == self.trainer.max_iter - 1 ): for writer in self._writers: writer.write()
[docs] def after_train(self): for writer in self._writers: writer.close()
[docs]class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase): """ Same as :class:`libai.utils.checkpoint.PeriodicCheckpointer`, but as a hook. Note that when used as a hook, it is unable to save additional data other than what's defined by the given `checkpointer`. It is executed every ``period`` iterations and after the last iteration. """
[docs] def before_train(self): self.max_iter = self.trainer.max_iter
[docs] def after_step(self): self.step(self.trainer.iter)
[docs]class BestCheckpointer(HookBase): """ Checkpoints best weights based off given metric. This hook should be used in conjunction to and executed after the hook that produces the metric, e.g. `EvalHook`. """ def __init__( self, eval_period: int, checkpointer: Checkpointer, val_metric: str, mode: str = "max", file_prefix: str = "model_best", ) -> None: """ Args: eval_period (int): the period `EvalHook` is set to run. checkpointer: the checkpointer object used to save checkpoints. val_metric (str): validation metric to track for best checkpoint, e.g. "acc@1" mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be maximized or minimized, e.g. for "acc@1" it should be "max" file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best" """ self._period = eval_period self._val_metric = val_metric assert mode in [ "max", "min", ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.' if mode == "max": self._compare = operator.gt else: self._compare = operator.lt self._checkpointer = checkpointer self._file_prefix = file_prefix self.best_metric = None self.best_iter = None def _update_best(self, val, iteration): if math.isnan(val) or math.isinf(val): return False self.best_metric = val self.best_iter = iteration return True def _best_checking(self): metric_tuple = self.trainer.storage.latest().get(self._val_metric) flag = flow.zeros(1) if dist.is_main_process(): if metric_tuple is None: logger.warning( f"Given val metric {self._val_metric} does not seem to be computed/stored. " "Will not be checkpointed based on that." ) else: latest_metric, metric_iter = metric_tuple if self.best_metric is None: if self._update_best(latest_metric, metric_iter): flag = flag + 1 logger.info( f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps" ) elif self._compare(latest_metric, self.best_metric): flag = flag + 1 logger.info( f"Saved best model as latest eval score for {self._val_metric} is " f"{latest_metric:0.5f}, better than last best score " f"{self.best_metric:0.5f} @ iteration {self.best_iter}." ) self._update_best(latest_metric, metric_iter) else: logger.info( f"Not saving as latest eval score for " f"{self._val_metric} is {latest_metric:0.5f}, " f"not better than best score {self.best_metric:0.5f} " f"@ iteration {self.best_iter}." ) dist.synchronize() flag = flag.to_global( sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cpu") ) if flag.to_local().item() == 1: self._checkpointer.save(f"{self._file_prefix}")
[docs] def after_step(self): # same conditions as `EvalHook` next_iter = self.trainer.iter + 1 if ( self._period > 0 and next_iter % self._period == 0 and next_iter != self.trainer.max_iter ): self._best_checking()
[docs] def after_train(self): # same conditions as `EvalHook` if self.trainer.iter + 1 >= self.trainer.max_iter: self._best_checking()
[docs]class EvalHook(HookBase): """ Run an evaluation function periodically, and at the end of training. It is executed every ``eval_period`` iterations and after the last iteration. """ def __init__(self, eval_period, eval_function): """ Args: eval_period (int): the period to run `eval_function`. eval_function (callable): a function which takes no arguments, and returns a nested dict of evaluation metrics. Note: This hook must be enabled in all or none workers. If you would like only certain workers to perform evaluation, give other workers a no-op function (`eval_function=lambda: None`). """ self._period = eval_period self._func = eval_function def _do_eval(self): results = self._func() if results: assert isinstance( results, dict ), "Eval function must return a dict. Got {} instead.".format(results) flattened_results = flatten_results_dict(results) # fixme: flatten_results_dict is not defined for k, v in flattened_results.items(): try: v = float(v) except Exception: raise ValueError( "[EvalHook] eval_function should return a nested dict of float. " "Got '{}: {}' instead.".format(k, v) ) self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) # Evaluation may take different time among workers. # A barrier make them start the next iteration together. dist.synchronize()
[docs] def after_step(self): next_iter = self.trainer.iter + 1 if self._period > 0 and next_iter % self._period == 0: # do the last eval in after_train if next_iter != self.trainer.max_iter: self._do_eval()
[docs] def after_train(self): # This condition is to prevent the eval from running after a failed training if self.trainer.iter + 1 >= self.trainer.max_iter: self._do_eval() # func is likely a closure that holds reference to the trainer # therefore we clean it to avoid circular reference in the end del self._func
[docs]class LRScheduler(HookBase): """ A hook which executes a oneflow builtin LR scheduler and summarizes the LR. It is executed after every iteration. """ def __init__(self, optimizer=None, scheduler=None): """ Args: optimizer (flow.optim.Optimizer): scheduler (flow.optim.LRScheduler): if a :class:`ParamScheduler` object, it defines the multiplier over the base LR in the optimizer. If any argument is not given, will try to obtain it from the trainer. """ self._optimizer = optimizer self._scheduler = scheduler
[docs] def before_train(self): self._optimizer = self._optimizer or self.trainer.optimizer self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
@staticmethod def get_best_param_group_id(optimizer): # NOTE: some heuristics on what LR to summarize # summarize the param group with most parameters largest_group = max(len(g["params"]) for g in optimizer.state_dict()["param_groups"]) if largest_group == 1: # If all groups have one parameter, # then find the most common initial LR, and use it for summary lr_count = Counter( [g["_options"]["lr"] for g in optimizer.state_dict()["param_groups"]] ) lr = lr_count.most_common()[0][0] for i, g in enumerate(optimizer.state_dict()["param_groups"]): if g["_options"]["lr"] == lr: return i else: for i, g in enumerate(optimizer.state_dict()["param_groups"]): if len(g["params"]) == largest_group: return i
[docs] def after_step(self): lr = self.scheduler.get_last_lr()[self._best_param_group_id] self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False) self.scheduler.step()
@property def scheduler(self): return self._scheduler or self.trainer.lr_scheduler def state_dict(self): if isinstance(self.scheduler, flow.optim.lr_scheduler._LRScheduler): return self.scheduler.state_dict() return {} def load_state_dict(self, state_dict): if isinstance(self.scheduler, flow.optim.lr_scheduler._LRScheduler): logger.info("Loading scheduler from state_dict ...") self.scheduler.load_state_dict(state_dict)