Skip to content

Latest commit

 

History

History
226 lines (131 loc) · 5.66 KB

transform_tutorial.md

File metadata and controls

226 lines (131 loc) · 5.66 KB

Transformation教程

Download Notebook

机制

  1. 每个Transformation都是一个具有可调用函数的类。示例如下
class ToCHWImage(object):
    """ convert hwc image to chw image
    required keys: image
    modified keys: image
    """

    def __init__(self, **kwargs):
        pass

    def __call__(self, data: dict):
        img = data['image']
        if isinstance(img, Image.Image):
            img = np.array(img)
        data['image'] = img.transpose((2, 0, 1))
        return data
  1. transformation的输入始终是dict,其中包含img_path、raw label等数据信息。

  2. transformation api应该明确输入中所需的key以及输出数据中修改或/和添加的key。

可用的transformation可以在mindocr/data/transforms/*_transform.py中发现

# import and check available transforms

from mindocr.data.transforms import general_transforms, det_transforms, rec_transforms
general_transforms.__all__
['DecodeImage', 'NormalizeImage', 'ToCHWImage', 'PackLoaderInputs']
det_transforms.__all__
['DetLabelEncode',
 'MakeBorderMap',
 'MakeShrinkMap',
 'EastRandomCropData',
 'PSERandomCrop']

文本检测

1. 加载图像和注释

准备

%load_ext autoreload
%autoreload 2
%reload_ext autoreload
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
import os

# load the label file which has the info of image path and annotation.
# This file is generated from the ic15 annotations using the converter script.
label_fp = '/Users/Samit/Data/datasets/ic15/det/train/train_icdar2015_label.txt'
root_dir = '/Users/Samit/Data/datasets/ic15/det/train'

data_lines = []
with open(label_fp, 'r') as f:
    for line in f:
        data_lines.append(line)

# just pick one image and its annotation
idx = 3
img_path, annot = data_lines[idx].strip().split('\t')

img_path = os.path.join(root_dir, img_path)
print('img_path', img_path)
print('raw annotation: ', annot)
img_path /Users/Samit/Data/datasets/ic15/det/train/ch4_training_images/img_612.jpg
raw annotation:  [{"transcription": "where", "points": [[483, 197], [529, 174], [530, 197], [485, 221]]}, {"transcription": "people", "points": [[531, 168], [607, 136], [608, 166], [532, 198]]}, {"transcription": "meet", "points": [[613, 128], [691, 100], [691, 131], [613, 160]]}, {"transcription": "###", "points": [[695, 299], [888, 315], [931, 635], [737, 618]]}, {"transcription": "###", "points": [[709, 19], [876, 8], [880, 286], [713, 296]]}, {"transcription": "###", "points": [[530, 270], [660, 246], [661, 300], [532, 324]]}, {"transcription": "###", "points": [[113, 356], [181, 359], [180, 387], [112, 385]]}, {"transcription": "###", "points": [[281, 328], [369, 338], [366, 361], [279, 351]]}, {"transcription": "###", "points": [[66, 314], [183, 313], [183, 328], [68, 330]]}]

解码图像 - DecodeImage

#img_path = '/Users/Samit/Data/datasets/ic15/det/train/ch4_training_images/img_1.jpg'
decode_image = general_transforms.DecodeImage(img_mode='RGB')

# TODO: check the input keys and output keys for the trans. func.

data = {'img_path': img_path}
data  = decode_image(data)
img = data['image']

# visualize
from mindocr.utils.visualize import show_img, show_imgs
show_img(img)

output_13_0

import time

start = time.time()
att = 100
for i in range(att):
    img  = decode_image(data)['image']
avg = (time.time() - start) / att

print('avg reading time: ', avg)
avg reading time:  0.004545390605926514

检测标签编码 - DetLabelEncode

data['label'] = annot

decode_image = det_transforms.DetLabelEncode()
data = decode_image(data)

#print(data['polys'])
print(data['texts'])

# visualize
from mindocr.utils.visualize import draw_boxes

res = draw_boxes(data['image'], data['polys'])
show_img(res)
['where', 'people', 'meet', '###', '###', '###', '###', '###', '###']

output_16_1

2. 图像和注释处理/增强

RandomCrop - EastRandomCropData

from mindocr.data.transforms.general_transforms import RandomCropWithBBox
import copy

#crop_data = det_transforms.EastRandomCropData(size=(640, 640))
crop_data = RandomCropWithBBox(crop_size=(640, 640))

show_img(data['image'])
for i in range(2):
    data_cache = copy.deepcopy(data)
    data_cropped = crop_data(data_cache)

    res_crop = draw_boxes(data_cropped['image'], data_cropped['polys'])
    show_img(res_crop)

output_19_0

output_19_1

output_19_2

ColorJitter

random_color_adj = general_transforms.RandomColorAdjust(brightness=0.4, saturation=0.5)

data_cache = copy.deepcopy(data)
#data_cache['image'] = data_cache['image'][:,:, ::-1]
data_adj = random_color_adj(data_cache)
#print(data_adj)
show_img(data_adj['image'], is_bgr_img=True)

output_21_0