一文读懂FocalLoss及Pytorch代码(详细注释)

更新时间:2023-07-08 01:26:29 阅读: 评论:0

⼀⽂读懂FocalLoss 及Pytorch 代码(详细注释)
⽂章⽬录
前⾔
Focal Loss及RetinaNet原理见另⼀篇⽂章:李世民的兄弟
本⽂介绍Focal Loss以及其Pytorch实现。
Focal Loss 详解
直接上公式:
其中:交叉熵损失:
是类别(t个不同类别)概率(多分类就是softmax的结果),衡量样本难易程度,如果较⼤则是简单样本,较⼩则是困难样本;是调节类别权重因⼦,它的值为第⼀类正样本权重,RetinaNet中设置为0.25,由于正样本远少于负样本,所以这样设置,让正样本的权重低,负样本权重为0.75;是调节难易样本的权重因⼦,让模型快速关注困难样本。RetinaNet中设为2;
整体控制损失⼤⼩,因为,所以时,loss⼤;时,loss⼩;这样合理的权重分配可以让模型更好的学习训练;
当时,FL就变成了CE。所以在复现的时候提供了⼀个思路,那就是定义前⾯两个权重因⼦,再乘上CE就得到了FL,是softmax的结果,那么就是log_softmax的结果。
实现思路
定义focal_loss类:
1.参数定义:alpha,gamma,num_class等,应⽤到⽬标检测中可能还需要anchors等参数。
2.forward:focalloss是由CE乘以因⼦组成的,⽽CE是由NLL和log_softmax组成的。所以,先实现NLL和log_softmax,然后⼀步⼀步通过公式实现focal loss,主要是Tensor的变换。
3.输⼊参数设计:输⼊张量是[B,N,C]或[B,C],其中B是批量,N是预测框数量,C是类别数;[B,C]就是单纯的分类问题。真实标签是[B,N]或[B],总的标签数,如果是⽬标检测中则标签数就是B*N。
Focal Loss 类的代码
from  torch import  nn
FL (p )=t −α(1−t p )log (p )
t γt p =t {p ,
1−p ,if  y =1otherwi王氏庄园
CE (p )=t −log (p )
t p t p t αt γ(1−p )t γγ=2p →t 0p →t 1α=1,γ=0p t log (p )t
from torch import nn
import torch
import functional as F
class focal_loss(nn.Module):
def__init__(lf, alpha=0.25, gamma=2, num_class =5, size_average=True):
"""
focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
步骤详细的实现了 focal_loss损失函数.
:param alpha:  阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常⽤于⽬标检测算法中抑制背景类 , retainnet 中设置为0.255
:param gamma:  伽马γ,难易样本调节参数. retainnet中设置为2
:param num_class:    类别数量
:param size_average:    损失计算⽅式,默认取均值
"""
super(focal_loss,lf).__init__()
lf.size_average = size_average
if isinstance(alpha,list):
asrt len(alpha)==num_class  # α可以以list⽅式输⼊,size:[num_class] ⽤于对不同类别精细地赋予权重
print(" --- Focal_loss alpha = {}, 将对每⼀类权重进⾏精细化赋值 --- ".format(alpha))
lf.alpha = torch.Tensor(alpha)
el:
asrt alpha<1#如果α为⼀个常数,则降低第⼀类的影响,在⽬标检测中为第⼀类
print(" --- Focal_loss alpha = {} ,将对背景类进⾏衰减,请在⽬标检测任务中使⽤ --- ".format(alpha))
lf.alpha = s(num_class)
lf.alpha[0]+= alpha
lf.alpha[1:]+=(1-alpha)# α最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_class]
lf.gamma = gamma
def forward(lf, preds, labels):
"""
focal_loss损失计算
:param preds:  预测类别. size:[B,N,C] or [B,C]    分别对应与检测与分类任务, B批次, N检测框数, C类别数
:param labels:  实际类别. size:[B,N] or [B]        [B*N个标签(假设框中有⽬标)],[B个标签]
:return:
"""
#固定类别维度,其余合并(总检测框数或总批次数),preds.size(-1)是最后⼀个维度
preds = preds.view(-1,preds.size(-1))
lf.alpha = (preds.device)
#使⽤log_softmax解决溢出问题,⽅便交叉熵计算⽽不⽤考虑值域
preds_logsoft = F.log_softmax(preds, dim=1)
#log_softmax是softmax+log运算,那再exp就算回去了变成softmax
preds_softmax = p(preds_logsoft)
# 这部分实现nll_loss ( crosntropy = log_softmax + nll)行政等级从低到高
preds_softmax = preds_softmax.gather(1,labels.view(-1,1))
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
lf.alpha = lf.alpha.gather(0,labels.view(-1))
# torch.pow((1-preds_softmax), lf.gamma) 为focal loss中 (1-pt)**γ
#torch.mul 矩阵对应位置相乘,⼤⼩⼀致
loss =-torch.mul(torch.pow((1-preds_softmax), lf.gamma), preds_logsoft)
#torch.t()求转置
优美文段loss = torch.mul(lf.alpha, loss.t())
#print(loss.size()) [1,5]
if lf.size_average:
loss = an()
el:
loss = loss.sum()
return loss
注:根据⾃⼰需要修改num_class的值
具体流程分析代码
先简单⼀点测试⼀下图⽚分类:
假设有三张图⽚,分类类别⼀共五类,三张图⽚的真实类别类别数分别为2,3,4
torch.manual_ed(50)
preds = torch.randn((3,5))
#preds = torch.randn((3,10,5))
print(preds)
preds = preds.view(-1,preds.size(-1))
print(preds.size())
labels = sor([2,3,4])
#labels = sor([2,3,4]*10)
print(labels.view(-1))
print(labels.view(-1,1))
看⼀下各张量值以及size()后⾯会⽤到:
计算log_softmax以及softmax:
preds_logsoft = F.log_softmax(preds, dim=1)# log_softmax
print(preds_logsoft)
preds_softmax = p(preds_logsoft)
print(preds_softmax)
结果:
从中取得每个图⽚分别在2,3,4类别的预测值:如何取得可见详解。
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
print(preds_logsoft)
迪士尼一日游攻略preds_softmax = preds_softmax.gather(1,labels.view(-1,1))
学化妆print(preds_softmax)
结果:
计算loss并转置:gamma取2
a = torch.pow((1-preds_softmax),2)
loss =-torch.mul(torch.pow((1-preds_softmax),2), preds_logsoft)
print(loss.t())
print(loss.t().size())
结果:
初始化alpha权重,并得到所需label的alpha权重
alpha = sor([0.25,0.75,0.75,0.75,0.75])
alpha = alpha.gather(0,labels.view(-1))
alpha
结果:
乘以alpha并求平均得到最终的focal loss
loss = torch.mul(alpha, loss.t())
loss = an()
loss
得到focal loss为:
提问问题大全
PS:若要加⼊检测框数量,输⼊采⽤代码中注释的部分,本例中是假设每张图⽚有10个检测框(正样本),则总标签个数就是30个,相应地改变标签张量为30个元素即可。
Focal Loss(FL)与CrossEntropy(CE)对⽐:
criterion = focal_loss()音响怎么连接电视
loss = criterion(preds, labels)
print("FL  loss",loss)
ss_entropy(preds, labels)
#a=nn.CrossEntropyLoss()
#a = a(preds,labels)
print('CE  loss',a)
结果说明FL⽐CE的损失低,效果好
总结
1. focal loss在⽬标检测中预测类别时使⽤,是分类⼦⽹下的损失。
2. 先log_softmax再取exp得到的softmax各类别之和不是严格为1。
3. F.cross_entropy和nn.CrossEntropyLoss相同。
4. gather函数详解见
5. 可以根据分类或检测类别中的⽬标多少来⾃定义alpha的权重。测试代码见资源()

本文发布于:2023-07-08 01:26:29,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/1084625.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:权重   类别   样本   检测   分类
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图