Skip to content

configure_optimizers in LightningCLI doesn't work in distributed mode #16489

Open
@stevenmanton

Description

@stevenmanton

Bug description

LightningCLI provides a way to override a model's configure_optimizers method. You can do this as a subclass or through the configuration. This works well on the CPU and if there's a single GPU, but it seems to fail for multiple GPUs. That is, it seems that the parameters provided in the configuration aren't actually used in distributed mode.

How to reproduce the bug

Because logging the optimizer is tricky in distributed mode, one easy way to test is to remove the configure_optimizers method from the model. Since LightningCLI.configure_optimizers overrides this method, you can actually do this. However, since it fails in distributed mode, you'll get an error.

For example, with the following script:

# dummy.py
from pytorch_lightning.demos.boring_classes import BoringModel, BoringDataModule
from pytorch_lightning.cli import LightningCLI

del BoringModel.configure_optimizers

if __name__ == "__main__":

    cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={
        "max_steps": 1,
        "accelerator": "gpu",
    })

You can run:

python src/product_dna/_antonstv/train/dummy.py fit --optimizer=Adam --trainer.devices=1

and the script will complete. But if you run

python src/product_dna/_antonstv/train/dummy.py fit --optimizer=Adam --trainer.devices=2

then you'll get an exception.

Error messages and logs

