NLP任务样本数据不均衡问题解决⽅案的总结和数据增强回译
的实战展⽰
⽬录
在做NLP分类标注等任务的时候,避免不了会遇到样本不均衡的情况,那么我们就需要处理这个问题,这样才能使模型有良好的表现。为此,在收集了⼀些资料以后,做了⼀个简单总结,⽅便以后回顾(怕跳槽⾯试的时候问道答不上来)。主要是从数据、算法和模型评价标准这个三个⽅⾯,来减少数据不平衡对模型性能的影响。
⼀、数据层⾯
当数据极度不平衡的时候,最容易相到的解决⽅案,就是从数据层⾯出发,⼩类数据太少了,那么就增加⼩类数据;⼤类样本太多了就删除⼀些样本。不管是2分类还是多分类,样本不均衡的表现都是样本数据数⽬之间存在着很⼤的差异。为了克服这个问题,实质上就要把数据经过⼀定的处理,变得不那么不均衡,⽐例适当⼀些。有实验表明,只要数据之间的⽐例超过了1:4,就会对算法造成偏差影响。针对数据⽐重失调,就可以对原始数据集进⾏采样调整,这⾥主要是⽋采样和过采样。
1、⽋采样(under-sampling)
对⼤类的数据样本进⾏采样来减少该类数据的样本个数。使⽤的⼀般经验规则,⼀般⽽⾔是对样本数⽬超过1W,10W 甚⾄更多,进⾏⽋采样。⼀般简单的做法,就是随机的删除部分样本。注意的是,⼀般很少使⽤⽋采样,标注数据的成本⽐较⾼,⽽深度学习的⽅法是数据量越⾼越好,所以⼀般都是使⽤过采样。
2、过采样
对⼩类数据的样本进⾏采样来增加⼩类样本数据的个数。Smote算法(它就是在少数类样本中⽤KNN⽅法合成了新样本)⼀般⽤来进⾏过采样的操作,这⾥有⼀点不⽅便的地⽅就是NLP任务中,不好使⽤Smote算法,我们的样本⼀般都是⽂本数据,不是直接的数字数据,只有把⽂本数据转化为数字数据才能进⾏smote操作。另外现在⼀般都是基于预训练模型做微调的,⽂本的向量表⽰也是变化的,所有不能进⾏smote算法来增加⼩类数据。那么针对NLP进⾏过采样的⼀些⽅法有那些呢?
1. 最简单的就是直接复制⼩类样本,从⽽达到增加⼩类样本数据的⽬的。这样的⽅法缺点也是很明显的,实际上样本中并没有加⼊新的
特征,特征还是很少,那么就会出现过拟合的问题。
2. 对⼩类样本数据经过⼀定的处理,做⼀些⼩的改变。例如随机的打乱词的顺序,句⼦的顺序;随机的删除⼀些词,⼀些句⼦;裁剪⽂
本的开头或者结尾等。我认为这些⼩⽅法⾄合适对语序不是特别重要的任务,像⼀些对语序特征特别重要的序列任务这种操做就不太恰当。
3. 复述⽣成:这个就属性q2q任务,根据原始问题成成格式更好的问题,然后把新问题替换到问答系统中。
4. EDA:同义词替换、随机插⼊和随机交换
5. 回译(back translation) 把中⽂——英⽂(其他的语⾔)——中⽂
6. ⽣成对抗⽹络——GAN
个⼈认为使⽤复述⽣成和回译以及⽣成对抗⽹络应该是最有效的,因为它们在做数据增强的时候,对原始数据做的处理使得语义发⽣了变化,但同时⼜保证了整个语义的完整性。随机删除的词,打乱顺序的⽅式,我认为对数据的整个语义破坏太⼤了。当然,这些技巧都值得在具体的数据集下做对应的实验,说不定它恰好就在这个数据集上起很重要的作⽤。
另外我⾃⼰做过的⼀些实践,回译是⽐较不错的,在百度翻译API免费的前提下,⼏乎没有成本。另外的复述⽣成和⽣成对抗⽹络不知道,听说⽣成对抗⽹络很难也很⿇烦。
⼆、算法层⾯
1、权重设置
在训练的时候给损失函数直接设定⼀定的⽐例,使得算法能够对⼩类数据更多的注意⼒。例如在深度学习中,做⼀个3分类任务,标签a、b、c的样本⽐例为1:1:8。在我们的交叉熵损失函数中就可以⽤类似这样的权重设置:
2、新的损失函数——Focal Loss
import torch
from torch import nn
import functional as F
import time
class focal_loss(nn.Module):
手机英语怎么说
"""
需要保证每个batch的长度⼀样,不然会报错。
"""
蒙眼睛
def __init__(lf,alpha=0.25,gamma = 2, num_class = 2, size_average =True):
"""
focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi) = -α(1-yi)**γ * log(yi)
:param alpha:
:param gamma:
:param num_class:
:param size_average:
"""
super(focal_loss, lf).__init__()
lf.size_average = size_average
if isinstance(alpha,list):
# α可以以list⽅式输⼊,size:[num_class] ⽤于对不同类别精细地赋予权重
asrt len(alpha) == num_class
print("Focal_loss alpha = {},对每⼀类权重进⾏精细化赋值".format(alpha))
lf.alpha = sor(alpha)
el:
asrt alpha<1 #如果α为⼀个常数,则降低第⼀类的影响
print("--- Focal_loss alpha = {},将对背景类或者⼤类负样本进⾏权重衰减".format(alpha))
lf.alpha = s(num_class)
lf.alpha[0] += alpha
lf.alpha[1:] += (1-alpha)
lf.gamma = gamma
def forward(lf, preds,labels):
"""
focal_loss损失计算
:param preds: 预测类别. size:[B,N,C] or [B,C] B:batch N:检测框数⽬ C:类别数
:param labels: 实际类别. size:[B,N] or [B]
:return:
"""
preds = preds.view(-1, preds.size(-1))
lf.alpha = (preds.device)
# 这⾥并没有直接使⽤log_softmax, 因为后⾯会⽤到softmax的结果(当然你也可以使⽤log_softmax,然后进⾏exp操作) preds_softmax = F.softmax(preds,dim=1)
preds_logsoft = torch.log(preds_softmax)
# 这部分实现nll_loss ( crosmpty = log_softmax + nll )写意花鸟画
preds_softmax = preds_softmax.gather(1,labels.view(-1,1))
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
lf.alpha = lf.alpha.gather(0,labels.view(-1))
loss = -torch.mul(torch.pow((1-preds_softmax),lf.gamma),preds_logsoft)
loss = torch.mul(lf.alpha,loss.t())
甲亢灵if lf.size_average:
loss = an()
el:
loss = loss.sum()
return loss
有⼀个要注意的是
lf.alpha = lf.alpha.gather(0,labels.view(-1))
当传⼊的labels长度不⼀致,就会使得lf.alpha的长度不⼀样,进⽽报错。所有要保证训练的时候每个bath传⼊的数据长度要⼀致。三、评价⽅式
在模型评价的时候,我们⼀般简单的采⽤accuracy就可以了。但是在样本数据极度不平衡,特别是那种重点关注⼩类识别准确率的时候,就不能使⽤accuracy来评价模型了。要使⽤precision和recall来综合考虑模型的性能,降低⼩类分错的⼏率。在pytorch中,⼀般使⽤tensor来计算,下⾯给出关于tensor计算precision和recall的代码,主要是熟悉tensor的操作——孰能⽣巧。
correct += (predict == label).sum().item()
total += label.size(0)
train_acc = correct / total
#精确率、recall和F1的计算
for i in range(lf.number_of_class):
if i == lf.none_label:
continue
#TP和FP
lf._true_positives += ((predictions==i)*(gold_labels==i)*mask.bool()).sum()
lf._fal_positives += ((predictions==i)*(gold_labels!=i)*mask.bool()).sum()
鹅蛋炒香菜#TN和FN
lf._true_negatives += ((predictions!=i)*(gold_labels!=i)*mask.bool()).sum()
lf._fal_negatives += ((predictions!=i)*(gold_labels==i)*mask.bool()).sum()
#精确率、
precision = float(lf._true_positives) / (float(lf._true_positives + lf._fal_positives) + 1e-13)
#recall
recall = float(lf._true_positives) / (float(lf._true_positives + lf._fal_negatives) + 1e-13)
#F1
f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))
四、数据增强实战——回译(back translate)
尝试过的库或者API分别是Translator、TextBlob 和百度翻译的API。其实这些⽅法都是在⽹上都有,这⾥我做⼀个总结吧。
1、Translator
⾸选看看Translator,这个翻译的库⽤的是MyMeory的API,免费的限制是每天1000words。安装Translator
from translate import Translator
直接看⽰例:
from translate import Translator
def translation_translate(text):
print(text)
translator = Translator(from_lang="chine", to_lang="english")
translation = anslate(text)
print(translation)
print(len(translation))
if len(translation)> 500:
translation = translation[0:500]
print(translation)
translator = Translator(from_lang="english", to_lang="chine")
translation = anslate(translation)
print(len(translation))
print(translation)
return translation
if __name__ == '__main__':
text = '国家“⼗五”重⼤专项“创新药物和中药现代化”(863计划2004AA2Z3380)。基因⼯程药物注射给药存在着:⾎浆半衰期较短,⽣物利⽤度不⾼;抗原性较强,易 translation_translate(text)
注意的是text的长度不能超过500。所以这个做回译还是有⼀定的限制的,要是text翻译到中间语⾔,中间语⾔的长度超过了500,要做
截取处理,语义就会丢失很多。⽽且每天的字数也有限制,⼀天1000字太少了。但是翻译效果还不错,如下:
国家“⼗五”重⼤专项“创新药物和中药现代化”(863计划2004AA2Z3380)。基因⼯程药物注射给药存在着:⾎浆半衰期较短,⽣物利⽤度不⾼;抗原性较强,易引起过敏The national "Fifth Five-Year Plan" major special "innovative drugs and modernization of Chine medicine" (863 plan 2004A2Z380). The injection of genetically e 320
The national "Fifth Five-Year Plan" major special "innovative drugs and modernization of Chine medicine" (863 plan 2004A2Z380). The injection of genetically e 85
国家"⼗五"重⼤专项"创新药物与中药现代化"(863计划2004A2Z380)。基因⼯程药物的注射存在:⾎浆半寿命短,⽣物利⽤度不⾼,抗原强,容易引起过敏反应等不
2、TextBlob
类似Translator的使⽤,但是这个是调⽤Google翻译的API,内⽹⽤不了。
3、百度翻译API
使⽤这个来做翻译的话,需要使⽤import http.client模块⼉来实现,百度也给出了详细的教程。我这⾥的⼀个需求是需要做数据增
强,每条数据需要,使⽤6种语⾔来做回译,才能配平样本⽐例。直接上代码:
核⼼函数:
def baidu_translate(content,from_lang,to_lang):
appid = '×××××××××'
cretKey = '××××××××××××××'
httpClient = None
myurl = '/api/trans/vip/translate'
q = content
fromLang = from_lang # 源语⾔
乌镇作文toLang = to_lang # 翻译后的语⾔
salt = random.randint(32768, 65536)
sign = appid + q + str(salt) + cretKey
sign = hashlib.de()).hexdigest()
myurl = myurl + '?appid=' + appid + '&q=' + urllib.par.quote(
q) + '&from=' + fromLang + '&to=' + toLang + '&salt=' + str(
salt) + '&sign=' + sign
try:
httpClient = http.client.HTTPConnection('api.')
# respon是HTTPRespon对象
respon = spon()
jsonRespon = ad().decode("utf-8") # 获得返回的结果,结果为json格式
js = json.loads(jsonRespon) # 将json格式的结果转换字典结构
dst = str(js["trans_result"][0]["dst"]) # 取得翻译后的⽂本结果
# print(dst) # 打印结果
return dst
except Exception as e:
print('err:',e)
finally:
if httpClient:
httpClient.clo()
def do_translate(content,from_lang,to_lang):
if len(content)>= 260:
content = content[0:260]
temp = baidu_translate(content,from_lang,to_lang)
time.sleep(1)#百度API免费调⽤的QPS=1,所以要1s以后才能调⽤
if temp is None:100首古诗大全
temp = 0
if len(temp) >= 1500:
temp = temp[0:1500]
京剧的资料res = baidu_translate(temp,to_lang,from_lang)
return res
遇到的⼀些坑:
注意到,这⾥使⽤的是标准版,没有收费,⽬前是免费的,但是以后说不定就不会开放免费的版本了。另外QPS=1,也就是1秒内并发能⼒只有1,所有这个在代码中,⽤了time.sleep(1),保证API被及时调⽤,⽽不会报错。最后由于我的中⽂预料长度很长⼤都在100-500之间,翻译成其他语⾔,字
符数就有1500-2000多,虽然百度API对字符数长度放宽了,但是不做长度处理还是会报错,这个就需要⾃⼰有针对性的调整了。
看⼀看下⾯的回译的结果,原始的中⽂就没有展⽰出来,这个6种不同语⾔,回译的⽂本。信息都有缺失,但是整体都还在,做数据增强就很不错了。