Skip to content

Commit 31e36e7

Browse files
authored
Merge pull request #115 from VikParuchuri/dev
OCR speedup
2 parents 80889bd + 1eb828a commit 31e36e7

File tree

11 files changed

+754
-413
lines changed

11 files changed

+754
-413
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,16 @@ rec_model, rec_processor = load_model(), load_processor()
121121
predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
122122
```
123123

124+
### Compilation
125+
126+
The OCR model can be compiled to get an ~15% speedup in total inference time. The first run will be slow while it compiles, though. First set `RECOGNITION_STATIC_CACHE=true`, then:
127+
128+
```python
129+
import torch
130+
131+
rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder)
132+
```
133+
124134
## Text line detection
125135

126136
This command will write out a json file with the detected bboxes.

benchmark/recognition.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import argparse
22
from collections import defaultdict
33

4+
import torch
5+
46
from benchmark.scoring import overlap_score
57
from surya.model.recognition.model import load_model as load_recognition_model
68
from surya.model.recognition.processor import load_processor as load_recognition_processor
@@ -26,8 +28,12 @@ def main():
2628
parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False)
2729
parser.add_argument("--langs", type=str, help="Specify certain languages to benchmark.", default=None)
2830
parser.add_argument("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28)
31+
parser.add_argument("--compile", action="store_true", help="Compile the model.", default=False)
2932
args = parser.parse_args()
3033

34+
if args.compile:
35+
assert settings.RECOGNITION_STATIC_CACHE, "You must set RECOGNITION_STATIC_CACHE to compile the model."
36+
3137
rec_model = load_recognition_model()
3238
rec_processor = load_recognition_processor()
3339

@@ -56,6 +62,11 @@ def main():
5662
else:
5763
lang_list.append(l)
5864

65+
if args.compile:
66+
rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder)
67+
# Run through one batch to compile the model
68+
run_recognition(images[:1], lang_list[:1], rec_model, rec_processor, bboxes=bboxes[:1])
69+
5970
start = time.time()
6071
predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes)
6172
surya_time = time.time() - start

ocr_text.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import json
33
from collections import defaultdict
44

5+
import torch
6+
57
from surya.input.langs import replace_lang_with_code, get_unique_langs
68
from surya.input.load import load_from_folder, load_from_file, load_lang_file
79
from surya.model.detection.segformer import load_model as load_detection_model, load_processor as load_detection_processor

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.4.8"
3+
version = "0.4.9"
44
description = "OCR, layout, reading order, and line detection in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"

surya/input/processing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ def split_image(img, processor):
4848
def prepare_image_detection(img, processor):
4949
new_size = (processor.size["width"], processor.size["height"])
5050

51-
img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size
52-
img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
53-
5451
img = np.asarray(img, dtype=np.uint8)
52+
img = cv2.resize(img, new_size, interpolation=cv2.INTER_LANCZOS4)
53+
5554
img = processor(img)["pixel_values"][0]
5655
img = torch.from_numpy(img)
5756
return img

surya/model/recognition/decoder.py

Lines changed: 121 additions & 327 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)