diff --git a/requirements.txt b/requirements.txt index 6cb79d3..7693b48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.13.3 -torch>=0.4.0 +torch>=0.4.0,<=1.4.0 transformers>=3.5.1,<4.0.0 sklearn diff --git a/train_k_fold_cross_val.py b/train_k_fold_cross_val.py index e42268e..deae640 100644 --- a/train_k_fold_cross_val.py +++ b/train_k_fold_cross_val.py @@ -80,8 +80,6 @@ def _reset_params(self): else: stdv = 1. / math.sqrt(p.shape[0]) torch.nn.init.uniform_(p, a=-stdv, b=stdv) - else: - self.model.bert.load_state_dict(self.pretrained_bert_state_dict) def _train(self, criterion, optimizer, train_data_loader, val_data_loader): max_val_acc = 0