1
1
import gc
2
2
import warnings
3
+
4
+ from transformers .activations import ACT2FN
5
+ from transformers .pytorch_utils import find_pruneable_heads_and_indices , prune_linear_layer
6
+
3
7
warnings .filterwarnings ("ignore" , message = "torch.utils._pytree._register_pytree_node is deprecated" )
4
8
5
9
import math
6
10
from typing import Optional , Tuple , Union
7
11
8
- from transformers import SegformerConfig , SegformerForSemanticSegmentation , SegformerDecodeHead , SegformerModel
12
+ from transformers import SegformerConfig , SegformerForSemanticSegmentation , SegformerDecodeHead , \
13
+ SegformerPreTrainedModel
9
14
from surya .model .detection .processor import SegformerImageProcessor
10
15
import torch
11
16
from torch import nn
12
17
13
- from transformers .modeling_outputs import SemanticSegmenterOutput
18
+ from transformers .modeling_outputs import SemanticSegmenterOutput , BaseModelOutput
14
19
from surya .settings import settings
15
20
16
21
@@ -63,7 +68,6 @@ def __init__(self, config):
63
68
self .batch_norm = nn .BatchNorm2d (config .decoder_hidden_size )
64
69
self .activation = nn .ReLU ()
65
70
66
- self .dropout = nn .Dropout (config .classifier_dropout_prob )
67
71
self .classifier = nn .Conv2d (config .decoder_hidden_size , config .num_labels , kernel_size = 1 )
68
72
69
73
self .config = config
@@ -94,14 +98,342 @@ def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
94
98
hidden_states = self .linear_fuse (torch .cat (all_hidden_states [::- 1 ], dim = 1 ))
95
99
hidden_states = self .batch_norm (hidden_states )
96
100
hidden_states = self .activation (hidden_states )
97
- hidden_states = self .dropout (hidden_states )
98
101
99
102
# logits are of shape (batch_size, num_labels, height/4, width/4)
100
103
logits = self .classifier (hidden_states )
101
104
102
105
return logits
103
106
104
107
108
+ class SegformerOverlapPatchEmbeddings (nn .Module ):
109
+ """Construct the overlapping patch embeddings."""
110
+
111
+ def __init__ (self , patch_size , stride , num_channels , hidden_size ):
112
+ super ().__init__ ()
113
+ self .proj = nn .Conv2d (
114
+ num_channels ,
115
+ hidden_size ,
116
+ kernel_size = patch_size ,
117
+ stride = stride ,
118
+ padding = patch_size // 2 ,
119
+ )
120
+
121
+ self .layer_norm = nn .LayerNorm (hidden_size )
122
+
123
+ def forward (self , pixel_values ):
124
+ embeddings = self .proj (pixel_values )
125
+ _ , _ , height , width = embeddings .shape
126
+ # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
127
+ # this can be fed to a Transformer layer
128
+ embeddings = embeddings .flatten (2 ).transpose (1 , 2 )
129
+ embeddings = self .layer_norm (embeddings )
130
+ return embeddings , height , width
131
+
132
+
133
+ class SegformerEfficientSelfAttention (nn .Module ):
134
+ """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
135
+ paper](https://arxiv.org/abs/2102.12122)."""
136
+
137
+ def __init__ (self , config , hidden_size , num_attention_heads , sequence_reduction_ratio ):
138
+ super ().__init__ ()
139
+ self .hidden_size = hidden_size
140
+ self .num_attention_heads = num_attention_heads
141
+
142
+ if self .hidden_size % self .num_attention_heads != 0 :
143
+ raise ValueError (
144
+ f"The hidden size ({ self .hidden_size } ) is not a multiple of the number of attention "
145
+ f"heads ({ self .num_attention_heads } )"
146
+ )
147
+
148
+ self .attention_head_size = int (self .hidden_size / self .num_attention_heads )
149
+ self .all_head_size = self .num_attention_heads * self .attention_head_size
150
+
151
+ self .query = nn .Linear (self .hidden_size , self .all_head_size )
152
+ self .key = nn .Linear (self .hidden_size , self .all_head_size )
153
+ self .value = nn .Linear (self .hidden_size , self .all_head_size )
154
+
155
+ self .sr_ratio = sequence_reduction_ratio
156
+ if sequence_reduction_ratio > 1 :
157
+ self .sr = nn .Conv2d (
158
+ hidden_size , hidden_size , kernel_size = sequence_reduction_ratio , stride = sequence_reduction_ratio
159
+ )
160
+ self .layer_norm = nn .LayerNorm (hidden_size )
161
+
162
+ def transpose_for_scores (self , hidden_states ):
163
+ new_shape = hidden_states .size ()[:- 1 ] + (self .num_attention_heads , self .attention_head_size )
164
+ hidden_states = hidden_states .view (new_shape )
165
+ return hidden_states .permute (0 , 2 , 1 , 3 )
166
+
167
+ def forward (
168
+ self ,
169
+ hidden_states ,
170
+ height ,
171
+ width ,
172
+ output_attentions = False ,
173
+ ):
174
+ query_layer = self .transpose_for_scores (self .query (hidden_states ))
175
+
176
+ if self .sr_ratio > 1 :
177
+ batch_size , seq_len , num_channels = hidden_states .shape
178
+ # Reshape to (batch_size, num_channels, height, width)
179
+ hidden_states = hidden_states .permute (0 , 2 , 1 ).reshape (batch_size , num_channels , height , width )
180
+ # Apply sequence reduction
181
+ hidden_states = self .sr (hidden_states )
182
+ # Reshape back to (batch_size, seq_len, num_channels)
183
+ hidden_states = hidden_states .reshape (batch_size , num_channels , - 1 ).permute (0 , 2 , 1 )
184
+ hidden_states = self .layer_norm (hidden_states )
185
+
186
+ key_layer = self .transpose_for_scores (self .key (hidden_states ))
187
+ value_layer = self .transpose_for_scores (self .value (hidden_states ))
188
+
189
+ # Take the dot product between "query" and "key" to get the raw attention scores.
190
+ attention_scores = torch .matmul (query_layer , key_layer .transpose (- 1 , - 2 ))
191
+
192
+ attention_scores = attention_scores / math .sqrt (self .attention_head_size )
193
+
194
+ # Normalize the attention scores to probabilities.
195
+ attention_probs = nn .functional .softmax (attention_scores , dim = - 1 )
196
+
197
+ context_layer = torch .matmul (attention_probs , value_layer )
198
+
199
+ context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
200
+ new_context_layer_shape = context_layer .size ()[:- 2 ] + (self .all_head_size ,)
201
+ context_layer = context_layer .view (new_context_layer_shape )
202
+
203
+ outputs = (context_layer , attention_probs ) if output_attentions else (context_layer ,)
204
+
205
+ return outputs
206
+
207
+ class SegformerEncoder (nn .Module ):
208
+ def __init__ (self , config ):
209
+ super ().__init__ ()
210
+ self .config = config
211
+
212
+ # patch embeddings
213
+ embeddings = []
214
+ for i in range (config .num_encoder_blocks ):
215
+ embeddings .append (
216
+ SegformerOverlapPatchEmbeddings (
217
+ patch_size = config .patch_sizes [i ],
218
+ stride = config .strides [i ],
219
+ num_channels = config .num_channels if i == 0 else config .hidden_sizes [i - 1 ],
220
+ hidden_size = config .hidden_sizes [i ],
221
+ )
222
+ )
223
+ self .patch_embeddings = nn .ModuleList (embeddings )
224
+
225
+ # Transformer blocks
226
+ blocks = []
227
+ cur = 0
228
+ for i in range (config .num_encoder_blocks ):
229
+ # each block consists of layers
230
+ layers = []
231
+ if i != 0 :
232
+ cur += config .depths [i - 1 ]
233
+ for j in range (config .depths [i ]):
234
+ layers .append (
235
+ SegformerLayer (
236
+ config ,
237
+ hidden_size = config .hidden_sizes [i ],
238
+ num_attention_heads = config .num_attention_heads [i ],
239
+ sequence_reduction_ratio = config .sr_ratios [i ],
240
+ mlp_ratio = config .mlp_ratios [i ],
241
+ )
242
+ )
243
+ blocks .append (nn .ModuleList (layers ))
244
+
245
+ self .block = nn .ModuleList (blocks )
246
+
247
+ # Layer norms
248
+ self .layer_norm = nn .ModuleList (
249
+ [nn .LayerNorm (config .hidden_sizes [i ]) for i in range (config .num_encoder_blocks )]
250
+ )
251
+
252
+ def forward (
253
+ self ,
254
+ pixel_values : torch .FloatTensor ,
255
+ output_attentions : Optional [bool ] = False ,
256
+ output_hidden_states : Optional [bool ] = False ,
257
+ return_dict : Optional [bool ] = True ,
258
+ ) -> Union [Tuple , BaseModelOutput ]:
259
+ all_hidden_states = () if output_hidden_states else None
260
+
261
+ batch_size = pixel_values .shape [0 ]
262
+
263
+ hidden_states = pixel_values
264
+ for idx , x in enumerate (zip (self .patch_embeddings , self .block , self .layer_norm )):
265
+ embedding_layer , block_layer , norm_layer = x
266
+ # first, obtain patch embeddings
267
+ hidden_states , height , width = embedding_layer (hidden_states )
268
+ # second, send embeddings through blocks
269
+ for i , blk in enumerate (block_layer ):
270
+ layer_outputs = blk (hidden_states , height , width , output_attentions )
271
+ hidden_states = layer_outputs [0 ]
272
+ # third, apply layer norm
273
+ hidden_states = norm_layer (hidden_states )
274
+ # fourth, optionally reshape back to (batch_size, num_channels, height, width)
275
+ if idx != len (self .patch_embeddings ) - 1 or (
276
+ idx == len (self .patch_embeddings ) - 1 and self .config .reshape_last_stage
277
+ ):
278
+ hidden_states = hidden_states .reshape (batch_size , height , width , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
279
+ all_hidden_states = all_hidden_states + (hidden_states ,)
280
+
281
+ return all_hidden_states
282
+
283
+ class SegformerSelfOutput (nn .Module ):
284
+ def __init__ (self , config , hidden_size ):
285
+ super ().__init__ ()
286
+ self .dense = nn .Linear (hidden_size , hidden_size )
287
+
288
+ def forward (self , hidden_states , input_tensor ):
289
+ hidden_states = self .dense (hidden_states )
290
+ return hidden_states
291
+
292
+
293
+ class SegformerAttention (nn .Module ):
294
+ def __init__ (self , config , hidden_size , num_attention_heads , sequence_reduction_ratio ):
295
+ super ().__init__ ()
296
+ self .self = SegformerEfficientSelfAttention (
297
+ config = config ,
298
+ hidden_size = hidden_size ,
299
+ num_attention_heads = num_attention_heads ,
300
+ sequence_reduction_ratio = sequence_reduction_ratio ,
301
+ )
302
+ self .output = SegformerSelfOutput (config , hidden_size = hidden_size )
303
+ self .pruned_heads = set ()
304
+
305
+ def prune_heads (self , heads ):
306
+ if len (heads ) == 0 :
307
+ return
308
+ heads , index = find_pruneable_heads_and_indices (
309
+ heads , self .self .num_attention_heads , self .self .attention_head_size , self .pruned_heads
310
+ )
311
+
312
+ # Prune linear layers
313
+ self .self .query = prune_linear_layer (self .self .query , index )
314
+ self .self .key = prune_linear_layer (self .self .key , index )
315
+ self .self .value = prune_linear_layer (self .self .value , index )
316
+ self .output .dense = prune_linear_layer (self .output .dense , index , dim = 1 )
317
+
318
+ # Update hyper params and store pruned heads
319
+ self .self .num_attention_heads = self .self .num_attention_heads - len (heads )
320
+ self .self .all_head_size = self .self .attention_head_size * self .self .num_attention_heads
321
+ self .pruned_heads = self .pruned_heads .union (heads )
322
+
323
+ def forward (self , hidden_states , height , width , output_attentions = False ):
324
+ self_outputs = self .self (hidden_states , height , width , output_attentions )
325
+
326
+ attention_output = self .output (self_outputs [0 ], hidden_states )
327
+ outputs = (attention_output ,) + self_outputs [1 :] # add attentions if we output them
328
+ return outputs
329
+
330
+ class SegformerDWConv (nn .Module ):
331
+ def __init__ (self , dim = 768 ):
332
+ super ().__init__ ()
333
+ self .dwconv = nn .Conv2d (dim , dim , 3 , 1 , 1 , bias = True , groups = dim )
334
+
335
+ def forward (self , hidden_states , height , width ):
336
+ batch_size , seq_len , num_channels = hidden_states .shape
337
+ hidden_states = hidden_states .transpose (1 , 2 ).view (batch_size , num_channels , height , width )
338
+ hidden_states = self .dwconv (hidden_states )
339
+ hidden_states = hidden_states .flatten (2 ).transpose (1 , 2 )
340
+
341
+ return hidden_states
342
+
343
+
344
+ class SegformerMixFFN (nn .Module ):
345
+ def __init__ (self , config , in_features , hidden_features = None , out_features = None ):
346
+ super ().__init__ ()
347
+ out_features = out_features or in_features
348
+ self .dense1 = nn .Linear (in_features , hidden_features )
349
+ self .dwconv = SegformerDWConv (hidden_features )
350
+ if isinstance (config .hidden_act , str ):
351
+ self .intermediate_act_fn = ACT2FN [config .hidden_act ]
352
+ else :
353
+ self .intermediate_act_fn = config .hidden_act
354
+ self .dense2 = nn .Linear (hidden_features , out_features )
355
+
356
+ def forward (self , hidden_states , height , width ):
357
+ hidden_states = self .dense1 (hidden_states )
358
+ hidden_states = self .dwconv (hidden_states , height , width )
359
+ hidden_states = self .intermediate_act_fn (hidden_states )
360
+ hidden_states = self .dense2 (hidden_states )
361
+ return hidden_states
362
+
363
+
364
+ class SegformerLayer (nn .Module ):
365
+ """This corresponds to the Block class in the original implementation."""
366
+
367
+ def __init__ (self , config , hidden_size , num_attention_heads , sequence_reduction_ratio , mlp_ratio ):
368
+ super ().__init__ ()
369
+ self .layer_norm_1 = nn .LayerNorm (hidden_size )
370
+ self .attention = SegformerAttention (
371
+ config ,
372
+ hidden_size = hidden_size ,
373
+ num_attention_heads = num_attention_heads ,
374
+ sequence_reduction_ratio = sequence_reduction_ratio ,
375
+ )
376
+ self .layer_norm_2 = nn .LayerNorm (hidden_size )
377
+ mlp_hidden_size = int (hidden_size * mlp_ratio )
378
+ self .mlp = SegformerMixFFN (config , in_features = hidden_size , hidden_features = mlp_hidden_size )
379
+
380
+ def forward (self , hidden_states , height , width , output_attentions = False ):
381
+ self_attention_outputs = self .attention (
382
+ self .layer_norm_1 (hidden_states ), # in Segformer, layernorm is applied before self-attention
383
+ height ,
384
+ width ,
385
+ output_attentions = output_attentions ,
386
+ )
387
+
388
+ attention_output = self_attention_outputs [0 ]
389
+ outputs = self_attention_outputs [1 :] # add self attentions if we output attention weights
390
+
391
+ # first residual connection (with stochastic depth)
392
+ hidden_states = attention_output + hidden_states
393
+
394
+ mlp_output = self .mlp (self .layer_norm_2 (hidden_states ), height , width )
395
+
396
+ # second residual connection (with stochastic depth)
397
+ layer_output = mlp_output + hidden_states
398
+
399
+ outputs = (layer_output ,) + outputs
400
+
401
+ return outputs
402
+
403
+ class SegformerModel (SegformerPreTrainedModel ):
404
+ def __init__ (self , config ):
405
+ super ().__init__ (config )
406
+ self .config = config
407
+
408
+ # hierarchical Transformer encoder
409
+ self .encoder = SegformerEncoder (config )
410
+
411
+ # Initialize weights and apply final processing
412
+ self .post_init ()
413
+
414
+ def _prune_heads (self , heads_to_prune ):
415
+ """
416
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
417
+ class PreTrainedModel
418
+ """
419
+ for layer , heads in heads_to_prune .items ():
420
+ self .encoder .layer [layer ].attention .prune_heads (heads )
421
+
422
+ def forward (
423
+ self ,
424
+ pixel_values : torch .FloatTensor ,
425
+ output_attentions : Optional [bool ] = None ,
426
+ output_hidden_states : Optional [bool ] = None ,
427
+ return_dict : Optional [bool ] = None ,
428
+ ) -> Union [Tuple , BaseModelOutput ]:
429
+ encoder_outputs = self .encoder (
430
+ pixel_values ,
431
+ output_attentions = output_attentions ,
432
+ output_hidden_states = output_hidden_states ,
433
+ return_dict = return_dict ,
434
+ )
435
+ return encoder_outputs
436
+
105
437
class SegformerForRegressionMask (SegformerForSemanticSegmentation ):
106
438
def __init__ (self , config ):
107
439
super ().__init__ (config )
@@ -119,17 +451,14 @@ def forward(
119
451
output_hidden_states : Optional [bool ] = None ,
120
452
return_dict : Optional [bool ] = None ,
121
453
) -> Union [Tuple , SemanticSegmenterOutput ]:
122
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
123
454
124
- outputs = self .segformer (
455
+ encoder_hidden_states = self .segformer (
125
456
pixel_values ,
126
- output_attentions = output_attentions ,
457
+ output_attentions = False ,
127
458
output_hidden_states = True , # we need the intermediate hidden states
128
459
return_dict = False ,
129
460
)
130
461
131
- encoder_hidden_states = outputs [1 ]
132
-
133
462
logits = self .decode_head (encoder_hidden_states )
134
463
# Apply sigmoid to get 0-1 output
135
464
sigmoid_logits = torch .special .expit (logits )
0 commit comments