Skip to content

Commit 66bdf01

Browse files
MatKbauerMatthias Karlbauer
andauthored
Added naming convention checks to lint (#501)
* Added naming convention checks to lint * Implemented python naming conventions and corrected code accordingly --------- Co-authored-by: Matthias Karlbauer <[email protected]>
1 parent ee6e757 commit 66bdf01

File tree

11 files changed

+149
-143
lines changed

11 files changed

+149
-143
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ select = [
8888
"I",
8989
# Banned imports
9090
"TID",
91+
# Naming conventions
92+
"N",
9193
# print
9294
"T201"
9395
]
@@ -102,7 +104,8 @@ ignore = [
102104
"SIM102",
103105
"SIM401",
104106
# To ignore, not relevant for us
105-
"SIM108" # in case additional norm layer supports are added in future
107+
"SIM108", # in case additional norm layer supports are added in future
108+
"N817" # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted)
106109
]
107110

108111
[tool.ruff.lint.flake8-tidy-imports.banned-api]

scripts/check_gh_issue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@
4646
pr = pr.split("/")[0]
4747
r = requests.get(f"https://github.com/{repo}/pull/{pr}")
4848
soup = BeautifulSoup(r.text, "html.parser")
49-
issueForm = soup.find_all("form", {"aria-label": re.compile("Link issues")})
49+
issue_form = soup.find_all("form", {"aria-label": re.compile("Link issues")})
5050
msg = msg_template.format(pr=pr, repo=repo)
5151

52-
if not issueForm:
52+
if not issue_form:
5353
print(msg)
5454
exit(1)
55-
issues = [i["href"] for i in issueForm[0].find_all("a")]
55+
issues = [i["href"] for i in issue_form[0].find_all("a")]
5656
issues = [i for i in issues if i is not None and repo in i]
5757
print(f"Linked issues for PR {pr}:")
5858
print(f"Found {len(issues)} linked issues.")

src/weathergen/datasets/tokenizer_forecast.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,65 +40,65 @@ def __init__(self, healpix_level: int, seed: int):
4040
self.num_healpix_cells_source = 12 * 4**self.hl_source
4141
self.num_healpix_cells_target = 12 * 4**self.hl_target
4242

43-
verts00, verts00_Rs = healpix_verts_rots(self.hl_source, 0.0, 0.0)
44-
verts10, verts10_Rs = healpix_verts_rots(self.hl_source, 1.0, 0.0)
45-
verts11, verts11_Rs = healpix_verts_rots(self.hl_source, 1.0, 1.0)
46-
verts01, verts01_Rs = healpix_verts_rots(self.hl_source, 0.0, 1.0)
47-
vertsmm, vertsmm_Rs = healpix_verts_rots(self.hl_source, 0.5, 0.5)
43+
verts00, verts00_rots = healpix_verts_rots(self.hl_source, 0.0, 0.0)
44+
verts10, verts10_rots = healpix_verts_rots(self.hl_source, 1.0, 0.0)
45+
verts11, verts11_rots = healpix_verts_rots(self.hl_source, 1.0, 1.0)
46+
verts01, verts01_rots = healpix_verts_rots(self.hl_source, 0.0, 1.0)
47+
vertsmm, vertsmm_rots = healpix_verts_rots(self.hl_source, 0.5, 0.5)
4848
self.hpy_verts = [
4949
verts00.to(torch.float32),
5050
verts10.to(torch.float32),
5151
verts11.to(torch.float32),
5252
verts01.to(torch.float32),
5353
vertsmm.to(torch.float32),
5454
]
55-
self.hpy_verts_Rs_source = [
56-
verts00_Rs.to(torch.float32),
57-
verts10_Rs.to(torch.float32),
58-
verts11_Rs.to(torch.float32),
59-
verts01_Rs.to(torch.float32),
60-
vertsmm_Rs.to(torch.float32),
55+
self.hpy_verts_rots_source = [
56+
verts00_rots.to(torch.float32),
57+
verts10_rots.to(torch.float32),
58+
verts11_rots.to(torch.float32),
59+
verts01_rots.to(torch.float32),
60+
vertsmm_rots.to(torch.float32),
6161
]
6262

63-
verts00, verts00_Rs = healpix_verts_rots(self.hl_target, 0.0, 0.0)
64-
verts10, verts10_Rs = healpix_verts_rots(self.hl_target, 1.0, 0.0)
65-
verts11, verts11_Rs = healpix_verts_rots(self.hl_target, 1.0, 1.0)
66-
verts01, verts01_Rs = healpix_verts_rots(self.hl_target, 0.0, 1.0)
67-
vertsmm, vertsmm_Rs = healpix_verts_rots(self.hl_target, 0.5, 0.5)
63+
verts00, verts00_rots = healpix_verts_rots(self.hl_target, 0.0, 0.0)
64+
verts10, verts10_rots = healpix_verts_rots(self.hl_target, 1.0, 0.0)
65+
verts11, verts11_rots = healpix_verts_rots(self.hl_target, 1.0, 1.0)
66+
verts01, verts01_rots = healpix_verts_rots(self.hl_target, 0.0, 1.0)
67+
vertsmm, vertsmm_rots = healpix_verts_rots(self.hl_target, 0.5, 0.5)
6868
self.hpy_verts = [
6969
verts00.to(torch.float32),
7070
verts10.to(torch.float32),
7171
verts11.to(torch.float32),
7272
verts01.to(torch.float32),
7373
vertsmm.to(torch.float32),
7474
]
75-
self.hpy_verts_Rs_target = [
76-
verts00_Rs.to(torch.float32),
77-
verts10_Rs.to(torch.float32),
78-
verts11_Rs.to(torch.float32),
79-
verts01_Rs.to(torch.float32),
80-
vertsmm_Rs.to(torch.float32),
75+
self.hpy_verts_rots_target = [
76+
verts00_rots.to(torch.float32),
77+
verts10_rots.to(torch.float32),
78+
verts11_rots.to(torch.float32),
79+
verts01_rots.to(torch.float32),
80+
vertsmm_rots.to(torch.float32),
8181
]
8282

8383
self.verts_local = []
8484
verts = torch.stack([verts10, verts11, verts01, vertsmm])
85-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts00_Rs, verts.transpose(0, 1)))
85+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts00_rots, verts.transpose(0, 1)))
8686
self.verts_local.append(temp.flatten(1, 2))
8787

