Skip to content

Commit a7d408e

Browse files
authored
Merge pull request #18 from ai-forever/dev
Text detection & OCR filters
2 parents 662172b + 7b434a6 commit a7d408e

16 files changed

Lines changed: 1708 additions & 3 deletions

DPF/dataloaders/images/raw_dataset.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
df: pd.DataFrame,
1818
cols_to_return: Optional[List[str]] = None,
1919
preprocess_f=default_preprocess,
20+
return_none_on_error: bool = False
2021
):
2122
super(RawDataset).__init__()
2223
if cols_to_return is None:
@@ -25,6 +26,7 @@ def __init__(
2526
self.columns = ["image_path"] + cols_to_return
2627
self.data_to_iterate = df[self.columns].values
2728
self.preprocess_f = preprocess_f
29+
self.return_none_on_error = return_none_on_error
2830

2931
def __len__(self):
3032
return len(self.data_to_iterate)
@@ -34,5 +36,12 @@ def __getitem__(self, idx):
3436
self.columns[c]: item for c, item in enumerate(self.data_to_iterate[idx])
3537
}
3638
image_path = data["image_path"]
37-
image_bytes = self.filesystem.read_file(image_path, binary=True).getvalue()
39+
if self.return_none_on_error:
40+
try:
41+
image_bytes = self.filesystem.read_file(image_path, binary=True).getvalue()
42+
except Exception as err:
43+
img_bytes = None
44+
else:
45+
image_bytes = self.filesystem.read_file(image_path, binary=True).getvalue()
46+
3847
return self.preprocess_f(image_bytes, data)

DPF/dataloaders/images/shards_dataset.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
df: pd.DataFrame,
2222
cols_to_return: Optional[List[str]] = None,
2323
preprocess_f=default_preprocess,
24+
return_none_on_error: bool = False
2425
):
2526
super(ShardsDataset).__init__()
2627
if cols_to_return is None:
@@ -32,6 +33,7 @@ def __init__(
3233
)
3334
self.total_samples = len(df)
3435
self.preprocess_f = preprocess_f
36+
self.return_none_on_error = return_none_on_error
3537

3638
def __len__(self):
3739
return self.total_samples
@@ -49,6 +51,13 @@ def __iter__(self):
4951
for data in data_all:
5052
data = {self.columns[i]: item for i, item in enumerate(data)}
5153
filename = os.path.basename(data["image_path"])
52-
img_bytes = tar.extractfile(filename).read()
54+
if self.return_none_on_error:
55+
try:
56+
img_bytes = tar.extractfile(filename).read()
57+
except Exception as err:
58+
img_bytes = None
59+
else:
60+
img_bytes = tar.extractfile(filename).read()
61+
5362
yield self.preprocess_f(img_bytes, data)
5463
tar.close()

DPF/dataloaders/images/universal_dataloader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
df,
2525
cols_to_return=None,
2626
preprocess_f=default_preprocess,
27+
return_none_on_error: bool = False,
2728
**dataloader_kwargs,
2829
):
2930
if cols_to_return is None:
@@ -36,6 +37,7 @@ def __init__(
3637
), "Unknown data format in dataloader"
3738
self.cols_to_return = cols_to_return
3839
self.preprocess_f = preprocess_f
40+
self.return_none_on_error = return_none_on_error
3941
self.dataloader_kwargs = dataloader_kwargs
4042
self.len = None
4143

@@ -47,6 +49,7 @@ def test(self):
4749
self.df[self.df["data_format"] == data_format],
4850
self.cols_to_return,
4951
self.preprocess_f,
52+
self.return_none_on_error
5053
)
5154
print(f'"{data_format}" dataset created')
5255
dataloader = DataLoader(dataset, **self.dataloader_kwargs)

