@@ -43,65 +43,65 @@ def __init__(self, healpix_level: int, seed: int, masker: Masker):
43
43
self .num_healpix_cells_source = 12 * 4 ** self .hl_source
44
44
self .num_healpix_cells_target = 12 * 4 ** self .hl_target
45
45
46
- verts00 , verts00_Rs = healpix_verts_rots (self .hl_source , 0.0 , 0.0 )
47
- verts10 , verts10_Rs = healpix_verts_rots (self .hl_source , 1.0 , 0.0 )
48
- verts11 , verts11_Rs = healpix_verts_rots (self .hl_source , 1.0 , 1.0 )
49
- verts01 , verts01_Rs = healpix_verts_rots (self .hl_source , 0.0 , 1.0 )
50
- vertsmm , vertsmm_Rs = healpix_verts_rots (self .hl_source , 0.5 , 0.5 )
46
+ verts00 , verts00_rots = healpix_verts_rots (self .hl_source , 0.0 , 0.0 )
47
+ verts10 , verts10_rots = healpix_verts_rots (self .hl_source , 1.0 , 0.0 )
48
+ verts11 , verts11_rots = healpix_verts_rots (self .hl_source , 1.0 , 1.0 )
49
+ verts01 , verts01_rots = healpix_verts_rots (self .hl_source , 0.0 , 1.0 )
50
+ vertsmm , vertsmm_rots = healpix_verts_rots (self .hl_source , 0.5 , 0.5 )
51
51
self .hpy_verts = [
52
52
verts00 .to (torch .float32 ),
53
53
verts10 .to (torch .float32 ),
54
54
verts11 .to (torch .float32 ),
55
55
verts01 .to (torch .float32 ),
56
56
vertsmm .to (torch .float32 ),
57
57
]
58
- self .hpy_verts_Rs_source = [
59
- verts00_Rs .to (torch .float32 ),
60
- verts10_Rs .to (torch .float32 ),
61
- verts11_Rs .to (torch .float32 ),
62
- verts01_Rs .to (torch .float32 ),
63
- vertsmm_Rs .to (torch .float32 ),
58
+ self .hpy_verts_rots_source = [
59
+ verts00_rots .to (torch .float32 ),
60
+ verts10_rots .to (torch .float32 ),
61
+ verts11_rots .to (torch .float32 ),
62
+ verts01_rots .to (torch .float32 ),
63
+ vertsmm_rots .to (torch .float32 ),
64
64
]
65
65
66
- verts00 , verts00_Rs = healpix_verts_rots (self .hl_target , 0.0 , 0.0 )
67
- verts10 , verts10_Rs = healpix_verts_rots (self .hl_target , 1.0 , 0.0 )
68
- verts11 , verts11_Rs = healpix_verts_rots (self .hl_target , 1.0 , 1.0 )
69
- verts01 , verts01_Rs = healpix_verts_rots (self .hl_target , 0.0 , 1.0 )
70
- vertsmm , vertsmm_Rs = healpix_verts_rots (self .hl_target , 0.5 , 0.5 )
66
+ verts00 , verts00_rots = healpix_verts_rots (self .hl_target , 0.0 , 0.0 )
67
+ verts10 , verts10_rots = healpix_verts_rots (self .hl_target , 1.0 , 0.0 )
68
+ verts11 , verts11_rots = healpix_verts_rots (self .hl_target , 1.0 , 1.0 )
69
+ verts01 , verts01_rots = healpix_verts_rots (self .hl_target , 0.0 , 1.0 )
70
+ vertsmm , vertsmm_rots = healpix_verts_rots (self .hl_target , 0.5 , 0.5 )
71
71
self .hpy_verts = [
72
72
verts00 .to (torch .float32 ),
73
73
verts10 .to (torch .float32 ),
74
74
verts11 .to (torch .float32 ),
75
75
verts01 .to (torch .float32 ),
76
76
vertsmm .to (torch .float32 ),
77
77
]
78
- self .hpy_verts_Rs_target = [
79
- verts00_Rs .to (torch .float32 ),
80
- verts10_Rs .to (torch .float32 ),
81
- verts11_Rs .to (torch .float32 ),
82
- verts01_Rs .to (torch .float32 ),
83
- vertsmm_Rs .to (torch .float32 ),
78
+ self .hpy_verts_rots_target = [
79
+ verts00_rots .to (torch .float32 ),
80
+ verts10_rots .to (torch .float32 ),
81
+ verts11_rots .to (torch .float32 ),
82
+ verts01_rots .to (torch .float32 ),
83
+ vertsmm_rots .to (torch .float32 ),
84
84
]
85
85
86
86
self .verts_local = []
87
87
verts = torch .stack ([verts10 , verts11 , verts01 , vertsmm ])
88
- temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts00_Rs , verts .transpose (0 , 1 )))
88
+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts00_rots , verts .transpose (0 , 1 )))
89
89
self .verts_local .append (temp .flatten (1 , 2 ))
90
90
91
91
verts = torch .stack ([verts00 , verts11 , verts01 , vertsmm ])
92
- temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts10_Rs , verts .transpose (0 , 1 )))
92
+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts10_rots , verts .transpose (0 , 1 )))
93
93
self .verts_local .append (temp .flatten (1 , 2 ))
94
94
95
95
verts = torch .stack ([verts00 , verts10 , verts01 , vertsmm ])
96
- temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts11_Rs , verts .transpose (0 , 1 )))
96
+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts11_rots , verts .transpose (0 , 1 )))
97
97
self .verts_local .append (temp .flatten (1 , 2 ))
98
98
99
99
verts = torch .stack ([verts00 , verts11 , verts10 , vertsmm ])
100
- temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts01_Rs , verts .transpose (0 , 1 )))
100
+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (verts01_rots , verts .transpose (0 , 1 )))
101
101
self .verts_local .append (temp .flatten (1 , 2 ))
102
102
103
103
verts = torch .stack ([verts00 , verts10 , verts11 , verts01 ])
104
- temp = ref - torch .stack (locs_to_cell_coords_ctrs (vertsmm_Rs , verts .transpose (0 , 1 )))
104
+ temp = ref - torch .stack (locs_to_cell_coords_ctrs (vertsmm_rots , verts .transpose (0 , 1 )))
105
105
self .verts_local .append (temp .flatten (1 , 2 ))
106
106
107
107
self .hpy_verts_local_target = torch .stack (self .verts_local ).transpose (0 , 1 )
@@ -171,7 +171,7 @@ def batchify_source(
171
171
time_win = time_win ,
172
172
token_size = token_size ,
173
173
hl = self .hl_source ,
174
- hpy_verts_Rs = self .hpy_verts_Rs_source [- 1 ],
174
+ hpy_verts_rots = self .hpy_verts_rots_source [- 1 ],
175
175
n_coords = normalizer .normalize_coords ,
176
176
n_geoinfos = normalizer .normalize_geoinfos ,
177
177
n_data = normalizer .normalize_source_channels ,
@@ -257,7 +257,7 @@ def id(arg):
257
257
time_win = time_win ,
258
258
token_size = token_size ,
259
259
hl = self .hl_source ,
260
- hpy_verts_Rs = self .hpy_verts_Rs_source [- 1 ],
260
+ hpy_verts_rots = self .hpy_verts_rots_source [- 1 ],
261
261
n_coords = id ,
262
262
n_geoinfos = normalizer .normalize_geoinfos ,
263
263
n_data = normalizer .normalize_target_channels ,
@@ -311,7 +311,7 @@ def id(arg):
311
311
target_coords ,
312
312
target_geoinfos ,
313
313
target_times ,
314
- self .hpy_verts_Rs_target ,
314
+ self .hpy_verts_rots_target ,
315
315
self .hpy_verts_local_target ,
316
316
self .hpy_nctrs_target ,
317
317
)
0 commit comments