Skip to content

Commit 7a79fb5

Browse files
authored
Merge pull request #321 from VikParuchuri/dev
Swap to headless opencv
2 parents 5b61bd7 + 294f711 commit 7a79fb5

33 files changed

+388
-188
lines changed

poetry.lock

Lines changed: 105 additions & 110 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.11.1"
3+
version = "0.12.0"
44
description = "OCR, layout, reading order, and table recognition in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"
@@ -20,9 +20,10 @@ pydantic-settings = "^2.1.0"
2020
python-dotenv = "^1.0.0"
2121
pillow = "^10.2.0"
2222
pypdfium2 = "=4.30.0"
23-
opencv-python = "^4.9.0.80"
2423
filetype = "^1.2.0"
2524
click = "^8.1.8"
25+
platformdirs = "^4.3.6"
26+
opencv-python-headless = "^4.11.0.86"
2627

2728
[tool.poetry.group.dev.dependencies]
2829
jupyter = "^1.0.0"

surya/common/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
3+

surya/common/donut/processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import numpy as np
1010
from PIL import Image
1111
import PIL
12-
from surya.settings import settings
1312

13+
from surya.common.s3 import S3DownloaderMixin
14+
from surya.settings import settings
1415

15-
class SuryaEncoderImageProcessor(DonutImageProcessor):
16+
class SuryaEncoderImageProcessor(S3DownloaderMixin, DonutImageProcessor):
1617
def __init__(self, *args, max_size=None, align_long_axis=False, **kwargs):
1718
super().__init__(*args, **kwargs)
1819

