Source code for libai.optim.build

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections import defaultdict
from typing import Any, Dict, List

import oneflow as flow

from libai.config import instantiate
from libai.layers import LayerNorm

# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/solver/build.py
# --------------------------------------------------------


[docs]def build_optimizer(cfg, model): """ Build an optimizer from config. """ cfg.params.model = model optim = instantiate(cfg) return optim
[docs]def get_default_optimizer_params( model, base_lr=None, weight_decay=None, weight_decay_norm=None, weight_decay_bias=None, clip_grad_max_norm=None, clip_grad_norm_type=None, overrides=None, ): """ Get default param list for optimizer, with suport for a few types of overrides. If no overrides are needed, it is equivalent to `model.parameters()`. Arguments: base_lr: lr for every group by default. Can be omitted to use the one in optimizer. weight_decay: weight decay for every group by default. Can be omitted to use the one in optimizer. weight_decay_norm: override weight decay for params in normalization layers weight_decay_bias: override weight decay for bias parameters overrides: if not `None`, provides values for optimizer hyperparameters (LR, weight decay) for module parameters with a given name; e.g. ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and weight decay values for all module parameters named `embedding`. For common transformer models, ``weight_decay_norm`` and ``weight_decay_bias`` are usually set to 0. Example: :: flow.optim.AdamW( get_default_optimizer_params(model, weight_decay_norm=0, weight_decay_bias=0), lr=0.01, weight_decay=1e-4 ) """ if overrides is None: overrides = {} defaults = {} if base_lr is not None: defaults["lr"] = base_lr if weight_decay is not None: defaults["weight_decay"] = weight_decay if clip_grad_max_norm is not None and clip_grad_norm_type is not None: defaults["clip_grad_max_norm"] = clip_grad_max_norm defaults["clip_grad_norm_type"] = clip_grad_norm_type bias_overrides = {} if weight_decay_bias is not None: bias_overrides["weight_decay"] = weight_decay_bias if len(bias_overrides): if "bias" in overrides: raise ValueError("Conflicting overrides for 'bias'") overrides["bias"] = bias_overrides norm_module_types = ( LayerNorm, flow.nn.BatchNorm1d, flow.nn.BatchNorm2d, flow.nn.BatchNorm3d, flow.nn.GroupNorm, flow.nn.InstanceNorm1d, flow.nn.InstanceNorm2d, flow.nn.InstanceNorm3d, flow.nn.FusedBatchNorm1d, flow.nn.FusedBatchNorm2d, flow.nn.FusedBatchNorm3d, ) params = [] memo = set() for module in model.modules(): for model_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) hyperparams = copy.copy(defaults) if isinstance(module, norm_module_types) and weight_decay_norm is not None: hyperparams["weight_decay"] = weight_decay_norm hyperparams.update(overrides.get(model_param_name, {})) params.append({"params": [value], **hyperparams}) return reduce_param_groups(params)
def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Transform parameter groups into per-parameter structure. Later items in `params` can overwrite parameters set in previous items. """ ret = defaultdict(dict) for item in params: assert "params" in item cur_params = {x: y for x, y in item.items() if x != "params"} for param in item["params"]: ret[param].update({"params": [param], **cur_params}) return list(ret.values()) def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Reorganize the parameter groups and merge duplicated groups. The number of parameter groups needs to be as small as possible in order to efficiently use the OneFlow multi-tensor optimizer. Therefore instead of using a parameter_group per single parameter, we reorganize the parameter groups and merge duplicated groups. This approach speeds up multi-tensor optimizer significantly. """ params = _expand_param_groups(params) groups = defaultdict(list) # re-group all parameter groups by their hyperparams for item in params: cur_params = tuple((x, y) for x, y in item.items() if x != "params") groups[cur_params].extend(item["params"]) ret = [] for param_keys, param_values in groups.items(): cur = {kv[0]: kv[1] for kv in param_keys} cur["params"] = param_values ret.append(cur) return ret