NLP任务样本数据不均衡问题解决⽅案的总结和数据增强回译
的实战展⽰
⽬录
在做NLP分类标注等任务的时候,避免不了会遇到样本不均衡的情况,那么我们就需要处理这个问题,这样才能使模型有良好的表
现。为此,在收集了⼀些资料以后,做了⼀个简单总结,⽅便以后回顾(怕跳槽⾯试的时候问道答不上来)。主要是从数据、算法和模型评
价标准这个三个⽅⾯,来减少数据不平衡对模型性能的影响。
⼀、数据层⾯
当数据极度不平衡的时候,最容易相到的解决⽅案,就是从数据层⾯出发,⼩类数据太少了,那么就增加⼩类数据;⼤类样本太多了就
删除⼀些样本。不管是2分类还是多分类,样本不均衡的表现都是样本数据数⽬之间存在着很⼤的差异。为了克服这个问题,实质上就要把
数据经过⼀定的处理,变得不那么不均衡,⽐例适当⼀些。有实验表明,只要数据之间的⽐例超过了1:4,就会对算法造成偏差影响。针对
数据⽐重失调,就可以对原始数据集进⾏采样调整,这⾥主要是⽋采样和过采样。
1、⽋采样(under-sampling)
对⼤类的数据样本进⾏采样来减少该类数据的样本个数。使⽤的⼀般经验规则,⼀般⽽⾔是对样本数⽬超过1W,10W甚⾄更多,进
⾏⽋采样。⼀般简单的做法,就是随机的删除部分样本。注意的是,⼀般很少使⽤⽋采样,标注数据的成本⽐较⾼,⽽深度学习的⽅法是数
据量越⾼越好,所以⼀般都是使⽤过采样。
2、过采样
对⼩类数据的样本进⾏采样来增加⼩类样本数据的个数。Smote算法(它就是在少数类样本中⽤KNN⽅法合成了新样本)⼀般⽤来进
⾏过采样的操作,这⾥有⼀点不⽅便的地⽅就是NLP任务中,不好使⽤Smote算法,我们的样本⼀般都是⽂本数据,不是直接的数字数据,
只有把⽂本数据转化为数字数据才能进⾏smote操作。另外现在⼀般都是基于预训练模型做微调的,⽂本的向量表⽰也是变化的,所有不能
进⾏smote算法来增加⼩类数据。那么针对NLP进⾏过采样的⼀些⽅法有那些呢?
1.最简单的就是直接复制⼩类样本,从⽽达到增加⼩类样本数据的⽬的。这样的⽅法缺点也是很明显的,实际上样本中并没有加⼊新的
特征,特征还是很少,那么就会出现过拟合的问题。
2.对⼩类样本数据经过⼀定的处理,做⼀些⼩的改变。例如随机的打乱词的顺序,句⼦的顺序;随机的删除⼀些词,⼀些句⼦;裁剪⽂
本的开头或者结尾等。我认为这些⼩⽅法⾄合适对语序不是特别重要的任务,像⼀些对语序特征特别重要的序列任务这种操做就不太
恰当。
3.复述⽣成:这个就属性q2q任务,根据原始问题成成格式更好的问题,然后把新问题替换到问答系统中。
:同义词替换、随机插⼊和随机交换
5.回译(backtranslation)把中⽂——英⽂(其他的语⾔)——中⽂
6.⽣成对抗⽹络——GAN
个⼈认为使⽤复述⽣成和回译以及⽣成对抗⽹络应该是最有效的,因为它们在做数据增强的时候,对原始数据做的处理使得语义发⽣了变
化,但同时⼜保证了整个语义的完整性。随机删除的词,打乱顺序的⽅式,我认为对数据的整个语义破坏太⼤了。当然,这些技巧都值得在
具体的数据集下做对应的实验,说不定它恰好就在这个数据集上起很重要的作⽤。
另外我⾃⼰做过的⼀些实践,回译是⽐较不错的,在百度翻译API免费的前提下,⼏乎没有成本。另外的复述⽣成和⽣成对抗⽹络不知道,
听说⽣成对抗⽹络很难也很⿇烦。
⼆、算法层⾯
1、权重设置
在训练的时候给损失函数直接设定⼀定的⽐例,使得算法能够对⼩类数据更多的注意⼒。例如在深度学习中,做⼀个3分类任务,标签
a、b、c的样本⽐例为1:1:8。在我们的交叉熵损失函数中就可以⽤类似这样的权重设置:
ntropyLoss(weight=_numpy(([8,8,1])).float().to(device))
2、新的损失函数——FocalLoss
importtorch
fromtorchimportnn
rtfunctionalasF
importtime
classfocal_loss():
"""
需要保证每个batch的长度⼀样,不然会报错。
"""
def__init__(lf,alpha=0.25,gamma=2,num_class=2,size_average=True):
"""
focal_loss损失函数,-α(1-yi)**γ*ce_loss(xi,yi)=-α(1-yi)**γ*log(yi)
:paramalpha:
:paramgamma:
:paramnum_class:
:paramsize_average:
"""
super(focal_loss,lf).__init__()
_average=size_average
ifisinstance(alpha,list):
#α可以以list⽅式输⼊,size:[num_class]⽤于对不同类别精细地赋予权重
asrtlen(alpha)==num_class
print("Focal_lossalpha={},对每⼀类权重进⾏精细化赋值".format(alpha))
=(alpha)
el:
asrtalpha<1#如果α为⼀个常数,则降低第⼀类的影响
print("---Focal_lossalpha={},将对背景类或者⼤类负样本进⾏权重衰减".format(alpha))
=(num_class)
[0]+=alpha
[1:]+=(1-alpha)
=gamma
defforward(lf,preds,labels):
"""
focal_loss损失计算
:parampreds:预测类别.size:[B,N,C]or[B,C]B:batchN:检测框数⽬C:类别数
:paramlabels:实际类别.size:[B,N]or[B]
:return:
"""
preds=(-1,(-1))
=()
#这⾥并没有直接使⽤log_softmax,因为后⾯会⽤到softmax的结果(当然你也可以使⽤log_softmax,然后进⾏exp操作)
preds_softmax=x(preds,dim=1)
preds_logsoft=(preds_softmax)
#这部分实现nll_loss(crosmpty=log_softmax+nll)
preds_softmax=preds_(1,(-1,1))
preds_logsoft=preds_(1,(-1,1))
=(0,(-1))
loss=-(((1-preds_softmax),),preds_logsoft)
loss=(,loss.t())
_average:
loss=()
el:
loss=()
returnloss
有⼀个要注意的是
=(0,(-1))
当传⼊的labels长度不⼀致,就会使得的长度不⼀样,进⽽报错。所有要保证训练的时候每个bath传⼊的数据长度要⼀致。
三、评价⽅式
在模型评价的时候,我们⼀般简单的采⽤accuracy就可以了。但是在样本数据极度不平衡,特别是那种重点关注⼩类识别准确率的时候,
就不能使⽤accuracy来评价模型了。要使⽤precision和recall来综合考虑模型的性能,降低⼩类分错的⼏率。在pytorch中,⼀般使⽤
tensor来计算,下⾯给出关于tensor计算precision和recall的代码,主要是熟悉tensor的操作——孰能⽣巧。
correct+=(predict==label).sum().item()
total+=(0)
train_acc=correct/total
#精确率、recall和F1的计算
foriinrange(_of_class):
ifi==_label:
continue
#TP和FP
lf._true_positives+=((predictions==i)*(gold_labels==i)*()).sum()
lf._fal_positives+=((predictions==i)*(gold_labels!=i)*()).sum()
#TN和FN
lf._true_negatives+=((predictions!=i)*(gold_labels!=i)*()).sum()
lf._fal_negatives+=((predictions!=i)*(gold_labels==i)*()).sum()
#精确率、
precision=float(lf._true_positives)/(float(lf._true_positives+lf._fal_positives)+1e-13)
#recall
recall=float(lf._true_positives)/(float(lf._true_positives+lf._fal_negatives)+1e-13)
#F1
f1_measure=2.*((precision*recall)/(precision+recall+1e-13))
四、数据增强实战——回译(backtranslate)
尝试过的库或者API分别是Translator、TextBlob和百度翻译的API。其实这些⽅法都是在⽹上都有,这⾥我做⼀个总结吧。
1、Translator
⾸选看看Translator,这个翻译的库⽤的是MyMeory的API,免费的限制是每天1000words。安装Translator
fromtranslateimportTranslator
直接看⽰例:
fromtranslateimportTranslator
deftranslation_translate(text):
print(text)
translator=Translator(from_lang="chine",to_lang="english")
translation=ate(text)
print(translation)
print(len(translation))
iflen(translation)>500:
translation=translation[0:500]
print(translation)
translator=Translator(from_lang="english",to_lang="chine")
translation=ate(translation)
print(len(translation))
print(translation)
returntranslation
if__name__=='__main__':
text='国家“⼗五”重⼤专项“创新药物和中药现代化”(863计划2004AA2Z3380)。基因⼯程药物注射给药存在着:⾎浆半衰期较短,⽣物利⽤度不⾼;抗原性较强,易
translation_translate(text)
注意的是text的长度不能超过500。所以这个做回译还是有⼀定的限制的,要是text翻译到中间语⾔,中间语⾔的长度超过了500,要做
截取处理,语义就会丢失很多。⽽且每天的字数也有限制,⼀天1000字太少了。但是翻译效果还不错,如下:
国家“⼗五”重⼤专项“创新药物和中药现代化”(863计划2004AA2Z3380)。基因⼯程药物注射给药存在着:⾎浆半衰期较短,⽣物利⽤度不⾼;抗原性较强,易引起过敏
Thenational"FifthFive-YearPlan"majorspecial"innovativedrugsandmodernizationofChinemedicine"(863plan2004A2Z380).Theinjectionofgeneticallye
320
Thenational"FifthFive-YearPlan"majorspecial"innovativedrugsandmodernizationofChinemedicine"(863plan2004A2Z380).Theinjectionofgeneticallye
85
国家"⼗五"重⼤专项"创新药物与中药现代化"(863计划2004A2Z380)。基因⼯程药物的注射存在:⾎浆半寿命短,⽣物利⽤度不⾼,抗原强,容易引起过敏反应等不
2、TextBlob
类似Translator的使⽤,但是这个是调⽤Google翻译的API,内⽹⽤不了。
3、百度翻译API
使⽤这个来做翻译的话,需要使⽤模块⼉来实现,百度也给出了详细的教程。我这⾥的⼀个需求是需要做数据增
强,每条数据需要,使⽤6种语⾔来做回译,才能配平样本⽐例。直接上代码:
核⼼函数:
defbaidu_translate(content,from_lang,to_lang):
appid='×××××××××'
cretKey='××××××××××××××'
httpClient=None
myurl='/api/trans/vip/translate'
q=content
fromLang=from_lang#源语⾔
toLang=to_lang#翻译后的语⾔
salt=t(32768,65536)
sign=appid+q+str(salt)+cretKey
sign=5(()).hexdigest()
myurl=myurl+'?appid='+appid+'&q='+(
q)+'&from='+fromLang+'&to='+toLang+'&salt='+str(
salt)+'&sign='+sign
try:
httpClient=nnection('')
t('GET',myurl)
#respon是HTTPRespon对象
respon=pon()
jsonRespon=().decode("utf-8")#获得返回的结果,结果为json格式
js=(jsonRespon)#将json格式的结果转换字典结构
dst=str(js["trans_result"][0]["dst"])#取得翻译后的⽂本结果
#print(dst)#打印结果
returndst
exceptExceptiona:
print('err:',e)
finally:
ifhttpClient:
()
defdo_translate(content,from_lang,to_lang):
iflen(content)>=260:
content=content[0:260]
temp=baidu_translate(content,from_lang,to_lang)
(1)#百度API免费调⽤的QPS=1,所以要1s以后才能调⽤
iftempisNone:
temp=0
iflen(temp)>=1500:
temp=temp[0:1500]
res=baidu_translate(temp,to_lang,from_lang)
returnres
遇到的⼀些坑:
注意到,这⾥使⽤的是标准版,没有收费,⽬前是免费的,但是以后说不定就不会开放免费的版本了。另外QPS=1,也就是1秒内并发能⼒
只有1,所有这个在代码中,⽤了(1),保证API被及时调⽤,⽽不会报错。最后由于我的中⽂预料长度很长⼤都在100-500之
间,翻译成其他语⾔,字符数就有1500-2000多,虽然百度API对字符数长度放宽了,但是不做长度处理还是会报错,这个就需要⾃⼰有
针对性的调整了。
看⼀看下⾯的回译的结果,原始的中⽂就没有展⽰出来,这个6种不同语⾔,回译的⽂本。信息都有缺失,但是整体都还在,做数据增强就
很不错了。
本项⽬属于有⾊⾦属材料制备加⼯技术领域。通过系统研究,证明了铜合⾦纳⽶强化相的形核、长⼤机理和强化机理,突破了引⼊纳⽶强化相、控制弥散分布等
这个项⽬属于有⾊⾦属材料的调制加⼯技术领域。通过系统研究,明确了在铜合⾦中纳⽶强化相核增长机构和强化机构,突破了纳⽶强化相的引进和扩散分布控制的共
该项⽬属于有⾊⾦属材料制品加⼯技术领域。通过系统研究,通过联合合⾦中纳⽶通过象形核成长机制和加强机制的引进纳⽶强化奖和分布控制的共同技术难题,开发
该项⽬是有⾊⾦属材料制造技术领域的⼀部分。通过系统研究,确定了在铜合⾦中⽣长和加强强化纳⽶芯的机制。1.克服引⼊和控制增强纳⽶散射的共同技术困难;开
该项⽬是有⾊⾦属制备和加⼯技术领域的⼀部分。通过引⼊纳⽶增强相、控制分散分布等常见技术问题,突破了铜合⾦中纳⽶增强相的核与⽣长机理和强化机理,发展
本项⽬属于有⾊⾦属的加⼯和加⼯技术领域。通过系统的研究,已经确定了纳⽶热的增长和加强机制。(a)在铜合⾦中强化核,克服了采⽤纳⽶热相和控制弥散分布的
完整代码:
importpandasaspd
importhashlib
importjson
importurllib
importrandom
importtime
fromtqdmimporttqdm
importcsv
defbaidu_translate(content,from_lang,to_lang):
appid='××××××××××××××
cretKey='××××××××××××××××××××'
httpClient=None
myurl='/api/trans/vip/translate'
q=content
fromLang=from_lang#源语⾔
toLang=to_lang#翻译后的语⾔
salt=t(32768,65536)
sign=appid+q+str(salt)+cretKey
sign=5(()).hexdigest()
myurl=myurl+'?appid='+appid+'&q='+(
q)+'&from='+fromLang+'&to='+toLang+'&salt='+str(
salt)+'&sign='+sign
try:
httpClient=nnection('')
t('GET',myurl)
#respon是HTTPRespon对象
respon=pon()
jsonRespon=().decode("utf-8")#获得返回的结果,结果为json格式
js=(jsonRespon)#将json格式的结果转换字典结构
dst=str(js["trans_result"][0]["dst"])#取得翻译后的⽂本结果
#print(dst)#打印结果
returndst
exceptExceptiona:
print('err:',e)
finally:
ifhttpClient:
()
defdo_translate(content,from_lang,to_lang):
iflen(content)>=260:
content=content[0:260]
temp=baidu_translate(content,from_lang,to_lang)
(1)#百度API免费调⽤的QPS=1,所以要1s以后才能调⽤
(1)#百度API免费调⽤的QPS=1,所以要1s以后才能调⽤
iftempisNone:
temp=0
iflen(temp)>=1500:
temp=temp[0:1500]
res=baidu_translate(temp,to_lang,from_lang)
returnres
defback_translate(A_title,R_title,A_content,R_content,level,writer):
new_A_titles=[]
new_R_titles=[]
new_A_contents=[]
new_R_contents=[]
new_levels=[]
fromlang_tolangs=[
('zh','en'),
('zh','jp'),
('zh','kor'),
('zh','fra'),
('zh','de'),
('zh','ru')
]
foreleinfromlang_tolangs:
from_lang=ele[0]
to_lang=ele[1]
A_content_new=do_translate(A_content,from_lang,to_lang)
(1)#百度API免费调⽤的QPS=1,所以要1s以后才能调⽤
R_content_new=do_translate(R_content,from_lang,to_lang)
(1)#百度API免费调⽤的QPS=1,所以要1s以后才能调⽤
new_A_(A_title)
new_R_(R_title)
new_A_(A_content_new)
new_R_(R_content_new)
new_(level)
ow([A_title,A_content_new,R_title,R_content_new,level])
returnnew_A_titles,new_R_titles,new_A_contents,new_R_contents,new_levels
if__name__=='__main__':
orginal_data=_csv('data/interrelation_',p='t')
print(orginal_y(['Level']).size())
A_titles=orginal_data[(orginal_data['Level']==3)|(orginal_data['Level']==4)]['A_title'].()
R_titles=orginal_data[(orginal_data['Level']==3)|(orginal_data['Level']==4)]['R_title'].()
A_contents=orginal_data[(orginal_data['Level']==3)|(orginal_data['Level']==4)]['A_content'].()
R_contents=orginal_data[(orginal_data['Level']==3)|(orginal_data['Level']==4)]['R_content'].()
levels=orginal_data[(orginal_data['Level']==3)|(orginal_data['Level']==4)]['Level'].()
A_title_new=[]
R_title_new=[]
A_content_new=[]
R_content_new=[]
levels_new=[]
count=0
csv_header=['A_title','A_content','R_title','R_content','Level']
withopen('data/final_augment_data_','w')asf:
writer=(f)
ow(csv_header)
forA_title,R_title,A_content,R_content,levelintqdm(list(zip(A_titles,R_titles,A_contents,R_contents,levels)),desc='回译执⾏:'):
ifcount>=311:
new_A_titles,new_R_titles,new_A_contents,new_R_contents,new_levels=back_translate(A_title,R_title,A_content,R_content,level,writer)
A_title_(new_A_titles)
R_title_(new_R_titles)
R_title_(new_R_titles)
A_content_(new_A_contents)
R_content_(new_R_contents)
levels_(new_levels)
count+=1
参考⽂章:
本文发布于:2023-03-12 09:01:32,感谢您对本站的认可!
本文链接:https://www.wtabcd.cn/zhishi/a/16785828928684.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文word下载地址:回译.doc
本文 PDF 下载地址:回译.pdf
留言与评论(共有 0 条评论) |