pytorchResnet-18源码解读

更新时间:2023-05-06 13:05:31 阅读: 评论:0

pytorchResnet-18源码解读⽬录
ResNet-18⽹络结构图
ResNet是微软研究院He KaiMing等⼈提出的。论⽂链接:
ResNet代码
在pytorch中定义了:
_all__ =['ResNet','resnet18','resnet34','resnet50','resnet101',
'resnet152','resnext50_32x4d','resnext101_32x8d',
'wide_resnet50_2','wide_resnet101_2']
Resnet 声明
这⾥只介绍ResNet-18。其调⽤⽅法:
from torchvision import models
resnet_18 = snet18(pretrained=True)
ResNet
其中pretrained表⽰是否载⼊在Image net上的与训练模型。ResNet18模型的定义如下:
def resnet18(pretrained=Fal, progress=True,**kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" </pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock,[2,2,2,2], pretrained, progress,
**kwargs)
resnet18调⽤类的私有函数_resnet , _resnet定义如下:
def_resnet(arch, block, layers, pretrained, progress,**kwargs):
"""
arch: ⽹络名字
block: 残差块类型,定义了BasicBlock与Bottleneck两种
layers: 每个stage中残差块的数⽬,长度为4
"""
model = ResNet(block, layers,**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
ResNet代码如下:
class ResNet(nn.Module):
def__init__(lf, block, layers, num_class=1000, zero_init_residual=Fal,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
"""
block: 残差块类型,定义了BasicBlock与Bottleneck两种
layers: 每个stage中残差块的数⽬,长度为4
num_class: 类别数⽬
num_class: 类别数⽬
zero_init_residual:若为True,则将残差块的最后⼀个BN层初始化为0,
这样残差分⽀从0开始每⼀个残差分⽀,每⼀个残差块表现的像⼀个恒等映射
根据论⽂:⽹络可提升0.2%~0.3%
"""
super(ResNet, lf).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
lf._norm_layer = norm_layer
lf.inplanes =64
lf.dilation =1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation =[Fal,Fal,Fal]
if len(replace_stride_with_dilation)!=3:
rai ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
lf.ba_width = width_per_group
bias=Fal)
lf.bn1 = norm_layer(lf.inplanes)
lf.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
lf.layer1 = lf._make_layer(block,64, layers[0])
lf.layer2 = lf._make_layer(block,128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
lf.layer3 = lf._make_layer(block,256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
lf.layer4 = lf._make_layer(block,512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
lf.avgpool = nn.AdaptiveAvgPool2d((1,1))
lf.fc = nn.Linear(512* pansion, num_class)
for m dules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m,(nn.BatchNorm2d, nn.GroupNorm)):
stant_(m.weight,1)
stant_(m.bias,0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to /abs/1706.02677
if zero_init_residual:
for m dules():
if isinstance(m, Bottleneck):
stant_(m.bn3.weight,0)
elif isinstance(m, BasicBlock):
stant_(m.bn2.weight,0)
def_make_layer(lf, block, planes, blocks, stride=1, dilate=Fal):
norm_layer = lf._norm_layer
downsample =None
previous_dilation = lf.dilation
if dilate:
lf.dilation *= stride
stride =1
if stride !=1or lf.inplanes != planes * pansion:
downsample = nn.Sequential(
conv1x1(lf.inplanes, planes * pansion, stride),
norm_layer(planes * pansion),
)
layers =[]
layers =[]
layers.append(block(lf.inplanes, planes, stride, downsample, lf.groups,
lf.ba_width, previous_dilation, norm_layer))
lf.inplanes = planes * pansion
for _ in range(1, blocks):
layers.append(block(lf.inplanes, planes, ups,
ba_width=lf.ba_width, dilation=lf.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def_forward_impl(lf, x):
# See note [TorchScript super()]
x = lf.conv1(x)
x = lf.bn1(x)
x = lf.relu(x)
x = lf.maxpool(x)
x = lf.layer1(x)
x = lf.layer2(x)
x = lf.layer3(x)
x = lf.layer4(x)
x = lf.avgpool(x)
x = torch.flatten(x,1)
x = lf.fc(x)
return x
def forward(lf, x):
return lf._forward_impl(x)
在resnet开始会有⼀个7x7的卷积核来做⼀次2x下采样,其后⽤maxpooling再做⼀次2x下采样。其后 会有4个layer,由_make_layer实现,最后是全连接层。下⾯介绍make_layer的实现。
make_layer定义
由代码注释知道:block为block类型,针对不同层数的resnet⽹络有BasicBlock与Bottleneck两种;planes是第⼀个卷积核的输出通道数;blocks是Int类型,指得是本个Make_layer包含block的个数
残差块定义
在前⾯注释提到,resnet中block有两种:BasicBlock与Bottleneck两种
1. BasicBlock是resnet18 与resnet34的残差结构块
2. Bottleneck是resnet50,resnet101与resnet152的残差块结构
先介绍BasicBlock:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=Fal, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=Fal)
class BasicBlock(nn.Module):
expansion =1
def__init__(lf, inplanes, planes, stride=1, downsample=None, groups=1,
ba_width=64, dilation=1, norm_layer=None):
"""
inplanes: 输⼊的通道树,int
planes:
stride:卷积层的步长
downsample: 分⽀下采样(nn.Sequential)
"""
super(BasicBlock, lf).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups !=1or ba_width !=64:
rai ValueError('BasicBlock only supports groups=1 and ba_width=64')
if dilation >1:
rai NotImplementedError("Dilation > 1 not supported in BasicBlock")
# v1 and lf.downsample layers downsample the input when stride != 1
lf.bn1 = norm_layer(planes)
lf.bn2 = norm_layer(planes)
lf.downsample = downsample
lf.stride = stride
def forward(lf, x):
identity = x
out = lf.conv1(x)
out = lf.bn1(out)
out = lf.relu(out)
out = lf.conv2(out)
out = lf.bn2(out)
if lf.downsample is not None:
identity = lf.downsample(x)
out += identity
out = lf.relu(out)
return out
每个block⾥依次包含:conv3x3, bn, relu, conv3x3, bn。在forward中⽤out+=x实现短接。如果参数stride=2,则会在第⼀个conv3x3中出现下采样。那么需要赋值downsample,down sample也是⼀个3x3卷积。
ResNet结构图
最后
第⼀次发博客,整理不全⾯的忘指出。

本文发布于:2023-05-06 13:05:31,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/90/97983.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:残差   卷积   类型   定义
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图