Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit a4047e9

Browse files
committed
BYOL improvements
1 parent 6e3063d commit a4047e9

File tree

6 files changed

+75
-65
lines changed

6 files changed

+75
-65
lines changed

configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml

+17-8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ config:
2222
TRANSFORMS:
2323
- name: RandomResizedCrop
2424
size: 224
25+
interpolation: 3
2526
- name: RandomHorizontalFlip
2627
- name: ToTensor
2728
- name: Normalize
@@ -38,6 +39,7 @@ config:
3839
TRANSFORMS:
3940
- name: Resize
4041
size: 256
42+
interpolation: 3
4143
- name: CenterCrop
4244
size: 224
4345
- name: ToTensor
@@ -82,7 +84,7 @@ config:
8284
PARAMS_FILE: "specify the model weights"
8385
STATE_DICT_KEY_NAME: classy_state_dict
8486
SYNC_BN_CONFIG:
85-
CONVERT_BN_TO_SYNC_BN: True
87+
CONVERT_BN_TO_SYNC_BN: False
8688
SYNC_BN_TYPE: apex
8789
GROUP_SIZE: 8
8890
LOSS:
@@ -93,22 +95,29 @@ config:
9395
name: sgd
9496
momentum: 0.9
9597
num_epochs: 80
98+
weight_decay: 0
9699
nesterov: True
97100
regularize_bn: False
98101
regularize_bias: True
99102
param_schedulers:
100103
lr:
101104
auto_lr_scaling:
102-
auto_scale: true
103-
base_value: 0.4
105+
# if set to True, learning rate will be scaled.
106+
auto_scale: True
107+
# base learning rate value that will be scaled.
108+
base_value: 0.2
109+
# batch size for which the base learning rate is specified. The current batch size
110+
# is used to determine how to scale the base learning rate value.
111+
# scaled_lr = ((batchsize_per_gpu * world_size) * base_value ) / base_lr_batch_size
104112
base_lr_batch_size: 256
105-
name: multistep
106-
values: [0.4, 0.3, 0.2, 0.1, 0.05]
107-
milestones: [16, 32, 48, 64]
108-
update_interval: epoch
113+
# scaling_type can be set to "sqrt" to reduce the impact of scaling on the base value
114+
scaling_type: "linear"
115+
name: constant
116+
update_interval: "epoch"
117+
value: 0.2
109118
DISTRIBUTED:
110119
BACKEND: nccl
111-
NUM_NODES: 8
120+
NUM_NODES: 4
112121
NUM_PROC_PER_NODE: 8
113122
INIT_METHOD: tcp
114123
RUN_ID: auto

configs/config/pretrain/byol/byol_8node_resnet.yaml

