Skip to content

Introduce masking class and incorporate in TokenizerMasking #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jun 28, 2025

Conversation

shmh40
Copy link
Contributor

@shmh40 shmh40 commented Jun 24, 2025

Description

WIP: create and instantiate a Masker class which implements masking strategies and is called in TokenizerMasking to produce masked source and target tokens. This class should allow the implementation of different masking strategies (e.g. random, per healpix cell, inpainting etc.) and should ensure the masking of source and target are properly aligned.

Current questions:

  1. The Masker returns masked source_data, and the mask perm_sel. This perm_sel is then used in tokenizer_masking, outside of the Masker. Perhaps both the source data and the target data should be produced using this Masker class? We could include, in Masker.mask, if target: use ~perm_sel to produce masked target data.
  2. I haven't quite worked out how this should link with different streams, if we want to mask a whole stream and predict from another stream.
  3. The second masking strategy, "block" is just a placeholder. Ideally this would be implemented with healpix cells. I need to work out/implement how to use the healpix cells in this Masker class too. Just pass them (produced in TokenizerMasking?) and use them?

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

Issue Number

Resolves #380
Resolves #408

Code Compatibility

  • I have performed a self-review of my code

Code Performance and Testing

  • I ran the uv run train and (if necessary) uv run evaluate on a least one GPU node and it works
  • If the new feature introduces modifications at the config level, I have made sure to have notified the other software developers through Mattermost and updated the paths in the $WEATHER_GENERATOR_PRIVATE directory

Dependencies

  • I have ensured that the code is still pip-installable after the changes and runs
  • I have tested that new dependencies themselves are pip-installable.
  • I have not introduced new dependencies in the inference portion of the pipeline

Documentation

  • My code follows the style guidelines of this project
  • I have updated the documentation and docstrings to reflect the changes
  • I have added comments to my code, particularly in hard-to-understand areas

Additional Notes

@shmh40 shmh40 requested a review from clessig June 24, 2025 13:52
@shmh40 shmh40 self-assigned this Jun 24, 2025
@shmh40 shmh40 added the enhancement New feature or request label Jun 24, 2025
Copy link
Collaborator

@tjhunter tjhunter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few style comments

@shmh40 shmh40 marked this pull request as draft June 24, 2025 14:54
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the draft. Let's try to keep it as much as possible as a refactor and introduce new features later; same for some fixes that are not directly related.

The Masker should encapsulate the masking as much as possible.

def mask_source(
self,
tokenized_data: list[torch.Tensor],
rng: np.random.Generator,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be part of the state, potentially passed in the constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed thank you - done with new commit

self,
tokenized_data: list[torch.Tensor],
rng: np.random.Generator,
masking_rate: float,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's passed in the constructor. Why is it passed here again? any sampling should also happen in this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed



class Masker:
"""Class to generate boolean masks for token sequences and apply them.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the class can also be used for BERT-type masking + noising.

Copy link
Contributor Author

@shmh40 shmh40 Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the docstring, remove boolean

if num_tokens == 0:
return tokenized_data, []

# Determine the masking rate to use for this call
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove, see above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


if self.masking_rate_sampling:
rate = np.clip(
np.abs(rng.normal(loc=rate, scale=1.0 / (2.5 * np.pi))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should parametrize this. But better to do it in a separate PR

@@ -336,6 +339,8 @@ def __iter__(self):

(ss_cells, ss_lens, ss_centroids) = self.tokenizer.batchify_source(
stream_info,
# NOTE: two unused arguments in TokenizerMasking,
# still used in TokenizerForecast?
self.masking_rate,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they should be removed

@@ -178,6 +178,7 @@ def tokenize_window_space(
if len(source) < 2:
return

# idx_ord_lens is length...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx_ord_lens is length is number of tokens per healpix cell

)
for cc, pp in zip(target_tokens_cells, self.perm_sel, strict=True)
]
######################
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove before we merge

@@ -280,7 +278,6 @@ def id(arg):
)

# tokenize
# TODO: properly set stream_id; don't forget to normalize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still needs to be done

mask[self.rng.integers(low=0, high=len(mask))] = False
# if masking rate is 1.0, all tokens are masked, so the source is empty
# but we must compute perm_sel for the target function
if masking_rate == 1.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case should be handled in the Masker and not here

@clessig clessig moved this to In Progress in WeatherGen-dev Jun 24, 2025
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments. If it has been tested, then it's good to be merged.

):
self.masking_rate = masking_rate
self.masking_strategy = masking_strategy
# self.masking_combination = masking_combination
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this line before we merge.

# Initialize the random number generator.
worker_info = torch.utils.data.get_worker_info()
div_factor = (worker_info.id + 1) if worker_info is not None else 1
self.rng = np.random.default_rng(int(time.time() / div_factor))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rng seed should be passed from the cf.seed, to ensure we can reproduce maskings if we want, while ensuring it is different for the different parallel workers.

I would suggest to keep it as is to not overload the PR but open once merged, open an issue to address the problem.

)

# Handle the special case where all tokens are masked
# NOTE: not going to handle different streams correctly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand the comment. Can you please elaborate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, my misunderstanding

if self.masking_strategy == "random":
flat_mask = self.rng.uniform(0, 1, num_tokens) < rate

elif self.masking_strategy == "block":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you visualized this and checked correctness?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, visualised and looks as expected. Just a placeholder for now. Healpix in next PR.

"""

# check that self.perm_sel is set with an assert statement
assert hasattr(self, 'perm_sel'), "Masker.perm_sel must be set (in mask_source) before calling mask_target."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You set it in the constructor (to None) so the hasattr will alway be true. You need to test for not-None

@shmh40 shmh40 marked this pull request as ready for review June 27, 2025 15:25
@shmh40
Copy link
Contributor Author

shmh40 commented Jun 27, 2025

Note, this PR does introduce a new config parameter, and hence other developers will be notified.

include a masking strategy here, currently only supporting "random" and "block"
masking_strategy: "random"

@clessig clessig merged commit 6f831c3 into develop Jun 28, 2025
3 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in WeatherGen-dev Jun 28, 2025
@clessig clessig deleted the shmh40/dev/masking_class branch June 28, 2025 12:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

RNGs in tokenizers are not properly initalized Masking class and separation of tokenizer_masking
3 participants