# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import shutil
from collections import defaultdict
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple
import numpy as np
import oneflow as flow
from oneflow import nn
from termcolor import colored
import libai.utils.distributed as dist
from libai.utils.file_io import HTTPURLHandler, PathManagerBase
class _IncompatibleKeys(
NamedTuple(
# pyre-fixme[10]: Name `IncompatibleKeys` is used but not defined.
"IncompatibleKeys",
[
("missing_keys", List[str]),
("unexpected_keys", List[str]),
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
("incorrect_shapes", List[Tuple]),
],
)
):
pass
[docs]class Checkpointer(object):
"""
A checkpointer that can save/load model as well as extra checkpointable
objects.
"""
# NOTE: only support data_parallel for saving model
# TODO: save model: support model_parallel and pipeline parallel
def __init__(
self,
model: nn.Module,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: object,
):
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the `state_dict()` and `load_state_dict()` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
self.model = model
self.checkpointables = copy.copy(checkpointables)
self.logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
# Default PathManager, support HTTP URLs
# A user may want to use a different project-specific PathManagerBase'
self.path_manager: PathManagerBase = PathManagerBase()
self.path_manager.register_handler(HTTPURLHandler())
[docs] def save(self, name: str, **kwargs: Dict[str, str]):
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
data = {}
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = name
save_dir = os.path.join(self.save_dir, basename)
assert os.path.basename(save_dir) == basename, basename
if not self.path_manager.exists(save_dir):
self.path_manager.mkdirs(save_dir)
self.logger.info("Saving checkpoint to {}".format(save_dir))
for save_name in data:
if save_name == "iteration":
continue
save_file = os.path.join(save_dir, save_name)
# If directory existing, remove it for saving
if self.path_manager.exists(save_file):
self.path_manager.mkdirs(save_file)
flow.save(data[save_name], save_file, global_dst_rank=0)
if basename != "model_best":
self.tag_last_checkpoint(basename)
[docs] def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object:
"""
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
checkpointables (list): List of checkpointable names to load. If not
specified (None), will load all the possible checkpointables.
Returns:
dict:
extra data loaded from the checkpoint that has not been
processed. For example, those saved with
:meth:`.save(**extra_data)`.
"""
if not path:
# no checkpoint provided
self.logger.info("No checkpoint found. Training model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(path))
checkpoint = self._load_file(path)
incompatible = self._load_model(checkpoint)
if incompatible is not None: # handle some existing subclasses that returns None
self._log_incompatible_keys(incompatible)
for key in self.checkpointables if checkpointables is None else checkpointables:
if key in checkpoint: # pyre-ignore
self.logger.info("Loading {} from {}".format(key, path))
obj = self.checkpointables[key]
obj.load_state_dict(checkpoint.pop(key)) # pyre-ignore
# return any further checkpoint data
return checkpoint
[docs] def has_checkpoint(self):
"""
Returns:
bool: whether a checkpoint exists in the target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
return self.path_manager.exists(save_file)
[docs] def get_checkpoint_file(self):
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
# load checkpoint file in rank0
if flow.env.get_rank() == 0:
with open(save_file, "r") as f:
last_saved = f.read().strip()
else:
last_saved = None
# broadcast checkpoint file to other ranks
last_saved = dist.broadcast_py_object(last_saved, src=0)
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
return os.path.join(self.save_dir, last_saved)
[docs] def resume_or_load(self, path: str, *, resume: bool = True):
"""
If `resume` is True, this method attempts to resume from the last
checkpoint (if exists). Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
return self.load(path, checkpointables=[])
[docs] def tag_last_checkpoint(self, last_filename_basename: str):
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
with self.path_manager.open(save_file, "w") as f:
f.write(last_filename_basename) # pyre-ignore
def _load_file(self, f: str):
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
dict: with keys "model" and optionally others that are saved by
the checkpointer dict["model"] must be a dict which maps strings
to flow.Tensor or numpy arrays.
"""
data = {}
keys = self.path_manager.ls(f)
# broadcast checkpointer keys to other ranks
keys = dist.broadcast_py_object(keys, src=0)
for key in keys:
data[key] = flow.load(os.path.join(f, key), global_src_rank=0)
try:
data["iter"] = int(f.split("_")[-1])
except: # noqa
self.logger.info(f"iter info in {f} not found, set iter to 0")
data["iter"] = 0
return data
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
"""
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
model_state_dict = self.model.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)
incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
return _IncompatibleKeys(
missing_keys=incompatible.missing_keys,
unexpected_keys=incompatible.unexpected_keys,
incorrect_shapes=incorrect_shapes,
)
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
"""
Log information about the incompatible keys returned by ``_load_model``.
"""
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
self.logger.warning(
"Skip loading parameter '{}' to the model due to incompatible "
"shapes: {} in the checkpoint but {} in the "
"model! You might want to double check if this is expected.".format(
k, shape_checkpoint, shape_model
)
)
if incompatible.missing_keys:
missing_keys = _filter_reused_missing_keys(self.model, incompatible.missing_keys)
if missing_keys:
self.logger.info(get_missing_parameters_message(missing_keys))
if incompatible.unexpected_keys:
self.logger.info(get_unexpected_parameters_message(incompatible.unexpected_keys))
def _convert_ndarray_to_tensor(self, state_dict: dict):
"""
In-place convert all numpy arrays in the state_dict to flow tensor.
Args:
state_dict (dict): a state-dict to be loaded to the model.
"""
# model could be an OrderedDict with _metadata attribute
# (as returned by oneflow's state_dict()). We should preserve these
# properties.
for k in list(state_dict.keys()):
v = state_dict[k]
if not isinstance(v, np.ndarray) and not isinstance(v, flow.Tensor):
raise ValueError("Unsupported type found in checkpoint! {}: {}".format(k, type(v)))
# If it's local tensor, convert it to global tensor.
if not v.is_global:
if k in self.model.state_dict():
model_v = self.model.state_dict()[k]
state_dict[k] = v.to_global(sbp=model_v.sbp, placement=model_v.placement)
[docs]class PeriodicCheckpointer:
"""
Save checkpoints periodically. When `.step(iteration)` is called, it will
execute `checkpointer.save` on the given checkpointer, if iteration is a
multiple of period or if `max_iter` is reached.
"""
def __init__(
self,
checkpointer: Checkpointer,
period: int,
max_iter: Optional[int] = None,
max_to_keep: Optional[int] = None,
file_prefix: str = "model",
):
"""
Args:
checkpointer (Any): the checkpointer object used to save
checkpoints.
period (int): the period to save checkpoint.
max_epoch (int): maximum number of epochs. When it is reached,
a checkpoint named "model_final" will be saved.
"""
self.checkpointer = checkpointer
self.period = int(period)
self.max_iter = max_iter
if max_to_keep is not None:
assert max_to_keep > 0
self.max_to_keep = max_to_keep
self.recent_checkpoints: List[str] = []
self.file_prefix = file_prefix
self.path_manager: PathManagerBase = checkpointer.path_manager
[docs] def step(self, iteration: int, **kwargs: Any):
"""
Perform the appropriate action at the given iteration.
Args:
iteration (int): the current epoch, ranged in [0, max_iter-1].
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
iteration = int(iteration)
additional_state = {"iteration": iteration}
additional_state.update(kwargs)
if (iteration + 1) % self.period == 0:
self.checkpointer.save(
"{}_{:07d}".format(self.file_prefix, iteration), **additional_state
)
if self.max_to_keep is not None:
self.recent_checkpoints.append(self.checkpointer.get_checkpoint_file())
if len(self.recent_checkpoints) > self.max_to_keep:
file_to_delete = self.recent_checkpoints.pop(0)
if (
dist.is_main_process()
and self.path_manager.exists(file_to_delete)
and not file_to_delete.endswith(
"{}_{:07d}".format(self.file_prefix, iteration)
)
):
if not self.path_manager.isfile(file_to_delete):
shutil.rmtree(file_to_delete)
else:
self.path_manager.rm(file_to_delete)
if self.max_iter is not None:
if iteration >= self.max_iter - 1:
self.checkpointer.save(f"{self.file_prefix}_final", **additional_state)
[docs] def save(self, name: str, **kwargs: Any):
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
self.checkpointer.save(name, **kwargs)
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]:
"""
Filter "missing keys" to not include keys that have been loaded with another name.
"""
keyset = set(keys)
param_to_names = defaultdict(set) # param -> names that points to it
for module_prefix, module in _named_modules_with_dup(model):
for name, param in list(module.named_parameters(recurse=False)) + list(
module.named_buffers(recurse=False) # pyre-ignore
):
full_name = (module_prefix + "." if module_prefix else "") + name
param_to_names[param].add(full_name)
for names in param_to_names.values():
# if one name appears missing but its alias exists, then this
# name is not considered missing
if any(n in keyset for n in names) and not all(n in keyset for n in names):
[keyset.remove(n) for n in names if n in keyset]
return list(keyset)
def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
msg += "\n".join(" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items())
return msg
def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
msg += "\n".join(" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items())
return msg
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix) :]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata # pyre-ignore
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix) :]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1 :]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(group) + "}"
def _named_modules_with_dup(model: nn.Module, prefix: str = "") -> Iterable[Tuple[str, nn.Module]]:
"""
The same as `model.named_modules()`, except that it includes
duplicated modules that have more than one name.
"""
yield prefix, model
for name, module in model._modules.items(): # pyre-ignore
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _named_modules_with_dup(module, submodule_prefix)