Skip to content

Commit cc7fed5

Browse files
authored
feat(runner): Forecast loop step generator (#168)
1 parent 8b22ea7 commit cc7fed5

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

src/anemoi/inference/runner.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import warnings
1414
from functools import cached_property
15+
from typing import Generator
1516

1617
import numpy as np
1718
import torch
@@ -257,6 +258,38 @@ def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
257258
# TODO: move this to a Stepper class.
258259
return model.predict_step(input_tensor_torch)
259260

261+
def forecast_stepper(self, start_date, lead_time) -> Generator:
262+
"""Generate step and date variables for the forecast loop
263+
264+
Parameters
265+
----------
266+
start_date : datetime.datetime
267+
Start date of the forecast
268+
lead_time : datetime.timedelta
269+
Lead time of the forecast
270+
271+
Returns
272+
------
273+
step : datetime.timedelta
274+
Time delta since beginning of forecast
275+
valid_date : datetime.datetime
276+
Date of the forecast
277+
next_date : datetime.datetime
278+
Date used to prepare the next input tensor
279+
is_last_step : bool
280+
True if it's the last step of the forecast
281+
"""
282+
steps = lead_time // self.checkpoint.timestep
283+
284+
LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)
285+
286+
for s in range(steps):
287+
step = (s + 1) * self.checkpoint.timestep
288+
valid_date = start_date + step
289+
next_date = valid_date
290+
is_last_step = s == steps - 1
291+
yield step, valid_date, next_date, is_last_step
292+
260293
def forecast(self, lead_time, input_tensor_numpy, input_state):
261294
self.model.eval()
262295

@@ -268,10 +301,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
268301
LOG.info("Using autocast %s", self.autocast)
269302

270303
lead_time = to_timedelta(lead_time)
271-
steps = lead_time // self.checkpoint.timestep
272-
273-
LOG.info("Using autocast %s", self.autocast)
274-
LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)
275304

276305
result = input_state.copy() # We should not modify the input state
277306
result["fields"] = dict()
@@ -295,9 +324,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
295324
if self.verbosity > 0:
296325
self._print_input_tensor("First input tensor", input_tensor_torch)
297326

298-
for s in range(steps):
299-
step = (s + 1) * self.checkpoint.timestep
300-
date = start + step
327+
for s, (step, date, next_date, is_last_step) in enumerate(self.forecast_stepper(start, lead_time)):
301328
title = f"Forecasting step {step} ({date})"
302329

303330
result["date"] = date
@@ -333,8 +360,8 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
333360
yield result
334361

335362
# No need to prepare next input tensor if we are at the last step
336-
if s == steps - 1:
337-
continue
363+
if is_last_step:
364+
break
338365

339366
# Update tensor for next iteration
340367
with ProfilingLabel("Update tensor for next step", self.use_profiler):
@@ -347,10 +374,10 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
347374
del y_pred # Recover memory
348375

349376
input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
350-
input_tensor_torch, input_state, date, check
377+
input_tensor_torch, input_state, next_date, check
351378
)
352379
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
353-
input_tensor_torch, input_state, date, check
380+
input_tensor_torch, input_state, next_date, check
354381
)
355382

356383
if not check.all():

0 commit comments

Comments
 (0)