Source code for libai.models.t5_model

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import oneflow as flow
import oneflow.nn as nn

from libai.config import configurable
from libai.layers import (
    Embedding,
    LayerNorm,
    LMLogits,
    ParallelCrossEntropyLoss,
    TransformerLayer,
    VocabEmbedding,
)
from libai.layers.attention import AttnMaskType
from libai.models.utils import init_method_normal, scaled_init_method_normal
from libai.utils import distributed as dist


class ExtendedMask(flow.nn.Module):
    def forward(self, attention_mask):
        return attention_mask.unsqueeze(1)


class T5Embedding(flow.nn.Module):
    def __init__(
        self,
        hidden_size,
        vocab_size,
        max_sequence_length,
        embedding_dropout_prob,
        init_method=flow.nn.init.xavier_normal_,
        amp_enabled=False,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        self.word_embeddings = VocabEmbedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size,
            init_method=init_method,
            amp_enabled=amp_enabled,
        )
        self.position_embeddings = Embedding(
            num_embeddings=max_sequence_length,
            embedding_dim=hidden_size,
            init_method=init_method,
            amp_enabled=amp_enabled,
        )
        self.position_ids = flow.arange(
            max_sequence_length,
            dtype=flow.long,
            sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
            placement=dist.get_layer_placement(0),
        ).unsqueeze(0)

        self.embedding_dropout = flow.nn.Dropout(embedding_dropout_prob)

    def forward(self, input_ids, past_length=0):
        seq_length = input_ids.size()[1]

        position_ids = self.position_ids[:, past_length : past_length + seq_length]
        position_ids = position_ids.expand_as(input_ids).to_global(sbp=input_ids.sbp)

        word_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = word_embeddings + position_embeddings
        embeddings = self.embedding_dropout(embeddings)
        return embeddings