Traceback (most recent call last):
  File "src/product_dna/_antonstv/train/dummy.py", line 8, in <module>
    cli = LightningCLI(
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/cli.py", line 358, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/cli.py", line 670, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 113, in launch
    mp.start_processes(
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 139, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1029, in _run
    verify_loop_configurations(self)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py", line 41, in verify_loop_configurations
    __verify_train_val_loop_configuration(trainer, model)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py", line 80, in __verify_train_val_loop_configuration
    raise MisconfigurationException(
lightning_fabric.utilities.exceptions.MisconfigurationException: No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.

Environment

  • CUDA:
    • GPU:
      • Tesla V100-SXM2-16GB
      • Tesla V100-SXM2-16GB
      • Tesla V100-SXM2-16GB
      • Tesla V100-SXM2-16GB
    • available: True
    • version: 11.7
  • Lightning:
    • lightning-utilities: 0.6.0
    • pytorch-lightning: 1.9.0
    • pytorch-ranger: 0.1.1
    • torch: 1.13.0
    • torch-optimizer: 0.3.0
    • torchaudio: 0.13.0
    • torchmetrics: 0.10.2
    • torchtext: 0.14.0
    • torchvision: 0.14.0
  • Packages:
    • absl-py: 1.3.0
    • aiohttp: 3.8.3
    • aiosignal: 1.3.1
    • alabaster: 0.7.12
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.6.2
    • argon2-cffi: 21.3.0
    • argon2-cffi-bindings: 21.2.0
    • arrow: 1.2.3
    • asttokens: 2.1.0
    • astunparse: 1.6.3
    • async-timeout: 4.0.2
    • attrs: 22.1.0
    • aws-requests-auth: 0.4.3
    • babel: 2.11.0
    • backcall: 0.2.0
    • beautifulsoup4: 4.11.1
    • bleach: 5.0.1
    • boto3: 1.26.7
    • botocore: 1.29.7
    • bravado: 11.0.3
    • bravado-core: 5.17.1
    • cachetools: 5.2.0
    • certifi: 2022.9.24
    • cffi: 1.15.1
    • charset-normalizer: 2.1.1
    • click: 8.1.3
    • codetiming: 1.4.0
    • commonmark: 0.9.1
    • contextlib2: 21.6.0
    • contourpy: 1.0.6
    • coverage: 6.5.0
    • cycler: 0.11.0
    • debugpy: 1.6.3
    • decorator: 5.1.1
    • defusedxml: 0.7.1
    • dill: 0.3.6
    • docker: 6.0.1
    • docker-pycreds: 0.4.0
    • docstring-parser: 0.15
    • docutils: 0.17.1
    • entrypoints: 0.4
    • exceptiongroup: 1.0.1
    • execnet: 1.9.0
    • executing: 1.2.0
    • fastjsonschema: 2.16.2
    • fasttext-wheel: 0.9.2
    • filelock: 3.8.0
    • fire: 0.4.0
    • flatbuffers: 22.10.26
    • fonttools: 4.38.0
    • fqdn: 1.5.1
    • frozenlist: 1.3.3
    • fsspec: 2022.11.0
    • future: 0.18.3
    • gast: 0.4.0
    • gitdb: 4.0.9
    • gitpython: 3.1.29
    • google-auth: 2.14.1
    • google-auth-oauthlib: 0.4.6
    • google-pasta: 0.2.0
    • grpcio: 1.50.0
    • h5py: 3.7.0
    • huggingface-hub: 0.10.1
    • hydra-core: 1.2.0
    • idna: 3.4
    • imagesize: 1.4.1
    • importlib-metadata: 4.13.0
    • importlib-resources: 5.10.0
    • iniconfig: 1.1.1
    • ipykernel: 6.17.1
    • ipython: 8.6.0
    • ipython-genutils: 0.2.0
    • ipywidgets: 8.0.2
    • isoduration: 20.11.0
    • jedi: 0.18.1
    • jinja2: 3.1.2
    • jmespath: 1.0.1
    • joblib: 1.2.0
    • jsonargparse: 4.19.0
    • jsonpointer: 2.3
    • jsonref: 1.1.0
    • jsonschema: 4.17.0
    • jupyter: 1.0.0
    • jupyter-client: 7.4.4
    • jupyter-console: 6.4.4
    • jupyter-core: 5.0.0
    • jupyter-server: 1.23.1
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-widgets: 3.0.3
    • keras: 2.10.0
    • keras-preprocessing: 1.1.2
    • kiwisolver: 1.4.4
    • libclang: 14.0.6
    • lightning-utilities: 0.6.0
    • markdown: 3.4.1
    • markupsafe: 2.1.1
    • matplotlib: 3.6.2
    • matplotlib-inline: 0.1.6
    • mistune: 2.0.4
    • monotonic: 1.6
    • more-itertools: 9.0.0
    • msgpack: 1.0.4
    • multidict: 6.0.2
    • multiprocess: 0.70.14
    • nbclassic: 0.4.8
    • nbclient: 0.7.0
    • nbconvert: 7.2.4
    • nbformat: 5.7.0
    • neptune-client: 0.16.16
    • nest-asyncio: 1.5.6
    • nltk: 3.7
    • notebook: 6.5.2
    • notebook-shim: 0.2.2
    • numpy: 1.23.4
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • oauthlib: 3.2.2
    • omegaconf: 2.2.3
    • opt-einsum: 3.3.0
    • packaging: 21.3
    • pandas: 1.5.1
    • pandocfilters: 1.5.0
    • parso: 0.8.3
    • pathos: 0.3.0
    • pathtools: 0.1.2
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 9.3.0
    • pip: 22.2.2
    • pkgutil-resolve-name: 1.3.10
    • platformdirs: 2.5.3
    • pluggy: 1.0.0
    • pox: 0.3.2
    • ppft: 1.7.6.6
    • productdnascience: 0.1.268.dev4+gc10be3c.d20221110
    • prometheus-client: 0.15.0
    • promise: 2.3
    • prompt-toolkit: 3.0.32
    • protobuf: 3.19.6
    • protobuf3-to-dict: 0.1.5
    • psutil: 5.9.4
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • py-cpuinfo: 9.0.0
    • pyarrow: 10.0.0
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pybind11: 2.10.1
    • pycparser: 2.21
    • pygments: 2.13.0
    • pyjwt: 2.6.0
    • pyparsing: 3.0.9
    • pyrsistent: 0.19.2
    • pytest: 7.2.0
    • pytest-benchmark: 4.0.0
    • pytest-cov: 4.0.0
    • pytest-mock: 3.10.0
    • pytest-xdist: 3.0.2
    • python-dateutil: 2.8.2
    • pytorch-lightning: 1.9.0
    • pytorch-ranger: 0.1.1
    • pytz: 2022.6
    • pyyaml: 6.0
    • pyzmq: 24.0.1
    • qtconsole: 5.4.0
    • qtpy: 2.3.0
    • regex: 2022.10.31
    • requests: 2.28.1
    • requests-oauthlib: 1.3.1
    • rfc3339-validator: 0.1.4
    • rfc3987: 1.3.8
    • rich: 12.6.0
    • rsa: 4.9
    • s3fs: 0.4.2
    • s3transfer: 0.6.0
    • sagemaker: 2.116.0
    • schema: 0.7.5
    • scikit-learn: 1.1.3
    • scipy: 1.9.3
    • send2trash: 1.8.0
    • sentry-sdk: 1.10.1
    • setproctitle: 1.3.2
    • setuptools: 65.5.0
    • shortuuid: 1.0.10
    • simplejson: 3.18.1
    • six: 1.16.0
    • smdebug-rulesconfig: 1.0.1
    • smmap: 5.0.0
    • sniffio: 1.3.0
    • snowballstemmer: 2.2.0
    • soupsieve: 2.3.2.post1
    • sphinx: 5.3.0
    • sphinx-rtd-theme: 1.1.1
    • sphinxcontrib-applehelp: 1.0.2
    • sphinxcontrib-devhelp: 1.0.2
    • sphinxcontrib-htmlhelp: 2.0.0
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-qthelp: 1.0.3
    • sphinxcontrib-serializinghtml: 1.1.5
    • stack-data: 0.6.0
    • swagger-spec-validator: 3.0.3
    • tensorboard: 2.10.1
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.1
    • tensorboardx: 2.5.1
    • tensorflow: 2.10.0
    • tensorflow-estimator: 2.10.0
    • tensorflow-hub: 0.12.0
    • tensorflow-io-gcs-filesystem: 0.27.0
    • tensorflow-text: 2.10.0
    • termcolor: 2.1.0
    • terminado: 0.17.0
    • threadpoolctl: 3.1.0
    • tinycss2: 1.2.1
    • tokenizers: 0.13.2
    • tomli: 2.0.1
    • torch: 1.13.0
    • torch-optimizer: 0.3.0
    • torchaudio: 0.13.0
    • torchmetrics: 0.10.2
    • torchtext: 0.14.0
    • torchvision: 0.14.0
    • tornado: 6.2
    • tqdm: 4.64.1
    • traitlets: 5.5.0
    • transformers: 4.25.0.dev0
    • typeshed-client: 2.2.0
    • typing-extensions: 4.4.0
    • unicode: 2.9
    • unidecode: 1.3.6
    • uri-template: 1.2.0
    • urllib3: 1.26.12
    • wandb: 0.13.5
    • wcwidth: 0.2.5
    • webcolors: 1.12
    • webencodings: 0.5.1
    • websocket-client: 1.4.2
    • werkzeug: 2.2.2
    • wheel: 0.37.1
    • widgetsnbextension: 4.0.3
    • wrapt: 1.14.1
    • yarl: 1.8.1
    • zipp: 3.10.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.13
    • version: Proposal for help #1 SMP Wed Nov 2 05:27:06 UTC 2022

More info

This error took me a really long time to track down. Even if there's not an easy fix, it would be great to throw a warning.

cc @Borda @carmocca @mauvilsa @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked onlightningclipl.cli.LightningCLIstrategy: ddpDistributedDataParallel

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions