-
Notifications
You must be signed in to change notification settings - Fork 22
Introduce masking class and incorporate in TokenizerMasking #383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few style comments
…sking to use these, then style improvements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the draft. Let's try to keep it as much as possible as a refactor and introduce new features later; same for some fixes that are not directly related.
The Masker should encapsulate the masking as much as possible.
src/weathergen/datasets/masking.py
Outdated
def mask_source( | ||
self, | ||
tokenized_data: list[torch.Tensor], | ||
rng: np.random.Generator, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be part of the state, potentially passed in the constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed thank you - done with new commit
src/weathergen/datasets/masking.py
Outdated
self, | ||
tokenized_data: list[torch.Tensor], | ||
rng: np.random.Generator, | ||
masking_rate: float, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's passed in the constructor. Why is it passed here again? any sampling should also happen in this class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
src/weathergen/datasets/masking.py
Outdated
|
||
|
||
class Masker: | ||
"""Class to generate boolean masks for token sequences and apply them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the class can also be used for BERT-type masking + noising.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the docstring, remove boolean
src/weathergen/datasets/masking.py
Outdated
if num_tokens == 0: | ||
return tokenized_data, [] | ||
|
||
# Determine the masking rate to use for this call |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove, see above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/weathergen/datasets/masking.py
Outdated
|
||
if self.masking_rate_sampling: | ||
rate = np.clip( | ||
np.abs(rng.normal(loc=rate, scale=1.0 / (2.5 * np.pi))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should parametrize this. But better to do it in a separate PR
@@ -336,6 +339,8 @@ def __iter__(self): | |||
|
|||
(ss_cells, ss_lens, ss_centroids) = self.tokenizer.batchify_source( | |||
stream_info, | |||
# NOTE: two unused arguments in TokenizerMasking, | |||
# still used in TokenizerForecast? | |||
self.masking_rate, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, they should be removed
@@ -178,6 +178,7 @@ def tokenize_window_space( | |||
if len(source) < 2: | |||
return | |||
|
|||
# idx_ord_lens is length... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idx_ord_lens is length is number of tokens per healpix cell
) | ||
for cc, pp in zip(target_tokens_cells, self.perm_sel, strict=True) | ||
] | ||
###################### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove before we merge
@@ -280,7 +278,6 @@ def id(arg): | |||
) | |||
|
|||
# tokenize | |||
# TODO: properly set stream_id; don't forget to normalize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still needs to be done
mask[self.rng.integers(low=0, high=len(mask))] = False | ||
# if masking rate is 1.0, all tokens are masked, so the source is empty | ||
# but we must compute perm_sel for the target function | ||
if masking_rate == 1.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case should be handled in the Masker and not here
…ng_rate, update comments, remove archived class
…rom batchify_source
…Masker class, remove handling special cases of masking (all masked)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments. If it has been tested, then it's good to be merged.
src/weathergen/datasets/masking.py
Outdated
): | ||
self.masking_rate = masking_rate | ||
self.masking_strategy = masking_strategy | ||
# self.masking_combination = masking_combination |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this line before we merge.
# Initialize the random number generator. | ||
worker_info = torch.utils.data.get_worker_info() | ||
div_factor = (worker_info.id + 1) if worker_info is not None else 1 | ||
self.rng = np.random.default_rng(int(time.time() / div_factor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rng seed should be passed from the cf.seed, to ensure we can reproduce maskings if we want, while ensuring it is different for the different parallel workers.
I would suggest to keep it as is to not overload the PR but open once merged, open an issue to address the problem.
src/weathergen/datasets/masking.py
Outdated
) | ||
|
||
# Handle the special case where all tokens are masked | ||
# NOTE: not going to handle different streams correctly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully understand the comment. Can you please elaborate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed, my misunderstanding
if self.masking_strategy == "random": | ||
flat_mask = self.rng.uniform(0, 1, num_tokens) < rate | ||
|
||
elif self.masking_strategy == "block": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you visualized this and checked correctness?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, visualised and looks as expected. Just a placeholder for now. Healpix in next PR.
src/weathergen/datasets/masking.py
Outdated
""" | ||
|
||
# check that self.perm_sel is set with an assert statement | ||
assert hasattr(self, 'perm_sel'), "Masker.perm_sel must be set (in mask_source) before calling mask_target." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You set it in the constructor (to None) so the hasattr will alway be true. You need to test for not-None
Note, this PR does introduce a new config parameter, and hence other developers will be notified. include a masking strategy here, currently only supporting "random" and "block" |
Description
WIP: create and instantiate a Masker class which implements masking strategies and is called in TokenizerMasking to produce masked source and target tokens. This class should allow the implementation of different masking strategies (e.g. random, per healpix cell, inpainting etc.) and should ensure the masking of source and target are properly aligned.
Current questions:
Type of Change
Issue Number
Resolves #380
Resolves #408
Code Compatibility
Code Performance and Testing
uv run train
and (if necessary)uv run evaluate
on a least one GPU node and it works$WEATHER_GENERATOR_PRIVATE
directoryDependencies
Documentation
Additional Notes