Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 3539f57

Browse files
stephenyan1231facebook-github-bot
authored andcommitted
mixup data augmentation (#469)
Summary: Pull Request resolved: #469 This diff implements the mixup data augmentation in the paper `mixup: Beyond Empirical Risk Minimization` (https://arxiv.org/abs/1710.09412) Empirically, it is much faster to do mixup transform on gpu than doing that on cpu. # Results accuracy gain - 1.0% with 135 training epochs - 1.3% with 270 training epochs [TODO]: fix accuracy meter at training phases. Reviewed By: mannatsingh Differential Revision: D20911088 fbshipit-source-id: 339c1939eaa224125a072fe971a2e1ce958ca26a
1 parent c635e82 commit 3539f57

File tree

7 files changed

+144
-8
lines changed

7 files changed

+144
-8
lines changed
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict, Optional
8+
9+
import torch
10+
from classy_vision.generic.util import convert_to_one_hot
11+
from torch.distributions.beta import Beta
12+
13+
14+
class MixupTransform:
15+
"""
16+
This implements the mixup data augmentation in the paper
17+
"mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412)
18+
"""
19+
20+
def __init__(self, alpha: float, num_classes: Optional[int] = None):
21+
"""
22+
Args:
23+
alpha: the hyperparameter of Beta distribution used to sample mixup
24+
coefficient.
25+
num_classes: number of classes in the dataset.
26+
"""
27+
self.alpha = alpha
28+
self.num_classes = num_classes
29+
30+
def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
31+
"""
32+
Args:
33+
sample: the batch data.
34+
"""
35+
if sample["target"].ndim == 1:
36+
assert self.num_classes is not None, "num_classes is expected for 1D target"
37+
sample["target"] = convert_to_one_hot(
38+
sample["target"].view(-1, 1), self.num_classes
39+
)
40+
else:
41+
assert sample["target"].ndim == 2, "target tensor shape must be 1D or 2D"
42+
43+
c = Beta(self.alpha, self.alpha).sample().to(device=sample["target"].device)
44+
permuted_indices = torch.randperm(sample["target"].shape[0])
45+
for key in ["input", "target"]:
46+
sample[key] = c * sample[key] + (1.0 - c) * sample[key][permuted_indices, :]
47+
48+
return sample

classy_vision/generic/util.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -736,12 +736,11 @@ def maybe_convert_to_one_hot(target, model_output):
736736
):
737737
target = convert_to_one_hot(target.view(-1, 1), model_output.shape[1])
738738

739-
assert (target.shape == model_output.shape) and (
740-
torch.min(target.eq(0) + target.eq(1)) == 1
741-
), (
742-
"Target must be one-hot/multi-label encoded and of the "
743-
"same shape as model_output."
744-
)
739+
# target are not necessarily hard 0/1 encoding. It can be soft
740+
# (i.e. fractional) in some cases, such as mixup label
741+
assert (
742+
target.shape == model_output.shape
743+
), "Target must of the same shape as model_output."
745744

746745
return target
747746

classy_vision/losses/soft_target_cross_entropy_loss.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import torch
1212
import torch.nn.functional as F
13+
from classy_vision.generic.util import convert_to_one_hot
1314
from classy_vision.losses import ClassyLoss, register_loss
1415

1516

@@ -58,13 +59,19 @@ def from_config(cls, config: Dict[str, Any]) -> "SoftTargetCrossEntropyLoss":
5859
def forward(self, output, target):
5960
"""for N examples and C classes
6061
- output: N x C these are raw outputs (without softmax/sigmoid)
61-
- target: N x C corresponding targets
62+
- target: N x C or N corresponding targets
6263
6364
Target elements set to ignore_index contribute 0 loss.
6465
6566
Samples where all entries are ignore_index do not contribute to the loss
6667
reduction.
6768
"""
69+
# check if targets are inputted as class integers
70+
if target.ndim == 1:
71+
assert (
72+
output.shape[0] == target.shape[0]
73+
), "SoftTargetCrossEntropyLoss requires output and target to have same batch size"
74+
target = convert_to_one_hot(target.view(-1, 1), output.shape[1])
6875
assert (
6976
output.shape == target.shape
7077
), "SoftTargetCrossEntropyLoss requires output and target to be same"

classy_vision/meters/accuracy_meter.py