8888
verts = torch.stack([verts00, verts11, verts01, vertsmm])
89-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts10_Rs, verts.transpose(0, 1)))
89+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts10_rots, verts.transpose(0, 1)))
9090
self.verts_local.append(temp.flatten(1, 2))
9191

9292
verts = torch.stack([verts00, verts10, verts01, vertsmm])
93-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts11_Rs, verts.transpose(0, 1)))
93+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts11_rots, verts.transpose(0, 1)))
9494
self.verts_local.append(temp.flatten(1, 2))
9595

9696
verts = torch.stack([verts00, verts11, verts10, vertsmm])
97-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts01_Rs, verts.transpose(0, 1)))
97+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts01_rots, verts.transpose(0, 1)))
9898
self.verts_local.append(temp.flatten(1, 2))
9999

100100
verts = torch.stack([verts00, verts10, verts11, verts01])
101-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(vertsmm_Rs, verts.transpose(0, 1)))
101+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(vertsmm_rots, verts.transpose(0, 1)))
102102
self.verts_local.append(temp.flatten(1, 2))
103103

104104
self.hpy_verts_local_target = torch.stack(self.verts_local).transpose(0, 1)
@@ -168,7 +168,7 @@ def batchify_source(
168168
time_win=time_win,
169169
token_size=token_size,
170170
hl=self.hl_source,
171-
hpy_verts_Rs=self.hpy_verts_Rs_source[-1],
171+
hpy_verts_rots=self.hpy_verts_rots_source[-1],
172172
n_coords=normalizer.normalize_coords,
173173
n_geoinfos=normalizer.normalize_geoinfos,
174174
n_data=normalizer.normalize_source_channels,
@@ -272,7 +272,7 @@ def batchify_target(
272272
target_coords,
273273
target_geoinfos,
274274
target_times,
275-
self.hpy_verts_Rs_target,
275+
self.hpy_verts_rots_target,
276276
self.hpy_verts_local_target,
277277
self.hpy_nctrs_target,
278278
)

src/weathergen/datasets/tokenizer_masking.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,65 +43,65 @@ def __init__(self, healpix_level: int, seed: int, masker: Masker):
4343
self.num_healpix_cells_source = 12 * 4**self.hl_source
4444
self.num_healpix_cells_target = 12 * 4**self.hl_target
4545

46-
verts00, verts00_Rs = healpix_verts_rots(self.hl_source, 0.0, 0.0)
47-
verts10, verts10_Rs = healpix_verts_rots(self.hl_source, 1.0, 0.0)
48-
verts11, verts11_Rs = healpix_verts_rots(self.hl_source, 1.0, 1.0)
49-
verts01, verts01_Rs = healpix_verts_rots(self.hl_source, 0.0, 1.0)
50-
vertsmm, vertsmm_Rs = healpix_verts_rots(self.hl_source, 0.5, 0.5)
46+
verts00, verts00_rots = healpix_verts_rots(self.hl_source, 0.0, 0.0)
47+
verts10, verts10_rots = healpix_verts_rots(self.hl_source, 1.0, 0.0)
48+
verts11, verts11_rots = healpix_verts_rots(self.hl_source, 1.0, 1.0)
49+
verts01, verts01_rots = healpix_verts_rots(self.hl_source, 0.0, 1.0)
50+
vertsmm, vertsmm_rots = healpix_verts_rots(self.hl_source, 0.5, 0.5)
5151
self.hpy_verts = [
5252
verts00.to(torch.float32),
5353
verts10.to(torch.float32),
5454
verts11.to(torch.float32),
5555
verts01.to(torch.float32),
5656
vertsmm.to(torch.float32),
5757
]
58-
self.hpy_verts_Rs_source = [
59-
verts00_Rs.to(torch.float32),
60-
verts10_Rs.to(torch.float32),
61-
verts11_Rs.to(torch.float32),
62-
verts01_Rs.to(torch.float32),
63-
vertsmm_Rs.to(torch.float32),
58+
self.hpy_verts_rots_source = [
59+
verts00_rots.to(torch.float32),
60+
verts10_rots.to(torch.float32),
61+
verts11_rots.to(torch.float32),
62+
verts01_rots.to(torch.float32),
63+
vertsmm_rots.to(torch.float32),
6464
]
6565

66-
verts00, verts00_Rs = healpix_verts_rots(self.hl_target, 0.0, 0.0)
67-
verts10, verts10_Rs = healpix_verts_rots(self.hl_target, 1.0, 0.0)
68-
verts11, verts11_Rs = healpix_verts_rots(self.hl_target, 1.0, 1.0)
69-
verts01, verts01_Rs = healpix_verts_rots(self.hl_target, 0.0, 1.0)
70-
vertsmm, vertsmm_Rs = healpix_verts_rots(self.hl_target, 0.5, 0.5)
66+
verts00, verts00_rots = healpix_verts_rots(self.hl_target, 0.0, 0.0)
67+
verts10, verts10_rots = healpix_verts_rots(self.hl_target, 1.0, 0.0)
68+
verts11, verts11_rots = healpix_verts_rots(self.hl_target, 1.0, 1.0)
69+
verts01, verts01_rots = healpix_verts_rots(self.hl_target, 0.0, 1.0)
70+
vertsmm, vertsmm_rots = healpix_verts_rots(self.hl_target, 0.5, 0.5)
7171
self.hpy_verts = [
7272
verts00.to(torch.float32),
7373
verts10.to(torch.float32),
7474
verts11.to(torch.float32),
7575
verts01.to(torch.float32),
7676
vertsmm.to(torch.float32),
7777
]
78-
self.hpy_verts_Rs_target = [
79-
verts00_Rs.to(torch.float32),
80-
verts10_Rs.to(torch.float32),
81-
verts11_Rs.to(torch.float32),
82-
verts01_Rs.to(torch.float32),
83-
vertsmm_Rs.to(torch.float32),
78+
self.hpy_verts_rots_target = [
79+
verts00_rots.to(torch.float32),
80+
verts10_rots.to(torch.float32),
81+
verts11_rots.to(torch.float32),
82+
verts01_rots.to(torch.float32),
83+
vertsmm_rots.to(torch.float32),
8484
]
8585

8686
self.verts_local = []
8787
verts = torch.stack([verts10, verts11, verts01, vertsmm])
88-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts00_Rs, verts.transpose(0, 1)))
88+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts00_rots, verts.transpose(0, 1)))
8989
self.verts_local.append(temp.flatten(1, 2))
9090

