Skip to content

Enhance cross validation #899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
manuelgitgomes opened this issue Mar 26, 2024 · 10 comments
Closed

Enhance cross validation #899

manuelgitgomes opened this issue Mar 26, 2024 · 10 comments
Assignees
Labels
enhancement New feature or request

Comments

@manuelgitgomes
Copy link
Collaborator

Enhance cross validation by using scikit-learn functions.

@manuelgitgomes manuelgitgomes added the enhancement New feature or request label Mar 26, 2024
@manuelgitgomes manuelgitgomes self-assigned this Mar 26, 2024
manuelgitgomes added a commit that referenced this issue Apr 24, 2024
Adding cross validation support to batch execution, by creating folds and diving each run into folds.
@manuelgitgomes
Copy link
Collaborator Author

Hello @miguelriemoliveira and @Kazadhum.

I have now added cross validation support in batch execution on the branch dev/cross-validation.

The user now defines type of cross validation and its parameters in data.yaml:

cross_validation:
type: "stratified-k-fold"
n_splits: 3 # Number of folds
train_size: # Percentage of the dataset used for training, only used in StratifiedShuffleSplit

Right now, fold creation is supported with StratifiedKFold, KFold, LeaveOneOut, and StratifiedShuffleSplit from scikit-learn.
The classes used for stratification are the combination of sensors and patterns detected, as written in:

def generateClasses(dataset):
"""
Generate classes based on an ATOM dataset.
Classes follow the format detected_pattern--detected_sensor1-detected_sensor2-[...]---[...]
Args:
dataset (dict): ATOM dataset.
Returns:
tuple: A tuple containing the classes and collection keys.
"""
classes = []
collection_keys = list(dataset['collections'].keys())
for collection_key in collection_keys:
detected_sensors = []
class_name = ''
for pattern_key in dataset['patterns'].keys():
for sensor_key in dataset['sensors'].keys():
if dataset['collections'][collection_key]['labels'][pattern_key][sensor_key]['detected']:
detected_sensors.append(sensor_key)
detected_pattern_and_sensors = pattern_key + '--' + '-'.join(detected_sensors)
class_name += detected_pattern_and_sensors + '---'
classes.append(class_name.rstrip('---'))
return classes, collection_keys

This then creates a new auto_rendered.yaml, with division of each run in folds, using -csf to define the collections used by the fold:

nig_0.1_run001_fold001:
cmd: |
rosrun atom_calibration calibrate -json $ATOM_DATASETS/rrbot/train/dataset.json \
-v -max_nfev 2 -ss 1 \
-nig 0.1 0.1 \
-csf 'lambda x: int(x) in [0, 3]' \
&& \
rosrun atom_evaluation rgb_to_rgb_evaluation \
-train_json $ATOM_DATASETS/rrbot/train/atom_calibration.json \
-test_json $ATOM_DATASETS/rrbot/train/dataset.json \
-ss rgb_left -st rgb_right \
-csf 'lambda x: int(x) in [1, 2]' \
-sfr -sfrn /tmp/rgb_rgb_evaluation.csv
files_to_collect:
- '$ATOM_DATASETS/rrbot/train/atom_calibration.json'
- '$ATOM_DATASETS/rrbot/train/atom_calibration_params.yml'
- '$ATOM_DATASETS/rrbot/train/command_line_args.yml'
- '/tmp/rgb_rgb_evaluation.csv'

process_results is also adapted to run with these folds!

I have done a test with rrbot and everything seemed nice, can you test on your machines?

@miguelriemoliveira
Copy link
Member

Hi @manuelgitgomes ,

looks great. Thanks.

@Kazadhum I do not have a lot of time right now. Can you test it please?

One question: if we want the run the old way, is it possible or not?

@Kazadhum
Copy link
Collaborator

Hi@manuelgitgomes and @miguelriemoliveira! I'll test it as soon as I can. I'll try to tell you something today.

@Kazadhum
Copy link
Collaborator

Kazadhum commented Apr 25, 2024

Hi @manuelgitgomes and @miguelriemoliveira !
I just tested with stratified k-folding and it seems to be working correctly at first glance, as well as process_results.

BTW, is it still the case that it doesn't make sense to have the collection column in the processed results, since it also averages the collection number?

I can run more thorough tests next week, if necessary

@manuelgitgomes
Copy link
Collaborator Author

One question: if we want the run the old way, is it possible or not?

Right now, you can't, but I can change that easily enough (I think). On it.

BTW, is it still the case that it doesn't make sense to have the collection column in the processed results, since it also averages the collection number?

It doesn't, I believe. I can delete it.

@Kazadhum
Copy link
Collaborator

BTW, is it still the case that it doesn't make sense to have the collection column in the processed results, since it also averages the collection number?

It doesn't, I believe. I can delete it.

Now that I mention it, I think it doesn't make much sense to have any lines at all besides the "Average" line, since all other values are "meaningless" (as they belong to different collections). Do you agree?

@manuelgitgomes
Copy link
Collaborator Author

Now that I mention it, I think it doesn't make much sense to have any lines at all besides the "Average" line, since all other values are "meaningless" (as they belong to different collections). Do you agree?

In the processed results? Sure, makes sense. Can also be removed, but I don't see any wrong in leaving it there, right?

manuelgitgomes added a commit that referenced this issue Apr 26, 2024
Changed batch execution and process results to allow for an empty cross validation definition in data.yml.
This guarantees backwards compatibility, which was tested with old_template.yml.j2
Also removed the "Collection #" column from the processed results.
@manuelgitgomes
Copy link
Collaborator Author

Changed batch execution and process results to allow for an empty cross validation definition in data.yml.
This guarantees backwards compatibility, which was tested with old_template.yml.j2
Also removed the "Collection #" column from the processed results.

@miguelriemoliveira
Copy link
Member

Looks great!

manuelgitgomes added a commit that referenced this issue Apr 30, 2024
* #899 Adding cross validation to atom batch execution

Adding cross validation support to batch execution, by creating folds and diving each run into folds.

* #899 Guaranteeing backwards compatibility in batch execution

Changed batch execution and process results to allow for an empty cross validation definition in data.yml.
This guarantees backwards compatibility, which was tested with old_template.yml.j2
Also removed the "Collection #" column from the processed results.
@manuelgitgomes
Copy link
Collaborator Author

This has been merged to main, closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants