123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import math
- import paddle
- from paddle import ParamAttr, nn
- from paddle import nn, ParamAttr
- from paddle.nn import functional as F
- import paddle.fluid as fluid
- import numpy as np
- gradient_clip = 10
- class WrapEncoderForFeature(nn.Layer):
- def __init__(self,
- src_vocab_size,
- max_length,
- n_layer,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- preprocess_cmd,
- postprocess_cmd,
- weight_sharing,
- bos_idx=0):
- super(WrapEncoderForFeature, self).__init__()
- self.prepare_encoder = PrepareEncoder(
- src_vocab_size,
- d_model,
- max_length,
- prepostprocess_dropout,
- bos_idx=bos_idx,
- word_emb_param_name="src_word_emb_table")
- self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
- d_inner_hid, prepostprocess_dropout,
- attention_dropout, relu_dropout, preprocess_cmd,
- postprocess_cmd)
- def forward(self, enc_inputs):
- conv_features, src_pos, src_slf_attn_bias = enc_inputs
- enc_input = self.prepare_encoder(conv_features, src_pos)
- enc_output = self.encoder(enc_input, src_slf_attn_bias)
- return enc_output
- class WrapEncoder(nn.Layer):
- """
- embedder + encoder
- """
- def __init__(self,
- src_vocab_size,
- max_length,
- n_layer,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- preprocess_cmd,
- postprocess_cmd,
- weight_sharing,
- bos_idx=0):
- super(WrapEncoder, self).__init__()
- self.prepare_decoder = PrepareDecoder(
- src_vocab_size,
- d_model,
- max_length,
- prepostprocess_dropout,
- bos_idx=bos_idx)
- self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
- d_inner_hid, prepostprocess_dropout,
- attention_dropout, relu_dropout, preprocess_cmd,
- postprocess_cmd)
- def forward(self, enc_inputs):
- src_word, src_pos, src_slf_attn_bias = enc_inputs
- enc_input = self.prepare_decoder(src_word, src_pos)
- enc_output = self.encoder(enc_input, src_slf_attn_bias)
- return enc_output
- class Encoder(nn.Layer):
- """
- encoder
- """
- def __init__(self,
- n_layer,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- preprocess_cmd="n",
- postprocess_cmd="da"):
- super(Encoder, self).__init__()
- self.encoder_layers = list()
- for i in range(n_layer):
- self.encoder_layers.append(
- self.add_sublayer(
- "layer_%d" % i,
- EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
- prepostprocess_dropout, attention_dropout,
- relu_dropout, preprocess_cmd,
- postprocess_cmd)))
- self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
- prepostprocess_dropout)
- def forward(self, enc_input, attn_bias):
- for encoder_layer in self.encoder_layers:
- enc_output = encoder_layer(enc_input, attn_bias)
- enc_input = enc_output
- enc_output = self.processer(enc_output)
- return enc_output
- class EncoderLayer(nn.Layer):
- """
- EncoderLayer
- """
- def __init__(self,
- n_head,
- d_key,
- d_value,
- d_model,
- d_inner_hid,
- prepostprocess_dropout,
- attention_dropout,
- relu_dropout,
- preprocess_cmd="n",
- postprocess_cmd="da"):
- super(EncoderLayer, self).__init__()
- self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
- prepostprocess_dropout)
- self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
- attention_dropout)
- self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
- prepostprocess_dropout)
- self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
- prepostprocess_dropout)
- self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
- self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
- prepostprocess_dropout)
- def forward(self, enc_input, attn_bias):
- attn_output = self.self_attn(
- self.preprocesser1(enc_input), None, None, attn_bias)
- attn_output = self.postprocesser1(attn_output, enc_input)
- ffn_output = self.ffn(self.preprocesser2(attn_output))
- ffn_output = self.postprocesser2(ffn_output, attn_output)
- return ffn_output
- class MultiHeadAttention(nn.Layer):
- """
- Multi-Head Attention
- """
- def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
- super(MultiHeadAttention, self).__init__()
- self.n_head = n_head
- self.d_key = d_key
- self.d_value = d_value
- self.d_model = d_model
- self.dropout_rate = dropout_rate
- self.q_fc = paddle.nn.Linear(
- in_features=d_model, out_features=d_key * n_head, bias_attr=False)
- self.k_fc = paddle.nn.Linear(
- in_features=d_model, out_features=d_key * n_head, bias_attr=False)
- self.v_fc = paddle.nn.Linear(
- in_features=d_model, out_features=d_value * n_head, bias_attr=False)
- self.proj_fc = paddle.nn.Linear(
- in_features=d_value * n_head, out_features=d_model, bias_attr=False)
- def _prepare_qkv(self, queries, keys, values, cache=None):
- if keys is None: # self-attention
- keys, values = queries, queries
- static_kv = False
- else: # cross-attention
- static_kv = True
- q = self.q_fc(queries)
- q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
- q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
- if cache is not None and static_kv and "static_k" in cache:
- # for encoder-decoder attention in inference and has cached
- k = cache["static_k"]
- v = cache["static_v"]
- else:
- k = self.k_fc(keys)
- v = self.v_fc(values)
- k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
- k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
- v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
- v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
- if cache is not None:
- if static_kv and not "static_k" in cache:
- # for encoder-decoder attention in inference and has not cached
- cache["static_k"], cache["static_v"] = k, v
- elif not static_kv:
- # for decoder self-attention in inference
- cache_k, cache_v = cache["k"], cache["v"]
- k = paddle.concat([cache_k, k], axis=2)
- v = paddle.concat([cache_v, v], axis=2)
- cache["k"], cache["v"] = k, v
- return q, k, v
- def forward(self, queries, keys, values, attn_bias, cache=None):
- # compute q ,k ,v
- keys = queries if keys is None else keys
- values = keys if values is None else values
- q, k, v = self._prepare_qkv(queries, keys, values, cache)
- # scale dot product attention
- product = paddle.matmul(x=q, y=k, transpose_y=True)
- product = product * self.d_model**-0.5
- if attn_bias is not None:
- product += attn_bias
- weights = F.softmax(product)
- if self.dropout_rate:
- weights = F.dropout(
- weights, p=self.dropout_rate, mode="downscale_in_infer")
- out = paddle.matmul(weights, v)
- # combine heads
- out = paddle.transpose(out, perm=[0, 2, 1, 3])
- out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
- # project to output
- out = self.proj_fc(out)
- return out
- class PrePostProcessLayer(nn.Layer):
- """
- PrePostProcessLayer
- """
- def __init__(self, process_cmd, d_model, dropout_rate):
- super(PrePostProcessLayer, self).__init__()
- self.process_cmd = process_cmd
- self.functors = []
- for cmd in self.process_cmd:
- if cmd == "a": # add residual connection
- self.functors.append(lambda x, y: x + y if y is not None else x)
- elif cmd == "n": # add layer normalization
- self.functors.append(
- self.add_sublayer(
- "layer_norm_%d" % len(
- self.sublayers(include_sublayers=False)),
- paddle.nn.LayerNorm(
- normalized_shape=d_model,
- weight_attr=fluid.ParamAttr(
- initializer=fluid.initializer.Constant(1.)),
- bias_attr=fluid.ParamAttr(
- initializer=fluid.initializer.Constant(0.)))))
- elif cmd == "d": # add dropout
- self.functors.append(lambda x: F.dropout(
- x, p=dropout_rate, mode="downscale_in_infer")
- if dropout_rate else x)
- def forward(self, x, residual=None):
- for i, cmd in enumerate(self.process_cmd):
- if cmd == "a":
- x = self.functors[i](x, residual)
- else:
- x = self.functors[i](x)
- return x
- class PrepareEncoder(nn.Layer):
- def __init__(self,
- src_vocab_size,
- src_emb_dim,
- src_max_len,
- dropout_rate=0,
- bos_idx=0,
- word_emb_param_name=None,
- pos_enc_param_name=None):
- super(PrepareEncoder, self).__init__()
- self.src_emb_dim = src_emb_dim
- self.src_max_len = src_max_len
- self.emb = paddle.nn.Embedding(
- num_embeddings=self.src_max_len,
- embedding_dim=self.src_emb_dim,
- sparse=True)
- self.dropout_rate = dropout_rate
- def forward(self, src_word, src_pos):
- src_word_emb = src_word
- src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
- src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
- src_pos = paddle.squeeze(src_pos, axis=-1)
- src_pos_enc = self.emb(src_pos)
- src_pos_enc.stop_gradient = True
- enc_input = src_word_emb + src_pos_enc
- if self.dropout_rate:
- out = F.dropout(
- x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
- else:
- out = enc_input
- return out
- class PrepareDecoder(nn.Layer):
- def __init__(self,
- src_vocab_size,
- src_emb_dim,
- src_max_len,
- dropout_rate=0,
- bos_idx=0,
- word_emb_param_name=None,
- pos_enc_param_name=None):
- super(PrepareDecoder, self).__init__()
- self.src_emb_dim = src_emb_dim
- """
- self.emb0 = Embedding(num_embeddings=src_vocab_size,
- embedding_dim=src_emb_dim)
- """
- self.emb0 = paddle.nn.Embedding(
- num_embeddings=src_vocab_size,
- embedding_dim=self.src_emb_dim,
- padding_idx=bos_idx,
- weight_attr=paddle.ParamAttr(
- name=word_emb_param_name,
- initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
- self.emb1 = paddle.nn.Embedding(
- num_embeddings=src_max_len,
- embedding_dim=self.src_emb_dim,
- weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
- self.dropout_rate = dropout_rate
- def forward(self, src_word, src_pos):
- src_word = fluid.layers.cast(src_word, 'int64')
- src_word = paddle.squeeze(src_word, axis=-1)
- src_word_emb = self.emb0(src_word)
- src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
- src_pos = paddle.squeeze(src_pos, axis=-1)
- src_pos_enc = self.emb1(src_pos)
- src_pos_enc.stop_gradient = True
- enc_input = src_word_emb + src_pos_enc
- if self.dropout_rate:
- out = F.dropout(
- x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
- else:
- out = enc_input
- return out
- class FFN(nn.Layer):
- """
- Feed-Forward Network
- """
- def __init__(self, d_inner_hid, d_model, dropout_rate):
- super(FFN, self).__init__()
- self.dropout_rate = dropout_rate
- self.fc1 = paddle.nn.Linear(
- in_features=d_model, out_features=d_inner_hid)
- self.fc2 = paddle.nn.Linear(
- in_features=d_inner_hid, out_features=d_model)
- def forward(self, x):
- hidden = self.fc1(x)
- hidden = F.relu(hidden)
- if self.dropout_rate:
- hidden = F.dropout(
- hidden, p=self.dropout_rate, mode="downscale_in_infer")
- out = self.fc2(hidden)
- return out
|