# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Mapping
import oneflow as flow
from libai.utils import distributed as dist
def pad_batch(x_dict, batch_size, last_batch_lack, is_last_batch):
x = list(x_dict.values())[0]
tensor_batch = x.shape[0]
assert tensor_batch <= batch_size
if tensor_batch == batch_size and not is_last_batch:
return x_dict, batch_size
valid_sample = tensor_batch - last_batch_lack
data_parallel_size = dist.get_data_parallel_size()
assert tensor_batch % data_parallel_size == 0
tensor_micro_batch_size = tensor_batch // data_parallel_size
padded_dict = {}
for key, xi in x_dict.items():
pad_shape = (batch_size, *xi.shape[1:])
local_xi = xi.to_global(
sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cuda")
).to_local()
padded_xi = flow.zeros(pad_shape, dtype=xi.dtype, device="cuda")
padded_xi[:tensor_batch, ...] = padded_xi[:tensor_batch, ...] + local_xi
for i in range(last_batch_lack - 1):
start_idx = tensor_micro_batch_size * (data_parallel_size - i - 1) - 1
padded_xi[start_idx:-1] = padded_xi[start_idx + 1 :]
padded_xi = padded_xi.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=xi.placement
).to_global(sbp=xi.sbp)
padded_dict[key] = padded_xi
return padded_dict, valid_sample
[docs]def flatten_results_dict(results):
"""
Expand a hierarchical dict of scalars into a flat dict of scalars.
If results[k1][k2][k3] = v, the returned dict will have the entry
{"k1/k2/k3": v}.
Args:
results (dict):
"""
r = {}
for k, v in results.items():
if isinstance(v, Mapping):
v = flatten_results_dict(v)
for kk, vv in v.items():
r[k + "/" + kk] = vv
else:
r[k] = v
return r