Skip to content

EarlyStopping callback issue with Skorch + HuggingFace Accelerate #952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Raphaaal opened this issue Apr 6, 2023 · 10 comments
Closed

EarlyStopping callback issue with Skorch + HuggingFace Accelerate #952

Raphaaal opened this issue Apr 6, 2023 · 10 comments

Comments

@Raphaaal
Copy link

Raphaaal commented Apr 6, 2023

Hi,

I am trying to use the EarlyStopping callback with Skorch + HuggingFace's accelerate on two GPUs.

accelerate config:

compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MULTI_GPU
downcast_bf16: 'no'
dynamo_config: {}
fsdp_config: {}
gpu_ids: 0,1
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

IIUC accelerate, the model is copied on both GPUs and each GPU will receive a different slice of the batch. Is this assumption correct?

This can lead to inconsistencies where only one of the two model copies has its early stopping callback triggered, leaving the remaining process hanging. In this case, the training stalls, possibly because one process is waiting for the other (that has been interrupted due to early-stopping).

It is a bit tricky to reproduce, but here is my attempt:

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"

import torch
import torch.nn as nn
import numpy as np
import random
from skorch import NeuralNetClassifier
from sklearn.datasets import make_classification
from accelerate import Accelerator
from skorch.hf import AccelerateMixin
from skorch.callbacks import EarlyStopping
from skorch.dataset import ValidSplit


# Reproducibility
SEED = 42
def seed_everything(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.dense0 = nn.Linear(100, 50)
        self.dense1 = nn.Linear(50, 25)
        self.dense2 = nn.Linear(25, 2)
        self.nonlin = nn.Softmax(dim=-1)

    def forward(self, X):
        X = self.dense0(X)
        X = self.dense1(X)
        X = self.dense2(X)
        X = self.nonlin(X)
        return X


class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier):
    # cf. issue 944
    def evaluation_step(self, batch, training=False):
        output = super().evaluation_step(batch, training=training)
        return self.accelerator.gather_for_metrics(output)


class AcceleratedEarlyStopping(EarlyStopping):
    def on_train_end(self, net, **kwargs):
        if (
            self.load_best and (self.best_epoch_ != net.history[-1, "epoch"])
            and (self.best_model_weights_ is not None)
        ):
            best_model_weights_accelerate = {
                # Need to load the module keys
                # e.g., `dense0.weight` instead of `module.dense0.weight`
                k.split("module.")[1]: v 
                for k, v 
                in self.best_model_weights_.items()
            }
            net.module_.load_state_dict(best_model_weights_accelerate)
            self._sink(
                f"Process {net.accelerator.process_index}: " # <= added
                f"Restoring best model from epoch {self.best_epoch_}.", 
                verbose=net.verbose
            )


class SkorchAccelerator(Accelerator):
    # cf. issue 944
    def __deepcopy__(self, memo):
        return self


seed_everything()

batch_size = 200

# Data generation
X, y = make_classification(
    100_000, 100, 
    n_informative=10, random_state=SEED, flip_y=0.1
)
X = X.astype(np.float32)
y = y.astype(np.int64)

# For each batch:
# the first GPU will receive the examples [0 :100]
# the second GPU will receive the examples [100 :200]
accelerator = SkorchAccelerator(split_batches=True)

# Model
model_skorch = AcceleratedNeuralNetClassifier(
    accelerator=accelerator, 
    module=MyModule, 
    max_epochs=10, 
    verbose=True, 
    iterator_train__shuffle=False,
    iterator_valid__shuffle=False,
    batch_size=batch_size, 
    train_split=ValidSplit(0.2),
    callbacks=[
        AcceleratedEarlyStopping(
            monitor='valid_loss', 
            patience=3, 
            load_best=True, 
            threshold=0.001,
            lower_is_better=True,
        )
    ],
)

model_skorch.fit(X, y)

preds = model_skorch.predict_proba(X[:5])
if accelerator.is_main_process:
    print(f"Process {accelerator.process_index}: {preds}")
else:
    print(f"Process {accelerator.process_index}: {preds}")

Launched using accelerate launch my_script.py.

Output:

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        0.5071       0.8088        0.4415  0.9237
      2        0.4317       0.8151        0.4305  0.8976
      3        0.4269       0.8184        0.4288  0.9003
      4        0.4260       0.8189        0.4285  0.8913
      5        0.4258       0.8188        0.4284  0.8990
