forked from goncabakar/GMVAE-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
48 lines (37 loc) · 1.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import torch
from torch import nn, optim
def seperate_dataset(data, labels, seperator):
idx = torch.randperm(data.nelement())
data_perm = data[idx]
labels_perm = labels[idx]
train_data = data_perm[0:seperator]
test_data = data_perm[seperator:]
train_labels = labels_perm[0:seperator]
test_labels = labels_perm[seperator:]
return train_data, test_data, train_labels, test_labels
def save_checkpoint(model, epoch, model_out_path, save_dir, optimizer=None, lr=0.0001, tloss=-1):
#name = "epoch_{}.pth".format(epoch)
#model_out_path = os.path.join(save_dir, name)
if optimizer == None:
state = {"epoch": epoch ,"model": model,"tloss":tloss}
else:
state = {"epoch": epoch ,"model": model,"optimizer":optimizer.state_dict(), "tloss":tloss,
"lr":lr}
if not os.path.exists(save_dir):
os.mkdir(save_dir)
torch.save(state, model_out_path)
def load_checkpoint(path):
if not os.path.exists(path):
print("Model does not exist")
return None
checkpoint = torch.load(path)
epoch = checkpoint['epoch']
model = checkpoint['model']
tloss = checkpoint['tloss']
optimizer = None
if 'optimizer' in checkpoint:
lr = checkpoint['lr']
optimizer = optim.Adam(model.parameters(), lr=lr, amsgrad=False)
optimizer.load_state_dict(checkpoint['optimizer'])
return model, epoch, optimizer, tloss