[深度学习-实战项⽬]以图搜图Resnet+LSH-特征编码图像检索相似度计算
参考代码来源于
以图搜图
1.写在最前⾯
⼊职新公司以后⼀直在搞项⽬,没什么时间写博客。
最近⼀个项⽬是以图搜图项⽬,主要⽤到的技术就是⽬标检测(yolo)+图像检索(ResNet+LSH)。
⽬标检测就不⽤多说了,成熟和现成的代码⼀抓⼀⼤把,主要问题就是在优化提升精度和性能上的摸索。
图像检索的技术也挺多,但是⽹上的资源相对较少,所以记录⼀下这段时间⽤到的⼀个代码。
这个以图搜图和⼈脸识别技术其实很像,可以说是⼀样。⽆⾮就是提取特征,然后进⾏相似度计算。所以相关的技术有ReID,Arcface,以
及我在调研的时候有看到⼀个素描草图的图像匹配的研究。
2.源码解析
(1)跑通代码-即测试⼀下⾃⼰的图⽚
我只⽤到⾥⾯的编码和检索部分。运⾏。直接跑通这部分,然后缺什么库函数去pipinstall就⾏了。
val_featureimportAntiFraudFeatureDatat
val_indeximportEvaluteMap
if__name__=='__main__':
hash_size=0
input_dim=2048
num_hashtables=1
img_dir='ImageRetrieval/data'#存放所有图像库的图⽚
test_img_dir='./images'#待检索的图像
network='./weights/'#模型权重
#下⾯这⼏个好像没有⽤,不管他
out_similar_dir='./output/similar'
out_similar_file_dir='./output/similar_file'
all_csv_file='./output/'
feature_dict,lsh=AntiFraudFeatureDatat(img_dir,network).constructfeature(hash_size,input_dim,num_hashtables)
test_feature_dict=AntiFraudFeatureDatat(test_img_dir,network).test_feature()
EvaluteMap(out_similar_dir,out_similar_file_dir,all_csv_file).retrieval_images(test_feature_dict,lsh,3)
(2)特征编码
代码⾸先对img_dir中的所有图⽚进⾏特征提取:feature_dict,lsh=AntiFraudFeatureDatat(img_dir,network).constructfeature(hash_size,
input_dim,num_hashtables)
返回的feature_dict就是图⽚特征。(可以直接⽤余弦相似度进⾏相似计算)
但是这⾥还通过LSH对每张图⽚特征图进⾏0,1编号,所在这⾥后⾯⽤来图⽚检索的不是feature_dict,⽽是lsh,(应该是加速后⾯图⽚
检索时候的速度)
进到特征编码那块代码retrieval_,⾥⾯主要对图⽚进⾏编码的函数对象是AntiFraudFeatureDatat
⾸先前⾯⼀⼤段到(),都是加载⽹络模型,可以看到模型选择有很多参数,这些参数对应⽹络的结构设置,(后⾯如果⽤⾃⼰的数据
对⾃⼰的特征编码模型进⾏训练的话,要根据使⽤的不同模型参数进⾏修改)
这个函数ImageProcess是遍历⽬录底下的全部图⽚,并将他们的路径保存在数组中。
然后再这个函数extract_vectors中提取图像特征。(在这个⽬录底下ImageRetrieval-LSH/cirtorch/networks/)主要也不需要
怎么做修改,除⾮说你要修改⼀下图⽚的dataloader(这⾥是通过将所有图⽚路径保存下来做的datat,因为每张图⽚的尺⼨可以不⼀
样,Resnet⽹络的最后通过⼀个全连接层输出1*2048特征图。)
所以这⾥出来的vecs是N张图⽚的特征编码,每个特征编码是1*2048。
defconstructfeature(lf,hash_size,input_dim,num_hashtables):
multiscale='[1]'
print(">>Loadingnetwork:n>>>>'{}'".format(k))
state=(k)
net_params={}
net_params['architecture']=state['meta']['architecture']
net_params['pooling']=state['meta']['pooling']
net_params['local_whitening']=state['meta'].get('local_whitening',Fal)
net_params['regional']=state['meta'].get('regional',Fal)
net_params['whitening']=state['meta'].get('whitening',Fal)
net_params['mean']=state['meta']['mean']
net_params['std']=state['meta']['std']
net_params['pretrained']=Fal
#networkinitialization
net=init_network(net_params)
_state_dict(state['state_dict'])
print(">>>>loadednetwork:")
print(_repr())
#ttingupthemulti-scaleparameters
ms=list(eval(multiscale))
print(">>>>Evaluatingscales:{}".format(ms))
#movingnetworktogpuandevalmode
_available():
()
()
#tupthetransform数据预处理
normalize=ize(
mean=['mean'],
std=['std']
)
transform=e([
or(),
normalize
])
#extractdatabaandqueryvectors对图⽚进⾏编码提取数据库图⽚特征
print('>>databaimages...')
images=ImageProcess(_dir).process()
vecs,img_paths=extract_vectors(net,images,1024,transform,ms=ms)
feature_dict=dict(zip(img_paths,list(().cpu().numpy().T)))
#index
lsh=LSHash(hash_size=int(hash_size),input_dim=int(input_dim),num_hashtables=int(num_hashtables))
forimg_path,vecinfeature_():
(n(),extra_data=img_path)
###保存索引模型
#withopen(e_path,"wb")asf:
#(feature_dict,f)
#withopen(_path,"wb")asf:
#(lsh,f)
print("extractfeatureisdone")
returnfeature_dict,lsh
(3)图像检索
这⾥图像检索这块我没怎么改动,因为只是测试⼀下⾃⼰训练后的模型的效果⽐较⽅便查看⽤的。所以我只是修改了输出的数量。
这⾥如果要输出多个Top,要⾃⼰多加⼏个,(也可以⾃⼰写个循环,我⽐较懒,没有写)然后后⾯我还显⽰出了得分情况,(因为后⾯要
进⾏模型的对⽐)
deffind_similar_img_gyz(lf,feature_dict,lsh,num_results):
forq_path,q_vecinfeature_():
try:
respon=(q_n(),distance_func="cosine")#,num_results=int(num_results)
#print(respon[0][1])
#print((100*(1-respon[0][1])))
query_img_path0=respon[0][0][1]
query_img_path1=respon[1][0][1]
query_img_path2=respon[2][0][1]
query_img_path3=respon[3][0][1]
query_img_path4=respon[4][0][1]
score_img_path0=respon[0][1]
score_img_path1=respon[1][1]
score_img_path2=respon[2][1]
score_img_path3=respon[3][1]
score_img_path4=respon[4][1]
#score0=respon[0][1]
#score0=(100*(1-score0))
print('**********************************************')
print('inputimg:{}'.format(q_path))
print('query0img:{}'.format(query_img_path0),
'score:{}'.format((100*(1-score_img_path0))))
print('query1img:{}'.format(query_img_path1),
'score:{}'.format((100*(1-score_img_path1))))
print('query2img:{}'.format(query_img_path2),
'score:{}'.format((100*(1-score_img_path2))))
print('query3img:{}'.format(query_img_path3),
'score:{}'.format((100*(1-score_img_path3))))
print('query4img:{}'.format(query_img_path4),
'score:{}'.format((100*(1-score_img_path4))))
except:
continue
3.训练⾃⼰的数据集
(1)训练参数配置
#networkarchitectureandinitializationoptions
_argument('--arch','-a',metavar='ARCH',default='resnet50',choices=model_names,
help='modelarchitecture:'+
'|'.join(model_names)+
'(default:resnet101)')
_argument('--pool','-p',metavar='POOL',default='gem',choices=pool_names,
help='poolingoptions:'+
'|'.join(pool_names)+
'(default:gem)')
_argument('--local-whitening','-lw',dest='local_whitening',action='store_true',
help='trainmodelwithlearnablelocalwhitening(linearlayer)beforethepooling')
_argument('--regional','-r',dest='regional',action='store_true',
help='trainmodelwithregionalpoolingusingfixedgrid')
_argument('--whitening','-w',dest='whitening',action='store_true',
help='trainmodelwithlearnablewhitening(linearlayer)afterthepooling')
_argument('--not-pretrained',dest='pretrained',action='store_fal',
help='initializemodelwithrandomweights(default:pretrainedonimagenet)')
_argument('--loss','-l',metavar='LOSS',default='contrastive',
choices=loss_names,
help='traininglossoptions:'+
'|'.join(loss_names)+
'(default:contrastive)')
_argument('--loss-margin','-lm',metavar='LM',default=0.7,type=float,
help='lossmargin:(default:0.7)')
(2)训练数据的准备
这⾥训练数据⽤的是retrieval-SfM-120k但是因为数据集38个GB,⽹速不⾏,在外⽹上下不下来,所以⽓急败坏的我直接看他的标签⽂
件。
这个⽂件就是⼀个字典格式⽂件,⼤概分了⼏层如下,因为我没有准备验证集和测试集,所以训练时候测试和验证那部分我直接删去了(主
要因为测试集的格式和训练集不⼀样,我懒得再去解析另⼀个数据集的格式)。
{train:{
cids:[],cluster:[],qidxs:[],pidxs:[]
},
val:{…}
}
①cids:主要⽤来存放所有图⽚的路径,所以不管你图⽚存放在哪,只要有图⽚路径即可。数组长度就是总的图⽚数量。
②cluster:这个是存放该图⽚的类别,数组的长度和cids⼀样,类别⼀⼀对应cids的图⽚(retrieval-SfM-120k是有713个建筑物所以
是713类,依据⾃⼰的数据集⽽定,我的数据集每对图⽚都是⼀个类,所以有⼏千个类别
本文发布于:2023-03-01 23:11:49,感谢您对本站的认可!
本文链接:https://www.wtabcd.cn/fanwen/zuowen/1677683509102382.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文word下载地址:百度以图搜图.doc
本文 PDF 下载地址:百度以图搜图.pdf
留言与评论(共有 0 条评论) |