Skip to content

Commit 31d9126

Browse files
authored
Merge pull request #293 from VikParuchuri/dev
Integrate new latex OCR model
2 parents 217439b + 91c9ad9 commit 31d9126

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1231
-105
lines changed

.github/workflows/benchmarks.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,8 @@ jobs:
3737
- name: Run table recognition benchmark
3838
run: |
3939
poetry run python benchmark/table_recognition.py --max_rows 5
40-
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition
40+
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition
41+
- name: Run texify benchmark
42+
run: |
43+
poetry run python benchmark/texify.py --max_rows 5
44+
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify

.github/workflows/scripts.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ jobs:
2525
- name: Test detection
2626
run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0
2727
- name: Test OCR
28+
env:
29+
RECOGNITION_MAX_TOKENS: 25
2830
run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0
2931
- name: Test layout
3032
run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0
3133
- name: Test table
3234
run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0
35+
- name: Test texify
36+
env:
37+
TEXIFY_MAX_TOKENS: 25
38+
run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0
3339
- name: Test detection folder
3440
run: poetry run surya_detect benchmark_data/pdfs --page_range 0

README.md

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Surya is a document OCR toolkit that does:
77
- Layout analysis (table, image, header, etc detection)
88
- Reading order detection
99
- Table recognition (detecting rows/columns)
10+
- LaTeX OCR
1011

1112
It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
1213

@@ -19,9 +20,9 @@ It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmar
1920
|:------------------------------------------------------------------:|:--------------------------------------------------------------------------:|
2021
| <img src="static/images/excerpt_layout.png" width="500px"/> | <img src="static/images/excerpt_reading.jpg" width="500px"/> |
2122

22-
| Table Recognition | |
23-
|:-------------------------------------------------------------:|:----------------:|
24-
| <img src="static/images/scanned_tablerec.png" width="500px"/> | <img width="500px"/> |
23+
| Table Recognition | LaTeX OCR |
24+
|:-------------------------------------------------------------:|:------------------------------------------------------:|
25+
| <img src="static/images/scanned_tablerec.png" width="500px"/> | <img src="static/images/latex_ocr.png" width="500px"/> |
2526

2627

2728
Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.
@@ -284,10 +285,48 @@ from surya.table_rec import TableRecPredictor
284285
image = Image.open(IMAGE_PATH)
285286
table_rec_predictor = TableRecPredictor()
286287

