Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 9c21696

Browse files
LiyuanLucasLiuDeNeutoy
authored andcommitted
fixing a bug in trainer for histograms (#1498)
otherwise it would raise an error in line 505 (expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #2 'other')
1 parent 5f2f539 commit 9c21696

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

allennlp/training/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
500500
for name, param in self._model.named_parameters():
501501
param_updates[name].sub_(param.detach().cpu())
502502
update_norm = torch.norm(param_updates[name].view(-1, ))
503-
param_norm = torch.norm(param.view(-1, ))
503+
param_norm = torch.norm(param.view(-1, )).cpu()
504504
self._tensorboard.add_train_scalar("gradient_update/" + name,
505505
update_norm / (param_norm + 1e-7),
506506
batch_num_total)

0 commit comments

Comments
 (0)