Skip to content

Commit 7e6412b

Browse files
complete BigST with preprocess (#206)
1 parent a3612df commit 7e6412b

12 files changed

+442
-509
lines changed

baselines/BigST/PEMS08.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import os
2+
import sys
3+
import torch
4+
from easydict import EasyDict
5+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
6+
7+
from basicts.metrics import masked_mae, masked_mape, masked_rmse
8+
from basicts.data import TimeSeriesForecastingDataset
9+
from basicts.runners import SimpleTimeSeriesForecastingRunner
10+
from basicts.scaler import ZScoreScaler
11+
from basicts.utils import get_regular_settings, load_adj
12+
13+
from .arch import BigST
14+
# from .runner import BigSTPreprocessRunner
15+
from .loss import bigst_loss
16+
17+
import pdb
18+
19+
############################## Hot Parameters ##############################
20+
# Dataset & Metrics configuration
21+
DATA_NAME = 'PEMS08' # Dataset name
22+
regular_settings = get_regular_settings(DATA_NAME)
23+
INPUT_LEN = 2016 # regular_settings['INPUT_LEN'] # Length of input sequence
24+
OUTPUT_LEN = 12 # regular_settings['OUTPUT_LEN'] # Length of output sequence
25+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
26+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
27+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
28+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
29+
# Model architecture and parameters
30+
PREPROCESSED_FILE = "checkpoints\\BigSTPreprocess\\PEMS08_100_2016_12\\db8308a2c87de35e5f3db6177c5714ff\\BigSTPreprocess_best_val_MAE.pt"
31+
MODEL_ARCH = BigST
32+
33+
adj_mx, _ = load_adj("datasets/" + DATA_NAME +
34+
"/adj_mx.pkl", "doubletransition")
35+
MODEL_PARAM = {
36+
"bigst_args":{
37+
"num_nodes": 170,
38+
"seq_num": 12,
39+
"in_dim": 3,
40+
"out_dim": OUTPUT_LEN, # 源代码固定成12了
41+
"hid_dim": 32,
42+
"tau" : 0.25,
43+
"random_feature_dim": 64,
44+
"node_emb_dim": 32,
45+
"time_emb_dim": 32,
46+
"use_residual": True,
47+
"use_bn": True,
48+
"use_long": True,
49+
"use_spatial": True,
50+
"dropout": 0.3,
51+
"supports": [torch.tensor(i) for i in adj_mx],
52+
"time_of_day_size": 288,
53+
"day_of_week_size": 7
54+
},
55+
"preprocess_path": PREPROCESSED_FILE,
56+
"preprocess_args":{
57+
"num_nodes": 170,
58+
"in_dim": 3,
59+
"dropout": 0.3,
60+
"input_length": 2016,
61+
"output_length": 12,
62+
"nhid": 32,
63+
"tiny_batch_size": 64,
64+
}
65+
66+
67+
}
68+
69+
NUM_EPOCHS = 100
70+
71+
############################## General Configuration ##############################
72+
CFG = EasyDict()
73+
# General settings
74+
CFG.DESCRIPTION = 'An Example Config'
75+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
76+
# Runner
77+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
78+
79+
############################## Environment Configuration ##############################
80+
81+
CFG.ENV = EasyDict() # Environment settings. Default: None
82+
CFG.ENV.SEED = 0 # Random seed. Default: None
83+
84+
############################## Dataset Configuration ##############################
85+
CFG.DATASET = EasyDict()
86+
# Dataset settings
87+
CFG.DATASET.NAME = DATA_NAME
88+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
89+
CFG.DATASET.PARAM = EasyDict({
90+
'dataset_name': DATA_NAME,
91+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
92+
'input_len': INPUT_LEN,
93+
'output_len': OUTPUT_LEN,
94+
# 'mode' is automatically set by the runner
95+
})
96+
97+
############################## Scaler Configuration ##############################
98+
CFG.SCALER = EasyDict()
99+
# Scaler settings
100+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
101+
CFG.SCALER.PARAM = EasyDict({
102+
'dataset_name': DATA_NAME,
103+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
104+
'norm_each_channel': NORM_EACH_CHANNEL,
105+
'rescale': RESCALE,
106+
})
107+
108+
############################## Model Configuration ##############################
109+
CFG.MODEL = EasyDict()
110+
# Model settings
111+
CFG.MODEL.NAME = MODEL_ARCH.__name__
112+
CFG.MODEL.ARCH = MODEL_ARCH
113+
CFG.MODEL.PARAM = MODEL_PARAM
114+
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2]
115+
CFG.MODEL.TARGET_FEATURES = [0]
116+
117+
############################## Metrics Configuration ##############################
118+
119+
CFG.METRICS = EasyDict()
120+
# Metrics settings
121+
CFG.METRICS.FUNCS = EasyDict({
122+
'MAE': masked_mae,
123+
'MAPE': masked_mape,
124+
'RMSE': masked_rmse,
125+
})
126+
CFG.METRICS.TARGET = 'MAE'
127+
CFG.METRICS.NULL_VAL = NULL_VAL
128+
129+
############################## Training Configuration ##############################
130+
CFG.TRAIN = EasyDict()
131+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
132+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
133+
'checkpoints',
134+
MODEL_ARCH.__name__,
135+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
136+
)
137+
138+
139+
CFG.TRAIN.LOSS = bigst_loss if MODEL_PARAM['bigst_args']['use_spatial'] else masked_mae
140+
# Optimizer settings
141+
CFG.TRAIN.OPTIM = EasyDict()
142+
CFG.TRAIN.OPTIM.TYPE = "AdamW"
143+
CFG.TRAIN.OPTIM.PARAM = {
144+
"lr": 0.002,
145+
"weight_decay": 0.0001,
146+
}
147+
# Learning rate scheduler settings
148+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
149+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
150+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
151+
"milestones": [1, 50],
152+
"gamma": 0.5
153+
}
154+
# Train data loader settings
155+
CFG.TRAIN.DATA = EasyDict()
156+
CFG.TRAIN.DATA.BATCH_SIZE = 64
157+
CFG.TRAIN.DATA.SHUFFLE = True
158+
# Gradient clipping settings
159+
CFG.TRAIN.CLIP_GRAD_PARAM = {
160+
"max_norm": 5.0
161+
}
162+
163+
############################## Validation Configuration ##############################
164+
CFG.VAL = EasyDict()
165+
CFG.VAL.INTERVAL = 1
166+
CFG.VAL.DATA = EasyDict()
167+
CFG.VAL.DATA.BATCH_SIZE = 64
168+
169+
############################## Test Configuration ##############################
170+
CFG.TEST = EasyDict()
171+
CFG.TEST.INTERVAL = 1
172+
CFG.TEST.DATA = EasyDict()
173+
CFG.TEST.DATA.BATCH_SIZE = 64
174+
175+
############################## Evaluation Configuration ##############################
176+
CFG.EVAL = EasyDict()
177+
178+
# Evaluation parameters
179+
CFG.EVAL.HORIZONS = [3, 6, 12] # Prediction horizons for evaluation. Default: []
180+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True
181+
182+

