Skip to content

Commit 21b029f

Browse files
authored
Merge pull request #383 from VikParuchuri/dev
Dev
2 parents 442069e + 69e5294 commit 21b029f

File tree

5 files changed

+66
-25
lines changed

5 files changed

+66
-25
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.14.3"
3+
version = "0.14.4"
44
description = "OCR, layout, reading order, and table recognition in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"

surya/recognition/__init__.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,17 @@ class RecognitionPrompt:
7676
class RecognitionPredictor(BasePredictor):
7777
model_loader_cls = RecognitionModelLoader
7878
batch_size = settings.RECOGNITION_BATCH_SIZE
79-
torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16
79+
torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16
8080
default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128}
81-
encoder_chunk_size: int = 4096
81+
encoder_chunk_size: int = 4096 # Default chunk size
8282
encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768}
8383
min_prefill_ratio: int = 0.2
8484
min_trim_length: int = 50
8585
tasks = {
8686
TaskNames.ocr_with_boxes: {
8787
"needs_bboxes": True,
8888
"img_size": (1024, 256), # 370 max tokens
89-
"max_tokens": 256,
89+
"max_tokens": 224,
9090
},
9191
TaskNames.ocr_without_boxes: {
9292
"needs_bboxes": False,
@@ -111,8 +111,11 @@ def __init__(self, checkpoint=None, device=settings.TORCH_DEVICE_MODEL, dtype=No
111111
self.processor.pad_token_id, device=self.model.device, dtype=torch.long
112112
)
113113

114-
def get_encoder_chunk_size(self):
115-
chunk_size = self.encoder_chunk_size
114+
def get_encoder_chunk_size(self) -> int:
115+
if settings.RECOGNITION_CHUNK_SIZE is not None:
116+
return settings.RECOGNITION_CHUNK_SIZE
117+
118+
chunk_size = settings.encoder_chunk_size
116119
if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:
117120
if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:
118121
chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL]
@@ -239,6 +242,8 @@ def slice_bboxes(
239242
== len(all_polygons)
240243
== len(all_text)
241244
== len(all_task_names)
245+
), (
246+
f"Mismatch in lengths: {len(all_slices)}, {sum(slice_map)}, {len(all_polygons)}, {len(all_text)}, {len(all_task_names)}"
242247
)
243248

244249
return {
@@ -593,7 +598,7 @@ def prediction_loop(
593598
current_inputs = self.maybe_trim_cache_padding(current_inputs)
594599
mark_step()
595600
pbar.close()
596-
601+
597602
del self.kv_cache
598603
self.kv_cache = None
599604
torch.cuda.empty_cache()
@@ -636,12 +641,14 @@ def get_bboxes_text(
636641
# If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely
637642
if drop_repeated_text and detect_repeat_token(image_tokens):
638643
char_predictions.append(
639-
TextChar(
640-
text="",
641-
polygon=blank_bbox,
642-
confidence=0,
643-
bbox_valid=False,
644-
)
644+
[
645+
TextChar(
646+
text="",
647+
polygon=blank_bbox,
648+
confidence=0,
649+
bbox_valid=False,
650+
)
651+
]
645652
)
646653
continue
647654

@@ -772,7 +779,7 @@ def __call__(
772779
highres_images: List[Image.Image] | None = None,
773780
bboxes: List[List[List[int]]] | None = None,
774781
polygons: List[List[List[List[int]]]] | None = None,
775-
input_text: List[str | None] | None = None,
782+
input_text: List[List[str | None]] | None = None,
776783
sort_lines: bool = False,
777784
math_mode: bool = True,
778785
return_words: bool = False,
@@ -857,7 +864,11 @@ def __call__(
857864
batch_bboxes, image_sizes, bbox_size, bbox_size // 2
858865
)
859866
char_predictions = self.get_bboxes_text(
860-
flat, predicted_tokens, scores, predicted_polygons
867+
flat,
868+
predicted_tokens,
869+
scores,
870+
predicted_polygons,
871+
drop_repeated_text=drop_repeated_text,
861872
)
862873

863874
char_predictions = sorted(zip(indices, char_predictions), key=lambda x: x[0])
@@ -886,7 +897,11 @@ def __call__(
886897
)
887898
)
888899
else:
889-
confidence = float(np.mean([char.confidence for char in text_line]))
900+
confidence = (
901+
float(np.mean([char.confidence for char in text_line]))
902+
if len(text_line) > 0
903+
else 0
904+
)
890905
poly_box = PolygonBox(polygon=polygon)
891906
for char in text_line:
892907
char.rescale(

surya/recognition/schema.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
1+
import math
2+
import numpy as np
13
from typing import Optional, List
24

3-
from pydantic import BaseModel
5+
from pydantic import BaseModel, field_validator
46

57
from surya.common.polygon import PolygonBox
68

79

8-
class TextChar(PolygonBox):
10+
class BaseChar(PolygonBox):
911
text: str
12+
confidence: Optional[float] = 0
13+
14+
@field_validator("confidence", mode="before")
15+
@classmethod
16+
def validate_confidence(cls, v: float) -> float:
17+
if v is None:
18+
return 0
19+
elif math.isnan(v) or np.isnan(v):
20+
return 0
21+
return v
22+
23+
24+
class TextChar(BaseChar):
1025
bbox_valid: bool = True # This is false when the given bbox is not valid
11-
confidence: Optional[float] = None
1226

1327

14-
class TextWord(PolygonBox):
15-
text: str
28+
class TextWord(BaseChar):
1629
bbox_valid: bool = True
17-
confidence: Optional[float] = None
1830

1931

20-
class TextLine(PolygonBox):
21-
text: str
32+
class TextLine(BaseChar):
2233
chars: List[TextChar] # Individual characters in the line
23-
confidence: Optional[float] = None
2434
original_text_good: bool = False
2535
words: List[TextWord] | None = None
2636

surya/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
8282
RECOGNITION_BATCH_SIZE: Optional[int] = (
8383
None # Defaults to 8 for CPU/MPS, 256 otherwise
8484
)
85+
RECOGNITION_CHUNK_SIZE: Optional[int] = None
8586
RECOGNITION_RENDER_FONTS: Dict[str, str] = {
8687
"all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"),
8788
"zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),

tests/test_recognition.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
from PIL import ImageDraw, Image
23

34

45
def test_recognition(recognition_predictor, detection_predictor, test_image):
@@ -34,3 +35,17 @@ def test_recognition_input_text(recognition_predictor, detection_predictor, test
3435
text_lines = recognition_results[0].text_lines
3536
assert len(text_lines) == 4
3637
assert "Hello World" in text_lines[0].text
38+
39+
40+
def test_recognition_drop_repeats(recognition_predictor, detection_predictor):
41+
image = Image.new("RGB", (1024, 128), "white")
42+
draw = ImageDraw.Draw(image)
43+
text = "a" * 80
44+
draw.text((5, 5), text, fill="black", font_size=24)
45+
46+
recognition_results = recognition_predictor(
47+
[image], None, bboxes=[[[0, 0, 1024, 128]]], drop_repeated_text=True
48+
)
49+
assert len(recognition_results) == 1
50+
result = recognition_results[0].text_lines
51+
assert result[0].text == ""

0 commit comments

Comments
 (0)