Skip to content

Commit 6b7b29b

Browse files
committed
feat: 🎸 add inference script
add experimental script for inference processing
1 parent f606324 commit 6b7b29b

File tree

8 files changed

+486
-9
lines changed

8 files changed

+486
-9
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ignore=baselines,assets,checkpoints
1212

1313
# Files or directories matching the regex patterns are skipped. The regex
1414
# matches against base names, not paths.
15-
ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.CFF|^LICENSE
15+
ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.csv|^.*\.CFF|^LICENSE
1616

1717
# Pickle collected data for later comparisons.
1818
persistent=no

basicts/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .launcher import launch_evaluation, launch_training
1+
from .launcher import launch_evaluation, launch_inference, launch_training
22
from .runners import BaseEpochRunner
33

4-
__version__ = '0.4.6.3'
4+
__version__ = '0.4.6.4'
55

6-
__all__ = ['__version__', 'launch_training', 'launch_evaluation', 'BaseEpochRunner']
6+
__all__ = ['__version__', 'launch_training', 'launch_evaluation', 'BaseEpochRunner', 'launch_inference']
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import json
2+
import logging
3+
from typing import List, Tuple, Union
4+
5+
import numpy as np
6+
import pandas as pd
7+
8+
from .base_dataset import BaseDataset
9+
10+
11+
class TimeSeriesInferenceDataset(BaseDataset):
12+
"""
13+
A dataset class for time series inference tasks, where the input is a sequence of historical data points
14+
15+
Attributes:
16+
description_file_path (str): Path to the JSON file containing the description of the dataset.
17+
description (dict): Metadata about the dataset, such as shape and other properties.
18+
data (np.ndarray): The loaded time series data array.
19+
raw_data (str): The raw data path or data list of the dataset.
20+
last_datetime (pd.Timestamp): The last datetime in the dataset. Used to generate time features of future data.
21+
"""
22+
23+
def __init__(self, dataset_name:str, dataset: Union[str, list], input_len: int, output_len: int,
24+
logger: logging.Logger = None, train_val_test_ratio: List[float] = None) -> None:
25+
"""
26+
Initializes the TimeSeriesInferenceDataset by setting up paths, loading data, and
27+
preparing it according to the specified configurations.
28+
29+
Args:
30+
dataset_name (str): The name of the dataset.
31+
dataset(str or array): The data path of the dataset or data itself.
32+
input_len(str): The length of the input sequence (number of historical points).
33+
output_len(str): The length of the output sequence (number of future points to predict).
34+
logger (logging.Logger): logger.
35+
train_val_test_ratio (List[float]): The ratio of train, validation, and test data. Just for compatibility.
36+
Raises:
37+
38+
"""
39+
train_val_test_ratio: List[float] = []
40+
mode: str = 'inference'
41+
overlap = False
42+
super().__init__(dataset_name, train_val_test_ratio, mode, input_len, output_len, overlap)
43+
self.logger = logger
44+
45+
self.description_file_path = f'datasets/{dataset_name}/desc.json'
46+
self.description = self._load_description()
47+
48+
self.last_datetime:pd.Timestamp = pd.Timestamp.now()
49+
self._raw_data = dataset
50+
self.data = self._load_data()
51+
52+
def _load_description(self) -> dict:
53+
"""
54+
Loads the description of the dataset from a JSON file.
55+
56+
Returns:
57+
dict: A dictionary containing metadata about the dataset, such as its shape and other properties.
58+
59+
Raises:
60+
FileNotFoundError: If the description file is not found.
61+
json.JSONDecodeError: If there is an error decoding the JSON data.
62+
"""
63+
try:
64+
with open(self.description_file_path, 'r') as f:
65+
return json.load(f)
66+
except FileNotFoundError as e:
67+
raise FileNotFoundError(f'Description file not found: {self.description_file_path}') from e
68+
except json.JSONDecodeError as e:
69+
raise ValueError(f'Error decoding JSON file: {self.description_file_path}') from e
70+
71+
def _load_data(self) -> np.ndarray:
72+
"""
73+
Loads the time series data from a file or list and processes it according to the dataset description.
74+
Returns:
75+
np.ndarray: The data array for the specified mode (train, validation, or test).
76+
77+
Raises:
78+
ValueError: If there is an issue with loading the data file or if the data shape is not as expected.
79+
"""
80+
81+
if isinstance(self._raw_data, str):
82+
df = pd.read_csv(self._raw_data, header=None)
83+
else:
84+
df = pd.DataFrame(self._raw_data)
85+
86+
df_index = pd.to_datetime(df[0].values, format='%Y-%m-%d %H:%M:%S').to_numpy()
87+
df = df[df.columns[1:]]
88+
df.index = pd.Index(df_index)
89+
df = df.astype('float32')
90+
self.last_datetime = df.index[-1]
91+
92+
data = np.expand_dims(df.values, axis=-1)
93+
data = data[..., [0]]
94+
95+
data_with_features = self._add_temporal_features(data, df)
96+
97+
data_set_shape = self.description['shape']
98+
_, n, c = data_with_features.shape
99+
if data_set_shape[1] != n or data_set_shape[2] != c:
100+
raise ValueError(f'Error loading data. Shape mismatch: expected {data_set_shape[1:]}, got {[n,c]}.')
101+
102+
return data_with_features
103+
104+
def _add_temporal_features(self, data, df) -> np.ndarray:
105+
'''
106+
Add time of day and day of week as features to the data.
107+
108+
Args:
109+
data (np.ndarray): The data array.
110+
df (pd.DataFrame): The dataframe containing the datetime index.
111+
112+
Returns:
113+
np.ndarray: The data array with added time of day and day of week features.
114+
'''
115+
116+
_, n, _ = data.shape
117+
feature_list = [data]
118+
119+
# numerical time_of_day
120+
tod = (df.index.hour*60 + df.index.minute) / (24*60)
121+
tod_tiled = np.tile(tod, [1, n, 1]).transpose((2, 1, 0))
122+
feature_list.append(tod_tiled)
123+
124+
# numerical day_of_week
125+
dow = df.index.dayofweek / 7
126+
dow_tiled = np.tile(dow, [1, n, 1]).transpose((2, 1, 0))
127+
feature_list.append(dow_tiled)
128+
129+
# numerical day_of_month
130+
dom = (df.index.day - 1) / 31 # df.index.day starts from 1. We need to minus 1 to make it start from 0.
131+
dom_tiled = np.tile(dom, [1, n, 1]).transpose((2, 1, 0))
132+
feature_list.append(dom_tiled)
133+
134+
# numerical day_of_year
135+
doy = (df.index.dayofyear - 1) / 366 # df.index.month starts from 1. We need to minus 1 to make it start from 0.
136+
doy_tiled = np.tile(doy, [1, n, 1]).transpose((2, 1, 0))
137+
feature_list.append(doy_tiled)
138+
139+
data_with_features = np.concatenate(feature_list, axis=-1).astype('float32') # L x N x C
140+
141+
# Remove extra features
142+
data_set_shape = self.description['shape']
143+
data_with_features = data_with_features[..., range(data_set_shape[2])]
144+
145+
return data_with_features
146+
147+
def append_data(self, new_data: np.ndarray) -> None:
148+
"""
149+
Append new data to the existing data
150+
151+
Args:
152+
new_data (np.ndarray): The new data to append to the existing data.
153+
"""
154+
155+
freq = self.description['frequency (minutes)']
156+
l, _, _ = new_data.shape
157+
158+
data_with_features, datetime_list = self._gen_datetime_list(new_data, self.last_datetime, freq, l)
159+
self.last_datetime = datetime_list[-1]
160+
161+
self.data = np.concatenate([self.data, data_with_features], axis=0)
162+
163+
def _gen_datetime_list(self, new_data: np.ndarray, start_datetime: pd.Timestamp, freq: int, num_steps: int) -> Tuple[np.ndarray, List[pd.Timestamp]]:
164+
"""
165+
Generate a list of datetime objects based on the start datetime, frequency, and number of steps.
166+
167+
Args:
168+
start_datetime (pd.Timestamp): The starting datetime for the sequence.
169+
freq (int): The frequency of the data in minutes.
170+
num_steps (int): The number of steps in the sequence.
171+
172+
Returns:
173+
List[pd.Timestamp]: A list of datetime objects corresponding to the sequence.
174+
"""
175+
datetime_list = [start_datetime]
176+
for _ in range(num_steps):
177+
datetime_list.append(datetime_list[-1] + pd.Timedelta(minutes=freq))
178+
new_index = pd.Index(datetime_list[1:])
179+
new_df = pd.DataFrame()
180+
new_df.index = new_index
181+
data_with_features = self._add_temporal_features(new_data, new_df)
182+
183+
return data_with_features, datetime_list
184+
185+
def __getitem__(self, index: int) -> dict:
186+
"""
187+
Retrieves a sample from the dataset, considering both the input and output lengths.
188+
For inference, the input data is the last 'input_len' points in the dataset, and the output data is the next 'output_len' points.
189+
190+
Args:
191+
index (int): The index of the desired sample in the dataset.
192+
193+
Returns:
194+
dict: A dictionary containing 'inputs' and 'target', where both are slices of the dataset corresponding to
195+
the historical input data and future prediction data, respectively.
196+
"""
197+
history_data = self.data[-self.input_len:]
198+
199+
freq = self.description['frequency (minutes)']
200+
_, n, _ = history_data.shape
201+
future_data = np.zeros((self.output_len, n, 1))
202+
203+
data_with_features, _ = self._gen_datetime_list(future_data, self.last_datetime, freq, self.output_len)
204+
return {'inputs': history_data, 'target': data_with_features}
205+
206+
def __len__(self) -> int:
207+
"""
208+
Calculates the total number of samples available in the dataset.
209+
For inference, there is only one valid sample, as the input data is the last 'input_len' points in the dataset.
210+
211+
Returns:
212+
int: The number of valid samples that can be drawn from the dataset, based on the configurations of input and output lengths.
213+
"""
214+
return 1

