rec_metric.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import Levenshtein
  15. class RecMetric(object):
  16. def __init__(self, main_indicator='acc', **kwargs):
  17. self.main_indicator = main_indicator
  18. self.reset()
  19. def __call__(self, pred_label, *args, **kwargs):
  20. preds, labels = pred_label
  21. correct_num = 0
  22. all_num = 0
  23. norm_edit_dis = 0.0
  24. for (pred, pred_conf), (target, _) in zip(preds, labels):
  25. pred = pred.replace(" ", "")
  26. target = target.replace(" ", "")
  27. norm_edit_dis += Levenshtein.distance(pred, target) / max(
  28. len(pred), len(target), 1)
  29. if pred == target:
  30. correct_num += 1
  31. all_num += 1
  32. self.correct_num += correct_num
  33. self.all_num += all_num
  34. self.norm_edit_dis += norm_edit_dis
  35. return {
  36. 'acc': correct_num / all_num,
  37. 'norm_edit_dis': 1 - norm_edit_dis / all_num
  38. }
  39. def get_metric(self):
  40. """
  41. return metrics {
  42. 'acc': 0,
  43. 'norm_edit_dis': 0,
  44. }
  45. """
  46. acc = 1.0 * self.correct_num / self.all_num
  47. norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
  48. self.reset()
  49. return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
  50. def reset(self):
  51. self.correct_num = 0
  52. self.all_num = 0
  53. self.norm_edit_dis = 0