diff --git a/.github/workflows/publish_whl.yml b/.github/workflows/publish_whl.yml index f9f0ae6..f2ea7d5 100644 --- a/.github/workflows/publish_whl.yml +++ b/.github/workflows/publish_whl.yml @@ -56,7 +56,7 @@ jobs: ZIP_NAME=${RESOURCES_URL##*/} DIR_NAME=${ZIP_NAME%.*} unzip $ZIP_NAME - mv $DIR_NAME/en_ppstructure_mobile_v2_SLANet.onnx rapid_table/models/ + mv $DIR_NAME/slanet-plus.onnx rapid_table/models/ python setup.py bdist_wheel ${{ github.event.head_commit.message }} - name: Publish distribution 📦 to PyPI diff --git a/README.md b/README.md index 0a627b9..adfdb7d 100644 --- a/README.md +++ b/README.md @@ -19,13 +19,13 @@ RapidTable库是专门用来文档类图像的表格结构还原,结合RapidOC 目前支持两种类别的表格识别模型:中文和英文表格识别模型,具体可参见下面表格: -slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升,但paddle2onnx暂时不支持转换 +slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升 | 模型类型 | 模型名称 | 模型大小 | |:--------------:|:--------------------------------------:| :------: | | 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M | | 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M | - | slanet_plus 中文 | `inference.pdmodel` | 7.4M | + | slanet_plus 中文 | `slanet-plus.onnx` | 6.8M | 模型来源:[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md) @@ -45,26 +45,23 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu ### 安装 -由于模型较小,预先将英文表格识别模型(`en_ppstructure_mobile_v2_SLANet.onnx`)打包进了whl包内,如果做英文表格识别,可直接安装使用。 +由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。 > ⚠️注意:`rapid_table>=v0.1.0`之后,不再将`rapidocr_onnxruntime`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr_onnxruntime`包。 ```bash pip install rapidocr_onnxruntime pip install rapid_table -# 安装会引入paddlepaddle cpu 3.0.0b0 -#pip install slanet_plus_table ``` ### 使用方式 #### python脚本运行 -RapidTable类提供model_path参数,可以自行指定上述2个模型,默认是`en_ppstructure_mobile_v2_SLANet.onnx`。举例如下: +RapidTable类提供model_path参数,可以自行指定上述2个模型,默认是`slanet-plus.onnx`。举例如下: ```python -table_engine = RapidTable(model_path='ch_ppstructure_mobile_v2_SLANet.onnx') -#table_engine = SLANetPlus() +table_engine = RapidTable() ``` 完整示例: @@ -72,11 +69,9 @@ table_engine = RapidTable(model_path='ch_ppstructure_mobile_v2_SLANet.onnx') ```python from pathlib import Path -from rapid_table import RapidTable from rapid_table import RapidTable, VisTable table_engine = RapidTable() -#table_engine = SLANetPlus() ocr_engine = RapidOCR() viser = VisTable() diff --git a/rapid_table/main.py b/rapid_table/main.py index e8907a3..02d6516 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -19,12 +19,13 @@ class RapidTable: - def __init__(self, model_path: Optional[str] = None): + def __init__(self, model_path: Optional[str] = None, model_type: str = None): if model_path is None: model_path = str( - root_dir / "models" / "en_ppstructure_mobile_v2_SLANet.onnx" + root_dir / "models" / "slanet-plus.onnx" ) - + model_type = "slanet-plus" + self.model_type = model_type self.load_img = LoadImage() self.table_structure = TableStructurer(model_path) self.table_matcher = TableMatch() @@ -54,6 +55,9 @@ def __call__( dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w) pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img)) + # 适配slanet-plus模型输出的box缩放还原 + if self.model_type == "slanet-plus": + pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes) pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) elapse = time.time() - s @@ -76,7 +80,15 @@ def get_boxes_recs( r_boxes.append(box) dt_boxes = np.array(r_boxes) return dt_boxes, rec_res - + def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndarray: + h, w = img.shape[:2] + resized = 488 + ratio = min(resized / h, resized / w) + w_ratio = resized / (w * ratio) + h_ratio = resized / (h * ratio) + pred_bboxes[:, 0::2] *= w_ratio + pred_bboxes[:, 1::2] *= h_ratio + return pred_bboxes def main(): parser = argparse.ArgumentParser() diff --git a/rapid_table/utils.py b/rapid_table/utils.py index bb75e42..cc860b3 100644 --- a/rapid_table/utils.py +++ b/rapid_table/utils.py @@ -114,7 +114,7 @@ def __call__( return drawed_img def insert_border_style(self, table_html_str: str): - style_res = """""" diff --git a/setup.py b/setup.py index dea104a..0539473 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ def get_readme(): f"{MODULE_NAME}.table_matcher", f"{MODULE_NAME}.table_structure", ], - package_data={"": ["en_ppstructure_mobile_v2_SLANet.onnx"]}, + package_data={"": ["slanet-plus.onnx"]}, keywords=["ppstructure,table,rapidocr,rapid_table"], classifiers=[ "Programming Language :: Python :: 3.6", diff --git a/slanet_plus_table/__init__.py b/slanet_plus_table/__init__.py deleted file mode 100644 index 418f3c1..0000000 --- a/slanet_plus_table/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .main import SLANetPlus -from .utils import VisTable diff --git a/slanet_plus_table/main.py b/slanet_plus_table/main.py deleted file mode 100644 index 14eeb62..0000000 --- a/slanet_plus_table/main.py +++ /dev/null @@ -1,103 +0,0 @@ -import copy -import importlib -import time -from pathlib import Path -from typing import Optional, Union, List, Tuple - -import cv2 -import numpy as np - -from slanet_plus_table.table_matcher import TableMatch -from slanet_plus_table.table_structure import TableStructurer -from slanet_plus_table.utils import LoadImage, VisTable - -root_dir = Path(__file__).resolve().parent - - -class SLANetPlus: - def __init__(self, model_path: Optional[str] = None): - if model_path is None: - model_path = str( - root_dir / "models" - ) - - self.load_img = LoadImage() - self.table_structure = TableStructurer(model_path) - self.table_matcher = TableMatch() - - try: - self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() - except ModuleNotFoundError: - self.ocr_engine = None - - def __call__( - self, - img_content: Union[str, np.ndarray, bytes, Path], - ocr_result: List[Union[List[List[float]], str, str]] = None, - ) -> Tuple[str, float]: - if self.ocr_engine is None and ocr_result is None: - raise ValueError( - "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." - ) - - img = self.load_img(img_content) - - s = time.time() - h, w = img.shape[:2] - - if ocr_result is None: - ocr_result, _ = self.ocr_engine(img) - dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w) - - pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img)) - pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) - - elapse = time.time() - s - return pred_html, pred_bboxes, elapse - - def get_boxes_recs( - self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int - ) -> Tuple[np.ndarray, Tuple[str, str]]: - dt_boxes, rec_res, scores = list(zip(*ocr_result)) - rec_res = list(zip(rec_res, scores)) - - r_boxes = [] - for box in dt_boxes: - box = np.array(box) - x_min = max(0, box[:, 0].min() - 1) - x_max = min(w, box[:, 0].max() + 1) - y_min = max(0, box[:, 1].min() - 1) - y_max = min(h, box[:, 1].max() + 1) - box = [x_min, y_min, x_max, y_max] - r_boxes.append(box) - dt_boxes = np.array(r_boxes) - return dt_boxes, rec_res - - -if __name__ == '__main__': - slanet_table = SLANetPlus() - img_path = "D:\pythonProjects\TableStructureRec\outputs\\benchmark\\border_left_7267_OEJGHZF525Q011X2ZC34.jpg" - img = cv2.imread(img_path) - try: - ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime." - ) from exc - ocr_result, _ = ocr_engine(img) - table_html_str, table_cell_bboxes, elapse = slanet_table(img, ocr_result) - - viser = VisTable() - - img_path = Path(img_path) - - save_dir = "outputs" - save_html_path = f"{save_dir}/{Path(img_path).stem}.html" - save_drawed_path = f"{save_dir}/vis_{Path(img_path).name}" - viser( - img_path, - table_html_str, - save_html_path, - table_cell_bboxes, - save_drawed_path, - ) diff --git a/slanet_plus_table/models/inference.yml b/slanet_plus_table/models/inference.yml deleted file mode 100644 index dd6db0d..0000000 --- a/slanet_plus_table/models/inference.yml +++ /dev/null @@ -1,106 +0,0 @@ -Hpi: - backend_config: - paddle_infer: - cpu_num_threads: 8 - enable_log_info: false - selected_backends: - cpu: paddle_infer - gpu: paddle_infer - supported_backends: - cpu: - - paddle_infer - gpu: - - paddle_infer -Global: - model_name: SLANet_plus -PreProcess: - transform_ops: - - DecodeImage: - channel_first: false - img_mode: BGR - - TableLabelEncode: - learn_empty_box: false - loc_reg_num: 8 - max_text_length: 500 - merge_no_span_structure: true - replace_empty_cell_token: false - - TableBoxEncode: - in_box_format: xyxyxyxy - out_box_format: xyxyxyxy - - ResizeTableImage: - max_len: 488 - - NormalizeImage: - mean: - - 0.485 - - 0.456 - - 0.406 - order: hwc - scale: 1./255. - std: - - 0.229 - - 0.224 - - 0.225 - - PaddingTableImage: - size: - - 488 - - 488 - - ToCHWImage: null - - KeepKeys: - keep_keys: - - image - - structure - - bboxes - - bbox_masks - - shape -PostProcess: - name: TableLabelDecode - merge_no_span_structure: true - character_dict: - - - - - - - - - - - - - - - - ' - - - - ' colspan="2"' - - ' colspan="3"' - - ' colspan="4"' - - ' colspan="5"' - - ' colspan="6"' - - ' colspan="7"' - - ' colspan="8"' - - ' colspan="9"' - - ' colspan="10"' - - ' colspan="11"' - - ' colspan="12"' - - ' colspan="13"' - - ' colspan="14"' - - ' colspan="15"' - - ' colspan="16"' - - ' colspan="17"' - - ' colspan="18"' - - ' colspan="19"' - - ' colspan="20"' - - ' rowspan="2"' - - ' rowspan="3"' - - ' rowspan="4"' - - ' rowspan="5"' - - ' rowspan="6"' - - ' rowspan="7"' - - ' rowspan="8"' - - ' rowspan="9"' - - ' rowspan="10"' - - ' rowspan="11"' - - ' rowspan="12"' - - ' rowspan="13"' - - ' rowspan="14"' - - ' rowspan="15"' - - ' rowspan="16"' - - ' rowspan="17"' - - ' rowspan="18"' - - ' rowspan="19"' - - ' rowspan="20"' diff --git a/slanet_plus_table/requirements.txt b/slanet_plus_table/requirements.txt deleted file mode 100644 index cb794ff..0000000 --- a/slanet_plus_table/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ ---extra-index-url=https://www.paddlepaddle.org.cn/packages/stable/cpu/ -opencv_python>=4.5.1.48 -numpy>=1.21.6,<2 -paddlepaddle==3.0.0b0 -Pillow -requests diff --git a/slanet_plus_table/setup.py b/slanet_plus_table/setup.py deleted file mode 100644 index 8f20759..0000000 --- a/slanet_plus_table/setup.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import sys - -import setuptools - -from setuptools.command.install import install - -MODULE_NAME = "slanet_plus_table" - - - -setuptools.setup( - name=MODULE_NAME, - version="0.0.2", - platforms="Any", - long_description="simplify paddlex slanet plus table use", - long_description_content_type="text/markdown", - description="Tools for parsing table structures based paddlepaddle.", - author="jockerK", - author_email="xinyijianggo@gmail.com", - url="https://github.com/RapidAI/RapidTable", - license="Apache-2.0", - include_package_data=True, - install_requires=[ - "paddlepaddle==3.0.0b0", - "PyYAML>=6.0", - "opencv_python>=4.5.1.48", - "numpy>=1.21.6", - "Pillow", - ], - packages=[ - MODULE_NAME, - f"{MODULE_NAME}.models", - f"{MODULE_NAME}.table_matcher", - f"{MODULE_NAME}.table_structure", - ], - package_data={"": ["inference.pdiparams","inference.pdmodel"]}, - keywords=["ppstructure,table,rapidocr,rapid_table"], - classifiers=[ - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - python_requires=">=3.8", -) diff --git a/slanet_plus_table/table_matcher/__init__.py b/slanet_plus_table/table_matcher/__init__.py deleted file mode 100644 index 3cebc9e..0000000 --- a/slanet_plus_table/table_matcher/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -*- encoding: utf-8 -*- -from .matcher import TableMatch diff --git a/slanet_plus_table/table_matcher/matcher.py b/slanet_plus_table/table_matcher/matcher.py deleted file mode 100644 index b930c70..0000000 --- a/slanet_plus_table/table_matcher/matcher.py +++ /dev/null @@ -1,125 +0,0 @@ -# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -*- encoding: utf-8 -*- -import numpy as np - -from .utils import compute_iou, distance - - -class TableMatch: - def __init__(self, filter_ocr_result=True, use_master=False): - self.filter_ocr_result = filter_ocr_result - self.use_master = use_master - - def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res): - if self.filter_ocr_result: - dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res) - matched_index = self.match_result(dt_boxes, pred_bboxes) - pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) - return pred_html - - def match_result(self, dt_boxes, pred_bboxes): - matched = {} - for i, gt_box in enumerate(dt_boxes): - distances = [] - for j, pred_box in enumerate(pred_bboxes): - if len(pred_box) == 8: - pred_box = [ - np.min(pred_box[0::2]), - np.min(pred_box[1::2]), - np.max(pred_box[0::2]), - np.max(pred_box[1::2]), - ] - distances.append( - (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box)) - ) # compute iou and l1 distance - sorted_distances = distances.copy() - # select det box by iou and l1 distance - sorted_distances = sorted( - sorted_distances, key=lambda item: (item[1], item[0]) - ) - if distances.index(sorted_distances[0]) not in matched.keys(): - matched[distances.index(sorted_distances[0])] = [i] - else: - matched[distances.index(sorted_distances[0])].append(i) - return matched - - def get_pred_html(self, pred_structures, matched_index, ocr_contents): - end_html = [] - td_index = 0 - for tag in pred_structures: - if "" not in tag: - end_html.append(tag) - continue - - if "" == tag: - end_html.extend("") - - if td_index in matched_index.keys(): - b_with = False - if ( - "" in ocr_contents[matched_index[td_index][0]] - and len(matched_index[td_index]) > 1 - ): - b_with = True - end_html.extend("") - - for i, td_index_index in enumerate(matched_index[td_index]): - content = ocr_contents[td_index_index][0] - if len(matched_index[td_index]) > 1: - if len(content) == 0: - continue - - if content[0] == " ": - content = content[1:] - - if "" in content: - content = content[3:] - - if "" in content: - content = content[:-4] - - if len(content) == 0: - continue - - if i != len(matched_index[td_index]) - 1 and " " != content[-1]: - content += " " - end_html.extend(content) - - if b_with: - end_html.extend("") - - if "" == tag: - end_html.append("") - else: - end_html.append(tag) - - td_index += 1 - - # Filter elements - filter_elements = ["", "", "", ""] - end_html = [v for v in end_html if v not in filter_elements] - return "".join(end_html), end_html - - def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): - y1 = pred_bboxes[:, 1::2].min() - new_dt_boxes = [] - new_rec_res = [] - - for box, rec in zip(dt_boxes, rec_res): - if np.max(box[1::2]) < y1: - continue - new_dt_boxes.append(box) - new_rec_res.append(rec) - return new_dt_boxes, new_rec_res diff --git a/slanet_plus_table/table_matcher/utils.py b/slanet_plus_table/table_matcher/utils.py deleted file mode 100644 index de55fe1..0000000 --- a/slanet_plus_table/table_matcher/utils.py +++ /dev/null @@ -1,36 +0,0 @@ -def distance(box_1, box_2): - x1, y1, x2, y2 = box_1 - x3, y3, x4, y4 = box_2 - dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) - dis_2 = abs(x3 - x1) + abs(y3 - y1) - dis_3 = abs(x4 - x2) + abs(y4 - y2) - return dis + min(dis_2, dis_3) - - -def compute_iou(rec1, rec2): - """ - computing IoU - :param rec1: (y0, x0, y1, x1), which reflects - (top, left, bottom, right) - :param rec2: (y0, x0, y1, x1) - :return: scala value of IoU - """ - # computing area of each rectangles - S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) - S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) - - # computing the sum_area - sum_area = S_rec1 + S_rec2 - - # find the each edge of intersect rectangle - left_line = max(rec1[1], rec2[1]) - right_line = min(rec1[3], rec2[3]) - top_line = max(rec1[0], rec2[0]) - bottom_line = min(rec1[2], rec2[2]) - - # judge if there is an intersect - if left_line >= right_line or top_line >= bottom_line: - return 0.0 - else: - intersect = (right_line - left_line) * (bottom_line - top_line) - return (intersect / (sum_area - intersect)) * 1.0 diff --git a/slanet_plus_table/table_structure/__init__.py b/slanet_plus_table/table_structure/__init__.py deleted file mode 100644 index 9da248f..0000000 --- a/slanet_plus_table/table_structure/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -*- encoding: utf-8 -*- -from .table_structure import TableStructurer diff --git a/slanet_plus_table/table_structure/table_structure.py b/slanet_plus_table/table_structure/table_structure.py deleted file mode 100644 index cd7af51..0000000 --- a/slanet_plus_table/table_structure/table_structure.py +++ /dev/null @@ -1,178 +0,0 @@ -import time -import numpy as np -from .utils import TablePredictor, TablePreprocess, TableLabelDecode - - -# class SLANetPlus: -# def __init__(self, model_dir, model_prefix="inference"): -# self.preprocess_op = TablePreprocess() -# -# self.mean=[0.485, 0.456, 0.406] -# self.std=[0.229, 0.224, 0.225] -# self.target_img_size = [488, 488] -# self.scale=1 / 255 -# self.order="hwc" -# self.img_loader = LoadImage() -# self.target_size = 488 -# self.pad_color = 0 -# self.predictor = TablePredictor(model_dir, model_prefix) -# dict_character=['sos', '', '', '', '', '', '', '', '', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', '', 'eos'] -# self.beg_str = "sos" -# self.end_str = "eos" -# self.dict = {} -# self.table_matcher = TableMatch() -# for i, char in enumerate(dict_character): -# self.dict[char] = i -# self.character = dict_character -# self.td_token = ["", ""] -# -# def call(self, img): -# starttime = time.time() -# data = {"image": img} -# data = self.preprocess_op(data) -# img = data[0] -# if img is None: -# return None, 0 -# img = np.expand_dims(img, axis=0) -# img = img.copy() -# def __call__(self, img, ocr_result): -# img = self.img_loader(img) -# h, w = img.shape[:2] -# n_img, h_resize, w_resize = self.resize(img) -# n_img = self.normalize(n_img) -# n_img = self.pad(n_img) -# n_img = n_img.transpose((2, 0, 1)) -# n_img = np.expand_dims(n_img, axis=0) -# start = time.time() -# batch_output = self.predictor(n_img) -# elapse_time = time.time() - start -# ori_img_size = [[w, h]] -# output = self.decode(batch_output, ori_img_size)[0] -# corners = np.stack(output['bbox'], axis=0) -# dt_boxes, rec_res = get_boxes_recs(ocr_result, h, w) -# pred_html = self.table_matcher(output['structure'], convert_corners_to_bounding_boxes(corners), dt_boxes, rec_res) -# return pred_html,output['bbox'], elapse_time -# def resize(self, img): -# h, w = img.shape[:2] -# scale = self.target_size / max(h, w) -# h_resize = round(h * scale) -# w_resize = round(w * scale) -# resized_img = cv2.resize(img, (w_resize, h_resize), interpolation=cv2.INTER_LINEAR) -# return resized_img, h_resize, w_resize -# def pad(self, img): -# h, w = img.shape[:2] -# tw, th = self.target_img_size -# ph = th - h -# pw = tw - w -# pad = (0, ph, 0, pw) -# chns = 1 if img.ndim == 2 else img.shape[2] -# im = cv2.copyMakeBorder(img, *pad, cv2.BORDER_CONSTANT, value=(self.pad_color,) * chns) -# return im -# def normalize(self, img): -# img = img.astype("float32", copy=False) -# img *= self.scale -# img -= self.mean -# img /= self.std -# return img -# -# -# def decode(self, pred, ori_img_size): -# bbox_preds, structure_probs = [], [] -# for bbox_pred, stru_prob in pred: -# bbox_preds.append(bbox_pred) -# structure_probs.append(stru_prob) -# bbox_preds = np.array(bbox_preds) -# structure_probs = np.array(structure_probs) -# -# bbox_list, structure_str_list, structure_score = self.decode_single( -# structure_probs, bbox_preds, [self.target_img_size], ori_img_size -# ) -# structure_str_list = [ -# ( -# ["", "", ""] -# + structure -# + ["
", "", ""] -# ) -# for structure in structure_str_list -# ] -# return [ -# {"bbox": bbox, "structure": structure, "structure_score": structure_score} -# for bbox, structure in zip(bbox_list, structure_str_list) -# ] -# -# -# def decode_single(self, structure_probs, bbox_preds, padding_size, ori_img_size): -# """convert text-label into text-index.""" -# ignored_tokens = [self.beg_str, self.end_str] -# end_idx = self.dict[self.end_str] -# -# structure_idx = structure_probs.argmax(axis=2) -# structure_probs = structure_probs.max(axis=2) -# -# structure_batch_list = [] -# bbox_batch_list = [] -# batch_size = len(structure_idx) -# for batch_idx in range(batch_size): -# structure_list = [] -# bbox_list = [] -# score_list = [] -# for idx in range(len(structure_idx[batch_idx])): -# char_idx = int(structure_idx[batch_idx][idx]) -# if idx > 0 and char_idx == end_idx: -# break -# if char_idx in ignored_tokens: -# continue -# text = self.character[char_idx] -# if text in self.td_token: -# bbox = bbox_preds[batch_idx, idx] -# bbox = self._bbox_decode( -# bbox, padding_size[batch_idx], ori_img_size[batch_idx] -# ) -# bbox_list.append(bbox.astype(int)) -# structure_list.append(text) -# score_list.append(structure_probs[batch_idx, idx]) -# structure_batch_list.append(structure_list) -# structure_score = np.mean(score_list) -# bbox_batch_list.append(bbox_list) -# -# return bbox_batch_list, structure_batch_list, structure_score -# -# def _bbox_decode(self, bbox, padding_shape, ori_shape): -# -# pad_w, pad_h = padding_shape -# w, h = ori_shape -# ratio_w = pad_w / w -# ratio_h = pad_h / h -# ratio = min(ratio_w, ratio_h) -# -# bbox[0::2] *= pad_w -# bbox[1::2] *= pad_h -# bbox[0::2] /= ratio -# bbox[1::2] /= ratio -# -# return bbox - - -class TableStructurer: - def __init__(self, model_path: str): - self.preprocess_op = TablePreprocess() - self.predictor = TablePredictor(model_path) - self.character = ['', '', '', '', '', '', '', '', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', ''] - self.postprocess_op = TableLabelDecode(self.character) - - def __call__(self, img): - start_time = time.time() - data = {"image": img} - h, w = img.shape[:2] - ori_img_size = [[w, h]] - data = self.preprocess_op(data) - img = data[0] - if img is None: - return None, 0 - img = np.expand_dims(img, axis=0) - img = img.copy() - cur_img_size = [[488, 488]] - outputs = self.predictor(img) - output = self.postprocess_op(outputs, cur_img_size, ori_img_size)[0] - elapse = time.time() - start_time - return output["structure"], np.stack(output["bbox"]), elapse diff --git a/slanet_plus_table/table_structure/utils.py b/slanet_plus_table/table_structure/utils.py deleted file mode 100644 index f1cf7a0..0000000 --- a/slanet_plus_table/table_structure/utils.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -*- encoding: utf-8 -*- -# @Author: Jocker1212 -# @Contact: xinyijianggo@gmail.com -from pathlib import Path - -import cv2 -import numpy as np -from paddle.inference import Config, create_predictor - -class TablePredictor: - def __init__(self, model_dir, model_prefix="inference"): - model_file = f"{model_dir}/{model_prefix}.pdmodel" - params_file = f"{model_dir}/{model_prefix}.pdiparams" - config = Config(model_file, params_file) - config.disable_gpu() - config.disable_glog_info() - config.enable_new_ir(True) - config.enable_new_executor(True) - config.enable_memory_optim() - config.switch_ir_optim(True) - # Disable feed, fetch OP, needed by zero_copy_run - config.switch_use_feed_fetch_ops(False) - predictor = create_predictor(config) - self.config = config - self.predictor = predictor - # Get input and output handlers - input_names = predictor.get_input_names() - self.input_names = input_names.sort() - self.input_handlers = [] - self.output_handlers = [] - for input_name in input_names: - input_handler = predictor.get_input_handle(input_name) - self.input_handlers.append(input_handler) - self.output_names = predictor.get_output_names() - for output_name in self.output_names: - output_handler = predictor.get_output_handle(output_name) - self.output_handlers.append(output_handler) - - def __call__(self, batch_imgs): - self.input_handlers[0].reshape(batch_imgs.shape) - self.input_handlers[0].copy_from_cpu(batch_imgs) - self.predictor.run() - output = [] - for out_tensor in self.output_handlers: - batch = out_tensor.copy_to_cpu() - output.append(batch) - return self.format_output(output) - - def format_output(self, pred): - return [res for res in zip(*pred)] - -class TableLabelDecode: - """decode the table model outputs(probs) to character str""" - def __init__(self, dict_character=[], merge_no_span_structure=True, **kwargs): - - if merge_no_span_structure: - if "" not in dict_character: - dict_character.append("") - if "" in dict_character: - dict_character.remove("") - - dict_character = self.add_special_char(dict_character) - self.dict = {} - for i, char in enumerate(dict_character): - self.dict[char] = i - self.character = dict_character - self.td_token = ["", ""] - - def add_special_char(self, dict_character): - """add_special_char""" - self.beg_str = "sos" - self.end_str = "eos" - dict_character = [self.beg_str] + dict_character + [self.end_str] - return dict_character - - def get_ignored_tokens(self): - """get_ignored_tokens""" - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end): - """get_beg_end_flag_idx""" - if beg_or_end == "beg": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end - return idx - - def __call__(self, pred, img_size, ori_img_size): - """apply""" - bbox_preds, structure_probs = [], [] - for bbox_pred, stru_prob in pred: - bbox_preds.append(bbox_pred) - structure_probs.append(stru_prob) - bbox_preds = np.array(bbox_preds) - structure_probs = np.array(structure_probs) - - bbox_list, structure_str_list, structure_score = self.decode( - structure_probs, bbox_preds, img_size, ori_img_size - ) - structure_str_list = [ - ( - ["", "", ""] - + structure - + ["
", "", ""] - ) - for structure in structure_str_list - ] - return [ - {"bbox": bbox, "structure": structure, "structure_score": structure_score} - for bbox, structure in zip(bbox_list, structure_str_list) - ] - - def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size): - """convert text-label into text-index.""" - ignored_tokens = self.get_ignored_tokens() - end_idx = self.dict[self.end_str] - - structure_idx = structure_probs.argmax(axis=2) - structure_probs = structure_probs.max(axis=2) - - structure_batch_list = [] - bbox_batch_list = [] - batch_size = len(structure_idx) - for batch_idx in range(batch_size): - structure_list = [] - bbox_list = [] - score_list = [] - for idx in range(len(structure_idx[batch_idx])): - char_idx = int(structure_idx[batch_idx][idx]) - if idx > 0 and char_idx == end_idx: - break - if char_idx in ignored_tokens: - continue - text = self.character[char_idx] - if text in self.td_token: - bbox = bbox_preds[batch_idx, idx] - bbox = self._bbox_decode( - bbox, padding_size[batch_idx], ori_img_size[batch_idx] - ) - bbox_list.append(bbox.astype(int)) - structure_list.append(text) - score_list.append(structure_probs[batch_idx, idx]) - structure_batch_list.append(structure_list) - structure_score = np.mean(score_list) - bbox_batch_list.append(bbox_list) - - return bbox_batch_list, structure_batch_list, structure_score - - def _bbox_decode(self, bbox, padding_shape, ori_shape): - - pad_w, pad_h = padding_shape - w, h = ori_shape - ratio_w = pad_w / w - ratio_h = pad_h / h - ratio = min(ratio_w, ratio_h) - - bbox[0::2] *= pad_w - bbox[1::2] *= pad_h - bbox[0::2] /= ratio - bbox[1::2] /= ratio - - return bbox - - -class TablePreprocess: - def __init__(self): - self.table_max_len = 488 - self.build_pre_process_list() - self.ops = self.create_operators() - - def __call__(self, data): - """transform""" - if self.ops is None: - self.ops = [] - - for op in self.ops: - data = op(data) - if data is None: - return None - return data - - def create_operators( - self, - ): - """ - create operators based on the config - - Args: - params(list): a dict list, used to create some operators - """ - assert isinstance( - self.pre_process_list, list - ), "operator config should be a list" - ops = [] - for operator in self.pre_process_list: - assert ( - isinstance(operator, dict) and len(operator) == 1 - ), "yaml format error" - op_name = list(operator)[0] - param = {} if operator[op_name] is None else operator[op_name] - op = eval(op_name)(**param) - ops.append(op) - return ops - - def build_pre_process_list(self): - resize_op = { - "ResizeTableImage": { - "max_len": self.table_max_len, - } - } - pad_op = { - "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]} - } - normalize_op = { - "NormalizeImage": { - "std": [0.229, 0.224, 0.225], - "mean": [0.485, 0.456, 0.406], - "scale": "1./255.", - "order": "hwc", - } - } - to_chw_op = {"ToCHWImage": None} - keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}} - self.pre_process_list = [ - resize_op, - normalize_op, - pad_op, - to_chw_op, - keep_keys_op, - ] - - -class ResizeTableImage: - def __init__(self, max_len, resize_bboxes=False, infer_mode=False): - super(ResizeTableImage, self).__init__() - self.max_len = max_len - self.resize_bboxes = resize_bboxes - self.infer_mode = infer_mode - - def __call__(self, data): - img = data["image"] - height, width = img.shape[0:2] - ratio = self.max_len / (max(height, width) * 1.0) - resize_h = int(height * ratio) - resize_w = int(width * ratio) - resize_img = cv2.resize(img, (resize_w, resize_h)) - if self.resize_bboxes and not self.infer_mode: - data["bboxes"] = data["bboxes"] * ratio - data["image"] = resize_img - data["src_img"] = img - data["shape"] = np.array([height, width, ratio, ratio]) - data["max_len"] = self.max_len - return data - - -class PaddingTableImage: - def __init__(self, size, **kwargs): - super(PaddingTableImage, self).__init__() - self.size = size - - def __call__(self, data): - img = data["image"] - pad_h, pad_w = self.size - padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32) - height, width = img.shape[0:2] - padding_img[0:height, 0:width, :] = img.copy() - data["image"] = padding_img - shape = data["shape"].tolist() - shape.extend([pad_h, pad_w]) - data["shape"] = np.array(shape) - return data - - -class NormalizeImage: - """normalize image such as substract mean, divide std""" - - def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs): - if isinstance(scale, str): - scale = eval(scale) - self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) - mean = mean if mean is not None else [0.485, 0.456, 0.406] - std = std if std is not None else [0.229, 0.224, 0.225] - - shape = (3, 1, 1) if order == "chw" else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype("float32") - self.std = np.array(std).reshape(shape).astype("float32") - - def __call__(self, data): - img = np.array(data["image"]) - assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" - data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std - return data - - -class ToCHWImage: - """convert hwc image to chw image""" - - def __init__(self, **kwargs): - pass - - def __call__(self, data): - img = np.array(data["image"]) - data["image"] = img.transpose((2, 0, 1)) - return data - -class KeepKeys: - def __init__(self, keep_keys, **kwargs): - self.keep_keys = keep_keys - - def __call__(self, data): - data_list = [] - for key in self.keep_keys: - data_list.append(data[key]) - return data_list diff --git a/slanet_plus_table/utils.py b/slanet_plus_table/utils.py deleted file mode 100644 index bb0736c..0000000 --- a/slanet_plus_table/utils.py +++ /dev/null @@ -1,142 +0,0 @@ -from io import BytesIO -from pathlib import Path -from typing import Union, Optional - -import cv2 -import numpy as np -from PIL import Image, UnidentifiedImageError - -InputType = Union[str, np.ndarray, bytes, Path] - -class LoadImage: - def __init__( - self, - ): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - img = self.load_img(img) - - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3 and img.shape[2] == 4: - return self.cvt_four_to_three(img) - - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = np.array(Image.open(img)) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = np.array(Image.open(BytesIO(img))) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - if isinstance(img, np.ndarray): - return img - - raise LoadImageError(f"{type(img)} is not supported!") - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → RGB""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - -class LoadImageError(Exception): - pass - -class VisTable: - def __init__( - self, - ): - self.load_img = LoadImage() - - def __call__( - self, - img_path: Union[str, Path], - table_html_str: str, - save_html_path: Optional[str] = None, - table_cell_bboxes: Optional[np.ndarray] = None, - save_drawed_path: Optional[str] = None, - ) -> None: - if save_html_path: - html_with_border = self.insert_border_style(table_html_str) - self.save_html(save_html_path, html_with_border) - - if table_cell_bboxes is None: - return None - - img = self.load_img(img_path) - - dims_bboxes = table_cell_bboxes.shape[1] - if dims_bboxes == 4: - drawed_img = self.draw_rectangle(img, table_cell_bboxes) - elif dims_bboxes == 8: - drawed_img = self.draw_polylines(img, table_cell_bboxes) - else: - raise ValueError("Shape of table bounding boxes is not between in 4 or 8.") - - if save_drawed_path: - self.save_img(save_drawed_path, drawed_img) - - return drawed_img - - def insert_border_style(self, table_html_str: str): - style_res = """""" - prefix_table, suffix_table = table_html_str.split("") - html_with_border = f"{prefix_table}{style_res}{suffix_table}" - return html_with_border - - @staticmethod - def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: - img_copy = img.copy() - for box in boxes.astype(int): - x1, y1, x2, y2 = box - cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2) - return img_copy - - @staticmethod - def draw_polylines(img: np.ndarray, points) -> np.ndarray: - img_copy = img.copy() - for point in points.astype(int): - point = point.reshape(4, 2) - cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2) - return img_copy - - @staticmethod - def save_img(save_path: Union[str, Path], img: np.ndarray): - cv2.imwrite(str(save_path), img) - - @staticmethod - def save_html(save_path: Union[str, Path], html: str): - with open(save_path, "w", encoding="utf-8") as f: - f.write(html)