Skip to content

Commit 6ca5c4e

Browse files
Gaussian processes via gpytorch (#782)
This is a feature to add support for Gaussian Processes (GPs) via integration with gpytorch. Similar to "vanilla" PyTorch, the idea here is that skorch allows the user to focus on what's important (implementing the mean function and kernel function) and not to bother with stuff like the training loop, callbacks, etc. This is probably best illustrated in the accompanying notebook. GPs are primarily for regression, hence those are the main focus here. Traditionally, there are "exact" solutions and approximations. skorch will provide an ExactGPRegressor and a GPRegressor for those two use cases. On top of that, a GPBinaryClassifier is offered, though I suspect it to be rarely used. I couldn't get the GPClassifier for multiclass to work, the code is therefore commented out. The API is mostly the same as for the normal skorch estimators. There are some additions to make working with GPs easier: - predict method takes a return_std argument to return the standard deviation as well (as in sklearn's GaussianProcessRegressor; return_cov is not supported) - sample method to sample for the model - confidence_region method to get the confidence region Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 156efe6 commit 6ca5c4e

14 files changed

+4411
-5
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Added a `get_all_learnable_params` method to retrieve the named parameters of all PyTorch modules defined on the net, including of criteria if applicable
1414
- Added `MlflowLogger` callback for logging to Mlflow (#769)
1515
- Added `InputShapeSetter` callback for automatically setting the input dimension of the PyTorch module
16+
- Added a new module to support Gaussian Processes through [GPyTorch](https://gpytorch.ai/). To learn more about it, read the [GP documentation](https://skorch.readthedocs.io/en/latest/user/probabilistic.html) or take a look at the [GP notebook](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb). This feature is experimental, i.e. the API could be changed in the future in a backwards incompatible way.
1617

1718
### Changed
1819

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Resources
3131

3232
- `Documentation <https://skorch.readthedocs.io/en/latest/?badge=latest>`_
3333
- `Source Code <https://github.com/skorch-dev/skorch/>`_
34+
- `Installation <https://github.com/skorch-dev/skorch#installation>`_
3435

3536
========
3637
Examples
@@ -127,6 +128,7 @@ skorch also provides many convenient features, among others:
127128
- `Parameter freezing/unfreezing <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Freezer>`_
128129
- `Progress bar <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.ProgressBar>`_ (for CLI as well as jupyter)
129130
- `Automatic inference of CLI parameters <https://github.com/skorch-dev/skorch/tree/master/examples/cli>`_
131+
- `Integration with GPyTorch for Gaussian Processes <https://skorch.readthedocs.io/en/latest/user/probabilistic.html>`_
130132

131133
============
132134
Installation

docs/conf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@
5050

5151
intersphinx_mapping = {
5252
'pytorch': ('https://pytorch.org/docs/stable/', None),
53-
'sklearn': ('http://scikit-learn.org/stable/', None),
54-
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
53+
'sklearn': ('https://scikit-learn.org/stable/', None),
54+
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
5555
'python': ('https://docs.python.org/3', None),
5656
'mlflow': ('https://mlflow.org/docs/latest/', None),
57+
'gpytorch': ('https://docs.gpytorch.ai/en/stable/', None),
5758
}
5859

5960
# Add any paths that contain templates here, relative to this directory.
@@ -118,7 +119,7 @@
118119
# html_theme_options = {}
119120

120121
def setup(app):
121-
app.add_stylesheet('css/my_theme.css')
122+
app.add_css_file('css/my_theme.css')
122123

123124
# Add any paths that contain custom static files (such as style sheets) here,
124125
# relative to this directory. They are copied after the builtin static files,

docs/index.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ User's Guide
5252
user/callbacks
5353
user/dataset
5454
user/save_load
55+
user/probabilistic
5556
user/history
5657
user/toy
5758
user/helper
@@ -82,5 +83,5 @@ Indices and tables
8283
* :ref:`search`
8384

8485

85-
.. _pytorch: http://pytorch.org/
86-
.. _sklearn: http://scikit-learn.org/
86+
.. _pytorch: https://pytorch.org/
87+
.. _sklearn: https://scikit-learn.org/

docs/probabilistic.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
skorch.probabilistic
2+
====================
3+
4+
.. automodule:: skorch.probabilistic
5+
:members:

docs/skorch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ skorch
1111
helper
1212
history
1313
net
14+
probabilistic
1415
regressor
1516
scoring
1617
toy

docs/user/probabilistic.rst

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
==================
2+
Gaussian Processes
3+
==================
4+
5+
skorch integrates with GPyTorch_ to make it easy to train Gaussian Process (GP)
6+
models. You should already know how Gaussian Processes work. Please refer to
7+
other resources if you want to learn about them, this section assumes
8+
familiarity with the concept.
9+
10+
GPyTorch adopts many patterns from PyTorch, thus making it easy to pick up for
11+
seasoned PyTorch users. Similarly, the skorch GPyTorch integration should look
12+
familiar to seasoned skorch users. However, GPs are a different beast than the
13+
more common, non-probabilistic machine learning techniques. It is important to
14+
understand the basic concepts before using them in practice.
15+
16+
Installation
17+
------------
18+
19+
In addition to the normal skorch dependencies and PyTorch, you need to install
20+
GPyTorch as well. It wasn't added as a normal dependency since most users
21+
probably are not interested in using skorch for GPs. To install GPyTorch, use
22+
either pip or conda:
23+
24+
.. code:: bash
25+
26+
# using pip
27+
pip install -U gpytorch
28+
# using conda
29+
conda install gpytorch -c gpytorch
30+
31+
When to use GPyTorch with skorch
32+
--------------------------------
33+
34+
Here we want to quickly explain when it would be a good idea for you to use
35+
GPyTorch with skorch. There are a couple of offerings in the Python ecosystem
36+
when it comes to Gaussian Processes. We cannot provide an exhaustive list of
37+
pros and cons of each possibility. There are, however, two obvious alternatives
38+
that are worth discussing: using the sklearn_ implementation and using GPyTorch
39+
without skorch.
40+
41+
When to use skorch + GPyTorch over sklearn:
42+
43+
* When you are more familiar with PyTorch than with sklearn
44+
* When the kernels provided by sklearn are not sufficient for your use case and
45+
you would like to implement custom kernels with PyTorch
46+
* When you want to use the rich set of optimizers available in PyTorch
47+
* When sklearn is too slow and you want to use the GPU or scale across machines
48+
* When you like to use the skorch extras, e.g. callbacks
49+
50+
When to use skorch + GPyTorch over pure GPyTorch
51+
52+
* When you're already familiar with skorch and want an easy entry into GPs
53+
* When you like to use the skorch extras, e.g. callbacks and grid search
54+
* When you don't want to bother with writing your own training loop
55+
56+
However, if you are researching GPs and would like to have control over every
57+
detail, using all the rich but very specific featues that GPyTorch has on offer,
58+
it is better to use it directly without skorch.
59+
60+
Examples
61+
--------
62+
63+
Exact Gaussian Processes
64+
^^^^^^^^^^^^^^^^^^^^^^^^
65+
66+
Same as GPyTorch, skorch supports exact and approximate Gaussian Processes
67+
regression. For exact GPs, use the
68+
:class:`~skorch.probabilistic.ExactGPRegressor`. The likelihood has to be a
69+
:class:`~gpytorch.likelihoods.GaussianLikelihood` and the criterion
70+
:class:`~gpytorch.mlls.ExactMarginalLogLikelihood`, but those are the defaults
71+
and thus don't need to be specified. For exact GPs, the module needs to be an
72+
:class:`~gpytorch.models.ExactGP`. For this example, we use a simple RBF kernel.
73+
74+
.. code:: python
75+
76+
import gpytorch
77+
from skorch.probabilistic import ExactGPRegressor
78+
79+
class RbfModule(gpytorch.models.ExactGP):
80+
def __init__(likelihood, self):
81+
# detail: We don't set train_inputs and train_targets here skorch because
82+
# will take care of that.
83+
super().__init__()
84+
self.mean_module = gpytorch.means.ConstantMean()
85+
self.covar_module = gpytorch.kernels.RBFKernel()
86+
87+
def forward(self, x):
88+
mean_x = self.mean_module(x)
89+
covar_x = self.covar_module(x)
90+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
91+
92+
gpr = ExactGPRegressor(RbfModule)
93+
gpr.fit(X_train, y_train)
94+
y_pred = gpr.predict(X_test)
95+
96+
As you can see, this almost looks like a normal skorch regressor with a normal
97+
PyTorch module. We can fit as normal using the ``fit`` method and predict using
98+
the ``predict`` method.
99+
100+
Inside the module, we determine the mean by using a mean function (just constant
101+
in this case) and the covariance matrix using the RBF kernel function. You
102+
should know about mean and kernel functions already. Having the mean and
103+
covariance matrix, we assume that the output distribution is a multivariate
104+
normal function, since exact GPs rely on this assumption. We could send the
105+
``x`` through an MLP for `Deep Kernel Learning
106+
<https://docs.gpytorch.ai/en/stable/examples/06_PyTorch_NN_Integration_DKL/index.html>`_
107+
but left it out to keep the example simple.
108+
109+
One major difference to usual deep learning models is that we actually predict a
110+
distribution, not just a point estimate. That means that if we choose an
111+
appropriate model that fits the data well, we can express the **uncertainty** of
112+
the model:
113+
114+
.. code:: python
115+
116+
y_pred, y_std = gpr.predict(X, return_std=True)
117+
lower_conf_region = y_pred - y_std
118+
upper_conf_region = y_pred + y_std
119+
120+
Here we not only returned the mean of the prediction, ``y_pred``, but also its
121+
standard deviation, ``y_std``. This tells us how uncertain the model is about
122+
its prediction. E.g., it could be the case that the model is fairly certain when
123+
*interpolating* between data points but uncertain about *extrapolating*. This is
124+
not possible to know when models only learn point predictions.
125+
126+
The obtain the confidence region, you can also use the ``confidence_region``
127+
method:
128+
129+
.. code:: python
130+
131+
# 1 standard deviation
132+
lower, upper = gpr.confidence_region(X, sigmas=1)
133+
134+
# 2 standard deviation, the default
135+
lower, upper = gpr.confidence_region(X, sigmas=2)
136+
137+
Furthermore, a GP allows you to sample from the distribution even *before
138+
fitting* it. The GP needs to be initialized, however:
139+
140+
.. code:: python
141+
142+
gpr = ExactGPRegressor(...)
143+
gpr.initialize()
144+
samples = gpr.sample(X, n_samples=100)
145+
146+
By visualizing the samples and comparing them to the true underlying
147+
distribution of the target, you can already get a feel about whether the model
148+
you built is capable of generating the distribution of the target. If fitting
149+
takes a long time, it is therefore recommended to check the distribution first,
150+
otherwise you may try to fit a model that is incapable of generating the true
151+
distribution and waste a lot of time.
152+
153+
Approximate Gaussian Processes
154+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
155+
156+
For some situations, fitting an exact GP might be infeasible, e.g. because the
157+
distribution is not Gaussian or because you want to perform stochastic
158+
optimization with mini-batches. For this, GPyTorch provides facilities to train
159+
variational and approximate GPs. The module should inherit from
160+
:class:`~gpytorch.models.ApproximateGP` and should define a *variational
161+
strategy*. From the skorch side of things, use
162+
:class:`~skorch.probabilistic.GPRegressor`.
163+
164+
.. code:: python
165+
166+
import gpytorch
167+
from gpytorch.models import ApproximateGP
168+
from gpytorch.variational import CholeskyVariationalDistribution
169+
from gpytorch.variational import VariationalStrategy
170+
from skorch.probabilistic import GPRegressor
171+
172+
class VariationalModule(ApproximateGP):
173+
def __init__(self, inducing_points):
174+
variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
175+
variational_strategy = VariationalStrategy(
176+
self, inducing_points, variational_distribution, learn_inducing_locations=True,
177+
)
178+
super().__init__(variational_strategy)
179+
self.mean_module = gpytorch.means.ConstantMean()
180+
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
181+
182+
def forward(self, x):
183+
mean_x = self.mean_module(x)
184+
covar_x = self.covar_module(x)
185+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
186+
187+
X, y = get_data(...)
188+
X_incuding = X[:100]
189+
X_train, y_train = X[100:], y[100:]
190+
num_training_samples = len(X_train)
191+
192+
gpr = GPRegressor(
193+
VariationalModule,
194+
module__inducing_points=X_inducing,
195+
criterion__num_data=num_training_samples,
196+
)
197+
198+
gpr.fit(X_train, y_train)
199+
y_pred = gpr.predict(X_train)
200+
201+
As you can see, the variational strategy requires us to use inducing points. We
202+
split off 100 of our training data samples to use as inducing points, assuming
203+
that they are representative of the whole distribution. Apart from this, there
204+
is basically no difference to using exact GP regression.
205+
206+
Finally, skorch also provides :class:`~skorch.probabilistic.GPBinaryClassifier`
207+
for binary classification with GPs. It uses a Bernoulli likelihood by default.
208+
However, using GPs for classification is not very common, GPs are most commonly
209+
used for regression tasks where data points have a known relationship to each
210+
other (e.g. in time series forecasts).
211+
212+
Multiclass classification is not currently provided, but you can use
213+
:class:`~skorch.probabilistic.GPBinaryClassifier` in conjunction with
214+
:class:`~sklearn.multiclass.OneVsRestClassifier` to achieve the same result.
215+
216+
Further examples
217+
----------------
218+
219+
To see all of this in action, we provide a notebook that shows using skorch with GPs on real world data: `Gaussian Processes notebook <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb)>`_.
220+
221+
.. _GPyTorch: https://gpytorch.ai/
222+
.. _sklearn: https://scikit-learn.org/stable/modules/gaussian_process.html

docs/user/tutorials.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ The following are examples and notebooks on how to use skorch.
2222
* `Seq2Seq Translation using skorch <https://github.com/skorch-dev/skorch/tree/master/examples/translation>`_ - Translation with a seqeuence to sequence network.
2323

2424
* `Advanced Usage <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb>`_ - Dives deep into the inner works of skorch. `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb>`_
25+
26+
* `Gaussian Processes <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb>`_ - Train Gaussian Processes with the help of GPyTorch `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb>`_

0 commit comments

Comments
 (0)