使用TensorFlowObjectDetectionAPI训练自己的目标检测模型(一)制作数据集

更新时间:2023-07-14 03:33:56 阅读: 评论:0

使⽤TensorFlowObjectDetectionAPI训练⾃⼰的⽬标检测模型
安提诺乌斯(⼀)制作数据集
当前⽬标检测⽹络已经⼗分成熟,github上可以找到各种各样的检测⽹络,如果是在⼿机上使⽤,还有⼗分⽅便的推理框架,但是如果不是⼿机应⽤,⼜没有NPU可以⽤,⼜不想使⽤CPU推理,那选择就不多了,这⾥我是使⽤ARMNN作为推理框架,使⽤RK3399上的GPU进⾏推理。⽹络使⽤的是SSD,但是跑起来只有6-8帧,达不到实时性要求,因为检测⽬标种类很少,原⽣的SSD⽹络要检测80种,显然是有很多冗余的,⽹络压缩的⽅法很多,其中剪枝基本上就是去掉冗余的通道,这和直接在设计⽹络时少⼀些通道有什么区别我也不太清楚,有知道的还请多多指教。这⾥参考了⽹上的⼀些做法,决定先使⽤TensorFlow Object Detection API训练⼀个只检测特定种类的⼩型的SSD⽹络。
参考连接:
1. 使⽤fiftyone下载数据集中的特定类别
要进⾏⽹络训练第⼀步就是制作数据集,我要检测的物体是person,⾸先肯定要⽤⼀个⽐较⼤的数据集去训练,这样模型的性能会好⼀些,⽬前主流的开源数据集合⼤概有以下⼏种:
image net1400w张图⽚,27⼤类和2W+⼩类
open image170w张图⽚,600类
ms coco33w张图⽚, 91类
pascal voc    1.7w张图⽚,20类
tensorflow detection API 中预训练⽤的是,我们也⽤coco,因为我们只检测⼈,所以只下载带⼈的图⽚,这⾥要⽤到,这是⼀个管理数据集的,上边那些数据集都⽀持。
安装好fiftyone以后,使⽤python下载:
import fiftyone as fo
as foz
datat = foz.load_zoo_datat(
"coco-2017",
label_types=["detections"],
class=["person"],
only_matching=True
)
⾥边这些参数可以到⽂档⾥看,都是什么意思,我这⾥是把coco数据集⾥所有的带⼈的图⽚都下载了,如果想先试⼀试可以限制⼀下下载数量,官⽅⽂档⾥的例⼦:
datat = foz.load_zoo_datat(
"open-images-v6",#⽤的openimage
split="validation",#默认的化就是分成tran val test
max_samples=100,#下载100张
ed=51,
shuffle=True,
)
coco中带⼈的图⽚⼀共61407张,进度条有的时候会断开,重新输⼊上边下载的命令就可以接着下载,还没有找到什么好的解决办法。
看⼀下coco数据集的⽬录结构
这⾥是fiftyone下的datat结构,⼤概的意思就是把coco数据集下载下来以后,fiftyone可以根据load时的参数⽣成⼀个datat结构,之后可以通过fiftyone的app去操作这个datat ,这⾥的json⽂件都是fiftyone⽤的,我们只是通过fiftyone下载特定类别,所以不⽤研究这些json。raw⽂件夹下是coco数据集的标注信息,这些标注信息⾥就有我们想要的bbox,当然还有其他信息,我们需要进⾏过滤。剩下的/test /train /validation ⾥边的data就是放的图⽚了。
2. 将下载的数据集转换成Tensorflow使⽤的tfrecord格式
将下载好的数据整理成如下⽬录结构
/annotations ⽂件夹⾥存放标注⽂件
/
test2017 /train2017 /val2017 存放相应图⽚,使⽤下边脚本根据coco数据集⽣成图⽚对应的xml标注⽂件,因为脚本⾥train2017 val2017除了路径还有其他作⽤,所以就不改代码了,改⼀下⽬录结构吧。还有⼀点要注意,下载的图⽚多的话有可能有损坏的,这个脚本没有做这个判断,需要⼿动把损坏的图⽚先删了。
import COCO
import os
import shutil
from tqdm import tqdm
import skimage.io as io
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# 需要设置的路径
home_path = os.environ['HOME']
savepath=home_path+"/COCO/"
img_dir=savepath+'images/'
anno_dir=savepath+'annotations/'
datats_list=['train2017', 'val2017']#
#coco有80类,这⾥写要提取类的名字,以person为例
class_names = ['person']
class_names = ['person']
#这⾥是coco数据集的路径
dataDir= home_path+'/coco_data/'
print(dataDir)
'''
⽬录格式如下:
素菜汤$COCO_PATH
----|annotations
----|train2017
----|val2017
----|test2017
'''
headstr = """\
<annotation>
<folder>VOC</folder>
<filename>%s</filename>
<source>
<databa>My Databa</databa>
<annotation>COCO</annotation>
<image>flickr</image>
<flickrid>NULL</flickrid>
</source>
<owner>
<flickrid>NULL</flickrid>
<name>company</name>
</owner>
<size>
<width>%d</width>性知识书籍
<height>%d</height>
<depth>%d</depth>
</size>
<gmented>0</gmented>
"""
objstr = """\
<object>滕王阁序拼音
<name>%s</name>
<po>Unspecified</po>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>%d</xmin>
<ymin>%d</ymin>
<xmax>%d</xmax>
<ymax>%d</ymax>
</bndbox>
</object>
"""
tailstr = '''\
</annotation>杜仲雄花茶的功效
'''
# 检查⽬录是否存在,如果存在,先删除再创建,否则,直接创建def mkr(path):
if not ists(path):
os.makedirs(path)  # 可以创建多级⽬录
def id2name(coco):
class=dict()
for cls in coco.datat['categories']:
class[cls['id']]=cls['name']
return class
def write_xml(anno_path,head, objs, tail):
def write_xml(anno_path,head, objs, tail):
f = open(anno_path, "w")
f.write(head)
for obj in objs:
f.write(objstr%(obj[0],obj[1],obj[2],obj[3],obj[4]))
f.write(tail)
def save_annotations_and_imgs(coco,datat,filename,objs):
#将图⽚转为xml,例:COCO_train2017_000000196610.jpg-->COCO_l    dst_anno_dir = os.path.join(anno_dir, datat)
mkr(dst_anno_dir)
anno_path=dst_anno_dir + '/' + filename[:-3]+'xml'
img_path=dataDir+datat+'/'+filename
#print("img_path: ", img_path)
dst_img_dir = os.path.join(img_dir, datat)
mkr(dst_img_dir)
dst_imgpath=dst_img_dir+ '/' + filename
#print("dst_imgpath: ", dst_imgpath)
img=cv2.imread(img_path)
#if (img.shape[2] == 1):
#    print(filename + " not a RGB image")
#  return
head=headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
tail = tailstr
write_xml(anno_path,head, objs, tail)
#    标注⽂件 train&val 图⽚信息所有类别⼈的类别id
def showimg(coco,datat,img,class,cls_id,show=True):
global dataDir
img_path = os.path.join(dataDir, datat,img['file_name']) #dataDir+datat+'/'+img['file_name']
#print(img_path)
objs = []
if not ists(img_path):
print("no such file")
el:
# 打开这个图⽚
try:
I=Image.open('%s/%s/%s'%(dataDir,datat,img['file_name']))
except UnidentifiedImageError:
print("bad image, skip!")
print(img_path)
#通过id,得到注释的信息
annIds = AnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
工作日志模板# print(annIds)
anns = coco.loadAnns(annIds)#得到这个图⽚的标注信息
# print(anns)
# coco.showAnns(anns)
for ann in anns:#遍历标注信息
class_name=class[ann['category_id']]#得到这个标注的类别
if class_name in class_names:#如果是我们想要的
#print(class_name)
if 'bbox' in ann:#如果标注信息⾥有bbox
bbox=ann['bbox']
xmin = int(bbox[0])
ymin = int(bbox[1])
xmax = int(bbox[2] + bbox[0])
ymax = int(bbox[3] + bbox[1])
obj = [class_name, xmin, ymin, xmax, ymax]
objs.append(obj)
draw = ImageDraw.Draw(I)
if show:
if show:
plt.figure()
plt.axis('off')
plt.imshow(I)
plt.show()
return objs
# 遍历标注⽂件 instances_train2017 和 instances_val2017 ⾥的数据
for datat in datats_list:
#./COCO/annotations/instances_train2017.json高速摄影
annFile='{}/annotations/instances_{}.json'.format(dataDir,datat)
#使⽤COCO API⽤来初始化注释数据
coco = COCO(annFile)
#获取COCO数据集中的所有类别
class = id2name(coco)
#print(class)
#[1, 2, 3, 4, 6, 8]
class_ids = CatIds(catNms=class_names)#class_names:person
#print(class_ids)# 打印出我们要挑选的类的id(person -> 1)
miss = 0
for cls in class_names:
#获取该类的id
cls_CatIds(catNms=[cls])
img_ImgIds(catIds=cls_id)
#print(cls,len(img_ids))# person 64115 标准⽂件⾥⼀共有64115个图⽚⾥有person
# imgIds=img_ids[0:10]
#print(img_ids)
for imgId in tqdm(img_ids):
img = coco.loadImgs(imgId)[0]
#print(img)
filename = img['file_name']
#print(filename)
objs=showimg(coco, datat, img, class,class_ids,show=Fal)
if(objs):
save_annotations_and_imgs(coco, datat, filename, objs)
空中码头打一地名el:
miss+1
print(miss)
运⾏脚本以后会⽣成⽬录
/annotatios ⾥装的是xml⽂件  /images ⾥装的是jpg图⽚
然后按照⾥的⽅法制作tfrecord
""" Sample TensorFlow XML-to-TFRecord converter
usage: generate_tfrecord.py [-h] [-x XML_DIR] [-l LABELS_PATH] [-o OUTPUT_PATH] [-i IMAGE_DIR] [-c CSV_PATH]
optional arguments:
-h, --help            show this help message and exit
-x XML_DIR, --xml_dir XML_DIR
Path to the folder where the input .xml files are stored.
-l LABELS_PATH, --labels_path LABELS_PATH
Path to the labels (.pbtxt) file.
-o OUTPUT_PATH, --output_path OUTPUT_PATH

本文发布于:2023-07-14 03:33:56,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/1095425.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:数据   下载   标注   信息   检测
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图