db_postprocess.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import cv2
  19. import paddle
  20. from shapely.geometry import Polygon
  21. import pyclipper
  22. class DBPostProcess(object):
  23. """
  24. The post process for Differentiable Binarization (DB).
  25. """
  26. def __init__(self,
  27. thresh=0.3,
  28. box_thresh=0.7,
  29. max_candidates=1000,
  30. unclip_ratio=2.0,
  31. use_dilation=False,
  32. **kwargs):
  33. self.thresh = thresh
  34. self.box_thresh = box_thresh
  35. self.max_candidates = max_candidates
  36. self.unclip_ratio = unclip_ratio
  37. self.min_size = 3
  38. self.dilation_kernel = None if not use_dilation else np.array(
  39. [[1, 1], [1, 1]])
  40. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  41. '''
  42. _bitmap: single map with shape (1, H, W),
  43. whose values are binarized as {0, 1}
  44. '''
  45. bitmap = _bitmap
  46. height, width = bitmap.shape
  47. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
  48. cv2.CHAIN_APPROX_SIMPLE)
  49. if len(outs) == 3:
  50. img, contours, _ = outs[0], outs[1], outs[2]
  51. elif len(outs) == 2:
  52. contours, _ = outs[0], outs[1]
  53. num_contours = min(len(contours), self.max_candidates)
  54. boxes = []
  55. scores = []
  56. for index in range(num_contours):
  57. contour = contours[index]
  58. points, sside = self.get_mini_boxes(contour)
  59. if sside < self.min_size:
  60. continue
  61. points = np.array(points)
  62. score = self.box_score_fast(pred, points.reshape(-1, 2))
  63. if self.box_thresh > score:
  64. continue
  65. box = self.unclip(points).reshape(-1, 1, 2)
  66. box, sside = self.get_mini_boxes(box)
  67. if sside < self.min_size + 2:
  68. continue
  69. box = np.array(box)
  70. box[:, 0] = np.clip(
  71. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  72. box[:, 1] = np.clip(
  73. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  74. boxes.append(box.astype(np.int16))
  75. scores.append(score)
  76. return np.array(boxes, dtype=np.int16), scores
  77. def unclip(self, box):
  78. unclip_ratio = self.unclip_ratio
  79. poly = Polygon(box)
  80. distance = poly.area * unclip_ratio / poly.length
  81. offset = pyclipper.PyclipperOffset()
  82. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  83. expanded = np.array(offset.Execute(distance))
  84. return expanded
  85. def get_mini_boxes(self, contour):
  86. bounding_box = cv2.minAreaRect(contour)
  87. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  88. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  89. if points[1][1] > points[0][1]:
  90. index_1 = 0
  91. index_4 = 1
  92. else:
  93. index_1 = 1
  94. index_4 = 0
  95. if points[3][1] > points[2][1]:
  96. index_2 = 2
  97. index_3 = 3
  98. else:
  99. index_2 = 3
  100. index_3 = 2
  101. box = [
  102. points[index_1], points[index_2], points[index_3], points[index_4]
  103. ]
  104. return box, min(bounding_box[1])
  105. def box_score_fast(self, bitmap, _box):
  106. h, w = bitmap.shape[:2]
  107. box = _box.copy()
  108. xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
  109. xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
  110. ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
  111. ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
  112. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  113. box[:, 0] = box[:, 0] - xmin
  114. box[:, 1] = box[:, 1] - ymin
  115. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  116. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  117. def __call__(self, outs_dict, shape_list):
  118. pred = outs_dict['maps']
  119. if isinstance(pred, paddle.Tensor):
  120. pred = pred.numpy()
  121. pred = pred[:, 0, :, :]
  122. segmentation = pred > self.thresh
  123. boxes_batch = []
  124. for batch_index in range(pred.shape[0]):
  125. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  126. if self.dilation_kernel is not None:
  127. mask = cv2.dilate(
  128. np.array(segmentation[batch_index]).astype(np.uint8),
  129. self.dilation_kernel)
  130. else:
  131. mask = segmentation[batch_index]
  132. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
  133. src_w, src_h)
  134. boxes_batch.append({'points': boxes})
  135. return boxes_batch