Stopping since valid_loss has not improved in the last 3 epochs.
Process 0: Restoring best model from epoch 3.
Process 0: [[0.92847335 0.07152668]
 [0.9461131  0.05388684]
 [0.68290985 0.3170901 ]
 [0.7143555  0.2856444 ]
 [0.9872084  0.01279159]]

The execution freezes and only the torch.distributed timeout will end it with an error. I suspect that the freeze comes from the KeyboardInterrupt raised by the process early-stopped on one GPU, leaving the accelerator hanging, waiting until timeout for a process that has been interrupted.

Note that when the threshold is less sensitive, the two models on both GPUs are early stopped at the same epoch and no problem arises:

AcceleratedEarlyStopping(
    monitor='valid_loss', 
    patience=3, 
    load_best=True, 
    threshold=0.1, # <= changed
    lower_is_better=True,
)

Output:

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        0.5071       0.8088        0.4415  1.0222
      2        0.4317       0.8151        0.4305  1.0172
      3        0.4269       0.8184        0.4288  1.0142
Stopping since valid_loss has not improved in the last 3 epochs.
Process 1: Restoring best model from epoch 1.
Stopping since valid_loss has not improved in the last 3 epochs.
Process 0: Restoring best model from epoch 1.
Process 1: [[0.86593235 0.13406765]
 [0.93646425 0.06353572]
 [0.45666963 0.5433304 ]
 [0.6188156  0.3811844 ]
 [0.98520964 0.01479038]]
Process 0: [[0.86593235 0.13406765]
 [0.93646425 0.06353572]
 [0.45666963 0.5433304 ]
 [0.6188156  0.3811844 ]
 [0.98520964 0.01479038]]

Maybe a solution could be to find a way to gather all data across GPUs before updating the callback internals (e.g., self.misses). In this way, the callbacks for each model copy would not diverge and would raise KeyboardInterrupt at the same epoch. Without this, I fear that this kind of problem would also arise with other callbacks like Checkpoint.

Many many thanks in advance for your time and reply.

@BenjaminBossan
Copy link
Collaborator

Thanks again for this very detailed bug report. I'll probably only have time to look at this sometime next week, so sorry it'll take a bit longer.

The whole checkpointing code was written before working on the accelerate support, so it's very much possible that there are some issues. Just a few thoughts I had from reading:

  • I wonder if gathering is really the issue here. If you monitor a different score that is calculated using EpochScoring (which, with the bugfix, should gather correctly), do you observe the same behavior?
  • Maybe the KeyboardInterrupt itself is the issue here, which may not be communicated correctly across processes. A solution may be to instead set some kind of flag on the net, e.g. something like this:
# in Checkpoint callback
    def on_epoch_end(self, net, **kwargs):
        ...
        if self.misses_ == self.patience:
            ...
            net.should_interrupt = True
            # raise KeyboardInterrupt  # <= don't raise exception

# in accelerated net
    def run_single_epoch(self, iterator, training, prefix, step_fn, **fit_params):
        ...
        for batch in iterator:
            ...
            batch_count += 1
            if getattr(self, 'should_interrupt', False):
                if self.accelerator.is_main_process:  # maybe not needed
                    self.accelerator.wait_for_everyone()  # maybe not needed
                    break
            ...

I hope I could convey my intent here. Again, I couldn't test it yet, so I might be totally off.

@Raphaaal
Copy link
Author

Raphaaal commented Apr 6, 2023

Thanks for the quick reply.

I wonder if gathering is really the issue here. If you monitor a different score that is calculated using EpochScoring (which, with the bugfix, should gather correctly), do you observe the same behavior?

Will try, thanks for the suggestion. Does this require that EpochScoring is initialized with use_caching=False?

Maybe the KeyboardInterrupt itself is the issue here, which may not be communicated correctly across processes. A solution may be to instead set some kind of flag on the net, e.g. something like this:

Will also try, thanks. Does this mean that EarlyStopping will be triggered even if only one of the processes meets its misses condition? i.e., maybe, if the results were gathered, early stopping would not have been triggered.

Thanks a lot for your recommendations, will keep you posted here.

@Raphaaal
Copy link
Author

Raphaaal commented Apr 9, 2023

Some update:

I wonder if gathering is really the issue here. If you monitor a different score that is calculated using EpochScoring (which, with the bugfix, should gather correctly), do you observe the same behavior?

Good news, when using another score and use_caching=False, the scores are gathered correctly and the EarlyStopping callbacks do not diverge (i.e., they stop at the same epoch).

Code changes:

from skorch.callbacks import EarlyStopping, EpochScoring
...
class AcceleratedEpochScoring(EpochScoring):
    def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
        super().on_epoch_end(net, dataset_train, dataset_valid, **kwargs)
        print(
            f"Epoch {len(net.history)} - Process {accelerator.process_index} "
            f"valid AP: {net.history[-1]['valid_ap']}"
        )
...
callbacks=[
        AcceleratedEpochScoring(
            scoring='average_precision',
            name="valid_ap",
            lower_is_better=False,
            use_caching=False,
        ),
        AcceleratedEarlyStopping(
            monitor='valid_ap', 
            patience=3, 
            load_best=True, 
            threshold=0.0001,
            lower_is_better=False,
        ),
    ]
...

Output:

Epoch 1 - Process 0 valid AP: 0.8667104491884005
Epoch 1 - Process 1 valid AP: 0.8667104491884005
  epoch    train_loss    valid_acc    valid_ap    valid_loss     dur
-------  ------------  -----------  ----------  ------------  ------
      1        0.5071       0.8088      0.8667        0.4415  0.9132
Epoch 2 - Process 0 valid AP: 0.87453497832751
Epoch 2 - Process 1 valid AP: 0.87453497832751
      2        0.4317       0.8151      0.8745        0.4305  0.9020
Epoch 3 - Process 0 valid AP: 0.8762546816799396
Epoch 3 - Process 1 valid AP: 0.8762546816799396
      3        0.4269       0.8184      0.8763        0.4288  0.9039
Epoch 4 - Process 0 valid AP: 0.8766565617981912
Epoch 4 - Process 1 valid AP: 0.8766565617981912
      4        0.4260       0.8189      0.8767        0.4285  0.9013
Epoch 5 - Process 1 valid AP: 0.8767607151401329
Epoch 5 - Process 0 valid AP: 0.8767607151401329
      5        0.4258       0.8188      0.8768        0.4284  0.9012
Epoch 6 - Process 0 valid AP: 0.8767967911974108
Epoch 6 - Process 1 valid AP: 0.8767967911974108
      6        0.4257       0.8186      0.8768        0.4284  0.9026
Epoch 7 - Process 0 valid AP: 0.8768130228274358
Epoch 7 - Process 1 valid AP: 0.8768130228274358
      7        0.4257       0.8186      0.8768        0.4284  0.8985
Epoch 8 - Process 1 valid AP: 0.8768181790654082
Stopping since valid_ap has not improved in the last 3 epochs.
Epoch 8 - Process 0 valid AP: 0.8768181790654082
Stopping since valid_ap has not improved in the last 3 epochs.
Process 1: Restoring best model from epoch 5.
Process 0: Restoring best model from epoch 5.
Process 1: [[0.9338996  0.06610045]
 [0.9456476  0.05435238]
 [0.7032327  0.2967673 ]
 [0.7254356  0.27456442]
 [0.9866545  0.01334549]]
Process 0: [[0.9338996  0.06610045]
 [0.9456476  0.05435238]
 [0.7032327  0.2967673 ]
 [0.7254356  0.27456442]
 [0.9866545  0.01334549]]

Unfortunately, when setting use_caching=True, the scores (and thus the callbacks) diverge, leading to the same inconsistencies as before.

Example:

...
callbacks=[
        AcceleratedEpochScoring(
            scoring='average_precision',
            name="valid_ap",
            lower_is_better=False,
            use_caching=True,
        ),
        AcceleratedEarlyStopping(
            monitor='valid_ap', 
            patience=3, 
            load_best=True, 
            threshold=0.0001,
            lower_is_better=False,
        ),
    ]
...

Output:

Epoch 1 - Process 0 valid AP: 0.8658875089719353
Epoch 1 - Process 1 valid AP: 0.8676702178819956
  epoch    train_loss    valid_acc    valid_ap    valid_loss     dur
-------  ------------  -----------  ----------  ------------  ------
      1        0.5071       0.8088      0.8659        0.4415  0.9064
Epoch 2 - Process 0 valid AP: 0.8739121987265073
Epoch 2 - Process 1 valid AP: 0.8752980170341061
      2        0.4317       0.8151      0.8739        0.4305  0.8672
Epoch 3 - Process 0 valid AP: 0.8756911186984955
Epoch 3 - Process 1 valid AP: 0.8770161878651679
      3        0.4269       0.8184      0.8757        0.4288  0.8659
Epoch 4 - Process 1 valid AP: 0.8774304030718305
Epoch 4 - Process 0 valid AP: 0.876075109821432
      4        0.4260       0.8189      0.8761        0.4285  0.8579