+5-4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ config:
6767
RESNETS:
6868
DEPTH: 50
6969
ZERO_INIT_RESIDUAL: True
70-
HEAD:
70+
HEAD:
7171
PARAMS: [
7272
["mlp", {"dims": [2048, 4096, 256], "use_relu": True, "use_bn": True}],
7373
["mlp", {"dims": [256, 4096, 256], "use_relu": True, "use_bn": True}]
@@ -82,15 +82,16 @@ config:
8282
byol_loss:
8383
embedding_dim: 256
8484
momentum: 0.99
85-
OPTIMIZER: # from official BYOL implementation, deepmind-research/byol/configs/byol.py
85+
OPTIMIZER:
8686
name: lars
87-
trust_coefficient: 0.001
87+
eta: 0.001
8888
weight_decay: 1.0e-6
8989
momentum: 0.9
9090
nesterov: False
9191
num_epochs: 300
9292
regularize_bn: False
93-
regularize_bias: True
93+
regularize_bias: False
94+
exclude_bias_and_norm: True
9495
param_schedulers:
9596
lr:
9697
auto_lr_scaling:

vissl/data/ssl_transforms/img_pil_color_distortion.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,16 @@ class ImgPilColorDistortion(ClassyTransform):
2121
randomly convert the image to grayscale.
2222
"""
2323

24-
def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8,
25-
hue=0.2, color_jitter_probability=0.8, grayscale_probability=0.2):
24+
def __init__(
25+
self,
26+
strength,
27+
brightness=0.8,
28+
contrast=0.8,
29+
saturation=0.8,
30+
hue=0.2,
31+
color_jitter_probability=0.8,
32+
grayscale_probability=0.2,
33+
):
2634
"""
2735
Args:
2836
strength (float): A number used to quantify the strength of the
@@ -41,22 +49,23 @@ def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8,
4149
grayscale_probability (float): A floating point number used to
4250
quantify to apply randomly convert image to grayscale with
4351
the assigned probability. Default value is 0.2.
44-
This function follows the Pytorch documentation: https://pytorch.org/vision/stable/transforms.html
4552
"""
4653
self.strength = strength
4754
self.brightness = brightness
4855
self.contrast = contrast
4956
self.saturation = saturation
5057
self.hue = hue
51-
self.color_jitter_probability=color_jitter_probability
52-
self.grayscale_probability=grayscale_probability
58+
self.color_jitter_probability = color_jitter_probability
59+
self.grayscale_probability = grayscale_probability
5360
self.color_jitter = pth_transforms.ColorJitter(
5461
self.brightness * self.strength,
5562
self.contrast * self.strength,
5663
self.saturation * self.strength,
5764
self.hue * self.strength,
5865
)
59-
self.rnd_color_jitter = pth_transforms.RandomApply([self.color_jitter], p=self.color_jitter_probability)
66+
self.rnd_color_jitter = pth_transforms.RandomApply(
67+
[self.color_jitter], p=self.color_jitter_probability
68+
)
6069
self.rnd_gray = pth_transforms.RandomGrayscale(p=self.grayscale_probability)
6170
self.transforms = pth_transforms.Compose([self.rnd_color_jitter, self.rnd_gray])
6271

vissl/hooks/__init__.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from classy_vision.hooks.classy_hook import ClassyHook
1010
from vissl.config import AttrDict
11+
from vissl.hooks.byol_hooks import BYOLHook # noqa
1112
from vissl.hooks.deepclusterv2_hooks import ClusterMemoryHook, InitMemoryHook # noqa
1213
from vissl.hooks.dino_hooks import DINOHook
1314
from vissl.hooks.grad_clip_hooks import GradClipHook # noqa
@@ -21,15 +22,12 @@
2122
)
2223
from vissl.hooks.moco_hooks import MoCoHook # noqa
2324
from vissl.hooks.profiling_hook import ProfilingHook
24-
from vissl.hooks.byol_hooks import BYOLHook # noqa
25-
2625
from vissl.hooks.state_update_hooks import ( # noqa
2726
CheckNanLossHook,
2827
FreezeParametersHook,
2928
SetDataSamplerEpochHook,
3029
SSLModelComplexityHook,
3130
)
32-
from vissl.hooks.byol_hooks import BYOLHook # noqa
3331
from vissl.hooks.swav_hooks import NormalizePrototypesHook # noqa
3432
from vissl.hooks.swav_hooks import SwAVUpdateQueueScoresHook # noqa
3533
from vissl.hooks.swav_momentum_hooks import (
@@ -149,14 +147,6 @@ def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]:
149147
)
150148
]
151149
)
152-
if cfg.LOSS.name == "byol_loss":
153-
hooks.extend(
154-
[
155-
BYOLHook(
156-
cfg.LOSS["byol_loss"]["momentum"],
157-
)
158-
]
159-
)
160150
if cfg.HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY:
161151
hooks.extend([SSLModelComplexityHook()])
162152
if cfg.HOOKS.LOG_GPU_STATS:

vissl/hooks/byol_hooks.py

+27-23
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2-
import math
32
import logging
3+
import math
44

55
import torch
66
from classy_vision import tasks
77
from classy_vision.hooks.classy_hook import ClassyHook
88
from vissl.models import build_model
99
from vissl.utils.env import get_machine_local_and_dist_rank
1010

11+
1112
class BYOLHook(ClassyHook):
1213
"""
13-
BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
14-
is based on Contrastive learning. This hook
15-
creates a target network with the same architecture
16-
as the main online network, but without the projection head.
17-
The online network does not participate in backpropogation,
18-
but instead is an exponential moving average of the online network.
14+
BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
15+
is based on Contrastive learning. This hook
16+
creates a target network with the same architecture
17+
as the main online network, but without the projection head.
18+
The online network does not participate in backpropogation,
19+
but instead is an exponential moving average of the online network.
1920
"""
2021

2122
on_start = ClassyHook._noop
@@ -28,7 +29,7 @@ class BYOLHook(ClassyHook):
2829
on_update = ClassyHook._noop
2930

3031
@staticmethod
31-
def cosine_decay(training_iter, max_iters, initial_value) -> float:
32+
def cosine_decay(training_iter, max_iters, initial_value) -> float:
3233
"""
3334
For a given starting value, this function anneals the learning
3435
rate.
@@ -42,8 +43,8 @@ def target_ema(training_iter, base_ema, max_iters) -> float:
4243
"""
4344
Updates Exponential Moving average of the Target Network.
4445
"""
45-
decay = BYOLHook.cosine_decay(training_iter, max_iters, 1.)
46-
return 1. - (1. - base_ema) * decay
46+
decay = BYOLHook.cosine_decay(training_iter, max_iters, 1.0)
47+
return 1.0 - (1.0 - base_ema) * decay
4748

4849
def _build_byol_target_network(self, task: tasks.ClassyTask) -> None:
4950
"""
@@ -53,27 +54,29 @@ def _build_byol_target_network(self, task: tasks.ClassyTask) -> None:
5354
"""
5455
# Create the encoder, which will slowly track the model
5556
logging.info(
56-
"BYOL: Building BYOL target network - rank %s %s", *get_machine_local_and_dist_rank()
57+
"BYOL: Building BYOL target network - rank %s %s",
58+
*get_machine_local_and_dist_rank(),
5759
)
5860

59-
# Target model has the same architecture, but without the projector head.
60-
target_model_config = task.config['MODEL']
61-
target_model_config['HEAD']['PARAMS'] = target_model_config['HEAD']['PARAMS'][0:1]
61+
# Target model has the same architecture, *without* the projector head.
62+
target_model_config = task.config["MODEL"]
63+
target_model_config["HEAD"]["PARAMS"] = target_model_config["HEAD"]["PARAMS"][
64+
0:1
65+
]
6266
task.loss.target_network = build_model(
6367
target_model_config, task.config["OPTIMIZER"]
6468
)
6569

66-
# TESTED: Target Network and Online network are properly created.
67-
# TODO: Check SyncBatchNorm settings (low prior)
68-
6970
task.loss.target_network.to(task.device)
7071

7172
# Restore an hypothetical checkpoint, else copy the model parameters from the
7273
# online network.
7374
if task.loss.checkpoint is not None:
7475
task.loss.load_state_dict(task.loss.checkpoint)
7576
else:
76-
logging.info("BYOL: Copying and freezing model parameters from online to target network")
77+
logging.info(
78+
"BYOL: Copying and freezing model parameters from online to target network"
79+
)
7780
for param_q, param_k in zip(
7881
task.base_model.parameters(), task.loss.target_network.parameters()
7982
):
@@ -92,7 +95,9 @@ def _update_momentum_coefficient(self, task: tasks.ClassyTask) -> None:
9295
self.total_iters = task.max_iteration
9396
logging.info(f"{self.total_iters} total iters")
9497
training_iteration = task.iteration
95-
self.momentum = self.target_ema(training_iteration, self.base_momentum, self.total_iters)
98+
self.momentum = self.target_ema(
99+
training_iteration, self.base_momentum, self.total_iters
100+
)
96101

97102
@torch.no_grad()
98103
def _update_target_network(self, task: tasks.ClassyTask) -> None:
@@ -106,10 +111,10 @@ def _update_target_network(self, task: tasks.ClassyTask) -> None:
106111
task.base_model.parameters(), task.loss.target_network.parameters()
107112
):
108113
target_params.data = (
109-
target_params.data * self.momentum + online_params.data * (1. - self.momentum)
114+
target_params.data * self.momentum
115+
+ online_params.data * (1.0 - self.momentum)
110116
)
111117

