NLP任务样本数据不均衡问题解决方案的总结和数据增强回译的实战展示

更新时间:2023-07-27 03:24:38 阅读: 评论:0

NLP任务样本数据不均衡问题解决⽅案的总结和数据增强回译
的实战展⽰
⽬录
在做NLP分类标注等任务的时候,避免不了会遇到样本不均衡的情况,那么我们就需要处理这个问题,这样才能使模型有良好的表现。为此,在收集了⼀些资料以后,做了⼀个简单总结,⽅便以后回顾(怕跳槽⾯试的时候问道答不上来)。主要是从数据、算法和模型评价标准这个三个⽅⾯,来减少数据不平衡对模型性能的影响。
⼀、数据层⾯
当数据极度不平衡的时候,最容易相到的解决⽅案,就是从数据层⾯出发,⼩类数据太少了,那么就增加⼩类数据;⼤类样本太多了就删除⼀些样本。不管是2分类还是多分类,样本不均衡的表现都是样本数据数⽬之间存在着很⼤的差异。为了克服这个问题,实质上就要把数据经过⼀定的处理,变得不那么不均衡,⽐例适当⼀些。有实验表明,只要数据之间的⽐例超过了1:4,就会对算法造成偏差影响。针对数据⽐重失调,就可以对原始数据集进⾏采样调整,这⾥主要是⽋采样和过采样。
1、⽋采样(under-sampling)
对⼤类的数据样本进⾏采样来减少该类数据的样本个数。使⽤的⼀般经验规则,⼀般⽽⾔是对样本数⽬超过1W,10W 甚⾄更多,进⾏⽋采样。⼀般简单的做法,就是随机的删除部分样本。注意的是,⼀般很少使⽤⽋采样,标注数据的成本⽐较⾼,⽽深度学习的⽅法是数据量越⾼越好,所以⼀般都是使⽤过采样。
2、过采样
对⼩类数据的样本进⾏采样来增加⼩类样本数据的个数。Smote算法(它就是在少数类样本中⽤KNN⽅法合成了新样本)⼀般⽤来进⾏过采样的操作,这⾥有⼀点不⽅便的地⽅就是NLP任务中,不好使⽤Smote算法,我们的样本⼀般都是⽂本数据,不是直接的数字数据,只有把⽂本数据转化为数字数据才能进⾏smote操作。另外现在⼀般都是基于预训练模型做微调的,⽂本的向量表⽰也是变化的,所有不能进⾏smote算法来增加⼩类数据。那么针对NLP进⾏过采样的⼀些⽅法有那些呢?
1. 最简单的就是直接复制⼩类样本,从⽽达到增加⼩类样本数据的⽬的。这样的⽅法缺点也是很明显的,实际上样本中并没有加⼊新的
特征,特征还是很少,那么就会出现过拟合的问题。
2. 对⼩类样本数据经过⼀定的处理,做⼀些⼩的改变。例如随机的打乱词的顺序,句⼦的顺序;随机的删除⼀些词,⼀些句⼦;裁剪⽂
本的开头或者结尾等。我认为这些⼩⽅法⾄合适对语序不是特别重要的任务,像⼀些对语序特征特别重要的序列任务这种操做就不太恰当。
3. 复述⽣成:这个就属性q2q任务,根据原始问题成成格式更好的问题,然后把新问题替换到问答系统中。
4. EDA:同义词替换、随机插⼊和随机交换
5. 回译(back translation) 把中⽂——英⽂(其他的语⾔)——中⽂
6. ⽣成对抗⽹络——GAN
个⼈认为使⽤复述⽣成和回译以及⽣成对抗⽹络应该是最有效的,因为它们在做数据增强的时候,对原始数据做的处理使得语义发⽣了变化,但同时⼜保证了整个语义的完整性。随机删除的词,打乱顺序的⽅式,我认为对数据的整个语义破坏太⼤了。当然,这些技巧都值得在具体的数据集下做对应的实验,说不定它恰好就在这个数据集上起很重要的作⽤。
另外我⾃⼰做过的⼀些实践,回译是⽐较不错的,在百度翻译API免费的前提下,⼏乎没有成本。另外的复述⽣成和⽣成对抗⽹络不知道,听说⽣成对抗⽹络很难也很⿇烦。
⼆、算法层⾯
1、权重设置
在训练的时候给损失函数直接设定⼀定的⽐例,使得算法能够对⼩类数据更多的注意⼒。例如在深度学习中,做⼀个3分类任务,标签a、b、c的样本⽐例为1:1:8。在我们的交叉熵损失函数中就可以⽤类似这样的权重设置:
2、新的损失函数——Focal Loss
import torch
from torch import nn
import functional as F
import time
class focal_loss(nn.Module):
手机英语怎么说
"""
需要保证每个batch的长度⼀样,不然会报错。
"""
蒙眼睛
def __init__(lf,alpha=0.25,gamma = 2, num_class = 2, size_average =True):
"""
focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi) = -α(1-yi)**γ * log(yi)
:param alpha:
:param gamma:
:param num_class:
:param size_average:
"""
super(focal_loss, lf).__init__()
lf.size_average = size_average
if isinstance(alpha,list):
# α可以以list⽅式输⼊,size:[num_class] ⽤于对不同类别精细地赋予权重
asrt len(alpha) == num_class
print("Focal_loss alpha = {},对每⼀类权重进⾏精细化赋值".format(alpha))
lf.alpha = sor(alpha)
el:
asrt alpha<1 #如果α为⼀个常数,则降低第⼀类的影响
print("--- Focal_loss alpha = {},将对背景类或者⼤类负样本进⾏权重衰减".format(alpha))
lf.alpha = s(num_class)
lf.alpha[0] += alpha
lf.alpha[1:] += (1-alpha)
lf.gamma = gamma
def forward(lf, preds,labels):
"""
focal_loss损失计算
:param preds: 预测类别. size:[B,N,C] or [B,C]  B:batch N:检测框数⽬ C:类别数
:param labels: 实际类别. size:[B,N] or [B]
:return:
"""
preds = preds.view(-1, preds.size(-1))
lf.alpha = (preds.device)
# 这⾥并没有直接使⽤log_softmax, 因为后⾯会⽤到softmax的结果(当然你也可以使⽤log_softmax,然后进⾏exp操作)        preds_softmax = F.softmax(preds,dim=1)
preds_logsoft = torch.log(preds_softmax)
# 这部分实现nll_loss ( crosmpty = log_softmax + nll )写意花鸟画
preds_softmax = preds_softmax.gather(1,labels.view(-1,1))
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
lf.alpha = lf.alpha.gather(0,labels.view(-1))
loss = -torch.mul(torch.pow((1-preds_softmax),lf.gamma),preds_logsoft)
loss = torch.mul(lf.alpha,loss.t())
甲亢灵if lf.size_average:
loss = an()
el:
loss = loss.sum()
return loss
有⼀个要注意的是
lf.alpha = lf.alpha.gather(0,labels.view(-1))
当传⼊的labels长度不⼀致,就会使得lf.alpha的长度不⼀样,进⽽报错。所有要保证训练的时候每个bath传⼊的数据长度要⼀致。三、评价⽅式
在模型评价的时候,我们⼀般简单的采⽤accuracy就可以了。但是在样本数据极度不平衡,特别是那种重点关注⼩类识别准确率的时候,就不能使⽤accuracy来评价模型了。要使⽤precision和recall来综合考虑模型的性能,降低⼩类分错的⼏率。在pytorch中,⼀般使⽤tensor来计算,下⾯给出关于tensor计算precision和recall的代码,主要是熟悉tensor的操作——孰能⽣巧。
correct += (predict == label).sum().item()
total += label.size(0)
train_acc = correct / total
#精确率、recall和F1的计算
for i in range(lf.number_of_class):
if i == lf.none_label:
continue
#TP和FP
lf._true_positives += ((predictions==i)*(gold_labels==i)*mask.bool()).sum()
lf._fal_positives += ((predictions==i)*(gold_labels!=i)*mask.bool()).sum()
鹅蛋炒香菜#TN和FN
lf._true_negatives += ((predictions!=i)*(gold_labels!=i)*mask.bool()).sum()
lf._fal_negatives += ((predictions!=i)*(gold_labels==i)*mask.bool()).sum()
#精确率、
precision = float(lf._true_positives) / (float(lf._true_positives + lf._fal_positives) + 1e-13)
#recall
recall = float(lf._true_positives) / (float(lf._true_positives + lf._fal_negatives) + 1e-13)
#F1
f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))
四、数据增强实战——回译(back translate)
尝试过的库或者API分别是Translator、TextBlob 和百度翻译的API。其实这些⽅法都是在⽹上都有,这⾥我做⼀个总结吧。
1、Translator
⾸选看看Translator,这个翻译的库⽤的是MyMeory的API,免费的限制是每天1000words。安装Translator
from translate import Translator
直接看⽰例:
from translate import Translator
def translation_translate(text):
print(text)
translator = Translator(from_lang="chine", to_lang="english")
translation = anslate(text)
print(translation)
print(len(translation))
if len(translation)> 500:
translation = translation[0:500]
print(translation)
translator = Translator(from_lang="english", to_lang="chine")
translation = anslate(translation)
print(len(translation))
print(translation)
return translation
if __name__ == '__main__':
text = '国家“⼗五”重⼤专项“创新药物和中药现代化”(863计划2004AA2Z3380)。基因⼯程药物注射给药存在着:⾎浆半衰期较短,⽣物利⽤度不⾼;抗原性较强,易    translation_translate(text)
注意的是text的长度不能超过500。所以这个做回译还是有⼀定的限制的,要是text翻译到中间语⾔,中间语⾔的长度超过了500,要做
截取处理,语义就会丢失很多。⽽且每天的字数也有限制,⼀天1000字太少了。但是翻译效果还不错,如下:
国家“⼗五”重⼤专项“创新药物和中药现代化”(863计划2004AA2Z3380)。基因⼯程药物注射给药存在着:⾎浆半衰期较短,⽣物利⽤度不⾼;抗原性较强,易引起过敏The national "Fifth Five-Year Plan" major special "innovative drugs and modernization of Chine medicine" (863 plan 2004A2Z380). The injection of genetically e 320
The national "Fifth Five-Year Plan" major special "innovative drugs and modernization of Chine medicine" (863 plan 2004A2Z380). The injection of genetically e 85
国家"⼗五"重⼤专项"创新药物与中药现代化"(863计划2004A2Z380)。基因⼯程药物的注射存在:⾎浆半寿命短,⽣物利⽤度不⾼,抗原强,容易引起过敏反应等不
2、TextBlob
类似Translator的使⽤,但是这个是调⽤Google翻译的API,内⽹⽤不了。
3、百度翻译API
使⽤这个来做翻译的话,需要使⽤import http.client模块⼉来实现,百度也给出了详细的教程。我这⾥的⼀个需求是需要做数据增
强,每条数据需要,使⽤6种语⾔来做回译,才能配平样本⽐例。直接上代码:
核⼼函数:
def baidu_translate(content,from_lang,to_lang):
appid = '×××××××××'
cretKey = '××××××××××××××'
httpClient = None
myurl = '/api/trans/vip/translate'
q = content
fromLang = from_lang  # 源语⾔
乌镇作文toLang = to_lang  # 翻译后的语⾔
salt = random.randint(32768, 65536)
sign = appid + q + str(salt) + cretKey
sign = hashlib.de()).hexdigest()
myurl = myurl + '?appid=' + appid + '&q=' + urllib.par.quote(
q) + '&from=' + fromLang + '&to=' + toLang + '&salt=' + str(
salt) + '&sign=' + sign
try:
httpClient = http.client.HTTPConnection('api.')
# respon是HTTPRespon对象
respon = spon()
jsonRespon = ad().decode("utf-8")  # 获得返回的结果,结果为json格式
js = json.loads(jsonRespon)  # 将json格式的结果转换字典结构
dst = str(js["trans_result"][0]["dst"])  # 取得翻译后的⽂本结果
# print(dst)  # 打印结果
return dst
except Exception as e:
print('err:',e)
finally:
if httpClient:
httpClient.clo()
def do_translate(content,from_lang,to_lang):
if len(content)>= 260:
content = content[0:260]
temp = baidu_translate(content,from_lang,to_lang)
time.sleep(1)#百度API免费调⽤的QPS=1,所以要1s以后才能调⽤
if temp is None:100首古诗大全
temp = 0
if len(temp) >= 1500:
temp = temp[0:1500]
京剧的资料res = baidu_translate(temp,to_lang,from_lang)
return res
遇到的⼀些坑:
注意到,这⾥使⽤的是标准版,没有收费,⽬前是免费的,但是以后说不定就不会开放免费的版本了。另外QPS=1,也就是1秒内并发能⼒只有1,所有这个在代码中,⽤了time.sleep(1),保证API被及时调⽤,⽽不会报错。最后由于我的中⽂预料长度很长⼤都在100-500之间,翻译成其他语⾔,字
符数就有1500-2000多,虽然百度API对字符数长度放宽了,但是不做长度处理还是会报错,这个就需要⾃⼰有针对性的调整了。
看⼀看下⾯的回译的结果,原始的中⽂就没有展⽰出来,这个6种不同语⾔,回译的⽂本。信息都有缺失,但是整体都还在,做数据增强就很不错了。

本文发布于:2023-07-27 03:24:38,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/1119274.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:数据   样本   时候
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图