Skip to content

Commit 219f20d

Browse files
authored
Improve backend docstrings (#482)
* Empty commit * Add section about backends in docs/quickstart.rst * Add docs/backend.rst * Add examples using the PyTorch backend in dtw, soft_dtw, soft_dtw_alignment, cdist_soft_dtw, cdist_soft_dtw_normalized * Add import torch in examples using PyTorch backend * Improve backend.rst * Skip doctest if Torch is not installed * Remove print in docstrings * Add import pytest in docstrings when Torch is not installed * Improve backend.rst * Complete backend.rst * Remove useless # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS * Remove try import torch from docstrings and skip doctests when torch is not installed. * Skip doctests for linux_without_torch when torch is not installed in azure-pipelines.yml * Improve backend.rst * Add # noqa: E501 after docstrings with lines too long * Add blank line to fix a display error of the docstrings * Empty commit
1 parent 09441ab commit 219f20d

File tree

5 files changed

+324
-6
lines changed

5 files changed

+324
-6
lines changed

azure-pipelines.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
python -m pip install pytest pytest-azurepipelines
8383
python -m pip install scikit-learn==1.2
8484
python -m pip install tensorflow==2.9.0
85-
python -m pytest -v tslearn/ --doctest-modules
85+
python -m pytest -v tslearn/ --doctest-modules -k 'not tslearn.metrics.softdtw_variants.soft_dtw and not tslearn.metrics.softdtw_variants.cdist_soft_dtw and not tslearn.metrics.dtw_variants.dtw or tslearn.metrics.dtw_variants.dtw_'
8686
displayName: 'Test'
8787
8888

docs/backend.rst

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
Backend selection and use
2+
=========================
3+
4+
`tslearn` proposes different backends (`NumPy` and `PyTorch`)
5+
to compute time series metrics such as `DTW` and `Soft-DTW`.
6+
The `PyTorch` backend can be used to compute gradients of
7+
metric functions thanks to automatic differentiation.
8+
9+
Backend selection
10+
-----------------
11+
12+
A backend can be instantiated using the function ``instantiate_backend``.
13+
To specify which backend should be instantiated (`NumPy` or `PyTorch`),
14+
this function accepts four different kind of input parameters:
15+
16+
* a string equal to ``"numpy"`` or ``"pytorch"``.
17+
* a `NumPy` array or a `Torch` tensor.
18+
* a Backend instance. The input backend is then returned.
19+
* ``None`` or anything else than mentioned previously. The backend `NumPy` is then instantiated.
20+
21+
Examples
22+
~~~~~~~~
23+
24+
If the input is the string ``"numpy"``, the ``NumPyBackend`` is instantiated.
25+
26+
.. code-block:: python
27+
28+
>>> from tslearn.backend import instantiate_backend
29+
>>> be = instantiate_backend("numpy")
30+
>>> print(be.backend_string)
31+
"numpy"
32+
33+
If the input is the string ``"pytorch"``, the ``PyTorchBackend`` is instantiated.
34+
35+
.. code-block:: python
36+
37+
>>> be = instantiate_backend("pytorch")
38+
>>> print(be.backend_string)
39+
"pytorch"
40+
41+
If the input is a `NumPy` array, the ``NumPyBackend`` is instantiated.
42+
43+
.. code-block:: python
44+
45+
>>> import numpy as np
46+
>>> be = instantiate_backend(np.array([0]))
47+
>>> print(be.backend_string)
48+
"numpy"
49+
50+
If the input is a `Torch` tensor, the ``PyTorchBackend`` is instantiated.
51+
52+
.. code-block:: python
53+
54+
>>> import torch
55+
>>> be = instantiate_backend(torch.tensor([0]))
56+
>>> print(be.backend_string)
57+
"pytorch"
58+
59+
If the input is a Backend instance, the input backend is returned.
60+
61+
.. code-block:: python
62+
63+
>>> print(be.backend_string)
64+
"pytorch"
65+
>>> be = instantiate_backend(be)
66+
>>> print(be.backend_string)
67+
"pytorch"
68+
69+
If the input is ``None``, the ``NumPyBackend`` is instantiated.
70+
71+
.. code-block:: python
72+
73+
>>> be = instantiate_backend(None)
74+
>>> print(be.backend_string)
75+
"numpy"
76+
77+
If the input is anything else, the ``NumPyBackend`` is instantiated.
78+
79+
.. code-block:: python
80+
81+
>>> be = instantiate_backend("Hello, World!")
82+
>>> print(be.backend_string)
83+
"numpy"
84+
85+
The function ``instantiate_backend`` accepts any number of input parameters, including zero.
86+
To select which backend should be instantiated (`NumPy` or `PyTorch`),
87+
a for loop is performed on the inputs until a backend is selected.
88+
89+
.. code-block:: python
90+
91+
>>> be = instantiate_backend(1, None, "Hello, World!", torch.tensor([0]), "numpy")
92+
>>> print(be.backend_string)
93+
"pytorch"
94+
95+
If none of the inputs are related to `NumPy` or `PyTorch`, the ``NumPyBackend`` is instantiated.
96+
97+
.. code-block:: python
98+
99+
>>> be = instantiate_backend(1, None, "Hello, World!")
100+
>>> print(be.backend_string)
101+
"numpy"
102+
103+
Use the backends
104+
----------------
105+
106+
The names of the attributes and methods of the backends
107+
are inspired by the `NumPy` backend.
108+
109+
Examples
110+
~~~~~~~~
111+
112+
Create backend objects.
113+
114+
.. code-block:: python
115+
116+
>>> be = instantiate_backend("pytorch")
117+
>>> mat = be.array([[0 , 1], [2, 3]], dtype=float)
118+
>>> print(mat)
119+
tensor([[0., 1.],
120+
[2., 3.]], dtype=torch.float64)
121+
122+
Use backend functions.
123+
124+
.. code-block:: python
125+
126+
>>> norm = be.linalg.norm(mat)
127+
>>> print(norm)
128+
tensor(3.7417, dtype=torch.float64)
129+
130+
Choose the backend used by metric functions
131+
-------------------------------------------
132+
133+
`tslearn`'s metric functions have an optional input parameter "``be``" to specify the
134+
backend to use to compute the metric.
135+
136+
Examples
137+
~~~~~~~~
138+
139+
.. code-block:: python
140+
141+
>>> import torch
142+
>>> from tslearn.metrics import dtw
143+
>>> s1 = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
144+
>>> s2 = torch.tensor([[3.0], [4.0], [-3.0]])
145+
>>> sim = dtw(s1, s2, be="pytorch")
146+
>>> print(sim)
147+
sim tensor(6.4807, grad_fn=<SqrtBackward0>)
148+
149+
By default, the optional input parameter ``be`` is equal to ``None``.
150+
Note that the first line of the function ``dtw`` is:
151+
152+
.. code-block:: python
153+
154+
be = instantiate_backend(be, s1, s2)
155+
156+
Therefore, even if ``be=None``, the ``PyTorchBackend`` is instantiated and used to compute the
157+
DTW metric since ``s1`` and ``s2`` are `Torch` tensors.
158+
159+
.. code-block:: python
160+
161+
>>> sim = dtw(s1, s2)
162+
>>> print(sim)
163+
sim tensor(6.4807, grad_fn=<SqrtBackward0>)
164+
165+
Automatic differentiation
166+
-------------------------
167+
168+
The `PyTorch` backend can be used to compute the gradients of the metric functions thanks to automatic differentiation.
169+
170+
Examples
171+
~~~~~~~~
172+
173+
Compute the gradient of the Dynamic Time Warping similarity measure.
174+
175+
.. code-block:: python
176+
177+
>>> s1 = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
178+
>>> s2 = torch.tensor([[3.0], [4.0], [-3.0]])
179+
>>> sim = dtw(s1, s2, be="pytorch")
180+
>>> sim.backward()
181+
>>> d_s1 = s1.grad
182+
>>> print(d_s1)
183+
tensor([[-0.3086],
184+
[-0.1543],
185+
[ 0.7715]])
186+
187+
Compute the gradient of the Soft-DTW similarity measure.
188+
189+
.. code-block:: python
190+
191+
>>> from tslearn.metrics import soft_dtw
192+
>>> ts1 = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
193+
>>> ts2 = torch.tensor([[3.0], [4.0], [-3.0]])
194+
>>> sim = soft_dtw(ts1, ts2, gamma=1.0, be="pytorch", compute_with_backend=True)
195+
>>> print(sim)
196+
tensor(41.1876, dtype=torch.float64, grad_fn=<SelectBackward0>)
197+
>>> sim.backward()
198+
>>> d_ts1 = ts1.grad
199+
>>> print(d_ts1)
200+
tensor([[-4.0001],
201+
[-2.2852],
202+
[10.1643]])

docs/quickstart.rst

+1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ look at our :doc:`API Reference <reference>`.
1010
installation
1111
gettingstarted
1212
variablelength
13+
backend
1314
integration_other_software
1415
contributing

tslearn/metrics/dtw_variants.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,31 @@ def dtw(
694694
>>> dtw([1, 2, 3], [1., 2., 2., 3., 4.])
695695
1.0
696696
697+
The PyTorch backend can be used to compute gradients:
698+
699+
>>> import torch
700+
>>> s1 = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True)
701+
>>> s2 = torch.tensor([[3.0], [4.0], [-3.0]])
702+
>>> sim = dtw(s1, s2, be="pytorch")
703+
>>> print(sim)
704+
tensor(6.4807, grad_fn=<SqrtBackward0>)
705+
>>> sim.backward()
706+
>>> print(s1.grad)
707+
tensor([[-0.3086],
708+
[-0.1543],
709+
[ 0.7715]])
710+
711+
>>> s1_2d = torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], requires_grad=True)
712+
>>> s2_2d = torch.tensor([[3.0, 3.0], [4.0, 4.0], [-3.0, -3.0]])
713+
>>> sim = dtw(s1_2d, s2_2d, be="pytorch")
714+
>>> print(sim)
715+
tensor(9.1652, grad_fn=<SqrtBackward0>)
716+
>>> sim.backward()
717+
>>> print(s1_2d.grad)
718+
tensor([[-0.2182, -0.2182],
719+
[-0.1091, -0.1091],
720+
[ 0.5455, 0.5455]])
721+
697722
See Also
698723
--------
699724
dtw_path : Get both the matching path and the similarity score for DTW
@@ -705,7 +730,7 @@ def dtw(
705730
spoken word recognition," IEEE Transactions on Acoustics, Speech and
706731
Signal Processing, vol. 26(1), pp. 43--49, 1978.
707732
708-
"""
733+
""" # noqa: E501
709734
be = instantiate_backend(be, s1, s2)
710735
s1 = to_time_series(s1, remove_nans=True, be=be)
711736
s2 = to_time_series(s2, remove_nans=True, be=be)

0 commit comments

Comments
 (0)