287-
# list of dicts, one per image
288288
table_predictions = table_rec_predictor([image])
289289
```
290290

291+
## LaTeX OCR
292+
293+
This command will write out a json file with the LaTeX of the equations. You must pass in images that are already cropped to the equations. You can do this by running the layout model, then cropping, if you want.
294+
295+
```shell
296+
surya_latex_ocr DATA_PATH
297+
```
298+
299+
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
300+
- `--output_dir` specifies the directory to save results to instead of the default
301+
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
302+
303+
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
304+
305+
- `text` - the detected LaTeX text - it will be in KaTeX compatible LaTeX, with `<math display="block">...</math>` and `<math>...</math>` as delimiters.
306+
- `confidence` - the prediction confidence from 0-1.
307+
- `page` - the page number in the file
308+
309+
### From python
310+
311+
```python
312+
from PIL import Image
313+
from surya.texify import TexifyPredictor
314+
315+
image = Image.open(IMAGE_PATH)
316+
predictor = TexifyPredictor()
317+
318+
predictor([image])
319+
```
320+
321+
### Interactive app
322+
323+
You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with:
324+
325+
```shell
326+
pip install streamlit==1.40 streamlit-drawable-canvas-jsretry
327+
texify_gui
328+
```
329+
291330
# Limitations
292331

293332
- This is specialized for document OCR. It will likely not work on photos or other images.
@@ -413,6 +452,14 @@ Higher is better for intersection, which the percentage of the actual row/column
413452

414453
The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM. It has labeled rows and columns. After table recognition is run, the predicted rows and columns are compared to the ground truth. There is an additional penalty for predicting too many or too few rows/columns.
415454

455+
## LaTeX OCR
456+
457+
| Method | edit ⬇ | time taken (s) ⬇ |
458+
|--------|----------|------------------|
459+
| texify | 0.122617 | 35.6345 |
460+
461+
This inferences texify on a ground truth set of LaTeX, then does edit distance. This is a bit noisy, since 2 LaTeX strings that render the same can have different symbols in them.
462+
416463
## Running your own benchmarks
417464

418465
You can benchmark the performance of surya on your machine.
@@ -482,6 +529,15 @@ python benchmark/table_recognition.py --max_rows 1024 --tatr
482529
- `--results_dir` will let you specify a directory to save results to instead of the default one
483530
- `--tatr` specifies whether to also run table transformer
484531

532+
**LaTeX OCR**
533+
534+
```shell
535+
python benchmark/texify.py --max_rows 128
536+
```
537+
538+
- `--max_rows` controls how many images to process for the benchmark
539+
- `--results_dir` will let you specify a directory to save results to instead of the default one
540+
485541
# Training
486542

487543
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation.

benchmark/texify.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import argparse
2+
import os.path
3+
import random
4+
import re
5+
import time
6+
from functools import partial
7+
from pathlib import Path
8+
from typing import List
9+
10+
import click
11+
import datasets
12+
from tabulate import tabulate
13+
from bs4 import BeautifulSoup
14+
15+
from surya.settings import settings
16+
from surya.texify import TexifyPredictor, TexifyResult
17+
import json
18+
import io
19+
from rapidfuzz.distance import Levenshtein
20+
21+
def normalize_text(text):
22+
soup = BeautifulSoup(text, "html.parser")
23+
text = soup.get_text()
24+
text = re.sub(r"\n", " ", text)
25+
text = re.sub(r"\s+", " ", text)
26+
return text.strip()
27+
28+
29+
def score_text(predictions, references):
30+
lev_dist = []
31+
for p, r in zip(predictions, references):
32+
p = normalize_text(p)
33+
r = normalize_text(r)
34+
lev_dist.append(Levenshtein.normalized_distance(p, r))
35+
36+
return sum(lev_dist) / len(lev_dist)
37+
38+
39+
def inference_texify(source_data, predictor):
40+
texify_predictions: List[TexifyResult] = predictor([sd["image"] for sd in source_data])
41+
out_data = [
42+
{"text": texify_predictions[i].text, "equation": source_data[i]["equation"]}
43+
for i in range(len(texify_predictions))
44+
]
45+
46+
return out_data
47+
48+
49+
def image_to_bmp(image):
50+
img_out = io.BytesIO()
51+
image.save(img_out, format="BMP")
52+
return img_out
53+
54+
@click.command(help="Benchmark the performance of texify.")
55+
@click.option("--ds_name", type=str, help="Path to dataset file with source images/equations.", default=settings.TEXIFY_BENCHMARK_DATASET)
56+
@click.option("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
57+
@click.option("--max_rows", type=int, help="Maximum number of images to benchmark.", default=None)
58+
def main(ds_name: str, results_dir: str, max_rows: int):
59+
predictor = TexifyPredictor()
60+
ds = datasets.load_dataset(ds_name, split="train")
61+
62+
if max_rows:
63+
ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True)
64+
65+
start = time.time()
66+
predictions = inference_texify(ds, predictor)
67+
time_taken = time.time() - start
68+
69+
text = [p["text"] for p in predictions]
70+
references = [p["equation"] for p in predictions]
71+
scores = score_text(text, references)
72+
73+
write_data = {
74+
"scores": scores,
75+
"text": [{"prediction": p, "reference": r} for p, r in zip(text, references)]
76+
}
77+
78+
score_table = [
79+
["texify", write_data["scores"], time_taken]
80+
]
81+
score_headers = ["edit", "time taken (s)"]
82+
score_dirs = ["⬇", "⬇"]
83+
84+
score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)]
85+
print()
86+
print(tabulate(score_table, headers=["Method", *score_headers]))
87+
88+
result_path = Path(results_dir) / "texify_bench"
89+
result_path.mkdir(parents=True, exist_ok=True)
90+
with open(result_path / "results.json", "w") as f:
91+
json.dump(write_data, f, indent=4)
92+
93+
94+
if __name__ == "__main__":
95+
main()

benchmark/utils/verify_benchmark_scores.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def verify_table_rec(data):
3737
if row_score < 0.75 or col_score < 0.75:
3838
raise ValueError("Scores do not meet the required threshold")
3939

40+
def verify_texify(data):
41+
edit_dist = data["scores"]
42+
if edit_dist > .2:
43+
raise ValueError("Scores do not meet the required threshold")
44+
4045

4146
@click.command(help="Verify benchmark scores")
4247
@click.argument("file_path", type=str)
@@ -55,6 +60,8 @@ def main(file_path, bench_type):
5560
verify_order(data)
5661
elif bench_type == "table_recognition":
5762
verify_table_rec(data)
63+
elif bench_type == "texify":
64+
verify_texify(data)
5865
else:
5966
raise ValueError("Invalid benchmark type")
6067

detect_layout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from surya.scripts import detect_layout_cli
1+
from surya.scripts.detect_layout import detect_layout_cli
22

33
if __name__ == "__main__":
44
detect_layout_cli()

detect_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from surya.scripts import detect_text_cli
1+
from surya.scripts.detect_text import detect_text_cli
22

33
if __name__ == "__main__":
44
detect_text_cli()

ocr_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from surya.scripts import streamlit_app_cli
1+
from surya.scripts.run_streamlit_app import streamlit_app_cli
22

33
if __name__ == "__main__":
44
streamlit_app_cli()

ocr_latex.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from surya.scripts.ocr_latex import ocr_latex_cli
2+
3+
if __name__ == "__main__":
4+
ocr_latex_cli()

ocr_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from surya.scripts import ocr_text_cli
1+
from surya.scripts.ocr_text import ocr_text_cli
22

33
if __name__ == "__main__":
44
ocr_text_cli()

0 commit comments

Comments
 (0)