Skip to content

Commit 442069e

Browse files
authored
Merge pull request #379 from VikParuchuri/dev
Dev
2 parents 8a63dfc + 16f2e35 commit 442069e

File tree

11 files changed

+256
-32
lines changed

11 files changed

+256
-32
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.14.2"
3+
version = "0.14.3"
44
description = "OCR, layout, reading order, and table recognition in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"

surya/common/adetr/decoder.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ def forward(
193193
attn_output = self.o_proj(attn_output)
194194
return attn_output
195195

196+
def _clear_cache(self):
197+
if self.value_states is not None:
198+
del self.value_states
199+
if self.key_states is not None:
200+
del self.key_states
201+
196202
def _setup_cache(self, batch_size, device, dtype=None):
197203
# Setup initial caches
198204
self.value_states = None
@@ -297,6 +303,12 @@ def _setup_cache(self, batch_size, device, dtype=None):
297303
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
298304
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
299305

306+
def _clear_cache(self):
307+
if self.value_states is not None:
308+
del self.value_states
309+
if self.key_states is not None:
310+
del self.key_states
311+
300312
def _update_static_cache(self, key_states, value_states, **cache_kwargs):
301313
cache_position = cache_kwargs.get("cache_position")
302314
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
@@ -479,6 +491,14 @@ def _setup_cache(self, config, batch, device, dtype):
479491
if layer.cross_attn_block:
480492
layer.cross_attn_block._setup_cache(batch, device, dtype)
481493

494+
def _clear_cache(self):
495+
layers = getattr(self, "model", self).layers
496+
for layer in layers:
497+
if layer.temporal_block:
498+
layer.temporal_block._clear_cache()
499+
if layer.cross_attn_block:
500+
layer.cross_attn_block._clear_cache()
501+
482502
def reset_cache(self, batch, device, dtype):
483503
pass
484504

surya/common/surya/decoder/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(self, config: SuryaDecoderConfig, layer_idx: int):
156156
self.o_proj = nn.Linear(
157157
config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
158158
)
159+
self.merged_kv = False
159160

160161
def forward(
161162
self,
@@ -178,9 +179,6 @@ def forward(
178179
query_states, key_states, cos, sin
179180
)
180181

181-
# IMPORTANT: Do not use causal mask for prefill; Matches training
182-
# This is required for flash attn, which doesn't support a 4D mask as input
183-
# The `is_causal` argument is ignored by SDPA since we pass a 4D attention mask
184182
is_prefill = all(
185183
(
186184
input_shape[1] > 1,

surya/common/surya/encoder/__init__.py

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,100 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
270270
self.qkv = nn.Linear(dim, dim * 3, bias=True)
271271
self.proj = nn.Linear(dim, dim)
272272

273+
def unpack_qkv_with_mask(self, q, k, v, cu_seqlens):
274+
"""
275+
Unpacks q, k, v sequences into batch-major form and constructs an additive attention mask.
276+
277+
Args:
278+
q, k, v: Tensors of shape (total_seq_len, num_heads, head_dim)
279+
cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths
280+
281+
Returns:
282+
batched_q: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
283+
batched_k: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
284+
batched_v: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
285+
attention_mask: Tensor of shape (batch_size, 1, max_seq_len, max_seq_len)
286+
with 0 for valid tokens and -inf for padding (for additive attention)
287+
"""
288+
device = q.device
289+
dtype = q.dtype
290+
291+
batch_size = cu_seqlens.shape[0] - 1
292+
num_heads = q.shape[1]
293+
head_dim = q.shape[2]
294+
295+
seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
296+
max_seq_len = seq_lengths.max().item()
297+
298+
batch_indices = []
299+
position_indices = []
300+
301+
for i, seq_len in enumerate(seq_lengths):
302+
batch_indices.extend([i] * seq_len)
303+
position_indices.extend(list(range(seq_len)))
304+
305+
batch_indices = torch.tensor(batch_indices, device=device)
306+
position_indices = torch.tensor(position_indices, device=device)
307+
308+
batched_q = torch.zeros((batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype)
309+
batched_k = torch.zeros_like(batched_q)
310+
batched_v = torch.zeros_like(batched_q)
311+
312+
# Create additive attention mask: shape (batch_size, 1, max_seq_len, max_seq_len)
313+
# Each batch has a (max_seq_len, max_seq_len) matrix:
314+
# - Rows = queries, Columns = keys
315+
# - If query or key is padding, set to -inf
316+
attention_mask = torch.full(
317+
(batch_size, max_seq_len, max_seq_len),
318+
fill_value=float('-inf'),
319+
device=device,
320+
dtype=dtype
321+
)
322+
for b in range(batch_size):
323+
valid_len = seq_lengths[b].item()
324+
attention_mask[b, :valid_len, :valid_len] = 0 # Unmasked
325+
326+
attention_mask = attention_mask.unsqueeze(1) # (batch_size, 1, max_seq_len, max_seq_len)
327+
328+
batched_q[batch_indices, position_indices] = q
329+
batched_k[batch_indices, position_indices] = k
330+
batched_v[batch_indices, position_indices] = v
331+
332+
return batched_q, batched_k, batched_v, attention_mask
333+
334+
def repack_hidden_states(self, batched_output, cu_seqlens):
335+
"""
336+
Reverses the unpacking operation using indexing to convert batched outputs
337+
back to a flat tensor of shape (total_seq_len, hidden_dim).
338+
339+
Args:
340+
batched_output: Tensor of shape (batch_size, max_seq_len, hidden_dim)
341+
cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths
342+
343+
Returns:
344+
packed_output: Tensor of shape (total_seq_len, hidden_dim)
345+
"""
346+
device = batched_output.device
347+
dtype = batched_output.dtype
348+
349+
batch_size, max_seq_len, hidden_dim = batched_output.shape
350+
seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
351+
total_seq_len = seq_lengths.sum().item()
352+
353+
batch_indices = []
354+
position_indices = []
355+
356+
for i, seq_len in enumerate(seq_lengths):
357+
batch_indices.extend([i] * seq_len)
358+
position_indices.extend(list(range(seq_len)))
359+
360+
batch_indices = torch.tensor(batch_indices, device=device)
361+
position_indices = torch.tensor(position_indices, device=device)
362+
363+
packed_output = batched_output[batch_indices, position_indices]
364+
365+
return packed_output # Shape: (total_seq_len, hidden_dim)
366+
273367
def forward(
274368
self,
275369
hidden_states: torch.Tensor,
@@ -298,28 +392,22 @@ def forward(
298392
cos, sin = position_embeddings
299393
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
300394

301-
attention_mask = torch.zeros(
302-
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
303-
)
304-
for i in range(1, len(cu_seqlens)):
305-
attention_mask[
306-
...,
307-
cu_seqlens[i - 1] : cu_seqlens[i],
308-
cu_seqlens[i - 1] : cu_seqlens[i],
309-
] = True
310-
q = q.transpose(0, 1)
311-
k = k.transpose(0, 1)
312-
v = v.transpose(0, 1)
395+
q, k, v, attention_mask = self.unpack_qkv_with_mask(q, k, v, cu_seqlens)
396+
batch_size, max_seqlen = q.shape[:2]
397+
q = q.transpose(1, 2)
398+
k = k.transpose(1, 2)
399+
v = v.transpose(1, 2)
313400
attn_output = F.scaled_dot_product_attention(
314-
q.unsqueeze(0),
315-
k.unsqueeze(0),
316-
v.unsqueeze(0),
401+
q,
402+
k,
403+
v,
317404
attention_mask,
318405
dropout_p=0.0,
319406
)
320-
attn_output = attn_output.squeeze(0).transpose(0, 1)
321-
attn_output = attn_output.reshape(seq_length, -1)
407+
attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, max_seqlen, -1) # Bring back to (batch_size, max_seqlen, hidden_dim)
322408
attn_output = self.proj(attn_output)
409+
attn_output = self.repack_hidden_states(attn_output, cu_seqlens)
410+
323411
return attn_output
324412

325413

surya/common/util.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
from typing import List
3+
import torch
34

45
from surya.common.polygon import PolygonBox
56
from surya.settings import settings
@@ -22,7 +23,12 @@ def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
2223
other_box = other_box_obj.bbox
2324
if box == other_box:
2425
continue
25-
if box[0] >= other_box[0] and box[1] >= other_box[1] and box[2] <= other_box[2] and box[3] <= other_box[3]:
26+
if (
27+
box[0] >= other_box[0]
28+
and box[1] >= other_box[1]
29+
and box[2] <= other_box[2]
30+
and box[3] <= other_box[3]
31+
):
2632
contained = True
2733
break
2834
if not contained:
@@ -45,18 +51,42 @@ def rescale_bbox(bbox, processor_size, image_size):
4551
return new_bbox
4652

4753

48-
def expand_bbox(bbox, expansion_factor=.01):
54+
def expand_bbox(bbox, expansion_factor=0.01):
4955
expansion_low = 1 - expansion_factor
5056
expansion_high = 1 + expansion_factor
5157
return [
5258
bbox[0] * expansion_low,
5359
bbox[1] * expansion_low,
5460
bbox[2] * expansion_high,
55-
bbox[3] * expansion_high
61+
bbox[3] * expansion_high,
5662
]
5763

5864

59-
if settings.TORCH_DEVICE_MODEL == 'xla':
65+
def is_flash_attn_2_supported(device: str | torch.device) -> bool:
66+
if not torch.cuda.is_available():
67+
return False
68+
69+
if "cuda" not in str(device):
70+
return False
71+
72+
# Check CUDA version >= 12.0
73+
cuda_version_str = torch.version.cuda
74+
if cuda_version_str is None:
75+
return False
76+
cuda_version = tuple(map(int, cuda_version_str.split(".")))
77+
if cuda_version < (12, 0):
78+
return False
79+
80+
# Check GPU compute capability (Ampere, Ada, Hopper GPUs)
81+
major, minor = torch.cuda.get_device_capability()
82+
compute_capability = major + minor / 10
83+
if compute_capability < 8.0:
84+
return False
85+
86+
return True
87+
88+
89+
if settings.TORCH_DEVICE_MODEL == "xla":
6090
import torch_xla.core.xla_model as xm
6191
else:
6292
xm = None

surya/detection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,5 @@ def batch_detection(
151151
preds[idx] = heatmaps
152152

153153
yield preds, [orig_sizes[j] for j in batch_image_idxs]
154+
155+
torch.cuda.empty_cache()

surya/layout/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,5 +219,8 @@ def batch_layout_detection(
219219
batch_results = slicer.join(batch_results, tile_positions)
220220
results.extend(batch_results)
221221

222+
self.model.decoder.model._clear_cache()
223+
torch.cuda.empty_cache()
224+
222225
assert len(results) == len(images)
223226
return results

surya/recognition/__init__.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
words_from_chars,
3434
detect_repeat_token,
3535
prediction_to_polygon_batch,
36+
unwrap_math,
3637
)
3738
from surya.recognition.schema import TextLine, OCRResult, TextChar
3839
from surya.common.surya.schema import TaskNames
@@ -75,7 +76,7 @@ class RecognitionPrompt:
7576
class RecognitionPredictor(BasePredictor):
7677
model_loader_cls = RecognitionModelLoader
7778
batch_size = settings.RECOGNITION_BATCH_SIZE
78-
torch_dtype = settings.MODEL_DTYPE_BFLOAT
79+
torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16
7980
default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128}
8081
encoder_chunk_size: int = 4096
8182
encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768}
@@ -85,7 +86,7 @@ class RecognitionPredictor(BasePredictor):
8586
TaskNames.ocr_with_boxes: {
8687
"needs_bboxes": True,
8788
"img_size": (1024, 256), # 370 max tokens
88-
"max_tokens": 224,
89+
"max_tokens": 256,
8990
},
9091
TaskNames.ocr_without_boxes: {
9192
"needs_bboxes": False,
@@ -272,6 +273,10 @@ def prepare_input(
272273

273274
# Task input is the same for all tasks for now
274275
text = text or ""
276+
277+
# Remove input text that exceeds max generation tokens (likely invalid)
278+
if len(text) > self.tasks[task_name]["max_tokens"]:
279+
text = ""
275280
inputs = [
276281
{"type": "image", "image": image, "rotated": False},
277282
{"type": "text", "text": text.strip(), "math": math_mode},
@@ -588,11 +593,20 @@ def prediction_loop(
588593
current_inputs = self.maybe_trim_cache_padding(current_inputs)
589594
mark_step()
590595
pbar.close()
596+
597+
del self.kv_cache
598+
self.kv_cache = None
599+
torch.cuda.empty_cache()
591600

592601
return predicted_tokens, batch_bboxes, scores
593602

594603
def get_bboxes_text(
595-
self, flat: dict, predicted_tokens: list, scores: list, predicted_polygons: list
604+
self,
605+
flat: dict,
606+
predicted_tokens: list,
607+
scores: list,
608+
predicted_polygons: list,
609+
drop_repeated_text: bool = False,
596610
) -> list:
597611
char_predictions = []
598612
needs_boxes = [
@@ -614,10 +628,23 @@ def get_bboxes_text(
614628
needs_boxes,
615629
)
616630
):
631+
blank_bbox = [[0, 0], [0, 1], [1, 1], [1, 0]]
617632
if self.processor.no_output_token in image_tokens:
618633
char_predictions.append(None)
619634
continue
620635

636+
# If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely
637+
if drop_repeated_text and detect_repeat_token(image_tokens):
638+
char_predictions.append(
639+
TextChar(
640+
text="",
641+
polygon=blank_bbox,
642+
confidence=0,
643+
bbox_valid=False,
644+
)
645+
)
646+
continue
647+
621648
image_polygons = image_polygons[: len(image_tokens)].cpu().numpy().tolist()
622649

623650
detokenize_sequences = []
@@ -681,7 +708,6 @@ def _add_detokenize_sequence(
681708
img_chars = []
682709
for sequence in detokenize_sequences:
683710
token_ids, seq_score, bboxes, token_type = sequence
684-
blank_bbox = [[0, 0], [0, 1], [1, 1], [1, 0]]
685711
if token_type == "ocr":
686712
text = self.processor.ocr_tokenizer.decode(
687713
token_ids, task=TaskNames.ocr_with_boxes
@@ -750,6 +776,7 @@ def __call__(
750776
sort_lines: bool = False,
751777
math_mode: bool = True,
752778
return_words: bool = False,
779+
drop_repeated_text: bool = False,
753780
) -> List[OCRResult]:
754781
allowed_tasks = self.tasks.keys()
755782
if task_names is None:
@@ -874,6 +901,7 @@ def __call__(
874901
text_line, self.processor.ocr_tokenizer.special_tokens
875902
)
876903
text = "".join([char.text for char in text_line])
904+
text = unwrap_math(text)
877905
lines.append(
878906
TextLine(
879907
text=text,

0 commit comments

Comments
 (0)