eval_det_iou.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from collections import namedtuple
  4. import numpy as np
  5. from shapely.geometry import Polygon
  6. """
  7. reference from :
  8. https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
  9. """
  10. class DetectionIoUEvaluator(object):
  11. def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
  12. self.iou_constraint = iou_constraint
  13. self.area_precision_constraint = area_precision_constraint
  14. def evaluate_image(self, gt, pred):
  15. def get_union(pD, pG):
  16. return Polygon(pD).union(Polygon(pG)).area
  17. def get_intersection_over_union(pD, pG):
  18. return get_intersection(pD, pG) / get_union(pD, pG)
  19. def get_intersection(pD, pG):
  20. return Polygon(pD).intersection(Polygon(pG)).area
  21. def compute_ap(confList, matchList, numGtCare):
  22. correct = 0
  23. AP = 0
  24. if len(confList) > 0:
  25. confList = np.array(confList)
  26. matchList = np.array(matchList)
  27. sorted_ind = np.argsort(-confList)
  28. confList = confList[sorted_ind]
  29. matchList = matchList[sorted_ind]
  30. for n in range(len(confList)):
  31. match = matchList[n]
  32. if match:
  33. correct += 1
  34. AP += float(correct) / (n + 1)
  35. if numGtCare > 0:
  36. AP /= numGtCare
  37. return AP
  38. perSampleMetrics = {}
  39. matchedSum = 0
  40. Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
  41. numGlobalCareGt = 0
  42. numGlobalCareDet = 0
  43. arrGlobalConfidences = []
  44. arrGlobalMatches = []
  45. recall = 0
  46. precision = 0
  47. hmean = 0
  48. detMatched = 0
  49. iouMat = np.empty([1, 1])
  50. gtPols = []
  51. detPols = []
  52. gtPolPoints = []
  53. detPolPoints = []
  54. # Array of Ground Truth Polygons' keys marked as don't Care
  55. gtDontCarePolsNum = []
  56. # Array of Detected Polygons' matched with a don't Care GT
  57. detDontCarePolsNum = []
  58. pairs = []
  59. detMatchedNums = []
  60. arrSampleConfidences = []
  61. arrSampleMatch = []
  62. evaluationLog = ""
  63. # print(len(gt))
  64. for n in range(len(gt)):
  65. points = gt[n]['points']
  66. # transcription = gt[n]['text']
  67. dontCare = gt[n]['ignore']
  68. # points = Polygon(points)
  69. # points = points.buffer(0)
  70. if not Polygon(points).is_valid or not Polygon(points).is_simple:
  71. continue
  72. gtPol = points
  73. gtPols.append(gtPol)
  74. gtPolPoints.append(points)
  75. if dontCare:
  76. gtDontCarePolsNum.append(len(gtPols) - 1)
  77. evaluationLog += "GT polygons: " + str(len(gtPols)) + (
  78. " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
  79. if len(gtDontCarePolsNum) > 0 else "\n")
  80. for n in range(len(pred)):
  81. points = pred[n]['points']
  82. # points = Polygon(points)
  83. # points = points.buffer(0)
  84. if not Polygon(points).is_valid or not Polygon(points).is_simple:
  85. continue
  86. detPol = points
  87. detPols.append(detPol)
  88. detPolPoints.append(points)
  89. if len(gtDontCarePolsNum) > 0:
  90. for dontCarePol in gtDontCarePolsNum:
  91. dontCarePol = gtPols[dontCarePol]
  92. intersected_area = get_intersection(dontCarePol, detPol)
  93. pdDimensions = Polygon(detPol).area
  94. precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
  95. if (precision > self.area_precision_constraint):
  96. detDontCarePolsNum.append(len(detPols) - 1)
  97. break
  98. evaluationLog += "DET polygons: " + str(len(detPols)) + (
  99. " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
  100. if len(detDontCarePolsNum) > 0 else "\n")
  101. if len(gtPols) > 0 and len(detPols) > 0:
  102. # Calculate IoU and precision matrixs
  103. outputShape = [len(gtPols), len(detPols)]
  104. iouMat = np.empty(outputShape)
  105. gtRectMat = np.zeros(len(gtPols), np.int8)
  106. detRectMat = np.zeros(len(detPols), np.int8)
  107. for gtNum in range(len(gtPols)):
  108. for detNum in range(len(detPols)):
  109. pG = gtPols[gtNum]
  110. pD = detPols[detNum]
  111. iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
  112. for gtNum in range(len(gtPols)):
  113. for detNum in range(len(detPols)):
  114. if gtRectMat[gtNum] == 0 and detRectMat[
  115. detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
  116. if iouMat[gtNum, detNum] > self.iou_constraint:
  117. gtRectMat[gtNum] = 1
  118. detRectMat[detNum] = 1
  119. detMatched += 1
  120. pairs.append({'gt': gtNum, 'det': detNum})
  121. detMatchedNums.append(detNum)
  122. evaluationLog += "Match GT #" + \
  123. str(gtNum) + " with Det #" + str(detNum) + "\n"
  124. numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
  125. numDetCare = (len(detPols) - len(detDontCarePolsNum))
  126. if numGtCare == 0:
  127. recall = float(1)
  128. precision = float(0) if numDetCare > 0 else float(1)
  129. else:
  130. recall = float(detMatched) / numGtCare
  131. precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
  132. hmean = 0 if (precision + recall) == 0 else 2.0 * \
  133. precision * recall / (precision + recall)
  134. matchedSum += detMatched
  135. numGlobalCareGt += numGtCare
  136. numGlobalCareDet += numDetCare
  137. perSampleMetrics = {
  138. 'precision': precision,
  139. 'recall': recall,
  140. 'hmean': hmean,
  141. 'pairs': pairs,
  142. 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
  143. 'gtPolPoints': gtPolPoints,
  144. 'detPolPoints': detPolPoints,
  145. 'gtCare': numGtCare,
  146. 'detCare': numDetCare,
  147. 'gtDontCare': gtDontCarePolsNum,
  148. 'detDontCare': detDontCarePolsNum,
  149. 'detMatched': detMatched,
  150. 'evaluationLog': evaluationLog
  151. }
  152. return perSampleMetrics
  153. def combine_results(self, results):
  154. numGlobalCareGt = 0
  155. numGlobalCareDet = 0
  156. matchedSum = 0
  157. for result in results:
  158. numGlobalCareGt += result['gtCare']
  159. numGlobalCareDet += result['detCare']
  160. matchedSum += result['detMatched']
  161. methodRecall = 0 if numGlobalCareGt == 0 else float(
  162. matchedSum) / numGlobalCareGt
  163. methodPrecision = 0 if numGlobalCareDet == 0 else float(
  164. matchedSum) / numGlobalCareDet
  165. methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
  166. methodRecall * methodPrecision / (methodRecall + methodPrecision)
  167. # print(methodRecall, methodPrecision, methodHmean)
  168. # sys.exit(-1)
  169. methodMetrics = {
  170. 'precision': methodPrecision,
  171. 'recall': methodRecall,
  172. 'hmean': methodHmean
  173. }
  174. return methodMetrics
  175. if __name__ == '__main__':
  176. evaluator = DetectionIoUEvaluator()
  177. gts = [[{
  178. 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
  179. 'text': 1234,
  180. 'ignore': False,
  181. }, {
  182. 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
  183. 'text': 5678,
  184. 'ignore': False,
  185. }]]
  186. preds = [[{
  187. 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
  188. 'text': 123,
  189. 'ignore': False,
  190. }]]
  191. results = []
  192. for gt, pred in zip(gts, preds):
  193. results.append(evaluator.evaluate_image(gt, pred))
  194. metrics = evaluator.combine_results(results)
  195. print(metrics)