east_postprocess.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. from .locality_aware_nms import nms_locality
  19. import cv2
  20. import paddle
  21. import os
  22. import sys
  23. class EASTPostProcess(object):
  24. """
  25. The post process for EAST.
  26. """
  27. def __init__(self,
  28. score_thresh=0.8,
  29. cover_thresh=0.1,
  30. nms_thresh=0.2,
  31. **kwargs):
  32. self.score_thresh = score_thresh
  33. self.cover_thresh = cover_thresh
  34. self.nms_thresh = nms_thresh
  35. # c++ la-nms is faster, but only support python 3.5
  36. self.is_python35 = False
  37. if sys.version_info.major == 3 and sys.version_info.minor == 5:
  38. self.is_python35 = True
  39. def restore_rectangle_quad(self, origin, geometry):
  40. """
  41. Restore rectangle from quadrangle.
  42. """
  43. # quad
  44. origin_concat = np.concatenate(
  45. (origin, origin, origin, origin), axis=1) # (n, 8)
  46. pred_quads = origin_concat - geometry
  47. pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
  48. return pred_quads
  49. def detect(self,
  50. score_map,
  51. geo_map,
  52. score_thresh=0.8,
  53. cover_thresh=0.1,
  54. nms_thresh=0.2):
  55. """
  56. restore text boxes from score map and geo map
  57. """
  58. score_map = score_map[0]
  59. geo_map = np.swapaxes(geo_map, 1, 0)
  60. geo_map = np.swapaxes(geo_map, 1, 2)
  61. # filter the score map
  62. xy_text = np.argwhere(score_map > score_thresh)
  63. if len(xy_text) == 0:
  64. return []
  65. # sort the text boxes via the y axis
  66. xy_text = xy_text[np.argsort(xy_text[:, 0])]
  67. #restore quad proposals
  68. text_box_restored = self.restore_rectangle_quad(
  69. xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
  70. boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
  71. boxes[:, :8] = text_box_restored.reshape((-1, 8))
  72. boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
  73. if self.is_python35:
  74. import lanms
  75. boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
  76. else:
  77. boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
  78. if boxes.shape[0] == 0:
  79. return []
  80. # Here we filter some low score boxes by the average score map,
  81. # this is different from the orginal paper.
  82. for i, box in enumerate(boxes):
  83. mask = np.zeros_like(score_map, dtype=np.uint8)
  84. cv2.fillPoly(mask, box[:8].reshape(
  85. (-1, 4, 2)).astype(np.int32) // 4, 1)
  86. boxes[i, 8] = cv2.mean(score_map, mask)[0]
  87. boxes = boxes[boxes[:, 8] > cover_thresh]
  88. return boxes
  89. def sort_poly(self, p):
  90. """
  91. Sort polygons.
  92. """
  93. min_axis = np.argmin(np.sum(p, axis=1))
  94. p = p[[min_axis, (min_axis + 1) % 4,\
  95. (min_axis + 2) % 4, (min_axis + 3) % 4]]
  96. if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
  97. return p
  98. else:
  99. return p[[0, 3, 2, 1]]
  100. def __call__(self, outs_dict, shape_list):
  101. score_list = outs_dict['f_score']
  102. geo_list = outs_dict['f_geo']
  103. if isinstance(score_list, paddle.Tensor):
  104. score_list = score_list.numpy()
  105. geo_list = geo_list.numpy()
  106. img_num = len(shape_list)
  107. dt_boxes_list = []
  108. for ino in range(img_num):
  109. score = score_list[ino]
  110. geo = geo_list[ino]
  111. boxes = self.detect(
  112. score_map=score,
  113. geo_map=geo,
  114. score_thresh=self.score_thresh,
  115. cover_thresh=self.cover_thresh,
  116. nms_thresh=self.nms_thresh)
  117. boxes_norm = []
  118. if len(boxes) > 0:
  119. h, w = score.shape[1:]
  120. src_h, src_w, ratio_h, ratio_w = shape_list[ino]
  121. boxes = boxes[:, :8].reshape((-1, 4, 2))
  122. boxes[:, :, 0] /= ratio_w
  123. boxes[:, :, 1] /= ratio_h
  124. for i_box, box in enumerate(boxes):
  125. box = self.sort_poly(box.astype(np.int32))
  126. if np.linalg.norm(box[0] - box[1]) < 5 \
  127. or np.linalg.norm(box[3] - box[0]) < 5:
  128. continue
  129. boxes_norm.append(box)
  130. dt_boxes_list.append({'points': np.array(boxes_norm)})
  131. return dt_boxes_list