Skip to content

Issue with logs when using torch.compile  #18835

Open
@Forbu

Description

@Forbu

Bug description

I experience a bug with logging when using torch compile.

Traceback (most recent call last):
  File "/app/data/src_stardust/prediction/training/train_gnn_realdata.py", line 586, in <module>
    trainer.fit(modele, train_loader, validation_loader, ckpt_path=last_checkpoint)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 545, in fit
    call._call_and_handle_interrupt(
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 581, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 990, in _run
    results = self._run_stage()
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1034, in _run_stage
    self._run_sanity_check()
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1063, in _run_sanity_check
    val_loop.run()
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 181, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 134, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 391, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 403, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/app/data/src_stardust/prediction/training/train_gnn_realdata.py", line 296, in validation_step
    ) = st.preprocessing_batch_real(
  File "/app/data/src_stardust/prediction/training/train_gnn_realdata.py", line 344, in <resume in validation_step>
    self.log(
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 412, in log
    apply_to_collection(value, dict, self.__check_not_nested, name)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 413, in <resume in log>
    apply_to_collection(
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 500, in <resume in log>
    results.log(
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 390, in log
    meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only)
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 398, in <resume in log>
    raise MisconfigurationException(
  File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 398, in <resume in log>
    raise MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: You called `self.log(val_aff_db, ...)` twice in `validation_step` with different arguments. This is not allowed

In my validation_step method, I only have one self.log statement :

self.log(
                name="validation_metric",
                value=myvalue,
                on_step=True,
                on_epoch=True,
                logger=True,
                batch_size=1,
                prog_bar=False,
                reduce_fx="mean",
                enable_graph=False,
                add_dataloader_idx=False,
                metric_attribute=None,
                rank_zero_only=True,
                sync_dist=False,
            )

This issue only happen using torch.compile() on the whole lightning module.
I have the latest lightning version 2.1
(and torch 2.1)

What version are you seeing the problem on?

master

How to reproduce the bug

I will try to use the simple BoringModule from the test to easily reproduce the bug (I will do that in the comment later).

Environment

Current environment
- Lightning Component (e.g. Trainer, LightningModule):
- PyTorch Lightning Version (e.g., 2.1):
- PyTorch Version (e.g., 2.1):
- I installed lightning with pip

More info

This bug only happen when I compile my module :

myligthningmodule = torch.compile(myligthningmodule)

cc @carmocca

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions