Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Add FLOPs computation to ViT #746

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,7 @@ def flops(self, x):
batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops
)

# dropout layer
elif layer_type in ["Dropout"]:
# At test time, we do not drop values but scale the feature map by the
# dropout ratio
flops = 1
for dim_size in x.size():
flops *= dim_size

elif layer_type == "Identity":
elif layer_type in ["Dropout", "Identity"]:
flops = 0

elif hasattr(layer, "flops"):
Expand Down
6 changes: 3 additions & 3 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,16 @@ def on_phase_end(self, task) -> None:
f"Parameters/{name}", parameter, global_step=phase_type_idx
)

if torch.cuda.is_available() and task.train:
if torch.cuda.is_available():
self.tb_writer.add_scalar(
"Memory/peak_allocated",
f"Memory/{phase_type}/peak_allocated",
torch.cuda.max_memory_allocated(),
global_step=phase_type_idx,
)

loss_avg = sum(task.losses) / batches

loss_key = "Losses/{phase_type}".format(phase_type=task.phase_type)
loss_key = f"Losses/{phase_type}"
self.tb_writer.add_scalar(loss_key, loss_avg, global_step=phase_type_idx)

# plot meters which return a dict
Expand Down
28 changes: 28 additions & 0 deletions classy_vision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self.dropout = nn.Dropout(dropout_rate)
self.ln_2 = LayerNorm(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout_rate)
self.num_heads = num_heads

def forward(self, input):
x = self.ln_1(input)
Expand All @@ -87,6 +88,29 @@ def forward(self, input):
y = self.mlp(y)
return x + y

def flops(self, x):
flops = 0
seq_len, batch_size, hidden_dim = x.shape

num_elems = x.numel() // batch_size
flops += num_elems * 6 # ln_1 (* 2), x + input, ln_2 (* 2), x + y

# self_attention
# calculations are based on the fact that head_dim * num_heads = hidden_dim
flops += 3 * seq_len * (hidden_dim + 1) * hidden_dim # projection with bias
flops += hidden_dim * seq_len # scaling
flops += hidden_dim * seq_len * seq_len # attention weights
flops += self.num_heads * seq_len * seq_len # softmax
flops += hidden_dim * seq_len * seq_len # attention application
flops += seq_len * (hidden_dim + 1) * hidden_dim # out projection with bias

# mlp
mlp_dim = self.mlp.linear_1.out_features
flops += seq_len * (hidden_dim + 1) * mlp_dim # linear_1
flops += seq_len * mlp_dim # act
flops += seq_len * (mlp_dim + 1) * hidden_dim # linear_2
return flops * batch_size


class Encoder(nn.Module):
"""Transformer Encoder."""
Expand Down Expand Up @@ -300,6 +324,10 @@ def set_classy_state(self, state, strict=True):
state["model"]["trunk"]["encoder.pos_embedding"] = new_pos_embedding
super().set_classy_state(state, strict=strict)

@property
def input_shape(self):
return (3, self.image_size, self.image_size)


@register_model("vit_b_32")
class ViTB32(VisionTransformer):
Expand Down
10 changes: 3 additions & 7 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,7 @@ def on_phase_start(self):
self.phase_start_time_train = time.perf_counter()

def on_phase_end(self):
self.log_phase_end("train")
self.log_phase_end(self.phase_type)

if self.train:
self.optimizer.on_epoch(where=self.where)
Expand All @@ -1308,7 +1308,7 @@ def on_phase_end(self):
hook.on_phase_end(self)
self.perf_log = []

self.log_phase_end("total")
self.log_phase_end(f"{self.phase_type}_total")

if hasattr(self.datasets[self.phase_type], "on_phase_end"):
self.datasets[self.phase_type].on_phase_end()
Expand All @@ -1318,12 +1318,9 @@ def on_end(self):
hook.on_end(self)

def log_phase_end(self, tag):
if not self.train:
return

start_time = (
self.phase_start_time_train
if tag == "train"
if tag == self.phase_type
else self.phase_start_time_total
)
phase_duration = time.perf_counter() - start_time
Expand All @@ -1334,7 +1331,6 @@ def log_phase_end(self, tag):
{
"tag": tag,
"phase_idx": self.train_phase_idx,
"epoch_duration": phase_duration,
"im_per_sec": im_per_sec,
}
)
Expand Down