Skip to content

Commit d349f30

Browse files
authored
Merge pull request #315 from VikParuchuri/dev
Inline math fixes
2 parents db63214 + 53007e2 commit d349f30

File tree

5 files changed

+40
-10
lines changed

5 files changed

+40
-10
lines changed

.github/workflows/benchmarks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
run: |
2626
poetry run python benchmark/detection.py --max_rows 2
2727
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
28-
- name: Run inline detection benchmarj
28+
- name: Run inline detection benchmark
2929
run: |
3030
poetry run python benchmark/inline_detection.py --max_rows 5
3131
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/inline_math_bench/results.json --bench_type inline_detection

surya/detection/heatmap.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,22 @@ def parallel_get_inline_boxes(preds, orig_sizes, text_boxes, include_maps=False)
166166
for text_box in text_boxes:
167167
text_box_reshaped = rescale_bbox(text_box, orig_sizes, heatmap_size)
168168
x1, y1, x2, y2 = text_box_reshaped
169-
heatmap[y2:y2+3, x1:x2] = 0
170-
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes, text_threshold=settings.INLINE_MATH_THRESHOLD)
169+
170+
# Blank out above and below text boxes, so we avoid merging inline math blocks together
171+
heatmap[y2:y2+settings.INLINE_MATH_TEXT_BLANK_PX, x1:x2] = 0
172+
heatmap[y1-settings.INLINE_MATH_TEXT_BLANK_PX:y1, x1:x2] = 0
173+
heatmap[y1:y2, x2:x2+settings.INLINE_MATH_TEXT_BLANK_PX] = 0
174+
heatmap[y1:y2, x1-settings.INLINE_MATH_TEXT_BLANK_PX:x1] = 0
175+
176+
bboxes = get_and_clean_boxes(
177+
heatmap,
178+
heatmap_size,
179+
orig_sizes,
180+
text_threshold=settings.INLINE_MATH_THRESHOLD,
181+
low_text=settings.INLINE_MATH_BLANK_THRESHOLD
182+
)
183+
184+
bboxes = [bbox for bbox in bboxes if bbox.area > settings.INLINE_MATH_MIN_AREA]
171185

172186
heat_img, aff_img = None, None
173187
if include_maps:

surya/scripts/streamlit_app.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,20 @@ def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=1
5555
return label, results.labels
5656

5757

58-
def text_detection(img) -> (Image.Image, TextDetectionResult):
58+
def inline_detection(img) -> (Image.Image, TextDetectionResult):
5959
text_pred = predictors["detection"]([img])[0]
60-
text_polygons = [p.polygon for p in text_pred.bboxes]
6160
text_boxes = [p.bbox for p in text_pred.bboxes]
62-
det_img = draw_polys_on_image(text_polygons, img.copy())
63-
61+
6462
inline_pred = predictors["inline_detection"]([img], [text_boxes], include_maps=True)[0]
6563
inline_polygons = [p.polygon for p in inline_pred.bboxes]
66-
det_img = draw_polys_on_image(inline_polygons, det_img, color='blue')
64+
det_img = draw_polys_on_image(inline_polygons, img.copy(), color='blue')
65+
return det_img, text_pred, inline_pred
66+
67+
68+
def text_detection(img) -> (Image.Image, TextDetectionResult):
69+
text_pred = predictors["detection"]([img])[0]
70+
text_polygons = [p.polygon for p in text_pred.bboxes]
71+
det_img = draw_polys_on_image(text_polygons, img.copy())
6772
return det_img, text_pred, inline_pred
6873

6974

@@ -193,6 +198,7 @@ def page_counter(pdf_file):
193198
page_number = None
194199

195200
run_text_det = st.sidebar.button("Run Text Detection")
201+
run_inline_det = st.sidebar.button("Run Inline Math Detection")
196202
run_text_rec = st.sidebar.button("Run OCR")
197203
run_layout_det = st.sidebar.button("Run Layout Analysis")
198204
run_table_rec = st.sidebar.button("Run Table Rec")
@@ -211,6 +217,13 @@ def page_counter(pdf_file):
211217
st.json(text_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)
212218
st.json(inline_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)
213219

220+
if run_inline_det:
221+
det_img, text_pred, inline_pred = inline_detection(pil_image)
222+
with col1:
223+
st.image(det_img, caption="Detected Inline Math", use_container_width=True)
224+
st.json(text_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)
225+
st.json(inline_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)
226+
214227

215228
# Run layout
216229
if run_layout_det:

surya/settings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,11 @@ def TORCH_DEVICE_MODEL(self) -> str:
5757

5858
# Inline math detection
5959
INLINE_MATH_MODEL_CHECKPOINT: str = "datalab-to/inline_math_det0@75aafc7aa3d494ece6496d28038c91f0d2518a43"
60-
INLINE_MATH_THRESHOLD: float = 0.9 #Threshold for inline math detection (above this is considered inline-math)
60+
INLINE_MATH_THRESHOLD: float = 0.8 #Threshold for inline math detection (above this is considered inline-math)
61+
INLINE_MATH_BLANK_THRESHOLD: float = 0.5 # Threshold for blank space (below this is considered blank)
6162
INLINE_MATH_BENCH_DATASET_NAME: str = "datalab-to/inline_detection_bench"
63+
INLINE_MATH_TEXT_BLANK_PX: int = 2 # How many pixels to blank out at the botton of each text line
64+
INLINE_MATH_MIN_AREA: int = 100 # Minimum area for inline math detection
6265

6366
# Text recognition
6467
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec2@6611509b2c3a32c141703ce19adc899d9d0abf41"

surya/texify/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def batch_texify(self, images: List[Image.Image], batch_size: int | None) -> Tup
125125
batch_confidences = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
126126
batch_confidences = batch_confidences.cpu()[:current_batch_size]
127127
batch_predictions = batch_predictions.cpu()[:current_batch_size, 1:] # Cut off initial token
128-
detected_text = self.processor.tokenizer.batch_decode(batch_predictions)
128+
detected_text = self.processor.tokenizer.batch_decode(batch_predictions, skip_special_tokens=True)
129129

130130
batch_confidences = batch_confidences.tolist()
131131

0 commit comments

Comments
 (0)