123456789101112131415161718192021222324252627282930313233 |
- import paddle
- class ClsPostProcess(object):
- """ Convert between text-label and text-index """
- def __init__(self, label_list, **kwargs):
- super(ClsPostProcess, self).__init__()
- self.label_list = label_list
- def __call__(self, preds, label=None, *args, **kwargs):
- if isinstance(preds, paddle.Tensor):
- preds = preds.numpy()
- pred_idxs = preds.argmax(axis=1)
- decode_out = [(self.label_list[idx], preds[i, idx])
- for i, idx in enumerate(pred_idxs)]
- if label is None:
- return decode_out
- label = [(self.label_list[idx], 1.0) for idx in label]
- return decode_out, label
|