Skip to content

recover num_ranks from previous run to calculate epoch_base #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 24, 2025
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def init(

self.devices = self.init_torch()

# Get num_ranks of previous, to be continued run before
# num_ranks gets overwritten by current setting during init_ddp()
self.num_ranks_original = cf.num_ranks if "num_ranks" in cf.keys() else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is so cheap to access we should not add it as an extra state in the class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the next line will overwrite the "num_ranks" of the original run and adapt it to the current system. That's why it needs to be captured here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. seem my comment below about style, but looks good otherwise.


self.init_ddp(cf)

# read configuration of data streams
Expand Down Expand Up @@ -264,7 +268,15 @@ def run(self, cf, run_id_contd=None, epoch_contd=None):
self.loss_fcts_val = [[getattr(losses, name), w] for name, w in cf.loss_fcts_val]

# recover epoch when continuing run
epoch_base = int(self.cf.istep / len(self.data_loader))
if self.num_ranks_original is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would calculate it here (and note the more pythonic way):

num_ranks_original = self.cf.get("num_ranks", None)

epoch_base = int(self.cf.istep / len(self.data_loader))
else:
len_per_rank = (
len(self.dataset) // (self.num_ranks_original * cf.batch_size)
) * cf.batch_size
epoch_base = int(
self.cf.istep / (min(len_per_rank, cf.samples_per_epoch) * self.num_ranks_original)
)

# torch.autograd.set_detect_anomaly(True)
if cf.forecast_policy is not None:
Expand Down