SSD源码解读——损失函数的构建

更新时间:2023-07-15 01:32:04 阅读: 评论:0

SSD源码解读——损失函数的构建
之前,对SSD的论⽂进⾏了解读,可以回顾之前的博客:。
为了加深对SSD的理解,因此对SSD的源码进⾏了复现,主要参考的github项⽬是。同时,我⾃⼰对该项⽬增加了⼤量注释:
搭建SSD的项⽬,可以分成以下三个部分:
1. ;
2. ;
3. 损失函数的构建;
4. 。
接下来,本篇博客重点分析损失函数的构建。
检测任务的损失函数,与分类任务的损失函数具有很⼤不同。在检测的损失函数中,不仅需要计类别置信度的差异,坐标的差异,还需要使⽤到各种tricks,例如hard negative mining等。
在train.py中,⾸先需要对损失函数MultiBoxLoss()进⾏初始化,需要传⼊的参数为num_class类别数,正例的IOU阈值和hard negative mining的正负样本⽐例。在论⽂中,VOC的类别总数是21(20个类别加上1个背景);当预测框与GT框的IOU⼤于0.5时,认为该预测框是正例;hard negative mining的正样本和负样本的⽐例是1:3。
六级听力分值
# 损失函数
criterion = MultiBoxLoss(num_class=voc['num_class'],
overlap_thresh=0.5,
吉胡阿依neg_pos=3)
在models/multibox_loss中,定义了损失函数MultiBoxLoss()。在函数forward()中,需要传进来两个参数,分别是predictions和targets,其中,predictions是SSD⽹络得到的结果,分别是预测框坐标,类别置信度和先验锚点框;⽽targets是则是数据读取中的值,是GT框的坐标和类别label。⾸先,需要创建坐标loc_t和类别置信度conf_t的tensor,其shape分别是[batch_size,8732,4]和[batch_size,8732]。然后,使⽤⼀个for循环,将GT框与先验锚点框的坐标与label进⾏match,得到每个锚点框的label和坐标偏差,并将结果保存与loc_t和conf_t中。由于制定了某些锚点框⽤于预测⽬标,
因此,接下来,需要使⽤这部分锚点框信息来计算损失。取出含⽬标的锚点框,得到其index,其中,pos 的shape为[batch_size,8732],每个元素是true或者fal。再从⽹络预测的8732个预测框中,取出同样index的预测框的坐标偏差loc_p,⽽loc_t则是同样index的先验锚点框的坐标偏差。由于锚点框对应上了,则使⽤smooth_l1来计算预测框回归的算是loss_l,如下图所⽰
的L_{loc},。
接下来,则是使⽤hard negative mining和计算置信度损失。⾸先为模型预测出来的置信度conf_data进⾏维度变换,由[batch_size,8732,21]变成[batch_size*8732,21]的batch_conf,应该是为了⽅便下⾯进⾏计算。接下来,计算所有预测框的置信度损失loss_c,将含⽬标的锚点框(正例)的损失置0,并对损失进⾏排名,从⽽选出损失最⼤的前num_neg个损失的index。将正例的pos_index和损失最⼤的负例neg_idx 提取出来成conf_p,⽤于参与训练中,与相同index的先验锚点框进⾏计算交叉熵损失计算。最后将置信度损失和位置损失返回。
class MultiBoxLoss(nn.Module):
'''
SSD损失函数的计算
'''
def__init__(lf, num_class, overlap_thresh, neg_pos):
super(MultiBoxLoss, lf).__init__()
广东高考英语lf.num_class = num_class  # 类别数
lf.threshold = overlap_thresh  # GT框与先验锚点框的阈值
def forward(lf, predictions, targets):
'''
对损失函数进⾏计算:
1.进⾏GT框与先验锚点框的匹配,得到loc_t和conf_t,分别表⽰锚点框需要匹配的坐标和锚点框需要匹配的label
2.对包含⽬标的先验锚点框loc_t(即正例)与预测的loc_data计算位置损失函数
3.对负例(即背景)进⾏损失计算,选择损失最⼤的num_neg个负例和正例共同组成训练样本,取出这些训练样本的锚点框targets_weighted
与置信度预测值conf_p,计算置信度损失:sign up
a)为Hard Negative Mining计算最⼤置信度loss_c
b)将loss_c中正例对应的值置0,即保留了所有负例
c)对此loss_c进⾏排序,得到损失最⼤的idx_rank
d)计算⽤于训练的负例的个数num_neg,约为正例的3倍
e)选择idx_rank中前num_neg个⽤作训练
f)将正例的index和负例的index共同组成⽤于计算损失的index,并从预测置信度conf_data和真实置信度conf_t提出这些样本,形成
conf_p和targets_weighted,计算两者的置信度损失.
:param predictions: ⼀个元祖,包含位置预测,置信度预测,先验锚点框
位置预测:(batch_size,num_priors,4),即[batch_size,8732,4]
置信度预测:(batch_size,num_priors,num_class),即[batch_size, 8732, 21]
先验锚点框:(num_priors,4),即[8732, 4]
:param targets: 真实框的坐标与label,[batch_size,num_objs,5]
其中,5代表[xmin,ymin,xmia,ymax,label]
'''
loc_data, conf_data, priors = predictions
num = loc_data.shape[0]  # 即batch_size⼤⼩
priors = priors[:loc_data.shape[1], :]  # 取出8732个锚点框,与位置预测的锚点框数量相同
num_priors = priors.shape[0]  # 8732
loc_t = torch.Tensor(num, num_priors, 4)  # [batch_size,8732,4],⽣成随机tensor,后续⽤于填充
conf_t = torch.Tensor(num, num_priors)  # [batch_size,8732]
# 取消梯度更新,貌似默认是Fal
quires_grad = Fal
quires_grad = Fal
for idx in range(num):
truths = targets[idx][:, :-1]  # 坐标值,[xmin,ymin,xmia,ymax]
labels = targets[idx][:, -1]  # label
defaults = priors.cuda()
match(lf.threshold, truths, defaults, labels, loc_t, conf_t, idx)
if torch.cuda.is_available():
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()  # shape:[batch_size,8732],其元素组成是类别标签号和背景
pos = conf_t > 0  # 排除label=0,即排除背景,shape[batch_size,8732],其元素组成是true或者fal
# Localization Loss (Smooth L1),定位损失函数
# Shape: [batch,num_priors,4]
# pos.dim()表⽰pos有多少维,应该是⼀个定值(2)
# pos由[batch_size,8732]变成[batch_size,8732,1],然后展开成[batch_size,8732,4]
anxiouslypos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
loc_p = loc_data[pos_idx].view(-1, 4)  # [num_pos,4],取出带⽬标的这些框
loc_t = loc_t[pos_idx].view(-1, 4)  # [num_pos,4]
# 位置损失函数
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')  # 这⾥对损失值是相加,有公式可知,还没到相除的地步
# 为Hard Negative Mining计算max conf across batch
batch_conf = conf_data.view(-1, lf.num_class)  # shape[batch_size*8732,21]
# gather函数的作⽤是沿着定轴dim(1),按照Index(conf_t.view(-1, 1))取出元素
# batch_conf.gather(1, conf_t.view(-1, 1))的shape[8732,1],作⽤是得到每个锚点框在匹配GT框后的label
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1).long())  # 这个不是最终的置信度损失函数
# Hard Negative Mining
# 由于正例与负例的数据不均衡,因此不是所有负例都⽤于训练
loss_c[pos.view(-1, 1)] = 0  # pos与loss_c维度不⼀样,所以需要转换⼀下,选出负例
loss_c = loss_c.view(num, -1)  # [batch_size,8732]
_, loss_idx = loss_c.sort(1, descending=True)  # 得到降序排列的index
_, idx_rank = loss_idx.sort(1)
num_pos = pos.sum(1, keepdim=True)  # pos⾥⾯是true或者fal,因此sum后的结果应该是包含的⽬标数量
num_neg = torch.pos_ratio * num_pos, max=pos.size(1) - 1)  # ⽣成⼀个随机数⽤于表⽰负例的数量,正例和负例的⽐例约3:1
neg = idx_rank < pand_as(idx_rank)  # [batch_size,8732] 选择num_neg个负例,其元素组成是true或者fal
# 置信度损失,包括正例和负例
# [batch_size, 8732, 21],元素组成是true或者fal,但true代表着存在⽬标,其对应的index为label
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
# pos_idx由true和fal组成,表⽰选择出来的正例,neg_idx同理
# (pos_idx + neg_idx)表⽰选择出来⽤于训练的样例,包含正例和反例
# (other)函数的作⽤是逐个元素与other进⾏⼤⼩⽐较,⼤于则为true,否则为fal
# 因此conf_data[(pos_idx + neg_idx).gt(0)]得到了所有⽤于训练的样例
conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, lf.num_class)
targets_weighted = conf_t[(pos + neg).gt(0)]
loss_c = F.cross_entropy(conf_p, targets_weighted.long(), reduction='sum')
# L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
N = num_pos.sum()  # ⼀个batch⾥⾯所有正例的数量
loss_l /= N
loss_c /= N
return loss_l, loss_c
在hard negative mining中,需要先计算loss_c。从代码可以看到  loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1,
1).long()) ,这句代码就是置信度损失的计算,可以参考公式进⾏理解。这⾥可以提及⼀下,对loss_c的两次排序,参考,⾸先对值进⾏降序排序,得到排名1,然后对排名⼜进⾏降序排序,得到排名2,
如下图所⽰,即能取出idx_rank的前N个,可获得损失最⼤那些值,即变量neg的作⽤。
在计算损失函数时,提及了函数match(),这个函数位于models/box_utils.py中,是⼀个⾮常关键的函
数,对应论⽂的匹配策略那⼀章节,其作⽤是为每个锚点框指定GT框和为每个GT框指定锚点框。需要传进来⼏个参数,truths是GT框的坐标,priors是先验锚点框的坐标[中⼼点x,中⼼点y,W,H],labels是GT框对应的类别(不包含背景),loc_t和conf_t是⽤来保存结果的,idx是第i张图⽚。dp是什么意思
为了⽅便表述,num_objects表⽰⼀张图中,GT框的数量;num_priors表⽰先验锚点框的数量,即8732。
第⼀步,由于先验锚点框priors的坐标形式是[中⼼点x,中⼼点y,W,H],需要使⽤函数point_from()来将其转化成
[x_min,y_min,x_max,y_max]。然后计算每个GT框与所有先验锚点框的jaccard值,即IOU的值,使⽤了numpy风格的计算⽅式,返回的变量overlaps的shape为[GT框数量,8732]。
第⼆步,根据论⽂,为每个GT框匹配⼀个最⼤IOU的先验锚点框,确保每个GT框⾄少有⼀个锚点框进⾏预测。
第三步,为每个锚点框匹配上⼀个最⼤IOU的GT框来进⾏预测。
第四步,变量best_truth_overlap保存着每个框与GT框的最⼤IOU值(第三步的结果),使⽤index_fill()函数,将第⼆步的结果同步到这个变量中。在index_fill()函数中,使⽤数值2来进⾏填充,
是为了确保第⼆步中得到的锚点框肯定会被选到。对变量best_truth_idx也进⾏同样的处理。
第五步,由于传⼊进来的labels的类别是从0开始的,SSD中认为0应该是背景,所以,需要对labels进⾏加⼀。这⾥需要注意⼀
下,best_truth_idx的shape是[8732],每个元素的范围为[0,num_objects],所以conf的shape为[num_priors],每个元素表⽰先验锚点框的label(0是背景)。同时,需要将变量best_truth_overlap中IOU⼩于阈值(0.5)的锚点框的label设置为0。并将结果保存与conf_t,返回给外⾯的函数⽤于计算。
第六步,同样需要将GT框的坐标进⾏扩展,形成shape为[num_priors,4]的matches,这样每个锚点框都有对应的坐标进⾏预测,但最终并不是每个锚点框都⽤于训练中。
第七步,使⽤GT框与锚点框进⾏编码,对应论⽂中的公式2,得到shape为[num_priors,4]的值,即偏差,将此结果返回出去。
注意,这⾥使⽤的是GT框的信息和先验锚点框的信息,并没有涉及到⽹络预测出来的结果。得到每个锚点框的类别conf_t和坐标loc_t。由于没有⽤到⽹络预测的结果,可以认为这部分⼀直都是定值。
def match(threshold, truths, priors, labels, loc_t, conf_t, idx):
'''
这个函数对应论⽂中的matching strategy匹配策略.SSD需要为每⼀个先验锚点框都指定⼀个label,
这个label或者指向背景,或者指向每个类别.
论⽂中的匹配策略是:
1.⾸先,每个GT框选择与其IOU最⼤的⼀个锚点框,并令这个锚点框的label等于这个GT框的label
2.然后,当锚点框与GT框的IOU⼤于阈值(0.5)时,同样令这个锚点框的label等于这个GT框的label
因此,代码上的逻辑为:
1.计算每个GT框与每个锚点框的IOU,得到⼀个shape为[num_object,num_priors]的矩阵overlaps
2.选择与GT框的IOU最⼤的锚点框,锚点框的index为best_prior_idx,对应的IOU值为best_prior_overlap
3.为每⼀个锚点框选择⼀个IOU最⼤的GT框,可能会出现多个锚点框匹配⼀个GT框的情况,此时,每个锚点框对应GT框的index为best_truth_idx,
对应的IOU为best_truth_overlap.注意,此时IOU值可能会存在⼩于阈值的情况.
4.第3步可能到导致存在GT框没有与锚点框匹配上的情况,所以要和第2步进⾏结合.在第3步的基础上,对best_truth_overlap进⾏选择,选择出
best_prior_idx这些锚点框,让其对其的IOU等于⼀个⼤于1的定值;并且让best_truth_idx中index为best_prior_idx的锚点框的label
与GT框对应上.最终,best_truth_overlap表⽰每个锚点框与GT框的最⼤IOU值,⽽best_truth_idx表⽰每个锚点框⽤于与相应的GT框进⾏
匹配.
5.第4步中,会存在IOU⼩于阈值的情况,要将这些⼩于IOU阈值的锚点框的label指向背景,完成第⼆条匹配策略.
labels表⽰GT框对应的标签号,"conf=labels[best_truth_idx]+1"得到每个锚点框对应的标签号,其中label=0是背景.
"conf[best_truth_overlap < threshold] = 0"则将⼩于IOU阈值的锚点框的label指向背景
6.得到的conf表⽰每个锚点框对应的label,还需要⼀个矩阵,来表⽰每个锚点框需要匹配GT框的坐标.人人都恨克里斯第二季
英语发音词典
rsbtruths表⽰GT框的坐标,"matches = truths[best_truth_idx]"得到每个锚点框需要匹配GT框的坐标.
:param threshold:IOU的阈值
:param truths:GT框的坐标,shape:[num_obj,4]
:param priors:先验锚点框的坐标,shape:[num_priors,4],num_priors=8732
:param labels:这些GT框对应的label,shape:[num_obj],此时label=0还不是背景
:param loc_t:坐标结果会保存在这个tensor
:param conf_t:置信度结果会保存在这个tensor
:param idx:结果保存的idx
'''
# 第1步,计算IOU
overlaps = jaccard(truths, point_from(priors))  # shape:[num_object,num_priors]
# 第2步,为每个真实框匹配⼀个IOU最⼤的锚点框,GT框->锚点框
# best_prior_overlap为每个真实框的最⼤IOU值,shape[num_objects,1]
# best_prior_idx为对应的最⼤IOU的先验锚点框的Index,其元素值的范围为[0,num_priors]
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
# 第3步,若先验锚点框与GT框的IOU>阈值,也将这些锚点框匹配上,锚点框->GT框
# best_truth_overlap为每个先验锚点框对应其中⼀个真实框的最⼤IOU,shape[1,num_priors]
# best_truth_idx为每个先验锚点框对应的真实框的index,其元素值的范围为[0,num_objects]
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
best_prior_idx.squeeze_(1)  # [num_objects]
best_prior_overlap.squeeze_(1)  # [num_objects]
best_truth_idx.squeeze_(0)  # [num_priors],8732
best_truth_overlap.squeeze_(0)  # [num_priors],8732
# 第4步
# index_fill_(lf, dim: _int, index: Tensor, value: Number)对第dim⾏的index使⽤value进⾏填充
# best_truth_overlap为第⼀步匹配的结果,需要使⽤到,使⽤best_prior_idx是第⼆步的结果,也是需要使⽤上的
# 所以在best_truth_overlap上进⾏填充,表明选出来的正例
# 使⽤2进⾏填充,是因为,IOU值的范围是[0,1],只要使⽤⼤于1的值填充,就表明肯定能被选出来
best_truth_overlap.index_fill_(0, best_prior_idx, 2)  # 确定最佳先验锚点框
# 确保每个GT框都能匹配上最⼤IOU的先验锚点框
# 得到每个先验锚点框都能有⼀个匹配上的数字
# best_prior_idx的元素值的范围是[0,num_priors],长度为num_objects
for j in range(best_prior_idx.size(0)):
best_truth_idx[best_prior_idx[j]] = j
# 第5步
conf = labels[best_truth_idx] + 1  # Shape: [num_priors],0为背景,所以其余编号+1
conf[best_truth_overlap < threshold] = 0  # 置信度⼩于阈值的label设置为0
# 第6步
matches = truths[best_truth_idx]  # 取出最佳匹配的GT框,Shape: [num_priors,4]
# 进⾏位置编码一方什么
loc = encode(matches, priors,voc['variance'])
loc_t[idx] = loc  # [num_priors,4],应该学习的编码偏差
conf_t[idx] = conf  # [num_priors],每个锚点框的label
在函数match()中,使⽤到了函数encode()来对位置进⾏编码。参考和R-CNN中的公式,假设先验锚
点框的坐标为(d^{cx},d^{cy},d^w,d^h),预测框的坐标为(b^{cx},b^{cy},b^w,b^h),则预测框的转换值l为:
l^{cx}=(b^{cx}-d^{cx})/d^w,  l^{cy}=(b^{cy}-d^{cy})/d^h
b^w=d^wexp(l^x),  b^h=d^hexp(l^h)
⽽代码中,我们利⽤了⽅差的信息,因此进⾏了相应的调整,整体上是⼀致的。
def encode(matched, priors, variances):
'''
对坐标进⾏编码,对应论⽂中的公式2
利⽤GT框和先验锚点框,计算偏差,⽤于回归
:param matched: 每个先验锚点框对应最佳的GT框,Shape: [num_priors, 4],
其中4代表[xmin,ymin,xmax,ymax]
:param priors: 先验锚点框,Shape: [num_priors,4],
其中4代表[中⼼点x,中⼼点y,宽,⾼]
:return: shape:[num_priors, 4]
'''
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]  # 计算GT框与锚点框中⼼点的距离
g_cxcy /= (variances[0] * priors[:, 2:])
g_wh = (matched[:, 2:] - matched[:, :2])  # xmax-xmin,ymax-ymin
g_wh /= priors[:, 2:]
g_wh = torch.log(g_wh) / variances[1]
return torch.cat([g_cxcy, g_wh], 1)
⾄此,SSD的损失函数构建以介绍完成。相⽐于分类任务,⽬标检测的损失函数构建需要更多的代码,包含了各种tricks。
Processing math: 0%

本文发布于:2023-07-15 01:32:04,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/90/177710.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:锚点   损失   函数   计算   先验   置信度
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图