Source code for libai.models.bert_model

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import oneflow as flow
from oneflow import nn

from libai.config import configurable
from libai.layers import (
    Embedding,
    LayerNorm,
    Linear,
    LMLogits,
    ParallelCrossEntropyLoss,
    TransformerLayer,
    VocabEmbedding,
    build_activation,
)
from libai.layers.attention import AttnMaskType
from libai.utils import distributed as dist

from .utils import init_method_normal, scaled_init_method_normal


class BertExtendedAttnMask(nn.Module):
    def forward(self, attention_mask):
        # We create a 3D attention mask from a 2D tensor mask.
        # [b, 1, s]
        attention_mask_b1s = attention_mask.unsqueeze(1)
        # [b, s, 1]
        attention_mask_bs1 = attention_mask.unsqueeze(2)
        # [b, s, s]
        attention_mask_bss = attention_mask_b1s * attention_mask_bs1
        # [b, 1, s, s]
        extended_attention_mask = attention_mask_bss.unsqueeze(1)

        return extended_attention_mask


class BertEmbeddings(nn.Module):
    def __init__(
        self,
        vocab_size,
        hidden_size,
        max_sequence_length,
        embedding_dropout_prob,
        num_tokentypes=0,
        init_method=nn.init.xavier_normal_,
        amp_enabled=False,
    ):
        super().__init__()
        self.vocab_embeddings = VocabEmbedding(
            vocab_size, hidden_size, init_method=init_method, amp_enabled=amp_enabled
        )
        self.position_embeddings = Embedding(
            max_sequence_length, hidden_size, init_method=init_method, amp_enabled=amp_enabled
        )

        # NOTE(l1aoxingyu): Set position_ids sbp sign to [B, B] initially, because position_ids is a
        # 1D-tensor from 0 to seq_length, if set to [S(0), B] at first, then position_ids
        # will split at the first dim of hierarchy.
        self.position_ids = flow.arange(
            max_sequence_length,
            dtype=flow.long,
            sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
            placement=dist.get_layer_placement(0),
        ).unsqueeze(0)

        if num_tokentypes > 0:
            self.tokentype_embeddings = Embedding(
                num_tokentypes, hidden_size, init_method=init_method, amp_enabled=amp_enabled
            )
            self.tokentype_ids = flow.zeros(
                self.position_ids.size(),
                dtype=flow.long,
                sbp=self.position_ids.sbp,
                placement=self.position_ids.placement,
            )
        else:
            self.tokentype_embeddings = None

        self.embedding_dropout = nn.Dropout(embedding_dropout_prob)

    def forward(self, input_ids, tokentype_ids=None, position_ids=None):
        seq_length = input_ids.size()[1]

        word_embeddings = self.vocab_embeddings(input_ids)
        if position_ids is None:
            # Change position_ids sbp sign: [B, B] -> [S(0), B]
            position_ids = (
                self.position_ids[:, :seq_length].expand_as(input_ids).to_global(sbp=input_ids.sbp)
            )
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = word_embeddings + position_embeddings

        if self.tokentype_embeddings is not None:
            if tokentype_ids is None:
                tokentype_ids = (
                    self.tokentype_ids[:, :seq_length]
                    .expand_as(input_ids)
                    .to_global(sbp=input_ids.sbp)
                )
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)

        embeddings = self.embedding_dropout(embeddings)
        return embeddings

    def word_embeddings(self):
        return self.vocab_embeddings.weight


class BertLMPredictionHead(nn.Module):
    def __init__(self, hidden_size, init_method):
        super().__init__()
        self.dense = Linear(
            hidden_size,
            hidden_size,
            bias=True,
            parallel="data",
            init_method=init_method,
            layer_idx=-1,
        )
        self.activation_func = build_activation("gelu")
        self.layernorm = LayerNorm((hidden_size,), layer_idx=-1)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation_func(hidden_states)
        hidden_states = hidden_states.to_global(
            grad_sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.split(2)])
        )

        # NOTE(l1aoxingyu): hidden_states shape is [B, S, H] whose sbp sign: [S(0), S(2)]
        # Change from [S(0), S(2)] -> [S(0), B] because layernorm cannot get inputs with sbp S(2)
        hidden_states = hidden_states.to_global(
            sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast])
        )
        hidden_states = self.layernorm(hidden_states)
        return hidden_states


