|
7 | 7 | # granted to it by virtue of its status as an intergovernmental organisation
|
8 | 8 | # nor does it submit to any jurisdiction.
|
9 | 9 |
|
| 10 | +import argparse |
10 | 11 | import pdb
|
11 | 12 | import sys
|
12 | 13 | import time
|
13 | 14 | import traceback
|
14 | 15 |
|
| 16 | +import pandas as pd |
| 17 | + |
15 | 18 | from weathergen.train.trainer import Trainer
|
16 | 19 | from weathergen.utils.config import Config, private_conf
|
17 | 20 | from weathergen.utils.logger import init_loggers
|
18 | 21 |
|
19 | 22 |
|
20 | 23 | ####################################################################################################
|
21 |
| -def evaluate( |
22 |
| - run_id, |
23 |
| - epoch, |
24 |
| - masking_mode=None, |
25 |
| - forecacast_steps=None, |
26 |
| - samples=10000000, |
27 |
| - shuffle=False, |
28 |
| - save_samples=True, |
29 |
| - gridded_output_streams=[], |
30 |
| -): |
| 24 | +def evaluate(): |
| 25 | + """ |
| 26 | + Evaluation function for WeatherGenerator model. |
| 27 | + Entry point for calling the evaluation code from the command line. |
| 28 | +
|
| 29 | + Args: |
| 30 | + run_id (str): Run/model id of pretrained WeatherGenerator model. |
| 31 | + start_date (str): Start date for evaluation. Format must be parsable with pd.to_datetime. |
| 32 | + end_date (str): End date for evaluation. Format must be parsable with pd.to_datetime. |
| 33 | + epoch (int, optional): Epoch of pretrained WeatherGenerator model used for evaluation (-1 corresponds to last epoch). Defaults to -1. |
| 34 | + masking_mode (str, optional): Masking mode for evaluation. Defaults to None. |
| 35 | + forecast_steps (int, optional): Number of forecast steps for evaluation. Defaults to None. |
| 36 | + samples (int, optional): Number of samples for evaluation. Defaults to 10000000. |
| 37 | + shuffle (bool, optional): Shuffle samples for evaluation. Defaults to False. |
| 38 | + save_samples (bool, optional): Save samples for evaluation. Defaults to True. |
| 39 | + analysis_streams_output (list, optional): Analysis output streams during evaluation. Defaults to ['ERA5']. |
| 40 | + gridded_output_streams(list, optional): Currently unused and threrefore omitted here |
| 41 | + """ |
| 42 | + parser = argparse.ArgumentParser() |
| 43 | + |
| 44 | + parser.add_argument( |
| 45 | + "--run_id", |
| 46 | + type=str, |
| 47 | + required=True, |
| 48 | + help="Run/model id of pretrained WeatherGenerator model.", |
| 49 | + ) |
| 50 | + parser.add_argument( |
| 51 | + "--start_date", |
| 52 | + "-start", |
| 53 | + type=str, |
| 54 | + required=True, |
| 55 | + help="Start date for evaluation. Format must be parsable with pd.to_datetime.", |
| 56 | + ) |
| 57 | + parser.add_argument( |
| 58 | + "--end_date", |
| 59 | + "-end", |
| 60 | + type=str, |
| 61 | + required=True, |
| 62 | + help="End date for evaluation. Format must be parsable with pd.to_datetime.", |
| 63 | + ) |
| 64 | + parser.add_argument( |
| 65 | + "--epoch", |
| 66 | + type=int, |
| 67 | + default=-1, |
| 68 | + help="Epoch of pretrained WeatherGenerator model used for evaluation (-1 corresponds to the last checkpoint).", |
| 69 | + ) |
| 70 | + parser.add_argument( |
| 71 | + "--forecast_steps", |
| 72 | + type=int, |
| 73 | + default=None, |
| 74 | + help="Number of forecast steps for evaluation. Uses attribute from config when None is set.", |
| 75 | + ) |
| 76 | + parser.add_argument( |
| 77 | + "--samples", type=int, default=10000000, help="Number of evaluation samples." |
| 78 | + ) |
| 79 | + parser.add_argument( |
| 80 | + "--shuffle", type=bool, default=False, help="Shuffle samples from evaluation." |
| 81 | + ) |
| 82 | + parser.add_argument( |
| 83 | + "--save_samples", type=bool, default=True, help="Save samples from evaluation." |
| 84 | + ) |
| 85 | + parser.add_argument( |
| 86 | + "--analysis_streams_output", |
| 87 | + type=list, |
| 88 | + default=["ERA5"], |
| 89 | + help="Analysis output streams during evaluation.", |
| 90 | + ) |
| 91 | + |
| 92 | + args = parser.parse_args() |
| 93 | + |
31 | 94 | # TODO: move somewhere else
|
32 | 95 | init_loggers()
|
| 96 | + |
33 | 97 | # load config if specified
|
34 |
| - cf = Config.load(run_id, epoch if epoch is not None else -1) |
| 98 | + cf = Config.load(args.run_id, args.epoch) |
35 | 99 |
|
36 | 100 | cf.run_history += [(cf.run_id, cf.istep)]
|
37 | 101 |
|
38 |
| - cf.samples_per_validation = samples |
39 |
| - cf.log_validation = samples if save_samples else 0 |
| 102 | + cf.samples_per_validation = args.samples |
| 103 | + cf.log_validation = args.samples if args.save_samples else 0 |
40 | 104 |
|
41 |
| - if masking_mode is not None: |
42 |
| - cf.masking_mode = masking_mode |
| 105 | + start_date, end_date = pd.to_datetime(args.start_date), pd.to_datetime(args.end_date) |
43 | 106 |
|
44 |
| - # Oct-Nov 2022 |
45 |
| - cf.start_date_val = 202210011600 |
46 |
| - cf.end_date_val = 202212010400 |
| 107 | + cf.start_date_val = start_date.strftime( |
| 108 | + "%Y%m%d%H%M" |
| 109 | + ) # ML: would be better to use datetime-objects |
| 110 | + cf.end_date_val = end_date.strftime("%Y%m%d%H%M") |
| 111 | + # # Oct-Nov 2022 |
| 112 | + # cf.start_date_val = 202210011600 |
| 113 | + # cf.end_date_val = 202212010400 |
47 | 114 | # # 2022
|
48 | 115 | # cf.start_date_val = 202201010400
|
49 | 116 | # cf.end_date_val = 202301010400
|
50 | 117 |
|
51 |
| - # cf.step_hrs = 12 |
| 118 | + cf.shuffle = args.shuffle |
52 | 119 |
|
53 |
| - cf.shuffle = shuffle |
54 |
| - |
55 |
| - cf.forecast_steps = forecacast_steps if forecacast_steps else cf.forecast_steps |
| 120 | + cf.forecast_steps = args.forecast_steps if args.forecast_steps else cf.forecast_steps |
56 | 121 | # cf.forecast_policy = 'fixed'
|
57 | 122 |
|
58 | 123 | # cf.analysis_streams_output = ['Surface', 'Air', 'METEOSAT', 'ATMS', 'IASI', 'AMSR2']
|
59 |
| - cf.analysis_streams_output = ["ERA5"] |
| 124 | + cf.analysis_streams_output = args.analysis_streams_output |
60 | 125 |
|
61 | 126 | # make sure number of loaders does not exceed requested samples
|
62 |
| - cf.loader_num_workers = min(cf.loader_num_workers, samples) |
| 127 | + cf.loader_num_workers = min(cf.loader_num_workers, args.samples) |
63 | 128 |
|
64 | 129 | trainer = Trainer()
|
65 |
| - trainer.evaluate(cf, run_id, epoch, True) |
| 130 | + trainer.evaluate(cf, args.run_id, args.epoch, True) |
66 | 131 |
|
67 | 132 |
|
68 | 133 | ####################################################################################################
|
69 |
| -def train(run_id=None) -> None: |
| 134 | +def train() -> None: |
| 135 | + """ |
| 136 | + Training function for WeatherGenerator model. |
| 137 | + Entry point for calling the training code from the command line. |
| 138 | + Configurations are set in the function body. |
| 139 | +
|
| 140 | + Args: |
| 141 | + run_id (str, optional): Run/model id of pretrained WeatherGenerator model to continue training. Defaults to None. |
| 142 | +
|
| 143 | + Note: All model configurations are set in the function body. |
| 144 | + """ |
| 145 | + parser = argparse.ArgumentParser() |
| 146 | + |
| 147 | + parser.add_argument( |
| 148 | + "--run_id", |
| 149 | + type=str, |
| 150 | + default=None, |
| 151 | + help="Run/model id of pretrained WeatherGenerator model to continue training. Defaults to None.", |
| 152 | + ) |
| 153 | + |
| 154 | + args = parser.parse_args() |
| 155 | + |
70 | 156 | # TODO: move somewhere else
|
71 | 157 | init_loggers()
|
72 | 158 | private_cf = private_conf()
|
| 159 | + |
73 | 160 | cf = Config()
|
74 | 161 |
|
75 | 162 | # directory where input streams are specified
|
@@ -204,7 +291,7 @@ def train(run_id=None) -> None:
|
204 | 291 | cf.istep = 0
|
205 | 292 | cf.run_history = []
|
206 | 293 |
|
207 |
| - cf.run_id = run_id |
| 294 | + cf.run_id = args.run_id |
208 | 295 | cf.desc = ""
|
209 | 296 |
|
210 | 297 | trainer = Trainer(log_freq=20, checkpoint_freq=250, print_freq=10)
|
|
0 commit comments