12
12
import logging
13
13
import warnings
14
14
from functools import cached_property
15
+ from typing import Generator
15
16
16
17
import numpy as np
17
18
import torch
@@ -257,6 +258,38 @@ def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
257
258
# TODO: move this to a Stepper class.
258
259
return model .predict_step (input_tensor_torch )
259
260
261
+ def forecast_stepper (self , start_date , lead_time ) -> Generator :
262
+ """Generate step and date variables for the forecast loop
263
+
264
+ Parameters
265
+ ----------
266
+ start_date : datetime.datetime
267
+ Start date of the forecast
268
+ lead_time : datetime.timedelta
269
+ Lead time of the forecast
270
+
271
+ Returns
272
+ ------
273
+ step : datetime.timedelta
274
+ Time delta since beginning of forecast
275
+ valid_date : datetime.datetime
276
+ Date of the forecast
277
+ next_date : datetime.datetime
278
+ Date used to prepare the next input tensor
279
+ is_last_step : bool
280
+ True if it's the last step of the forecast
281
+ """
282
+ steps = lead_time // self .checkpoint .timestep
283
+
284
+ LOG .info ("Lead time: %s, time stepping: %s Forecasting %s steps" , lead_time , self .checkpoint .timestep , steps )
285
+
286
+ for s in range (steps ):
287
+ step = (s + 1 ) * self .checkpoint .timestep
288
+ valid_date = start_date + step
289
+ next_date = valid_date
290
+ is_last_step = s == steps - 1
291
+ yield step , valid_date , next_date , is_last_step
292
+
260
293
def forecast (self , lead_time , input_tensor_numpy , input_state ):
261
294
self .model .eval ()
262
295
@@ -268,10 +301,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
268
301
LOG .info ("Using autocast %s" , self .autocast )
269
302
270
303
lead_time = to_timedelta (lead_time )
271
- steps = lead_time // self .checkpoint .timestep
272
-
273
- LOG .info ("Using autocast %s" , self .autocast )
274
- LOG .info ("Lead time: %s, time stepping: %s Forecasting %s steps" , lead_time , self .checkpoint .timestep , steps )
275
304
276
305
result = input_state .copy () # We should not modify the input state
277
306
result ["fields" ] = dict ()
@@ -295,9 +324,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
295
324
if self .verbosity > 0 :
296
325
self ._print_input_tensor ("First input tensor" , input_tensor_torch )
297
326
298
- for s in range (steps ):
299
- step = (s + 1 ) * self .checkpoint .timestep
300
- date = start + step
327
+ for s , (step , date , next_date , is_last_step ) in enumerate (self .forecast_stepper (start , lead_time )):
301
328
title = f"Forecasting step { step } ({ date } )"
302
329
303
330
result ["date" ] = date
@@ -333,8 +360,8 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
333
360
yield result
334
361
335
362
# No need to prepare next input tensor if we are at the last step
336
- if s == steps - 1 :
337
- continue
363
+ if is_last_step :
364
+ break
338
365
339
366
# Update tensor for next iteration
340
367
with ProfilingLabel ("Update tensor for next step" , self .use_profiler ):
@@ -347,10 +374,10 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
347
374
del y_pred # Recover memory
348
375
349
376
input_tensor_torch = self .add_dynamic_forcings_to_input_tensor (
350
- input_tensor_torch , input_state , date , check
377
+ input_tensor_torch , input_state , next_date , check
351
378
)
352
379
input_tensor_torch = self .add_boundary_forcings_to_input_tensor (
353
- input_tensor_torch , input_state , date , check
380
+ input_tensor_torch , input_state , next_date , check
354
381
)
355
382
356
383
if not check .all ():
0 commit comments