class BertPooler(nn.Module):
    """Pooler layer.

    Pool hidden states of the first token and
    add a linear transformation followed by a tanh.

    Args:
        hidden_size: hidden state feature dimension
    """

    def __init__(self, hidden_size, init_method):
        super().__init__()
        self.dense = Linear(
            hidden_size,
            hidden_size,
            bias=True,
            parallel="col",
            init_method=init_method,
            layer_idx=-1,
        )
        self.activation_func = build_activation("tanh")

    def forward(self, hidden_states):
        """Just "pool" the model by simply taking the [CLS] token corresponding
        to the first token."""
        # hidden_states: [bsz, seq_len, hidden_size]
        select_token_tensor = hidden_states[:, 0, :]
        pooled_output = self.dense(select_token_tensor)
        pooled_output = self.activation_func(pooled_output)
        return pooled_output


class BertLoss(nn.Module):
    def __init__(self, add_binary_head):
        super().__init__()
        self.add_binary_head = add_binary_head
        self.lm_loss = ParallelCrossEntropyLoss()

    def forward(self, lm_output, lm_labels, loss_mask, binary_logits, ns_labels):
        lm_labels = lm_labels.to_global(placement=lm_output.placement)
        loss_mask = loss_mask.to_global(placement=lm_output.placement)
        binary_logits = binary_logits.to_global(placement=lm_output.placement)
        ns_labels = ns_labels.to_global(placement=lm_output.placement)
        lm_loss = self.lm_loss(lm_output, lm_labels)
        loss_mask = loss_mask.float()
        # Change loss_mask.sum() sbp sign from [P, B] -> [B, B]
        # because (lm_loss * loss_mask) / loss_mask.sum() cannot accept P / P
        denominator = (
            loss_mask.sum().to_global(sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]))
            + 1e-7
        )
        masked_lm_loss = flow.sum(lm_loss.view(-1) * loss_mask.view(-1)) / denominator
        # NOTE(l1aoxingyu): Change lm loss sbp sign [P, P] -> [P, B] to add with sop loss
        # whose sbp sign: [P, B]
        masked_lm_loss = masked_lm_loss.to_global(
            sbp=dist.get_nd_sbp([flow.sbp.partial_sum, flow.sbp.broadcast])
        )

        loss_dict = {"lm_loss": masked_lm_loss}

        if self.add_binary_head:
            sop_loss = flow._C.cross_entropy(
                binary_logits, ns_labels, ignore_index=-1, reduction="none"
            ).mean()
            loss_dict["sop_loss"] = sop_loss
        return loss_dict


