Source code for libai.config.config

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import inspect
import os

import pkg_resources
from omegaconf import OmegaConf

from .lazy import LazyConfig

# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/config/config.py
# --------------------------------------------------------


[docs]def configurable(init_func=None, *, from_config=None): """ Decorate a function or a class's __init__ method so that it can be called with a :class:`CfgNode` object using a :func:`from_config` function that translates :class:`CfgNode` to arguments. Examples: .. code-block:: python # Usage 1: Decorator on __init__: class A: @configurable def __init__(self, a, b=2, c=3): pass @classmethod def from_config(cls, cfg): # 'cfg' must be the first argument # Returns kwargs to be passed to __init__ return {"a": cfg.A, "b": cfg.B} a1 = A(a=1, b=2) # regular construction a2 = A(cfg) # construct with a cfg a3 = A(cfg, b=3, c=4) # construct with extra overwrite # Usage 2: Decorator on any function. Needs an extra from_config argument: @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) def a_func(a, b=2, c=3): pass a1 = a_func(a=1, b=2) # regular call a2 = a_func(cfg) # call with a cfg a3 = a_func(cfg, b=3, c=4) # call with extra overwrite Args: init_func (callable): a class's ``__init__`` method in usage 1. The class must have a ``from_config`` classmethod which takes `cfg` as the first argument. from_config (callable): the from_config function in usage 2. It must take `cfg` as its first argument. """ if init_func is not None: assert ( inspect.isfunction(init_func) and from_config is None and init_func.__name__ == "__init__" ), "Incorrect use of @configurable. Check API documentation for examples." @functools.wraps(init_func) def wrapped(self, *args, **kwargs): try: from_config_func = type(self).from_config except AttributeError as e: raise AttributeError( "Class with @configurable must have a 'from_config' classmethod." ) from e if not inspect.ismethod(from_config_func): raise TypeError("Class with @configurable must have a 'from_config' classmethod.") if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) init_func(self, **explicit_args) else: init_func(self, *args, **kwargs) return wrapped else: if from_config is None: return configurable # @configurable() is made equivalent to @configurable assert inspect.isfunction( from_config ), "from_config argument of configurable must be a function!" def wrapper(orig_func): @functools.wraps(orig_func) def wrapped(*args, **kwargs): if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config, *args, **kwargs) return orig_func(**explicit_args) else: return orig_func(*args, **kwargs) wrapped.from_config = from_config return wrapped return wrapper
def _get_args_from_config(from_config_func, *args, **kwargs): """ Use `from_config` to obtain explicit arguments. Returns: dict: arguments to be used for cls.__init__ """ signature = inspect.signature(from_config_func) if list(signature.parameters.keys())[0] != "cfg": if inspect.isfunction(from_config_func): name = from_config_func.__name__ else: name = f"{from_config_func.__self__}.from_config" raise TypeError(f"{name} must take 'cfg' as the first argument!") support_var_arg = any( param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] for param in signature.parameters.values() ) if support_var_arg: # forward all arguments to from_config, if from_config accepts them ret = from_config_func(*args, **kwargs) else: # forward supported arguments to from_config supported_arg_names = set(signature.parameters.keys()) extra_kwargs = {} for name in list(kwargs.keys()): if name not in supported_arg_names: extra_kwargs[name] = kwargs.pop(name) ret = from_config_func(*args, **kwargs) # forward the other arguments to __init__ ret.update(extra_kwargs) return ret def _called_with_cfg(*args, **kwargs): """ Returns: bool: whether the arguments contain CfgNode and should be considered forwarded to from_config. """ from omegaconf import DictConfig if len(args) and isinstance(args[0], DictConfig): return True if isinstance(kwargs.pop("cfg", None), DictConfig): return True # `from_config`'s first argument is forced to be "cfg". # So the above check covers all cases. return False
[docs]def try_get_key(cfg, *keys, default=None): """ Try select keys from cfg until the first key that exists. Otherwise return default. """ for k in keys: none = object() p = OmegaConf.select(cfg, k, default=none) if p is not none: return p return default
[docs]def get_config(config_path): """ Returns a config object from a config_path. Args: config_path (str): config file name relative to libai's "configs/" directory, e.g., "common/models/bert.py" Returns: omegaconf.DictConfig: a config object """ cfg_file = pkg_resources.resource_filename("libai.config", os.path.join("configs", config_path)) if not os.path.exists(cfg_file): raise RuntimeError("{} not available in LiBai configs!".format(config_path)) cfg = LazyConfig.load(cfg_file) return cfg