Skip to content

Commit 960d48d

Browse files
committed
minor
1 parent cc40839 commit 960d48d

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

util/utils.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import torch
2+
from tqdm import tqdm
3+
import json
4+
import os
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
import matplotlib
8+
import argparse
9+
import pickle
10+
import pdb
11+
from tqdm import tqdm
12+
13+
14+
15+
def get_data_location(args):
16+
if args.dataset == 'ins_channel':
17+
data_location = os.path.join(args.data_location, 'data_set_ins')
18+
elif args.dataset == 'backward_facing':
19+
data_location = os.path.join(args.data_location, 'data_set_pitz')
20+
elif args.dataset == 'duan':
21+
data_location = os.path.join(args.data_location, 'data_set_duan')
22+
else:
23+
raise ValueError('Not implemented')
24+
return data_location
25+
26+
27+
def save_loss(args, loss_list, Nt):
28+
plt.figure()
29+
plt.plot(loss_list,'-o')
30+
plt.yscale('log')
31+
plt.xlabel('epoch')
32+
plt.ylabel('loss')
33+
plt.title(str(min(loss_list))+'Nt'+str(Nt))
34+
print(os.path.join(args.logging_path, 'loss_curve.png'))
35+
plt.savefig(os.path.join(args.logging_path, 'loss_curve.png'))
36+
plt.close()
37+
np.savetxt(os.path.join(args.logging_path, 'loss_curve.txt'),
38+
np.asarray(loss_list))
39+
40+
def save_args(args):
41+
with open(os.path.join(args.logging_path, 'args.txt'), 'w') as f:
42+
json.dump(args.__dict__, f, indent=2)
43+
44+
def save_args_sample(args,name):
45+
with open(os.path.join(args.experiment_path, name), 'w') as f:
46+
json.dump(args.__dict__, f, indent=2)
47+
48+
def read_args_txt(args, argtxt):
49+
#args.parser.parse_args(namespace=args.update_args_no_folder_create())
50+
f = open (argtxt, "r")
51+
args = args.parser.parse_args(namespace=argparse.Namespace(**json.loads(f.read())))
52+
return args
53+
return t
54+
55+
def save_model(model, args, Nt, bestModel = False):
56+
if bestModel:
57+
torch.save(model.state_dict(),
58+
os.path.join(args.model_save_path,
59+
'best_model_sofar'))
60+
np.savetxt(os.path.join(args.model_save_path,
61+
'best_model_sofar_Nt'),np.ones(2)*Nt)
62+
else:
63+
torch.save(model.state_dict(),
64+
os.path.join(args.model_save_path,
65+
'model_epoch_' + str(Nt)))
66+
67+
def load_model(model,args_train,args_sample):
68+
if args_sample.usebestmodel:
69+
model.load_state_dict(torch.load(args_train.current_model_save_path+'best_model_sofar'))
70+
else:
71+
model.load_state_dict(torch.load(args_train.current_model_save_path+'model_epoch_'+str(args_sample.model_epoch)))
72+
return model
73+
74+
75+
76+
77+
78+
79+
80+
81+
82+
83+
84+
85+
86+
87+
class normalizer_1dks(object):
88+
"""
89+
arguments:
90+
target_dataset (torch.utils.data.Dataset) : this is dataset we
91+
want to normalize
92+
"""
93+
def __init__(self, target_dataset,args):
94+
# mark the orginal device of the target_dataset
95+
self.mean = target_dataset.mean().to(args.device)
96+
self.std = target_dataset.std().to(args.device)
97+
def normalize(self, batch):
98+
return (batch - self.mean) / self.std
99+
def normalize_inv(self, batch):
100+
return batch * self.std +self.mean
101+
102+
103+
104+
105+
106+
107+
108+
109+
110+
111+
112+
113+
114+
115+
116+
117+
118+
119+
120+
121+
if __name__ == '__main__':
122+
num_videos = 10
123+
fig, axs = plt.subplots(2,int(num_videos/2))
124+
number_of_sample = int(num_videos/2)
125+
fig.subplots_adjust(hspace=-0.9,wspace=0.1)
126+
videos_to_plot = [np.zeros([1,3,1,64,256]) for _ in range(num_videos)]
127+
j = 0
128+
for k in range(0, num_videos):
129+
this_video = videos_to_plot[k-1]
130+
axs[k//number_of_sample, k%number_of_sample].imshow(np.sqrt(this_video[0,0,j,:,:]**2 + this_video[0,1,j,:,:]**2))
131+
axs[k//number_of_sample, k%number_of_sample].set_xticks([])
132+
axs[k//number_of_sample, k%number_of_sample].set_yticks([])
133+
plt.savefig('test_space.png',bbox_inches='tight')

0 commit comments

Comments
 (0)