@@ -76,17 +76,17 @@ class RecognitionPrompt:
76
76
class RecognitionPredictor (BasePredictor ):
77
77
model_loader_cls = RecognitionModelLoader
78
78
batch_size = settings .RECOGNITION_BATCH_SIZE
79
- torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16
79
+ torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16
80
80
default_batch_sizes = {"cpu" : 32 , "mps" : 64 , "cuda" : 256 , "xla" : 128 }
81
- encoder_chunk_size : int = 4096
81
+ encoder_chunk_size : int = 4096 # Default chunk size
82
82
encoder_chunk_sizes = {"cpu" : 4096 , "mps" : 4096 , "cuda" : 32768 , "xla" : 32768 }
83
83
min_prefill_ratio : int = 0.2
84
84
min_trim_length : int = 50
85
85
tasks = {
86
86
TaskNames .ocr_with_boxes : {
87
87
"needs_bboxes" : True ,
88
88
"img_size" : (1024 , 256 ), # 370 max tokens
89
- "max_tokens" : 256 ,
89
+ "max_tokens" : 224 ,
90
90
},
91
91
TaskNames .ocr_without_boxes : {
92
92
"needs_bboxes" : False ,
@@ -111,8 +111,11 @@ def __init__(self, checkpoint=None, device=settings.TORCH_DEVICE_MODEL, dtype=No
111
111
self .processor .pad_token_id , device = self .model .device , dtype = torch .long
112
112
)
113
113
114
- def get_encoder_chunk_size (self ):
115
- chunk_size = self .encoder_chunk_size
114
+ def get_encoder_chunk_size (self ) -> int :
115
+ if settings .RECOGNITION_CHUNK_SIZE is not None :
116
+ return settings .RECOGNITION_CHUNK_SIZE
117
+
118
+ chunk_size = settings .encoder_chunk_size
116
119
if settings .TORCH_DEVICE_MODEL in self .encoder_chunk_sizes :
117
120
if settings .TORCH_DEVICE_MODEL in self .encoder_chunk_sizes :
118
121
chunk_size = self .encoder_chunk_sizes [settings .TORCH_DEVICE_MODEL ]
@@ -239,6 +242,8 @@ def slice_bboxes(
239
242
== len (all_polygons )
240
243
== len (all_text )
241
244
== len (all_task_names )
245
+ ), (
246
+ f"Mismatch in lengths: { len (all_slices )} , { sum (slice_map )} , { len (all_polygons )} , { len (all_text )} , { len (all_task_names )} "
242
247
)
243
248
244
249
return {
@@ -593,7 +598,7 @@ def prediction_loop(
593
598
current_inputs = self .maybe_trim_cache_padding (current_inputs )
594
599
mark_step ()
595
600
pbar .close ()
596
-
601
+
597
602
del self .kv_cache
598
603
self .kv_cache = None
599
604
torch .cuda .empty_cache ()
@@ -636,12 +641,14 @@ def get_bboxes_text(
636
641
# If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely
637
642
if drop_repeated_text and detect_repeat_token (image_tokens ):
638
643
char_predictions .append (
639
- TextChar (
640
- text = "" ,
641
- polygon = blank_bbox ,
642
- confidence = 0 ,
643
- bbox_valid = False ,
644
- )
644
+ [
645
+ TextChar (
646
+ text = "" ,
647
+ polygon = blank_bbox ,
648
+ confidence = 0 ,
649
+ bbox_valid = False ,
650
+ )
651
+ ]
645
652
)
646
653
continue
647
654
@@ -772,7 +779,7 @@ def __call__(
772
779
highres_images : List [Image .Image ] | None = None ,
773
780
bboxes : List [List [List [int ]]] | None = None ,
774
781
polygons : List [List [List [List [int ]]]] | None = None ,
775
- input_text : List [str | None ] | None = None ,
782
+ input_text : List [List [ str | None ] ] | None = None ,
776
783
sort_lines : bool = False ,
777
784
math_mode : bool = True ,
778
785
return_words : bool = False ,
@@ -857,7 +864,11 @@ def __call__(
857
864
batch_bboxes , image_sizes , bbox_size , bbox_size // 2
858
865
)
859
866
char_predictions = self .get_bboxes_text (
860
- flat , predicted_tokens , scores , predicted_polygons
867
+ flat ,
868
+ predicted_tokens ,
869
+ scores ,
870
+ predicted_polygons ,
871
+ drop_repeated_text = drop_repeated_text ,
861
872
)
862
873
863
874
char_predictions = sorted (zip (indices , char_predictions ), key = lambda x : x [0 ])
@@ -886,7 +897,11 @@ def __call__(
886
897
)
887
898
)
888
899
else :
889
- confidence = float (np .mean ([char .confidence for char in text_line ]))
900
+ confidence = (
901
+ float (np .mean ([char .confidence for char in text_line ]))
902
+ if len (text_line ) > 0
903
+ else 0
904
+ )
890
905
poly_box = PolygonBox (polygon = polygon )
891
906
for char in text_line :
892
907
char .rescale (
0 commit comments