DPF/filters/images/ocr_filter.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from typing import Optional
2+
import os
3+
import torch
4+
from torch import nn
5+
import numpy as np
6+
import json
7+
8+
try:
9+
from torch.utils.data.dataloader import default_collate
10+
except ImportError:
11+
from torch.utils.data import default_collate
12+
from torchvision import models, transforms
13+
from huggingface_hub import hf_hub_url, cached_download
14+
15+
from DPF.filters.utils import FP16Module, identical_collate_fn
16+
from DPF.utils import read_image_rgb_from_bytes
17+
from .img_filter import ImageFilter
18+
19+
from .ocr_model.utils import AttnLabelConverter
20+
from .ocr_model.dataset import AlignCollate
21+
from .ocr_model.model import Model
22+
23+
24+
class Options:
25+
pass
26+
27+
28+
class OCRFilter(ImageFilter):
29+
30+
def __init__(
31+
self,
32+
weights_path: str,
33+
model_name: Optional[str] = None,
34+
device: str = "cuda:0",
35+
workers: int = 16,
36+
pad: int = 5,
37+
pbar: bool = True,
38+
):
39+
super().__init__(pbar)
40+
41+
self.num_workers = workers
42+
self.batch_size = 1
43+
self.device = device
44+
45+
self.weights_path = weights_path
46+
self.model_name = model_name or os.path.basename(self.weights_path).split('.')[0]
47+
# load model
48+
self.opt = Options()
49+
self.opt.workers = 4
50+
self.opt.batch_size = 192
51+
self.opt.batch_max_length = 32
52+
self.opt.imgH = 32
53+
self.opt.imgW = 100
54+
self.opt.rgb = False
55+
self.opt.character = '0123456789abcdefghijklmnopqrstuvwxyz'
56+
self.opt.sensitive = False
57+
self.opt.PAD = False
58+
self.opt.Transformation = "TPS"
59+
self.opt.FeatureExtraction = "ResNet"
60+
self.opt.SequenceModeling = "BiLSTM"
61+
self.opt.Prediction = "Attn"
62+
self.opt.num_fiducial = 20
63+
self.opt.input_channel = 1
64+
self.opt.output_channel = 512
65+
self.opt.hidden_size = 256
66+
67+
self.converter = AttnLabelConverter(self.opt.character)
68+
self.opt.num_class = len(self.converter.character)
69+
70+
self.model = Model(self.opt)
71+
weights = torch.load(self.weights_path)
72+
keys = list(weights.keys())
73+
for key in keys:
74+
weights[key.lstrip('module.')] = weights[key]
75+
weights.pop(key)
76+
77+
self.model.load_state_dict(weights)
78+
self.model.to(self.device)
79+
self.model.eval()
80+
81+
self.AlignCollate = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD)
82+
#
83+
self.text_box_col = "text_boxes"
84+
self.ocr_col = f"OCR_{self.model_name}"
85+
86+
self.schema = ["image_path", self.ocr_col]
87+
self.dataloader_kwargs = {
88+
"num_workers": self.num_workers,
89+
"batch_size": self.batch_size,
90+
"preprocess_f": self.preprocess,
91+
"collate_fn": lambda x: x,
92+
"drop_last": False,
93+
"cols_to_return": [self.text_box_col],
94+
}
95+
96+
def preprocess(self, img_bytes: bytes, data: dict):
97+
image_path = data["image_path"]
98+
boxes = json.loads(data[self.text_box_col])
99+
pil_img = read_image_rgb_from_bytes(img_bytes).convert('L')
100+
return image_path, pil_img, boxes
101+
102+
def process_batch(self, batch) -> dict:
103+
df_batch_labels = self._generate_dict_from_schema()
104+
image_path, pil_img, boxes = batch[0]
105+
w, h = pil_img.size
106+
107+
input_data = []
108+
for box in boxes:
109+
left = max(box[0][0], 0)
110+
upper = max(box[0][1], 0)
111+
right = min(box[1][0], w)
112+
lower = min(box[1][1], h)
113+
if upper > lower:
114+
upper, lower = lower, upper
115+
if left > right:
116+
left, right = right, left
117+
118+
crop = pil_img.crop(
119+
(left, upper, right, lower)
120+
)
121+
input_data.append((crop, ''))
122+
123+
if len(input_data) == 0:
124+
df_batch_labels[self.ocr_col].append("[]")
125+
df_batch_labels["image_path"].append(image_path)
126+
return df_batch_labels
127+
128+
data_preproc = self.AlignCollate(input_data)
129+
image_tensors = data_preproc[0]
130+
131+
batch_size = image_tensors.size(0)
132+
image = image_tensors.to(self.device)
133+
length_for_pred = torch.IntTensor([self.opt.batch_max_length] * batch_size).to(self.device)
134+
text_for_pred = torch.LongTensor(batch_size, self.opt.batch_max_length + 1).fill_(0).to(self.device)
135+
136+
preds = self.model(image, text_for_pred, is_train=False)
137+
_, preds_index = preds.max(2)
138+
preds_str = self.converter.decode(preds_index, length_for_pred)
139+
preds_str = [s.replace('[s]', '') for s in preds_str]
140+
141+
res = []
142+
for box, prediction in zip(boxes, preds_str):
143+
res.append((box, prediction))
144+
145+
df_batch_labels[self.ocr_col].append(json.dumps(res))
146+
df_batch_labels["image_path"].append(image_path)
147+
148+
return df_batch_labels