Epoch 5 - Process 1 valid AP: 0.8775437625178769
Epoch 5 - Process 0 valid AP: 0.8761556960152714
      5        0.4258       0.8188      0.8762        0.4284  0.8669
Epoch 6 - Process 0 valid AP: 0.8761769421377053
Epoch 6 - Process 1 valid AP: 0.8775944576563434
      6        0.4257       0.8186      0.8762        0.4284  0.8555
Epoch 7 - Process 1 valid AP: 0.8776132354459487
Epoch 7 - Process 0 valid AP: 0.8761924790408311
      7        0.4257       0.8186      0.8762        0.4284  0.8606
Epoch 8 - Process 0 valid AP: 0.8761933441749683
Epoch 8 - Process 1 valid AP: 0.8776246081399928
Stopping since valid_ap has not improved in the last 3 epochs.
      8        0.4257       0.8186      0.8762        0.4284  0.8500
Process 1: Restoring best model from epoch 5.
Process 1: [[ 7.2763360e-05 -6.3443067e-04]
 [ 1.0134692e-03  6.3253456e-04]
 [-4.8138766e-04 -7.2052615e-04]
 [ 3.3249665e-04  1.3802793e-03]
 [-1.5416289e-04 -1.9535948e-04]]

The execution freezes at this point like before.

Maybe the KeyboardInterrupt itself is the issue here, which may not be communicated correctly across processes. A solution may be to instead set some kind of flag on the net, e.g. something like this:

Have not tried yet because your first recommendation works out for my use case so far.

Thanks again for these suggestions.


EDIT: I don't get why, but the seed_everything() seems required for the scores not to diverge across workers.

Code change:

...
# seed_everything()
...

Output:

Epoch 1 - Process 0 valid AP: 0.6529586114318979Epoch 1 - Process 1 valid AP: 0.6581087486075916

  epoch    train_loss    valid_acc    valid_ap    valid_loss     dur
-------  ------------  -----------  ----------  ------------  ------
      1        0.5155       0.8046      0.6530        0.4494  0.9343
Epoch 2 - Process 1 valid AP: 0.6617621615160396
Epoch 2 - Process 0 valid AP: 0.6560850631151591
      2        0.4359       0.8175      0.6561        0.4371  0.9169
Epoch 3 - Process 0 valid AP: 0.657530893463605
Epoch 3 - Process 1 valid AP: 0.663897991348084
      3        0.4326       0.8173      0.6575        0.4352  0.9161
Epoch 4 - Process 1 valid AP: 0.6635836630290339
Epoch 4 - Process 0 valid AP: 0.6571444995037148
      4        0.4293       0.8209      0.6571        0.4349  0.9250
Epoch 5 - Process 0 valid AP: 0.6571361297506764
Epoch 5 - Process 1 valid AP: 0.6639303478077014
      5        0.4260       0.8196      0.6571        0.4351  0.9278
Epoch 6 - Process 0 valid AP: 0.6574242903366264
Stopping since valid_ap has not improved in the last 3 epochs.
Epoch 6 - Process 1 valid AP: 0.6639567529073835
Stopping since valid_ap has not improved in the last 3 epochs.
Process 1: Restoring best model from epoch 3.
Process 0: Restoring best model from epoch 3.
Process 1: [[0.9379423  0.06205774]
 [0.9535447  0.04645533]
 [0.73085994 0.26914003]
 [0.7446962  0.25530383]
 [0.98844004 0.01156   ]]
Process 0: [[0.9379423  0.06205774]
 [0.9535447  0.04645533]
 [0.73085994 0.26914003]
 [0.7446962  0.25530383]
 [0.98844004 0.01156   ]]

@BenjaminBossan
Copy link
Collaborator

Thanks so much for running those experiments.

First of all, I believe defining how to "correctly" early stop when using data parallel training is tricky. Ideally, we would like to have the outcome be independent of the training procedure (data parallel vs not). Let's say we have two batches, b0 and b1. In sequential training, we do forward(b0), then backwards() + update the model, then forward(b1), then backwards() + update. So when we call forward(b1), the model is already updated. Let's assume the score has improved here.

In parallel training, when we call forward(b1), the model is not updated yet, therefore the results will be different, presumably worse. So it could very well happen that the score has not improved because of that and we early stop. These kinds of differences are expected and unavoidable.

I tried to do some research on how to "correctly" early stop when using pytorch DDP and found surprisingly little. My take away is that it's not often used and maybe should be avoided if possible (possibly using Checkpoint instead to retrieve the best model, but training will continue even if there is no improvement).

