Source code for libai.evaluation.utils

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections.abc import Mapping

import oneflow as flow

from libai.utils import distributed as dist


def pad_batch(x_dict, batch_size, last_batch_lack, is_last_batch):
    x = list(x_dict.values())[0]
    tensor_batch = x.shape[0]
    assert tensor_batch <= batch_size

    if tensor_batch == batch_size and not is_last_batch:
        return x_dict, batch_size

    valid_sample = tensor_batch - last_batch_lack
    data_parallel_size = dist.get_data_parallel_size()
    assert tensor_batch % data_parallel_size == 0
    tensor_micro_batch_size = tensor_batch // data_parallel_size
    padded_dict = {}
    for key, xi in x_dict.items():
        pad_shape = (batch_size, *xi.shape[1:])
        local_xi = xi.to_global(
            sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cuda")
        ).to_local()
        padded_xi = flow.zeros(pad_shape, dtype=xi.dtype, device="cuda")
        padded_xi[:tensor_batch, ...] = padded_xi[:tensor_batch, ...] + local_xi
        for i in range(last_batch_lack - 1):
            start_idx = tensor_micro_batch_size * (data_parallel_size - i - 1) - 1
            padded_xi[start_idx:-1] = padded_xi[start_idx + 1 :]
        padded_xi = padded_xi.to_global(
            sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=xi.placement
        ).to_global(sbp=xi.sbp)
        padded_dict[key] = padded_xi
    return padded_dict, valid_sample





[docs]def flatten_results_dict(results): """ Expand a hierarchical dict of scalars into a flat dict of scalars. If results[k1][k2][k3] = v, the returned dict will have the entry {"k1/k2/k3": v}. Args: results (dict): """ r = {} for k, v in results.items(): if isinstance(v, Mapping): v = flatten_results_dict(v) for kk, vv in v.items(): r[k + "/" + kk] = vv else: r[k] = v return r