1
1
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
- import math
3
2
import logging
3
+ import math
4
4
5
5
import torch
6
6
from classy_vision import tasks
7
7
from classy_vision .hooks .classy_hook import ClassyHook
8
8
from vissl .models import build_model
9
9
from vissl .utils .env import get_machine_local_and_dist_rank
10
10
11
+
11
12
class BYOLHook (ClassyHook ):
12
13
"""
13
- BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
14
- is based on Contrastive learning. This hook
15
- creates a target network with the same architecture
16
- as the main online network, but without the projection head.
17
- The online network does not participate in backpropogation,
18
- but instead is an exponential moving average of the online network.
14
+ BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
15
+ is based on Contrastive learning. This hook
16
+ creates a target network with the same architecture
17
+ as the main online network, but without the projection head.
18
+ The online network does not participate in backpropogation,
19
+ but instead is an exponential moving average of the online network.
19
20
"""
20
21
21
22
on_start = ClassyHook ._noop
@@ -28,7 +29,7 @@ class BYOLHook(ClassyHook):
28
29
on_update = ClassyHook ._noop
29
30
30
31
@staticmethod
31
- def cosine_decay (training_iter , max_iters , initial_value ) -> float :
32
+ def cosine_decay (training_iter , max_iters , initial_value ) -> float :
32
33
"""
33
34
For a given starting value, this function anneals the learning
34
35
rate.
@@ -42,8 +43,8 @@ def target_ema(training_iter, base_ema, max_iters) -> float:
42
43
"""
43
44
Updates Exponential Moving average of the Target Network.
44
45
"""
45
- decay = BYOLHook .cosine_decay (training_iter , max_iters , 1. )
46
- return 1. - (1. - base_ema ) * decay
46
+ decay = BYOLHook .cosine_decay (training_iter , max_iters , 1.0 )
47
+ return 1.0 - (1.0 - base_ema ) * decay
47
48
48
49
def _build_byol_target_network (self , task : tasks .ClassyTask ) -> None :
49
50
"""
@@ -53,27 +54,29 @@ def _build_byol_target_network(self, task: tasks.ClassyTask) -> None:
53
54
"""
54
55
# Create the encoder, which will slowly track the model
55
56
logging .info (
56
- "BYOL: Building BYOL target network - rank %s %s" , * get_machine_local_and_dist_rank ()
57
+ "BYOL: Building BYOL target network - rank %s %s" ,
58
+ * get_machine_local_and_dist_rank (),
57
59
)
58
60
59
- # Target model has the same architecture, but without the projector head.
60
- target_model_config = task .config ['MODEL' ]
61
- target_model_config ['HEAD' ]['PARAMS' ] = target_model_config ['HEAD' ]['PARAMS' ][0 :1 ]
61
+ # Target model has the same architecture, *without* the projector head.
62
+ target_model_config = task .config ["MODEL" ]
63
+ target_model_config ["HEAD" ]["PARAMS" ] = target_model_config ["HEAD" ]["PARAMS" ][
64
+ 0 :1
65
+ ]
62
66
task .loss .target_network = build_model (
63
67
target_model_config , task .config ["OPTIMIZER" ]
64
68
)
65
69
66
- # TESTED: Target Network and Online network are properly created.
67
- # TODO: Check SyncBatchNorm settings (low prior)
68
-
69
70
task .loss .target_network .to (task .device )
70
71
71
72
# Restore an hypothetical checkpoint, else copy the model parameters from the
72
73
# online network.
73
74
if task .loss .checkpoint is not None :
74
75
task .loss .load_state_dict (task .loss .checkpoint )
75
76
else :
76
- logging .info ("BYOL: Copying and freezing model parameters from online to target network" )
77
+ logging .info (
78
+ "BYOL: Copying and freezing model parameters from online to target network"
79
+ )
77
80
for param_q , param_k in zip (
78
81
task .base_model .parameters (), task .loss .target_network .parameters ()
79
82
):
@@ -92,7 +95,9 @@ def _update_momentum_coefficient(self, task: tasks.ClassyTask) -> None:
92
95
self .total_iters = task .max_iteration
93
96
logging .info (f"{ self .total_iters } total iters" )
94
97
training_iteration = task .iteration
95
- self .momentum = self .target_ema (training_iteration , self .base_momentum , self .total_iters )
98
+ self .momentum = self .target_ema (
99
+ training_iteration , self .base_momentum , self .total_iters
100
+ )
96
101
97
102
@torch .no_grad ()
98
103
def _update_target_network (self , task : tasks .ClassyTask ) -> None :
@@ -106,10 +111,10 @@ def _update_target_network(self, task: tasks.ClassyTask) -> None:
106
111
task .base_model .parameters (), task .loss .target_network .parameters ()
107
112
):
108
113
target_params .data = (
109
- target_params .data * self .momentum + online_params .data * (1. - self .momentum )
114
+ target_params .data * self .momentum
115
+ + online_params .data * (1.0 - self .momentum )
110
116
)
111
117
112
-
113
118
@torch .no_grad ()
114
119
def on_forward (self , task : tasks .ClassyTask ) -> None :
115
120
"""
@@ -127,9 +132,8 @@ def on_forward(self, task: tasks.ClassyTask) -> None:
127
132
else :
128
133
self ._update_target_network (task )
129
134
130
-
131
135
# Compute target network embeddings
132
- batch = task .last_batch .sample [' input' ]
136
+ batch = task .last_batch .sample [" input" ]
133
137
target_embs = task .loss .target_network (batch )[0 ]
134
138
135
139
# Save target embeddings to use them in the loss
0 commit comments