detectron2中将模型参数的pkl文件load到model里

更新时间:2023-07-17 16:13:34 阅读: 评论:0

detectron2中将模型参数的pkl⽂件load到model⾥这⾥以fpn⽹络为例,进⾏讲解
⾸先fpn的⽹络参数⽂件在进⾏下载。它是以字典形式存储在pkl⽂件中,并且key是以⽹络每⼀层的名称命名,如下所⽰
⽽且它的命名⽅式基于cafe2的,与detectron2不同,因此不能直接将其load到model.state_dict中,其中detectron2中⽹络的命名⽅式如下:
不同的层之间使⽤点来分隔⽽不是下划线。
因此detectron2中提供了可以从cafe2到detectron2的转换函数,主要是通过正则表达式进⾏字符串的匹配
convert_c2_detectron_names(weights),其中weights就是需要进⾏转变的参数字典
def convert_c2_detectron_names(weights):
"""
Map Caffe2 Detectron weight names to Detectron2 names.
看大海的心情说说Args:
weights (dict): name -> tensor
Returns:素质英语
dict: detectron2 names -> tensor
dict: detectron2 names -> C2 names
"""
logger = Logger(__name__)
logger.info("Renaming Caffe2 weights ......")
original_keys =sorted(weights.keys())
layer_keys = copy.deepcopy(original_keys)
layer_keys = convert_basic_c2_names(layer_keys)
# --------------------------------------------------------------------------
# RPN hidden reprentation conv
# --------------------------------------------------------------------------
# FPN ca
# In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
# shared for all other levels, hence the appearance of "fpn2"
layer_keys =[
]
# Non-FPN ca
layer_keys =[k.replace("conv.rpn","proposal_generator.v")for k in layer_keys]
# --------------------------------------------------------------------------
# RPN box transformation conv
# --------------------------------------------------------------------------
# FPN ca (e note above about "fpn2")
layer_keys =[
调薪申请
]
layer_keys =[
]
# Non-FPN ca
layer_keys =[
layer_keys =[
]
# --------------------------------------------------------------------------
# Fast R-CNN box head
# --------------------------------------------------------------------------
layer_keys =[re.sub("^bbox\\.pred","bbox_pred", k)for k in layer_keys]
layer_keys =[re.sub("^cls\\.score","cls_score", k)for k in layer_keys]
layer_keys =[re.sub("^fc6\\.","box_head.fc1.", k)for k in layer_keys]
layer_keys =[re.sub("^fc7\\.","box_head.fc2.", k)for k in layer_keys]
# 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
layer_keys =[re.sub("^head\\.conv","v", k)for k in layer_keys]
我的世界可可豆# --------------------------------------------------------------------------
# FPN lateral and output convolutions
# --------------------------------------------------------------------------
def fpn_map(name):
"""
Look for keys with the following patterns:
1) Starts with "fpn.inner."
Example: "s2.2.sum.lateral.weight"
Meaning: The are lateral pathway convolutions
2) Starts with "s"
Example: "s2.2.sum.weight"
Meaning: The are FPN output convolutions
"""
splits = name.split(".")
norm =".norm"if"norm"in splits el""
if name.startswith("fpn.inner."):
# splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
stage =int(splits[2][len("res"):])
return"fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
elif name.startswith("s"):
# splits example: ['fpn', 'res2', '2', 'sum', 'weight']
stage =int(splits[1][len("res"):])
return"fpn_output{}{}.{}".format(stage, norm, splits[-1])
return name
layer_keys =[fpn_map(k)for k in layer_keys]
layer_keys =[fpn_map(k)for k in layer_keys]
# --------------------------------------------------------------------------
危化品# Mask R-CNN mask head
# --------------------------------------------------------------------------
# roi_heads.StandardROIHeads ca
layer_keys =[k.replace(".[mask].fcn","mask_head.mask_fcn")for k in layer_keys]
layer_keys =[re.sub("^\\.mask\\.fcn","mask_head.mask_fcn", k)for k in layer_keys]
layer_keys =[k.replace("mask.fcn.logits","mask_head.predictor")for k in layer_keys]
# roi_heads.Res5ROIHeads ca
layer_keys =[k.replace("conv5.mask","mask_head.deconv")for k in layer_keys]
# --------------------------------------------------------------------------
# Keypoint R-CNN head
# --------------------------------------------------------------------------
# interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
layer_keys =[k.replace("conv.fcn","roi_heads.v_fcn")for k in layer_keys]    layer_keys =[
layer_keys =[k.replace("kps.score.","roi_heads.keypoint_head.score.")for k in layer_keys]
# --------------------------------------------------------------------------
# Done with replacements
# --------------------------------------------------------------------------
asrt len(t(layer_keys))==len(layer_keys)
asrt len(original_keys)==len(layer_keys)
new_weights ={}
new_keys_to_original_keys ={}
for orig, renamed in zip(original_keys, layer_keys):
new_keys_to_original_keys[renamed]= orig
if renamed.startswith("bbox_pred.")or renamed.startswith("mask_head.predictor."):
# remove the meaningless prediction weight for background class
new_start_idx =4if renamed.startswith("bbox_pred.")el1
new_weights[renamed]= weights[orig][new_start_idx:]
26用英语怎么写logger.info(
"Remove prediction weight for background class in {}. The shape changes from "
"{} to {}.".format(
renamed,tuple(weights[orig].shape),tuple(new_weights[renamed].shape) )
)
elif renamed.startswith("cls_score."):
下的四字成语# move weights of bg class from original index 0 to last index
logger.info(
"Move classification weights for background class in {} from index 0 to "
"index {}.".format(renamed, weights[orig].shape[0]-1)
)
new_weights[renamed]= torch.cat([weights[orig][1:], weights[orig][:1]]) el:
new_weights[renamed]= weights[orig]
return new_weights, new_keys_to_original_keys
处理backbone中的命名
def convert_basic_c2_names(original_keys):
"""
Apply some basic name conversion to names in C2 weights.
It only deals with typical backbone models.
Args:
original_keys (list[str]):
Returns:
list[str]: The same number of strings matching tho in original_keys.
"""
layer_keys = copy.deepcopy(original_keys)
layer_keys =[
{"pred_b":"linear_b","pred_w":"linear_w"}.get(k, k)for k in layer_keys
]# some hard-coded mappings
# ⾸先将下划线换成点
layer_keys =[k.replace("_",".")for k in layer_keys]
# 将以.b和.w结尾的换成.bias和.weight
layer_keys =[re.sub("\\.b$",".bias", k)for k in layer_keys]
layer_keys =[re.sub("\\.w$",".weight", k)for k in layer_keys]
# Uniform both bn and gn names to "norm"
layer_keys =[re.sub("bn\\.s$","norm.weight", k)for k in layer_keys]
layer_keys =[re.sub("bn\\.bias$","norm.bias", k)for k in layer_keys]
layer_keys =[re.sub("bn\\.rm","norm.running_mean", k)for k in layer_keys]
layer_keys =[re.sub("bn\\.an$","norm.running_mean", k)for k in layer_keys]    layer_keys =[re.sub("bn\\.riv$","norm.running_var", k)for k in layer_keys]
layer_keys =[re.sub("bn\\.running.var$","norm.running_var", k)for k in layer_keys]
layer_keys =[re.sub("bn\\.gamma$","norm.weight", k)for k in layer_keys]
layer_keys =[re.sub("bn\\.beta$","norm.bias", k)for k in layer_keys]庄子的名言
layer_keys =[re.sub("gn\\.s$","norm.weight", k)for k in layer_keys]
layer_keys =[re.sub("gn\\.bias$","norm.bias", k)for k in layer_keys]
# stem
layer_keys =[re.sub("^res\\.conv1\\.norm\\.","", k)for k in layer_keys]
# to avoid mis-matching with "conv1" in other components (e.g. detection head)
layer_keys =[re.sub("^conv1\\.","v1.", k)for k in layer_keys]
# layer1-4 is ud by torchvision, however we follow the C2 naming strategy (res2-5)
# layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
# layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
# layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
# layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
# blocks
layer_keys =[k.replace(".branch1.",".shortcut.")for k in layer_keys]
layer_keys =[k.replace(".branch2a.",".conv1.")for k in layer_keys]
layer_keys =[k.replace(".branch2b.",".conv2.")for k in layer_keys]
layer_keys =[k.replace(".branch2c.",".conv3.")for k in layer_keys]
# DenPo substitutions
layer_keys =[re.sub("^v.fcn","body_conv_fcn", k)for k in layer_keys]
layer_keys =[k.replace("AnnIndex.lowres","ann_index_lowres")for k in layer_keys]
layer_keys =[k.replace("Index.UV.lowres","index_uv_lowres")for k in layer_keys]
layer_keys =[k.replace("U.lowres","u_lowres")for k in layer_keys]
layer_keys =[k.replace("V.lowres","v_lowres")for k in layer_keys]
return layer_keys
在将所有⽹络名字转换成detectron2风格后就可以进⾏参数加载了

本文发布于:2023-07-17 16:13:34,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/89/1085281.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:参数   命名   转变   字典   转换   字符串
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图