Skip to content

Commit a8ebacf

Browse files
committed
Pull in segformer code
1 parent 94d9954 commit a8ebacf

File tree

1 file changed

+338
-9
lines changed

1 file changed

+338
-9
lines changed

surya/model/detection/segformer.py

Lines changed: 338 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import gc
22
import warnings
3+
4+
from transformers.activations import ACT2FN
5+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
6+
37
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated")
48

59
import math
610
from typing import Optional, Tuple, Union
711

8-
from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerDecodeHead, SegformerModel
12+
from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerDecodeHead, \
13+
SegformerPreTrainedModel
914
from surya.model.detection.processor import SegformerImageProcessor
1015
import torch
1116
from torch import nn
1217

13-
from transformers.modeling_outputs import SemanticSegmenterOutput
18+
from transformers.modeling_outputs import SemanticSegmenterOutput, BaseModelOutput
1419
from surya.settings import settings
1520

1621

@@ -63,7 +68,6 @@ def __init__(self, config):
6368
self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
6469
self.activation = nn.ReLU()
6570

66-
self.dropout = nn.Dropout(config.classifier_dropout_prob)
6771
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
6872

6973
self.config = config
@@ -94,14 +98,342 @@ def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
9498
hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
9599
hidden_states = self.batch_norm(hidden_states)
96100
hidden_states = self.activation(hidden_states)
97-
hidden_states = self.dropout(hidden_states)
98101

99102
# logits are of shape (batch_size, num_labels, height/4, width/4)
100103
logits = self.classifier(hidden_states)
101104

102105
return logits
103106

104107

