Skip to content

Commit

Permalink
Merge branch 'layoutlm' of https://github.com/KevinNuNu/mmocr into la…
Browse files Browse the repository at this point in the history
…youtlm
  • Loading branch information
KevinNuNu committed Mar 28, 2023
2 parents 1d0c5e3 + 106fcb9 commit d90663b
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 7 deletions.
9 changes: 8 additions & 1 deletion mmocr/datasets/preparers/config_generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,18 @@ def _prepare_anns(self, train_anns: Optional[List[Dict]],
assert 'ann_file' in ann_dict
suffix = ann_dict['ann_file'].split('.')[-1]
if suffix == 'json':
dataset_type = 'OCRDataset'
if self.task in ['ser', 're']:
dataset_type = f'{self.task.upper()}Dataset'
else:
dataset_type = 'OCRDataset'
elif suffix == 'lmdb':
assert self.task == 'textrecog', \
'LMDB format only works for textrecog now.'
dataset_type = 'RecogLMDBDataset'
elif suffix == 'huggingface':
assert self.task in ['ser', 're'], \
'Huggingface format only works for ser or re now.'
dataset_type = f'{self.task.upper()}HuggingfaceDataset'
else:
raise NotImplementedError(
'ann file only supports JSON file or LMDB file')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def _gen_dataset_config(self) -> str:
cfg = ''
for key_name, ann_dict in self.anns.items():
cfg += f'\n{key_name} = dict(\n'
cfg += ' type=\'REDataset\',\n'
cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501
cfg += f' type=\'{ann_dict["dataset_type"]}\',\n'
cfg += f' data_root={self.dataset_name}_{self.task}_data_root,\n' # noqa: E501
cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n'
if ann_dict['split'] in ['test', 'val']:
cfg += ' test_mode=True,\n'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def _gen_dataset_config(self) -> str:
cfg = ''
for key_name, ann_dict in self.anns.items():
cfg += f'\n{key_name} = dict(\n'
cfg += ' type=\'SERDataset\',\n'
cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501
cfg += f' type=\'{ann_dict["dataset_type"]}\',\n'
cfg += f' data_root={self.dataset_name}_{self.task}_data_root,\n' # noqa: E501
cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n'
if ann_dict['split'] in ['test', 'val']:
cfg += ' test_mode=True,\n'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _gen_dataset_config(self) -> str:
for key_name, ann_dict in self.anns.items():
cfg += f'\n{key_name} = dict(\n'
cfg += ' type=\'OCRDataset\',\n'
cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501
cfg += f' data_root={self.dataset_name}_{self.task}_data_root,\n' # noqa: E501
cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n'
if ann_dict['split'] == 'train':
cfg += ' filter_cfg=dict(filter_empty_gt=True, min_size=32),\n' # noqa: E501
Expand Down
3 changes: 2 additions & 1 deletion mmocr/datasets/preparers/dumpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseDumper
from .huggingface_dumper import HuggingfaceDumper
from .json_dumper import JsonDumper
from .lmdb_dumper import TextRecogLMDBDumper
from .wild_receipt_openset_dumper import WildreceiptOpensetDumper

__all__ = [
'BaseDumper', 'JsonDumper', 'WildreceiptOpensetDumper',
'TextRecogLMDBDumper'
'TextRecogLMDBDumper', 'HuggingfaceDumper'
]
37 changes: 37 additions & 0 deletions mmocr/datasets/preparers/dumpers/huggingface_dumper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from collections import defaultdict
from typing import Dict

from datasets import Dataset, Image

from mmocr.registry import DATA_DUMPERS
from .base import BaseDumper


@DATA_DUMPERS.register_module()
class HuggingfaceDumper(BaseDumper):
"""Semantic Entity Recognition and Relation Extraction huggingface datasets
format dumper."""

def dump(self, data: Dict) -> None:
"""Dump data to datasets format to disk.
Args:
data (Dict): MMOCR format data to be dumped.
"""
data_list = data.get('data_list', None)
filename = f'{self.task}_{self.split}.huggingface'
dst_file = osp.join(self.data_root, filename)

merged_dict = defaultdict(list)
for d in data_list:
instances = d['instances']
img_path = osp.join(self.data_root, d['img_path'])
merged_dict['image'].append(img_path)
for k, v in instances.items():
merged_dict[k].append(v)
ds = Dataset.from_dict(merged_dict)
ds = ds.cast_column('image', Image())
# save to disk
ds.save_to_disk(dst_file)
55 changes: 55 additions & 0 deletions tools/dataset_converters/prepare_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def parse_args():
help='Whether to dump the textrecog dataset to LMDB format, It\'s a '
'shortcut to force the dataset to be dumped in lmdb format. '
'Applicable when --task=textrecog')
parser.add_argument(
'--huggingface',
action='store_true',
default=False,
help='Whether to dump the ser/re dataset to huggingface format,'
'It\'s a shortcut to force the dataset to be dumped in huggingface '
'format. Applicable when --task=ser or re')
parser.add_argument(
'--overwrite-cfg',
action='store_true',
Expand Down Expand Up @@ -124,10 +131,56 @@ def force_lmdb(cfg):
return cfg


def force_huggingface(cfg):
"""Force the dataset to be dumped in huggingface format.
Args:
cfg (Config): Config object.
Returns:
Config: Config object.
"""
for split in ['train', 'val', 'test']:
preparer_cfg = cfg.get(f'{split}_preparer')
if preparer_cfg:
if preparer_cfg.get('dumper') is None:
raise ValueError(
f'{split} split does not come with a dumper, '
'so most likely the annotations are MMOCR-ready and do '
'not need any adaptation, and it '
'cannot be dumped in LMDB format.')
preparer_cfg.dumper['type'] = 'HuggingfaceDumper'

cfg.config_generator['dataset_name'] = f'{cfg.dataset_name}_huggingface'

for split in ['train_anns', 'val_anns', 'test_anns']:
if split in cfg.config_generator:
# It can be None when users want to clear out the default
# value
if not cfg.config_generator[split]:
continue
ann_list = cfg.config_generator[split]
for ann_dict in ann_list:
ann_dict['ann_file'] = (
osp.splitext(ann_dict['ann_file'])[0] + '.huggingface')
else:
if split == 'train_anns':
ann_list = [dict(ann_file=f'{cfg.task}_train.huggingface')]
elif split == 'test_anns':
ann_list = [dict(ann_file=f'{cfg.task}_test.huggingface')]
else:
ann_list = []
cfg.config_generator[split] = ann_list

return cfg


def main():
args = parse_args()
if args.lmdb and args.task != 'textrecog':
raise ValueError('--lmdb only works with --task=textrecog')
if args.huggingface and args.task not in ['ser', 're']:
raise ValueError('--huggingface only works with --task=ser or re')
for dataset in args.datasets:
if not osp.isdir(osp.join(args.dataset_zoo_path, dataset)):
warnings.warn(f'{dataset} is not supported yet. Please check '
Expand All @@ -145,6 +198,8 @@ def main():
cfg.dataset_name = dataset
if args.lmdb:
cfg = force_lmdb(cfg)
if args.huggingface:
cfg = force_huggingface(cfg)
preparer = DatasetPreparer.from_file(cfg)
preparer.run(args.splits)

Expand Down

0 comments on commit d90663b

Please sign in to comment.