From 106fcb92ea765744969ffdaa57f0c2392122e5e1 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 27 Mar 2023 13:07:08 +0800 Subject: [PATCH] =?UTF-8?q?[Feature]=20=E5=A2=9E=E5=8A=A0HuggingfaceDumper?= =?UTF-8?q?(=E4=BD=BF=E7=94=A8huggingface=E7=9A=84datasets=E5=8C=85?= =?UTF-8?q?=E5=B0=81=E8=A3=85)=EF=BC=8Cprepare=5Fdataset=E8=84=9A=E6=9C=AC?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0--huggingface=E5=8F=AF=E9=80=89=E6=8C=89?= =?UTF-8?q?=E9=92=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../preparers/config_generators/base.py | 9 ++- .../config_generators/re_config_generator.py | 8 +-- .../config_generators/ser_config_generator.py | 8 +-- .../textdet_config_generator.py | 2 +- mmocr/datasets/preparers/dumpers/__init__.py | 3 +- .../preparers/dumpers/huggingface_dumper.py | 37 +++++++++++++ tools/dataset_converters/prepare_dataset.py | 55 +++++++++++++++++++ 7 files changed, 109 insertions(+), 13 deletions(-) create mode 100644 mmocr/datasets/preparers/dumpers/huggingface_dumper.py diff --git a/mmocr/datasets/preparers/config_generators/base.py b/mmocr/datasets/preparers/config_generators/base.py index ba3811a425..7e1237e93e 100644 --- a/mmocr/datasets/preparers/config_generators/base.py +++ b/mmocr/datasets/preparers/config_generators/base.py @@ -83,11 +83,18 @@ def _prepare_anns(self, train_anns: Optional[List[Dict]], assert 'ann_file' in ann_dict suffix = ann_dict['ann_file'].split('.')[-1] if suffix == 'json': - dataset_type = 'OCRDataset' + if self.task in ['ser', 're']: + dataset_type = f'{self.task.upper()}Dataset' + else: + dataset_type = 'OCRDataset' elif suffix == 'lmdb': assert self.task == 'textrecog', \ 'LMDB format only works for textrecog now.' dataset_type = 'RecogLMDBDataset' + elif suffix == 'huggingface': + assert self.task in ['ser', 're'], \ + 'Huggingface format only works for ser or re now.' + dataset_type = f'{self.task.upper()}HuggingfaceDataset' else: raise NotImplementedError( 'ann file only supports JSON file or LMDB file') diff --git a/mmocr/datasets/preparers/config_generators/re_config_generator.py b/mmocr/datasets/preparers/config_generators/re_config_generator.py index 35f2b6589f..6806167a0b 100644 --- a/mmocr/datasets/preparers/config_generators/re_config_generator.py +++ b/mmocr/datasets/preparers/config_generators/re_config_generator.py @@ -87,12 +87,10 @@ def _gen_dataset_config(self) -> str: cfg = '' for key_name, ann_dict in self.anns.items(): cfg += f'\n{key_name} = dict(\n' - cfg += ' type=\'REDataset\',\n' - cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501 + cfg += f' type=\'{ann_dict["dataset_type"]}\',\n' + cfg += f' data_root={self.dataset_name}_{self.task}_data_root,\n' # noqa: E501 cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n' - if ann_dict['split'] == 'train': - cfg += ' filter_cfg=dict(filter_empty_gt=True, min_size=32),\n' # noqa: E501 - elif ann_dict['split'] in ['test', 'val']: + if ann_dict['split'] in ['test', 'val']: cfg += ' test_mode=True,\n' cfg += ' pipeline=None)\n' return cfg diff --git a/mmocr/datasets/preparers/config_generators/ser_config_generator.py b/mmocr/datasets/preparers/config_generators/ser_config_generator.py index c931678698..f05740b1b5 100644 --- a/mmocr/datasets/preparers/config_generators/ser_config_generator.py +++ b/mmocr/datasets/preparers/config_generators/ser_config_generator.py @@ -87,12 +87,10 @@ def _gen_dataset_config(self) -> str: cfg = '' for key_name, ann_dict in self.anns.items(): cfg += f'\n{key_name} = dict(\n' - cfg += ' type=\'SERDataset\',\n' - cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501 + cfg += f' type=\'{ann_dict["dataset_type"]}\',\n' + cfg += f' data_root={self.dataset_name}_{self.task}_data_root,\n' # noqa: E501 cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n' - if ann_dict['split'] == 'train': - cfg += ' filter_cfg=dict(filter_empty_gt=True, min_size=32),\n' # noqa: E501 - elif ann_dict['split'] in ['test', 'val']: + if ann_dict['split'] in ['test', 'val']: cfg += ' test_mode=True,\n' cfg += ' pipeline=None)\n' return cfg diff --git a/mmocr/datasets/preparers/config_generators/textdet_config_generator.py b/mmocr/datasets/preparers/config_generators/textdet_config_generator.py index fcb8af4fb0..0c26cbc497 100644 --- a/mmocr/datasets/preparers/config_generators/textdet_config_generator.py +++ b/mmocr/datasets/preparers/config_generators/textdet_config_generator.py @@ -86,7 +86,7 @@ def _gen_dataset_config(self) -> str: for key_name, ann_dict in self.anns.items(): cfg += f'\n{key_name} = dict(\n' cfg += ' type=\'OCRDataset\',\n' - cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501 + cfg += f' data_root={self.dataset_name}_{self.task}_data_root,\n' # noqa: E501 cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n' if ann_dict['split'] == 'train': cfg += ' filter_cfg=dict(filter_empty_gt=True, min_size=32),\n' # noqa: E501 diff --git a/mmocr/datasets/preparers/dumpers/__init__.py b/mmocr/datasets/preparers/dumpers/__init__.py index ed3dda486b..39e94a07f1 100644 --- a/mmocr/datasets/preparers/dumpers/__init__.py +++ b/mmocr/datasets/preparers/dumpers/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseDumper +from .huggingface_dumper import HuggingfaceDumper from .json_dumper import JsonDumper from .lmdb_dumper import TextRecogLMDBDumper from .wild_receipt_openset_dumper import WildreceiptOpensetDumper __all__ = [ 'BaseDumper', 'JsonDumper', 'WildreceiptOpensetDumper', - 'TextRecogLMDBDumper' + 'TextRecogLMDBDumper', 'HuggingfaceDumper' ] diff --git a/mmocr/datasets/preparers/dumpers/huggingface_dumper.py b/mmocr/datasets/preparers/dumpers/huggingface_dumper.py new file mode 100644 index 0000000000..09889c94f1 --- /dev/null +++ b/mmocr/datasets/preparers/dumpers/huggingface_dumper.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import defaultdict +from typing import Dict + +from datasets import Dataset, Image + +from mmocr.registry import DATA_DUMPERS +from .base import BaseDumper + + +@DATA_DUMPERS.register_module() +class HuggingfaceDumper(BaseDumper): + """Semantic Entity Recognition and Relation Extraction huggingface datasets + format dumper.""" + + def dump(self, data: Dict) -> None: + """Dump data to datasets format to disk. + + Args: + data (Dict): MMOCR format data to be dumped. + """ + data_list = data.get('data_list', None) + filename = f'{self.task}_{self.split}.huggingface' + dst_file = osp.join(self.data_root, filename) + + merged_dict = defaultdict(list) + for d in data_list: + instances = d['instances'] + img_path = osp.join(self.data_root, d['img_path']) + merged_dict['image'].append(img_path) + for k, v in instances.items(): + merged_dict[k].append(v) + ds = Dataset.from_dict(merged_dict) + ds = ds.cast_column('image', Image()) + # save to disk + ds.save_to_disk(dst_file) diff --git a/tools/dataset_converters/prepare_dataset.py b/tools/dataset_converters/prepare_dataset.py index 1d2e74c069..bcfaa922c8 100644 --- a/tools/dataset_converters/prepare_dataset.py +++ b/tools/dataset_converters/prepare_dataset.py @@ -36,6 +36,13 @@ def parse_args(): help='Whether to dump the textrecog dataset to LMDB format, It\'s a ' 'shortcut to force the dataset to be dumped in lmdb format. ' 'Applicable when --task=textrecog') + parser.add_argument( + '--huggingface', + action='store_true', + default=False, + help='Whether to dump the ser/re dataset to huggingface format,' + 'It\'s a shortcut to force the dataset to be dumped in huggingface ' + 'format. Applicable when --task=ser or re') parser.add_argument( '--overwrite-cfg', action='store_true', @@ -124,10 +131,56 @@ def force_lmdb(cfg): return cfg +def force_huggingface(cfg): + """Force the dataset to be dumped in huggingface format. + + Args: + cfg (Config): Config object. + + Returns: + Config: Config object. + """ + for split in ['train', 'val', 'test']: + preparer_cfg = cfg.get(f'{split}_preparer') + if preparer_cfg: + if preparer_cfg.get('dumper') is None: + raise ValueError( + f'{split} split does not come with a dumper, ' + 'so most likely the annotations are MMOCR-ready and do ' + 'not need any adaptation, and it ' + 'cannot be dumped in LMDB format.') + preparer_cfg.dumper['type'] = 'HuggingfaceDumper' + + cfg.config_generator['dataset_name'] = f'{cfg.dataset_name}_huggingface' + + for split in ['train_anns', 'val_anns', 'test_anns']: + if split in cfg.config_generator: + # It can be None when users want to clear out the default + # value + if not cfg.config_generator[split]: + continue + ann_list = cfg.config_generator[split] + for ann_dict in ann_list: + ann_dict['ann_file'] = ( + osp.splitext(ann_dict['ann_file'])[0] + '.huggingface') + else: + if split == 'train_anns': + ann_list = [dict(ann_file=f'{cfg.task}_train.huggingface')] + elif split == 'test_anns': + ann_list = [dict(ann_file=f'{cfg.task}_test.huggingface')] + else: + ann_list = [] + cfg.config_generator[split] = ann_list + + return cfg + + def main(): args = parse_args() if args.lmdb and args.task != 'textrecog': raise ValueError('--lmdb only works with --task=textrecog') + if args.huggingface and args.task not in ['ser', 're']: + raise ValueError('--huggingface only works with --task=ser or re') for dataset in args.datasets: if not osp.isdir(osp.join(args.dataset_zoo_path, dataset)): warnings.warn(f'{dataset} is not supported yet. Please check ' @@ -145,6 +198,8 @@ def main(): cfg.dataset_name = dataset if args.lmdb: cfg = force_lmdb(cfg) + if args.huggingface: + cfg = force_huggingface(cfg) preparer = DatasetPreparer.from_file(cfg) preparer.run(args.splits)