Skip to content

Commit 30ce562

Browse files
authored
Merge pull request #252 from VikParuchuri/dev
New layout model
2 parents cb86a92 + e05dbd0 commit 30ce562

40 files changed

+3440
-4718
lines changed

README.md

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ model, processor = load_model(), load_processor()
197197
predictions = batch_text_detection([image], model, processor)
198198
```
199199

200-
## Layout analysis
200+
## Layout and reading order
201201

202-
This command will write out a json file with the detected layout.
202+
This command will write out a json file with the detected layout and reading order.
203203

204204
```shell
205205
surya_layout DATA_PATH
@@ -215,14 +215,14 @@ The `results.json` file will contain a json dictionary where the keys are the in
215215
- `bboxes` - detected bounding boxes for text
216216
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
217217
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
218-
- `confidence` - the confidence of the model in the detected text (0-1). This is currently not very reliable.
219-
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Text`, `Title`.
218+
- `position` - the reading order of the box.
219+
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`.
220220
- `page` - the page number in the file
221221
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
222222

223223
**Performance tips**
224224

225-
Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `400MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.
225+
Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `220MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 7GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.
226226

227227
### From python
228228

@@ -231,7 +231,6 @@ from PIL import Image
231231
from surya.detection import batch_text_detection
232232
from surya.layout import batch_layout_detection
233233
from surya.model.layout.model import load_model, load_processor
234-
from surya.settings import settings
235234

