123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import sys
- __dir__ = os.path.dirname(__file__)
- sys.path.append(__dir__)
- sys.path.append(os.path.join(__dir__, '..'))
- import numpy as np
- from .locality_aware_nms import nms_locality
- import paddle
- import cv2
- import time
- class SASTPostProcess(object):
- """
- The post process for SAST.
- """
- def __init__(self,
- score_thresh=0.5,
- nms_thresh=0.2,
- sample_pts_num=2,
- shrink_ratio_of_width=0.3,
- expand_scale=1.0,
- tcl_map_thresh=0.5,
- **kwargs):
- self.score_thresh = score_thresh
- self.nms_thresh = nms_thresh
- self.sample_pts_num = sample_pts_num
- self.shrink_ratio_of_width = shrink_ratio_of_width
- self.expand_scale = expand_scale
- self.tcl_map_thresh = tcl_map_thresh
-
- # c++ la-nms is faster, but only support python 3.5
- self.is_python35 = False
- if sys.version_info.major == 3 and sys.version_info.minor == 5:
- self.is_python35 = True
-
- def point_pair2poly(self, point_pair_list):
- """
- Transfer vertical point_pairs into poly point in clockwise.
- """
- # constract poly
- point_num = len(point_pair_list) * 2
- point_list = [0] * point_num
- for idx, point_pair in enumerate(point_pair_list):
- point_list[idx] = point_pair[0]
- point_list[point_num - 1 - idx] = point_pair[1]
- return np.array(point_list).reshape(-1, 2)
-
- def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
- """
- Generate shrink_quad_along_width.
- """
- ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
- p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
- p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
- return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
-
- def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
- """
- expand poly along width.
- """
- point_num = poly.shape[0]
- left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
- left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
- (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
- left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
- right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
- poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
- right_ratio = 1.0 + \
- shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
- (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
- right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
- poly[0] = left_quad_expand[0]
- poly[-1] = left_quad_expand[-1]
- poly[point_num // 2 - 1] = right_quad_expand[1]
- poly[point_num // 2] = right_quad_expand[2]
- return poly
- def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
- """Restore quad."""
- xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
- xy_text = xy_text[:, ::-1] # (n, 2)
- # Sort the text boxes via the y axis
- xy_text = xy_text[np.argsort(xy_text[:, 1])]
- scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
- scores = scores[:, np.newaxis]
- # Restore
- point_num = int(tvo_map.shape[-1] / 2)
- assert point_num == 4
- tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
- xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
- quads = xy_text_tile - tvo_map
- return scores, quads, xy_text
- def quad_area(self, quad):
- """
- compute area of a quad.
- """
- edge = [
- (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
- (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
- (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
- (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
- ]
- return np.sum(edge) / 2.
-
- def nms(self, dets):
- if self.is_python35:
- import lanms
- dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
- else:
- dets = nms_locality(dets, self.nms_thresh)
- return dets
- def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
- """
- Cluster pixels in tcl_map based on quads.
- """
- instance_count = quads.shape[0] + 1 # contain background
- instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
- if instance_count == 1:
- return instance_count, instance_label_map
- # predict text center
- xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
- n = xy_text.shape[0]
- xy_text = xy_text[:, ::-1] # (n, 2)
- tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
- pred_tc = xy_text - tco
-
- # get gt text center
- m = quads.shape[0]
- gt_tc = np.mean(quads, axis=1) # (m, 2)
- pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
- gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
- dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
- xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
- instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
- return instance_count, instance_label_map
- def estimate_sample_pts_num(self, quad, xy_text):
- """
- Estimate sample points number.
- """
- eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
- ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
- dense_sample_pts_num = max(2, int(ew))
- dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
- endpoint=True, dtype=np.float32).astype(np.int32)]
- dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
- estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
- sample_pts_num = max(2, int(estimate_arc_len / eh))
- return sample_pts_num
- def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
- shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
- """
- first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
- """
- # restore quad
- scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
- dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
- dets = self.nms(dets)
- if dets.shape[0] == 0:
- return []
- quads = dets[:, :-1].reshape(-1, 4, 2)
- # Compute quad area
- quad_areas = []
- for quad in quads:
- quad_areas.append(-self.quad_area(quad))
- # instance segmentation
- # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
- instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
- # restore single poly with tcl instance.
- poly_list = []
- for instance_idx in range(1, instance_count):
- xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
- quad = quads[instance_idx - 1]
- q_area = quad_areas[instance_idx - 1]
- if q_area < 5:
- continue
-
- #
- len1 = float(np.linalg.norm(quad[0] -quad[1]))
- len2 = float(np.linalg.norm(quad[1] -quad[2]))
- min_len = min(len1, len2)
- if min_len < 3:
- continue
- # filter small CC
- if xy_text.shape[0] <= 0:
- continue
- # filter low confidence instance
- xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
- if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
- # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
- continue
- # sort xy_text
- left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
- (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
- right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
- (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
- proj_unit_vec = (right_center_pt - left_center_pt) / \
- (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
- proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
- xy_text = xy_text[np.argsort(proj_value)]
- # Sample pts in tcl map
- if self.sample_pts_num == 0:
- sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
- else:
- sample_pts_num = self.sample_pts_num
- xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
- endpoint=True, dtype=np.float32).astype(np.int32)]
- point_pair_list = []
- for x, y in xy_center_line:
- # get corresponding offset
- offset = tbo_map[y, x, :].reshape(2, 2)
- if offset_expand != 1.0:
- offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
- expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
- offset_detal = offset / offset_length * expand_length
- offset = offset + offset_detal
- # original point
- ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
- point_pair_list.append(point_pair)
- # ndarry: (x, 2), expand poly along width
- detected_poly = self.point_pair2poly(point_pair_list)
- detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
- detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
- poly_list.append(detected_poly)
- return poly_list
- def __call__(self, outs_dict, shape_list):
- score_list = outs_dict['f_score']
- border_list = outs_dict['f_border']
- tvo_list = outs_dict['f_tvo']
- tco_list = outs_dict['f_tco']
- if isinstance(score_list, paddle.Tensor):
- score_list = score_list.numpy()
- border_list = border_list.numpy()
- tvo_list = tvo_list.numpy()
- tco_list = tco_list.numpy()
-
- img_num = len(shape_list)
- poly_lists = []
- for ino in range(img_num):
- p_score = score_list[ino].transpose((1,2,0))
- p_border = border_list[ino].transpose((1,2,0))
- p_tvo = tvo_list[ino].transpose((1,2,0))
- p_tco = tco_list[ino].transpose((1,2,0))
- src_h, src_w, ratio_h, ratio_w = shape_list[ino]
- poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
- shrink_ratio_of_width=self.shrink_ratio_of_width,
- tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
- poly_lists.append({'points': np.array(poly_list)})
- return poly_lists
|