Skip to content

Commit 2ab5208

Browse files
authored
Refactor of batch data handling and passing to model (#34)
* Refactored handling of data from data loader to model. Functional although final convergence tests are missing. Some cleanup and adding of comments is also missing. * Fixed problems that surfaces when training with a large number of streams. * - Added docs. - Fixed minor isses, in particular related to evaluation. * - Added proper docs to StreamData - Cleaned up some data that was no longer used or could be computed more locally. * - Updated batchifyer (missed in last commit). - Fixed problem in running evaluation with writing output - Auto-formatting and linting
1 parent de731a3 commit 2ab5208

File tree

6 files changed

+581
-504
lines changed

6 files changed

+581
-504
lines changed

src/weathergen/datasets/batchifyer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ def batchify_target(
314314
if len(source) < 2:
315315
target_tokens, target_coords = torch.tensor([]), torch.tensor([])
316316
target_tokens_lens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32)
317-
target_coords_lens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32)
318317

319318
else:
320319
thetas = ((90.0 - source[:, geoinfo_offset]) / 180.0) * np.pi
@@ -339,7 +338,6 @@ def batchify_target(
339338
target_coords[c] = normalize_targets(t[:, :geoinfo_size].clone())
340339

341340
target_tokens_lens = torch.tensor([len(s) for s in target_tokens], dtype=torch.int32)
342-
target_coords_lens = target_tokens_lens.detach().clone()
343341

344342
# if target_coords_local and target_tokens_lens.sum()>0 :
345343
if target_tokens_lens.sum() > 0:
@@ -352,6 +350,6 @@ def batchify_target(
352350
self.hpy_nctrs_target,
353351
)
354352
target_coords.requires_grad = False
355-
target_coords = list(target_coords.split(target_coords_lens.tolist()))
353+
target_coords = list(target_coords.split(target_tokens_lens.tolist()))
356354

357-
return (target_tokens, target_tokens_lens, target_coords, target_coords_lens)
355+
return (target_tokens, target_coords)

0 commit comments

Comments
 (0)