Skip to content

Commit db63214

Browse files
authored
Merge pull request #312 from VikParuchuri/dev
Inline math model, new text detection model
2 parents 06a3cc6 + 4eb67cf commit db63214

36 files changed

+906
-358
lines changed

.github/workflows/benchmarks.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ jobs:
2525
run: |
2626
poetry run python benchmark/detection.py --max_rows 2
2727
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
28+
- name: Run inline detection benchmarj
29+
run: |
30+
poetry run python benchmark/inline_detection.py --max_rows 5
31+
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/inline_math_bench/results.json --bench_type inline_detection
2832
- name: Run recognition benchmark test
2933
run: |
3034
poetry run python benchmark/recognition.py --max_rows 2

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ For Google Cloud, I aligned the output from Google Cloud with the ground truth.
388388

389389
| Model | Time (s) | Time per page (s) | precision | recall |
390390
|-----------|------------|---------------------|-------------|----------|
391-
| surya | 50.2099 | 0.196133 | 0.821061 | 0.956556 |
391+
| surya | 47.2285 | 0.094452 | 0.835857 | 0.960807 |
392392
| tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 |
393393

394394

benchmark/inline_detection.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import collections
2+
import copy
3+
import json
4+
from pathlib import Path
5+
6+
import click
7+
8+
from benchmark.utils.metrics import precision_recall
9+
from surya.debug.draw import draw_polys_on_image
10+
from surya.input.processing import convert_if_not_rgb
11+
from surya.common.util import rescale_bbox
12+
from surya.settings import settings
13+
from surya.detection import DetectionPredictor, InlineDetectionPredictor
14+
15+
import os
16+
import time
17+
from tabulate import tabulate
18+
import datasets
19+
20+
21+
@click.command(help="Benchmark inline math detection model.")
22+
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
23+
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100)
24+
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
25+
def main(results_dir: str, max_rows: int, debug: bool):
26+
det_predictor = DetectionPredictor()
27+
inline_det_predictor = InlineDetectionPredictor()
28+
29+
dataset = datasets.load_dataset(settings.INLINE_MATH_BENCH_DATASET_NAME, split=f"train[:{max_rows}]")
30+
images = list(dataset["image"])
31+
images = convert_if_not_rgb(images)
32+
correct_boxes = []
33+
for i, boxes in enumerate(dataset["bboxes"]):
34+
img_size = images[i].size
35+
# Rescale from normalized 0-1 vals to image size
36+
correct_boxes.append([rescale_bbox(b, (1, 1), img_size) for b in boxes])
37+
38+
if settings.DETECTOR_STATIC_CACHE:
39+
# Run through one batch to compile the model
40+
det_predictor(images[:1])
41+
inline_det_predictor(images[:1], [[]])
42+
43+
start = time.time()
44+
det_results = det_predictor(images)
45+
46+
# Reformat text boxes to inline math input format
47+
text_boxes = []
48+
for result in det_results:
49+
text_boxes.append([b.bbox for b in result.bboxes])
50+
51+
inline_results = inline_det_predictor(images, text_boxes)
52+
surya_time = time.time() - start
53+
54+
result_path = Path(results_dir) / "inline_math_bench"
55+
result_path.mkdir(parents=True, exist_ok=True)
56+
57+
page_metrics = collections.OrderedDict()
58+
for idx, (sb, cb) in enumerate(zip(inline_results, correct_boxes)):
59+
surya_boxes = [s.bbox for s in sb.bboxes]
60+
surya_polys = [s.polygon for s in sb.bboxes]
61+
62+
surya_metrics = precision_recall(surya_boxes, cb)
63+
64+
page_metrics[idx] = {
65+
"surya": surya_metrics,
66+
}
67+
68+
if debug:
69+
bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))
70+
bbox_image.save(result_path / f"{idx}_bbox.png")
71+
72+
mean_metrics = {}
73+
metric_types = sorted(page_metrics[0]["surya"].keys())
74+
models = ["surya"]
75+
76+
for k in models:
77+
for m in metric_types:
78+
metric = []
79+
for page in page_metrics:
80+
metric.append(page_metrics[page][k][m])
81+
if k not in mean_metrics:
82+
mean_metrics[k] = {}
83+
mean_metrics[k][m] = sum(metric) / len(metric)
84+
85+
out_data = {
86+
"times": {
87+
"surya": surya_time,
88+
},
89+
"metrics": mean_metrics,
90+
"page_metrics": page_metrics
91+
}
92+
93+
with open(result_path / "results.json", "w+", encoding="utf-8") as f:
94+
json.dump(out_data, f, indent=4)
95+
96+
table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types
97+
table_data = [
98+
["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types],
99+
]
100+
101+
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
102+
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.")
103+
print(f"Wrote results to {result_path}")
104+
105+
106+
if __name__ == "__main__":
107+
main()

benchmark/recognition.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from surya.settings import settings
1111
from surya.recognition.languages import CODE_TO_LANGUAGE
1212
from benchmark.utils.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE
13+
from benchmark.utils.textract import textract_ocr_parallel
1314
import os
1415
import datasets
1516
import json
@@ -22,22 +23,24 @@
2223
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
2324
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None)
2425
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
25-
@click.option("--tesseract", is_flag=True, help="Run tesseract instead of surya.", default=False)
26+
@click.option("--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False)
27+
@click.option("--textract", is_flag=True, help="Run benchmarks on textract.", default=False)
2628
@click.option("--langs", type=str, help="Specify certain languages to benchmark.", default=None)
2729
@click.option("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28)
30+
@click.option("--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28)
2831
@click.option("--specify_language", is_flag=True, help="Pass language codes into the model.", default=False)
29-
def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: str, tess_cpus: int, specify_language: bool):
32+
def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, textract: bool, langs: str, tess_cpus: int, textract_cpus:int, specify_language: bool):
3033
rec_predictor = RecognitionPredictor()
3134

3235
split = "train"
33-
if max_rows:
34-
split = f"train[:{max_rows}]"
35-
3636
dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split)
3737