9191
verts = torch.stack([verts00, verts11, verts01, vertsmm])
92-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts10_Rs, verts.transpose(0, 1)))
92+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts10_rots, verts.transpose(0, 1)))
9393
self.verts_local.append(temp.flatten(1, 2))
9494

9595
verts = torch.stack([verts00, verts10, verts01, vertsmm])
96-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts11_Rs, verts.transpose(0, 1)))
96+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts11_rots, verts.transpose(0, 1)))
9797
self.verts_local.append(temp.flatten(1, 2))
9898

9999
verts = torch.stack([verts00, verts11, verts10, vertsmm])
100-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts01_Rs, verts.transpose(0, 1)))
100+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(verts01_rots, verts.transpose(0, 1)))
101101
self.verts_local.append(temp.flatten(1, 2))
102102

103103
verts = torch.stack([verts00, verts10, verts11, verts01])
104-
temp = ref - torch.stack(locs_to_cell_coords_ctrs(vertsmm_Rs, verts.transpose(0, 1)))
104+
temp = ref - torch.stack(locs_to_cell_coords_ctrs(vertsmm_rots, verts.transpose(0, 1)))
105105
self.verts_local.append(temp.flatten(1, 2))
106106

107107
self.hpy_verts_local_target = torch.stack(self.verts_local).transpose(0, 1)
@@ -171,7 +171,7 @@ def batchify_source(
171171
time_win=time_win,
172172
token_size=token_size,
173173
hl=self.hl_source,
174-
hpy_verts_Rs=self.hpy_verts_Rs_source[-1],
174+
hpy_verts_rots=self.hpy_verts_rots_source[-1],
175175
n_coords=normalizer.normalize_coords,
176176
n_geoinfos=normalizer.normalize_geoinfos,
177177
n_data=normalizer.normalize_source_channels,
@@ -257,7 +257,7 @@ def id(arg):
257257
time_win=time_win,
258258
token_size=token_size,
259259
hl=self.hl_source,
260-
hpy_verts_Rs=self.hpy_verts_Rs_source[-1],
260+
hpy_verts_rots=self.hpy_verts_rots_source[-1],
261261
n_coords=id,
262262
n_geoinfos=normalizer.normalize_geoinfos,
263263
n_data=normalizer.normalize_target_channels,
@@ -311,7 +311,7 @@ def id(arg):
311311
target_coords,
312312
target_geoinfos,
313313
target_times,
314-
self.hpy_verts_Rs_target,
314+
self.hpy_verts_rots_target,
315315
self.hpy_verts_local_target,
316316
self.hpy_nctrs_target,
317317
)

