transformer基本架构及代码实现

更新时间:2023-07-30 13:36:52 阅读: 评论:0

transformer基本架构及代码实现
从2018年Google提出BERT模型开始,transformer结构就在NLP领域⼤杀四⽅,使⽤transformer的BERT模型在当时横扫NLP领域的11项任务,取得SOTA成绩,包括⼀直到后来相继出现的XLNET,roBERT等,均采⽤transformer结构作为核⼼。在著名的SOTA机器翻译排⾏榜上,⼏乎所有排名靠前的模型都是⽤transformer。那么在transformer出现之前,占领市场的⼀直都是LSTM和GRU等模型,相⽐之
下,transformer具有如下两个显著的优势:
2.在分析预测长序列⽂本时,transformer能够捕捉间隔较长的语义关联效果。
由于transformer在NLP领域的巨⼤成功,使得研究⼈员很⾃然的想到,如果将其应⽤于CV领域,⼜会取得怎样的效果呢,毕竟CV领域中的模型长期以来都是被CNNs主导,如果transformer能在CV领域进⾏适配和创新,是否能为CV模型的发展开辟⼀条新的道路。果然,近期transformer⼜在CV领域杀疯了,关于transformer的视觉模型在各⼤顶会论⽂中登场,其中⼜有不少模型实现了优于CNNs的效果。
今天我们就从最原始的transformer模型⼊⼿,来带⼤家彻底认识⼀下transformer。
transformer的架构
transformer的总体架构如下图:
从上图可以看到,transformer的总体架构可以分为四个部分:输⼊、输出、编码器和解码器,以机器翻译任务为例,各个部分的组成如下:输⼊部分(橙⾊区域)包含:
1.源⽂本的嵌⼊层以及位置编码器
2.⽬标⽂本的嵌⼊层以及位置编码器
输出部分(蓝⾊区域)包含:
1.线性层
2.softmax层
编码器部分(红⾊区域):
1.由N个编码器层堆叠⽽成
2.每个编码器层由两个⼦层连接结构组成
3.第⼀个⼦层连接结构包括⼀个多头⾃注意⼒层和规范化层以及⼀个残差连接
4.第⼆个⼦层连接结构包括⼀个前馈全连接⼦层和规范化层以及⼀个残差连接
解码器部分(紫⾊区域):
1.由N个解码器层堆叠⽽成
2.每个解码器层由三个⼦层连接结构组成
3.第⼀个⼦层连接结构包括⼀个多头⾃注意⼒⼦层和规范化层以及⼀个残差连接
4.第⼆个⼦层连接结构包括⼀个多头注意⼒⼦层和规范化层以及⼀个残差连接
5.第三个⼦层连接结构包括⼀个前馈全连接⼦层和规范化层以及⼀个残差连接
输⼊部分:
⽂本嵌⼊层(Input Embedding)作⽤:⽆论是从源⽂本嵌⼊还是⽬标⽂本嵌⼊,都是为了将⽂本中的词汇的数字表⽰转变为向量表⽰,希望在这样的⾼维空间捕捉词汇间的关系。
幸福的味道
Embedding代码实现:
1# ⽂本嵌⼊层
2class Embedding(Layer):
3
4'''
5    :param vocab:词表⼤⼩
6    :param dim_model:词嵌⼊的维度
7'''
8def__init__(lf,vocab,dim_model,**kwargs):
9        lf._vocab = vocab
10        lf._dim_model = dim_model
11        super(Embedding, lf).__init__(**kwargs)
12透视
13def build(lf, input_shape, **kwargs):
14        lf.embeddings = lf.add_weight(
15            shape=(lf._vocab,lf._dim_model),
16            initializer='global_uniform',
17            name='embeddings'
18        )
19        super(Embedding, lf).build(input_shape)
20
21def call(lf, inputs):
22if K.dtype(inputs) != 'int32':
23            inputs = K.cast(inputs,'int32')
24        embeddings = K.beddings,inputs)
25        embeddings *= lf._dim_model**0.5
26return embeddings
27
28def compute_output_shape(lf, input_shape):
29return input_shape + (lf._dim_model)
位置编码层(Position Encoding)作⽤:因为在transformer编码器结构中并没有针对词汇位置信息的处理,因此需要在Embedding层后加⼊位置编码器,将词汇位置不同可能会产⽣不同语义的信息加⼊到词嵌⼊张量中,以弥补位置信息的缺失。
PE计算公式:
PE(pos,2i)=sin(pos/100002i/d model)
PE(pos,2i+1)=cos(pos/100002i/d model)
Position Encoding代码实现:
1# 位置编码层
2class PositionEncoding(Layer):
3
4'''
5    :param dim_model:词嵌⼊维度
6'''
7def__init__(lf,dim_model,**kwargs):
8        lf._dim_model = dim_model
9        super(PositionEncoding, lf).__init__(**kwargs)
10
11def call(lf, inputs, **kwargs):
12        q_length = inputs.shape[1]
13        position_encodings = np.zeros((q_length, lf._model_dim))
14for pos in range(q_length):
15for i in range(lf._model_dim):
16                position_encodings[pos, i] = pos / np.power(10000, (i - i % 2) / lf._model_dim)
满意拼音17        position_encodings[:, 0::2] = np.sin(position_encodings[:, 0::2])  # 2i
18        position_encodings[:, 1::2] = np.cos(position_encodings[:, 1::2])  # 2i+1
19        position_encodings = K.cast(position_encodings, 'float32')
20return position_encodings
21
22def compute_output_shape(lf, input_shape):
23return input_shape
Embedding和Position Encoding相加层代码实现:
1# Embeddings和Position Encodings相加层
2class Add(Layer):
3def__init__(lf,**kwargs):
4        super(Add, lf).__init__(**kwargs)
5
6def call(lf, inputs, **kwargs):
7        embeddings,positionEncodings = inputs
8return embeddings + positionEncodings
9
10def compute_output_shape(lf, input_shape):
11return input_shape[0]
编码器解码器组件实现
相关概念:
  - 掩码张量:掩代表遮掩,码就是张量中的数值,它的尺⼨不定,⾥⾯⼀般只有0 和 1 元素,代表位置被遮掩或者不被遮掩,因此它的作⽤就是让另外⼀个张量中的⼀些数值被遮掩,也可以说是被替换,它的表现形式是⼀个张量。
  - 掩码张量的作⽤:在transformer中,掩码张量的主要作⽤在应⽤attention,有⼀些⽣成的attention张量中的值计算有可能已知了未来信息⽽得到的,未来信息被看到是因为训练时会把整个输出结果都⼀次性进⾏Embedding,但是理论上解码器的输出却不是⼀次就能产⽣最终结果的,⽽是⼀次次的通过上⼀次结果综合得到的,因此,未来的信息可能被提前利⽤,这个时候就需要对未来信息进⾏遮掩。
  - Multi-Head Attention 是由多个Self-Attention 组成。从多头注意⼒的结构图中,我们看到貌似这个所谓的多头指的就是多组线性变变换层,其实并不是,这⾥其实仅使⽤了⼀组线性变换层,即三个变
