123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import numpy as np
- import paddle
- from paddle import nn
- import paddle.nn.functional as F
- class BalanceLoss(nn.Layer):
- def __init__(self,
- balance_loss=True,
- main_loss_type='DiceLoss',
- negative_ratio=3,
- return_origin=False,
- eps=1e-6,
- **kwargs):
- """
- The BalanceLoss for Differentiable Binarization text detection
- args:
- balance_loss (bool): whether balance loss or not, default is True
- main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
- 'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
- negative_ratio (int|float): float, default is 3.
- return_origin (bool): whether return unbalanced loss or not, default is False.
- eps (float): default is 1e-6.
- """
- super(BalanceLoss, self).__init__()
- self.balance_loss = balance_loss
- self.main_loss_type = main_loss_type
- self.negative_ratio = negative_ratio
- self.return_origin = return_origin
- self.eps = eps
- if self.main_loss_type == "CrossEntropy":
- self.loss = nn.CrossEntropyLoss()
- elif self.main_loss_type == "Euclidean":
- self.loss = nn.MSELoss()
- elif self.main_loss_type == "DiceLoss":
- self.loss = DiceLoss(self.eps)
- elif self.main_loss_type == "BCELoss":
- self.loss = BCELoss(reduction='none')
- elif self.main_loss_type == "MaskL1Loss":
- self.loss = MaskL1Loss(self.eps)
- else:
- loss_type = [
- 'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
- ]
- raise Exception(
- "main_loss_type in BalanceLoss() can only be one of {}".format(
- loss_type))
- def forward(self, pred, gt, mask=None):
- """
- The BalanceLoss for Differentiable Binarization text detection
- args:
- pred (variable): predicted feature maps.
- gt (variable): ground truth feature maps.
- mask (variable): masked maps.
- return: (variable) balanced loss
- """
- # if self.main_loss_type in ['DiceLoss']:
- # # For the loss that returns to scalar value, perform ohem on the mask
- # mask = ohem_batch(pred, gt, mask, self.negative_ratio)
- # loss = self.loss(pred, gt, mask)
- # return loss
- positive = gt * mask
- negative = (1 - gt) * mask
- positive_count = int(positive.sum())
- negative_count = int(
- min(negative.sum(), positive_count * self.negative_ratio))
- loss = self.loss(pred, gt, mask=mask)
- if not self.balance_loss:
- return loss
- positive_loss = positive * loss
- negative_loss = negative * loss
- negative_loss = paddle.reshape(negative_loss, shape=[-1])
- if negative_count > 0:
- sort_loss = negative_loss.sort(descending=True)
- negative_loss = sort_loss[:negative_count]
- # negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
- balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
- positive_count + negative_count + self.eps)
- else:
- balance_loss = positive_loss.sum() / (positive_count + self.eps)
- if self.return_origin:
- return balance_loss, loss
- return balance_loss
- class DiceLoss(nn.Layer):
- def __init__(self, eps=1e-6):
- super(DiceLoss, self).__init__()
- self.eps = eps
- def forward(self, pred, gt, mask, weights=None):
- """
- DiceLoss function.
- """
- assert pred.shape == gt.shape
- assert pred.shape == mask.shape
- if weights is not None:
- assert weights.shape == mask.shape
- mask = weights * mask
- intersection = paddle.sum(pred * gt * mask)
- union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
- loss = 1 - 2.0 * intersection / union
- assert loss <= 1
- return loss
- class MaskL1Loss(nn.Layer):
- def __init__(self, eps=1e-6):
- super(MaskL1Loss, self).__init__()
- self.eps = eps
- def forward(self, pred, gt, mask):
- """
- Mask L1 Loss
- """
- loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
- loss = paddle.mean(loss)
- return loss
- class BCELoss(nn.Layer):
- def __init__(self, reduction='mean'):
- super(BCELoss, self).__init__()
- self.reduction = reduction
- def forward(self, input, label, mask=None, weight=None, name=None):
- loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
- return loss
- def ohem_single(score, gt_text, training_mask, ohem_ratio):
- pos_num = (int)(np.sum(gt_text > 0.5)) - (
- int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
- if pos_num == 0:
- # selected_mask = gt_text.copy() * 0 # may be not good
- selected_mask = training_mask
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
- neg_num = (int)(np.sum(gt_text <= 0.5))
- neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
- if neg_num == 0:
- selected_mask = training_mask
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
- neg_score = score[gt_text <= 0.5]
- # 将负样本得分从高到低排序
- neg_score_sorted = np.sort(-neg_score)
- threshold = -neg_score_sorted[neg_num - 1]
- # 选出 得分高的 负样本 和正样本 的 mask
- selected_mask = ((score >= threshold) |
- (gt_text > 0.5)) & (training_mask > 0.5)
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
- def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
- scores = scores.numpy()
- gt_texts = gt_texts.numpy()
- training_masks = training_masks.numpy()
- selected_masks = []
- for i in range(scores.shape[0]):
- selected_masks.append(
- ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
- i, :, :], ohem_ratio))
- selected_masks = np.concatenate(selected_masks, 0)
- selected_masks = paddle.to_variable(selected_masks)
- return selected_masks
|