Skip to content

Commit 5fbc9b1

Browse files
committed
further linting of docstrings
1 parent 08d676f commit 5fbc9b1

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

src/weathergen/datasets/masking.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ class Masker:
77
This class supports different masking strategies and combinations.
88
Attributes:
99
masking_rate (float): The base rate at which tokens are masked.
10-
masking_strategy (str): The strategy used for masking (e.g., "random", "block", MORE TO BE IMPLEMENTED...).
11-
TO BE IMPLEMENTED: masking_combination (str): The strategy for combining masking strategies through training (e.g., "sequential").
10+
masking_strategy (str): The strategy used for masking (e.g., "random",
11+
"block", MORE TO BE IMPLEMENTED...).
12+
TO BE IMPLEMENTED: masking_combination (str): The strategy for combining masking
13+
strategies through training (e.g., "sequential").
1214
masking_rate_sampling (bool): Whether to sample the masking rate from a distribution.
1315
"""
1416

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ def __iter__(self):
339339

340340
(ss_cells, ss_lens, ss_centroids) = self.tokenizer.batchify_source(
341341
stream_info,
342-
# NOTE: two unused arguments in TokenizerMasking, still used in TokenizerForecast?
342+
# NOTE: two unused arguments in TokenizerMasking,
343+
# still used in TokenizerForecast?
343344
self.masking_rate,
344345
self.masking_rate_sampling,
345346
torch.from_numpy(rdata.coords),

src/weathergen/datasets/tokenizer_masking.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ def batchify_source(
198198
if masking_rate == 1.0:
199199
token_lens = [len(t) for t in tokenized_data]
200200
self.perm_sel = [np.ones(l, dtype=bool) for l in token_lens]
201-
source_tokens_cells = [c[~p] for c, p in zip(tokenized_data, self.perm_sel, strict=False)]
201+
source_tokens_cells = [
202+
c[~p] for c, p in zip(tokenized_data, self.perm_sel, strict=False)
203+
]
202204
source_tokens_lens = torch.zeros([self.num_healpix_cells_source], dtype=torch.int32)
203205
source_centroids = torch.tensor([])
204206
return (source_tokens_cells, source_tokens_lens, source_centroids)
@@ -281,10 +283,12 @@ def id(arg):
281283
)
282284

283285
# --- MODIFICATION START ---
284-
# The following block is modified to handle cases where a cell has no target tokens,
286+
# The following block is modified to handle cases
287+
# where a cell has no target tokens,
285288
# which would cause an error in torch.cat with an empty list.
286289

287-
# Pre-calculate the total feature dimension of a token to create correctly shaped empty tensors.
290+
# Pre-calculate the total feature dimension of a token to create
291+
# correctly shaped empty tensors.
288292
feature_dim = 6 + coords.shape[-1] + geoinfos.shape[-1] + source.shape[-1]
289293

290294
processed_target_tokens = []

0 commit comments

Comments
 (0)