108+
class SegformerOverlapPatchEmbeddings(nn.Module):
109+
"""Construct the overlapping patch embeddings."""
110+
111+
def __init__(self, patch_size, stride, num_channels, hidden_size):
112+
super().__init__()
113+
self.proj = nn.Conv2d(
114+
num_channels,
115+
hidden_size,
116+
kernel_size=patch_size,
117+
stride=stride,
118+
padding=patch_size // 2,
119+
)
120+
121+
self.layer_norm = nn.LayerNorm(hidden_size)
122+
123+
def forward(self, pixel_values):
124+
embeddings = self.proj(pixel_values)
125+
_, _, height, width = embeddings.shape
126+
# (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
127+
# this can be fed to a Transformer layer
128+
embeddings = embeddings.flatten(2).transpose(1, 2)
129+
embeddings = self.layer_norm(embeddings)
130+
return embeddings, height, width
131+
132+
133+
class SegformerEfficientSelfAttention(nn.Module):
134+
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
135+
paper](https://arxiv.org/abs/2102.12122)."""
136+
137+
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
138+
super().__init__()
139+
self.hidden_size = hidden_size
140+
self.num_attention_heads = num_attention_heads
141+
142+
if self.hidden_size % self.num_attention_heads != 0:
143+
raise ValueError(
144+
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
145+
f"heads ({self.num_attention_heads})"
146+
)
147+
148+
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
149+
self.all_head_size = self.num_attention_heads * self.attention_head_size
150+
151+
self.query = nn.Linear(self.hidden_size, self.all_head_size)
152+
self.key = nn.Linear(self.hidden_size, self.all_head_size)
153+
self.value = nn.Linear(self.hidden_size, self.all_head_size)
154+
155+
self.sr_ratio = sequence_reduction_ratio
156+
if sequence_reduction_ratio > 1:
157+
self.sr = nn.Conv2d(
158+
hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
159+
)
160+
self.layer_norm = nn.LayerNorm(hidden_size)
161+
162+
def transpose_for_scores(self, hidden_states):
163+
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
164+
hidden_states = hidden_states.view(new_shape)
165+
return hidden_states.permute(0, 2, 1, 3)
166+
167+
def forward(
168+
self,
169+
hidden_states,
170+
height,
171+
width,
172+
output_attentions=False,
173+
):
174+
query_layer = self.transpose_for_scores(self.query(hidden_states))
175+
176+
if self.sr_ratio > 1:
177+
batch_size, seq_len, num_channels = hidden_states.shape
178+
# Reshape to (batch_size, num_channels, height, width)
179+
hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
180+
# Apply sequence reduction
181+
hidden_states = self.sr(hidden_states)
182+
# Reshape back to (batch_size, seq_len, num_channels)
183+
hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
184+
hidden_states = self.layer_norm(hidden_states)
185+
186+
key_layer = self.transpose_for_scores(self.key(hidden_states))
187+
value_layer = self.transpose_for_scores(self.value(hidden_states))
188+
189+
# Take the dot product between "query" and "key" to get the raw attention scores.
190+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
191+
192+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
193+
194+
# Normalize the attention scores to probabilities.
195+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
196+
197+
context_layer = torch.matmul(attention_probs, value_layer)
198+
199+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
200+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
201+
context_layer = context_layer.view(new_context_layer_shape)
202+
203+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
204+
205+
return outputs
206+
207+
class SegformerEncoder(nn.Module):
208+
def __init__(self, config):
209+
super().__init__()
210+
self.config = config
211+
212+
# patch embeddings
213+
embeddings = []
214+
for i in range(config.num_encoder_blocks):
215+
embeddings.append(
216+
SegformerOverlapPatchEmbeddings(
217+
patch_size=config.patch_sizes[i],
218+
stride=config.strides[i],
219+
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
220+
hidden_size=config.hidden_sizes[i],
221+
)
222+
)
223+
self.patch_embeddings = nn.ModuleList(embeddings)
224+
225+
# Transformer blocks
226+
blocks = []
227+
cur = 0
228+
for i in range(config.num_encoder_blocks):
229+
# each block consists of layers
230+
layers = []
231+
if i != 0:
232+
cur += config.depths[i - 1]
233+
for j in range(config.depths[i]):
234+
layers.append(
235+
SegformerLayer(
236+
config,
237+
hidden_size=config.hidden_sizes[i],
238+
num_attention_heads=config.num_attention_heads[i],
239+
sequence_reduction_ratio=config.sr_ratios[i],
240+
mlp_ratio=config.mlp_ratios[i],
241+
)
242+
)
243+
blocks.append(nn.ModuleList(layers))
244+
245+
self.block = nn.ModuleList(blocks)
246+
247+
# Layer norms
248+
self.layer_norm = nn.ModuleList(
249+
[nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
250+
)
251+
252+
def forward(
253+
self,
254+
pixel_values: torch.FloatTensor,
255+
output_attentions: Optional[bool] = False,
256+
output_hidden_states: Optional[bool] = False,
257+
return_dict: Optional[bool] = True,
258+
) -> Union[Tuple, BaseModelOutput]:
259+
all_hidden_states = () if output_hidden_states else None
260+
261+
batch_size = pixel_values.shape[0]
262+
263+
hidden_states = pixel_values
264+
for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)):
265+
embedding_layer, block_layer, norm_layer = x
266+
# first, obtain patch embeddings
267+
hidden_states, height, width = embedding_layer(hidden_states)
268+
# second, send embeddings through blocks
269+
for i, blk in enumerate(block_layer):
270+
layer_outputs = blk(hidden_states, height, width, output_attentions)
271+
hidden_states = layer_outputs[0]
272+
# third, apply layer norm
273+
hidden_states = norm_layer(hidden_states)
274+
# fourth, optionally reshape back to (batch_size, num_channels, height, width)
275+
if idx != len(self.patch_embeddings) - 1 or (
276+
idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage
277+
):
278+
hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
279+
all_hidden_states = all_hidden_states + (hidden_states,)
280+
281+
return all_hidden_states
282+
283+
class SegformerSelfOutput(nn.Module):
284+
def __init__(self, config, hidden_size):
285+
super().__init__()
286+
self.dense = nn.Linear(hidden_size, hidden_size)
287+
288+
def forward(self, hidden_states, input_tensor):
289+
hidden_states = self.dense(hidden_states)
290+
return hidden_states
291+
292+
293+
class SegformerAttention(nn.Module):
294+
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
295+
super().__init__()
296+
self.self = SegformerEfficientSelfAttention(
297+
config=config,
298+
hidden_size=hidden_size,
299+
num_attention_heads=num_attention_heads,
300+
sequence_reduction_ratio=sequence_reduction_ratio,
301+
)
302+
self.output = SegformerSelfOutput(config, hidden_size=hidden_size)
303+
self.pruned_heads = set()
304+
305+
def prune_heads(self, heads):
306+
if len(heads) == 0:
307+
return
308+
heads, index = find_pruneable_heads_and_indices(
309+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
310+
)
311+
312+
# Prune linear layers
313+
self.self.query = prune_linear_layer(self.self.query, index)
314+
self.self.key = prune_linear_layer(self.self.key, index)
315+
self.self.value = prune_linear_layer(self.self.value, index)
316+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
317+
318+
# Update hyper params and store pruned heads
319+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
320+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
321+
self.pruned_heads = self.pruned_heads.union(heads)
322+
323+
def forward(self, hidden_states, height, width, output_attentions=False):
324+
self_outputs = self.self(hidden_states, height, width, output_attentions)
325+
326+
attention_output = self.output(self_outputs[0], hidden_states)
327+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
328+
return outputs
329+
330+
class SegformerDWConv(nn.Module):
331+
def __init__(self, dim=768):
332+
super().__init__()
333+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
334+
335+
def forward(self, hidden_states, height, width):
336+
batch_size, seq_len, num_channels = hidden_states.shape
337+
hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
338+
hidden_states = self.dwconv(hidden_states)
339+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
340+
341+
return hidden_states
342+
343+
344+
class SegformerMixFFN(nn.Module):
345+
def __init__(self, config, in_features, hidden_features=None, out_features=None):
346+
super().__init__()
347+
out_features = out_features or in_features
348+
self.dense1 = nn.Linear(in_features, hidden_features)
349+
self.dwconv = SegformerDWConv(hidden_features)
350+
if isinstance(config.hidden_act, str):
351+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
352+
else:
353+
self.intermediate_act_fn = config.hidden_act
354+
self.dense2 = nn.Linear(hidden_features, out_features)
355+
356+
def forward(self, hidden_states, height, width):
357+
hidden_states = self.dense1(hidden_states)
358+
hidden_states = self.dwconv(hidden_states, height, width)
359+
hidden_states = self.intermediate_act_fn(hidden_states)
360+
hidden_states = self.dense2(hidden_states)
361+
return hidden_states
362+
363+
364+
class SegformerLayer(nn.Module):
365+
"""This corresponds to the Block class in the original implementation."""
366+
367+
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio, mlp_ratio):
368+
super().__init__()
369+
self.layer_norm_1 = nn.LayerNorm(hidden_size)
370+
self.attention = SegformerAttention(
371+
config,
372+
hidden_size=hidden_size,
373+
num_attention_heads=num_attention_heads,
374+
sequence_reduction_ratio=sequence_reduction_ratio,
375+
)
376+
self.layer_norm_2 = nn.LayerNorm(hidden_size)
377+
mlp_hidden_size = int(hidden_size * mlp_ratio)
378+
self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)
379+
380+
def forward(self, hidden_states, height, width, output_attentions=False):
381+
self_attention_outputs = self.attention(
382+
self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
383+
height,
384+
width,
385+
output_attentions=output_attentions,
386+
)
387+
388+
attention_output = self_attention_outputs[0]
389+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
390+
391+
# first residual connection (with stochastic depth)
392+
hidden_states = attention_output + hidden_states
393+
394+
mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
395+
396+
# second residual connection (with stochastic depth)
397+
layer_output = mlp_output + hidden_states
398+
399+
outputs = (layer_output,) + outputs
400+
401+
return outputs
402+
403+
class SegformerModel(SegformerPreTrainedModel):
404+
def __init__(self, config):
405+
super().__init__(config)
406+
self.config = config
407+
408+
# hierarchical Transformer encoder
409+
self.encoder = SegformerEncoder(config)
410+
411+
# Initialize weights and apply final processing
412+
self.post_init()
413+
414+
def _prune_heads(self, heads_to_prune):
415+
"""
416+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
417+
class PreTrainedModel
418+
"""
419+
for layer, heads in heads_to_prune.items():
420+
self.encoder.layer[layer].attention.prune_heads(heads)
421+
422+
def forward(
423+
self,
424+
pixel_values: torch.FloatTensor,
425+
output_attentions: Optional[bool] = None,
426+
output_hidden_states: Optional[bool] = None,
427+
return_dict: Optional[bool] = None,
428+
) -> Union[Tuple, BaseModelOutput]:
429+
encoder_outputs = self.encoder(
430+
pixel_values,
431+
output_attentions=output_attentions,
432+
output_hidden_states=output_hidden_states,
433+
return_dict=return_dict,
434+
)
435+
return encoder_outputs
436+
105437
class SegformerForRegressionMask(SegformerForSemanticSegmentation):
106438
def __init__(self, config):
107439
super().__init__(config)
@@ -119,17 +451,14 @@ def forward(
119451
output_hidden_states: Optional[bool] = None,
120452
return_dict: Optional[bool] = None,
121453
) -> Union[Tuple, SemanticSegmenterOutput]:
122-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
123454

124-
outputs = self.segformer(
455+
encoder_hidden_states = self.segformer(
125456
pixel_values,
126-
output_attentions=output_attentions,
457+
output_attentions=False,
127458
output_hidden_states=True, # we need the intermediate hidden states
128459
return_dict=False,
129460
)
130461

131-
encoder_hidden_states = outputs[1]
132-
133462
logits = self.decode_head(encoder_hidden_states)
134463
# Apply sigmoid to get 0-1 output
135464
sigmoid_logits = torch.special.expit(logits)

0 commit comments

Comments
 (0)