Skip to content

Commit 351b1be

Browse files
committed
Refactor rec predictor initialization in scripts + benchmarks
1 parent cd2acf0 commit 351b1be

File tree

8 files changed

+24
-8
lines changed

8 files changed

+24
-8
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,13 @@ Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference
132132

133133
```python
134134
from PIL import Image
135+
from surya.foundation import FoudnationPredictor
135136
from surya.recognition import RecognitionPredictor
136137
from surya.detection import DetectionPredictor
137138

138139
image = Image.open(IMAGE_PATH)
139-
recognition_predictor = RecognitionPredictor()
140+
foundation_predictor = FoundationPredictor()
141+
recognition_predictor = RecognitionPredictor(foundation_predictor)
140142
detection_predictor = DetectionPredictor()
141143

142144
predictions = recognition_predictor([image], det_predictor=detection_predictor)

benchmark/recognition.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from benchmark.utils.scoring import overlap_score, overlap_score_exact
88
from surya.input.processing import convert_if_not_rgb
99
from surya.debug.text import draw_text_on_image
10+
from surya.foundation import FoundationPredictor
1011
from surya.recognition import RecognitionPredictor
1112
from surya.settings import settings
1213
from surya.recognition.languages import CODE_TO_LANGUAGE
@@ -112,7 +113,8 @@ def main(
112113
textract_cpus: int,
113114
languages: str | None,
114115
):
115-
rec_predictor = RecognitionPredictor()
116+
foundation_predictor = FoundationPredictor()
117+
rec_predictor = RecognitionPredictor(foundation_predictor)
116118

117119
split = "train"
118120
dataset = datasets.load_dataset(

benchmark/texify.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from surya.common.surya.schema import TaskNames
1313
from surya.settings import settings
14+
from surya.foundation import FoundationPredictor
1415
from surya.recognition import RecognitionPredictor, OCRResult
1516
import json
1617
from rapidfuzz.distance import Levenshtein
@@ -77,7 +78,8 @@ def inference_texify(
7778
"--line_mode", is_flag=True, help="Use line mode for texify.", default=False
7879
)
7980
def main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool):
80-
predictor = RecognitionPredictor()
81+
foundation_predictor = FoundationPredictor()
82+
predictor = RecognitionPredictor(foundation_predictor)
8183
ds = datasets.load_dataset(ds_name, split="train")
8284

8385
if max_rows:

surya/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from surya.layout import LayoutPredictor
88
from surya.logging import configure_logging
99
from surya.ocr_error import OCRErrorPredictor
10+
from surya.foundation import FoundationPredictor
1011
from surya.recognition import RecognitionPredictor
1112
from surya.table_rec import TableRecPredictor
1213

@@ -16,10 +17,11 @@
1617
def load_predictors(
1718
device: str | torch.device | None = None, dtype: torch.dtype | str | None = None
1819
) -> Dict[str, BasePredictor]:
20+
foundation_predictor = FoundationPredictor(device=device, dtype=dtype)
1921
return {
2022
"layout": LayoutPredictor(device=device, dtype=dtype),
2123
"ocr_error": OCRErrorPredictor(device=device, dtype=dtype),
22-
"recognition": RecognitionPredictor(device=device, dtype=dtype),
24+
"recognition": RecognitionPredictor(foundation_predictor),
2325
"detection": DetectionPredictor(device=device, dtype=dtype),
2426
"table_rec": TableRecPredictor(device=device, dtype=dtype),
2527
}

surya/scripts/ocr_latex.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from surya.logging import configure_logging, get_logger
99
from surya.scripts.config import CLILoader
10+
from surya.foundation import FoundationPredictor
1011
from surya.recognition import RecognitionPredictor
1112
from surya.common.surya.schema import TaskNames
1213

@@ -19,7 +20,8 @@
1920
def ocr_latex_cli(input_path: str, **kwargs):
2021
loader = CLILoader(input_path, kwargs, highres=True)
2122

22-
texify_predictor = RecognitionPredictor()
23+
foundation_predictor = FoundationPredictor()
24+
texify_predictor = RecognitionPredictor(foundation_predictor)
2325
tasks = [TaskNames.block_without_boxes] * len(loader.images)
2426
bboxes = [[[0, 0, image.width, image.height]] for image in loader.images]
2527

surya/scripts/ocr_text.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from surya.detection import DetectionPredictor
99
from surya.debug.text import draw_text_on_image
1010
from surya.logging import configure_logging, get_logger
11+
from surya.foundation import FoundationPredictor
1112
from surya.recognition import RecognitionPredictor
1213
from surya.scripts.config import CLILoader
1314

@@ -25,8 +26,9 @@ def ocr_text_cli(input_path: str, task_name: str, disable_math: bool, **kwargs):
2526
loader = CLILoader(input_path, kwargs, highres=True)
2627
task_names = [task_name] * len(loader.images)
2728

29+
foundation_predictor = FoundationPredictor()
2830
det_predictor = DetectionPredictor()
29-
rec_predictor = RecognitionPredictor()
31+
rec_predictor = RecognitionPredictor(foundation_predictor)
3032

3133
start = time.time()
3234
predictions_by_image = rec_predictor(

surya/scripts/texify_app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List
44

55
from surya.recognition import RecognitionPredictor
6+
from surya.foundation import FoundationPredictor
67
from surya.common.surya.schema import TaskNames
78

89
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = (
@@ -33,7 +34,8 @@ def replace_fences(text):
3334

3435
@st.cache_resource()
3536
def load_predictor():
36-
return RecognitionPredictor()
37+
foundation_predictor = FoundationPredictor()
38+
return RecognitionPredictor(foundation_predictor)
3739

3840

3941
@st.cache_data()

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from surya.ocr_error import OCRErrorPredictor
1010
from surya.layout import LayoutPredictor
1111
from surya.recognition import RecognitionPredictor
12+
from surya.foundation import FoundationPredictor
1213
from surya.table_rec import TableRecPredictor
1314

1415

@@ -35,7 +36,8 @@ def detection_predictor() -> DetectionPredictor:
3536

3637
@pytest.fixture(scope="session")
3738
def recognition_predictor() -> RecognitionPredictor:
38-
recognition_predictor = RecognitionPredictor()
39+
foundation_predictor = FoundationPredictor()
40+
recognition_predictor = RecognitionPredictor(foundation_predictor)
3941
yield recognition_predictor
4042
del recognition_predictor
4143

0 commit comments

Comments
 (0)