src/weathergen/datasets/tokenizer_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def tokenize_window_space(
166166
time_win,
167167
token_size,
168168
hl,
169-
hpy_verts_Rs,
169+
hpy_verts_rots,
170170
n_coords,
171171
n_geoinfos,
172172
n_data,
@@ -190,14 +190,15 @@ def tokenize_window_space(
190190
source_padded = torch.cat([torch.zeros_like(source[0]).unsqueeze(0), n_data(source)])
191191

192192
# convert to local coordinates
193-
# TODO: how to vectorize it so that there's no list comprhension (and the Rs are not duplicated)
193+
# TODO: how to vectorize it so that there's no list comprhension (and the rots are not
194+
# duplicated)
194195
# TODO: avoid that padded lists are rotated, which means potentially a lot of zeros
195196
if local_coords:
196197
fp32 = torch.float32
197198
posr3 = torch.cat([torch.zeros_like(posr3[0]).unsqueeze(0), posr3])
198199
coords_local = [
199200
n_coords(r3tos2(torch.matmul(R, posr3[idxs].transpose(1, 0)).transpose(1, 0)).to(fp32))
200-
for R, idxs in zip(hpy_verts_Rs, idxs_ord, strict=True)
201+
for R, idxs in zip(hpy_verts_rots, idxs_ord, strict=True)
201202
]
202203
else:
203204
coords_local = torch.cat([torch.zeros_like(coords[0]).unsqueeze(0), coords])
@@ -240,7 +241,7 @@ def tokenize_window_spacetime(
240241
time_win,
241242
token_size,
242243
hl,
243-
hpy_verts_Rs,
244+
hpy_verts_rots,
244245
n_coords,
245246
n_geoinfos,
246247
n_data,
@@ -267,7 +268,7 @@ def tokenize_window_spacetime(
267268
time_win,
268269
token_size,
269270
hl,
270-
hpy_verts_Rs,
271+
hpy_verts_rots,
271272
n_coords,
272273
n_geoinfos,
273274
n_data,

0 commit comments

Comments
 (0)