换张量对Q,K,V进⾏线性变换,这些变换并不会改变原有张量的尺度,因此每个变换张量都是⽅阵,得到结果后多头作⽤才开始体现,每个头从词义层⾯分割输出张量,但是句⼦中的每个词的表⽰只取得⼀部分,也就是只分割了最后⼀维的词嵌⼊向量(words embedding)。
  - lf-attention和multi-head attention的结构如下图。在计算中需要⽤到矩阵Q(query),K(key),V(value),实际接收的输⼊是单词的表⽰向量组成的矩阵X 或上⼀个编码器的输出,Q,K,V通过将输⼊进⾏线性变换得到。
Self-Attention 层代码实现:
1# ⾃注意⼒层
2class ScaledDotProductAttention(Layer):
3def__init__(lf, masking=True, future=Fal, dropout_rate=0., **kwargs):
4        lf._masking = masking
5        lf._future = future
6        lf._dropout_rate = dropout_rate
7        lf._masking_num = -2 ** 32 + 1
8        super(ScaledDotProductAttention, lf).__init__(**kwargs)
9
10def mask(lf, inputs, masks):
11        masks = K.cast(masks, 'float32')
12        masks = K.tile(masks, [K.shape(inputs)[0] // K.shape(masks)[0], 1])
13        masks = K.expand_dims(masks, 1)
14        outputs = inputs + masks * lf._masking_num
15return outputs
16
17def future_mask(lf, inputs):
18        diag_vals = tf.ones_like(inputs[0, :, :])
19        tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_den()
20        future_masks = tf.pand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1])
21        paddings = tf.ones_like(future_masks) * lf._masking_num
22        outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)
23return outputs
狗熊会24
25def call(lf, inputs, **kwargs):
26if lf._masking:
27asrt len(inputs) == 4, "inputs should be t [queries, keys, values, masks]."
28            queries, keys, values, masks = inputs
29el:
30asrt len(inputs) == 3, "inputs should be t [queries, keys, values]."
31            queries, keys, values = inputs
32
33if K.dtype(queries) != 'float32':  queries = K.cast(queries, 'float32')
34if K.dtype(keys) != 'float32':  keys = K.cast(keys, 'float32')
35if K.dtype(values) != 'float32':  values = K.cast(values, 'float32')
36
37        matmul = K.batch_dot(queries, tf.transpo(keys, [0, 2, 1]))  # MatMul
38        scaled_matmul = matmul / int(queries.shape[-1]) ** 0.5  # Scale
39if lf._masking:
40            scaled_matmul = lf.mask(scaled_matmul, masks)  # Mask(opt.)
41是我想的太多
42if lf._future:
43            scaled_matmul = lf.future_mask(scaled_matmul)
44
45        softmax_out = K.softmax(scaled_matmul)  # SoftMax
46# Dropout
云南省行政区划47        out = K.dropout(softmax_out, lf._dropout_rate)
48
49        outputs = K.batch_dot(out, values)
50
51return outputs
52
53def compute_output_shape(lf, input_shape):
54return input_shape
Multi-Head Attention层代码实现:
1# 多头⾃注意⼒层
2class MultiHeadAttention(Layer):
3
4def__init__(lf, n_heads, head_dim, dropout_rate=.1, masking=True, future=Fal, trainable=True, **kwargs):
5        lf._n_heads = n_heads
6        lf._head_dim = head_dim
7        lf._dropout_rate = dropout_rate
8        lf._masking = masking
9        lf._future = future
10        lf._trainable = trainable
11        super(MultiHeadAttention, lf).__init__(**kwargs)
12
13# ⽤⽅阵做Q,K,V的权重矩阵进⾏线性变换,维度不变
14def build(lf, input_shape):
15        lf._weights_queries = lf.add_weight(
16            shape=(input_shape[0][-1], lf._n_heads * lf._head_dim),
17            initializer='glorot_uniform',
18            trainable=lf._trainable,
19            name='weights_queries')
20        lf._weights_keys = lf.add_weight(
21            shape=(input_shape[1][-1], lf._n_heads * lf._head_dim),
22            initializer='glorot_uniform',
23            trainable=lf._trainable,
24            name='weights_keys')
25        lf._weights_values = lf.add_weight(
26            shape=(input_shape[2][-1], lf._n_heads * lf._head_dim),
27            initializer='glorot_uniform',
28            trainable=lf._trainable,
29            name='weights_values')
30        super(MultiHeadAttention, lf).build(input_shape)
31
32def call(lf, inputs, **kwargs):
33if lf._masking:
34asrt len(inputs) == 4, "inputs should be t [queries, keys, values, masks]."
35            queries, keys, values, masks = inputs
36el:
37asrt len(inputs) == 3, "inputs should be t [queries, keys, values]."
38            queries, keys, values = inputs
39
40        queries_linear = K.dot(queries, lf._weights_queries)
41        keys_linear = K.dot(keys, lf._weights_keys)
42        values_linear = K.dot(values, lf._weights_values)
43
44# 将变换后的Q,K,V在embedding words的维度上进⾏切分
45        queries_multi_heads = tf.concat(tf.split(queries_linear, lf._n_heads, axis=2), axis=0)
46        keys_multi_heads = tf.concat(tf.split(keys_linear, lf._n_heads, axis=2), axis=0)
47        values_multi_heads = tf.concat(tf.split(values_linear, lf._n_heads, axis=2), axis=0) 48
49if lf._masking:
50            att_inputs = [queries_multi_heads, keys_multi_heads, values_multi_heads, masks] 51el:
52            att_inputs = [queries_multi_heads, keys_multi_heads, values_multi_heads]
53
54        attention = ScaledDotProductAttention(
55            masking=lf._masking, future=lf._future, dropout_rate=lf._dropout_rate)
56        att_out = attention(att_inputs)
57
58        outputs = tf.concat(tf.split(att_out, lf._n_heads, axis=0), axis=2)
59
60return outputs
61
62def compute_output_shape(lf, input_shape):
63return input_shape
佳能s80Position-wi Feed Forward代码实现:
1# Position-wi Feed Forward层
2# out = (relu(xW1+b1))W2+b2
3class PositionWiFeedForward(Layer):
4
5def__init__(lf, model_dim, inner_dim, trainable=True, **kwargs):
6        lf._model_dim = model_dim
7        lf._inner_dim = inner_dim
8        lf._trainable = trainable
9        super(PositionWiFeedForward, lf).__init__(**kwargs)
10
11def build(lf, input_shape):
12        lf.weights_inner = lf.add_weight(
13            shape=(input_shape[-1], lf._inner_dim),
14            initializer='glorot_uniform',
15            trainable=lf._trainable,
16            name="weights_inner")
17        lf.weights_out = lf.add_weight(
18            shape=(lf._inner_dim, lf._model_dim),
19            initializer='glorot_uniform',
20            trainable=lf._trainable,
21            name="weights_out")
22        lf.bais_inner = lf.add_weight(
23            shape=(lf._inner_dim,),
24            initializer='uniform',
25            trainable=lf._trainable,
26            name="bais_inner")
27        lf.bais_out = lf.add_weight(
28            shape=(lf._model_dim,),
青年精神29            initializer='uniform',
30            trainable=lf._trainable,
31            name="bais_out")
32        super(PositionWiFeedForward, lf).build(input_shape)
33
34def call(lf, inputs, **kwargs):
35if K.dtype(inputs) != 'float32':
36            inputs = K.cast(inputs, 'float32')
37        inner_out = K.relu(K.dot(inputs, lf.weights_inner) + lf.bais_inner)
38        outputs = K.dot(inner_out, lf.weights_out) + lf.bais_out
39return outputs
40

本文发布于:2023-07-30 13:36:52,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/1123771.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:模型   位置   信息
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图