# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------
# ResMLP Model
# References:
# resmlp: https://github.com/facebookresearch/deit/blob/main/resmlp_models.py
# --------------------------------------------------------
import oneflow as flow
import oneflow.nn as nn
from flowvision.layers.weight_init import trunc_normal_
import libai.utils.distributed as dist
from libai.config import configurable
from libai.layers import MLP, DropPath, LayerNorm, Linear, PatchEmbedding
class Affine(nn.Module):
def __init__(self, dim, *, layer_idx=0):
super().__init__()
self.alpha = nn.Parameter(
flow.ones(
dim,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
self.beta = nn.Parameter(
flow.zeros(
dim,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
)
self.layer_idx = layer_idx
def forward(self, x):
x = x.to_global(placement=dist.get_layer_placement(self.layer_idx))
return self.alpha * x + self.beta
class layers_scale_mlp_blocks(nn.Module):
def __init__(
self, dim, drop=0.0, drop_path=0.0, init_values=1e-4, num_patches=196, *, layer_idx=0
):
super().__init__()
self.norm1 = Affine(dim, layer_idx=layer_idx)
self.attn = Linear(num_patches, num_patches, layer_idx=layer_idx)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = Affine(dim, layer_idx=layer_idx)
self.mlp = MLP(hidden_size=dim, ffn_hidden_size=int(4.0 * dim), layer_idx=layer_idx)
self.gamma_1 = nn.Parameter(
init_values
* flow.ones(
dim,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(layer_idx),
),
requires_grad=True,
)
self.gamma_2 = nn.Parameter(
init_values
* flow.ones(
dim,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(layer_idx),
),
requires_grad=True,
)
self.layer_idx = layer_idx
def forward(self, x):
x = x.to_global(placement=dist.get_layer_placement(self.layer_idx))
x = x + self.drop_path(
self.gamma_1 * self.attn(self.norm1(x).transpose(1, 2)).transpose(1, 2)
)
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
[docs]class ResMLP(nn.Module):
"""ResMLP in LiBai.
LiBai's implementation of:
`ResMLP: Feedforward networks for image classification with data-efficient training
<https://arxiv.org/abs/2105.03404>`_
Args:
img_size (int, tuple(int)): input image size
patch_size (int, tuple(int)): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
drop_rate (float): dropout rate
drop_path_rate (float): stochastic depth rate
init_scale (float): the layer scale ratio
num_classes (int): number of classes for classification head
loss_func (callable, optional): loss function for computing the total loss
between logits and labels
"""
@configurable
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
drop_rate=0.0,
drop_path_rate=0.0,
init_scale=1e-4,
num_classes=1000,
loss_func=None,
):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbedding(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
dpr = [drop_path_rate for i in range(depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
layers_scale_mlp_blocks(
dim=embed_dim,
drop=drop_rate,
drop_path=dpr[i],
init_values=init_scale,
num_patches=num_patches,
layer_idx=i,
)
for i in range(depth)
]
)
self.norm = Affine(embed_dim, layer_idx=-1)
self.head = (
Linear(embed_dim, num_classes, layer_idx=-1) if num_classes > 0 else nn.Identity()
)
# loss func
self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func
# weight init
self.apply(self._init_weights)
@classmethod
def from_config(cls, cfg):
return {
"img_size": cfg.img_size,
"patch_size": cfg.patch_size,
"in_chans": cfg.in_chans,
"embed_dim": cfg.embed_dim,
"depth": cfg.depth,
"drop_rate": cfg.drop_rate,
"drop_path_rate": cfg.drop_path_rate,
"init_scale": cfg.init_scale,
"num_classes": cfg.num_classes,
"loss_func": cfg.loss_func,
}
def _init_weights(self, m):
if isinstance(m, Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
x = self.patch_embed(x)
# layer scale mlp blocks
for i, blk in enumerate(self.blocks):
x = blk(x)
return x
def forward_head(self, x):
B = x.shape[0]
x = self.norm(x)
x = x.mean(dim=1).reshape(B, 1, -1)
return self.head(x[:, 0])
[docs] def forward(self, images, labels=None):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x = self.forward_features(images)
x = self.forward_head(x)
if labels is not None and self.training:
losses = self.loss_func(x, labels)
return {"losses": losses}
else:
return {"prediction_scores": x}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.loss_func, "config"):
# Old API in OneFlow 0.8
for module_block in model.modules():
# module.origin can get the original module
if isinstance(module_block.origin, PatchEmbedding):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, layers_scale_mlp_blocks):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
# Set norm and head stage id
model.norm.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), PatchEmbedding):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), layers_scale_mlp_blocks):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
# Set norm and head stage id
model.norm.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
@staticmethod
def set_activation_checkpoint(model):
for module_block in model.modules():
if hasattr(module_block, "origin"):
# Old API in OneFlow 0.8
if isinstance(module_block.origin, layers_scale_mlp_blocks):
module_block.config.activation_checkpointing = True
else:
if isinstance(module_block.to(nn.Module), layers_scale_mlp_blocks):
module_block.to(nn.graph.GraphModule).activation_checkpointing = True