Regarding the problem with caching: Yes, I'm not surprised that caching causes issues here. I'm not sure if much can be done, except for documenting the possible issues and recommend to avoid caching with parallel training.

Regarding the necessity for seeding everything: Hmm, interesting, not quite sure what's going on here.

I'll do some of my own experiments, if I find something interesting, I'll post an update.

@BenjaminBossan
Copy link
Collaborator

BenjaminBossan commented Apr 12, 2023

After some more reading and experimenting, I think that the whole situation of using DDP with skorch is a bit more complicated than I expected. Since we have multiple processes, there is a actually a copy of the net.history for each process. These histories will be similar but not identical, since the processes see different batches. As a result, it is unclear what is the correct way to handle anything that deals with the history. Early stopping uses the history, hence it is unclear how early stopping should work correctly.

The most correct approach, if I understand what's going on correctly, would be to synchronize the histories. So when proc0 writes batch metrics and proc1 writes different batch metrics, those two should be merged into a single history and then those histories should be distributed to both processes again. This is not easily achieved, I need to think about if and how that can be implemented (it might require the use of a distributed key value store). Until then, I'd be careful about relying on anything using net.history.

@Raphaaal
Copy link
Author

Thanks a lot for your detailed answers and sorry for the late reply.

In parallel training, when we call forward(b1), the model is not updated yet, therefore the results will be different, presumably worse.

I may be misunderstanding something here. To my my mind, when using DDP, backprop is performed on each GPU (i.e., for each model copy, using its own batch) and then the gradients are averaged across all GPUs before performing the optimizer step. So I am picturing an updated (i.e., synced) model copy when we call forward(b1) in your example. Please let me know if I am wrong.

The most correct approach, if I understand what's going on correctly, would be to synchronize the histories

This would be great indeed. Thanks for the recent changes you commited.

Until then, I'd be careful about relying on anything using net.history