3838
if langs:
3939
langs = langs.split(",")
4040
dataset = dataset.filter(lambda x: x["language"] in langs, num_proc=4)
41+
42+
if max_rows and max_rows<len(dataset):
43+
dataset = dataset.shuffle().select(range(max_rows))
4144

4245
images = list(dataset["image"])
4346
images = convert_if_not_rgb(images)
@@ -121,6 +124,28 @@ def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: s
121124
with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f:
122125
json.dump(tess_scores, f)
123126

127+
if textract:
128+
start = time.time()
129+
textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus)
130+
textract_time = time.time()-start
131+
132+
textract_scores = defaultdict(list)
133+
for idx, (pred, ref_text, lang) in enumerate(zip(textract_predictions, line_text, lang_list)):
134+
image_score = overlap_score(pred, ref_text)
135+
for l in lang:
136+
textract_scores[CODE_TO_LANGUAGE[l]].append(image_score)
137+
138+
flat_textract_scores = [s for l in textract_scores for s in textract_scores[l]]
139+
benchmark_stats["textract"] = {
140+
"avg_score": sum(flat_textract_scores) / len(flat_textract_scores),
141+
"lang_scores": {l: sum(scores) / len(scores) for l, scores in textract_scores.items()},
142+
"time_per_img": textract_time / len(images)
143+
}
144+
print(len(flat_textract_scores))
145+
146+
with open(os.path.join(result_path, "textract_scores.json"), "w+") as f:
147+
json.dump(textract_scores, f)
148+
124149
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
125150
json.dump(benchmark_stats, f)
126151

@@ -133,6 +158,10 @@ def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: s
133158
table_data.append(
134159
["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages]
135160
)
161+
if textract:
162+
table_data.append(
163+
["textract", benchmark_stats["textract"]["time_per_img"], benchmark_stats["textract"]["avg_score"]] + [benchmark_stats["textract"]["lang_scores"][l] for l in key_languages],
164+
)
136165

137166
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
138167
print("Only a few major languages are displayed. See the result path for additional languages.")

benchmark/utils/textract.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
from concurrent.futures import ThreadPoolExecutor
3+
from tqdm import tqdm
4+
import traceback
5+
6+
from surya.input.processing import slice_bboxes_from_image
7+
from surya.recognition import RecognitionPredictor
8+
9+
def textract_ocr(extractor, img):
10+
try:
11+
document = extractor.detect_document_text(file_source=img)
12+
return [line.text for line in document.lines]
13+
except:
14+
traceback.print_exc()
15+
return [None]
16+
17+
def textract_ocr_parallel(imgs, cpus=None):
18+
from textractor import Textractor # Optional dependency
19+
20+
extractor = Textractor(profile_name='default')
21+
parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size())
22+
if not cpus:
23+
cpus = os.cpu_count()
24+
parallel_cores = min(parallel_cores, cpus)
25+
26+
with ThreadPoolExecutor(max_workers=parallel_cores) as executor:
27+
textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR")
28+
textract_text = list(textract_text)
29+
return textract_text

benchmark/utils/verify_benchmark_scores.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ def verify_det(data):
1818
raise ValueError("Scores do not meet the required threshold")
1919

2020

21+
def verify_inline_det(data):
22+
scores = data["metrics"]["surya"]
23+
if scores["precision"] <= 0.5 or scores["recall"] <= 0.5:
24+
raise ValueError("Scores do not meet the required threshold")
25+
2126
def verify_rec(data):
2227
scores = data["surya"]
2328
if scores["avg_score"] <= 0.9:
@@ -62,6 +67,8 @@ def main(file_path, bench_type):
6267
verify_table_rec(data)
6368
elif bench_type == "texify":
6469
verify_texify(data)
70+
elif bench_type == "inline_detection":
71+
verify_inline_det(data)
6572
else:
6673
raise ValueError("Invalid benchmark type")
6774

0 commit comments

Comments
 (0)