【因果学习】VCRCNN(CVPR2020)代码

更新时间:2023-05-20 10:10:23 阅读: 评论:0

【因果学习】VCRCNN(CVPR2020)代码
作者基于MaskRCNN框架(Detectron2的前⾝)开发。受Bottom-Up and Top-Down Attention for Image Captioning and VQA启发,使⽤Mask RCNN作为Bottom-Up的backbone,为Downstream任务例如Image Captioning、VQA等提供图⽚特征。
thewolf论⽂中提到,去掉了RPN,使⽤GT bbox作为输⼊,训练的损失修改为:
测试阶段,则变为特征提取阶段,通过ROI_HEAD输出的特征,认为是VC Feature。
配置⽂件在:e2e_mask_rcnn_R_101_FPN_1x.yaml,相较于MaskRCNN,作者的BASE_LR从0.02修改为0.005,MAX_ITERS从90k修改为240k,同时作者是从头训练。主要涉及的⽂件在:ROI_BOX_HEAD中:FPN2MLPFeatureExtrator和FPNPredictor,其中前者是ROI Align和flatten + 两层fc+relu,输出1024维特征。后者FPNPredictor则是class预测和box回归。
在box_head.py的ROIBoxHead()中增加了causal_predictor和feature_save_path。
对于predictor(),去掉了box_regression部分,只对class进⾏分类。⽤class_logits和class_logits_causal_list送⼊loss_evaluator(),并在测试阶段,执⾏save_object_feature_gt_bu()。
在roi_box_predictors.py中增加了CausalPredictor()
@registry.ROI_ister("CausalPredictor")
class CausalPredictor(nn.Module):
def __init__(lf, cfg, in_channels):
super(CausalPredictor, lf).__init__()
num_class = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES
reprentation_size = in_channels
lf.causal_score = nn.Linear(2*reprentation_size, num_class)cmct
lf.Wy = nn.Linear(reprentation_size, lf.embedding_size)
lf.Wz = nn.Linear(reprentation_size, lf.embedding_size)
al_(lf.causal_score.weight, std=0.01)
al_(lf.Wy.weight, std=0.02)
al_(lf.Wz.weight, std=0.02)
stant_(lf.Wy.bias, 0)
stant_(lf.Wz.bias, 0)
stant_(lf.causal_score.bias, 0)
lf.feature_size = reprentation_size
lf.dic = sor(np.load(cfg.DIC_FILE)[1:], dtype=torch.float)
lf.prior = sor(np.load(cfg.PRIOR_PROB), dtype=torch.float)
lpg是什么
def forward(lf, x, proposals):
device = x.get_device()
dic_z = (device)
prior = (device)
box_size_list = [proposal.bbox.size(0) for proposal in proposals]
feature_split = x.split(box_size_list)
xzs = [lf.z_dic(feature_pre_obj, dic_z, prior) for feature_pre_obj in feature_split]
causal_logits_list = [lf.causal_score(xz) for xz in xzs]
happy feetreturn causal_logits_list
def z_dic(lf, y, dic_z, prior):
"""
Plea note that we computer the intervention in the whole batch rather than for one object in the main paper.        """
length = y.size(0)
if length == 1:
print('debug')
时代周刊年度风云人物attention = (lf.Wy(y), lf.Wz(dic_z).t()) / (lf.embedding_size ** 0.5)
attention = F.softmax(attention, 1)
z_hat = attention.unsqueeze(2) * dic_z.unsqueeze(0)
z = torch.matmul(prior.unsqueeze(0), z_hat).squeeze(1)neutral
dreamtxz = torch.cat((y.unsqueeze(1).repeat(1, length, 1), z.unsqueeze(0).repeat(length, 1, 1)), 2).view(-1, 2*y.size(1))
# detect if encounter nan
if torch.isnan(xz).sum():
print(xz)
return xz
在loss.py中修改了FastRCNNLossComputation()中的__call__函数
def __call__(lf, class_logits, causal_logits_list, proposals):
"""
龙文教育官网Computes the loss for Faster R-CNN.
This requires that the subsample method has been called beforehand.
Arguments:
class_logits (list[Tensor])
box_regression (list[Tensor])
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""
class_logits = cat(class_logits, dim=0)
device = class_logits.device
labels = [_field("labels").to(dtype=torch.int64) for proposal in proposals]
labels_lf = cat(labels, dim=0)
# lf predictor loss
classification_loss = F.cross_entropy(class_logits, labels_lf)
# context predictor loss
红玫瑰的英文
causal_loss = 0.
for causal_logit, label in zip(causal_logits_list, labels):
mask_label = label.unsqueeze(0).repeat(label.size(0), 1)
mask = 1 - (mask_label.size(0)).to(device)
loss_causal = F.cross_entropy(causal_logit, mask_label.view(-1), reduction='none')
loss_causal = loss_causal * mask.view(-1)
causal_loss += an(loss_causal)
return classification_loss, causal_loss
在box_head.py的ROIBoxHead()中增加函数,⽤于在测试中,保存feature
def save_object_feature_gt_bu(lf, x, result, targets):
for i, image in enumerate(result):
feature_pre_image = _field("features").cpu().numpy()
try:
_field("num_box")[0] == feature_pre_image.shape[0]
image_id = _field("image_id")[0].cpu().numpy())
path = os.path.join(lf.feature_save_path, image_id) +'.npy'
np.save(path, feature_pre_image)
except:
print(image)
总的来说,作者去掉了和bbox相关的所有部分,本⽂使⽤的Mask R-CNN测试时需要提供bbox GT,某种程度上来说,它只执⾏了分类任务,并不包含任何的定位信息,因此不能单独使⽤,必须要加上Up-Down feature。
不同于Up-Down那篇论⽂的Faster RCNN是可以⽤于⽬标检测任务的。
land page从论⽂的测试结果也可以看出,Only VC的效果是⽐Origin要低的。

本文发布于:2023-05-20 10:10:23,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/78/706133.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:训练   任务   测试   特征   作者   变为
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图