通过和resnet18和resnet50理解PyTorch的ResNet模块

更新时间:2023-05-05 11:58:49 阅读: 评论:0

通过和resnet18和resnet50理解PyTorch的ResNet模块
⽂章⽬录
resnet和resnext的框架基本相同的,这⾥先学习下resnet的构建,感觉⾼度模块化,很⽅便。本⽂算是对 ResNet代码的详细理解,另外,强烈推荐这位⼤神的PyTorch的教程!
模型介绍
resnet的模型可以直接通过torchvision导⼊,可以通过pretrained设定是否导⼊预训练的参数。
import torchvision
model = snet50(pretrained=Fal)
如果选择导⼊,resnet50、resnet101和resnet18等的模型函数⼗分简洁并且只有ResNet的参数不同,只是需要导⼊预训练参数时,调⽤load_state_dict加载model_zoo.load_url下载的参数,这⾥model_urls是⼀个维护不同模型参数下载地址的字典。
def resnet18(pretrained=Fal,**kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock,[2,2,2,2],**kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet50(pretrained=Fal,**kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck,[3,4,6,3],**kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
model_urls ={
'resnet18':'/models/resnet18-5c106cde.pth',
'resnet34':'/models/resnet34-333f7ec4.pth',
'resnet50':'/models/resnet50-19c8e357.pth',
'resnet101':'/models/resnet101-5d3b4d8f.pth',
'resnet152':'/models/resnet152-b121ed2d.pth',
}
接下来我们看下重点,也就是ResNet,ResNet的组成是:基础模块Bottleneck/Basicblock,通过_make_layer⽣成四个的⼤的layer,然后在forward中排序。
__init__的两个重要参数,block和layers,block有两种(Bottleneck/Basicblock),不同模型调⽤的类不同在resnet50、resnet101、resnet152中调⽤的是Bottleneck类,⽽在resnet18和resnet34中调⽤的是BasicBlock类,在后⾯我们详细理解。layers是包含四个元素的列表,每个元素分别是_make_layer⽣成四个的⼤的layer的包含的resdual⼦结构的个数,在resnet50可以看到列表是 [3, 4, 6, 3]。_make_layer包含四个参数,第⼀个参数是block的类型,第⼆个参数planes是输出的channel数,第三个参数blocks每个blocks中包含多少个residual⼦结构,也就是上述列表layers所存储的数字,第四个参数为步长。
def__init__(lf, block, layers, num_class=1000):
lf.inplanes =64
super(ResNet, lf).__init__()
bias=Fal)
lf.bn1 = nn.BatchNorm2d(64)
lf.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
lf.layer1 = lf._make_layer(block,64, layers[0])
lf.layer2 = lf._make_layer(block,128, layers[1], stride=2)
lf.layer3 = lf._make_layer(block,256, layers[2], stride=2)
lf.layer4 = lf._make_layer(block,512, layers[3], stride=2)
lf.avgpool = nn.AvgPool2d(7, stride=1)
lf.fc = nn.Linear(512* pansion, num_class)
for m dules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0]* m.kernel_size[1]* m.out_channels
m.al_(0, math.sqrt(2./ n))# 卷积参数变量初始化
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)# BN参数初始化
m._()
def_make_layer(lf, block, planes, blocks, stride=1):
downsample =None
if stride !=1or lf.inplanes != planes * pansion:
downsample = nn.Sequential(
nn.Conv2d(lf.inplanes, planes * pansion,
kernel_size=1, stride=stride, bias=Fal),
nn.BatchNorm2d(planes * pansion),
)
layers =[]
layers.append(block(lf.inplanes, planes, stride, downsample))
lf.inplanes = planes * pansion
for i in range(1, blocks):
layers.append(block(lf.inplanes, planes))
return nn.Sequential(*layers)
def forward(lf, x):
x = lf.conv1(x)
x = lf.bn1(x)
x = lf.relu(x)
x = lf.maxpool(x)
x = lf.layer1(x)
x = lf.layer2(x)
x = lf.layer3(x)
x = lf.layer4(x)
x = lf.avgpool(x)
x = x.view(x.size(0),-1)
x = lf.fc(x)
return x
接下来我们看下两种block:Bottleneck/Basicblock,他们最重要的是resdual的结构。所有的模型都继承Module,bottleneck改写了__init__和forward(),forward()中的out += residual就是element-wi add的操作。Bottleneck需要理解的有两处:expansion=4和downsample(下采样)。关于下采样的理论我也不清楚,我们后⾯直接通过代码来理解吧。
expansion =4
def__init__(lf, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, lf).__init__()
padding=1, bias=Fal)
lf.bn2 = nn.BatchNorm2d(planes)
lf.downsample = downsample
lf.stride = stride
def forward(lf, x):
residual = x
out = lf.conv1(x)
out = lf.bn1(out)
out = lf.relu(out)
out = lf.conv2(out)
out = lf.bn2(out)
out = lf.relu(out)
out = lf.conv3(out)
out = lf.bn3(out)
if lf.downsample is not None:
residual = lf.downsample(x)
out += residual
out = lf.relu(out)
return out
Basicblock的resdual包含两个卷积层,第⼀层卷积层的kernel=3。
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=Fal)
class BasicBlock(nn.Module):
expansion =1
def__init__(lf, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, lf).__init__()
lf.bn1 = nn.BatchNorm2d(planes)
lf.bn2 = nn.BatchNorm2d(planes)
lf.downsample = downsample
lf.stride = stride
def forward(lf, x):
residual = x
out = lf.conv1(x)
out = lf.bn1(out)
out = lf.relu(out)
out = lf.conv2(out)
out = lf.bn2(out)
if lf.downsample is not None:
residual = lf.downsample(x)
out += residual
out = lf.relu(out)
return out
resnet18模型流程
resnet调⽤的Resnet参数是model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
Resnet – init()
lf.layer1之前的变量初始化不难理解,lf.layer1=lf._make_layer(block, 64, layers[0])这⾥block=Basicblock,layer[0]=2
执⾏_make_layer
downsample = None——if条件不满⾜,downsample=None
下⾯构建blocks层Basicblock:
layers=[]——layers.append(Basicblock(64,64,1,downsample=None))
赋值输⼊channel lf.inplanes = pansion = 641 = 64
for循环构建剩下的blocks-1个residual,不传downsample.
lf.layer2 执⾏lf._make_layer(block, 128, layers[1], stride=2)
downsample=None
显然if条件满⾜ downsample=nn.Sequential(nn.Conv2d(64,128, kernel_size=1, stride=2, bias=Fal), nn.BatchNorm2d(128), )
layers=[]——layers.append(Basicblock(64,128,2,downsample))
lf.inplanes = 128*1=128
for循环构建剩下的blocks-1个residual,不传dowmsample.
可以看出接下来layer3和layer4与layer2相似,最终构成resnet18.
总结
从layer2到layer4,每个layer第⼀个输⼊会增加⼀倍channel,所以resdual会采⽤下采样,⽽对于每⼀层⽽⾔,channel都是相同
的,pansion都为1,所以我们看不出其发挥的作⽤,我们将在resnet50研究下。如下图,这⾥没找到resnet18,图中的虚线就是downsample,其产⽣于channel变化的resdual。
resnet50
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs),可以看出,resnet50采⽤Bottleneck模块,并且每个⼤的layer的blocks数量也不同。layer1=lf._make_layer(Bottleneck, 64, 3)
if条件满⾜,downsample = nn.Sequential(
nn.Conv2d(lf.inplanes=64, 64 * 4,
kernel_size=1, stride=stride, bias=Fal),
nn.BatchNorm2d(644),)
layers.append(Bottleneck(64,64,1,dowmsample)),bottleneck内经过三个卷积层Conv2d(64,64) Conv2d(64,64)
Conv2d(64,644)保证每个block的输出channel是planes expansion,通过lf.inplanes = pansion赋值后⾯block的输⼊channel也是planes*expansion。

本文发布于:2023-05-05 11:58:49,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/531788.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:参数   模型   构建   理解   包含
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图