基于mmdetection训练SwinTransformerObjectDetection

更新时间:2023-07-21 04:28:05 阅读: 评论:0

基于mmdetection训练SwinTransformerObjectDetection
starring role
环境搭建
docker
找了⼀个torch版本为1.5.1+cu101的docker环境,然后安装mmdetection环境
pip install mmcv-full
specializegit /SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection-master
pip install -r
pip install -v -e .
安装apex
git /NVIDIA/apex
cd apex
pip install -
python tup.py install --cpp_ext
安装成功
Processing dependencies for apex==0.1
Finished processing dependencies for apex==0.1
backbone:mmdet/models/backbones
neck:mmdet/models/necks
高考状元的学习方法
head:mmdet/models/roi_heads
BBox Assigner:mmdet/core/bbox/assigners
BBox Sampler:mmdet/core/bbox/samplers
BBox Encoder:mmdet/core/bbox/coder
BBox Decoder:mmdet/core/bbox/coder
Loss:mmdet/models/loss
BBox PostProcess:mmdet/core/post_processing
在"Swin-Transformer-Object-Detection-master/configs/swin/"⽬录下,可以看到模型⽂件,选择对应的修改以"cascade_mask_rcnn_swin_ba_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py"为例:
# head为例
roi_head=dict(
bbox_head=[
dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
礼仪培训班conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_class=15,# 修改类别数量
# 根据gpu的数量,使⽤合适的BN
# norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_cfg=dict(type='BN', requires_grad=True),
# 调整学习率等相关参数,lr = 0.00125*batch_size
optimizer =dict(_delete_=True,type='AdamW', lr=0.00125, betas=(0.9,0.999), weight_decay=0.05,
paramwi_cfg=dict(custom_keys={'absolute_pos_embed':dict(decay_mult=0.),
'relative_position_bias_table':dict(decay_mult=0.),
'norm':dict(decay_mult=0.)}))
# 修改epoch
runner =dict(type='EpochBadRunner', max_epochs=20)
# 不适⽤fp16,将u_fp16改为Fal
optimizer_config =dict(
type="DistOptimizerHook",
update_interval=1,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
u_fp16=Fal,grow的过去式
)
在"configs/ba/datats/coco_instance.py"中根据需要修改
# 修改数据集的类型,路径
datat_type = 'CocoDatat'
data_root = '/home/coco/'
# 修改img_size等参数,CUDA out of memory时可以修改
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
# 原本为1333*800
#dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='Resize', img_scale=(416, 416), keep_ratio=True),
# 修改batch_size
data = dict(
samples_per_gpu=1, # 每块GPU上的sample个数,batch_size = gpu数⽬*该参数
workers_per_gpu=1, # 每块GPU上的workers的个数
# 以train为例
train=dict(
type=datat_type,
ann_file=data_root + 'annotations/instances_train2017.json', # 标注路径
img_prefix=data_root + 'train2017/', # 训练图⽚路径
pipeline=train_pipeline),
修改类别:mmdet/datats/coco.py和 mmdet/core/evaluation/class_names.py⽂件
class CocoDatat(CustomDatat):
#CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
#          'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
#          'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
#          'hor', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
#          'backpack', 'umbrella', 'handbag', 'tie', 'suitca', 'frisbee',
#          'skis', 'snowboard', 'sports ball', 'kite', 'baball bat',
#          'baball glove', 'skateboard', 'surfboard', 'tennis racket',
#          'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
#          'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
#          'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
#          'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
#          'mou', 'remote', 'keyboard', 'cell phone', 'microwave',
#          'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
#          'va', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
CLASSES =('person','tool_vehicle','bicycle','motorbike','pedal_tricycle','car','pasnger_car',
'truck','police_car','ambulance','bus','dump_truck','tanker','roadblock','fire_car')
def coco_class():
return['person','tool_vehicle','bicycle','motorbike','pedal_tricycle','car','pasnger_car',
'truck','police_car','ambulance','bus','dump_truck','tanker','roadblock','fire_car']
修改"./tools/train.py"⽂件
# 选取其中⼀种版本,单机版本 MMDataParallel、分布式(单机多卡或多机多卡)版本 MMDistributedDataParallel
parr.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
模型预训练,权重加载、保存参数,config/ba/default_runtime.py⽂件
checkpoint_config =dict(interval=1)# 每训练⼀个epoch,保存⼀次权重
load_from =None# 加载backbone权重
resume_from =None# 继续训练
训练模型
使⽤编号为3的单个gpu训练
python ./tools/train.py configs/swin/cascade_mask_rcnn_swin_ba_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py --gpu-ids 3
使⽤多gpu训练
tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 4
训练Log及权重
保存在"Swin-Transformer-Object-Detection-master/work_dirs/"中
coco测试
python tools/test.py configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py cascade_mask_ rcnn_swin_small_patch4_window7.pth --eval gm
输出demo,输出为cls,x1,y1,x2,y2的txt格式
from argpar import ArgumentParr
from mmdet.apis import inference_detector, init_detector
import numpy as np
import os
from tqdm import tqdm
def main():
parr = ArgumentParr()
parr.add_argument('--img-path', default='/data/wj/test/',help='Image file')
parr.add_argument('--config', default='../work_dirs/cascade_rcnn_x101_64x4d_fpn_20e_coco/cascade_rcnn_x101_64x4d_fpn_20e_coco.py',help='C onfig file')
parr.add_argument('--checkpoint', default='../work_dirs/cascade_rcnn_x101_64x4d_fpn_20e_coco/latest.pth',help='Checkpoint file')
parr.add_argument(
'--device', default='cuda:0',help='Device ud for inference')
parr.add_argument(
'--score-thr',type=float, default=0.3,help='bbox score threshold')
args = parr.par_args()
imgs_path = args.img_path手续英语
save_path ='../output/'
# build the model from a config file and a checkpoint file
model = init_fig, args.checkpoint, device=args.device)
for img_path in tqdm(os.listdir(imgs_path)):
img = os.path.join(imgs_path, img_path)
世界大学排名2020
result = inference_detector(model, img)
bboxes = np.vstack(result)
labels =[
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
]
labels = np.concatenate(labels)
score_thr = args.score_thr
if score_thr >0:
asrt bboxes.shape[1]==5
scores = bboxes[:,-1]
inds = scores > score_thr
bboxes = bboxes[inds,:]
labels = labels[inds]
learning to e>cream soda
if len(bboxes)==0:
txt_path = os.path.join(save_path,'{}.txt'.format(img_path.split('.')[0]))
with open(txt_path,'w')as f:
f.write("")
for i,(bbox, label)in enumerate(zip(bboxes, labels)):
bbox_int = bbox.astype(np.int32)
x1, y1, x2, y2, conf = bbox_int
txt_path = os.path.join(save_path,'{}.txt'.format(img_path.split('.')[0]))
with open(txt_path,'a')as f:
f.write("{} {} {} {} {}\n".format(label, x1, y1, x2, y2))
踩过的坑及解决⽅案:
参考:
>cdfs是什么意思

本文发布于:2023-07-21 04:28:05,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/90/183910.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:修改   版本   权重   训练   学习   环境
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图