112-
113118
@torch.no_grad()
114119
def on_forward(self, task: tasks.ClassyTask) -> None:
115120
"""
@@ -127,9 +132,8 @@ def on_forward(self, task: tasks.ClassyTask) -> None:
127132
else:
128133
self._update_target_network(task)
129134

130-
131135
# Compute target network embeddings
132-
batch = task.last_batch.sample['input']
136+
batch = task.last_batch.sample["input"]
133137
target_embs = task.loss.target_network(batch)[0]
134138

135139
# Save target embeddings to use them in the loss

vissl/losses/byol_loss.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import torch.nn.functional as F
88
from classy_vision.losses import ClassyLoss, register_loss
99

10-
_BYOLLossConfig = namedtuple(
11-
"_BYOLLossConfig", ["embedding_dim", "momentum"]
12-
)
10+
11+
_BYOLLossConfig = namedtuple("_BYOLLossConfig", ["embedding_dim", "momentum"])
12+
1313

1414
def regression_loss(x, y):
1515
"""
@@ -19,17 +19,16 @@ def regression_loss(x, y):
1919
Cosine similarity. This implementation uses Cosine similarity.
2020
"""
2121
normed_x, normed_y = F.normalize(x, dim=1), F.normalize(y, dim=1)
22-
return torch.sum((normed_x - normed_y).pow(2), dim=1)
22+
# Euclidean Distance squared.
23+
return 2 - 2 * (normed_x * normed_y).sum(dim=1)
2324

2425

2526
class BYOLLossConfig(_BYOLLossConfig):
26-
""" Settings for the BYOL loss"""
27+
"""Settings for the BYOL loss"""
2728

2829
@staticmethod
2930
def defaults() -> "BYOLLossConfig":
30-
return BYOLLossConfig(
31-
embedding_dim=256, momentum=0.999
32-
)
31+
return BYOLLossConfig(embedding_dim=256, momentum=0.999)
3332

3433

3534
@register_loss("byol_loss")
@@ -68,7 +67,9 @@ def from_config(cls, config: BYOLLossConfig) -> "BYOLLoss":
6867
"""
6968
return cls(config)
7069

