| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 | # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##    http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom paddle import nnfrom .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLossclass DBLoss(nn.Layer):    """    Differentiable Binarization (DB) Loss Function    args:        param (dict): the super paramter for DB Loss    """    def __init__(self,                 balance_loss=True,                 main_loss_type='DiceLoss',                 alpha=5,                 beta=10,                 ohem_ratio=3,                 eps=1e-6,                 **kwargs):        super(DBLoss, self).__init__()        self.alpha = alpha        self.beta = beta        self.dice_loss = DiceLoss(eps=eps)        self.l1_loss = MaskL1Loss(eps=eps)        self.bce_loss = BalanceLoss(            balance_loss=balance_loss,            main_loss_type=main_loss_type,            negative_ratio=ohem_ratio)    def forward(self, predicts, labels):        predict_maps = predicts['maps']        label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[            1:]        shrink_maps = predict_maps[:, 0, :, :]        threshold_maps = predict_maps[:, 1, :, :]        binary_maps = predict_maps[:, 2, :, :]        loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,                                         label_shrink_mask)        loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,                                           label_threshold_mask)        loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,                                          label_shrink_mask)        loss_shrink_maps = self.alpha * loss_shrink_maps        loss_threshold_maps = self.beta * loss_threshold_maps        loss_all = loss_shrink_maps + loss_threshold_maps \                   + loss_binary_maps        losses = {'loss': loss_all, \                  "loss_shrink_maps": loss_shrink_maps, \                  "loss_threshold_maps": loss_threshold_maps, \                  "loss_binary_maps": loss_binary_maps}        return losses
 |