Skip to content

Commit 94d9954

Browse files
committed
Fix bugs with RGB
1 parent e37597d commit 94d9954

File tree

10 files changed

+32
-15
lines changed

10 files changed

+32
-15
lines changed

benchmark/detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from surya.benchmark.metrics import precision_recall
88
from surya.benchmark.tesseract import tesseract_parallel
99
from surya.model.detection.segformer import load_model, load_processor
10-
from surya.input.processing import open_pdf, get_page_images
10+
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
1111
from surya.detection import batch_text_detection
1212
from surya.postprocessing.heatmap import draw_polys_on_image
1313
from surya.postprocessing.util import rescale_bbox
@@ -47,7 +47,7 @@ def main():
4747
# These have already been shuffled randomly, so sampling from the start is fine
4848
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
4949
images = list(dataset["image"])
50-
images = [i.convert("RGB") for i in images]
50+
images = convert_if_not_rgb(images)
5151
correct_boxes = []
5252
for i, boxes in enumerate(dataset["bboxes"]):
5353
img_size = images[i].size

benchmark/layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from surya.benchmark.metrics import precision_recall
77
from surya.detection import batch_text_detection
88
from surya.model.detection.segformer import load_model, load_processor
9-
from surya.input.processing import open_pdf, get_page_images
9+
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
1010
from surya.layout import batch_layout_detection
1111
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
1212
from surya.postprocessing.util import rescale_bbox
@@ -33,7 +33,7 @@ def main():
3333
# These have already been shuffled randomly, so sampling from the start is fine
3434
dataset = datasets.load_dataset(settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
3535
images = list(dataset["image"])
36-
images = [i.convert("RGB") for i in images]
36+
images = convert_if_not_rgb(images)
3737

3838
start = time.time()
3939
line_predictions = batch_text_detection(images, det_model, det_processor)

benchmark/ordering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import json
55

6+
from surya.input.processing import convert_if_not_rgb
67
from surya.model.ordering.model import load_model
78
from surya.model.ordering.processor import load_processor
89
from surya.ordering import batch_ordering
@@ -29,7 +30,7 @@ def main():
2930
split = f"train[:{args.max}]"
3031
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
3132
images = list(dataset["image"])
32-
images = [i.convert("RGB") for i in images]
33+
images = convert_if_not_rgb(images)
3334
bboxes = list(dataset["bboxes"])
3435

3536
start = time.time()

benchmark/recognition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from benchmark.scoring import overlap_score
7+
from surya.input.processing import convert_if_not_rgb
78
from surya.model.recognition.model import load_model as load_recognition_model
89
from surya.model.recognition.processor import load_processor as load_recognition_processor
910
from surya.ocr import run_recognition
@@ -48,7 +49,7 @@ def main():
4849
dataset = dataset.filter(lambda x: x["language"] in langs)
4950

5051
images = list(dataset["image"])
51-
images = [i.convert("RGB") for i in images]
52+
images = convert_if_not_rgb(images)
5253
bboxes = list(dataset["bboxes"])
5354
line_text = list(dataset["text"])
5455
languages = list(dataset["language"])

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.4.10"
3+
version = "0.4.11"
44
description = "OCR, layout, reading order, and line detection in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"

surya/detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from surya.model.detection.segformer import SegformerForRegressionMask
88
from surya.postprocessing.heatmap import get_and_clean_boxes
99
from surya.postprocessing.affinity import get_vertical_lines
10-
from surya.input.processing import prepare_image_detection, split_image, get_total_splits
10+
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
1111
from surya.schema import TextDetectionResult
1212
from surya.settings import settings
1313
from tqdm import tqdm
@@ -51,7 +51,7 @@ def batch_detection(images: List, model: SegformerForRegressionMask, processor,
5151
all_preds = []
5252
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
5353
batch_image_idxs = batches[batch_idx]
54-
batch_images = [images[j].convert("RGB") for j in batch_image_idxs]
54+
batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs])
5555

5656
split_index = []
5757
split_heights = []

surya/input/processing.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
import random
31
from typing import List
42

53
import cv2
@@ -11,6 +9,15 @@
119
from surya.settings import settings
1210

1311

12+
def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
13+
new_images = []
14+
for image in images:
15+
if image.mode != "RGB":
16+
image = image.convert("RGB")
17+
new_images.append(image)
18+
return new_images
19+
20+
1421
def get_total_splits(image_size, processor):
1522
img_height = list(image_size)[1]
1623
max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
@@ -48,6 +55,8 @@ def split_image(img, processor):
4855
def prepare_image_detection(img, processor):
4956
new_size = (processor.size["width"], processor.size["height"])
5057

58+
# This double resize actually necessary for downstream accuracy
59+
img.thumbnail(new_size, Image.Resampling.LANCZOS)
5160
img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
5261

5362
img = np.asarray(img, dtype=np.uint8)

surya/ocr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from PIL import Image
33

44
from surya.detection import batch_text_detection
5-
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image
5+
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image, convert_if_not_rgb
66
from surya.postprocessing.text import sort_text_lines
77
from surya.recognition import batch_recognition
88
from surya.schema import TextLine, OCRResult
99

1010

1111
def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]:
12+
images = convert_if_not_rgb(images)
1213
# Polygons need to be in corner format - [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], bboxes in [x1, y1, x2, y2] format
1314
assert bboxes is not None or polygons is not None
1415
assert len(images) == len(langs), "You need to pass in one list of languages for each image"
@@ -57,6 +58,7 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model
5758

5859

5960
def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]:
61+
images = convert_if_not_rgb(images)
6062
det_predictions = batch_text_detection(images, det_model, det_processor)
6163

6264
all_slices = []

surya/ordering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from PIL import Image
55

6+
from surya.input.processing import convert_if_not_rgb
67
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
78
from surya.schema import OrderBox, OrderResult
89
from surya.settings import settings
@@ -37,7 +38,7 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVi
3738
if batch_size is None:
3839
batch_size = get_batch_size()
3940

40-
images = [image.convert("RGB") for image in images]
41+
images = convert_if_not_rgb(images)
4142

4243
output_order = []
4344
for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"):

surya/recognition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from PIL import Image
44

5+
from surya.input.processing import convert_if_not_rgb
56
from surya.postprocessing.math.latex import fix_math, contains_math
67
from surya.postprocessing.text import truncate_repetitions
78
from surya.settings import settings
@@ -24,9 +25,11 @@ def get_batch_size():
2425
def batch_recognition(images: List, languages: List[List[str]], model, processor, batch_size=None):
2526
assert all([isinstance(image, Image.Image) for image in images])
2627
assert len(images) == len(languages)
27-
assert [len(l) <= settings.RECOGNITION_MAX_LANGS for l in languages], f"OCR only supports up to {settings.RECOGNITION_MAX_LANGS} languages per image"
2828

29-
images = [image.convert("RGB") for image in images]
29+
for l in languages:
30+
assert len(l) <= settings.RECOGNITION_MAX_LANGS, f"OCR only supports up to {settings.RECOGNITION_MAX_LANGS} languages per image, you passed {l}."
31+
32+
images = convert_if_not_rgb(images)
3033
if batch_size is None:
3134
batch_size = get_batch_size()
3235

0 commit comments

Comments
 (0)