You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/usage.rst
+22-18
Original file line number
Diff line number
Diff line change
@@ -1,7 +1,7 @@
1
1
Using CEBRA
2
2
===========
3
3
4
-
This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for in-depth CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code:
4
+
This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code:
5
5
6
6
* For regular usage, we recommend leveraging the **high-level interface**, adhering to ``scikit-learn`` formatting.
7
7
* Upon specific needs, advanced users might consider diving into the **low-level interface** that adheres to ``PyTorch`` formatting.
@@ -12,7 +12,7 @@ Firstly, why use CEBRA?
12
12
13
13
CEBRA is primarily designed for producing robust, consistent extractions of latent factors from time-series data. It supports three modes, and is a self-supervised representation learning algorithm that uses our modified contrastive learning approach designed for multi-modal time-series data. In short, it is a type of non-linear dimensionality reduction, like `tSNE <https://www.jmlr.org/papers/v9/vandermaaten08a.html>`_ and `UMAP <https://arxiv.org/abs/1802.03426>`_. We show in our original paper that it outperforms tSNE and UMAP at producing closer-to-ground-truth latents and is more consistent.
14
14
15
-
That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023).
15
+
That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see `Schneider, Lee, Mathis. Nature 2023<https://www.nature.com/articles/s41586-023-06031-6>`_.
16
16
17
17
The CEBRA workflow
18
18
------------------
@@ -22,7 +22,7 @@ We recommend to start with running CEBRA-Time (unsupervised) and look both at th
22
22
23
23
(1) Use CEBRA-Time for unsupervised data exploration.
24
24
(2) Consider running a hyperparameter sweep on the inputs to the model, such as :py:attr:`cebra.CEBRA.model_architecture`, :py:attr:`cebra.CEBRA.time_offsets`, :py:attr:`cebra.CEBRA.output_dimension`, and set :py:attr:`cebra.CEBRA.batch_size` to be as high as your GPU allows. You want to see clear structure in the 3D plot (the first 3 latents are shown by default).
25
-
(3) Use CEBRA-Behavior with many different labels and combinations, then look at the InfoNCE loss - the lower the loss value, the better the fit (see :py:doc:`cebra-figures/figures/ExtendedDataFigure5`), and visualize the embeddings. The goal is to understand which labels are contributing to the structure you see in CEBRA-Time, and improve this structure. Again, you should consider a hyperparameter sweep.
25
+
(3) Use CEBRA-Behavior with many different labels and combinations, then look at the InfoNCE loss - the lower the loss value, the better the fit (see :py:doc:`cebra-figures/figures/ExtendedDataFigure5`), and visualize the embeddings. The goal is to understand which labels are contributing to the structure you see in CEBRA-Time, and improve this structure. Again, you should consider a hyperparameter sweep (and avoid overfitting by performing the proper train/validation split (see Step 3 in our quick start guide below).
26
26
(4) Interpretability: now you can use these latents in downstream tasks, such as measuring consistency, decoding, and determining the dimensionality of your data with topological data analysis.
27
27
28
28
All the steps to do this are described below. Enjoy using CEBRA! 🔥🦓
@@ -179,7 +179,7 @@ We provide a set of pre-defined models. You can access (and search) a list of av
179
179
180
180
Then, you can choose the one that fits best with your needs and provide it to the CEBRA model as the :py:attr:`~.CEBRA.model_architecture` parameter.
181
181
182
-
As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis, 2022).
182
+
As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis. Nature 2023).
183
183
184
184
.. list-table::
185
185
:widths: 25 25 20 30
@@ -265,9 +265,8 @@ For standard usage we recommend the default values (i.e., ``InfoNCE`` and ``cosi
265
265
266
266
.. rubric:: Temperature :py:attr:`~.CEBRA.temperature`
267
267
268
-
:py:attr:`~.CEBRA.temperature` has the largest effect on visualization of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data.
268
+
:py:attr:`~.CEBRA.temperature` has the largest effect on *visualization* of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data. Lower temperatures (e.g. around 0.1) will result in a more dispersed embedding, higher temperatures (larger than 1) will concentrate the embedding.
269
269
270
-
The simplest way to handle it is to use a *learnable temperature*. For that, set :py:attr:`~.CEBRA.temperature_mode` to ``auto``. :py:attr:`~.CEBRA.temperature` will be trained alongside the model.
271
270
272
271
🚀 For advance usage, you might need to find the optimal :py:attr:`~.CEBRA.temperature`. For that we recommend to perform a grid-search.
273
272
@@ -307,7 +306,6 @@ Here is an example of a CEBRA model initialization:
307
306
cebra_model = CEBRA(
308
307
model_architecture = "offset10-model",
309
308
batch_size = 1024,
310
-
temperature_mode="auto",
311
309
learning_rate = 0.001,
312
310
max_iterations = 10,
313
311
time_offsets = 10,
@@ -321,8 +319,7 @@ Here is an example of a CEBRA model initialization:
@@ -568,7 +565,8 @@ We provide a simple hyperparameters sweep to compare CEBRA models with different
568
565
learning_rate = [0.001],
569
566
time_offsets = 5,
570
567
max_iterations = 5,
571
-
temperature_mode = "auto",
568
+
temperature_mode='constant',
569
+
temperature = 0.1,
572
570
verbose = False)
573
571
574
572
# 2. Define the datasets to iterate over
@@ -820,7 +818,7 @@ It takes a CEBRA model and returns a 2D plot of the loss against the number of i
820
818
Displaying the temperature
821
819
""""""""""""""""""""""""""
822
820
823
-
:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``.
821
+
:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``. We recommend only using `auto` if you have first explored the `constant` setting. If you use the ``auto`` mode, please always check the time evolution of the temperature over time alongside the loss curve.
824
822
825
823
To that extend, you can use the function :py:func:`~.plot_temperature`.
826
824
@@ -1186,9 +1184,10 @@ Improve model performance
1186
1184
🧐 Below is a (non-exhaustive) list of actions you can try if your embedding looks different from what you were expecting.
1187
1185
1188
1186
#. Assess that your model `converged <https://machine-learning.paperspace.com/wiki/convergence>`_. For that, observe if the training loss stabilizes itself around the end of the training or still seems to be decreasing. Refer to `Visualize the training loss`_ for more details on how to display the training loss.
1189
-
#. Increase the number of iterations. It should be at least 10,000.
1187
+
#. Increase the number of iterations. It typically should be at least 10,000. On small datasets, it can make sense to stop training earlier to avoid overfitting effects.
1190
1188
#. Make sure the batch size is big enough. It should be at least 512.
1191
1189
#. Fine-tune the model's hyperparameters, namely ``learning_rate``, ``output_dimension``, ``num_hidden_units`` and eventually ``temperature`` (by setting ``temperature_mode`` back to ``constant``). Refer to `Grid search`_ for more details on performing hyperparameters tuning.
1190
+
#. To note, you should still be mindful of performing train/validation splits and shuffle controls to avoid `overfitting <https://developers.google.com/machine-learning/crash-course/overfitting/overfitting>`_.
1192
1191
1193
1192
1194
1193
@@ -1202,14 +1201,19 @@ Putting all previous snippet examples together, we obtain the following pipeline
1202
1201
import cebra
1203
1202
from numpy.random import uniform, randint
1204
1203
from sklearn.model_selection import train_test_split
1204
+
import os
1205
+
import tempfile
1206
+
from pathlib import Path
1205
1207
1206
1208
# 1. Define a CEBRA model
1207
1209
cebra_model = cebra.CEBRA(
1208
1210
model_architecture = "offset10-model",
1209
1211
batch_size = 512,
1210
1212
learning_rate = 1e-4,
1211
-
max_iterations = 10, # TODO(user): to change to at least 10'000
1212
-
max_adapt_iterations = 10, # TODO(user): to change to ~100-500
1213
+
temperature_mode='constant',
1214
+
temperature = 0.1,
1215
+
max_iterations = 10, # TODO(user): to change to ~500-10000 depending on dataset size
1216
+
#max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting
1213
1217
time_offsets = 10,
1214
1218
output_dimension = 8,
1215
1219
verbose = False
@@ -1243,7 +1247,7 @@ Putting all previous snippet examples together, we obtain the following pipeline
0 commit comments