Skip to content

Commit 8a63dfc

Browse files
authored
Merge pull request #374 from VikParuchuri/dev
Fix large image issue
2 parents 60ca35f + 6c1ae57 commit 8a63dfc

File tree

9 files changed

+228
-116
lines changed

9 files changed

+228
-116
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ There is a hosted API for all surya models available [here](https://www.datalab.
5959

6060
I want surya to be as widely accessible as possible, while still funding my development/training costs. Research and personal usage is always okay, but there are some restrictions on commercial usage.
6161

62-
The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. You also must not be competitive with the [Datalab API](https://www.datalab.to/). If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to).
62+
The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under \$2M USD in gross revenue in the most recent 12-month period AND under \$2M in lifetime VC/angel funding raised. You also must not be competitive with the [Datalab API](https://www.datalab.to/). If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to).
6363

6464
# Installation
6565

poetry.lock

Lines changed: 136 additions & 98 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.14.1"
3+
version = "0.14.2"
44
description = "OCR, layout, reading order, and table recognition in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"
@@ -14,7 +14,7 @@ packages = [
1414
[tool.poetry.dependencies]
1515
python = "^3.10"
1616
transformers = "^4.51.2"
17-
torch = "^2.5.1"
17+
torch = "^2.7.0"
1818
pydantic = "^2.5.3"
1919
pydantic-settings = "^2.1.0"
2020
python-dotenv = "^1.0.0"
@@ -25,8 +25,8 @@ click = "^8.1.8"
2525
platformdirs = "^4.3.6"
2626
opencv-python-headless = "^4.11.0.86"
2727
einops = "^0.8.1"
28-
2928
pre-commit = "^4.2.0"
29+
3030
[tool.poetry.group.dev.dependencies]
3131
jupyter = "^1.0.0"
3232
pytesseract = "^0.3.10"

surya/common/s3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def download_directory(remote_path: str, local_dir: str):
8585
manifest = json.load(f)
8686

8787
pbar = tqdm(
88-
desc=f"Downloading {model_name} model...", total=len(manifest["files"])
88+
desc=f"Downloading {model_name} model to {local_dir}",
89+
total=len(manifest["files"]),
8990
)
9091

9192
with ThreadPoolExecutor(

surya/common/surya/__init__.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010

1111
from surya.common.s3 import S3DownloaderMixin
1212
from surya.common.surya.config import SuryaModelConfig
13-
from surya.common.surya.decoder.__init__ import SuryaDecoderModel
14-
from surya.common.surya.embedder.__init__ import SimpleTokenEmbedder
15-
from surya.common.surya.encoder.__init__ import SuryaEncoderModel
13+
from surya.common.surya.decoder import SuryaDecoderModel
14+
from surya.common.surya.embedder import SimpleTokenEmbedder
15+
from surya.common.surya.encoder import SuryaEncoderModel
1616

1717
from transformers.utils import is_flash_attn_2_available
1818

19+
from surya.logging import get_logger
20+
1921
if is_flash_attn_2_available():
2022
from surya.common.surya.flash_attn_utils import _get_unpad_data
2123

24+
logger = get_logger()
25+
2226

2327
@dataclass
2428
class SuryaModelOutput(CausalLMOutputWithPast):
@@ -123,11 +127,57 @@ def set_output_embeddings(self, new_embeddings: nn.Module):
123127
def set_input_embeddings(self, new_embeddings: nn.Module):
124128
self.embedder.token_embed = new_embeddings
125129

126-
def get_image_embeddings(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor):
130+
def get_image_embeddings(
131+
self,
132+
pixel_values: torch.Tensor,
133+
grid_thw: torch.Tensor,
134+
encoder_chunk_size: int | None,
135+
):
127136
# embed all images with the vision encoder after they have already been tiled and flattened into a single batch
128-
embeddings = self.vision_encoder.embed_images(
129-
image_batch=pixel_values, grid_thw=grid_thw
137+
chunks = [0]
138+
grid_chunks = [0]
139+
curr_chunk_len = 0
140+
curr_seq_len = 0
141+
for i in range(len(grid_thw)):
142+
curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item()
143+
if curr_chunk_len > encoder_chunk_size:
144+
chunks.append(curr_chunk_len + curr_seq_len)
145+
curr_seq_len += curr_chunk_len
146+
curr_chunk_len = 0
147+
grid_chunks.append(i + 1)
148+
149+
if curr_chunk_len > 0:
150+
chunks.append(pixel_values.shape[0])
151+
grid_chunks.append(len(grid_thw))
152+
153+
assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], (
154+
f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}"
155+
)
156+
157+
logger.debug(
158+
f"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}"
130159
)
160+
embeddings = []
161+
for i in range(len(chunks) - 1):
162+
start = chunks[i]
163+
end = chunks[i + 1]
164+
grid_start = grid_chunks[i]
165+
grid_end = grid_chunks[i + 1]
166+
chunk_embeddings = self.vision_encoder.embed_images(
167+
image_batch=pixel_values[start:end],
168+
grid_thw=grid_thw[grid_start:grid_end],
169+
)
170+
embeddings.append(chunk_embeddings)
171+
172+
if len(embeddings) == 0:
173+
raise ValueError(
174+
"No image embeddings were generated. Check the input images and grid sizes."
175+
)
176+
elif len(embeddings) == 1:
177+
embeddings = embeddings[0]
178+
else:
179+
embeddings = torch.cat(embeddings, dim=0)
180+
131181
encoding_2d = self.get_2d_learned_embeddings(
132182
grid_thw,
133183
device=embeddings.device,
@@ -144,7 +194,9 @@ def get_image_embeddings(self, pixel_values: torch.Tensor, grid_thw: torch.Tenso
144194

145195
return embeddings
146196

147-
def embed_ids_boxes_images(self, input_ids, pixel_values, grid_thw):
197+
def embed_ids_boxes_images(
198+
self, input_ids, pixel_values, grid_thw, encoder_chunk_size: int
199+
):
148200
"""
149201
Insert embedded image tiles into the corresponding positions into the full input sequence
150202
@@ -154,7 +206,9 @@ def embed_ids_boxes_images(self, input_ids, pixel_values, grid_thw):
154206
inputs_embeds = self.embedder.embed(input_tokens=input_ids)
155207
if pixel_values is not None:
156208
image_features = self.get_image_embeddings(
157-
pixel_values=pixel_values, grid_thw=grid_thw
209+
pixel_values=pixel_values,
210+
grid_thw=grid_thw,
211+
encoder_chunk_size=encoder_chunk_size,
158212
)
159213

160214
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
@@ -230,12 +284,13 @@ def forward(
230284
output_attentions=False,
231285
use_cache=False,
232286
logits_to_keep=None,
287+
encoder_chunk_size=None,
233288
**kwargs: KwargsForCausalLM,
234289
):
235290
# Process the mixed batch if provided
236291
if inputs_embeds is None:
237292
inputs_embeds = self.embed_ids_boxes_images(
238-
input_ids, image_tiles, grid_thw
293+
input_ids, image_tiles, grid_thw, encoder_chunk_size
239294
)
240295

241296
# Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder

surya/common/surya/processor/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def scale_to_fit(
168168
elif current_pixels < min_pixels:
169169
scale_factor = (min_pixels / current_pixels) ** 0.5
170170

171-
new_width = int(width * scale_factor)
172-
new_height = int(height * scale_factor)
171+
new_width = math.ceil(width * scale_factor)
172+
new_height = math.ceil(height * scale_factor)
173173
else:
174174
return img
175175

surya/logging.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import logging
22
import warnings
3+
from surya.settings import settings
34

45

56
def configure_logging():
67
# Setup surya logger
7-
logger = logging.getLogger("surya")
8+
logger = get_logger()
89

910
if not logger.handlers:
1011
handler = logging.StreamHandler()
@@ -14,7 +15,7 @@ def configure_logging():
1415
handler.setFormatter(formatter)
1516
logger.addHandler(handler)
1617

17-
logger.setLevel(logging.DEBUG)
18+
logger.setLevel(settings.LOGLEVEL)
1819
warnings.simplefilter(action="ignore", category=FutureWarning)
1920

2021

surya/recognition/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
ContinuousBatchingQuantizedCache,
4242
)
4343
from surya.settings import settings
44+
from surya.logging import get_logger, configure_logging
45+
46+
configure_logging()
47+
logger = get_logger()
4448

4549

4650
@dataclass
@@ -73,6 +77,8 @@ class RecognitionPredictor(BasePredictor):
7377
batch_size = settings.RECOGNITION_BATCH_SIZE
7478
torch_dtype = settings.MODEL_DTYPE_BFLOAT
7579
default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128}
80+
encoder_chunk_size: int = 4096
81+
encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768}
7682
min_prefill_ratio: int = 0.2
7783
min_trim_length: int = 50
7884
tasks = {
@@ -104,6 +110,13 @@ def __init__(self, checkpoint=None, device=settings.TORCH_DEVICE_MODEL, dtype=No
104110
self.processor.pad_token_id, device=self.model.device, dtype=torch.long
105111
)
106112

113+
def get_encoder_chunk_size(self):
114+
chunk_size = self.encoder_chunk_size
115+
if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:
116+
if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:
117+
chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL]
118+
return chunk_size
119+
107120
def setup_cache(self, batch_size: int):
108121
self.kv_cache = None
109122
self.prompt_queue.clear()
@@ -328,6 +341,7 @@ def decode(self, current_inputs: Optional[ContinuousBatchInput] = None):
328341
return new_input, processed_output
329342

330343
def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
344+
logger.debug(f"Prefilling {self.num_empty_slots} slots")
331345
prompts: List[RecognitionPrompt] = [
332346
self.prompt_queue.popleft()
333347
for _ in range(min(self.num_empty_slots, len(self.prompt_queue)))
@@ -380,6 +394,7 @@ def prefill(self, current_inputs: Optional[ContinuousBatchInput] = None):
380394
past_key_values=prefill_cache,
381395
use_cache=True,
382396
logits_to_keep=1,
397+
encoder_chunk_size=self.get_encoder_chunk_size(),
383398
)
384399

385400
# Process outputs
@@ -462,6 +477,7 @@ def maybe_trim_cache_padding(self, current_inputs: ContinuousBatchInput):
462477
if trim_start < self.min_trim_length:
463478
return current_inputs
464479

480+
logger.debug(f"Trimming cache from left by {trim_start} tokens.")
465481
trimmed_attention_mask = attention_mask[:, trim_start:]
466482
current_inputs.attention_mask = trimmed_attention_mask
467483

surya/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class Settings(BaseSettings):
2222
10 # Number of workers for parallel model downloads
2323
)
2424
MODEL_CACHE_DIR: str = str(Path(user_cache_dir("datalab")) / "models")
25+
LOGLEVEL: str = "INFO" # Logging level
2526

2627
# Paths
2728
DATA_DIR: str = "data"

0 commit comments

Comments
 (0)