baselines/BigST/PEMS04.py renamed to baselines/BigST/PreprocessPEMS08.py

+17-26
Original file line numberDiff line numberDiff line change
@@ -10,41 +10,32 @@
1010
from basicts.scaler import ZScoreScaler
1111
from basicts.utils import get_regular_settings, load_adj
1212

13-
from .arch import BigST
14-
from .loss import bigst_loss
13+
from .arch import BigSTPreprocess
14+
from .runner import BigSTPreprocessRunner
1515

1616
############################## Hot Parameters ##############################
1717
# Dataset & Metrics configuration
18-
DATA_NAME = 'PEMS04' # Dataset name
18+
DATA_NAME = 'PEMS08' # Dataset name
1919
regular_settings = get_regular_settings(DATA_NAME)
20-
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
21-
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
20+
INPUT_LEN = 2016
21+
OUTPUT_LEN = 12
2222
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
2323
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
2424
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
2525
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
2626
# Model architecture and parameters
27-
MODEL_ARCH = BigST
27+
MODEL_ARCH = BigSTPreprocess
2828
adj_mx, _ = load_adj("datasets/" + DATA_NAME +
2929
"/adj_mx.pkl", "doubletransition")
3030
MODEL_PARAM = {
31-
"num_nodes": 307,
32-
"seq_num": INPUT_LEN,
31+
"num_nodes": 170,
3332
"in_dim": 3,
34-
"out_dim": OUTPUT_LEN,
35-
"hid_dim": 32,
36-
"tau" : 0.25,
37-
"random_feature_dim": 64,
38-
"node_emb_dim": 32,
39-
"time_emb_dim": 32,
40-
"use_residual": True,
41-
"use_bn": True,
42-
"use_spatial": True,
43-
"use_long": False,
4433
"dropout": 0.3,
45-
"supports": [torch.tensor(i) for i in adj_mx],
46-
"time_of_day_size": 288,
47-
"day_of_week_size": 7,
34+
"input_length": INPUT_LEN,
35+
"output_length": OUTPUT_LEN,
36+
"nhid": 32,
37+
"tiny_batch_size": 64,
38+
4839
}
4940

5041
NUM_EPOCHS = 100
@@ -55,7 +46,7 @@
5546
CFG.DESCRIPTION = 'An Example Config'
5647
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
5748
# Runner
58-
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
49+
CFG.RUNNER = BigSTPreprocessRunner
5950

6051
############################## Environment Configuration ##############################
6152

@@ -115,7 +106,7 @@
115106
MODEL_ARCH.__name__,
116107
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
117108
)
118-
CFG.TRAIN.LOSS = bigst_loss
109+
CFG.TRAIN.LOSS = masked_mae
119110
# Optimizer settings
120111
CFG.TRAIN.OPTIM = EasyDict()
121112
CFG.TRAIN.OPTIM.TYPE = "AdamW"
@@ -132,7 +123,7 @@
132123
}
133124
# Train data loader settings
134125
CFG.TRAIN.DATA = EasyDict()
135-
CFG.TRAIN.DATA.BATCH_SIZE = 64
126+
CFG.TRAIN.DATA.BATCH_SIZE = 1
136127
CFG.TRAIN.DATA.SHUFFLE = True
137128
# Gradient clipping settings
138129
CFG.TRAIN.CLIP_GRAD_PARAM = {
@@ -143,13 +134,13 @@
143134
CFG.VAL = EasyDict()
144135
CFG.VAL.INTERVAL = 1
145136
CFG.VAL.DATA = EasyDict()
146-
CFG.VAL.DATA.BATCH_SIZE = 64
137+
CFG.VAL.DATA.BATCH_SIZE = 1
147138

148139
############################## Test Configuration ##############################
149140
CFG.TEST = EasyDict()
150141
CFG.TEST.INTERVAL = 1
151142
CFG.TEST.DATA = EasyDict()
152-
CFG.TEST.DATA.BATCH_SIZE = 64
143+
CFG.TEST.DATA.BATCH_SIZE = 1
153144

154145
############################## Evaluation Configuration ##############################
155146

baselines/BigST/arch/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from .bigst_arch import BigST
2+
from .preprocess import BigSTPreprocess
23

3-
__all__ = ["BigST"]
4+
5+
__all__ = ["BigST", "BigSTPreprocess"]

0 commit comments

Comments
 (0)