Skip to content

Commit bd706ae

Browse files
authored
Update usage.rst
1 parent ca64b2a commit bd706ae

File tree

1 file changed

+36
-29
lines changed

1 file changed

+36
-29
lines changed

docs/source/usage.rst

+36-29
Original file line numberDiff line numberDiff line change
@@ -1207,44 +1207,47 @@ Putting all previous snippet examples together, we obtain the following pipeline
12071207

12081208
# 1. Define a CEBRA model
12091209
cebra_model = cebra.CEBRA(
1210-
model_architecture = "offset10-model",
1211-
batch_size = 512,
1212-
learning_rate = 1e-4,
1213-
temperature_mode='constant',
1214-
temperature = 0.1,
1215-
max_iterations = 10, # TODO(user): to change to ~500-10000 depending on dataset size
1216-
#max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting
1217-
time_offsets = 10,
1218-
output_dimension = 8,
1219-
verbose = False
1210+
model_architecture = "offset10-model",
1211+
batch_size = 512,
1212+
learning_rate = 1e-4,
1213+
temperature_mode='constant',
1214+
temperature = 0.1,
1215+
max_iterations = 10, # TODO(user): to change to ~500-10000 depending on dataset size
1216+
#max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting
1217+
time_offsets = 10,
1218+
output_dimension = 8,
1219+
verbose = False
12201220
)
1221-
1221+
12221222
# 2. Load example data
12231223
neural_data = cebra.load_data(file="neural_data.npz", key="neural")
12241224
new_neural_data = cebra.load_data(file="neural_data.npz", key="new_neural")
12251225
continuous_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["continuous1", "continuous2", "continuous3"])
12261226
discrete_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"]).flatten()
1227-
1227+
1228+
12281229
assert neural_data.shape == (100, 3)
12291230
assert new_neural_data.shape == (100, 4)
12301231
assert discrete_label.shape == (100, )
12311232
assert continuous_label.shape == (100, 3)
1232-
1233+
12331234
# 3. Split data and labels into train/validation
12341235
from sklearn.model_selection import train_test_split
1235-
1236-
split_idx = int(0.8 * len(neural_data))
1236+
1237+
split_idx = int(0.8 * len(neural_data))
12371238
# suggestion: 5%-20% depending on your dataset size; note that this splits the
12381239
# into an early and late part, which might not be ideal for your data/experiment!
12391240
# As a more involved alternative, consider e.g. a nested time-series split.
1240-
1241+
12411242
train_data = neural_data[:split_idx]
12421243
valid_data = neural_data[split_idx:]
1243-
1244+
12441245
train_continuous_label = continuous_label[:split_idx]
12451246
valid_continuous_label = continuous_label[split_idx:]
1246-
1247-
1247+
1248+
train_discrete_label = discrete_label[:split_idx]
1249+
valid_discrete_label = discrete_label[split_idx:]
1250+
12481251
# 4. Fit the model
12491252
# time contrastive learning
12501253
cebra_model.fit(train_data)
@@ -1254,32 +1257,36 @@ Putting all previous snippet examples together, we obtain the following pipeline
12541257
cebra_model.fit(train_data, train_continuous_label)
12551258
# mixed behavior contrastive learning
12561259
cebra_model.fit(train_data, train_discrete_label, train_continuous_label)
1257-
1260+
1261+
12581262
# 5. Save the model
12591263
tmp_file = Path(tempfile.gettempdir(), 'cebra.pt')
12601264
cebra_model.save(tmp_file)
1261-
1265+
12621266
# 6. Load the model and compute an embedding
12631267
cebra_model = cebra.CEBRA.load(tmp_file)
12641268
train_embedding = cebra_model.transform(train_data)
12651269
valid_embedding = cebra_model.transform(valid_data)
1266-
assert train_embedding.shape == (70, 8) # TODO(user): change to split ratio & output dim
1267-
assert valid_embedding.shape == (30, 8) # TODO(user): change to split ratio & output dim
1268-
1270+
1271+
assert train_embedding.shape == (80, 8) # TODO(user): change to split ratio & output dim
1272+
assert valid_embedding.shape == (20, 8) # TODO(user): change to split ratio & output dim
1273+
12691274
# 7. Evaluate the model performance (you can also check the train_data)
1270-
goodness_of_fit = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model,
1271-
valid_data,
1275+
goodness_of_fit = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model,
1276+
valid_data,
12721277
valid_discrete_label,
12731278
valid_continuous_label)
1274-
1279+
12751280
# 8. Adapt the model to a new session
12761281
cebra_model.fit(new_neural_data, adapt = True)
1277-
1282+
12781283
# 9. Decode discrete labels behavior from the embedding
12791284
decoder = cebra.KNNDecoder()
12801285
decoder.fit(train_embedding, train_discrete_label)
12811286
prediction = decoder.predict(valid_embedding)
1282-
assert prediction.shape == (30,)
1287+
assert prediction.shape == (20,)
1288+
1289+
12831290

12841291
👉 For further guidance on different/customized applications of CEBRA on your own data, refer to the ``examples/`` folder or to the full documentation folder ``docs/``.
12851292

0 commit comments

Comments
 (0)