-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def update(self, model_output, target, **kwargs):
145145
for i, k in enumerate(self._topk):
146146
self._curr_correct_predictions_k[i] += (
147147
torch.gather(target, dim=1, index=pred[:, :k])
148-
.long()
149148
.max(dim=1)
150149
.values.sum()
151150
.item()

classy_vision/tasks/classification_task.py

+26
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616
from classy_vision.dataset import ClassyDataset, build_dataset
17+
from classy_vision.dataset.transforms.mixup import MixupTransform
1718
from classy_vision.generic.distributed_util import (
1819
all_reduce_mean,
1920
barrier,
@@ -141,6 +142,7 @@ def __init__(self):
141142
BroadcastBuffersMode.DISABLED
142143
)
143144
self.amp_args = None
145+
self.mixup_transform = None
144146
self.perf_log = []
145147
self.last_batch = None
146148
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
@@ -326,6 +328,19 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
326328
logging.info(f"AMP enabled with args {amp_args}")
327329
return self
328330

331+
def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]):
332+
"""Disable / enable mixup transform for data augmentation
333+
334+
Args::
335+
mixup_transform: a callable object which performs mixup data augmentation
336+
"""
337+
self.mixup_transform = mixup_transform
338+
if mixup_transform is None:
339+
logging.info(f"mixup disabled")
340+
else:
341+
logging.info(f"mixup enabled")
342+
return self
343+
329344
@classmethod
330345
def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
331346
"""Instantiates a ClassificationTask from a configuration.
@@ -353,6 +368,13 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
353368
meters = build_meters(config.get("meters", {}))
354369
model = build_model(config["model"])
355370

371+
mixup_transform = None
372+
if config.get("mixup") is not None:
373+
assert "alpha" in config["mixup"], "key alpha is missing in mixup dict"
374+
mixup_transform = MixupTransform(
375+
config["mixup"]["alpha"], config["mixup"].get("num_classes")
376+
)
377+
356378
# hooks config is optional
357379
hooks_config = config.get("hooks")
358380
hooks = []
@@ -371,6 +393,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
371393
.set_optimizer(optimizer)
372394
.set_meters(meters)
373395
.set_amp_args(amp_args)
396+
.set_mixup_transform(mixup_transform)
374397
.set_distributed_options(
375398
broadcast_buffers_mode=BroadcastBuffersMode[
376399
config.get("broadcast_buffers", "disabled").upper()
@@ -775,6 +798,9 @@ def train_step(self):
775798
for key, value in sample.items():
776799
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)
777800

801+
if self.mixup_transform is not None:
802+
sample = self.mixup_transform(sample)
803+
778804
with torch.enable_grad():
779805
# Forward pass
780806
output = self.model(sample["input"])

test/dataset_transforms_mixup_test.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import unittest
9+
10+
import torch
11+
from classy_vision.dataset.transforms.mixup import MixupTransform
12+
13+
14+
class DatasetTransformsMixupTest(unittest.TestCase):
15+
def test_mixup_transform_single_label(self):
16+
alpha = 2.0
17+
num_classes = 3
18+
mixup_transform = MixupTransform(alpha, num_classes)
19+
sample = {
20+
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
21+
"target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32),
22+
}
23+
sample_mixup = mixup_transform(sample)
24+
self.assertTrue(sample["input"].shape == sample_mixup["input"].shape)
25+
self.assertTrue(sample_mixup["target"].shape[0] == 4)
26+
self.assertTrue(sample_mixup["target"].shape[1] == 3)
27+
28+
def test_mixup_transform_single_label_missing_num_classes(self):
29+
alpha = 2.0
30+
mixup_transform = MixupTransform(alpha, None)
31+
sample = {
32+
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
33+
"target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32),
34+
}
35+
with self.assertRaises(Exception):
36+
mixup_transform(sample)
37+
38+
def test_mixup_transform_multi_label(self):
39+
alpha = 2.0
40+
mixup_transform = MixupTransform(alpha, None)
41+
sample = {
42+
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
43+
"target": torch.as_tensor(
44+
[[1, 0, 0, 0], [0, 1, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1]],
45+
dtype=torch.int32,
46+
),
47+
}
48+
sample_mixup = mixup_transform(sample)
49+
self.assertTrue(sample["input"].shape == sample_mixup["input"].shape)
50+
self.assertTrue(sample["target"].shape == sample_mixup["target"].shape)

test/losses_soft_target_cross_entropy_loss_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ def test_soft_target_cross_entropy(self):
4747
targets = torch.tensor([[-1, 0, 0, 0, 1]])
4848
self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918)
4949

50+
def test_soft_target_cross_entropy_integer_label(self):
51+
config = self._get_config()
52+
crit = SoftTargetCrossEntropyLoss.from_config(config)
53+
outputs = self._get_outputs()
54+
targets = torch.tensor([4])
55+
self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918)
56+
5057
def test_unnormalized_soft_target_cross_entropy(self):
5158
config = {
5259
"name": "soft_target_cross_entropy",

0 commit comments

Comments
 (0)