Skip to content

Commit 7be1c45

Browse files
[MRG] Add PyTorch backend for soft-DTW (#431)
Co-authored-by: Romain Tavenard <[email protected]>
1 parent a091483 commit 7be1c45

34 files changed

+3091
-1125
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@ Changelogs for this project are recorded in this file since v0.2.0.
1010

1111
## [Towards v0.6]
1212

13+
### Added
14+
15+
* Support of the `PyTorch` backend for the metrics of `tslearn`.
16+
In particular, the Dynamic Time Warping (DTW) metric and the Soft-DTW metric now support the `PyTorch` backend.
17+
1318
### Removed
1419

1520
* Support for Python version 3.7 is dropped
21+
* Elements that were deprecated in v0.4 are now removed, as announced
1622

1723
## [v0.5.3]
1824

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,4 @@ If you use `tslearn` in a scientific publication, we would appreciate citations:
140140
```
141141

142142
#### Acknowledgments
143-
Authors would like to thank Mathieu Blondel for providing code for [Kernel k-means](https://gist.github.com/mblondel/6230787) and [Soft-DTW](https://github.com/mblondel/soft-dtw).
143+
Authors would like to thank Mathieu Blondel for providing code for [Kernel k-means](https://gist.github.com/mblondel/6230787) and [Soft-DTW](https://github.com/mblondel/soft-dtw), and to Mehran Maghoumi for his [`torch`-compatible implementation of SoftDTW](https://github.com/Maghoumi/pytorch-softdtw-cuda).

docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __call__(self, *args, **kwargs):
8080
'doc_module': ('tslearn',),
8181
'subsection_order': ["examples", "examples/metrics", "examples/neighbors",
8282
"examples/clustering", "examples/classification",
83-
"examples/misc"].index,
83+
"examples/autodiff", "examples/misc"].index,
8484
'image_scrapers': (matplotlib_svg_scraper(),),
8585
# 'binder': {
8686
# # Required keys

docs/examples/autodiff/README.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Automatic differentiation
2+
-------------------------
3+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Soft-DTW loss for PyTorch neural network
4+
========================================
5+
6+
The aim here is to use the Soft Dynamic Time Warping metric as a loss function of a PyTorch Neural Network for
7+
time series forecasting.
8+
9+
The `torch`-compatible implementation of the soft-DTW loss function is available from the
10+
:mod:`tslearn.metrics` module.
11+
"""
12+
13+
# Authors: Yann Cabanes, Romain Tavenard
14+
# License: BSD 3 clause
15+
# sphinx_gallery_thumbnail_number = 2
16+
17+
"""Import the modules"""
18+
19+
import numpy as np
20+
import matplotlib.pyplot as plt
21+
from tslearn.datasets import CachedDatasets
22+
from tslearn.metrics import SoftDTWLossPyTorch
23+
import torch
24+
from torch import nn
25+
26+
##############################################################################
27+
# Load the dataset
28+
# ----------------
29+
#
30+
# Using the CachedDatasets utility from tslearn, we load the "Trace" time series dataset.
31+
# The dimensions of the arrays storing the time series training and testing datasets are (100, 275, 1).
32+
# We create a new dataset X_subset made of 50 random time series from classes indexed 1 to 3
33+
# (y_train < 4) in the training set: X_subset is of shape (50, 275, 1).
34+
35+
data_loader = CachedDatasets()
36+
X_train, y_train, X_test, y_test = data_loader.load_dataset("Trace")
37+
38+
X_subset = X_train[y_train < 4]
39+
np.random.shuffle(X_subset)
40+
X_subset = X_subset[:50]
41+
42+
##############################################################################
43+
# Multi-step ahead forecasting
44+
# ----------------------------
45+
#
46+
# In this section, our goal is to implement a single-hidden-layer perceptron for time series forecasting.
47+
# Our network will be trained to minimize the soft-DTW metric.
48+
# We will rely on a `torch`-compatible implementation of the soft-DTW loss function.
49+
# The code below is an implementation of a generic Multi-Layer-Perceptron class in torch,
50+
# and we will rely on it for the implementation of a forecasting MLP with softDTW loss.
51+
52+
# Note that Soft-DTW can take negative values due to the regularization parameter gamma.
53+
# The normalized soft-DTW (also coined soft-DTW divergence) between the time series x and y is defined as:
54+
# Soft-DTW(x, y) - (Soft-DTW(x, x) + Soft-DTW(y, y)) / 2
55+
# The normalized Soft-DTW is always positive.
56+
# However, the computation time of the normalized soft-DTW equals three times the computation time of the Soft-DTW.
57+
58+
class MultiLayerPerceptron(torch.nn.Module):
59+
def __init__(self, layers, loss=None):
60+
# At init, we define our layers
61+
super(MultiLayerPerceptron, self).__init__()
62+
self.layers = layers
63+
if loss is None:
64+
self.loss = torch.nn.MSELoss(reduction="none")
65+
else:
66+
self.loss = loss
67+
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.001)
68+
69+
def forward(self, X):
70+
# The forward method informs about the forward pass: how one computes outputs of the network
71+
# from the input and the parameters of the layers registered at init
72+
if not isinstance(X, torch.Tensor):
73+
X = torch.Tensor(X)
74+
batch_size = X.size(0)
75+
X_reshaped = torch.reshape(X, (batch_size, -1)) # Manipulations to deal with time series format
76+
output = self.layers(X_reshaped)
77+
return torch.reshape(output, (batch_size, -1, 1)) # Manipulations to deal with time series format
78+
79+
def fit(self, X, y, max_epochs=10):
80+
# The fit method performs the actual optimization
81+
X_torch = torch.Tensor(X)
82+
y_torch = torch.Tensor(y)
83+
84+
for e in range(max_epochs):
85+
self.optimizer.zero_grad()
86+
# Forward pass
87+
y_pred = self.forward(X_torch)
88+
# Compute Loss
89+
loss = self.loss(y_pred, y_torch).mean()
90+
# Backward pass
91+
loss.backward()
92+
self.optimizer.step()
93+
94+
95+
##############################################################################
96+
# Using MSE as a loss function
97+
# ----------------------------
98+
#
99+
# We define an MLP class that would allow training a single-hidden-layer model using
100+
# mean squared error (MSE) as a loss function to be optimized.
101+
# We train the network for 1000 epochs on a forecasting task that would consist,
102+
# given the first 150 elements of a time series, in predicting the next 125 ones.
103+
104+
model = MultiLayerPerceptron(
105+
layers=nn.Sequential(
106+
nn.Linear(in_features=150, out_features=256),
107+
nn.ReLU(),
108+
nn.Linear(in_features=256, out_features=125)
109+
)
110+
)
111+
112+
# Here one needs to define what X and y are, obviously
113+
model.fit(X_subset[:, :150], X_subset[:, 150:], max_epochs=1000)
114+
115+
ts_index = 50
116+
y_pred = model(X_test[:, :150, 0]).detach().numpy()
117+
118+
plt.figure()
119+
plt.title('Multi-step ahead forecasting using MSE')
120+
plt.plot(X_test[ts_index].ravel())
121+
plt.plot(np.arange(150, 275), y_pred[ts_index], 'r-')
122+
123+
124+
##############################################################################
125+
# Using Soft-DTW as a loss function
126+
# ---------------------------------
127+
#
128+
# We take inspiration from the code above to define an MLP class that would allow training
129+
# a single-hidden-layer model using soft-DTW as a criterion to be optimized.
130+
# We train the network for 100 epochs on a forecasting task that would consist, given the first 150 elements
131+
# of a time series, in predicting the next 125 ones.
132+
133+
model = MultiLayerPerceptron(
134+
layers=nn.Sequential(
135+
nn.Linear(in_features=150, out_features=256),
136+
nn.ReLU(),
137+
nn.Linear(in_features=256, out_features=125)
138+
),
139+
loss=SoftDTWLossPyTorch(gamma=0.1)
140+
)
141+
142+
model.fit(X_subset[:, :150], X_subset[:, 150:], max_epochs=100)
143+
144+
y_pred = model(X_test[:, :150, 0]).detach().numpy()
145+
146+
plt.figure()
147+
plt.title('Multi-step ahead forecasting using Soft-DTW loss')
148+
plt.plot(X_test[ts_index].ravel())
149+
plt.plot(np.arange(150, 275), y_pred[ts_index], 'r-')

docs/gen_modules/tslearn.backend.rst

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
.. _mod-backend:
2+
3+
tslearn.backend
4+
===============
5+
6+
.. automodule:: tslearn.backend
7+
8+
9+
.. rubric:: Functions
10+
11+
.. autosummary::
12+
:toctree: backend
13+
:template: function.rst
14+
15+
Backend
16+
instantiate_backend
17+
select_backend
18+
NumPyBackend
19+
PyTorchBackend
20+

docs/gen_modules/tslearn.metrics.rst

+1
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ tslearn.metrics
3737
lb_keogh
3838
sigma_gak
3939
gamma_soft_dtw
40+
SoftDTWLossPyTorch

docs/reference.rst

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The complete ``tslearn`` project is automatically documented for every module.
1111
:toctree: gen_modules/
1212
:template: module.rst
1313

14+
tslearn.backend
1415
tslearn.barycenters
1516
tslearn.clustering
1617
tslearn.datasets

docs/requirements_rtd.txt

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ numba
1414
sphinx_bootstrap_theme
1515
numpydoc
1616
matplotlib
17+
torch

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ scipy
44
scikit-learn
55
joblib>=0.12
66
tensorflow>=2
7+
torch
78
pandas
89
cesium
910
h5py

requirements_nocast.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ scipy
44
scikit-learn
55
joblib>=0.12
66
tensorflow>=2
7+
torch
78
h5py

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
packages=find_packages(),
3636
package_data={"tslearn": [".cached_datasets/Trace.npz"]},
3737
install_requires=['numpy', 'scipy', 'scikit-learn', 'numba', 'joblib'],
38-
extras_require={'tests': ['pytest']},
38+
extras_require={'tests': ['pytest', 'torch'], 'pytorch': ['torch']},
3939
version=VERSION,
4040
url="http://tslearn.readthedocs.io/",
4141
project_urls={

tslearn/backend/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
The :mod:`tslearn.backend` module provides multiple backends.
3+
The backends provided are NumPy and PyTorch.
4+
"""
5+
6+
from .backend import Backend, instantiate_backend, select_backend
7+
from .numpy_backend import NumPyBackend
8+
from .pytorch_backend import PyTorchBackend
9+
10+
__all__ = [
11+
"Backend",
12+
"instantiate_backend",
13+
"select_backend",
14+
"NumPyBackend",
15+
"PyTorchBackend",
16+
]

0 commit comments

Comments
 (0)