self_attention.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. from paddle import ParamAttr, nn
  20. from paddle import nn, ParamAttr
  21. from paddle.nn import functional as F
  22. import paddle.fluid as fluid
  23. import numpy as np
  24. gradient_clip = 10
  25. class WrapEncoderForFeature(nn.Layer):
  26. def __init__(self,
  27. src_vocab_size,
  28. max_length,
  29. n_layer,
  30. n_head,
  31. d_key,
  32. d_value,
  33. d_model,
  34. d_inner_hid,
  35. prepostprocess_dropout,
  36. attention_dropout,
  37. relu_dropout,
  38. preprocess_cmd,
  39. postprocess_cmd,
  40. weight_sharing,
  41. bos_idx=0):
  42. super(WrapEncoderForFeature, self).__init__()
  43. self.prepare_encoder = PrepareEncoder(
  44. src_vocab_size,
  45. d_model,
  46. max_length,
  47. prepostprocess_dropout,
  48. bos_idx=bos_idx,
  49. word_emb_param_name="src_word_emb_table")
  50. self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
  51. d_inner_hid, prepostprocess_dropout,
  52. attention_dropout, relu_dropout, preprocess_cmd,
  53. postprocess_cmd)
  54. def forward(self, enc_inputs):
  55. conv_features, src_pos, src_slf_attn_bias = enc_inputs
  56. enc_input = self.prepare_encoder(conv_features, src_pos)
  57. enc_output = self.encoder(enc_input, src_slf_attn_bias)
  58. return enc_output
  59. class WrapEncoder(nn.Layer):
  60. """
  61. embedder + encoder
  62. """
  63. def __init__(self,
  64. src_vocab_size,
  65. max_length,
  66. n_layer,
  67. n_head,
  68. d_key,
  69. d_value,
  70. d_model,
  71. d_inner_hid,
  72. prepostprocess_dropout,
  73. attention_dropout,
  74. relu_dropout,
  75. preprocess_cmd,
  76. postprocess_cmd,
  77. weight_sharing,
  78. bos_idx=0):
  79. super(WrapEncoder, self).__init__()
  80. self.prepare_decoder = PrepareDecoder(
  81. src_vocab_size,
  82. d_model,
  83. max_length,
  84. prepostprocess_dropout,
  85. bos_idx=bos_idx)
  86. self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
  87. d_inner_hid, prepostprocess_dropout,
  88. attention_dropout, relu_dropout, preprocess_cmd,
  89. postprocess_cmd)
  90. def forward(self, enc_inputs):
  91. src_word, src_pos, src_slf_attn_bias = enc_inputs
  92. enc_input = self.prepare_decoder(src_word, src_pos)
  93. enc_output = self.encoder(enc_input, src_slf_attn_bias)
  94. return enc_output
  95. class Encoder(nn.Layer):
  96. """
  97. encoder
  98. """
  99. def __init__(self,
  100. n_layer,
  101. n_head,
  102. d_key,
  103. d_value,
  104. d_model,
  105. d_inner_hid,
  106. prepostprocess_dropout,
  107. attention_dropout,
  108. relu_dropout,
  109. preprocess_cmd="n",
  110. postprocess_cmd="da"):
  111. super(Encoder, self).__init__()
  112. self.encoder_layers = list()
  113. for i in range(n_layer):
  114. self.encoder_layers.append(
  115. self.add_sublayer(
  116. "layer_%d" % i,
  117. EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
  118. prepostprocess_dropout, attention_dropout,
  119. relu_dropout, preprocess_cmd,
  120. postprocess_cmd)))
  121. self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
  122. prepostprocess_dropout)
  123. def forward(self, enc_input, attn_bias):
  124. for encoder_layer in self.encoder_layers:
  125. enc_output = encoder_layer(enc_input, attn_bias)
  126. enc_input = enc_output
  127. enc_output = self.processer(enc_output)
  128. return enc_output
  129. class EncoderLayer(nn.Layer):
  130. """
  131. EncoderLayer
  132. """
  133. def __init__(self,
  134. n_head,
  135. d_key,
  136. d_value,
  137. d_model,
  138. d_inner_hid,
  139. prepostprocess_dropout,
  140. attention_dropout,
  141. relu_dropout,
  142. preprocess_cmd="n",
  143. postprocess_cmd="da"):
  144. super(EncoderLayer, self).__init__()
  145. self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
  146. prepostprocess_dropout)
  147. self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
  148. attention_dropout)
  149. self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
  150. prepostprocess_dropout)
  151. self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
  152. prepostprocess_dropout)
  153. self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
  154. self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
  155. prepostprocess_dropout)
  156. def forward(self, enc_input, attn_bias):
  157. attn_output = self.self_attn(
  158. self.preprocesser1(enc_input), None, None, attn_bias)
  159. attn_output = self.postprocesser1(attn_output, enc_input)
  160. ffn_output = self.ffn(self.preprocesser2(attn_output))
  161. ffn_output = self.postprocesser2(ffn_output, attn_output)
  162. return ffn_output
  163. class MultiHeadAttention(nn.Layer):
  164. """
  165. Multi-Head Attention
  166. """
  167. def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
  168. super(MultiHeadAttention, self).__init__()
  169. self.n_head = n_head
  170. self.d_key = d_key
  171. self.d_value = d_value
  172. self.d_model = d_model
  173. self.dropout_rate = dropout_rate
  174. self.q_fc = paddle.nn.Linear(
  175. in_features=d_model, out_features=d_key * n_head, bias_attr=False)
  176. self.k_fc = paddle.nn.Linear(
  177. in_features=d_model, out_features=d_key * n_head, bias_attr=False)
  178. self.v_fc = paddle.nn.Linear(
  179. in_features=d_model, out_features=d_value * n_head, bias_attr=False)
  180. self.proj_fc = paddle.nn.Linear(
  181. in_features=d_value * n_head, out_features=d_model, bias_attr=False)
  182. def _prepare_qkv(self, queries, keys, values, cache=None):
  183. if keys is None: # self-attention
  184. keys, values = queries, queries
  185. static_kv = False
  186. else: # cross-attention
  187. static_kv = True
  188. q = self.q_fc(queries)
  189. q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
  190. q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
  191. if cache is not None and static_kv and "static_k" in cache:
  192. # for encoder-decoder attention in inference and has cached
  193. k = cache["static_k"]
  194. v = cache["static_v"]
  195. else:
  196. k = self.k_fc(keys)
  197. v = self.v_fc(values)
  198. k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
  199. k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
  200. v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
  201. v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
  202. if cache is not None:
  203. if static_kv and not "static_k" in cache:
  204. # for encoder-decoder attention in inference and has not cached
  205. cache["static_k"], cache["static_v"] = k, v
  206. elif not static_kv:
  207. # for decoder self-attention in inference
  208. cache_k, cache_v = cache["k"], cache["v"]
  209. k = paddle.concat([cache_k, k], axis=2)
  210. v = paddle.concat([cache_v, v], axis=2)
  211. cache["k"], cache["v"] = k, v
  212. return q, k, v
  213. def forward(self, queries, keys, values, attn_bias, cache=None):
  214. # compute q ,k ,v
  215. keys = queries if keys is None else keys
  216. values = keys if values is None else values
  217. q, k, v = self._prepare_qkv(queries, keys, values, cache)
  218. # scale dot product attention
  219. product = paddle.matmul(x=q, y=k, transpose_y=True)
  220. product = product * self.d_model**-0.5
  221. if attn_bias is not None:
  222. product += attn_bias
  223. weights = F.softmax(product)
  224. if self.dropout_rate:
  225. weights = F.dropout(
  226. weights, p=self.dropout_rate, mode="downscale_in_infer")
  227. out = paddle.matmul(weights, v)
  228. # combine heads
  229. out = paddle.transpose(out, perm=[0, 2, 1, 3])
  230. out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
  231. # project to output
  232. out = self.proj_fc(out)
  233. return out
  234. class PrePostProcessLayer(nn.Layer):
  235. """
  236. PrePostProcessLayer
  237. """
  238. def __init__(self, process_cmd, d_model, dropout_rate):
  239. super(PrePostProcessLayer, self).__init__()
  240. self.process_cmd = process_cmd
  241. self.functors = []
  242. for cmd in self.process_cmd:
  243. if cmd == "a": # add residual connection
  244. self.functors.append(lambda x, y: x + y if y is not None else x)
  245. elif cmd == "n": # add layer normalization
  246. self.functors.append(
  247. self.add_sublayer(
  248. "layer_norm_%d" % len(
  249. self.sublayers(include_sublayers=False)),
  250. paddle.nn.LayerNorm(
  251. normalized_shape=d_model,
  252. weight_attr=fluid.ParamAttr(
  253. initializer=fluid.initializer.Constant(1.)),
  254. bias_attr=fluid.ParamAttr(
  255. initializer=fluid.initializer.Constant(0.)))))
  256. elif cmd == "d": # add dropout
  257. self.functors.append(lambda x: F.dropout(
  258. x, p=dropout_rate, mode="downscale_in_infer")
  259. if dropout_rate else x)
  260. def forward(self, x, residual=None):
  261. for i, cmd in enumerate(self.process_cmd):
  262. if cmd == "a":
  263. x = self.functors[i](x, residual)
  264. else:
  265. x = self.functors[i](x)
  266. return x
  267. class PrepareEncoder(nn.Layer):
  268. def __init__(self,
  269. src_vocab_size,
  270. src_emb_dim,
  271. src_max_len,
  272. dropout_rate=0,
  273. bos_idx=0,
  274. word_emb_param_name=None,
  275. pos_enc_param_name=None):
  276. super(PrepareEncoder, self).__init__()
  277. self.src_emb_dim = src_emb_dim
  278. self.src_max_len = src_max_len
  279. self.emb = paddle.nn.Embedding(
  280. num_embeddings=self.src_max_len,
  281. embedding_dim=self.src_emb_dim,
  282. sparse=True)
  283. self.dropout_rate = dropout_rate
  284. def forward(self, src_word, src_pos):
  285. src_word_emb = src_word
  286. src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
  287. src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
  288. src_pos = paddle.squeeze(src_pos, axis=-1)
  289. src_pos_enc = self.emb(src_pos)
  290. src_pos_enc.stop_gradient = True
  291. enc_input = src_word_emb + src_pos_enc
  292. if self.dropout_rate:
  293. out = F.dropout(
  294. x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
  295. else:
  296. out = enc_input
  297. return out
  298. class PrepareDecoder(nn.Layer):
  299. def __init__(self,
  300. src_vocab_size,
  301. src_emb_dim,
  302. src_max_len,
  303. dropout_rate=0,
  304. bos_idx=0,
  305. word_emb_param_name=None,
  306. pos_enc_param_name=None):
  307. super(PrepareDecoder, self).__init__()
  308. self.src_emb_dim = src_emb_dim
  309. """
  310. self.emb0 = Embedding(num_embeddings=src_vocab_size,
  311. embedding_dim=src_emb_dim)
  312. """
  313. self.emb0 = paddle.nn.Embedding(
  314. num_embeddings=src_vocab_size,
  315. embedding_dim=self.src_emb_dim,
  316. padding_idx=bos_idx,
  317. weight_attr=paddle.ParamAttr(
  318. name=word_emb_param_name,
  319. initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
  320. self.emb1 = paddle.nn.Embedding(
  321. num_embeddings=src_max_len,
  322. embedding_dim=self.src_emb_dim,
  323. weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
  324. self.dropout_rate = dropout_rate
  325. def forward(self, src_word, src_pos):
  326. src_word = fluid.layers.cast(src_word, 'int64')
  327. src_word = paddle.squeeze(src_word, axis=-1)
  328. src_word_emb = self.emb0(src_word)
  329. src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
  330. src_pos = paddle.squeeze(src_pos, axis=-1)
  331. src_pos_enc = self.emb1(src_pos)
  332. src_pos_enc.stop_gradient = True
  333. enc_input = src_word_emb + src_pos_enc
  334. if self.dropout_rate:
  335. out = F.dropout(
  336. x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
  337. else:
  338. out = enc_input
  339. return out
  340. class FFN(nn.Layer):
  341. """
  342. Feed-Forward Network
  343. """
  344. def __init__(self, d_inner_hid, d_model, dropout_rate):
  345. super(FFN, self).__init__()
  346. self.dropout_rate = dropout_rate
  347. self.fc1 = paddle.nn.Linear(
  348. in_features=d_model, out_features=d_inner_hid)
  349. self.fc2 = paddle.nn.Linear(
  350. in_features=d_inner_hid, out_features=d_model)
  351. def forward(self, x):
  352. hidden = self.fc1(x)
  353. hidden = F.relu(hidden)
  354. if self.dropout_rate:
  355. hidden = F.dropout(
  356. hidden, p=self.dropout_rate, mode="downscale_in_infer")
  357. out = self.fc2(hidden)
  358. return out