Skip to content

Commit 88a330b

Browse files
authored
Update train_SSL.py
1 parent dd9c67d commit 88a330b

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

scripts/train_SSL.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@
1919
import MIR.random_image_generation as rs
2020
from MIR.utils import mk_grid_img
2121

22-
def MSE_torch(x, y):
23-
return torch.mean((x - y) ** 2)
24-
25-
def prepare_input(resolution):
26-
x = torch.FloatTensor(1, *resolution)
27-
y = torch.FloatTensor(1, *resolution)
28-
return dict(inputs=(x,y))
29-
3022
def main():
3123
iter_max = 3000
3224
val_step = 50
@@ -186,6 +178,7 @@ def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max
186178
while len(model_lists) > max_model_num:
187179
os.remove(model_lists[0])
188180
model_lists = natsorted(glob.glob(save_dir + '*'))
181+
189182
if __name__ == '__main__':
190183
'''
191184
GPU configuration

0 commit comments

Comments
 (0)