basicts/launcher.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,104 @@ def launch_training(cfg: Union[Dict, str],
132132

133133
# launch the training process
134134
easytorch.launch_training(cfg=cfg, devices=gpus, node_rank=node_rank)
135+
136+
def inference_func(cfg: Dict,
137+
input_data_file_path: str,
138+
output_data_file_path: str,
139+
ckpt_path: str,
140+
strict: bool = True) -> None:
141+
"""
142+
Starts the inference process.
143+
144+
This function performs the following steps:
145+
1. Initializes the runner specified in the configuration (`cfg`).
146+
2. Sets up logging for the inference process.
147+
3. Loads the model checkpoint.
148+
4. Executes the inference pipeline using the initialized runner.
149+
150+
Args:
151+
cfg (Dict): EasyTorch configuration dictionary.
152+
input_data_file_path (str): Path to the input data file.
153+
output_data_file_path (str): Path to the output data file.
154+
ckpt_path (str): Path to the model checkpoint. If not provided, the best model checkpoint is loaded automatically.
155+
strict (bool): Enforces that the checkpoint keys match the model. Defaults to True.
156+
157+
Raises:
158+
Exception: Catches any exception, logs the traceback, and re-raises it.
159+
"""
160+
161+
# initialize the runner
162+
logger = get_logger('easytorch-launcher')
163+
logger.info(f"Initializing runner '{cfg['RUNNER']}'")
164+
runner = cfg['RUNNER'](cfg)
165+
166+
# initialize the logger for the runner
167+
runner.init_logger(logger_name='easytorch-inference', log_file_name='inference_log')
168+
169+
# setup the graph if needed
170+
if runner.need_setup_graph:
171+
runner.setup_graph(cfg=cfg, train=False)
172+
173+
try:
174+
# load the model checkpoint
175+
if ckpt_path is None or not os.path.exists(ckpt_path):
176+
ckpt_path_auto = os.path.join(runner.ckpt_save_dir, '{}_best_val_{}.pt'.format(runner.model_name, runner.target_metrics.replace('/', '_')))
177+
logger.info(f'Checkpoint file not found at {ckpt_path}. Loading the best model checkpoint `{ckpt_path_auto}` automatically.')
178+
if not os.path.exists(ckpt_path_auto):
179+
raise FileNotFoundError(f'Checkpoint file not found at {ckpt_path}')
180+
runner.load_model(ckpt_path=ckpt_path_auto, strict=strict)
181+
else:
182+
logger.info(f'Loading model checkpoint from {ckpt_path}')
183+
runner.load_model(ckpt_path=ckpt_path, strict=strict)
184+
185+
# start the inference pipeline
186+
runner.inference_pipeline(cfg=cfg, input_data=input_data_file_path, output_data_file_path=output_data_file_path)
187+
188+
except BaseException as e:
189+
# log the exception and re-raise it
190+
runner.logger.error(traceback.format_exc())
191+
raise e
192+
193+
def launch_inference(cfg: Union[Dict, str],
194+
ckpt_path: str,
195+
input_data_file_path: str,
196+
output_data_file_path: str,
197+
device_type: str = 'gpu',
198+
gpus: Optional[str] = None) -> None:
199+
"""
200+
Launches the inference process.
201+
202+
Args:
203+
cfg (Union[Dict, str]): EasyTorch configuration as a dictionary or a path to a config file.
204+
ckpt_path (str): Path to the model checkpoint.
205+
input_data_file_path (str): Path to the input data file.
206+
output_data_file_path (str): Path to the output data file.
207+
device_type (str, optional): Device type to use ('cpu' or 'gpu'). Defaults to 'gpu'.
208+
gpus (Optional[str]): GPU device IDs to use. Defaults to None (use all available GPUs).
209+
210+
Raises:
211+
AssertionError: If the batch size is not specified in either the config or as an argument.
212+
"""
213+
214+
logger = get_logger('easytorch-launcher')
215+
logger.info('Launching EasyTorch inference.')
216+
217+
# check params
218+
# cfg path which start with dot will crash the easytorch, just remove dot
219+
while isinstance(cfg, str) and cfg.startswith(('./','.\\')):
220+
cfg = cfg[2:]
221+
while ckpt_path.startswith(('./','.\\')):
222+
ckpt_path = ckpt_path[2:]
223+
224+
# initialize the configuration
225+
cfg_dict = init_cfg(cfg, save=True)
226+
227+
# set the device type (CPU, GPU, or MLU)
228+
set_device_type(device_type)
229+
230+
# set the visible GPUs if the device type is not CPU
231+
if device_type != 'cpu':
232+
set_visible_devices(gpus)
233+
234+
# run the inference process
235+
inference_func(cfg_dict, input_data_file_path, output_data_file_path, ckpt_path)

0 commit comments

Comments
 (0)