71-
def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> torch.Tensor:
70+
def forward(
71+
self, online_network_prediction: torch.Tensor, *args, **kwargs
72+
) -> torch.Tensor:
7273
"""
7374
In this function, the Online Network receives the tensor as input after projection
7475
and they make predictions on the output of the target network’s projection,
@@ -79,7 +80,6 @@ def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> t
7980
compute the cross entropy loss for this batch.
8081
8182
Args:
82-
query: output of the encoder given the current batch
8383
online_network_prediction: online model output. this is a prediction of the
8484
target network output.
8585
@@ -91,8 +91,6 @@ def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> t
9191
online_view1, online_view2 = torch.chunk(online_network_prediction, 2, 0)
9292
target_view1, target_view2 = torch.chunk(self.target_embs.detach(), 2, 0)
9393

94-
# TESTED: Views are received correctly.
95-
9694
# Compute losses
9795
loss1 = regression_loss(online_view1, target_view2)
9896
loss2 = regression_loss(online_view2, target_view1)
@@ -111,7 +109,6 @@ def load_state_dict(self, state_dict, *args, **kwargs) -> None:
111109
Args:
112110
state_dict (serialized via torch.save)
113111
"""
114-
115112
# If the encoder has been allocated, use the normal pytorch restoration
116113
if self.target_network is None:
117114
self.checkpoint = state_dict

0 commit comments

Comments
 (0)