det_metric.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. __all__ = ['DetMetric']
  18. from .eval_det_iou import DetectionIoUEvaluator
  19. class DetMetric(object):
  20. def __init__(self, main_indicator='hmean', **kwargs):
  21. self.evaluator = DetectionIoUEvaluator()
  22. self.main_indicator = main_indicator
  23. self.reset()
  24. def __call__(self, preds, batch, **kwargs):
  25. '''
  26. batch: a list produced by dataloaders.
  27. image: np.ndarray of shape (N, C, H, W).
  28. ratio_list: np.ndarray of shape(N,2)
  29. polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  30. ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
  31. preds: a list of dict produced by post process
  32. points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  33. '''
  34. gt_polyons_batch = batch[2]
  35. ignore_tags_batch = batch[3]
  36. for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
  37. ignore_tags_batch):
  38. # prepare gt
  39. gt_info_list = [{
  40. 'points': gt_polyon,
  41. 'text': '',
  42. 'ignore': ignore_tag
  43. } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
  44. # prepare det
  45. det_info_list = [{
  46. 'points': det_polyon,
  47. 'text': ''
  48. } for det_polyon in pred['points']]
  49. result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
  50. self.results.append(result)
  51. def get_metric(self):
  52. """
  53. return metrics {
  54. 'precision': 0,
  55. 'recall': 0,
  56. 'hmean': 0
  57. }
  58. """
  59. metircs = self.evaluator.combine_results(self.results)
  60. self.reset()
  61. return metircs
  62. def reset(self):
  63. self.results = [] # clear results