DPF/filters/images/ocr_model/__init__.py

Whitespace-only changes.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import sys
3+
import re
4+
import six
5+
import math
6+
import lmdb
7+
import torch
8+
9+
from natsort import natsorted
10+
from PIL import Image
11+
import numpy as np
12+
from torch.utils.data import Dataset, ConcatDataset, Subset
13+
from torch._utils import _accumulate
14+
import torchvision.transforms as transforms
15+
16+
17+
class ResizeNormalize(object):
18+
19+
def __init__(self, size, interpolation=Image.BICUBIC):
20+
self.size = size
21+
self.interpolation = interpolation
22+
self.toTensor = transforms.ToTensor()
23+
24+
def __call__(self, img):
25+
img = img.resize(self.size, self.interpolation)
26+
img = self.toTensor(img)
27+
img.sub_(0.5).div_(0.5)
28+
return img
29+
30+
31+
class NormalizePAD(object):
32+
33+
def __init__(self, max_size, PAD_type='right'):
34+
self.toTensor = transforms.ToTensor()
35+
self.max_size = max_size
36+
self.max_width_half = math.floor(max_size[2] / 2)
37+
self.PAD_type = PAD_type
38+
39+
def __call__(self, img):
40+
img = self.toTensor(img)
41+
img.sub_(0.5).div_(0.5)
42+
c, h, w = img.size()
43+
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
44+
Pad_img[:, :, :w] = img # right pad
45+
if self.max_size[2] != w: # add border Pad
46+
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
47+
48+
return Pad_img
49+
50+
51+
class AlignCollate(object):
52+
53+
def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False):
54+
self.imgH = imgH
55+
self.imgW = imgW
56+
self.keep_ratio_with_pad = keep_ratio_with_pad
57+
58+
def __call__(self, batch):
59+
batch = filter(lambda x: x is not None, batch)
60+
images, labels = zip(*batch)
61+
62+
if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper
63+
resized_max_w = self.imgW
64+
input_channel = 3 if images[0].mode == 'RGB' else 1
65+
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
66+
67+
resized_images = []
68+
for image in images:
69+
w, h = image.size
70+
ratio = w / float(h)
71+
if math.ceil(self.imgH * ratio) > self.imgW:
72+
resized_w = self.imgW
73+
else:
74+
resized_w = math.ceil(self.imgH * ratio)
75+
76+
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
77+
resized_images.append(transform(resized_image))
78+
# resized_image.save('./image_test/%d_test.jpg' % w)
79+
80+
image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
81+
82+
else:
83+
transform = ResizeNormalize((self.imgW, self.imgH))
84+
image_tensors = [transform(image) for image in images]
85+
image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
86+
87+
return image_tensors, labels
88+
89+
90+
def tensor2im(image_tensor, imtype=np.uint8):
91+
image_numpy = image_tensor.cpu().float().numpy()
92+
if image_numpy.shape[0] == 1:
93+
image_numpy = np.tile(image_numpy, (3, 1, 1))
94+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
95+
return image_numpy.astype(imtype)
96+
97+
98+
def save_image(image_numpy, image_path):
99+
image_pil = Image.fromarray(image_numpy)
100+
image_pil.save(image_path)

0 commit comments

Comments
 (0)