Skip to content

Commit 80e9a7e

Browse files
authored
Merge pull request #392 from datalab-to/dev
Dev
2 parents 8023f3a + 3c84652 commit 80e9a7e

File tree

8 files changed

+64
-17
lines changed

8 files changed

+64
-17
lines changed

.github/workflows/benchmarks.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@ env:
77

88
jobs:
99
build:
10-
runs-on: ${{ matrix.os }}
11-
strategy:
12-
matrix:
13-
os: [ubuntu-latest, windows-latest]
10+
runs-on: t4_gpu
1411
steps:
1512
- uses: actions/checkout@v3
1613
- name: Set up Python 3.11

.github/workflows/ci.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ name: Unit tests
22

33
on: [push]
44

5-
env:
6-
TORCH_DEVICE: "cpu"
7-
85
jobs:
96
build:
10-
runs-on: ubuntu-latest
7+
runs-on: ${{ matrix.os }}
8+
strategy:
9+
matrix:
10+
os: [t4_gpu, ubuntu-latest, windows-latest]
1111
steps:
1212
- uses: actions/checkout@v3
1313
- name: Set up Python 3.11

.github/workflows/scripts.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@ name: Test CLI scripts
22

33
on: [push]
44

5-
env:
6-
TORCH_DEVICE: "cpu"
7-
85
jobs:
96
build:
10-
runs-on: ubuntu-latest
7+
runs-on: t4_gpu
118
steps:
129
- uses: actions/checkout@v3
1310
- name: Set up Python 3.11

benchmark/texify.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,12 @@ def score_text(predictions, references):
3838
return sum(lev_dist) / len(lev_dist)
3939

4040

41-
def inference_texify(source_data, predictor: RecognitionPredictor):
41+
def inference_texify(
42+
source_data, predictor: RecognitionPredictor, line_mode: bool = False
43+
):
4244
images = [sd["image"] for sd in source_data]
43-
tasks = [TaskNames.block_without_boxes] * len(images)
45+
mode = TaskNames.ocr_with_boxes if line_mode else TaskNames.block_without_boxes
46+
tasks = [mode] * len(images)
4447
bboxes = [[[0, 0, image.width, image.height]] for image in images]
4548
texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes)
4649
out_data = [
@@ -70,15 +73,18 @@ def inference_texify(source_data, predictor: RecognitionPredictor):
7073
@click.option(
7174
"--max_rows", type=int, help="Maximum number of images to benchmark.", default=None
7275
)
73-
def main(ds_name: str, results_dir: str, max_rows: int):
76+
@click.option(
77+
"--line_mode", is_flag=True, help="Use line mode for texify.", default=False
78+
)
79+
def main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool):
7480
predictor = RecognitionPredictor()
7581
ds = datasets.load_dataset(ds_name, split="train")
7682

7783
if max_rows:
7884
ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True)
7985

8086
start = time.time()
81-
predictions = inference_texify(ds, predictor)
87+
predictions = inference_texify(ds, predictor, line_mode)
8288
time_taken = time.time() - start
8389

8490
text = [p["text"] for p in predictions]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.14.5"
3+
version = "0.14.6"
44
description = "OCR, layout, reading order, and table recognition in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"

surya/recognition/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
detect_repeat_token,
3535
prediction_to_polygon_batch,
3636
unwrap_math,
37+
clean_math_tags,
3738
)
3839
from surya.recognition.schema import TextLine, OCRResult, TextChar
3940
from surya.common.surya.schema import TaskNames
@@ -917,6 +918,7 @@ def __call__(
917918
)
918919
text = "".join([char.text for char in text_line])
919920
text = unwrap_math(text)
921+
text = clean_math_tags(text)
920922
lines.append(
921923
TextLine(
922924
text=text,

surya/recognition/util.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,35 @@ def unwrap_math(text: str) -> str:
2727
return text
2828

2929

30+
MATH_BLOCK = re.compile(r"(<math\b[^>]*>)(.*?)</math>", flags=re.I | re.S)
31+
STRIP_TAGS = re.compile(r"</?(?:br|u|del|mark|i|b|sup|sub)\b[^>]*>", flags=re.I | re.S)
32+
33+
34+
def clean_math_tags(html: str) -> str:
35+
# strip unwanted tags inside every well‑formed <math>…</math>
36+
def _inner(m):
37+
inner = STRIP_TAGS.sub("", m.group(2))
38+
return f"{m.group(1)}{inner}</math>" if inner.strip() else ""
39+
40+
cleaned = MATH_BLOCK.sub(_inner, html)
41+
42+
# drop only orphan *closing* </math> tags
43+
depth = 0
44+
parts = []
45+
for token in re.split(r"(</?math[^>]*>)", cleaned, flags=re.I):
46+
if token.lower().startswith("<math"):
47+
depth += 1
48+
parts.append(token)
49+
elif token.lower() == "</math>":
50+
if depth: # keep it only if it matches an open
51+
depth -= 1
52+
parts.append(token)
53+
# else: skip orphan closing tag
54+
else:
55+
parts.append(token)
56+
return "".join(parts)
57+
58+
3059
def detect_repeat_token(predicted_tokens: List[int], max_repeats: int = 40):
3160
if len(predicted_tokens) < max_repeats:
3261
return False

tests/test_recognition.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
22
from PIL import ImageDraw, Image
3+
from surya.recognition.util import clean_math_tags
34

45

56
def test_recognition(recognition_predictor, detection_predictor, test_image):
@@ -49,3 +50,18 @@ def test_recognition_drop_repeats(recognition_predictor, detection_predictor):
4950
assert len(recognition_results) == 1
5051
result = recognition_results[0].text_lines
5152
assert result[0].text == ""
53+
54+
55+
def test_recognition_clean_math():
56+
math = """<math display="block">na_n^{1+2r} \\text{cov}(\\hat{f}_n^{(r)}(x), \\hat{f}_n^{(r)}(y)) = \\frac{1}{n} \\sum_{j=1}^n \\frac{a_n^{1+2r}}{a_j^{1+2r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_j}{a_j}\\right)\\right) <br>+ \\frac{a_n^{1+2r}}{n} \\sum_{\\substack{j \\neq k \\\\ 1 \\le j, k \\le n}} \\frac{1}{(a_j a_k)^{1+r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_k}{a_k}\\right)\\right) <br>=: I_1 + I_2.</math> (1.7)</math>'"""
57+
clean_math = clean_math_tags(math)
58+
59+
assert clean_math.count("</math>") == 1, "Should have one closing math tag"
60+
assert "<br>" not in clean_math, "Should not have <br> tags in cleaned math"
61+
62+
63+
def test_recognition_clean_math_preserve_text():
64+
text = """Hello, this is a sentence with <math display="inline">x^2 + y^2 = z^2</math> and some text after it, with a weird tag <hello> and <goodbye>."""
65+
clean_text = clean_math_tags(text)
66+
67+
assert clean_text == text

0 commit comments

Comments
 (0)