Skip to content

[WIP] Initial implementation of AMP support #707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/user/neuralnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ CUDA before being passed to the PyTorch :class:`~torch.nn.Module`. The
device parameter adheres to the general syntax of the PyTorch device
parameter.

amp_enabled and grad_scaler
^^^^^^^^^^^^^^^^^^^^^^^^^^^

TODO

initialize()
^^^^^^^^^^^^

Expand Down
53 changes: 37 additions & 16 deletions examples/benchmarks/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

```
python examples/benchmarks/mnist.py
python examples/benchmarks/mnist.py --device cpu --num_samples 5000
python examples/benchmarks/mnist.py --device cpu --num_samples 5000 --amp_enabled true
```

When called the first time, this will download MNIST data to
Expand All @@ -25,7 +25,6 @@
"""

import argparse
import os
import time

import numpy as np
Expand All @@ -34,11 +33,12 @@
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
import torch
from torch import nn

from skorch.utils import to_device
from skorch import NeuralNetClassifier
from skorch.callbacks import EpochScoring
import torch
from torch import nn


BATCH_SIZE = 128
Expand Down Expand Up @@ -95,6 +95,7 @@ def performance_skorch(
device,
lr,
max_epochs,
amp_enabled=False,
):
torch.manual_seed(0)
net = NeuralNetClassifier(
Expand All @@ -104,6 +105,7 @@ def performance_skorch(
lr=lr,
device=device,
max_epochs=max_epochs,
amp_enabled=amp_enabled,
callbacks=[
('tr_acc', EpochScoring(
'accuracy',
Expand Down Expand Up @@ -141,6 +143,7 @@ def train_torch(
device,
lr,
max_epochs,
amp_enabled=False,
):
model = to_device(model, device)

Expand Down Expand Up @@ -168,6 +171,7 @@ def train_torch(
device=device,
criterion=criterion,
optimizer=optimizer,
amp_enabled=amp_enabled,
)
report(y=y_train, epoch=epoch, training=True, **train_out)

Expand All @@ -177,6 +181,7 @@ def train_torch(
batch_size=batch_size,
device=device,
criterion=criterion,
amp_enabled=amp_enabled,
)
report(y=y_valid, epoch=epoch, training=False, **valid_out)

Expand All @@ -185,20 +190,29 @@ def train_torch(
return model


def train_step(model, dataset, device, criterion, batch_size, optimizer):
def train_step(model, dataset, device, criterion, batch_size, optimizer, amp_enabled=False):
model.train()
y_preds = []
losses = []
batch_sizes = []
tic = time.time()
scaler = None if not amp_enabled else torch.cuda.amp.GradScaler()

for Xi, yi in torch.utils.data.DataLoader(dataset, batch_size=batch_size):
Xi, yi = to_device(Xi, device), to_device(yi, device)
optimizer.zero_grad()
y_pred = model(Xi)
y_pred = torch.log(y_pred)
loss = criterion(y_pred, yi)
loss.backward()
optimizer.step()
with torch.cuda.amp.autocast(enabled=amp_enabled):
y_pred = model(Xi)
y_pred = torch.log(y_pred)
loss = criterion(y_pred, yi)

if not amp_enabled:
loss.backward()
optimizer.step()
else:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since PyTorch XLA requires to wrap the step of the optimizer as well (and this is a feature me might want to support in the future as well once TPUs become more accessible for smaller companies) I suggest that we introduce something akin to self.optimizer_step(optimizer) which sorts stuff like AMP scaling and XLA optimizations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to give this a try in the context of this PR to see if it makes things more ergonomic.


y_preds.append(y_pred)
losses.append(loss.item())
Expand All @@ -212,7 +226,7 @@ def train_step(model, dataset, device, criterion, batch_size, optimizer):
}


def valid_step(model, dataset, device, criterion, batch_size):
def valid_step(model, dataset, device, criterion, batch_size, amp_enabled=False):
model.eval()
y_preds = []
losses = []
Expand All @@ -223,9 +237,10 @@ def valid_step(model, dataset, device, criterion, batch_size):
dataset, batch_size=batch_size,
):
Xi, yi = to_device(Xi, device), to_device(yi, device)
y_pred = model(Xi)
y_pred = torch.log(y_pred)
loss = criterion(y_pred, yi)
with torch.cuda.amp.autocast(enabled=amp_enabled):
y_pred = model(Xi)
y_pred = torch.log(y_pred)
loss = criterion(y_pred, yi)

y_preds.append(y_pred)
loss = loss.item()
Expand All @@ -249,6 +264,7 @@ def performance_torch(
device,
lr,
max_epochs,
amp_enabled=False,
):
torch.manual_seed(0)
model = ClassifierModule()
Expand All @@ -262,6 +278,7 @@ def performance_torch(
device=device,
max_epochs=max_epochs,
lr=0.1,
amp_enabled=True,
)

X_test = torch.tensor(X_test).to(device)
Expand All @@ -270,7 +287,7 @@ def performance_torch(
return accuracy_score(y_test, y_pred)


def main(device, num_samples):
def main(device, num_samples, amp_enabled):
data = get_data(num_samples)
# trigger potential cuda call overhead
torch.zeros(1).to(device)
Expand All @@ -284,6 +301,7 @@ def main(device, num_samples):
max_epochs=MAX_EPOCHS,
lr=LEARNING_RATE,
device=device,
amp_enabled=amp_enabled,
)
time_skorch = time.time() - tic

Expand All @@ -296,6 +314,7 @@ def main(device, num_samples):
max_epochs=MAX_EPOCHS,
lr=LEARNING_RATE,
device=device,
amp_enabled=amp_enabled,
)
time_torch = time.time() - tic

Expand All @@ -314,5 +333,7 @@ def main(device, num_samples):
help='device (e.g. "cuda", "cpu")')
parser.add_argument('--num_samples', type=int, default=20000,
help='total number of samples to use')
parser.add_argument('--amp_enabled', type=bool, default=False,
help='whether to enable automatic mixed precision')
args = parser.parse_args()
main(device=args.device, num_samples=args.num_samples)
main(device=args.device, num_samples=args.num_samples, amp_enabled=args.amp_enabled)
7 changes: 6 additions & 1 deletion skorch/callbacks/regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn.utils import clip_grad_norm_

from skorch.callbacks import Callback
from skorch.utils import _unscale_optimizer_grads


__all__ = ['GradientNormClipping']
Expand Down Expand Up @@ -37,10 +38,14 @@ def __init__(
self.gradient_clip_value = gradient_clip_value
self.gradient_clip_norm_type = gradient_clip_norm_type

def on_grad_computed(self, _, named_parameters, **kwargs):
def on_grad_computed(self, net, named_parameters, **kwargs):
"""TODO"""
if self.gradient_clip_value is None:
return

if net.amp_enabled:
_unscale_optimizer_grads(net.grad_scaler_, net.optimizer_)

clip_grad_norm_(
(p for _, p in named_parameters),
max_norm=self.gradient_clip_value,
Expand Down
9 changes: 9 additions & 0 deletions skorch/callbacks/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class Checkpoint(Callback):

Supports the same format specifiers as ``f_params``.

f_grad_scaler TODO

f_history : file-like object, str, None (default='history.json')
File path to the file or file-like object where the model
training history should be saved. Pass ``None`` to disable
Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(
f_params='params.pt',
f_optimizer='optimizer.pt',
f_criterion='criterion.pt',
f_grad_scaler=None,
f_history='history.json',
f_pickle=None,
fn_prefix='',
Expand All @@ -151,6 +154,7 @@ def __init__(
self.f_params = f_params
self.f_optimizer = f_optimizer
self.f_criterion = f_criterion
self.f_grad_scaler = f_grad_scaler
self.f_history = f_history
self.f_pickle = f_pickle
self.fn_prefix = fn_prefix
Expand Down Expand Up @@ -216,6 +220,7 @@ def save_model(self, net):
- optimizer state;
- criterion state;
- training history;
- grad scaler, if any;
- custom modules;
- entire model object.

Expand Down Expand Up @@ -670,6 +675,8 @@ class TrainEndCheckpoint(Callback):

Supports the same format specifiers as ``f_params``.

f_grad_scaler TODO

f_history : file-like object, str, None (default='history.json')
File path to the file or file-like object where the model
training history should be saved. Pass ``None`` to disable
Expand Down Expand Up @@ -700,6 +707,7 @@ def __init__(
f_params='params.pt',
f_optimizer='optimizer.pt',
f_criterion='criterion.pt',
f_grad_scaler=None,
f_history='history.json',
f_pickle=None,
fn_prefix='train_end_',
Expand All @@ -710,6 +718,7 @@ def __init__(
self.f_params = f_params
self.f_optimizer = f_optimizer
self.f_criterion = f_criterion
self.f_grad_scaler = f_grad_scaler
self.f_history = f_history
self.f_pickle = f_pickle
self.fn_prefix = fn_prefix
Expand Down
Loading