[docs]class T5Model(flow.nn.Module): """T5 Model that outputs logits. Args: vocab_size (int): The size of vocabulary file. hidden_size (int): The size of hidden states. hidden_layers (int): The number of ``TransformerLayer`` in the encoder and decoder. num_attention_heads (int): The number of attention heads for each attention layer of ``TransformerLayer``. intermediate_size (int): The size of intermediate layer in feed-forward network for each ``TransformerLayer``. embedding_dropout_prob (float): The dropout ratio for the output of T5Embedding Layer. hidden_dropout_prob (float): The dropout ratio for the output for each ``TransformerLayer``. attention_probs_dropout_prob (float): The dropout ratio for the output of each attention layer in ``TransformerLayer``. max_position_embeddings (int): Max sequence length of input, defines the shape of Position Embeddings in ``T5Emebedding``. initializer_range (float, optional): Sigma of the normal distribution in the initialization method. Defaults to 0.02. layernorm_eps (float, optional): The epsilon of LayerNorm layer. Defaults to 1e-12. bias_gelu_fusion (bool, optional): Whether or not to fuse the computing of bias and gelu. Defaults to ``False``. bias_dropout_fusion (bool, optional): Whether or not to fuse the computing of dropout and bias. Defaults to ``False``. scale_mask_softmax_fusion (bool, optional): Whether to fuse the computing of mask and softmax in attention layers. Defaults to ``False``. apply_query_key_layer_scaling (bool, optional): Whether or not to use layer index related scaling in computing attention scores. If ``True``, the scaling factor equals to sqrt(d) * (layer_index + 1). Defaults to ``True``. apply_residual_post_layernorm (bool, optional): If set ``True``, use original BERT residual connection ordering otherwise use Megatron BERT residual connection which is more stable when scaling model size introduced in https://arxiv.org/pdf/1909.08053.pdf. Default: ``False``. amp_enabled (bool, optional): Whether or not to set fp16 for embedding weight in T5 model. Defaults to ``False``. """ @configurable def __init__( self, vocab_size, hidden_size, hidden_layers, num_attention_heads, intermediate_size, embedding_dropout_prob, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, initializer_range=0.02, layernorm_eps=1e-12, bias_gelu_fusion=False, bias_dropout_fusion=False, scale_mask_softmax_fusion=False, apply_query_key_layer_scaling=True, apply_residual_post_layernorm=False, amp_enabled=False, ) -> None: super().__init__() init_method = init_method_normal(initializer_range) scaled_init_method = scaled_init_method_normal(initializer_range, hidden_layers) self.embedding = T5Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, embedding_dropout_prob=embedding_dropout_prob, init_method=init_method, amp_enabled=amp_enabled, ) self.extended_attn_mask = ExtendedMask() encoder_layers = flow.nn.ModuleList( [ TransformerLayer( hidden_size=hidden_size, ffn_hidden_size=intermediate_size, num_attention_heads=num_attention_heads, is_decoder=False, attention_dropout_prob=attention_probs_dropout_prob, output_dropout_prob=hidden_dropout_prob, layernorm_epsilon=layernorm_eps, init_method=init_method, output_layer_init_method=scaled_init_method, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_fusion=bias_dropout_fusion, scale_mask_softmax_fusion=scale_mask_softmax_fusion, apply_query_key_layer_scaling=apply_query_key_layer_scaling, apply_residual_post_layernorm=apply_residual_post_layernorm, attn_mask_type=AttnMaskType.padding, layer_idx=i, ) for i in range(hidden_layers) ] ) encoder_final_layernorm = LayerNorm( (hidden_size,), eps=layernorm_eps, layer_idx=hidden_layers - 1, ) self.encoder = flow.nn.Sequential() self.encoder.add_module("layers", encoder_layers) self.encoder.add_module("final_layernorm", encoder_final_layernorm) decoder_layers = flow.nn.ModuleList( [ TransformerLayer( hidden_size=hidden_size, ffn_hidden_size=intermediate_size, num_attention_heads=num_attention_heads, is_decoder=True, attention_dropout_prob=attention_probs_dropout_prob, output_dropout_prob=hidden_dropout_prob, layernorm_epsilon=layernorm_eps, init_method=init_method, output_layer_init_method=scaled_init_method, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_fusion=bias_dropout_fusion, scale_mask_softmax_fusion=scale_mask_softmax_fusion, apply_query_key_layer_scaling=apply_query_key_layer_scaling, attn_mask_type=AttnMaskType.padding, layer_idx=i, ) for i in range(hidden_layers, 2 * hidden_layers) ] ) decoder_final_layernorm = LayerNorm( (hidden_size,), eps=layernorm_eps, layer_idx=2 * hidden_layers - 1, ) self.decoder = flow.nn.Sequential() self.decoder.add_module("layers", decoder_layers) self.decoder.add_module("final_layernorm", decoder_final_layernorm) self.past_key_values = [None] * len(self.decoder.layers) self.encoder_states = None self.past_length = 0 self.lm_head = LMLogits(vocab_size, bias=True) @classmethod def from_config(cls, cfg): return { "vocab_size": cfg.vocab_size, "hidden_size": cfg.hidden_size, "hidden_layers": cfg.hidden_layers, "num_attention_heads": cfg.num_attention_heads, "intermediate_size": cfg.intermediate_size, "embedding_dropout_prob": cfg.embedding_dropout_prob, "hidden_dropout_prob": cfg.hidden_dropout_prob, "attention_probs_dropout_prob": cfg.attention_probs_dropout_prob, "max_position_embeddings": cfg.max_position_embeddings, "initializer_range": cfg.initializer_range, "layernorm_eps": cfg.layernorm_eps, "bias_gelu_fusion": cfg.bias_gelu_fusion, "bias_dropout_fusion": cfg.bias_dropout_fusion, "scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion, "apply_query_key_layer_scaling": cfg.apply_query_key_layer_scaling, "apply_residual_post_layernorm": cfg.apply_residual_post_layernorm, "amp_enabled": cfg.amp_enabled, }
[docs] def forward( self, encoder_input_ids, decoder_input_ids, encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask, use_cache=False, ): """ Args: encoder_input_ids (flow.LongTensor): Indices of input sequence tokens in vocabulary for encoder. decoder_input_ids (flow.LongTensor): Indices of input sequence tokens in vocabulary for decoder. encoder_attn_mask (flow.BoolTensor): Mask for encoder to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. decoder_attn_mask (flow.BoolTensor): Mask for decoder to avoid performing attention on subsequent token indices. Mask values have the same meaning as encoder_attn_mask. encoder_decoder_attn_mask (flow.BoolTensor): Mask for decoder to avoid performing attention on encoder padded token indices. Mask values have the same meaning as encoder_attn_mask. use_cache (bool, optional): It will be set to True, when the model is in the inference phase and used for incremental decoding. Defaults to False. Returns: flow.Tensor: logits """ encoder_input_ids = encoder_input_ids.to_global(placement=dist.get_layer_placement(0)) decoder_input_ids = decoder_input_ids.to_global(placement=dist.get_layer_placement(0)) encoder_attn_mask = encoder_attn_mask.to_global(placement=dist.get_layer_placement(0)) decoder_attn_mask = decoder_attn_mask.to_global(placement=dist.get_layer_placement(0)) encoder_decoder_attn_mask = encoder_decoder_attn_mask.to_global( placement=dist.get_layer_placement(0) ) if use_cache and self.encoder_states is not None: encoder_states = self.encoder_states else: self.set_cache(encoder_states=None, past_key_values=None) encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask) enc_embedding_output = self.embedding(encoder_input_ids) enc_hidden_states = enc_embedding_output for layer in self.encoder.layers: enc_hidden_states = layer(enc_hidden_states, encoder_attn_mask) encoder_states = self.encoder.final_layernorm(enc_hidden_states) decoder_attn_mask = self.extended_attn_mask(decoder_attn_mask) encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask) dec_embedding_output = self.embedding(decoder_input_ids, self.past_length) dec_hidden_states = dec_embedding_output if use_cache: presents = [] for layer, past_key_value in zip(self.decoder.layers, self.past_key_values): dec_hidden_states = layer( dec_hidden_states, decoder_attn_mask, encoder_states, encoder_decoder_attn_mask, past_key_value=past_key_value, use_cache=use_cache, ) if use_cache: dec_hidden_states, present = dec_hidden_states presents.append(present) if use_cache: self.set_cache(encoder_states, past_key_values=presents) decoder_states = self.decoder.final_layernorm(dec_hidden_states) logits = self.lm_head(decoder_states, self.embedding.word_embeddings.weight) return logits
def set_cache(self, encoder_states, past_key_values): self.encoder_states = encoder_states self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2] if past_key_values is None: past_key_values = [None] * len(self.decoder.layers) assert len(past_key_values) == len(self.decoder.layers), ( f"past_key_values's length {len(past_key_values)} doesn't match " f"decoder num_layers' length {self.decoder.layers}" ) self.past_key_values = past_key_values
class T5Loss(flow.nn.Module): def __init__(self) -> None: super().__init__() self.lm_loss = ParallelCrossEntropyLoss() def forward(self, logits, lm_labels, loss_mask): lm_loss = self.lm_loss(logits, lm_labels) loss_mask = loss_mask.to_global(placement=lm_loss.placement) loss_mask = loss_mask.float() denominator = loss_mask.sum().to_global( sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]) ) lm_loss = flow._C.amp_white_identity(lm_loss) lm_loss = flow._C.amp_black_identity(lm_loss) masked_lm_loss = flow.sum(lm_loss.view(-1) * loss_mask.view(-1)) / denominator masked_lm_loss = masked_lm_loss.to_global( sbp=dist.get_nd_sbp([flow.sbp.partial_sum, flow.sbp.broadcast]) ) return {"masked_lm_loss": masked_lm_loss}
[docs]class T5ForPreTraining(flow.nn.Module): """ T5 Model with classification head on top. """ def __init__(self, cfg) -> None: super().__init__() self.t5_model = T5Model(cfg) self.loss_func = T5Loss() def set_cache(self, encoder_states, past_key_values): self.t5_model.set_cache(encoder_states, past_key_values)
[docs] def forward( self, encoder_input_ids, decoder_input_ids, encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask, lm_labels=None, loss_mask=None, use_cache=False, ): """ Args: encoder_input_ids (flow.LongTensor): Indices of input sequence tokens in vocabulary for encoder. decoder_input_ids (flow.LongTensor): Indices of input sequence tokens in vocabulary for decoder. encoder_attn_mask (flow.BoolTensor): Mask for encoder to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. decoder_attn_mask (flow.BoolTensor): Mask for decoder to avoid performing attention on subsequent token indices. Mask values have the same meaning as encoder_attn_mask. encoder_decoder_attn_mask (flow.BoolTensor): Mask for decoder to avoid performing attention on encoder padded token indices. Mask values have the same meaning as encoder_attn_mask. lm_labels (flow.LongTensor, optional): Labels for computing the masked language modeling loss. Indices should be in `[-1, 0, ..., config.vocab_size]`. None for evaluating. loss_mask (flow.BoolTensor, optional): Mask to avoid performing loss computing on ignored tokens. Tokens with indices set to `-1` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. None for evaluating. use_cache (bool, optional): It will be set to True, when the model is in the inference phase and used for incremental decoding. Defaults to False. Returns: dict: A dict containing :code:`loss_value` or :code:`logits` depending on training or evaluation mode. :code:`{"masked_lm_loss": loss_value}` when training, :code:`{"prediction_scores": logits}` when evaluating. """ logits = self.t5_model( encoder_input_ids, decoder_input_ids, encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask, use_cache=use_cache, ) if lm_labels is not None: lm_loss = self.loss_func(logits, lm_labels, loss_mask) return lm_loss else: return { "prediction_scores": logits, }
@staticmethod def set_pipeline_stage_id(model): dist_utils = dist.get_dist_util() # Set pipeline parallelism stage_id if hasattr(model.t5_model.encoder.final_layernorm, "config"): # Old API in OneFlow 0.8 for module_block in model.modules(): if isinstance(module_block.origin, T5Embedding): module_block.config.set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.origin, ExtendedMask): module_block.config.set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.origin, TransformerLayer): module_block.config.set_stage( dist_utils.get_layer_stage_id(module_block.layer_idx), dist.get_layer_placement(module_block.layer_idx), ) elif isinstance(module_block.origin, LMLogits): module_block.config.set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) elif isinstance(module_block.origin, T5Loss): module_block.config.set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) model.t5_model.encoder.final_layernorm.config.set_stage( dist_utils.get_layer_stage_id(model.t5_model.encoder.final_layernorm.layer_idx), dist.get_layer_placement(model.t5_model.encoder.final_layernorm.layer_idx), ) model.t5_model.decoder.final_layernorm.config.set_stage( dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx), dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx), ) else: for module_block in model.modules(): if isinstance(module_block.to(nn.Module), T5Embedding): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.to(nn.Module), ExtendedMask): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.to(nn.Module), TransformerLayer): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(module_block.layer_idx), dist.get_layer_placement(module_block.layer_idx), ) elif isinstance(module_block.to(nn.Module), LMLogits): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) elif isinstance(module_block.to(nn.Module), T5Loss): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) model.t5_model.encoder.final_layernorm.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(model.t5_model.encoder.final_layernorm.layer_idx), dist.get_layer_placement(model.t5_model.encoder.final_layernorm.layer_idx), ) model.t5_model.decoder.final_layernorm.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx), dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx), )