Source code for libai.models.vision_transformer

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import oneflow as flow
import oneflow.nn as nn
from flowvision.layers.weight_init import trunc_normal_

import libai.utils.distributed as dist
from libai.config.config import configurable
from libai.layers import LayerNorm, Linear, PatchEmbedding, TransformerLayer


[docs]class VisionTransformer(nn.Module): """Vision Transformer in LiBai. LiBai's implementation of: `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_ Args: img_size (int, tuple(int)): input image size patch_size (int, tuple(int)): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate num_classes (int): number of classes for classification head loss_func (callable, optional): loss function for computing the total loss between logits and labels """ @configurable def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4.0, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, num_classes=1000, loss_func=None, ): super().__init__() self.img_size = img_size self.num_classes = num_classes self.patch_embed = PatchEmbedding( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) ffn_size = int(embed_dim * mlp_ratio) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter( flow.zeros( 1, 1, embed_dim, sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=dist.get_layer_placement(0), ) ) self.pos_embed = nn.Parameter( flow.zeros( 1, num_patches + 1, embed_dim, sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=dist.get_layer_placement(0), ) ) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [ x.item() for x in flow.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks = nn.Sequential( *[ TransformerLayer( hidden_size=embed_dim, ffn_hidden_size=ffn_size, num_attention_heads=num_heads, attention_dropout_prob=attn_drop_rate, output_dropout_prob=drop_rate, drop_path_prob=dpr[i], layer_idx=i, ) for i in range(depth) ] ) self.norm = LayerNorm(embed_dim, layer_idx=-1) self.head = Linear(embed_dim, num_classes, layer_idx=-1) # loss func self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func # weight init trunc_normal_(self.pos_embed, std=0.02) trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def no_weight_decay(self): return {"pos_embed", "cls_token"} @classmethod def from_config(cls, cfg): return { "img_size": cfg.img_size, "patch_size": cfg.patch_size, "in_chans": cfg.in_chans, "embed_dim": cfg.embed_dim, "depth": cfg.depth, "num_heads": cfg.num_heads, "mlp_ratio": cfg.mlp_ratio, "drop_rate": cfg.drop_rate, "attn_drop_rate": cfg.attn_drop_rate, "drop_path_rate": cfg.drop_path_rate, "num_classes": cfg.num_classes, "loss_func": cfg.loss_func, } def forward_features(self, x): # patch embedding x = self.patch_embed(x) cls_token = self.cls_token.expand( x.shape[0], -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks cls_token = cls_token.to_global(sbp=x.sbp, placement=cls_token.placement) x = flow.cat((cls_token, x), dim=1) # position embedding pos_embed = self.pos_embed.expand(x.shape[0], -1, -1) pos_embed = pos_embed.to_global(sbp=x.sbp, placement=pos_embed.placement) x = self.pos_drop(x + pos_embed) # transformer block x = self.blocks(x) return x def forward_head(self, x): x = self.norm(x) outcome = x[:, 0] outcome = self.head(outcome) return outcome
[docs] def forward(self, images, labels=None): """ Args: images (flow.Tensor): training samples. labels (flow.LongTensor, optional): training targets Returns: dict: A dict containing :code:`loss_value` or :code:`logits` depending on training or evaluation mode. :code:`{"losses": loss_value}` when training, :code:`{"prediction_scores": logits}` when evaluating. """ x = self.forward_features(images) x = self.forward_head(x) if labels is not None and self.training: losses = self.loss_func(x, labels) return {"losses": losses} else: return {"prediction_scores": x}
@staticmethod def set_pipeline_stage_id(model): dist_utils = dist.get_dist_util() # Set pipeline parallelism stage_id if hasattr(model.pos_embed, "config"): # Old API in OneFlow 0.8 for module_block in model.modules(): if isinstance(module_block.origin, PatchEmbedding): module_block.config.set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.origin, TransformerLayer): module_block.config.set_stage( dist_utils.get_layer_stage_id(module_block.layer_idx), dist.get_layer_placement(module_block.layer_idx), ) # Set pos_embed and cls_token stage id model.pos_embed.config.set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) model.cls_token.config.set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) model.pos_drop.config.set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) model.norm.config.set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) model.head.config.set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) model.loss_func.config.set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) else: for module_block in model.modules(): if isinstance(module_block.to(nn.Module), PatchEmbedding): module_block.to(flow.nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) elif isinstance(module_block.to(nn.Module), TransformerLayer): module_block.to(flow.nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(module_block.layer_idx), dist.get_layer_placement(module_block.layer_idx), ) # Set pos_embed and cls_token stage id model.pos_embed.to(flow.nn.graph.GraphTensor).set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) model.cls_token.to(flow.nn.graph.GraphTensor).set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) model.pos_drop.to(flow.nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0) ) model.norm.to(flow.nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) model.head.to(flow.nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) ) model.loss_func.to(flow.nn.graph.GraphModule).set_stage( dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1) )