# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
import math
import operator
import time
from collections import Counter
import oneflow as flow
from libai.evaluation import flatten_results_dict
from libai.utils import distributed as dist
from libai.utils.checkpoint import Checkpointer
from libai.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
from libai.utils.events import EventWriter
from libai.utils.timer import Timer
from .trainer import HookBase
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/hooks.py
# --------------------------------------------------------
"""
Implement some common hooks.
"""
logger = logging.getLogger(__name__)
[docs]class CallbackHook(HookBase):
"""
Create a hook using callback functions provided by the user.
"""
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
"""
Each argument is a function that takes one argument: the trainer.
"""
self._before_train = before_train
self._before_step = before_step
self._after_step = after_step
self._after_train = after_train
[docs] def before_train(self):
if self._before_train:
self._before_train(self.trainer)
[docs] def after_train(self):
if self._after_train:
self._after_train(self.trainer)
# The functions may be closures that hold reference to the trainer
# Therefore, delete them to avoid circular reference.
del self._before_train, self._after_train
del self._before_step, self._after_step
[docs] def before_step(self):
if self._before_step:
self._before_step(self.trainer)
[docs] def after_step(self):
if self._after_step:
self._after_step(self.trainer)
[docs]class IterationTimer(HookBase):
"""
Track the time spent for each iteration (each run_step call in the trainer).
Print a summary in the end of training.
This hook uses the time between the call to its :meth:`before_step`
and :meth:`after_step` methods.
Under the convention that :meth:`before_step` of all hooks should only
take negligible amount of time, the :class:`IterationTimer` hook should be
placed at the beginning of the list of hooks to obtain accurate timing.
"""
def __init__(self, warmup_iter=3):
"""
Args:
warmup_iter (int): the number of iterations at the beginning to exclude
from timing.
"""
self._warmup_iter = warmup_iter
self._step_timer = Timer()
[docs] def before_train(self):
self._start_time = time.perf_counter()
self._total_timer = Timer()
self._total_timer.pause()
[docs] def after_train(self):
total_time = time.perf_counter() - self._start_time
total_time_minus_hooks = self._total_timer.seconds()
hook_time = total_time - total_time_minus_hooks
num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
if num_iter > 0 and total_time_minus_hooks > 0:
# Speed is meaningful only after warmup
# NOTE this format is parsed by grep in some scripts
logger.info(
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
num_iter,
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
total_time_minus_hooks / num_iter,
)
)
logger.info(
"Total training time: {} ({} on hooks)".format(
str(datetime.timedelta(seconds=int(total_time))),
str(datetime.timedelta(seconds=int(hook_time))),
)
)
[docs] def before_step(self):
self._step_timer.reset()
self._total_timer.resume()
[docs] def after_step(self):
# +1 because we're in after_step
iter_done = self.trainer.iter - self.trainer.start_iter + 1
if iter_done >= self._warmup_iter:
sec = self._step_timer.seconds()
self.trainer.storage.put_scalars(time=sec)
else:
self._start_time = time.perf_counter()
self._total_timer.reset()
self._total_timer.pause()
[docs]class PeriodicWriter(HookBase):
"""
Write events to EventStorage periodically.
It is executed every ``period`` iterations and after the last iteration.
"""
def __init__(self, writers, period=20):
"""
Args:
writers (list[EventWriter]): a list of EventWriter objects
period (int):
"""
self._writers = writers
for w in writers:
assert isinstance(w, EventWriter), w
self._period = period
[docs] def after_step(self):
if (self.trainer.iter + 1) % self._period == 0 or (
self.trainer.iter == self.trainer.max_iter - 1
):
for writer in self._writers:
writer.write()
[docs] def after_train(self):
for writer in self._writers:
writer.close()
[docs]class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
"""
Same as :class:`libai.utils.checkpoint.PeriodicCheckpointer`, but as a hook.
Note that when used as a hook,
it is unable to save additional data other than what's defined
by the given `checkpointer`.
It is executed every ``period`` iterations and after the last iteration.
"""
[docs] def before_train(self):
self.max_iter = self.trainer.max_iter
[docs] def after_step(self):
self.step(self.trainer.iter)
[docs]class BestCheckpointer(HookBase):
"""
Checkpoints best weights based off given metric.
This hook should be used in conjunction to and executed after the hook
that produces the metric, e.g. `EvalHook`.
"""
def __init__(
self,
eval_period: int,
checkpointer: Checkpointer,
val_metric: str,
mode: str = "max",
file_prefix: str = "model_best",
) -> None:
"""
Args:
eval_period (int): the period `EvalHook` is set to run.
checkpointer: the checkpointer object used to save checkpoints.
val_metric (str): validation metric to track for best checkpoint, e.g. "acc@1"
mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be
maximized or minimized, e.g. for "acc@1" it should be "max"
file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best"
"""
self._period = eval_period
self._val_metric = val_metric
assert mode in [
"max",
"min",
], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.'
if mode == "max":
self._compare = operator.gt
else:
self._compare = operator.lt
self._checkpointer = checkpointer
self._file_prefix = file_prefix
self.best_metric = None
self.best_iter = None
def _update_best(self, val, iteration):
if math.isnan(val) or math.isinf(val):
return False
self.best_metric = val
self.best_iter = iteration
return True
def _best_checking(self):
metric_tuple = self.trainer.storage.latest().get(self._val_metric)
flag = flow.zeros(1)
if dist.is_main_process():
if metric_tuple is None:
logger.warning(
f"Given val metric {self._val_metric} does not seem to be computed/stored. "
"Will not be checkpointed based on that."
)
else:
latest_metric, metric_iter = metric_tuple
if self.best_metric is None:
if self._update_best(latest_metric, metric_iter):
flag = flag + 1
logger.info(
f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps"
)
elif self._compare(latest_metric, self.best_metric):
flag = flag + 1
logger.info(
f"Saved best model as latest eval score for {self._val_metric} is "
f"{latest_metric:0.5f}, better than last best score "
f"{self.best_metric:0.5f} @ iteration {self.best_iter}."
)
self._update_best(latest_metric, metric_iter)
else:
logger.info(
f"Not saving as latest eval score for "
f"{self._val_metric} is {latest_metric:0.5f}, "
f"not better than best score {self.best_metric:0.5f} "
f"@ iteration {self.best_iter}."
)
dist.synchronize()
flag = flag.to_global(
sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cpu")
)
if flag.to_local().item() == 1:
self._checkpointer.save(f"{self._file_prefix}")
[docs] def after_step(self):
# same conditions as `EvalHook`
next_iter = self.trainer.iter + 1
if (
self._period > 0
and next_iter % self._period == 0
and next_iter != self.trainer.max_iter
):
self._best_checking()
[docs] def after_train(self):
# same conditions as `EvalHook`
if self.trainer.iter + 1 >= self.trainer.max_iter:
self._best_checking()
[docs]class EvalHook(HookBase):
"""
Run an evaluation function periodically, and at the end of training.
It is executed every ``eval_period`` iterations and after the last iteration.
"""
def __init__(self, eval_period, eval_function):
"""
Args:
eval_period (int): the period to run `eval_function`.
eval_function (callable): a function which takes no arguments, and
returns a nested dict of evaluation metrics.
Note:
This hook must be enabled in all or none workers.
If you would like only certain workers to perform evaluation,
give other workers a no-op function (`eval_function=lambda: None`).
"""
self._period = eval_period
self._func = eval_function
def _do_eval(self):
results = self._func()
if results:
assert isinstance(
results, dict
), "Eval function must return a dict. Got {} instead.".format(results)
flattened_results = flatten_results_dict(results)
# fixme: flatten_results_dict is not defined
for k, v in flattened_results.items():
try:
v = float(v)
except Exception:
raise ValueError(
"[EvalHook] eval_function should return a nested dict of float. "
"Got '{}: {}' instead.".format(k, v)
)
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
dist.synchronize()
[docs] def after_step(self):
next_iter = self.trainer.iter + 1
if self._period > 0 and next_iter % self._period == 0:
# do the last eval in after_train
if next_iter != self.trainer.max_iter:
self._do_eval()
[docs] def after_train(self):
# This condition is to prevent the eval from running after a failed training
if self.trainer.iter + 1 >= self.trainer.max_iter:
self._do_eval()
# func is likely a closure that holds reference to the trainer
# therefore we clean it to avoid circular reference in the end
del self._func
[docs]class LRScheduler(HookBase):
"""
A hook which executes a oneflow builtin LR scheduler and summarizes the LR.
It is executed after every iteration.
"""
def __init__(self, optimizer=None, scheduler=None):
"""
Args:
optimizer (flow.optim.Optimizer):
scheduler (flow.optim.LRScheduler):
if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
in the optimizer.
If any argument is not given, will try to obtain it from the trainer.
"""
self._optimizer = optimizer
self._scheduler = scheduler
[docs] def before_train(self):
self._optimizer = self._optimizer or self.trainer.optimizer
self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
@staticmethod
def get_best_param_group_id(optimizer):
# NOTE: some heuristics on what LR to summarize
# summarize the param group with most parameters
largest_group = max(len(g["params"]) for g in optimizer.state_dict()["param_groups"])
if largest_group == 1:
# If all groups have one parameter,
# then find the most common initial LR, and use it for summary
lr_count = Counter(
[g["_options"]["lr"] for g in optimizer.state_dict()["param_groups"]]
)
lr = lr_count.most_common()[0][0]
for i, g in enumerate(optimizer.state_dict()["param_groups"]):
if g["_options"]["lr"] == lr:
return i
else:
for i, g in enumerate(optimizer.state_dict()["param_groups"]):
if len(g["params"]) == largest_group:
return i
[docs] def after_step(self):
lr = self.scheduler.get_last_lr()[self._best_param_group_id]
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
self.scheduler.step()
@property
def scheduler(self):
return self._scheduler or self.trainer.lr_scheduler
def state_dict(self):
if isinstance(self.scheduler, flow.optim.lr_scheduler._LRScheduler):
return self.scheduler.state_dict()
return {}
def load_state_dict(self, state_dict):
if isinstance(self.scheduler, flow.optim.lr_scheduler._LRScheduler):
logger.info("Loading scheduler from state_dict ...")
self.scheduler.load_state_dict(state_dict)