Skip to content

Commit 1eb828a

Browse files
committed
Pad out languages if needed
1 parent 407c34d commit 1eb828a

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

surya/recognition.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
6363

6464
batch_pixel_values = processed_batches["pixel_values"][i:i+batch_size]
6565
batch_langs = processed_batches["langs"][i:i+batch_size]
66+
max_lang_len = max([len(lang) for lang in batch_langs])
67+
68+
# Pad languages to max length if needed, to ensure we can convert to a tensor
69+
for lang_idx in range(len(batch_langs)):
70+
lang_len = len(batch_langs[lang_idx])
71+
if lang_len < max_lang_len:
72+
batch_langs[lang_idx] = [processor.tokenizer.pad_id] * (max_lang_len - lang_len) + batch_langs[lang_idx]
73+
6674
batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
6775
current_batch_size = len(batch_pixel_values)
6876

@@ -120,7 +128,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
120128
encoder_cache = [None] * layer_count
121129
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
122130

123-
with torch.no_grad():
131+
with torch.no_grad(): # inference_mode doesn't work with torch.compile
124132
# Run post-prefill tokens
125133
while token_count < settings.RECOGNITION_MAX_TOKENS:
126134
is_prefill = token_count == 0

0 commit comments

Comments
 (0)