locality_aware_nms.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. """
  2. Locality aware nms.
  3. """
  4. import numpy as np
  5. from shapely.geometry import Polygon
  6. def intersection(g, p):
  7. """
  8. Intersection.
  9. """
  10. g = Polygon(g[:8].reshape((4, 2)))
  11. p = Polygon(p[:8].reshape((4, 2)))
  12. g = g.buffer(0)
  13. p = p.buffer(0)
  14. if not g.is_valid or not p.is_valid:
  15. return 0
  16. inter = Polygon(g).intersection(Polygon(p)).area
  17. union = g.area + p.area - inter
  18. if union == 0:
  19. return 0
  20. else:
  21. return inter / union
  22. def intersection_iog(g, p):
  23. """
  24. Intersection_iog.
  25. """
  26. g = Polygon(g[:8].reshape((4, 2)))
  27. p = Polygon(p[:8].reshape((4, 2)))
  28. if not g.is_valid or not p.is_valid:
  29. return 0
  30. inter = Polygon(g).intersection(Polygon(p)).area
  31. #union = g.area + p.area - inter
  32. union = p.area
  33. if union == 0:
  34. print("p_area is very small")
  35. return 0
  36. else:
  37. return inter / union
  38. def weighted_merge(g, p):
  39. """
  40. Weighted merge.
  41. """
  42. g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
  43. g[8] = (g[8] + p[8])
  44. return g
  45. def standard_nms(S, thres):
  46. """
  47. Standard nms.
  48. """
  49. order = np.argsort(S[:, 8])[::-1]
  50. keep = []
  51. while order.size > 0:
  52. i = order[0]
  53. keep.append(i)
  54. ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
  55. inds = np.where(ovr <= thres)[0]
  56. order = order[inds + 1]
  57. return S[keep]
  58. def standard_nms_inds(S, thres):
  59. """
  60. Standard nms, retun inds.
  61. """
  62. order = np.argsort(S[:, 8])[::-1]
  63. keep = []
  64. while order.size > 0:
  65. i = order[0]
  66. keep.append(i)
  67. ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
  68. inds = np.where(ovr <= thres)[0]
  69. order = order[inds + 1]
  70. return keep
  71. def nms(S, thres):
  72. """
  73. nms.
  74. """
  75. order = np.argsort(S[:, 8])[::-1]
  76. keep = []
  77. while order.size > 0:
  78. i = order[0]
  79. keep.append(i)
  80. ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
  81. inds = np.where(ovr <= thres)[0]
  82. order = order[inds + 1]
  83. return keep
  84. def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
  85. """
  86. soft_nms
  87. :para boxes_in, N x 9 (coords + score)
  88. :para threshould, eliminate cases min score(0.001)
  89. :para Nt_thres, iou_threshi
  90. :para sigma, gaussian weght
  91. :method, linear or gaussian
  92. """
  93. boxes = boxes_in.copy()
  94. N = boxes.shape[0]
  95. if N is None or N < 1:
  96. return np.array([])
  97. pos, maxpos = 0, 0
  98. weight = 0.0
  99. inds = np.arange(N)
  100. tbox, sbox = boxes[0].copy(), boxes[0].copy()
  101. for i in range(N):
  102. maxscore = boxes[i, 8]
  103. maxpos = i
  104. tbox = boxes[i].copy()
  105. ti = inds[i]
  106. pos = i + 1
  107. #get max box
  108. while pos < N:
  109. if maxscore < boxes[pos, 8]:
  110. maxscore = boxes[pos, 8]
  111. maxpos = pos
  112. pos = pos + 1
  113. #add max box as a detection
  114. boxes[i, :] = boxes[maxpos, :]
  115. inds[i] = inds[maxpos]
  116. #swap
  117. boxes[maxpos, :] = tbox
  118. inds[maxpos] = ti
  119. tbox = boxes[i].copy()
  120. pos = i + 1
  121. #NMS iteration
  122. while pos < N:
  123. sbox = boxes[pos].copy()
  124. ts_iou_val = intersection(tbox, sbox)
  125. if ts_iou_val > 0:
  126. if method == 1:
  127. if ts_iou_val > Nt_thres:
  128. weight = 1 - ts_iou_val
  129. else:
  130. weight = 1
  131. elif method == 2:
  132. weight = np.exp(-1.0 * ts_iou_val**2 / sigma)
  133. else:
  134. if ts_iou_val > Nt_thres:
  135. weight = 0
  136. else:
  137. weight = 1
  138. boxes[pos, 8] = weight * boxes[pos, 8]
  139. #if box score falls below thresold, discard the box by
  140. #swaping last box update N
  141. if boxes[pos, 8] < threshold:
  142. boxes[pos, :] = boxes[N - 1, :]
  143. inds[pos] = inds[N - 1]
  144. N = N - 1
  145. pos = pos - 1
  146. pos = pos + 1
  147. return boxes[:N]
  148. def nms_locality(polys, thres=0.3):
  149. """
  150. locality aware nms of EAST
  151. :param polys: a N*9 numpy array. first 8 coordinates, then prob
  152. :return: boxes after nms
  153. """
  154. S = []
  155. p = None
  156. for g in polys:
  157. if p is not None and intersection(g, p) > thres:
  158. p = weighted_merge(g, p)
  159. else:
  160. if p is not None:
  161. S.append(p)
  162. p = g
  163. if p is not None:
  164. S.append(p)
  165. if len(S) == 0:
  166. return np.array([])
  167. return standard_nms(np.array(S), thres)
  168. if __name__ == '__main__':
  169. # 343,350,448,135,474,143,369,359
  170. print(
  171. Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]]))
  172. .area)