Skip to content

Commit 80361c6

Browse files
committed
docs: ✏️ update configs of iTransformer
1 parent c2940f2 commit 80361c6

File tree

2 files changed

+160
-1
lines changed

2 files changed

+160
-1
lines changed

baselines/iTransformer/Electricity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"sigma" : 0.2,
4444
"dropout": 0.1,
4545
"freq": 'h',
46-
"use_norm" : False,
46+
"use_norm" : True,
4747
"output_attention": False,
4848
"embed": "timeF", # [timeF, fixed, learned]
4949
"activation": "gelu",

baselines/iTransformer/Traffic.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import os
2+
import sys
3+
from easydict import EasyDict
4+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
5+
from basicts.metrics import masked_mae, masked_mse, masked_mape, masked_rmse
6+
from basicts.data import TimeSeriesForecastingDataset
7+
from basicts.runners import SimpleTimeSeriesForecastingRunner
8+
from basicts.scaler import ZScoreScaler
9+
from basicts.utils import get_regular_settings
10+
11+
from .arch import iTransformer
12+
13+
############################## Hot Parameters ##############################
14+
# Dataset & Metrics configuration
15+
DATA_NAME = 'Traffic' # Dataset name
16+
regular_settings = get_regular_settings(DATA_NAME)
17+
INPUT_LEN = 96
18+
OUTPUT_LEN = 720
19+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
20+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
21+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
22+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
23+
# Model architecture and parameters
24+
MODEL_ARCH = iTransformer
25+
NUM_NODES = 862
26+
MODEL_PARAM = {
27+
"enc_in": NUM_NODES, # num nodes
28+
"dec_in": NUM_NODES,
29+
"c_out": NUM_NODES,
30+
"seq_len": INPUT_LEN,
31+
"label_len": INPUT_LEN/2, # start token length used in decoder
32+
"pred_len": OUTPUT_LEN, # prediction sequence length
33+
"factor": 3, # attn factor
34+
"p_hidden_dims": [128, 128],
35+
"p_hidden_layers": 2,
36+
"d_model": 512,
37+
"moving_avg": 25, # window size of moving average. This is a CRUCIAL hyper-parameter.
38+
"n_heads": 8,
39+
"e_layers": 4, # num of encoder layers
40+
"d_layers": 1, # num of decoder layers
41+
"d_ff": 512,
42+
"distil": True,
43+
"sigma" : 0.2,
44+
"dropout": 0.1,
45+
"freq": 'h',
46+
"use_norm" : True,
47+
"output_attention": False,
48+
"embed": "timeF", # [timeF, fixed, learned]
49+
"activation": "gelu",
50+
"num_time_features": 4, # number of used time features
51+
"time_of_day_size": 24,
52+
"day_of_week_size": 7,
53+
"day_of_month_size": 31,
54+
"day_of_year_size": 366
55+
}
56+
NUM_EPOCHS = 20
57+
58+
############################## General Configuration ##############################
59+
CFG = EasyDict()
60+
# General settings
61+
CFG.DESCRIPTION = 'An Example Config'
62+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
63+
# Runner
64+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
65+
66+
############################## Dataset Configuration ##############################
67+
CFG.DATASET = EasyDict()
68+
# Dataset settings
69+
CFG.DATASET.NAME = DATA_NAME
70+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
71+
CFG.DATASET.PARAM = EasyDict({
72+
'dataset_name': DATA_NAME,
73+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
74+
'input_len': INPUT_LEN,
75+
'output_len': OUTPUT_LEN,
76+
# 'mode' is automatically set by the runner
77+
})
78+
79+
############################## Scaler Configuration ##############################
80+
CFG.SCALER = EasyDict()
81+
# Scaler settings
82+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
83+
CFG.SCALER.PARAM = EasyDict({
84+
'dataset_name': DATA_NAME,
85+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
86+
'norm_each_channel': NORM_EACH_CHANNEL,
87+
'rescale': RESCALE,
88+
})
89+
90+
############################## Model Configuration ##############################
91+
CFG.MODEL = EasyDict()
92+
# Model settings
93+
CFG.MODEL.NAME = MODEL_ARCH.__name__
94+
CFG.MODEL.ARCH = MODEL_ARCH
95+
CFG.MODEL.PARAM = MODEL_PARAM
96+
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4]
97+
CFG.MODEL.TARGET_FEATURES = [0]
98+
99+
############################## Metrics Configuration ##############################
100+
101+
CFG.METRICS = EasyDict()
102+
# Metrics settings
103+
CFG.METRICS.FUNCS = EasyDict({
104+
'MAE': masked_mae,
105+
'MSE': masked_mse,
106+
'RMSE': masked_rmse,
107+
'MAPE': masked_mape
108+
})
109+
CFG.METRICS.TARGET = 'MSE'
110+
CFG.METRICS.NULL_VAL = NULL_VAL
111+
112+
############################## Training Configuration ##############################
113+
CFG.TRAIN = EasyDict()
114+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
115+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
116+
'checkpoints',
117+
MODEL_ARCH.__name__,
118+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
119+
)
120+
CFG.TRAIN.LOSS = masked_mae
121+
# Optimizer settings
122+
CFG.TRAIN.OPTIM = EasyDict()
123+
CFG.TRAIN.OPTIM.TYPE = "Adam"
124+
CFG.TRAIN.OPTIM.PARAM = {
125+
"lr": 0.001,
126+
}
127+
# Learning rate scheduler settings
128+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
129+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
130+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
131+
"milestones": [5, 10],
132+
"gamma": 0.5
133+
}
134+
CFG.TRAIN.CLIP_GRAD_PARAM = {
135+
'max_norm': 5.0
136+
}
137+
# Train data loader settings
138+
CFG.TRAIN.DATA = EasyDict()
139+
CFG.TRAIN.DATA.BATCH_SIZE = 32
140+
CFG.TRAIN.DATA.SHUFFLE = True
141+
142+
############################## Validation Configuration ##############################
143+
CFG.VAL = EasyDict()
144+
CFG.VAL.INTERVAL = 1
145+
CFG.VAL.DATA = EasyDict()
146+
CFG.VAL.DATA.BATCH_SIZE = 32
147+
148+
############################## Test Configuration ##############################
149+
CFG.TEST = EasyDict()
150+
CFG.TEST.INTERVAL = 1
151+
CFG.TEST.DATA = EasyDict()
152+
CFG.TEST.DATA.BATCH_SIZE = 32
153+
154+
############################## Evaluation Configuration ##############################
155+
156+
CFG.EVAL = EasyDict()
157+
158+
# Evaluation parameters
159+
CFG.EVAL.USE_GPU = False # Whether to use GPU for evaluation. Default: True

0 commit comments

Comments
 (0)