Skip to content

Commit b0c72a9

Browse files
authored
Enable parsing arguments to train and evaluate (#40)
* Implement argument parsing to allow running train and evaluate directly from command line with arguments. * Include several fixes from code-review. * Resolve remaining review comments. * Add changes from lintering.
1 parent ce4cb31 commit b0c72a9

File tree

1 file changed

+114
-27
lines changed

1 file changed

+114
-27
lines changed

src/weathergen/__init__.py

Lines changed: 114 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,69 +7,156 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10+
import argparse
1011
import pdb
1112
import sys
1213
import time
1314
import traceback
1415

16+
import pandas as pd
17+
1518
from weathergen.train.trainer import Trainer
1619
from weathergen.utils.config import Config, private_conf
1720
from weathergen.utils.logger import init_loggers
1821

1922

2023
####################################################################################################
21-
def evaluate(
22-
run_id,
23-
epoch,
24-
masking_mode=None,
25-
forecacast_steps=None,
26-
samples=10000000,
27-
shuffle=False,
28-
save_samples=True,
29-
gridded_output_streams=[],
30-
):
24+
def evaluate():
25+
"""
26+
Evaluation function for WeatherGenerator model.
27+
Entry point for calling the evaluation code from the command line.
28+
29+
Args:
30+
run_id (str): Run/model id of pretrained WeatherGenerator model.
31+
start_date (str): Start date for evaluation. Format must be parsable with pd.to_datetime.
32+
end_date (str): End date for evaluation. Format must be parsable with pd.to_datetime.
33+
epoch (int, optional): Epoch of pretrained WeatherGenerator model used for evaluation (-1 corresponds to last epoch). Defaults to -1.
34+
masking_mode (str, optional): Masking mode for evaluation. Defaults to None.
35+
forecast_steps (int, optional): Number of forecast steps for evaluation. Defaults to None.
36+
samples (int, optional): Number of samples for evaluation. Defaults to 10000000.
37+
shuffle (bool, optional): Shuffle samples for evaluation. Defaults to False.
38+
save_samples (bool, optional): Save samples for evaluation. Defaults to True.
39+
analysis_streams_output (list, optional): Analysis output streams during evaluation. Defaults to ['ERA5'].
40+
gridded_output_streams(list, optional): Currently unused and threrefore omitted here
41+
"""
42+
parser = argparse.ArgumentParser()
43+
44+
parser.add_argument(
45+
"--run_id",
46+
type=str,
47+
required=True,
48+
help="Run/model id of pretrained WeatherGenerator model.",
49+
)
50+
parser.add_argument(
51+
"--start_date",
52+
"-start",
53+
type=str,
54+
required=True,
55+
help="Start date for evaluation. Format must be parsable with pd.to_datetime.",
56+
)
57+
parser.add_argument(
58+
"--end_date",
59+
"-end",
60+
type=str,
61+
required=True,
62+
help="End date for evaluation. Format must be parsable with pd.to_datetime.",
63+
)
64+
parser.add_argument(
65+
"--epoch",
66+
type=int,
67+
default=-1,
68+
help="Epoch of pretrained WeatherGenerator model used for evaluation (-1 corresponds to the last checkpoint).",
69+
)
70+
parser.add_argument(
71+
"--forecast_steps",
72+
type=int,
73+
default=None,
74+
help="Number of forecast steps for evaluation. Uses attribute from config when None is set.",
75+
)
76+
parser.add_argument(
77+
"--samples", type=int, default=10000000, help="Number of evaluation samples."
78+
)
79+
parser.add_argument(
80+
"--shuffle", type=bool, default=False, help="Shuffle samples from evaluation."
81+
)
82+
parser.add_argument(
83+
"--save_samples", type=bool, default=True, help="Save samples from evaluation."
84+
)
85+
parser.add_argument(
86+
"--analysis_streams_output",
87+
type=list,
88+
default=["ERA5"],
89+
help="Analysis output streams during evaluation.",
90+
)
91+
92+
args = parser.parse_args()
93+
3194
# TODO: move somewhere else
3295
init_loggers()
96+
3397
# load config if specified
34-
cf = Config.load(run_id, epoch if epoch is not None else -1)
98+
cf = Config.load(args.run_id, args.epoch)
3599

36100
cf.run_history += [(cf.run_id, cf.istep)]
37101

38-
cf.samples_per_validation = samples
39-
cf.log_validation = samples if save_samples else 0
102+
cf.samples_per_validation = args.samples
103+
cf.log_validation = args.samples if args.save_samples else 0
40104

41-
if masking_mode is not None:
42-
cf.masking_mode = masking_mode
105+
start_date, end_date = pd.to_datetime(args.start_date), pd.to_datetime(args.end_date)
43106

44-
# Oct-Nov 2022
45-
cf.start_date_val = 202210011600
46-
cf.end_date_val = 202212010400
107+
cf.start_date_val = start_date.strftime(
108+
"%Y%m%d%H%M"
109+
) # ML: would be better to use datetime-objects
110+
cf.end_date_val = end_date.strftime("%Y%m%d%H%M")
111+
# # Oct-Nov 2022
112+
# cf.start_date_val = 202210011600
113+
# cf.end_date_val = 202212010400
47114
# # 2022
48115
# cf.start_date_val = 202201010400
49116
# cf.end_date_val = 202301010400
50117

51-
# cf.step_hrs = 12
118+
cf.shuffle = args.shuffle
52119

53-
cf.shuffle = shuffle
54-
55-
cf.forecast_steps = forecacast_steps if forecacast_steps else cf.forecast_steps
120+
cf.forecast_steps = args.forecast_steps if args.forecast_steps else cf.forecast_steps
56121
# cf.forecast_policy = 'fixed'
57122

58123
# cf.analysis_streams_output = ['Surface', 'Air', 'METEOSAT', 'ATMS', 'IASI', 'AMSR2']
59-
cf.analysis_streams_output = ["ERA5"]
124+
cf.analysis_streams_output = args.analysis_streams_output
60125

61126
# make sure number of loaders does not exceed requested samples
62-
cf.loader_num_workers = min(cf.loader_num_workers, samples)
127+
cf.loader_num_workers = min(cf.loader_num_workers, args.samples)
63128

64129
trainer = Trainer()
65-
trainer.evaluate(cf, run_id, epoch, True)
130+
trainer.evaluate(cf, args.run_id, args.epoch, True)
66131

67132

68133
####################################################################################################
69-
def train(run_id=None) -> None:
134+
def train() -> None:
135+
"""
136+
Training function for WeatherGenerator model.
137+
Entry point for calling the training code from the command line.
138+
Configurations are set in the function body.
139+
140+
Args:
141+
run_id (str, optional): Run/model id of pretrained WeatherGenerator model to continue training. Defaults to None.
142+
143+
Note: All model configurations are set in the function body.
144+
"""
145+
parser = argparse.ArgumentParser()
146+
147+
parser.add_argument(
148+
"--run_id",
149+
type=str,
150+
default=None,
151+
help="Run/model id of pretrained WeatherGenerator model to continue training. Defaults to None.",
152+
)
153+
154+
args = parser.parse_args()
155+
70156
# TODO: move somewhere else
71157
init_loggers()
72158
private_cf = private_conf()
159+
73160
cf = Config()
74161

75162
# directory where input streams are specified
@@ -204,7 +291,7 @@ def train(run_id=None) -> None:
204291
cf.istep = 0
205292
cf.run_history = []
206293

207-
cf.run_id = run_id
294+
cf.run_id = args.run_id
208295
cf.desc = ""
209296

210297
trainer = Trainer(log_freq=20, checkpoint_freq=250, print_freq=10)

0 commit comments

Comments
 (0)