surya/common/load.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,4 @@ def model(
1818
def processor(
1919
self
2020
) -> Any:
21-
raise NotImplementedError()
22-
23-
@staticmethod
24-
def split_checkpoint_revision(checkpoint: str) -> tuple[str, str | None]:
25-
parts = checkpoint.rsplit("@", 1)
26-
if len(parts) == 1:
27-
return parts[0], "main" # Default revision is main
28-
return parts[0], parts[1]
21+
raise NotImplementedError()

surya/common/s3.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import json
2+
import os
3+
import shutil
4+
import tempfile
5+
import time
6+
from concurrent.futures import ThreadPoolExecutor
7+
from pathlib import Path
8+
9+
import requests
10+
from platformdirs import user_cache_dir
11+
from tqdm import tqdm
12+
13+
from surya.settings import settings
14+
15+
def join_urls(url1: str, url2: str):
16+
url1 = url1.rstrip("/")
17+
url2 = url2.lstrip("/")
18+
return f"{url1}/{url2}"
19+
20+
21+
def get_model_name(pretrained_model_name_or_path: str):
22+
return pretrained_model_name_or_path.split("/")[0]
23+
24+
25+
def download_file(remote_path: str, local_path: str, chunk_size: int = 1024 * 1024):
26+
local_path = Path(local_path)
27+
try:
28+
response = requests.get(remote_path, stream=True, allow_redirects=True)
29+
response.raise_for_status() # Raise an exception for bad status codes
30+
31+
with open(local_path, 'wb') as f:
32+
for chunk in response.iter_content(chunk_size=chunk_size):
33+
if chunk:
34+
f.write(chunk)
35+
36+
return local_path
37+
except Exception as e:
38+
if local_path.exists():
39+
local_path.unlink()
40+
print(f"Download error for file {remote_path}: {str(e)}")
41+
raise
42+
43+
def check_manifest(local_dir: str):
44+
local_dir = Path(local_dir)
45+
manifest_path = local_dir / "manifest.json"
46+
if not os.path.exists(manifest_path):
47+
return False
48+
49+
try:
50+
with open(manifest_path, "r") as f:
51+
manifest = json.load(f)
52+
for file in manifest["files"]:
53+
if not os.path.exists(local_dir / file):
54+
return False
55+
except Exception as e:
56+
return False
57+
58+
return True
59+
60+
61+
def download_directory(remote_path: str, local_dir: str):
62+
model_name = get_model_name(remote_path)
63+
s3_url = join_urls(settings.S3_BASE_URL, remote_path)
64+
# Check to see if it's already downloaded
65+
model_exists = check_manifest(local_dir)
66+
if model_exists:
67+
return
68+
69+
# Use tempfile.TemporaryDirectory to automatically clean up
70+
with tempfile.TemporaryDirectory() as temp_dir:
71+
# Download the manifest file
72+
manifest_file = join_urls(s3_url, "manifest.json")
73+
manifest_path = os.path.join(temp_dir, "manifest.json")
74+
download_file(manifest_file, manifest_path)
75+
76+
# List and download all files
77+
with open(manifest_path, "r") as f:
78+
manifest = json.load(f)
79+
80+
pbar = tqdm(desc=f"Downloading {model_name} model...", total=len(manifest["files"]))
81+
82+
with ThreadPoolExecutor(max_workers=settings.PARALLEL_DOWNLOAD_WORKERS) as executor:
83+
futures = []
84+
for file in manifest["files"]:
85+
remote_file = join_urls(s3_url, file)
86+
local_file = os.path.join(temp_dir, file)
87+
futures.append(executor.submit(download_file, remote_file, local_file))
88+
89+
for future in futures:
90+
future.result()
91+
pbar.update(1)
92+
93+
pbar.close()
94+
95+
# Move all files to new directory
96+
for file in os.listdir(temp_dir):
97+
shutil.move(os.path.join(temp_dir, file), local_dir)
98+
99+
100+
class S3DownloaderMixin:
101+
102+
@classmethod
103+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
104+
# Allow loading models directly from the hub, or using s3
105+
if not pretrained_model_name_or_path.startswith("s3://"):
106+
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
107+
108+
pretrained_model_name_or_path = pretrained_model_name_or_path.replace("s3://", "")
109+
cache_dir = Path(user_cache_dir('datalab')) / "models"
110+
local_path = os.path.join(cache_dir, pretrained_model_name_or_path)
111+
os.makedirs(local_path, exist_ok=True)
112+
113+
# Retry logic for downloading the model folder
114+
retries = 3
115+
delay = 5
116+
attempt = 0
117+
success = False
118+
while not success and attempt < retries:
119+
try:
120+
download_directory(pretrained_model_name_or_path, local_path)
121+
success = True # If download succeeded
122+
except Exception as e:
123+
print(f"Error downloading model from {pretrained_model_name_or_path}. Attempt {attempt+1} of {retries}. Error: {e}")
124+
attempt += 1
125+
if attempt < retries:
126+
print(f"Retrying in {delay} seconds...")
127+
time.sleep(delay) # Wait before retrying
128+
else:
129+
print(f"Failed to download {pretrained_model_name_or_path} after {retries} attempts.")
130+
raise e # Reraise exception after max retries
131+
132+
pretrained_model_name_or_path = local_path
133+
134+
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

surya/detection/loader.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ def __init__(self, checkpoint: Optional[str] = None):
1717
if self.checkpoint is None:
1818
self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
1919

20-
self.checkpoint, self.revision = self.split_checkpoint_revision(self.checkpoint)
21-
2220
def model(
2321
self,
2422
device: Optional[torch.device | str] = None,
@@ -29,12 +27,11 @@ def model(
2927
if dtype is None:
3028
dtype = settings.MODEL_DTYPE
3129

32-
config = EfficientViTConfig.from_pretrained(self.checkpoint, revision=self.revision)
30+
config = EfficientViTConfig.from_pretrained(self.checkpoint)
3331
model = EfficientViTForSemanticSegmentation.from_pretrained(
3432
self.checkpoint,
3533
torch_dtype=dtype,
3634
config=config,
37-
revision=self.revision
3835
)
3936
model = model.to(device)
4037
model = model.eval()
@@ -52,7 +49,7 @@ def model(
5249
return model
5350

5451
def processor(self) -> SegformerImageProcessor:
55-
return SegformerImageProcessor.from_pretrained(self.checkpoint, revision=self.revision)
52+
return SegformerImageProcessor.from_pretrained(self.checkpoint)
5653

5754
class InlineDetectionModelLoader(DetectionModelLoader):
5855
def __init__(self, checkpoint: Optional[str] = None):

surya/detection/model/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from transformers import PretrainedConfig
22

3+
from surya.common.s3 import S3DownloaderMixin
34

4-
class EfficientViTConfig(PretrainedConfig):
5+
6+
class EfficientViTConfig(S3DownloaderMixin, PretrainedConfig):
57
r"""
68
```"""
79

surya/detection/model/encoderdecoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from transformers import PreTrainedModel
1919
from transformers.modeling_outputs import SemanticSegmenterOutput
2020

21+
from surya.common.s3 import S3DownloaderMixin
2122
from surya.detection.model.config import EfficientViTConfig
2223

2324

@@ -721,7 +722,7 @@ def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
721722
return logits
722723

723724

724-
class EfficientViTForSemanticSegmentation(EfficientViTPreTrainedModel):
725+
class EfficientViTForSemanticSegmentation(S3DownloaderMixin, EfficientViTPreTrainedModel):
725726
def __init__(self, config, **kwargs):
726727
super().__init__(config)
727728
self.vit = EfficientVitLarge(config)

surya/detection/processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
import PIL.Image
2121
import torch
2222

23+
from surya.common.s3 import S3DownloaderMixin
2324

24-
class SegformerImageProcessor(BaseImageProcessor):
25+
26+
class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
2527
r"""
2628
Constructs a Segformer image processor.
2729

surya/layout/loader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ def __init__(self, checkpoint: Optional[str] = None):
1616
if self.checkpoint is None:
1717
self.checkpoint = settings.LAYOUT_MODEL_CHECKPOINT
1818

19-
self.checkpoint, self.revision = self.split_checkpoint_revision(self.checkpoint)
20-
2119
def model(
2220
self,
2321
device=settings.TORCH_DEVICE_MODEL,
@@ -28,7 +26,7 @@ def model(
2826
if dtype is None:
2927
dtype = settings.MODEL_DTYPE
3028

31-
config = SuryaLayoutConfig.from_pretrained(self.checkpoint, revision=self.revision)
29+
config = SuryaLayoutConfig.from_pretrained(self.checkpoint)
3230
decoder_config = config.decoder
3331
decoder = SuryaLayoutDecoderConfig(**decoder_config)
3432
config.decoder = decoder
@@ -37,7 +35,7 @@ def model(
3735
encoder = DonutSwinLayoutConfig(**encoder_config)
3836
config.encoder = encoder
3937

40-
model = SuryaLayoutModel.from_pretrained(self.checkpoint, config=config, torch_dtype=dtype, revision=self.revision)
38+
model = SuryaLayoutModel.from_pretrained(self.checkpoint, config=config, torch_dtype=dtype)
4139
model = model.to(device)
4240
model = model.eval()
4341

surya/layout/model/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from transformers import PretrainedConfig
55
from transformers.modeling_outputs import CausalLMOutput
66
from transformers.utils import ModelOutput
7+
from surya.common.s3 import S3DownloaderMixin
78
from surya.settings import settings
89

910
SPECIAL_TOKENS = 3
@@ -36,7 +37,7 @@
3637
LABEL_COUNT = len(ID_TO_LABEL)
3738

3839

39-
class SuryaLayoutConfig(PretrainedConfig):
40+
class SuryaLayoutConfig(S3DownloaderMixin, PretrainedConfig):
4041
model_type = "vision-encoder-decoder"
4142
is_composition = True
4243

surya/layout/model/encoderdecoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
66
from transformers.modeling_outputs import BaseModelOutput
7+
from surya.common.s3 import S3DownloaderMixin
78
from surya.layout.model.encoder import DonutSwinLayoutModel
89
from surya.layout.model.decoder import SuryaLayoutDecoder
910
from transformers.utils import ModelOutput
@@ -16,7 +17,7 @@ class LayoutBboxOutput(ModelOutput):
1617
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
1718

1819

19-
class SuryaLayoutModel(PreTrainedModel):
20+
class SuryaLayoutModel(S3DownloaderMixin, PreTrainedModel):
2021
config_class = VisionEncoderDecoderConfig
2122
base_model_prefix = "vision_encoder_decoder"
2223
main_input_name = "pixel_values"

surya/ocr_error/loader.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ def __init__(self, checkpoint: Optional[str] = None):
1616
if self.checkpoint is None:
1717
self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT
1818

19-
self.checkpoint, self.revision = self.split_checkpoint_revision(self.checkpoint)
20-
2119
def model(
2220
self,
2321
device=settings.TORCH_DEVICE_MODEL,
@@ -28,12 +26,11 @@ def model(
2826
if dtype is None:
2927
dtype = settings.MODEL_DTYPE
3028

31-
config = DistilBertConfig.from_pretrained(self.checkpoint, revision=self.revision)
29+
config = DistilBertConfig.from_pretrained(self.checkpoint)
3230
model = DistilBertForSequenceClassification.from_pretrained(
3331
self.checkpoint,
3432
torch_dtype=dtype,
3533
config=config,
36-
revision=self.revision
3734
).to(device).eval()
3835

3936
if settings.COMPILE_ALL or settings.COMPILE_OCR_ERROR:
@@ -50,4 +47,4 @@ def model(
5047
def processor(
5148
self
5249
) -> DistilBertTokenizer:
53-
return DistilBertTokenizer.from_pretrained(self.checkpoint, revision=self.revision)
50+
return DistilBertTokenizer.from_pretrained(self.checkpoint)

surya/ocr_error/model/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from transformers.configuration_utils import PretrainedConfig
55
from transformers.onnx import OnnxConfig
66

7+
from surya.common.s3 import S3DownloaderMixin
8+
79
ID2LABEL = {
810
0: 'good',
911
1: 'bad'
1012
}
1113

12-
class DistilBertConfig(PretrainedConfig):
14+
class DistilBertConfig(S3DownloaderMixin, PretrainedConfig):
1315
model_type = "distilbert"
1416
attribute_map = {
1517
"hidden_size": "dim",

surya/ocr_error/model/encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from flash_attn import flash_attn_func, flash_attn_varlen_func
2222
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
2323

24+
from surya.common.s3 import S3DownloaderMixin
2425
from surya.ocr_error.model.config import DistilBertConfig
2526

2627

@@ -693,7 +694,7 @@ def forward(
693694
)
694695

695696

696-
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
697+
class DistilBertForSequenceClassification(S3DownloaderMixin, DistilBertPreTrainedModel):
697698
def __init__(self, config: DistilBertConfig):
698699
super().__init__(config)
699700
self.num_labels = config.num_labels

0 commit comments

Comments
 (0)