[docs]class BertModel(nn.Module): """The bare Bert Model transformer outputting raw hidden-states without any specific head on top. Args: vocab_size (int): The size of vocabulary file. hidden_size (int): The size of hidden states. hidden_layers (int): The number of ``TransformerLayer`` in encoder. num_attention_heads (int): The number of attention heads for each attention layer of ``TransformerLayer``. intermediate_size (int): The size of intermediate layer in feed-forward network for each ``TransformerLayer``. hidden_dropout_prob (float, optional): The dropout ratio for the output for each TransformerLayer. Defaults to 0.0. attention_probs_dropout_prob (float, optional): The dropout ratio for the output of each attention layer in ``TransformerLayer``. Defaults to 0.0. max_position_embeddings (int): Max sequence length of input, defines the shape of Position Embeddings in ``BertEmbedding``. num_tokentypes (int, optional): Number of segment token indices. Defaults to 2. add_pooling_layer (bool, optional): Whether or not averaging or pooling the sequence of hidden-states for the whole input sequence. Defaults to ``True``. initializer_range (float, optional): Sigma of the normal distribution in the initialization method. Defaults to 0.02. layernorm_epsilon (float, optional): The epsilon of LayerNorm layer. Defaults to 1e-5. bias_gelu_fusion (bool, optional): Whether or not to fuse the computing of bias and gelu. Defaults to ``False``. bias_dropout_fusion (bool, optional): Whether or not to fuse the computing of dropout and bias. Defaults to ``False``. scale_mask_softmax_fusion (bool, optional): Whether to fuse the computing of mask and softmax in attention layers. Defaults to ``False``. apply_query_key_layer_scaling (bool, optional): Whether or not to use layer index related scaling in computing attention scores. If ``True``, the scaling factor equals to sqrt(d) * (layer_index + 1). Defaults to ``True``. apply_residual_post_layernorm (bool, optional): If set ``True``, use original BERT residual connection ordering otherwise use Megatron BERT residual connection which is more stable when scaling model size introduced in https://arxiv.org/pdf/1909.08053.pdf. Default: ``False``. amp_enabled (bool, optional): Whether or not to set fp16 for embedding weight in T5 model. Defaults to ``False``. """ @configurable def __init__( self, vocab_size, hidden_size, hidden_layers, num_attention_heads, intermediate_size, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, num_tokentypes=2, add_pooling_layer=True, initializer_range=0.02, layernorm_eps=1e-12, bias_gelu_fusion=True, bias_dropout_fusion=True, scale_mask_softmax_fusion=True, apply_query_key_layer_scaling=True, apply_residual_post_layernorm=False, amp_enabled=False, ): super().__init__() init_method = init_method_normal(initializer_range) scaled_init_method = scaled_init_method_normal(initializer_range, hidden_layers) # Embeddings self.embeddings = BertEmbeddings( vocab_size, hidden_size, max_position_embeddings, hidden_dropout_prob, num_tokentypes, init_method, amp_enabled, ) # Mask generation self.extended_attn_mask = BertExtendedAttnMask() # Encoders self.encoders = nn.ModuleList( [ TransformerLayer( hidden_size, intermediate_size, num_attention_heads, attention_dropout_prob=attention_probs_dropout_prob, output_dropout_prob=hidden_dropout_prob, layernorm_epsilon=layernorm_eps, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_fusion=bias_dropout_fusion, scale_mask_softmax_fusion=scale_mask_softmax_fusion, apply_query_key_layer_scaling=apply_query_key_layer_scaling, init_method=init_method, output_layer_init_method=scaled_init_method, apply_residual_post_layernorm=apply_residual_post_layernorm, attn_mask_type=AttnMaskType.padding, # bert mask type layer_idx=i, ) for i in range(hidden_layers) ] ) self.final_layernorm = LayerNorm((hidden_size,), eps=layernorm_eps, layer_idx=-1) self.pooler = BertPooler(hidden_size, init_method) if add_pooling_layer else None @classmethod def from_config(cls, cfg): return { "vocab_size": cfg.vocab_size, "hidden_size": cfg.hidden_size, "hidden_layers": cfg.hidden_layers, "num_attention_heads": cfg.num_attention_heads, "intermediate_size": cfg.intermediate_size, "hidden_dropout_prob": cfg.hidden_dropout_prob, "attention_probs_dropout_prob": cfg.attention_probs_dropout_prob, "max_position_embeddings": cfg.max_position_embeddings, "num_tokentypes": cfg.num_tokentypes, "add_pooling_layer": cfg.add_pooling_layer, "initializer_range": cfg.initializer_range, "layernorm_eps": cfg.layernorm_eps, "bias_gelu_fusion": cfg.bias_gelu_fusion, "bias_dropout_fusion": cfg.bias_dropout_fusion, "scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion, "apply_query_key_layer_scaling": cfg.apply_query_key_layer_scaling, "apply_residual_post_layernorm": cfg.apply_residual_post_layernorm, "amp_enabled": cfg.amp_enabled, }
[docs] def forward(self, input_ids, attention_mask, tokentype_ids=None): """ Args: input_ids (flow.LongTensor): Indices of input sequence tokens in vocabulary. attention_mask (flow.BoolTensor): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. tokentype_ids (flow.LongTensor, optional): Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`. Defaults to None. """ extended_attention_mask = self.extended_attn_mask(attention_mask) embedding_output = self.embeddings(input_ids, tokentype_ids) hidden_states = embedding_output for layer in self.encoders: hidden_states = layer(hidden_states, extended_attention_mask) encoder_output = self.final_layernorm(hidden_states) pooled_output = self.pooler(encoder_output) if self.pooler is not None else None return encoder_output, pooled_output
def word_embeddings_weight(self): return self.embeddings.word_embeddings()
class BertPreTrainingHeads(nn.Module): def __init__(self, vocab_size, hidden_size, init_method, add_binary_head=True): super().__init__() self.predictions = BertLMPredictionHead(hidden_size, init_method) self.seq_relationship = Linear( hidden_size, 2, bias=True, parallel="data", init_method=init_method, layer_idx=-1, ) self.lm_logits = LMLogits(vocab_size, bias=True) self.loss_func = BertLoss(add_binary_head) def forward( self, sequence_output, pooled_output, word_embeddings_weight, ns_labels, lm_labels, loss_mask, ): prediction_scores = self.predictions(sequence_output) seq_relationship_score = self.seq_relationship(pooled_output) prediction_scores = self.lm_logits(prediction_scores, word_embeddings_weight) if lm_labels is not None: return self.loss_func( prediction_scores, lm_labels, loss_mask, seq_relationship_score, ns_labels ) return { "prediction_scores": prediction_scores, "seq_relationship_score": seq_relationship_score, }
[docs]class BertForPreTraining(nn.Module): """Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next sentence prediction (classification)` head. """ def __init__(self, cfg): super().__init__() self.bert = BertModel(cfg) self.cls_head = BertPreTrainingHeads( cfg.vocab_size, cfg.hidden_size, init_method_normal(cfg.initializer_range), cfg.add_binary_head, )
[docs] def forward( self, input_ids, attention_mask, tokentype_ids=None, ns_labels=None, lm_labels=None, loss_mask=None, ): """ Args: input_ids (flow.LongTensor): Indices of input sequence tokens in vocabulary. attention_mask (flow.BoolTensor): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. tokentype_ids (flow.LongTensor, optional): Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`. Defaults to None. ns_labels (flow.LongTensor, optional): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see `input_ids` docstring). Indices should be in `[0, 1]`: - 0 indicates sequence B is a continuation of sequence A, - 1 indicates sequence B is a random sequence. lm_labels (flow.LongTensor, optional): Labels for computing the masked language modeling loss. Indices should be in `[-1, 0, ..., config.vocab_size]`. loss_mask (flow.BoolTensor, optional): Mask to avoid performing loss computing on ignored tokens. Tokens with indices set to `-1` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) attention_mask = attention_mask.to_global(placement=dist.get_layer_placement(0)) tokentype_ids = tokentype_ids.to_global(placement=dist.get_layer_placement(0)) outputs = self.bert(input_ids, attention_mask, tokentype_ids) sequence_output, pooled_output = outputs[:2] return self.cls_head( sequence_output, pooled_output, self.bert.word_embeddings_weight(), ns_labels, lm_labels, loss_mask, )
@staticmethod def set_pipeline_stage_id(model): dist_utils = dist.get_dist_util() # Set pipeline parallelism stage_id if hasattr(model.bert.final_layernorm, "config"): # Old API in OneFlow 0.8 for module_block in model.modules(): # module.origin can get the original module if isinstance(module_block.origin, BertEmbeddings): module_block.config.set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.origin, BertExtendedAttnMask): module_block.config.set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.origin, TransformerLayer): module_block.config.set_stage( dist_utils.get_layer_stage_id(module_block.layer_idx), dist.get_layer_placement(module_block.layer_idx), ) elif isinstance(module_block.origin, BertPooler): module_block.config.set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) elif isinstance(module_block.origin, BertPreTrainingHeads): module_block.config.set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) # Set the last layernorm stage id model.bert.final_layernorm.config.set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) else: for module_block in model.modules(): if isinstance(module_block.to(nn.Module), BertEmbeddings): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.to(nn.Module), BertExtendedAttnMask): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.to(nn.Module), TransformerLayer): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(module_block.layer_idx), dist.get_layer_placement(module_block.layer_idx), ) elif isinstance(module_block.to(nn.Module), BertPooler): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) elif isinstance(module_block.to(nn.Module), BertPreTrainingHeads): module_block.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) # Set the last layernorm stage id model.bert.final_layernorm.to(nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) )
class BertForClassification(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.num_labels = cfg.num_labels self.bert = BertModel(cfg) self.classifier = Linear( cfg.hidden_size, cfg.num_labels, bias=True, parallel="row", init_method=init_method_normal(cfg.initializer_range), layer_idx=-1, ) classifier_dropout = ( cfg.classifier_dropout if cfg.classifier_dropout is not None else cfg.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) def forward(self, input_ids, attention_mask, tokentype_ids=None, labels=None, **kwargs): labels = labels if labels is not None else kwargs.get("ns_labels") outputs = self.bert(input_ids, attention_mask, tokentype_ids) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss.to_global(sbp=dist.get_nd_sbp([flow.sbp.partial_sum, flow.sbp.broadcast])) return {"cls_loss": loss} else: return {"logits": logits}