-
Notifications
You must be signed in to change notification settings - Fork 398
EarlyStopping callback issue with Skorch + HuggingFace Accelerate #952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks again for this very detailed bug report. I'll probably only have time to look at this sometime next week, so sorry it'll take a bit longer. The whole checkpointing code was written before working on the accelerate support, so it's very much possible that there are some issues. Just a few thoughts I had from reading:
# in Checkpoint callback
def on_epoch_end(self, net, **kwargs):
...
if self.misses_ == self.patience:
...
net.should_interrupt = True
# raise KeyboardInterrupt # <= don't raise exception
# in accelerated net
def run_single_epoch(self, iterator, training, prefix, step_fn, **fit_params):
...
for batch in iterator:
...
batch_count += 1
if getattr(self, 'should_interrupt', False):
if self.accelerator.is_main_process: # maybe not needed
self.accelerator.wait_for_everyone() # maybe not needed
break
... I hope I could convey my intent here. Again, I couldn't test it yet, so I might be totally off. |
Thanks for the quick reply.
Will try, thanks for the suggestion. Does this require that
Will also try, thanks. Does this mean that Thanks a lot for your recommendations, will keep you posted here. |
Some update:
Good news, when using another score and Code changes: from skorch.callbacks import EarlyStopping, EpochScoring
...
class AcceleratedEpochScoring(EpochScoring):
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
super().on_epoch_end(net, dataset_train, dataset_valid, **kwargs)
print(
f"Epoch {len(net.history)} - Process {accelerator.process_index} "
f"valid AP: {net.history[-1]['valid_ap']}"
)
...
callbacks=[
AcceleratedEpochScoring(
scoring='average_precision',
name="valid_ap",
lower_is_better=False,
use_caching=False,
),
AcceleratedEarlyStopping(
monitor='valid_ap',
patience=3,
load_best=True,
threshold=0.0001,
lower_is_better=False,
),
]
... Output:
Unfortunately, when setting Example: ...
callbacks=[
AcceleratedEpochScoring(
scoring='average_precision',
name="valid_ap",
lower_is_better=False,
use_caching=True,
),
AcceleratedEarlyStopping(
monitor='valid_ap',
patience=3,
load_best=True,
threshold=0.0001,
lower_is_better=False,
),
]
... Output:
The execution freezes at this point like before.
Have not tried yet because your first recommendation works out for my use case so far. Thanks again for these suggestions. EDIT: I don't get why, but the Code change: ...
# seed_everything()
... Output:
|
Thanks so much for running those experiments. First of all, I believe defining how to "correctly" early stop when using data parallel training is tricky. Ideally, we would like to have the outcome be independent of the training procedure (data parallel vs not). Let's say we have two batches, b0 and b1. In sequential training, we do forward(b0), then backwards() + update the model, then forward(b1), then backwards() + update. So when we call forward(b1), the model is already updated. Let's assume the score has improved here. In parallel training, when we call forward(b1), the model is not updated yet, therefore the results will be different, presumably worse. So it could very well happen that the score has not improved because of that and we early stop. These kinds of differences are expected and unavoidable. I tried to do some research on how to "correctly" early stop when using pytorch DDP and found surprisingly little. My take away is that it's not often used and maybe should be avoided if possible (possibly using Regarding the problem with caching: Yes, I'm not surprised that caching causes issues here. I'm not sure if much can be done, except for documenting the possible issues and recommend to avoid caching with parallel training. Regarding the necessity for seeding everything: Hmm, interesting, not quite sure what's going on here. I'll do some of my own experiments, if I find something interesting, I'll post an update. |
After some more reading and experimenting, I think that the whole situation of using DDP with skorch is a bit more complicated than I expected. Since we have multiple processes, there is a actually a copy of the The most correct approach, if I understand what's going on correctly, would be to synchronize the histories. So when proc0 writes batch metrics and proc1 writes different batch metrics, those two should be merged into a single history and then those histories should be distributed to both processes again. This is not easily achieved, I need to think about if and how that can be implemented (it might require the use of a distributed key value store). Until then, I'd be careful about relying on anything using |
Thanks a lot for your detailed answers and sorry for the late reply.
I may be misunderstanding something here. To my my mind, when using DDP, backprop is performed on each GPU (i.e., for each model copy, using its own batch) and then the gradients are averaged across all GPUs before performing the optimizer step. So I am picturing an updated (i.e., synced) model copy when we call
This would be great indeed. Thanks for the recent changes you commited.
Do you agree that, as long as I am using a validation metric from the history (i.e., a metric computed using Skorch's eval step), it is safe in DDP mode? Again, thanks a lot for your time. Let me know if I can be of any help. |
Your understanding is correct (unless my understanding is also incorrect), what I meant is that this batch b1 is the "using its own batch" batch, so b0 and b1 are processed independently, gradients are synced, then b2 and b3 are processed independently, etc.
Yes, from my current understanding, when using
Thanks for the offer. Maybe once #955 is merged, you could try it in your initial test to verify that |
Hi, I re-ran the initial test using the recent merge of #955 with two GPUs and Code: import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
import torch
import torch.nn as nn
import numpy as np
import random
from skorch import NeuralNetClassifier
from sklearn.datasets import make_classification
from accelerate import Accelerator
from skorch.hf import AccelerateMixin
from skorch.callbacks import EarlyStopping
from skorch.dataset import ValidSplit
from torch.distributed import TCPStore
# NB: this import is missing in the example from
# https://skorch.readthedocs.io/en/latest/history.html#skorch.history.DistributedHistory
from skorch.history import DistributedHistory
# Reproducibility
SEED = 42
def seed_everything(seed=42):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense0 = nn.Linear(100, 50)
self.dense1 = nn.Linear(50, 25)
self.dense2 = nn.Linear(25, 2)
self.nonlin = nn.Softmax(dim=-1)
def forward(self, X):
X = self.dense0(X)
X = self.dense1(X)
X = self.dense2(X)
X = self.nonlin(X)
return X
class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier):
# cf. issue 944
def evaluation_step(self, batch, training=False):
output = super().evaluation_step(batch, training=training)
return self.accelerator.gather_for_metrics(output)
class AcceleratedEarlyStopping(EarlyStopping):
def on_train_end(self, net, **kwargs):
if (
self.load_best and (self.best_epoch_ != net.history[-1, "epoch"])
and (self.best_model_weights_ is not None)
):
best_model_weights_accelerate = {
# Need to load the module keys
# e.g., `dense0.weight` instead of `module.dense0.weight`
k.split("module.")[1]: v
for k, v
in self.best_model_weights_.items()
}
net.module_.load_state_dict(best_model_weights_accelerate)
self._sink(
f"Process {net.accelerator.process_index}: " # <= added
f"Restoring best model from epoch {self.best_epoch_}.",
verbose=net.verbose
)
class SkorchAccelerator(Accelerator):
# cf. issue 944
def __deepcopy__(self, memo):
return self
seed_everything()
batch_size = 200
# Data generation
X, y = make_classification(
100_000, 100,
n_informative=10, random_state=SEED, flip_y=0.1
)
X = X.astype(np.float32)
y = y.astype(np.int64)
# Accelerate
accelerator = SkorchAccelerator(split_batches=True)
# Distributed history config
is_master = accelerator.is_main_process
world_size = accelerator.num_processes
rank = accelerator.local_process_index
store = TCPStore(
"127.0.0.1",
port=1234,
world_size=world_size,
is_master=is_master
)
dist_history = DistributedHistory(
store=store,
rank=rank,
world_size=world_size
)
# Model
model_skorch = AcceleratedNeuralNetClassifier(
history=dist_history,
accelerator=accelerator,
module=MyModule,
max_epochs=10,
verbose=True,
iterator_train__shuffle=False,
iterator_valid__shuffle=False,
batch_size=batch_size,
train_split=ValidSplit(0.2),
callbacks=[
AcceleratedEarlyStopping(
monitor='valid_loss',
patience=3,
load_best=True,
threshold=0.001,
lower_is_better=True,
)
],
)
model_skorch.fit(X, y)
preds = model_skorch.predict_proba(X[:5])
print(f"Process {accelerator.process_index}: {preds}") Output: epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.5084 0.8088 0.4413 1.3090
2 0.4341 0.8151 0.4305 1.3063
3 0.4296 0.8184 0.4287 1.3474
4 0.4288 0.8189 0.4284 1.3253
5 0.4287 0.8188 0.4283 1.2274
6 0.4287 0.8186 0.4283 1.2227
7 0.4287 0.8186 0.4283 1.2602
Stopping since valid_loss has not improved in the last 3 epochs.
Stopping since valid_loss has not improved in the last 3 epochs.
Process 1: Restoring best model from epoch 5.
Process 0: Restoring best model from epoch 5.
Process 0: [[0.9338996 0.06610045]
[0.9456476 0.05435238]
[0.7032327 0.2967673 ]
[0.7254356 0.27456442]
[0.9866545 0.01334549]]
Process 1: [[0.9338996 0.06610045]
[0.9456476 0.05435238]
[0.7032327 0.2967673 ]
[0.7254356 0.27456442]
[0.9866545 0.01334549]] The execution completes normally. Thanks again |
Yes, exactly, that's why a small difference is expected. Thanks for testing. |
Small update: With the new sklearn v1.3.0 release, the devs added a new protocol, class MyAccelerator:
def __sklearn_clone__(self):
return self I tried it out and it works. So if sklearn 1.3 is an option, it is better to override I asked the accelerate devs if they would consider adding this to accelerate, let's see. I could imagine this is more acceptable to them than overriding |
Hi,
I am trying to use the
EarlyStopping
callback withSkorch
+ HuggingFace'saccelerate
on two GPUs.accelerate
config:IIUC
accelerate
, the model is copied on both GPUs and each GPU will receive a different slice of the batch. Is this assumption correct?This can lead to inconsistencies where only one of the two model copies has its early stopping callback triggered, leaving the remaining process hanging. In this case, the training stalls, possibly because one process is waiting for the other (that has been interrupted due to early-stopping).
It is a bit tricky to reproduce, but here is my attempt:
Launched using
accelerate launch my_script.py
.Output:
The execution freezes and only the
torch.distributed
timeout will end it with an error. I suspect that the freeze comes from theKeyboardInterrupt
raised by the process early-stopped on one GPU, leaving theaccelerator
hanging, waiting until timeout for a process that has been interrupted.Note that when the threshold is less sensitive, the two models on both GPUs are early stopped at the same epoch and no problem arises:
Output:
Maybe a solution could be to find a way to gather all data across GPUs before updating the callback internals (e.g.,
self.misses
). In this way, the callbacks for each model copy would not diverge and would raiseKeyboardInterrupt
at the same epoch. Without this, I fear that this kind of problem would also arise with other callbacks likeCheckpoint
.Many many thanks in advance for your time and reply.
The text was updated successfully, but these errors were encountered: