HRNet解析

更新时间:2023-05-06 13:04:37 阅读: 评论:0

HRNet解析
前⾔
⼤多数⽹络都是由较⾼的分辨率特征图开始,通过步长为2的卷积块,甚⾄是池化操作,来逐渐缩⼩特征图⼤⼩,丰富各个通道的信息,最后再通过⼀个全局池化,输出通道信息。于是HRNet的作者就在思考能否通过并⾏,来融合多个尺度特征图信息来提⾼⽹络的性能,事实上也证明了这种⽅法的有效。这篇SOTA的模型也常⽤于⽬标检测,姿势估计等复杂任务,且表现都⼗分不错
⽹络结构
这是论⽂⾥⾯的⼀幅图⽚,看上去⼗分清楚,整个⽹络有多个分⽀,经过⼀定卷积操作后,上⾯的分⽀通过下采样来缩⼩特征图⼤⼩,融合进下⾯的分⽀,⽽下⾯的分⽀则通过上采样来恢复原特征图⼤⼩,融合进上⾯的分⽀
这⾥上采样模块,作者使⽤的是最近邻元素填充的⽅式
代码分析
残差块构造
⽹络中的卷积操作,还是以残差块的思想,所以开头两段module是残差块的构造,包含基本块和bottleneck块
class BasicBlock(nn.Module):
"""
基本块构造
"""
expansion = 1
def __init__(lf, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, lf).__init__()
lf.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
lf.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
lf.downsample = downsample
lf.stride = stride
def forward(lf, x):
"""
构造残差连接
:param x:
:
return:
"""
residual = x
out = lf.conv1(x)
out = lf.bn1(out)
out = lf.relu(out)
out = lf.conv2(out)
out = lf.bn2(out)
if lf.downsample is not None:
# 如果downsample不为None,则进⾏下采样
residual = lf.downsample(x)
out += residual
out = lf.relu(out)
return out
class Bottleneck(nn.Module):
"""
残差块的bottleneck部分
"""
expansion = 4
def __init__(lf, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, lf).__init__()
lf.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
lf.bn3 = nn.BatchNorm2d(pansion, momentum=BN_MOMENTUM)
lf.downsample = downsample
lf.stride = stride
def forward(lf, x):
residual = x
out = lf.conv1(x)
out = lf.bn1(out)
out = lf.relu(out)
out = lf.conv2(out)
out = lf.bn2(out)
out = lf.relu(out)
out = lf.conv3(out)
out = lf.bn3(out)
if lf.downsample is not None:
residual = lf.downsample(x)
out += residual
out = lf.relu(out)
return out
HRNET主要⽹络构造
class HighResolutionModule(nn.Module):
def __init__(lf, num_branches, blocks, num_blocks, num_inchannels, num_channels,
fu_method, multi_scale_output=True):
super(HighResolutionModule, lf).__init__()
lf._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
lf.num_inchannels = num_inchannels
lf.fu_method = fu_method
lf.num_branches = num_branches
lf.multi_scale_output = multi_scale_output
lf.branches = lf._make_branches(
num_branches, blocks, num_blocks, num_channels
)
lf.fu_layers = lf._make_fu_layers()
def _check_branches(lf, num_branches, blocks, num_blocks,
num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
<(error_msg)
rai ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
<(error_msg)
rai ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
<(error_msg)
rai ValueError(error_msg)
def _make_one_branch(lf, branch_index, block, num_blocks, num_channels, stride=1):
"""
添加⼀个分⽀
:param branch_index:
:param block:
:param num_blocks:
:param num_channels:
:param stride:
:return:
"""
downsample = None
if stride != 1 or lf.num_inchannels[branch_index] != num_channels[branch_index]*pansion:
            # 当stride不为1,或输⼊通道不等于输出通道数
# 则加⼊downsample下采样模块
# 注意这个下采样模块是针对resnet结构做的,⽽不是特征融合的那个下采样
downsample = nn.Sequential(
nn.Conv2d(lf.num_inchannels[branch_index],
num_channels[branch_index]*pansion,
kernel_size=1, stride=stride, bias=Fal),
nn.BatchNorm2d(
num_channels[branch_index]*pansion,
momentum=BN_MOMENTUM
),
)
layers = []
# 以残差块模式加⼊
layers.append(
block(
lf.num_inchannels[branch_index],
num_channels[branch_index],
stride,
downsample
)
)
# 因为经过残差块后,对应分⽀的输⼊通道改变
# 这时候根据不同的残差块的expansion来改变通道数⽬
lf.num_inchannels[branch_index] = num_channels[branch_index]*pansion
# 将剩下的块都加⼊进来
# 后续的块都是步长为1,特征图⼤⼩不变
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
lf.num_inchannels[branch_index],
num_channels[branch_index]
)
)
return nn.Sequential(*layers)
def _make_branches(lf, num_branches, block, num_blocks, num_channels):
"""
⽤⼀个循环来调⽤_make_one_branch
:param num_branches:
:param block:
:param num_blocks:
:param num_channels:
:return:
"""
branches = []
for i in range(num_branches):
branches.append(
lf._make_one_branch(i, block, num_blocks, num_channels)
)
return nn.ModuleList(branches)
def get_num_inchannels(lf):
return lf.num_inchannels
def _make_fu_layers(lf):
if lf.num_branches == 1:
return None
num_branches = lf.num_branches
num_inchannels = lf.num_inchannels
fu_layers = []
# i,j都是分⽀数
for i in range(num_branches if lf.multi_scale_output el 1):
fu_layer = []
for j in range(num_branches):
if j > i:
# 此时上⾯的分⽀与下⾯的分⽀融合
fu_layer.append(
nn.Sequential(
# 当j > i 执⾏上采样融合
nn.Conv2d(num_inchannels[j], num_inchannels[i],
1, 1, 0, bias=Fal),
nn.BatchNorm2d(num_inchannels[i]),
nn.Upsample(scale_factor=2**(j-i), mode='nearest')
)
)
elif j==i:
# 此时分⽀是⾃⼰
# j == i 不做操作
fu_layer.append(None)
el:
# 此时上分⽀下采样,所以经过⼀个步长为2的卷积减半特征图⼤⼩                    conv3x3s = []
for k in range(i-j):
"""
上⾯的分⽀需要经过多次下采样减⼩特征图
k = i-j-1这时候是最后⼀次,最后⼀次有点特别,不需要relu激活                        """
if k == i-j-1:
#
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=Fal
),
nn.BatchNorm2d(num_outchannels_conv3x3)
)
)
el:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=Fal
),
nn.BatchNorm2d(num_outchannels_conv3x3),
nn.ReLU(inplace=True)
)
)
fu_layer.append(nn.Sequential(*conv3x3s))
fu_layers.append(nn.ModuleList(fu_layer))
return nn.ModuleList(fu_layers)
def forward(lf, x):
if lf.num_branches == 1:
return [lf.branches[0](x[0])]
for i in range(lf.num_branches):
x[i] = lf.branches[i](x[i])
x_fu = []
for i in range(len(lf.fu_layers)):
y = x[0] if i == 0 el lf.fu_layers[i][0](x[0])
for j in range(1, lf.num_branches):
if i == j:
y = y + x[j]
el:
y = y + lf.fu_layers[i][j](x[j])
x_fu.lu(y))
return x_fu

本文发布于:2023-05-06 13:04:37,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/78/540238.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:特征   采样   残差   通道   构造
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图