Source code for libai.layers.attention

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import math
from typing import Tuple

import oneflow as flow
from oneflow import nn

from .linear import Linear

class AttnMaskType(enum.Enum):
    padding = 1
    causal = 2

[docs]class MultiheadAttention(nn.Module): """Multi-head attention layer, support self attention and cross attention. Args: hidden_size: size of hidden state. num_attention_heads: number of attention heads. is_cross_attention: used to specify whether it is self attention or cross attention. Defaults to False. attention_dropout_prob: dropout probability of attention weights. Defaults to 0.0. output_dropout_prob: dropout probability of output. Defaults to 0.0. init_method: method to initialize the input layer weights. Defaults to ``init.xavier_normal_``. output_layer_init_method: method to initialize the output layer weights. If None, use ``init_method``. bias_dropout_fusion: whether to fuse add bias and dropout. Defaults to False. scale_mask_softmax_fusion: whether to fuse scale, mask and softmax. Defaults to False. apply_query_key_layer_scaling: if `True`, scaling the attention score by layer index. Defaults to False. layer_idx: a layer_idx sign which determines the placements. It will be used in pipeline parallelism. Defaults to 0. """ def __init__( self, hidden_size, num_attention_heads, is_cross_attention=False, attention_dropout_prob=0.0, output_dropout_prob=0.0, init_method=nn.init.xavier_normal_, output_layer_init_method=None, bias_dropout_fusion=False, scale_mask_softmax_fusion=False, apply_query_key_layer_scaling=False, attn_mask_type=AttnMaskType.padding, *, layer_idx=0 ): super().__init__() self.hidden_size = hidden_size if output_layer_init_method is None: output_layer_init_method = init_method assert ( hidden_size % num_attention_heads == 0 ), "hidden_size must be divisible by num_attention_heads." self.num_heads = num_attention_heads self.head_size = hidden_size // num_attention_heads self.attn_mask_type = attn_mask_type self.attention_dropout_prob = attention_dropout_prob self.dropout = nn.Dropout(p=attention_dropout_prob) self.norm_factor = 1.0 / math.sqrt(float(self.head_size)) self.coeff = None if apply_query_key_layer_scaling: self.coeff = layer_idx + 1 self.norm_factor /= self.coeff self.is_cross_attention = is_cross_attention self.scale_mask_softmax_fusion = scale_mask_softmax_fusion self.bias_dropout_fusion = bias_dropout_fusion if self.bias_dropout_fusion: self.output_dropout_prob = output_dropout_prob else: self.output_dropout = nn.Dropout(p=output_dropout_prob) if self.is_cross_attention: self.query = Linear( self.hidden_size, self.hidden_size, parallel="col", init_method=init_method, layer_idx=layer_idx, ) self.key_value = Linear( self.hidden_size, self.hidden_size * 2, parallel="col", init_method=init_method, layer_idx=layer_idx, ) else: self.query_key_value = Linear( self.hidden_size, self.hidden_size * 3, parallel="col", init_method=init_method, layer_idx=layer_idx, ) self.dense = Linear( self.hidden_size, self.hidden_size, parallel="row", init_method=output_layer_init_method, skip_bias_add=self.bias_dropout_fusion, layer_idx=layer_idx, )
[docs] def forward( self, hidden_states: flow.Tensor, encoder_states: flow.Tensor = None, attention_mask: flow.Tensor = None, past_key_value: Tuple[flow.Tensor, flow.Tensor] = None, use_cache: bool = False, ): """ Args: hidden_states (flow.Tensor): shape is [bsz, tgt_len, hidden_size]. encoder_states (flow.Tensor, optional): shape is [bsz, src_len, hidden_size]. Defaults to None. attention_mask (flow.Tensor, optional): shape is [bsz, 1, tgt_len, src_len]. It should be the combination of padding mask and casual mask. It is the padding mask of source input when used with self-attention in encoder. And it is the combination of padding mask of target input and casual mask when used with self-attention in decoder. It is the padding mask of source input when used with cross-attention in decoder. Defaults to None. past_key_value (Tuple[flow.Tensor, flow.Tensor], optional): tuple of key and value, each shape is [bsz, num_heads, src_len, head_size]. Defaults to None. use_cache (bool, optional): it will be set to True, when the model is in the inference phase and used for incremental decoding. Defaults to False. """ # hidden_states, encoder_states: [S(0), B] # attention_mask: [S(0), B] if encoder_states is not None: encoder_states = encoder_states.to_global(placement=hidden_states.placement) if attention_mask is not None: attention_mask = attention_mask.to_global(placement=hidden_states.placement) bsz, tgt_len = hidden_states.size()[:2] if self.is_cross_attention: # if it is cross attention, key and value should be calculated only once, and the # result can be reused. query = self.query(hidden_states) query = query.view(bsz, -1, self.num_heads, self.head_size) query = query.permute(0, 2, 1, 3) if past_key_value is not None: key, value = past_key_value elif encoder_states is not None: key_value = self.key_value(encoder_states) key_value = key_value.view(bsz, -1, self.num_heads, 2 * self.head_size) key_value = key_value.permute(0, 2, 1, 3) key, value = flow.chunk(key_value, chunks=2, dim=-1) else: raise ValueError( "past_key_value and encoder_states cannot be None at the same time." ) else: # if it is self attention, query, key, and value are all obtained from hidden_states. # when in the inference phase of an incremental decoder, # hidden_states is the last-added state, # the full key and value could be obtained by concatenating with past_key_value. query_key_value = self.query_key_value(hidden_states) query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size) query_key_value = query_key_value.permute( 0, 2, 1, 3 ) # [bsz, num_heads, src_len, 3 * head_size] query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1) if past_key_value is not None: past_key, past_value = past_key_value key =, key), dim=2) value =, value), dim=2) # query, key, value: [S(0), S(1)], shape: [bsz, num_heads, seq_length, head_size] if use_cache: past_key_value = (key, value) # [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)] attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor) # [S(0), S(1)] x [S(0), B] = [S(0), S(1)] if attention_mask is not None: if self.scale_mask_softmax_fusion: if self.attn_mask_type == AttnMaskType.padding: attention_weights = flow._C.fused_scale_mask_softmax_dropout( attention_scores, attention_mask, fill_value=-10000.0, scale=self.coeff, p=self.attention_dropout_prob, )[0] else: if self.coeff is not None: attention_scores *= self.coeff attention_scores = flow.mul(attention_scores, attention_mask) attention_scores = attention_scores - 10000.0 * (1 - attention_mask) # TODO(xingyu.liao): graph will occur `where_scalar` errors # when using `masked_fill` # attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0) attention_weights = flow.softmax(attention_scores, dim=-1) # [bsz, num_heads, tgt_len, src_len] attention_weights = self.dropout(attention_weights) else: if self.scale_mask_softmax_fusion and self.attn_mask_type == AttnMaskType.causal: attention_weights = flow._C.fused_scale_tril_softmax_mask_scale( attention_scores, p=self.attention_dropout_prob, diagonal=0, tril_scale_value=self.coeff, )[0] else: attention_weights = flow.softmax(attention_scores, dim=-1) # [bsz, num_heads, tgt_len, src_len] attention_weights = self.dropout(attention_weights) # Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)] context = flow.matmul(attention_weights, value) # Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size] context = context.transpose(1, 2) # Concat multi-head results from # [bsz, tgt_len, num_heads, head_size] -> [bsz, tgt_len, num_heads * head_size] # SBP sign: [S(0), S(2)] # [S(0), S(2)] x [B, S(0)] = [S(0), P] -> [S(0), B] output = self.dense(context.flatten(2)) if self.bias_dropout_fusion: output, bias = output output = flow._C.fused_bias_add_dropout( output, bias, p=self.output_dropout_prob, axis=output.ndim - 1 ) else: output = self.output_dropout(output) if use_cache: output = (output, past_key_value) return output
[docs] def extra_repr(self) -> str: return "hidden_size={}, num_heads={}, is_cross_attention={}".format( self.hidden_size, self.num_heads, self.is_cross_attention, )