236235
image = Image.open(IMAGE_PATH)
237236
model = load_model()
@@ -244,52 +243,6 @@ line_predictions = batch_text_detection([image], det_model, det_processor)
244243
layout_predictions = batch_layout_detection([image], model, processor, line_predictions)
245244
```
246245

247-
## Reading order
248-
249-
This command will write out a json file with the detected reading order and layout.
250-
251-
```shell
252-
surya_order DATA_PATH
253-
```
254-
255-
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
256-
- `--images` will save images of the pages and detected text lines (optional)
257-
- `--max` specifies the maximum number of pages to process if you don't want to process everything
258-
- `--results_dir` specifies the directory to save results to instead of the default
259-
260-
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
261-
262-
- `bboxes` - detected bounding boxes for text
263-
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
264-
- `position` - the position in the reading order of the bbox, starting from 0.
265-
- `label` - the label for the bbox. See the layout section of the documentation for a list of potential labels.
266-
- `page` - the page number in the file
267-
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
268-
269-
**Performance tips**
270-
271-
Setting the `ORDER_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `360MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 11GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.
272-
273-
### From python
274-
275-
```python
276-
from PIL import Image
277-
from surya.ordering import batch_ordering
278-
from surya.model.ordering.processor import load_processor
279-
from surya.model.ordering.model import load_model
280-
281-
image = Image.open(IMAGE_PATH)
282-
# bboxes should be a list of lists with layout bboxes for the image in [x1,y1,x2,y2] format
283-
# You can get this from the layout model, see above for usage
284-
bboxes = [bbox1, bbox2, ...]
285-
286-
model = load_model()
287-
processor = load_processor()
288-
289-
# order_predictions will be a list of dicts, one per image
290-
order_predictions = batch_ordering([image], [bboxes], model, processor)
291-
```
292-
293246
## Table Recognition
294247

295248
This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes. If you want to get a formatted markdown table, check out the [tabled](https://www.github.com/VikParuchuri/tabled) repo.
@@ -324,6 +277,9 @@ The `results.json` file will contain a json dictionary where the keys are the in
324277

325278
Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.
326279

280+
### From python
281+
282+
See `table_recognition.py` for a code sample. Table recognition depends on extracting cells, so it is a little more involved to setup than other model types.
327283

328284
# Limitations
329285

@@ -410,16 +366,15 @@ Then we calculate precision and recall for the whole dataset.
410366

411367
## Layout analysis
412368

413-
![Benchmark chart](static/images/benchmark_layout_chart.png)
414-
415-
| Layout Type | precision | recall |
416-
| ----------- | --------- | ------ |
417-
| Image | 0.97 | 0.96 |
418-
| Table | 0.99 | 0.99 |
419-
| Text | 0.9 | 0.97 |
420-
| Title | 0.94 | 0.88 |
369+
| Layout Type | precision | recall |
370+
|---------------|-------------|----------|
371+
| Image | 0.91265 | 0.93976 |
372+
| List | 0.80849 | 0.86792 |
373+
| Table | 0.84957 | 0.96104 |
374+
| Text | 0.93019 | 0.94571 |
375+
| Title | 0.92102 | 0.95404 |
421376

422-
Time per image - .4 seconds on GPU (A10).
377+
Time per image - .13 seconds on GPU (A10).
423378

424379
**Methodology**
425380

@@ -430,7 +385,7 @@ I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/
430385

431386
## Reading Order
432387

433-
75% mean accuracy, and .14 seconds per image on an A6000 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.
388+
88% mean accuracy, and .4 seconds per image on an A10 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.
434389

435390
**Methodology**
436391

benchmark/layout.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import json
55

66
from surya.benchmark.metrics import precision_recall
7-
from surya.detection import batch_text_detection
8-
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
9-
from surya.model.layout.model import load_model, load_processor
7+
from surya.model.layout.model import load_model
8+
from surya.model.layout.processor import load_processor
109
from surya.input.processing import convert_if_not_rgb
1110
from surya.layout import batch_layout_detection
1211
from surya.postprocessing.heatmap import draw_bboxes_on_image
@@ -26,8 +25,6 @@ def main():
2625

2726
model = load_model()
2827
processor = load_processor()
29-
det_model = load_det_model()
30-
det_processor = load_det_processor()
3128

3229
pathname = "layout_bench"
3330
# These have already been shuffled randomly, so sampling from the start is fine
@@ -36,12 +33,10 @@ def main():
3633
images = convert_if_not_rgb(images)
3734

3835
if settings.LAYOUT_STATIC_CACHE:
39-
line_prediction = batch_text_detection(images[:1], det_model, det_processor)
40-
batch_layout_detection(images[:1], model, processor, line_prediction)
36+
batch_layout_detection(images[:1], model, processor)
4137

4238
start = time.time()
43-
line_predictions = batch_text_detection(images, det_model, det_processor)
44-
layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
39+
layout_predictions = batch_layout_detection(images, model, processor)
4540
surya_time = time.time() - start
4641

4742
folder_name = os.path.basename(pathname).split(".")[0]
@@ -50,9 +45,10 @@ def main():
5045

5146
label_alignment = { # First is publaynet, second is surya
5247
"Image": [["Figure"], ["Picture", "Figure"]],
53-
"Table": [["Table"], ["Table"]],
54-
"Text": [["Text", "List"], ["Text", "Formula", "Footnote", "Caption", "List-item"]],
55-
"Title": [["Title"], ["Section-header", "Title"]]
48+
"Table": [["Table"], ["Table", "Form", "TableOfContents"]],
49+
"Text": [["Text"], ["Text", "Formula", "Footnote", "Caption", "TextInlineMath", "Code", "Handwriting"]],
50+
"List": [["List"], ["ListItem"]],
51+
"Title": [["Title"], ["SectionHeader", "Title"]]
5652
}
5753

5854
page_metrics = collections.OrderedDict()

benchmark/ordering.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import json
55

66
from surya.input.processing import convert_if_not_rgb
7-
from surya.model.ordering.model import load_model
8-
from surya.model.ordering.processor import load_processor
9-
from surya.ordering import batch_ordering
7+
from surya.layout import batch_layout_detection
8+
from surya.model.layout.model import load_model
9+
from surya.model.layout.processor import load_processor
10+
from surya.schema import Bbox
1011
from surya.settings import settings
1112
from surya.benchmark.metrics import rank_accuracy
1213
import os
@@ -15,7 +16,7 @@
1516

1617

1718
def main():
18-
parser = argparse.ArgumentParser(description="Benchmark surya reading order model.")
19+
parser = argparse.ArgumentParser(description="Benchmark surya layout for reading order.")
1920
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
2021
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
2122
args = parser.parse_args()
@@ -31,10 +32,9 @@ def main():
3132
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
3233
images = list(dataset["image"])
3334
images = convert_if_not_rgb(images)
34-
bboxes = list(dataset["bboxes"])
3535

3636
start = time.time()
37-
order_predictions = batch_ordering(images, bboxes, model, processor)
37+
layout_predictions = batch_layout_detection(images, model, processor)
3838
surya_time = time.time() - start
3939

4040
folder_name = os.path.basename(pathname).split(".")[0]
@@ -43,11 +43,21 @@ def main():
4343

4444
page_metrics = collections.OrderedDict()
4545
mean_accuracy = 0
46-
for idx, order_pred in enumerate(order_predictions):
46+
for idx, order_pred in enumerate(layout_predictions):
4747
row = dataset[idx]
48-
pred_labels = [str(l.position) for l in order_pred.bboxes]
4948
labels = row["labels"]
50-
accuracy = rank_accuracy(pred_labels, labels)
49+
bboxes = row["bboxes"]
50+
pred_positions = []
51+
for label, bbox in zip(labels, bboxes):
52+
max_intersection = 0
53+
matching_idx = 0
54+
for pred_box in order_pred.bboxes:
55+
intersection = pred_box.intersection_pct(Bbox(bbox=bbox))
56+
if intersection > max_intersection:
57+
max_intersection = intersection
58+
matching_idx = pred_box.position
59+
pred_positions.append(matching_idx)
60+
accuracy = rank_accuracy(pred_positions, labels)
5161
mean_accuracy += accuracy
5262
page_results = {
5363
"accuracy": accuracy,
@@ -56,7 +66,7 @@ def main():
5666

5767
page_metrics[idx] = page_results
5868

59-
mean_accuracy /= len(order_predictions)
69+
mean_accuracy /= len(layout_predictions)
6070

6171
out_data = {
6272
"time": surya_time,

detect_layout.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import json
77
from collections import defaultdict
88

9-
from surya.detection import batch_text_detection
109
from surya.input.load import load_from_folder, load_from_file
1110
from surya.layout import batch_layout_detection
12-
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
13-
from surya.model.layout.model import load_model, load_processor
11+
from surya.model.layout.model import load_model
12+
from surya.model.layout.processor import load_processor
1413
from surya.postprocessing.heatmap import draw_polys_on_image
1514
from surya.settings import settings
1615
import os
@@ -27,8 +26,6 @@ def main():
2726

2827
model = load_model()
2928
processor = load_processor()
30-
det_model = load_det_model()
31-
det_processor = load_det_processor()
3229

3330
if os.path.isdir(args.input_path):
3431
images, names, _ = load_from_folder(args.input_path, args.max)
@@ -38,9 +35,7 @@ def main():
3835
folder_name = os.path.basename(args.input_path).split(".")[0]
3936

4037
start = time.time()
41-
line_predictions = batch_text_detection(images, det_model, det_processor)
42-
43-
layout_predictions = batch_layout_detection(images, model, processor, line_predictions, include_maps=args.debug)
38+
layout_predictions = batch_layout_detection(images, model, processor)
4439
result_path = os.path.join(args.results_dir, folder_name)
4540
os.makedirs(result_path, exist_ok=True)
4641
if args.debug:
@@ -49,17 +44,13 @@ def main():
4944
if args.images:
5045
for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
5146
polygons = [p.polygon for p in layout_pred.bboxes]
52-
labels = [p.label for p in layout_pred.bboxes]
47+
labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes]
5348
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
5449
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png"))
5550

56-
if args.debug:
57-
heatmap = layout_pred.segmentation_map
58-
heatmap.save(os.path.join(result_path, f"{name}_{idx}_segmentation.png"))
59-
6051
predictions_by_page = defaultdict(list)
6152
for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)):
62-
out_pred = pred.model_dump(exclude=["segmentation_map"])
53+
out_pred = pred.model_dump()
6354
out_pred["page"] = len(predictions_by_page[name]) + 1
6455
predictions_by_page[name].append(out_pred)
6556

ocr_app.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,19 @@
99
from surya.input.pdflines import get_page_text_lines, get_table_blocks
1010
from surya.layout import batch_layout_detection
1111
from surya.model.detection.model import load_model, load_processor
12-
from surya.model.layout.model import load_model as load_layout_model, load_processor as load_layout_processor
12+
from surya.model.layout.model import load_model as load_layout_model
13+
from surya.model.layout.processor import load_processor as load_layout_processor
1314
from surya.model.recognition.model import load_model as load_rec_model
1415
from surya.model.recognition.processor import load_processor as load_rec_processor
15-
from surya.model.ordering.processor import load_processor as load_order_processor
16-
from surya.model.ordering.model import load_model as load_order_model
1716
from surya.model.table_rec.model import load_model as load_table_model
1817
from surya.model.table_rec.processor import load_processor as load_table_processor
19-
from surya.ordering import batch_ordering
2018
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
2119
from surya.ocr import run_ocr
2220
from surya.postprocessing.text import draw_text_on_image
2321
from PIL import Image
2422
from surya.languages import CODE_TO_LANGUAGE
2523
from surya.input.langs import replace_lang_with_code
26-
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult, TableResult
24+
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
2725
from surya.settings import settings
2826
from surya.tables import batch_table_recognition
2927
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
@@ -43,10 +41,6 @@ def load_rec_cached():
4341
def load_layout_cached():
4442
return load_layout_model(), load_layout_processor()
4543

46-
@st.cache_resource()
47-
def load_order_cached():
48-
return load_order_model(), load_order_processor()
49-
5044

5145
@st.cache_resource()
5246
def load_table_cached():
@@ -61,24 +55,13 @@ def text_detection(img) -> (Image.Image, TextDetectionResult):
6155

6256

6357
def layout_detection(img) -> (Image.Image, LayoutResult):
64-
_, det_pred = text_detection(img)
65-
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
58+
pred = batch_layout_detection([img], layout_model, layout_processor)[0]
6659
polygons = [p.polygon for p in pred.bboxes]
67-
labels = [p.label for p in pred.bboxes]
60+
labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
6861
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
6962
return layout_img, pred
7063

7164

72-
def order_detection(img) -> (Image.Image, OrderResult):
73-
_, layout_pred = layout_detection(img)
74-
bboxes = [l.bbox for l in layout_pred.bboxes]
75-
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
76-
polys = [l.polygon for l in pred.bboxes]
77-
positions = [str(l.position) for l in pred.bboxes]
78-
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=18)
79-
return order_img, pred
80-
81-
8265
def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
8366
if skip_table_detection:
8467
layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
@@ -171,7 +154,6 @@ def page_count(pdf_file):
171154
det_model, det_processor = load_det_cached()
172155
rec_model, rec_processor = load_rec_cached()
173156
layout_model, layout_processor = load_layout_cached()
174-
order_model, order_processor = load_order_cached()
175157
table_model, table_processor = load_table_cached()
176158

177159

@@ -211,7 +193,6 @@ def page_count(pdf_file):
211193
text_det = st.sidebar.button("Run Text Detection")
212194
text_rec = st.sidebar.button("Run OCR")
213195
layout_det = st.sidebar.button("Run Layout Analysis")
214-
order_det = st.sidebar.button("Run Reading Order")
215196
table_rec = st.sidebar.button("Run Table Rec")
216197
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
217198
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")
@@ -245,12 +226,6 @@ def page_count(pdf_file):
245226
with text_tab:
246227
st.text("\n".join([p.text for p in pred.text_lines]))
247228

248-
if order_det:
249-
order_img, pred = order_detection(pil_image)
250-
with col1:
251-
st.image(order_img, caption="Reading Order", use_column_width=True)
252-
st.json(pred.model_dump(), expanded=True)
253-
254229

255230
if table_rec:
256231
table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)

0 commit comments

Comments
 (0)