Do you agree that, as long as I am using a validation metric from the history (i.e., a metric computed using Skorch's eval step), it is safe in DDP mode?

Again, thanks a lot for your time. Let me know if I can be of any help.

@BenjaminBossan
Copy link
Collaborator

I may be misunderstanding something here. To my my mind, when using DDP, backprop is performed on each GPU (i.e., for each model copy, using its own batch) and then the gradients are averaged across all GPUs before performing the optimizer step. So I am picturing an updated (i.e., synced) model copy when we call forward(b1) in your example. Please let me know if I am wrong.

Your understanding is correct (unless my understanding is also incorrect), what I meant is that this batch b1 is the "using its own batch" batch, so b0 and b1 are processed independently, gradients are synced, then b2 and b3 are processed independently, etc.

Do you agree that, as long as I am using a validation metric from the history (i.e., a metric computed using Skorch's eval step), it is safe in DDP mode?

Yes, from my current understanding, when using gather_for_metrics, the epoch scores should be synced and thus safe to use.

Again, thanks a lot for your time. Let me know if I can be of any help.

Thanks for the offer. Maybe once #955 is merged, you could try it in your initial test to verify that valid_loss indeed works correctly with early stopping. I tried it successfully with a toy example but having a real world test would be better.

@Raphaaal
Copy link
Author

Hi,

I re-ran the initial test using the recent merge of #955 with two GPUs and Accelerate.
The test seems conclusive.

Code:

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"

import torch
import torch.nn as nn
import numpy as np
import random
from skorch import NeuralNetClassifier
from sklearn.datasets import make_classification
from accelerate import Accelerator
from skorch.hf import AccelerateMixin
from skorch.callbacks import EarlyStopping
from skorch.dataset import ValidSplit
from torch.distributed import TCPStore

# NB: this import is missing in the example from 
# https://skorch.readthedocs.io/en/latest/history.html#skorch.history.DistributedHistory
from skorch.history import DistributedHistory 


# Reproducibility
SEED = 42
def seed_everything(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.dense0 = nn.Linear(100, 50)
        self.dense1 = nn.Linear(50, 25)
        self.dense2 = nn.Linear(25, 2)
        self.nonlin = nn.Softmax(dim=-1)

    def forward(self, X):
        X = self.dense0(X)
        X = self.dense1(X)
        X = self.dense2(X)
        X = self.nonlin(X)
        return X


class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier):
    # cf. issue 944
    def evaluation_step(self, batch, training=False):
        output = super().evaluation_step(batch, training=training)
        return self.accelerator.gather_for_metrics(output)


class AcceleratedEarlyStopping(EarlyStopping):
    def on_train_end(self, net, **kwargs):
        if (
            self.load_best and (self.best_epoch_ != net.history[-1, "epoch"])
            and (self.best_model_weights_ is not None)
        ):
            best_model_weights_accelerate = {
                # Need to load the module keys
                # e.g., `dense0.weight` instead of `module.dense0.weight`
                k.split("module.")[1]: v 
                for k, v 
                in self.best_model_weights_.items()
            }
            net.module_.load_state_dict(best_model_weights_accelerate)
            self._sink(
                f"Process {net.accelerator.process_index}: " # <= added
                f"Restoring best model from epoch {self.best_epoch_}.", 
                verbose=net.verbose
            )


class SkorchAccelerator(Accelerator):
    # cf. issue 944
    def __deepcopy__(self, memo):
        return self


seed_everything()

batch_size = 200

# Data generation
X, y = make_classification(
    100_000, 100, 
    n_informative=10, random_state=SEED, flip_y=0.1
)
X = X.astype(np.float32)
y = y.astype(np.int64)

# Accelerate
accelerator = SkorchAccelerator(split_batches=True)

# Distributed history config
is_master = accelerator.is_main_process
world_size = accelerator.num_processes
rank = accelerator.local_process_index
store = TCPStore(
    "127.0.0.1", 
    port=1234, 
    world_size=world_size, 
    is_master=is_master
    )
dist_history = DistributedHistory(
    store=store, 
    rank=rank, 
    world_size=world_size
)

# Model
model_skorch = AcceleratedNeuralNetClassifier(
    history=dist_history,
    accelerator=accelerator, 

    module=MyModule, 
    max_epochs=10, 
    verbose=True, 
    iterator_train__shuffle=False,
    iterator_valid__shuffle=False,
    batch_size=batch_size, 
    train_split=ValidSplit(0.2),
    callbacks=[
        AcceleratedEarlyStopping(
            monitor='valid_loss', 
            patience=3, 
            load_best=True, 
            threshold=0.001,
            lower_is_better=True,
        )
    ],
)

model_skorch.fit(X, y)

preds = model_skorch.predict_proba(X[:5])
print(f"Process {accelerator.process_index}: {preds}")

Output:

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        0.5084       0.8088        0.4413  1.3090
      2        0.4341       0.8151        0.4305  1.3063
      3        0.4296       0.8184        0.4287  1.3474
      4        0.4288       0.8189        0.4284  1.3253
      5        0.4287       0.8188        0.4283  1.2274
      6        0.4287       0.8186        0.4283  1.2227
      7        0.4287       0.8186        0.4283  1.2602
Stopping since valid_loss has not improved in the last 3 epochs.
Stopping since valid_loss has not improved in the last 3 epochs.
Process 1: Restoring best model from epoch 5.
Process 0: Restoring best model from epoch 5.
Process 0: [[0.9338996  0.06610045]
 [0.9456476  0.05435238]
 [0.7032327  0.2967673 ]
 [0.7254356  0.27456442]
 [0.9866545  0.01334549]]
Process 1: [[0.9338996  0.06610045]
 [0.9456476  0.05435238]
 [0.7032327  0.2967673 ]
 [0.7254356  0.27456442]
 [0.9866545  0.01334549]]

The execution completes normally.
Note that the printed DistributedHistory differs lightly vs. the historical standard History (e.g., valid_loss). I guess this is expected because the history values are now aggregated (and they were only printed for the main process before).

Thanks again

@BenjaminBossan
Copy link
Collaborator

Note that the printed DistributedHistory differs lightly vs. the historical standard History (e.g., valid_loss). I guess this is expected because the history values are now aggregated (and they were only printed for the main process before).

Yes, exactly, that's why a small difference is expected. Thanks for testing.

@BenjaminBossan
Copy link
Collaborator

Small update: With the new sklearn v1.3.0 release, the devs added a new protocol, __sklearn_clone__. When sklearn clones an object, it will first check if that method is present. This means that instead of overriding __deepcopy__, which works but may have some unintended side effects, it is better to override __sklearn_clone__ now:

class MyAccelerator:
    def __sklearn_clone__(self):
        return self

I tried it out and it works. So if sklearn 1.3 is an option, it is better to override __sklearn_clone__ .

I asked the accelerate devs if they would consider adding this to accelerate, let's see. I could imagine this is more acceptable to them than overriding __deepcopy__. If they do, all good, else I'll just update the skorch docs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants