Skip to content

Commit eb75fc5

Browse files
authored
Update usage.rst
- update suggestion on data split
1 parent 191d4e5 commit eb75fc5

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

docs/source/usage.rst

+13-12
Original file line numberDiff line numberDiff line change
@@ -1230,18 +1230,19 @@ Putting all previous snippet examples together, we obtain the following pipeline
12301230
assert discrete_label.shape == (100, )
12311231
assert continuous_label.shape == (100, 3)
12321232

1233-
# 3. Split data and labels
1234-
(
1235-
train_data,
1236-
valid_data,
1237-
train_discrete_label,
1238-
valid_discrete_label,
1239-
train_continuous_label,
1240-
valid_continuous_label,
1241-
) = train_test_split(neural_data,
1242-
discrete_label,
1243-
continuous_label,
1244-
test_size=0.3)
1233+
# 3. Split data and labels into train/validation
1234+
1235+
from sklearn.model_selection import train_test_split
1236+
1237+
split_idx = int(0.8 * len(neural_data))
1238+
#suggest: 5%-20% depending on your dataset size; note this also split
1239+
early/late, which might not be ideal for your data/expt!
1240+
1241+
train_data = neural_data[:split_idx]
1242+
valid_data = neural_data[split_idx:]
1243+
1244+
train_continuous_label = neural_data.continuous_index.numpy()[:split_idx]
1245+
valid_continuous_label = neural_data.continuous_index.numpy()[split_idx:]
12451246

12461247
# 4. Fit the model
12471248
# time contrastive learning

0 commit comments

Comments
 (0)