Source code for libai.data.samplers.samplers

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import oneflow as flow
from oneflow.utils.data import Sampler


[docs]class CyclicSampler(Sampler): """ This sampler supports cyclic sampling, and it is also compatible with non-data parallelism and data parallelism. Arguments: dataset: dataset to be sampled. micro_batch_size: batch size for per model instance. global_batch_size is micro_batch_size times data_parallel_size. shuffle: whether to shuffle the dataset. consumed_samples: the number of samples that have been trained at the current time, used for resuming training (default: ``0``). data_parallel_rank: local rank for data parallelism. data_parallel_size: the size of data parallelism. seed: random seed, used for reproducing experiments (default: ``0``). """ def __init__( self, dataset, micro_batch_size, shuffle=False, consumed_samples=0, data_parallel_rank=0, data_parallel_size=1, seed=0, ): self.dataset = dataset self.data_size = len(self.dataset) self.shuffle = shuffle self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size self.micro_batch_size = micro_batch_size self.actual_batch_size = self.micro_batch_size * self.data_parallel_size self.data_size_per_epoch = self.data_size // self.actual_batch_size * self.micro_batch_size self.consumed_samples = consumed_samples self.seed = seed def __iter__(self): """divide the data into data_parallel_size buckets, and shuffle it if `shuffle` is set to `True`. Each processor samples from its own buckets and data_loader will load the corresponding data. """ epoch = self.consumed_samples // self.data_size_per_epoch current_epoch_samples = self.consumed_samples % self.data_size_per_epoch batch = [] while True: bucket_offset = current_epoch_samples // self.data_parallel_size start_idx = self.data_parallel_rank * self.data_size_per_epoch if self.shuffle: generator = flow.Generator() generator.manual_seed(self.seed + epoch) random_idx = flow.randperm(self.data_size_per_epoch, generator=generator).tolist() indices = [start_idx + x for x in random_idx[bucket_offset:]] else: seq_idx = flow.arange(self.data_size_per_epoch).tolist() indices = [start_idx + x for x in seq_idx[bucket_offset:]] epoch += 1 if hasattr(self.dataset, "supports_prefetch") and self.dataset.supports_prefetch: self.dataset.prefetch(indices) for idx in indices: batch.append(idx) if len(batch) == self.micro_batch_size: self.consumed_samples += self.actual_batch_size yield batch batch = [] current_epoch_samples = 0 def __len__(self): return self.data_size
[docs] def set_consumed_samples(self, consumed_samples): """You can recover the training iteration by setting `consumed_samples`.""" self.consumed_samples = consumed_samples
[docs] def set_epoch(self, epoch): """Used for restoring training status.""" self.epoch = epoch
[docs]class SingleRoundSampler(Sampler): """ This sampler supports single round sampling, and it is also compatible with non data parallelism and data parallelism. Arguments: dataset: dataset to be sampled. micro_batch_size: batch size for per model instance, global_batch_size is micro_batch_size times data_parallel_size. shuffle: whether to shuffle the dataset. data_parallel_rank: local rank for data parallelism. data_parallel_size: the size of data parallelism. seed: random seed, used for reproducing experiments (default: ``0``). drop_last: whether to drop the remaining data (default: ``False``). """ def __init__( self, dataset, micro_batch_size, shuffle=False, data_parallel_rank=0, data_parallel_size=1, seed=0, drop_last=False, ): self.dataset = dataset self.data_size = len(self.dataset) self.shuffle = shuffle self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size self.micro_batch_size = micro_batch_size self.seed = seed self.drop_last = drop_last def __iter__(self): bucket_size = self.data_size // self.data_parallel_size remain = self.data_size % self.data_parallel_size start_idx = self.data_parallel_rank * bucket_size if self.data_parallel_rank < remain: bucket_size += 1 start_idx += min(self.data_parallel_rank, remain) if self.shuffle: generator = flow.Generator() generator.manual_seed(self.seed) random_idx = flow.randperm(bucket_size, generator=generator).tolist() indices = [start_idx + x for x in random_idx] else: seq_idx = flow.arange(bucket_size).tolist() indices = [start_idx + x for x in seq_idx] if hasattr(self.dataset, "supports_prefetch") and self.dataset.supports_prefetch: self.dataset.prefetch(indices) batch = [] for idx in indices: batch.append(idx) if len(batch) == self.micro_batch_size: yield batch batch = [] if not self.drop_last: if self.data_parallel_rank >= remain and remain > 0: batch.append(0) if len(batch) > 0: yield batch def __len__(self): global_batch_size = self.micro_batch_size * self.data_parallel_size if self.drop_last: return self.data_size // global_batch_size else: return (self.data_size + global_batch_size - 1) // global_batch_size