Skip to content

Commit c2940f2

Browse files
committed
feat: 🎸 add a LTSF baseline SOFTS (added by @superarthurlx)
1 parent 47e0070 commit c2940f2

File tree

6 files changed

+307
-10
lines changed

6 files changed

+307
-10
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ The code links (💻Code) in the table below point to the official implementatio
134134
135135
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
136136
| :------------ | :------------------------------------------------------------------------------------------------------- | :----------------------------------------------------- | :---------------------------------------------------------------------------- | :--------- | :----- |
137+
| SOFTS | SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion | [Link](https://arxiv.org/pdf/2404.14197) | [Link](https://github.com/Secilia-Cxy/SOFTS) | NeurIPS'24 | LTSF |
137138
| CATS | Are Self-Attentions Effective for Time Series Forecasting? | [Link](https://arxiv.org/pdf/2405.16877) | [Link](https://github.com/dongbeank/CATS) | NeurIPS'24 | LTSF |
138139
| Sumba | Structured Matrix Basis for Multivariate Time Series Forecasting with Interpretable Dynamics | [Link](https://xiucheng.org/assets/pdfs/nips24-sumba.pdf) | [Link](https://github.com/chenxiaodanhit/Sumba/) | NeurIPS'24 | LTSF |
139140
| GLAFF | Rethinking the Power of Timestamps for Robust Time Series Forecasting: A Global-Local Fusion Perspective | [Link](https://arxiv.org/pdf/2409.18696) | [Link](https://github.com/ForestsKing/GLAFF) | NeurIPS'24 | LTSF |

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ BasicTS 实现了丰富的基线模型,包括经典模型、时空预测模型
135135
136136
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
137137
| :------------ | :------------------------------------------------------------------------------------------------------- | :----------------------------------------------------- | :---------------------------------------------------------------------------- | :--------- | :----- |
138+
| SOFTS | SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion | [Link](https://arxiv.org/pdf/2404.14197) | [Link](https://github.com/Secilia-Cxy/SOFTS) | NeurIPS'24 | LTSF |
138139
| CATS | Are Self-Attentions Effective for Time Series Forecasting? | [Link](https://arxiv.org/pdf/2405.16877) | [Link](https://github.com/dongbeank/CATS) | NeurIPS'24 | LTSF |
139140
| Sumba | Structured Matrix Basis for Multivariate Time Series Forecasting with Interpretable Dynamics | [Link](https://xiucheng.org/assets/pdfs/nips24-sumba.pdf) | [Link](https://github.com/chenxiaodanhit/Sumba/) | NeurIPS'24 | LTSF |
140141
| GLAFF | Rethinking the Power of Timestamps for Robust Time Series Forecasting: A Global-Local Fusion Perspective | [Link](https://arxiv.org/pdf/2409.18696) | [Link](https://github.com/ForestsKing/GLAFF) | NeurIPS'24 | LTSF |

baselines/SOFTS/ETTh1.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import os
2+
import sys
3+
from easydict import EasyDict
4+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
5+
from basicts.metrics import masked_mae, masked_mse, masked_mape, masked_rmse
6+
from basicts.data import TimeSeriesForecastingDataset
7+
from basicts.runners import SimpleTimeSeriesForecastingRunner
8+
from basicts.scaler import ZScoreScaler
9+
from basicts.utils import get_regular_settings
10+
11+
from .arch import SOFTS
12+
13+
############################## Hot Parameters ##############################
14+
# Dataset & Metrics configuration
15+
DATA_NAME = 'ETTh1' # Dataset name
16+
regular_settings = get_regular_settings(DATA_NAME)
17+
INPUT_LEN = regular_settings['INPUT_LEN'] # 336, better performance
18+
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
19+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
20+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
21+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
22+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
23+
# Model architecture and parameters
24+
MODEL_ARCH = SOFTS
25+
NUM_NODES = 7
26+
MODEL_PARAM = {
27+
"enc_in": NUM_NODES, # num nodes
28+
"dec_in": NUM_NODES,
29+
"c_out": NUM_NODES,
30+
"seq_len": INPUT_LEN,
31+
"pred_len": OUTPUT_LEN, # prediction sequence length
32+
"e_layers": 2, # num of encoder layers
33+
"d_model": 256,
34+
"d_core": 256,
35+
"d_ff": 512,
36+
"dropout": 0.0,
37+
"use_norm" : True,
38+
"activation": "gelu",
39+
"num_time_features": 4, # number of used time features
40+
"time_of_day_size": 24,
41+
"day_of_week_size": 7,
42+
"day_of_month_size": 31,
43+
"day_of_year_size": 366
44+
}
45+
NUM_EPOCHS = 50
46+
47+
############################## General Configuration ##############################
48+
CFG = EasyDict()
49+
# General settings
50+
CFG.DESCRIPTION = 'An Example Config'
51+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
52+
# Runner
53+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
54+
55+
############################## Dataset Configuration ##############################
56+
CFG.DATASET = EasyDict()
57+
# Dataset settings
58+
CFG.DATASET.NAME = DATA_NAME
59+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
60+
CFG.DATASET.PARAM = EasyDict({
61+
'dataset_name': DATA_NAME,
62+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
63+
'input_len': INPUT_LEN,
64+
'output_len': OUTPUT_LEN,
65+
# 'mode' is automatically set by the runner
66+
})
67+
68+
############################## Scaler Configuration ##############################
69+
CFG.SCALER = EasyDict()
70+
# Scaler settings
71+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
72+
CFG.SCALER.PARAM = EasyDict({
73+
'dataset_name': DATA_NAME,
74+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
75+
'norm_each_channel': NORM_EACH_CHANNEL,
76+
'rescale': RESCALE,
77+
})
78+
79+
############################## Model Configuration ##############################
80+
CFG.MODEL = EasyDict()
81+
# Model settings
82+
CFG.MODEL.NAME = MODEL_ARCH.__name__
83+
CFG.MODEL.ARCH = MODEL_ARCH
84+
CFG.MODEL.PARAM = MODEL_PARAM
85+
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4]
86+
CFG.MODEL.TARGET_FEATURES = [0]
87+
88+
############################## Metrics Configuration ##############################
89+
90+
CFG.METRICS = EasyDict()
91+
# Metrics settings
92+
CFG.METRICS.FUNCS = EasyDict({
93+
'MAE': masked_mae,
94+
'MSE': masked_mse,
95+
'RMSE': masked_rmse,
96+
'MAPE': masked_mape
97+
})
98+
CFG.METRICS.TARGET = 'MAE'
99+
CFG.METRICS.NULL_VAL = NULL_VAL
100+
101+
############################## Training Configuration ##############################
102+
CFG.TRAIN = EasyDict()
103+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
104+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
105+
'checkpoints',
106+
MODEL_ARCH.__name__,
107+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
108+
)
109+
CFG.TRAIN.LOSS = masked_mae
110+
# Optimizer settings
111+
CFG.TRAIN.OPTIM = EasyDict()
112+
CFG.TRAIN.OPTIM.TYPE = "Adam"
113+
CFG.TRAIN.OPTIM.PARAM = {
114+
"lr": 0.0003,
115+
}
116+
# Learning rate scheduler settings
117+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
118+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
119+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
120+
"milestones": [1, 25, 50],
121+
"gamma": 0.5
122+
}
123+
CFG.TRAIN.CLIP_GRAD_PARAM = {
124+
'max_norm': 5.0
125+
}
126+
# Train data loader settings
127+
CFG.TRAIN.DATA = EasyDict()
128+
CFG.TRAIN.DATA.BATCH_SIZE = 64
129+
CFG.TRAIN.DATA.SHUFFLE = True
130+
CFG.TRAIN.EARLY_STOPPING_PATIENCE = 10
131+
132+
############################## Validation Configuration ##############################
133+
CFG.VAL = EasyDict()
134+
CFG.VAL.INTERVAL = 1
135+
CFG.VAL.DATA = EasyDict()
136+
CFG.VAL.DATA.BATCH_SIZE = 64
137+
138+
############################## Test Configuration ##############################
139+
CFG.TEST = EasyDict()
140+
CFG.TEST.INTERVAL = 1
141+
CFG.TEST.DATA = EasyDict()
142+
CFG.TEST.DATA.BATCH_SIZE = 64
143+
144+
############################## Evaluation Configuration ##############################
145+
146+
CFG.EVAL = EasyDict()
147+
148+
# Evaluation parameters
149+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

baselines/SOFTS/ETTh2.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import os
2+
import sys
3+
from easydict import EasyDict
4+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
5+
from basicts.metrics import masked_mae, masked_mse, masked_mape, masked_rmse
6+
from basicts.data import TimeSeriesForecastingDataset
7+
from basicts.runners import SimpleTimeSeriesForecastingRunner
8+
from basicts.scaler import ZScoreScaler
9+
from basicts.utils import get_regular_settings
10+
11+
from .arch import SOFTS
12+
13+
############################## Hot Parameters ##############################
14+
# Dataset & Metrics configuration
15+
DATA_NAME = 'ETTh2' # Dataset name
16+
regular_settings = get_regular_settings(DATA_NAME)
17+
# INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
18+
INPUT_LEN = 192 # better performance
19+
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
20+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
21+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
22+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
23+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
24+
# Model architecture and parameters
25+
MODEL_ARCH = SOFTS
26+
NUM_NODES = 7
27+
MODEL_PARAM = {
28+
"enc_in": NUM_NODES, # num nodes
29+
"dec_in": NUM_NODES,
30+
"c_out": NUM_NODES,
31+
"seq_len": INPUT_LEN,
32+
"pred_len": OUTPUT_LEN, # prediction sequence length
33+
"e_layers": 2, # num of encoder layers
34+
"d_model": 128,
35+
"d_core": 64,
36+
"d_ff": 128,
37+
"dropout": 0.0,
38+
"use_norm" : True,
39+
"activation": "gelu",
40+
"num_time_features": 4, # number of used time features
41+
"time_of_day_size": 24,
42+
"day_of_week_size": 7,
43+
"day_of_month_size": 31,
44+
"day_of_year_size": 366
45+
}
46+
NUM_EPOCHS = 20
47+
48+
############################## General Configuration ##############################
49+
CFG = EasyDict()
50+
# General settings
51+
CFG.DESCRIPTION = 'An Example Config'
52+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
53+
# Runner
54+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
55+
56+
CFG.ENV = EasyDict() # Environment settings. Default: None
57+
CFG.ENV.SEED = 2024 # Random seed. Default: None
58+
59+
############################## Dataset Configuration ##############################
60+
CFG.DATASET = EasyDict()
61+
# Dataset settings
62+
CFG.DATASET.NAME = DATA_NAME
63+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
64+
CFG.DATASET.PARAM = EasyDict({
65+
'dataset_name': DATA_NAME,
66+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
67+
'input_len': INPUT_LEN,
68+
'output_len': OUTPUT_LEN,
69+
# 'mode' is automatically set by the runner
70+
})
71+
72+
############################## Scaler Configuration ##############################
73+
CFG.SCALER = EasyDict()
74+
# Scaler settings
75+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
76+
CFG.SCALER.PARAM = EasyDict({
77+
'dataset_name': DATA_NAME,
78+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
79+
'norm_each_channel': NORM_EACH_CHANNEL,
80+
'rescale': RESCALE,
81+
})
82+
83+
############################## Model Configuration ##############################
84+
CFG.MODEL = EasyDict()
85+
# Model settings
86+
CFG.MODEL.NAME = MODEL_ARCH.__name__
87+
CFG.MODEL.ARCH = MODEL_ARCH
88+
CFG.MODEL.PARAM = MODEL_PARAM
89+
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4]
90+
CFG.MODEL.TARGET_FEATURES = [0]
91+
92+
############################## Metrics Configuration ##############################
93+
94+
CFG.METRICS = EasyDict()
95+
# Metrics settings
96+
CFG.METRICS.FUNCS = EasyDict({
97+
'MAE': masked_mae,
98+
'MSE': masked_mse,
99+
'RMSE': masked_rmse,
100+
'MAPE': masked_mape
101+
})
102+
CFG.METRICS.TARGET = 'MAE'
103+
CFG.METRICS.NULL_VAL = NULL_VAL
104+
105+
############################## Training Configuration ##############################
106+
CFG.TRAIN = EasyDict()
107+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
108+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
109+
'checkpoints',
110+
MODEL_ARCH.__name__,
111+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
112+
)
113+
CFG.TRAIN.LOSS = masked_mse
114+
# Optimizer settings
115+
CFG.TRAIN.OPTIM = EasyDict()
116+
CFG.TRAIN.OPTIM.TYPE = "Adam"
117+
CFG.TRAIN.OPTIM.PARAM = {
118+
"lr": 0.0003,
119+
}
120+
# Learning rate scheduler settings
121+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
122+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
123+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
124+
"milestones": [1, 25, 50],
125+
"gamma": 0.5
126+
}
127+
CFG.TRAIN.CLIP_GRAD_PARAM = {
128+
'max_norm': 5.0
129+
}
130+
# Train data loader settings
131+
CFG.TRAIN.DATA = EasyDict()
132+
CFG.TRAIN.DATA.BATCH_SIZE = 32
133+
CFG.TRAIN.DATA.SHUFFLE = True
134+
CFG.TRAIN.EARLY_STOPPING_PATIENCE = 10
135+
136+
############################## Validation Configuration ##############################
137+
CFG.VAL = EasyDict()
138+
CFG.VAL.INTERVAL = 1
139+
CFG.VAL.DATA = EasyDict()
140+
CFG.VAL.DATA.BATCH_SIZE = 64
141+
142+
############################## Test Configuration ##############################
143+
CFG.TEST = EasyDict()
144+
CFG.TEST.INTERVAL = 1
145+
CFG.TEST.DATA = EasyDict()
146+
CFG.TEST.DATA.BATCH_SIZE = 64
147+
148+
############################## Evaluation Configuration ##############################
149+
150+
CFG.EVAL = EasyDict()
151+
152+
# Evaluation parameters
153+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

baselines/SOFTS/Weather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# Dataset & Metrics configuration
1515
DATA_NAME = 'Weather' # Dataset name
1616
regular_settings = get_regular_settings(DATA_NAME)
17-
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
17+
INPUT_LEN = regular_settings['INPUT_LEN'] # 336, better performance
1818
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
1919
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
2020
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data

baselines/SOFTS/arch/softs_arch.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class SOFTS(nn.Module):
5454
'''
5555
Paper: SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion
5656
Official Code: https://github.com/Secilia-Cxy/SOFTS
57-
Link: https://xiucheng.org/assets/pdfs/nips24-sumba.pdf
57+
Link: https://arxiv.org/pdf/2404.14197
5858
Venue: NeurIPS 2024
5959
Task: Long-term Time Series Forecasting
6060
'''
@@ -119,18 +119,11 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_s
119119
torch.Tensor: outputs with shape [B, L2, N, 1]
120120
"""
121121

122-
# change MinuteOfDay to MinuteOfHour
122+
# change TimeOfDay to MinuteOfHour
123123
history_data[..., 1] = history_data[..., 1] * self.time_of_day_size // (self.time_of_day_size / 24) / 23.0
124124
x_enc, x_mark_enc, x_dec, x_mark_dec = data_transformation_4_xformer(history_data=history_data,
125125
future_data=future_data,
126126
start_token_len=0)
127127
#print(x_mark_enc.shape, x_mark_dec.shape)
128128
prediction = self.forward_xformer(x_enc=x_enc, x_mark_enc=x_mark_enc)
129129
return prediction.unsqueeze(-1)
130-
131-
132-
133-
134-
135-
136-

0 commit comments

Comments
 (0)