diff --git a/.github/workflows/publish_whl.yml b/.github/workflows/publish_whl.yml
index f9f0ae6..f2ea7d5 100644
--- a/.github/workflows/publish_whl.yml
+++ b/.github/workflows/publish_whl.yml
@@ -56,7 +56,7 @@ jobs:
ZIP_NAME=${RESOURCES_URL##*/}
DIR_NAME=${ZIP_NAME%.*}
unzip $ZIP_NAME
- mv $DIR_NAME/en_ppstructure_mobile_v2_SLANet.onnx rapid_table/models/
+ mv $DIR_NAME/slanet-plus.onnx rapid_table/models/
python setup.py bdist_wheel ${{ github.event.head_commit.message }}
- name: Publish distribution 📦 to PyPI
diff --git a/README.md b/README.md
index 0a627b9..adfdb7d 100644
--- a/README.md
+++ b/README.md
@@ -19,13 +19,13 @@ RapidTable库是专门用来文档类图像的表格结构还原,结合RapidOC
目前支持两种类别的表格识别模型:中文和英文表格识别模型,具体可参见下面表格:
-slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升,但paddle2onnx暂时不支持转换
+slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升
| 模型类型 | 模型名称 | 模型大小 |
|:--------------:|:--------------------------------------:| :------: |
| 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M |
| 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M |
- | slanet_plus 中文 | `inference.pdmodel` | 7.4M |
+ | slanet_plus 中文 | `slanet-plus.onnx` | 6.8M |
模型来源:[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)
@@ -45,26 +45,23 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
### 安装
-由于模型较小,预先将英文表格识别模型(`en_ppstructure_mobile_v2_SLANet.onnx`)打包进了whl包内,如果做英文表格识别,可直接安装使用。
+由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。
> ⚠️注意:`rapid_table>=v0.1.0`之后,不再将`rapidocr_onnxruntime`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr_onnxruntime`包。
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
-# 安装会引入paddlepaddle cpu 3.0.0b0
-#pip install slanet_plus_table
```
### 使用方式
#### python脚本运行
-RapidTable类提供model_path参数,可以自行指定上述2个模型,默认是`en_ppstructure_mobile_v2_SLANet.onnx`。举例如下:
+RapidTable类提供model_path参数,可以自行指定上述2个模型,默认是`slanet-plus.onnx`。举例如下:
```python
-table_engine = RapidTable(model_path='ch_ppstructure_mobile_v2_SLANet.onnx')
-#table_engine = SLANetPlus()
+table_engine = RapidTable()
```
完整示例:
@@ -72,11 +69,9 @@ table_engine = RapidTable(model_path='ch_ppstructure_mobile_v2_SLANet.onnx')
```python
from pathlib import Path
-from rapid_table import RapidTable
from rapid_table import RapidTable, VisTable
table_engine = RapidTable()
-#table_engine = SLANetPlus()
ocr_engine = RapidOCR()
viser = VisTable()
diff --git a/rapid_table/main.py b/rapid_table/main.py
index e8907a3..02d6516 100644
--- a/rapid_table/main.py
+++ b/rapid_table/main.py
@@ -19,12 +19,13 @@
class RapidTable:
- def __init__(self, model_path: Optional[str] = None):
+ def __init__(self, model_path: Optional[str] = None, model_type: str = None):
if model_path is None:
model_path = str(
- root_dir / "models" / "en_ppstructure_mobile_v2_SLANet.onnx"
+ root_dir / "models" / "slanet-plus.onnx"
)
-
+ model_type = "slanet-plus"
+ self.model_type = model_type
self.load_img = LoadImage()
self.table_structure = TableStructurer(model_path)
self.table_matcher = TableMatch()
@@ -54,6 +55,9 @@ def __call__(
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img))
+ # 适配slanet-plus模型输出的box缩放还原
+ if self.model_type == "slanet-plus":
+ pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes)
pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
elapse = time.time() - s
@@ -76,7 +80,15 @@ def get_boxes_recs(
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, rec_res
-
+ def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndarray:
+ h, w = img.shape[:2]
+ resized = 488
+ ratio = min(resized / h, resized / w)
+ w_ratio = resized / (w * ratio)
+ h_ratio = resized / (h * ratio)
+ pred_bboxes[:, 0::2] *= w_ratio
+ pred_bboxes[:, 1::2] *= h_ratio
+ return pred_bboxes
def main():
parser = argparse.ArgumentParser()
diff --git a/rapid_table/utils.py b/rapid_table/utils.py
index bb75e42..cc860b3 100644
--- a/rapid_table/utils.py
+++ b/rapid_table/utils.py
@@ -114,7 +114,7 @@ def __call__(
return drawed_img
def insert_border_style(self, table_html_str: str):
- style_res = """"""
diff --git a/setup.py b/setup.py
index dea104a..0539473 100644
--- a/setup.py
+++ b/setup.py
@@ -53,7 +53,7 @@ def get_readme():
f"{MODULE_NAME}.table_matcher",
f"{MODULE_NAME}.table_structure",
],
- package_data={"": ["en_ppstructure_mobile_v2_SLANet.onnx"]},
+ package_data={"": ["slanet-plus.onnx"]},
keywords=["ppstructure,table,rapidocr,rapid_table"],
classifiers=[
"Programming Language :: Python :: 3.6",
diff --git a/slanet_plus_table/__init__.py b/slanet_plus_table/__init__.py
deleted file mode 100644
index 418f3c1..0000000
--- a/slanet_plus_table/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .main import SLANetPlus
-from .utils import VisTable
diff --git a/slanet_plus_table/main.py b/slanet_plus_table/main.py
deleted file mode 100644
index 14eeb62..0000000
--- a/slanet_plus_table/main.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import copy
-import importlib
-import time
-from pathlib import Path
-from typing import Optional, Union, List, Tuple
-
-import cv2
-import numpy as np
-
-from slanet_plus_table.table_matcher import TableMatch
-from slanet_plus_table.table_structure import TableStructurer
-from slanet_plus_table.utils import LoadImage, VisTable
-
-root_dir = Path(__file__).resolve().parent
-
-
-class SLANetPlus:
- def __init__(self, model_path: Optional[str] = None):
- if model_path is None:
- model_path = str(
- root_dir / "models"
- )
-
- self.load_img = LoadImage()
- self.table_structure = TableStructurer(model_path)
- self.table_matcher = TableMatch()
-
- try:
- self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
- except ModuleNotFoundError:
- self.ocr_engine = None
-
- def __call__(
- self,
- img_content: Union[str, np.ndarray, bytes, Path],
- ocr_result: List[Union[List[List[float]], str, str]] = None,
- ) -> Tuple[str, float]:
- if self.ocr_engine is None and ocr_result is None:
- raise ValueError(
- "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
- )
-
- img = self.load_img(img_content)
-
- s = time.time()
- h, w = img.shape[:2]
-
- if ocr_result is None:
- ocr_result, _ = self.ocr_engine(img)
- dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
-
- pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img))
- pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
-
- elapse = time.time() - s
- return pred_html, pred_bboxes, elapse
-
- def get_boxes_recs(
- self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
- ) -> Tuple[np.ndarray, Tuple[str, str]]:
- dt_boxes, rec_res, scores = list(zip(*ocr_result))
- rec_res = list(zip(rec_res, scores))
-
- r_boxes = []
- for box in dt_boxes:
- box = np.array(box)
- x_min = max(0, box[:, 0].min() - 1)
- x_max = min(w, box[:, 0].max() + 1)
- y_min = max(0, box[:, 1].min() - 1)
- y_max = min(h, box[:, 1].max() + 1)
- box = [x_min, y_min, x_max, y_max]
- r_boxes.append(box)
- dt_boxes = np.array(r_boxes)
- return dt_boxes, rec_res
-
-
-if __name__ == '__main__':
- slanet_table = SLANetPlus()
- img_path = "D:\pythonProjects\TableStructureRec\outputs\\benchmark\\border_left_7267_OEJGHZF525Q011X2ZC34.jpg"
- img = cv2.imread(img_path)
- try:
- ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
- except ModuleNotFoundError as exc:
- raise ModuleNotFoundError(
- "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
- ) from exc
- ocr_result, _ = ocr_engine(img)
- table_html_str, table_cell_bboxes, elapse = slanet_table(img, ocr_result)
-
- viser = VisTable()
-
- img_path = Path(img_path)
-
- save_dir = "outputs"
- save_html_path = f"{save_dir}/{Path(img_path).stem}.html"
- save_drawed_path = f"{save_dir}/vis_{Path(img_path).name}"
- viser(
- img_path,
- table_html_str,
- save_html_path,
- table_cell_bboxes,
- save_drawed_path,
- )
diff --git a/slanet_plus_table/models/inference.yml b/slanet_plus_table/models/inference.yml
deleted file mode 100644
index dd6db0d..0000000
--- a/slanet_plus_table/models/inference.yml
+++ /dev/null
@@ -1,106 +0,0 @@
-Hpi:
- backend_config:
- paddle_infer:
- cpu_num_threads: 8
- enable_log_info: false
- selected_backends:
- cpu: paddle_infer
- gpu: paddle_infer
- supported_backends:
- cpu:
- - paddle_infer
- gpu:
- - paddle_infer
-Global:
- model_name: SLANet_plus
-PreProcess:
- transform_ops:
- - DecodeImage:
- channel_first: false
- img_mode: BGR
- - TableLabelEncode:
- learn_empty_box: false
- loc_reg_num: 8
- max_text_length: 500
- merge_no_span_structure: true
- replace_empty_cell_token: false
- - TableBoxEncode:
- in_box_format: xyxyxyxy
- out_box_format: xyxyxyxy
- - ResizeTableImage:
- max_len: 488
- - NormalizeImage:
- mean:
- - 0.485
- - 0.456
- - 0.406
- order: hwc
- scale: 1./255.
- std:
- - 0.229
- - 0.224
- - 0.225
- - PaddingTableImage:
- size:
- - 488
- - 488
- - ToCHWImage: null
- - KeepKeys:
- keep_keys:
- - image
- - structure
- - bboxes
- - bbox_masks
- - shape
-PostProcess:
- name: TableLabelDecode
- merge_no_span_structure: true
- character_dict:
- -
- -
- -
- -
- -
- -
- -
- - | '
- - |
- - ' colspan="2"'
- - ' colspan="3"'
- - ' colspan="4"'
- - ' colspan="5"'
- - ' colspan="6"'
- - ' colspan="7"'
- - ' colspan="8"'
- - ' colspan="9"'
- - ' colspan="10"'
- - ' colspan="11"'
- - ' colspan="12"'
- - ' colspan="13"'
- - ' colspan="14"'
- - ' colspan="15"'
- - ' colspan="16"'
- - ' colspan="17"'
- - ' colspan="18"'
- - ' colspan="19"'
- - ' colspan="20"'
- - ' rowspan="2"'
- - ' rowspan="3"'
- - ' rowspan="4"'
- - ' rowspan="5"'
- - ' rowspan="6"'
- - ' rowspan="7"'
- - ' rowspan="8"'
- - ' rowspan="9"'
- - ' rowspan="10"'
- - ' rowspan="11"'
- - ' rowspan="12"'
- - ' rowspan="13"'
- - ' rowspan="14"'
- - ' rowspan="15"'
- - ' rowspan="16"'
- - ' rowspan="17"'
- - ' rowspan="18"'
- - ' rowspan="19"'
- - ' rowspan="20"'
diff --git a/slanet_plus_table/requirements.txt b/slanet_plus_table/requirements.txt
deleted file mode 100644
index cb794ff..0000000
--- a/slanet_plus_table/requirements.txt
+++ /dev/null
@@ -1,6 +0,0 @@
---extra-index-url=https://www.paddlepaddle.org.cn/packages/stable/cpu/
-opencv_python>=4.5.1.48
-numpy>=1.21.6,<2
-paddlepaddle==3.0.0b0
-Pillow
-requests
diff --git a/slanet_plus_table/setup.py b/slanet_plus_table/setup.py
deleted file mode 100644
index 8f20759..0000000
--- a/slanet_plus_table/setup.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-import sys
-
-import setuptools
-
-from setuptools.command.install import install
-
-MODULE_NAME = "slanet_plus_table"
-
-
-
-setuptools.setup(
- name=MODULE_NAME,
- version="0.0.2",
- platforms="Any",
- long_description="simplify paddlex slanet plus table use",
- long_description_content_type="text/markdown",
- description="Tools for parsing table structures based paddlepaddle.",
- author="jockerK",
- author_email="xinyijianggo@gmail.com",
- url="https://github.com/RapidAI/RapidTable",
- license="Apache-2.0",
- include_package_data=True,
- install_requires=[
- "paddlepaddle==3.0.0b0",
- "PyYAML>=6.0",
- "opencv_python>=4.5.1.48",
- "numpy>=1.21.6",
- "Pillow",
- ],
- packages=[
- MODULE_NAME,
- f"{MODULE_NAME}.models",
- f"{MODULE_NAME}.table_matcher",
- f"{MODULE_NAME}.table_structure",
- ],
- package_data={"": ["inference.pdiparams","inference.pdmodel"]},
- keywords=["ppstructure,table,rapidocr,rapid_table"],
- classifiers=[
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: 3.12",
- ],
- python_requires=">=3.8",
-)
diff --git a/slanet_plus_table/table_matcher/__init__.py b/slanet_plus_table/table_matcher/__init__.py
deleted file mode 100644
index 3cebc9e..0000000
--- a/slanet_plus_table/table_matcher/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# -*- encoding: utf-8 -*-
-from .matcher import TableMatch
diff --git a/slanet_plus_table/table_matcher/matcher.py b/slanet_plus_table/table_matcher/matcher.py
deleted file mode 100644
index b930c70..0000000
--- a/slanet_plus_table/table_matcher/matcher.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# -*- encoding: utf-8 -*-
-import numpy as np
-
-from .utils import compute_iou, distance
-
-
-class TableMatch:
- def __init__(self, filter_ocr_result=True, use_master=False):
- self.filter_ocr_result = filter_ocr_result
- self.use_master = use_master
-
- def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
- if self.filter_ocr_result:
- dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res)
- matched_index = self.match_result(dt_boxes, pred_bboxes)
- pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
- return pred_html
-
- def match_result(self, dt_boxes, pred_bboxes):
- matched = {}
- for i, gt_box in enumerate(dt_boxes):
- distances = []
- for j, pred_box in enumerate(pred_bboxes):
- if len(pred_box) == 8:
- pred_box = [
- np.min(pred_box[0::2]),
- np.min(pred_box[1::2]),
- np.max(pred_box[0::2]),
- np.max(pred_box[1::2]),
- ]
- distances.append(
- (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
- ) # compute iou and l1 distance
- sorted_distances = distances.copy()
- # select det box by iou and l1 distance
- sorted_distances = sorted(
- sorted_distances, key=lambda item: (item[1], item[0])
- )
- if distances.index(sorted_distances[0]) not in matched.keys():
- matched[distances.index(sorted_distances[0])] = [i]
- else:
- matched[distances.index(sorted_distances[0])].append(i)
- return matched
-
- def get_pred_html(self, pred_structures, matched_index, ocr_contents):
- end_html = []
- td_index = 0
- for tag in pred_structures:
- if "" not in tag:
- end_html.append(tag)
- continue
-
- if " | " == tag:
- end_html.extend("")
-
- if td_index in matched_index.keys():
- b_with = False
- if (
- "" in ocr_contents[matched_index[td_index][0]]
- and len(matched_index[td_index]) > 1
- ):
- b_with = True
- end_html.extend("")
-
- for i, td_index_index in enumerate(matched_index[td_index]):
- content = ocr_contents[td_index_index][0]
- if len(matched_index[td_index]) > 1:
- if len(content) == 0:
- continue
-
- if content[0] == " ":
- content = content[1:]
-
- if "" in content:
- content = content[3:]
-
- if "" in content:
- content = content[:-4]
-
- if len(content) == 0:
- continue
-
- if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
- content += " "
- end_html.extend(content)
-
- if b_with:
- end_html.extend("")
-
- if " | | " == tag:
- end_html.append("")
- else:
- end_html.append(tag)
-
- td_index += 1
-
- # Filter elements
- filter_elements = ["", "", "", ""]
- end_html = [v for v in end_html if v not in filter_elements]
- return "".join(end_html), end_html
-
- def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
- y1 = pred_bboxes[:, 1::2].min()
- new_dt_boxes = []
- new_rec_res = []
-
- for box, rec in zip(dt_boxes, rec_res):
- if np.max(box[1::2]) < y1:
- continue
- new_dt_boxes.append(box)
- new_rec_res.append(rec)
- return new_dt_boxes, new_rec_res
diff --git a/slanet_plus_table/table_matcher/utils.py b/slanet_plus_table/table_matcher/utils.py
deleted file mode 100644
index de55fe1..0000000
--- a/slanet_plus_table/table_matcher/utils.py
+++ /dev/null
@@ -1,36 +0,0 @@
-def distance(box_1, box_2):
- x1, y1, x2, y2 = box_1
- x3, y3, x4, y4 = box_2
- dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
- dis_2 = abs(x3 - x1) + abs(y3 - y1)
- dis_3 = abs(x4 - x2) + abs(y4 - y2)
- return dis + min(dis_2, dis_3)
-
-
-def compute_iou(rec1, rec2):
- """
- computing IoU
- :param rec1: (y0, x0, y1, x1), which reflects
- (top, left, bottom, right)
- :param rec2: (y0, x0, y1, x1)
- :return: scala value of IoU
- """
- # computing area of each rectangles
- S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
- S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
-
- # computing the sum_area
- sum_area = S_rec1 + S_rec2
-
- # find the each edge of intersect rectangle
- left_line = max(rec1[1], rec2[1])
- right_line = min(rec1[3], rec2[3])
- top_line = max(rec1[0], rec2[0])
- bottom_line = min(rec1[2], rec2[2])
-
- # judge if there is an intersect
- if left_line >= right_line or top_line >= bottom_line:
- return 0.0
- else:
- intersect = (right_line - left_line) * (bottom_line - top_line)
- return (intersect / (sum_area - intersect)) * 1.0
diff --git a/slanet_plus_table/table_structure/__init__.py b/slanet_plus_table/table_structure/__init__.py
deleted file mode 100644
index 9da248f..0000000
--- a/slanet_plus_table/table_structure/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# -*- encoding: utf-8 -*-
-from .table_structure import TableStructurer
diff --git a/slanet_plus_table/table_structure/table_structure.py b/slanet_plus_table/table_structure/table_structure.py
deleted file mode 100644
index cd7af51..0000000
--- a/slanet_plus_table/table_structure/table_structure.py
+++ /dev/null
@@ -1,178 +0,0 @@
-import time
-import numpy as np
-from .utils import TablePredictor, TablePreprocess, TableLabelDecode
-
-
-# class SLANetPlus:
-# def __init__(self, model_dir, model_prefix="inference"):
-# self.preprocess_op = TablePreprocess()
-#
-# self.mean=[0.485, 0.456, 0.406]
-# self.std=[0.229, 0.224, 0.225]
-# self.target_img_size = [488, 488]
-# self.scale=1 / 255
-# self.order="hwc"
-# self.img_loader = LoadImage()
-# self.target_size = 488
-# self.pad_color = 0
-# self.predictor = TablePredictor(model_dir, model_prefix)
-# dict_character=['sos', '', '', '', '', '', '
', '', ' | ', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', ' | ', 'eos']
-# self.beg_str = "sos"
-# self.end_str = "eos"
-# self.dict = {}
-# self.table_matcher = TableMatch()
-# for i, char in enumerate(dict_character):
-# self.dict[char] = i
-# self.character = dict_character
-# self.td_token = ["", " | | "]
-#
-# def call(self, img):
-# starttime = time.time()
-# data = {"image": img}
-# data = self.preprocess_op(data)
-# img = data[0]
-# if img is None:
-# return None, 0
-# img = np.expand_dims(img, axis=0)
-# img = img.copy()
-# def __call__(self, img, ocr_result):
-# img = self.img_loader(img)
-# h, w = img.shape[:2]
-# n_img, h_resize, w_resize = self.resize(img)
-# n_img = self.normalize(n_img)
-# n_img = self.pad(n_img)
-# n_img = n_img.transpose((2, 0, 1))
-# n_img = np.expand_dims(n_img, axis=0)
-# start = time.time()
-# batch_output = self.predictor(n_img)
-# elapse_time = time.time() - start
-# ori_img_size = [[w, h]]
-# output = self.decode(batch_output, ori_img_size)[0]
-# corners = np.stack(output['bbox'], axis=0)
-# dt_boxes, rec_res = get_boxes_recs(ocr_result, h, w)
-# pred_html = self.table_matcher(output['structure'], convert_corners_to_bounding_boxes(corners), dt_boxes, rec_res)
-# return pred_html,output['bbox'], elapse_time
-# def resize(self, img):
-# h, w = img.shape[:2]
-# scale = self.target_size / max(h, w)
-# h_resize = round(h * scale)
-# w_resize = round(w * scale)
-# resized_img = cv2.resize(img, (w_resize, h_resize), interpolation=cv2.INTER_LINEAR)
-# return resized_img, h_resize, w_resize
-# def pad(self, img):
-# h, w = img.shape[:2]
-# tw, th = self.target_img_size
-# ph = th - h
-# pw = tw - w
-# pad = (0, ph, 0, pw)
-# chns = 1 if img.ndim == 2 else img.shape[2]
-# im = cv2.copyMakeBorder(img, *pad, cv2.BORDER_CONSTANT, value=(self.pad_color,) * chns)
-# return im
-# def normalize(self, img):
-# img = img.astype("float32", copy=False)
-# img *= self.scale
-# img -= self.mean
-# img /= self.std
-# return img
-#
-#
-# def decode(self, pred, ori_img_size):
-# bbox_preds, structure_probs = [], []
-# for bbox_pred, stru_prob in pred:
-# bbox_preds.append(bbox_pred)
-# structure_probs.append(stru_prob)
-# bbox_preds = np.array(bbox_preds)
-# structure_probs = np.array(structure_probs)
-#
-# bbox_list, structure_str_list, structure_score = self.decode_single(
-# structure_probs, bbox_preds, [self.target_img_size], ori_img_size
-# )
-# structure_str_list = [
-# (
-# ["", "", ""]
-# + structure
-# + ["
", "", ""]
-# )
-# for structure in structure_str_list
-# ]
-# return [
-# {"bbox": bbox, "structure": structure, "structure_score": structure_score}
-# for bbox, structure in zip(bbox_list, structure_str_list)
-# ]
-#
-#
-# def decode_single(self, structure_probs, bbox_preds, padding_size, ori_img_size):
-# """convert text-label into text-index."""
-# ignored_tokens = [self.beg_str, self.end_str]
-# end_idx = self.dict[self.end_str]
-#
-# structure_idx = structure_probs.argmax(axis=2)
-# structure_probs = structure_probs.max(axis=2)
-#
-# structure_batch_list = []
-# bbox_batch_list = []
-# batch_size = len(structure_idx)
-# for batch_idx in range(batch_size):
-# structure_list = []
-# bbox_list = []
-# score_list = []
-# for idx in range(len(structure_idx[batch_idx])):
-# char_idx = int(structure_idx[batch_idx][idx])
-# if idx > 0 and char_idx == end_idx:
-# break
-# if char_idx in ignored_tokens:
-# continue
-# text = self.character[char_idx]
-# if text in self.td_token:
-# bbox = bbox_preds[batch_idx, idx]
-# bbox = self._bbox_decode(
-# bbox, padding_size[batch_idx], ori_img_size[batch_idx]
-# )
-# bbox_list.append(bbox.astype(int))
-# structure_list.append(text)
-# score_list.append(structure_probs[batch_idx, idx])
-# structure_batch_list.append(structure_list)
-# structure_score = np.mean(score_list)
-# bbox_batch_list.append(bbox_list)
-#
-# return bbox_batch_list, structure_batch_list, structure_score
-#
-# def _bbox_decode(self, bbox, padding_shape, ori_shape):
-#
-# pad_w, pad_h = padding_shape
-# w, h = ori_shape
-# ratio_w = pad_w / w
-# ratio_h = pad_h / h
-# ratio = min(ratio_w, ratio_h)
-#
-# bbox[0::2] *= pad_w
-# bbox[1::2] *= pad_h
-# bbox[0::2] /= ratio
-# bbox[1::2] /= ratio
-#
-# return bbox
-
-
-class TableStructurer:
- def __init__(self, model_path: str):
- self.preprocess_op = TablePreprocess()
- self.predictor = TablePredictor(model_path)
- self.character = ['', '', '', '', '', '
', '', ' | ', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', ' | ']
- self.postprocess_op = TableLabelDecode(self.character)
-
- def __call__(self, img):
- start_time = time.time()
- data = {"image": img}
- h, w = img.shape[:2]
- ori_img_size = [[w, h]]
- data = self.preprocess_op(data)
- img = data[0]
- if img is None:
- return None, 0
- img = np.expand_dims(img, axis=0)
- img = img.copy()
- cur_img_size = [[488, 488]]
- outputs = self.predictor(img)
- output = self.postprocess_op(outputs, cur_img_size, ori_img_size)[0]
- elapse = time.time() - start_time
- return output["structure"], np.stack(output["bbox"]), elapse
diff --git a/slanet_plus_table/table_structure/utils.py b/slanet_plus_table/table_structure/utils.py
deleted file mode 100644
index f1cf7a0..0000000
--- a/slanet_plus_table/table_structure/utils.py
+++ /dev/null
@@ -1,331 +0,0 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# -*- encoding: utf-8 -*-
-# @Author: Jocker1212
-# @Contact: xinyijianggo@gmail.com
-from pathlib import Path
-
-import cv2
-import numpy as np
-from paddle.inference import Config, create_predictor
-
-class TablePredictor:
- def __init__(self, model_dir, model_prefix="inference"):
- model_file = f"{model_dir}/{model_prefix}.pdmodel"
- params_file = f"{model_dir}/{model_prefix}.pdiparams"
- config = Config(model_file, params_file)
- config.disable_gpu()
- config.disable_glog_info()
- config.enable_new_ir(True)
- config.enable_new_executor(True)
- config.enable_memory_optim()
- config.switch_ir_optim(True)
- # Disable feed, fetch OP, needed by zero_copy_run
- config.switch_use_feed_fetch_ops(False)
- predictor = create_predictor(config)
- self.config = config
- self.predictor = predictor
- # Get input and output handlers
- input_names = predictor.get_input_names()
- self.input_names = input_names.sort()
- self.input_handlers = []
- self.output_handlers = []
- for input_name in input_names:
- input_handler = predictor.get_input_handle(input_name)
- self.input_handlers.append(input_handler)
- self.output_names = predictor.get_output_names()
- for output_name in self.output_names:
- output_handler = predictor.get_output_handle(output_name)
- self.output_handlers.append(output_handler)
-
- def __call__(self, batch_imgs):
- self.input_handlers[0].reshape(batch_imgs.shape)
- self.input_handlers[0].copy_from_cpu(batch_imgs)
- self.predictor.run()
- output = []
- for out_tensor in self.output_handlers:
- batch = out_tensor.copy_to_cpu()
- output.append(batch)
- return self.format_output(output)
-
- def format_output(self, pred):
- return [res for res in zip(*pred)]
-
-class TableLabelDecode:
- """decode the table model outputs(probs) to character str"""
- def __init__(self, dict_character=[], merge_no_span_structure=True, **kwargs):
-
- if merge_no_span_structure:
- if " | " not in dict_character:
- dict_character.append(" | ")
- if "" in dict_character:
- dict_character.remove(" | ")
-
- dict_character = self.add_special_char(dict_character)
- self.dict = {}
- for i, char in enumerate(dict_character):
- self.dict[char] = i
- self.character = dict_character
- self.td_token = [" | ", " | | "]
-
- def add_special_char(self, dict_character):
- """add_special_char"""
- self.beg_str = "sos"
- self.end_str = "eos"
- dict_character = [self.beg_str] + dict_character + [self.end_str]
- return dict_character
-
- def get_ignored_tokens(self):
- """get_ignored_tokens"""
- beg_idx = self.get_beg_end_flag_idx("beg")
- end_idx = self.get_beg_end_flag_idx("end")
- return [beg_idx, end_idx]
-
- def get_beg_end_flag_idx(self, beg_or_end):
- """get_beg_end_flag_idx"""
- if beg_or_end == "beg":
- idx = np.array(self.dict[self.beg_str])
- elif beg_or_end == "end":
- idx = np.array(self.dict[self.end_str])
- else:
- assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
- return idx
-
- def __call__(self, pred, img_size, ori_img_size):
- """apply"""
- bbox_preds, structure_probs = [], []
- for bbox_pred, stru_prob in pred:
- bbox_preds.append(bbox_pred)
- structure_probs.append(stru_prob)
- bbox_preds = np.array(bbox_preds)
- structure_probs = np.array(structure_probs)
-
- bbox_list, structure_str_list, structure_score = self.decode(
- structure_probs, bbox_preds, img_size, ori_img_size
- )
- structure_str_list = [
- (
- ["", "", "", "", ""]
- )
- for structure in structure_str_list
- ]
- return [
- {"bbox": bbox, "structure": structure, "structure_score": structure_score}
- for bbox, structure in zip(bbox_list, structure_str_list)
- ]
-
- def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size):
- """convert text-label into text-index."""
- ignored_tokens = self.get_ignored_tokens()
- end_idx = self.dict[self.end_str]
-
- structure_idx = structure_probs.argmax(axis=2)
- structure_probs = structure_probs.max(axis=2)
-
- structure_batch_list = []
- bbox_batch_list = []
- batch_size = len(structure_idx)
- for batch_idx in range(batch_size):
- structure_list = []
- bbox_list = []
- score_list = []
- for idx in range(len(structure_idx[batch_idx])):
- char_idx = int(structure_idx[batch_idx][idx])
- if idx > 0 and char_idx == end_idx:
- break
- if char_idx in ignored_tokens:
- continue
- text = self.character[char_idx]
- if text in self.td_token:
- bbox = bbox_preds[batch_idx, idx]
- bbox = self._bbox_decode(
- bbox, padding_size[batch_idx], ori_img_size[batch_idx]
- )
- bbox_list.append(bbox.astype(int))
- structure_list.append(text)
- score_list.append(structure_probs[batch_idx, idx])
- structure_batch_list.append(structure_list)
- structure_score = np.mean(score_list)
- bbox_batch_list.append(bbox_list)
-
- return bbox_batch_list, structure_batch_list, structure_score
-
- def _bbox_decode(self, bbox, padding_shape, ori_shape):
-
- pad_w, pad_h = padding_shape
- w, h = ori_shape
- ratio_w = pad_w / w
- ratio_h = pad_h / h
- ratio = min(ratio_w, ratio_h)
-
- bbox[0::2] *= pad_w
- bbox[1::2] *= pad_h
- bbox[0::2] /= ratio
- bbox[1::2] /= ratio
-
- return bbox
-
-
-class TablePreprocess:
- def __init__(self):
- self.table_max_len = 488
- self.build_pre_process_list()
- self.ops = self.create_operators()
-
- def __call__(self, data):
- """transform"""
- if self.ops is None:
- self.ops = []
-
- for op in self.ops:
- data = op(data)
- if data is None:
- return None
- return data
-
- def create_operators(
- self,
- ):
- """
- create operators based on the config
-
- Args:
- params(list): a dict list, used to create some operators
- """
- assert isinstance(
- self.pre_process_list, list
- ), "operator config should be a list"
- ops = []
- for operator in self.pre_process_list:
- assert (
- isinstance(operator, dict) and len(operator) == 1
- ), "yaml format error"
- op_name = list(operator)[0]
- param = {} if operator[op_name] is None else operator[op_name]
- op = eval(op_name)(**param)
- ops.append(op)
- return ops
-
- def build_pre_process_list(self):
- resize_op = {
- "ResizeTableImage": {
- "max_len": self.table_max_len,
- }
- }
- pad_op = {
- "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
- }
- normalize_op = {
- "NormalizeImage": {
- "std": [0.229, 0.224, 0.225],
- "mean": [0.485, 0.456, 0.406],
- "scale": "1./255.",
- "order": "hwc",
- }
- }
- to_chw_op = {"ToCHWImage": None}
- keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
- self.pre_process_list = [
- resize_op,
- normalize_op,
- pad_op,
- to_chw_op,
- keep_keys_op,
- ]
-
-
-class ResizeTableImage:
- def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
- super(ResizeTableImage, self).__init__()
- self.max_len = max_len
- self.resize_bboxes = resize_bboxes
- self.infer_mode = infer_mode
-
- def __call__(self, data):
- img = data["image"]
- height, width = img.shape[0:2]
- ratio = self.max_len / (max(height, width) * 1.0)
- resize_h = int(height * ratio)
- resize_w = int(width * ratio)
- resize_img = cv2.resize(img, (resize_w, resize_h))
- if self.resize_bboxes and not self.infer_mode:
- data["bboxes"] = data["bboxes"] * ratio
- data["image"] = resize_img
- data["src_img"] = img
- data["shape"] = np.array([height, width, ratio, ratio])
- data["max_len"] = self.max_len
- return data
-
-
-class PaddingTableImage:
- def __init__(self, size, **kwargs):
- super(PaddingTableImage, self).__init__()
- self.size = size
-
- def __call__(self, data):
- img = data["image"]
- pad_h, pad_w = self.size
- padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
- height, width = img.shape[0:2]
- padding_img[0:height, 0:width, :] = img.copy()
- data["image"] = padding_img
- shape = data["shape"].tolist()
- shape.extend([pad_h, pad_w])
- data["shape"] = np.array(shape)
- return data
-
-
-class NormalizeImage:
- """normalize image such as substract mean, divide std"""
-
- def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
- if isinstance(scale, str):
- scale = eval(scale)
- self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
- mean = mean if mean is not None else [0.485, 0.456, 0.406]
- std = std if std is not None else [0.229, 0.224, 0.225]
-
- shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype("float32")
- self.std = np.array(std).reshape(shape).astype("float32")
-
- def __call__(self, data):
- img = np.array(data["image"])
- assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
- data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
- return data
-
-
-class ToCHWImage:
- """convert hwc image to chw image"""
-
- def __init__(self, **kwargs):
- pass
-
- def __call__(self, data):
- img = np.array(data["image"])
- data["image"] = img.transpose((2, 0, 1))
- return data
-
-class KeepKeys:
- def __init__(self, keep_keys, **kwargs):
- self.keep_keys = keep_keys
-
- def __call__(self, data):
- data_list = []
- for key in self.keep_keys:
- data_list.append(data[key])
- return data_list
diff --git a/slanet_plus_table/utils.py b/slanet_plus_table/utils.py
deleted file mode 100644
index bb0736c..0000000
--- a/slanet_plus_table/utils.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from io import BytesIO
-from pathlib import Path
-from typing import Union, Optional
-
-import cv2
-import numpy as np
-from PIL import Image, UnidentifiedImageError
-
-InputType = Union[str, np.ndarray, bytes, Path]
-
-class LoadImage:
- def __init__(
- self,
- ):
- pass
-
- def __call__(self, img: InputType) -> np.ndarray:
- if not isinstance(img, InputType.__args__):
- raise LoadImageError(
- f"The img type {type(img)} does not in {InputType.__args__}"
- )
-
- img = self.load_img(img)
-
- if img.ndim == 2:
- return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
- if img.ndim == 3 and img.shape[2] == 4:
- return self.cvt_four_to_three(img)
-
- return img
-
- def load_img(self, img: InputType) -> np.ndarray:
- if isinstance(img, (str, Path)):
- self.verify_exist(img)
- try:
- img = np.array(Image.open(img))
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- except UnidentifiedImageError as e:
- raise LoadImageError(f"cannot identify image file {img}") from e
- return img
-
- if isinstance(img, bytes):
- img = np.array(Image.open(BytesIO(img)))
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- return img
-
- if isinstance(img, np.ndarray):
- return img
-
- raise LoadImageError(f"{type(img)} is not supported!")
-
- @staticmethod
- def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
- """RGBA → RGB"""
- r, g, b, a = cv2.split(img)
- new_img = cv2.merge((b, g, r))
-
- not_a = cv2.bitwise_not(a)
- not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
-
- new_img = cv2.bitwise_and(new_img, new_img, mask=a)
- new_img = cv2.add(new_img, not_a)
- return new_img
-
- @staticmethod
- def verify_exist(file_path: Union[str, Path]):
- if not Path(file_path).exists():
- raise LoadImageError(f"{file_path} does not exist.")
-
-class LoadImageError(Exception):
- pass
-
-class VisTable:
- def __init__(
- self,
- ):
- self.load_img = LoadImage()
-
- def __call__(
- self,
- img_path: Union[str, Path],
- table_html_str: str,
- save_html_path: Optional[str] = None,
- table_cell_bboxes: Optional[np.ndarray] = None,
- save_drawed_path: Optional[str] = None,
- ) -> None:
- if save_html_path:
- html_with_border = self.insert_border_style(table_html_str)
- self.save_html(save_html_path, html_with_border)
-
- if table_cell_bboxes is None:
- return None
-
- img = self.load_img(img_path)
-
- dims_bboxes = table_cell_bboxes.shape[1]
- if dims_bboxes == 4:
- drawed_img = self.draw_rectangle(img, table_cell_bboxes)
- elif dims_bboxes == 8:
- drawed_img = self.draw_polylines(img, table_cell_bboxes)
- else:
- raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
-
- if save_drawed_path:
- self.save_img(save_drawed_path, drawed_img)
-
- return drawed_img
-
- def insert_border_style(self, table_html_str: str):
- style_res = """"""
- prefix_table, suffix_table = table_html_str.split("")
- html_with_border = f"{prefix_table}{style_res}{suffix_table}"
- return html_with_border
-
- @staticmethod
- def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
- img_copy = img.copy()
- for box in boxes.astype(int):
- x1, y1, x2, y2 = box
- cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
- return img_copy
-
- @staticmethod
- def draw_polylines(img: np.ndarray, points) -> np.ndarray:
- img_copy = img.copy()
- for point in points.astype(int):
- point = point.reshape(4, 2)
- cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
- return img_copy
-
- @staticmethod
- def save_img(save_path: Union[str, Path], img: np.ndarray):
- cv2.imwrite(str(save_path), img)
-
- @staticmethod
- def save_html(save_path: Union[str, Path], html: str):
- with open(save_path, "w", encoding="utf-8") as f:
- f.write(html)