det_basic_loss.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. from paddle import nn
  20. import paddle.nn.functional as F
  21. class BalanceLoss(nn.Layer):
  22. def __init__(self,
  23. balance_loss=True,
  24. main_loss_type='DiceLoss',
  25. negative_ratio=3,
  26. return_origin=False,
  27. eps=1e-6,
  28. **kwargs):
  29. """
  30. The BalanceLoss for Differentiable Binarization text detection
  31. args:
  32. balance_loss (bool): whether balance loss or not, default is True
  33. main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
  34. 'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
  35. negative_ratio (int|float): float, default is 3.
  36. return_origin (bool): whether return unbalanced loss or not, default is False.
  37. eps (float): default is 1e-6.
  38. """
  39. super(BalanceLoss, self).__init__()
  40. self.balance_loss = balance_loss
  41. self.main_loss_type = main_loss_type
  42. self.negative_ratio = negative_ratio
  43. self.return_origin = return_origin
  44. self.eps = eps
  45. if self.main_loss_type == "CrossEntropy":
  46. self.loss = nn.CrossEntropyLoss()
  47. elif self.main_loss_type == "Euclidean":
  48. self.loss = nn.MSELoss()
  49. elif self.main_loss_type == "DiceLoss":
  50. self.loss = DiceLoss(self.eps)
  51. elif self.main_loss_type == "BCELoss":
  52. self.loss = BCELoss(reduction='none')
  53. elif self.main_loss_type == "MaskL1Loss":
  54. self.loss = MaskL1Loss(self.eps)
  55. else:
  56. loss_type = [
  57. 'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
  58. ]
  59. raise Exception(
  60. "main_loss_type in BalanceLoss() can only be one of {}".format(
  61. loss_type))
  62. def forward(self, pred, gt, mask=None):
  63. """
  64. The BalanceLoss for Differentiable Binarization text detection
  65. args:
  66. pred (variable): predicted feature maps.
  67. gt (variable): ground truth feature maps.
  68. mask (variable): masked maps.
  69. return: (variable) balanced loss
  70. """
  71. # if self.main_loss_type in ['DiceLoss']:
  72. # # For the loss that returns to scalar value, perform ohem on the mask
  73. # mask = ohem_batch(pred, gt, mask, self.negative_ratio)
  74. # loss = self.loss(pred, gt, mask)
  75. # return loss
  76. positive = gt * mask
  77. negative = (1 - gt) * mask
  78. positive_count = int(positive.sum())
  79. negative_count = int(
  80. min(negative.sum(), positive_count * self.negative_ratio))
  81. loss = self.loss(pred, gt, mask=mask)
  82. if not self.balance_loss:
  83. return loss
  84. positive_loss = positive * loss
  85. negative_loss = negative * loss
  86. negative_loss = paddle.reshape(negative_loss, shape=[-1])
  87. if negative_count > 0:
  88. sort_loss = negative_loss.sort(descending=True)
  89. negative_loss = sort_loss[:negative_count]
  90. # negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
  91. balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
  92. positive_count + negative_count + self.eps)
  93. else:
  94. balance_loss = positive_loss.sum() / (positive_count + self.eps)
  95. if self.return_origin:
  96. return balance_loss, loss
  97. return balance_loss
  98. class DiceLoss(nn.Layer):
  99. def __init__(self, eps=1e-6):
  100. super(DiceLoss, self).__init__()
  101. self.eps = eps
  102. def forward(self, pred, gt, mask, weights=None):
  103. """
  104. DiceLoss function.
  105. """
  106. assert pred.shape == gt.shape
  107. assert pred.shape == mask.shape
  108. if weights is not None:
  109. assert weights.shape == mask.shape
  110. mask = weights * mask
  111. intersection = paddle.sum(pred * gt * mask)
  112. union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
  113. loss = 1 - 2.0 * intersection / union
  114. assert loss <= 1
  115. return loss
  116. class MaskL1Loss(nn.Layer):
  117. def __init__(self, eps=1e-6):
  118. super(MaskL1Loss, self).__init__()
  119. self.eps = eps
  120. def forward(self, pred, gt, mask):
  121. """
  122. Mask L1 Loss
  123. """
  124. loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
  125. loss = paddle.mean(loss)
  126. return loss
  127. class BCELoss(nn.Layer):
  128. def __init__(self, reduction='mean'):
  129. super(BCELoss, self).__init__()
  130. self.reduction = reduction
  131. def forward(self, input, label, mask=None, weight=None, name=None):
  132. loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
  133. return loss
  134. def ohem_single(score, gt_text, training_mask, ohem_ratio):
  135. pos_num = (int)(np.sum(gt_text > 0.5)) - (
  136. int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
  137. if pos_num == 0:
  138. # selected_mask = gt_text.copy() * 0 # may be not good
  139. selected_mask = training_mask
  140. selected_mask = selected_mask.reshape(
  141. 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
  142. return selected_mask
  143. neg_num = (int)(np.sum(gt_text <= 0.5))
  144. neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
  145. if neg_num == 0:
  146. selected_mask = training_mask
  147. selected_mask = selected_mask.reshape(
  148. 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
  149. return selected_mask
  150. neg_score = score[gt_text <= 0.5]
  151. # 将负样本得分从高到低排序
  152. neg_score_sorted = np.sort(-neg_score)
  153. threshold = -neg_score_sorted[neg_num - 1]
  154. # 选出 得分高的 负样本 和正样本 的 mask
  155. selected_mask = ((score >= threshold) |
  156. (gt_text > 0.5)) & (training_mask > 0.5)
  157. selected_mask = selected_mask.reshape(
  158. 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
  159. return selected_mask
  160. def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
  161. scores = scores.numpy()
  162. gt_texts = gt_texts.numpy()
  163. training_masks = training_masks.numpy()
  164. selected_masks = []
  165. for i in range(scores.shape[0]):
  166. selected_masks.append(
  167. ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
  168. i, :, :], ohem_ratio))
  169. selected_masks = np.concatenate(selected_masks, 0)
  170. selected_masks = paddle.to_variable(selected_masks)
  171. return selected_masks