1
+ import datetime
2
+ import torch
3
+ import torch .nn as nn
4
+ import torch .nn .functional as F
5
+ from utils .parser import get_parser_with_args
6
+ from utils .helpers import (get_loaders , get_criterion ,
7
+ load_model , initialize_metrics , get_mean_metrics ,
8
+ set_metrics )
9
+ from sklearn .metrics import precision_recall_fscore_support as prfs
10
+ import os
11
+ import logging
12
+ import json
13
+ import random
14
+ import numpy as np
15
+ import re
16
+ import warnings
17
+ from models .vgg import Vgg19
18
+ warnings .filterwarnings ("ignore" )
19
+
20
+ """
21
+ Initialize Parser and define arguments
22
+ """
23
+ parser , metadata = get_parser_with_args (metadata_json_path = '/home/aaa/xujialang/master_thesis/MFPNet/metadata.json' )
24
+ opt = parser .parse_args ()
25
+
26
+ """
27
+ Initialize experiments log
28
+ """
29
+ logging .basicConfig (level = logging .INFO )
30
+
31
+ """
32
+ Set up environment: define paths, download data, and set device
33
+ """
34
+ dev = torch .device ('cuda:3' if torch .cuda .is_available () else 'cpu' )
35
+ logging .info ('GPU AVAILABLE? ' + str (torch .cuda .is_available ()))
36
+
37
+ def seed_torch (seed ):
38
+ random .seed (seed )
39
+ os .environ ['PYTHONHASHSEED' ] = str (seed )
40
+ np .random .seed (seed )
41
+ torch .manual_seed (seed )
42
+ torch .cuda .manual_seed (seed )
43
+ # torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
44
+ torch .backends .cudnn .benchmark = False
45
+ torch .backends .cudnn .deterministic = True
46
+ seed_torch (seed = 777 )
47
+
48
+ train_loader , val_loader = get_loaders (opt )
49
+ print (opt .batch_size * len (train_loader ))
50
+ print (opt .batch_size * len (val_loader ))
51
+
52
+ """
53
+ Load Model then define other aspects of the model
54
+ """
55
+ logging .info ('LOADING Model' )
56
+ model = load_model (opt , dev )
57
+ vgg = Vgg19 ().to (dev )
58
+ """
59
+ Resume
60
+ """
61
+ epoch_resume = 0
62
+ if opt .resume != "None" :
63
+ model .load_state_dict (torch .load (os .path .join (opt .resume )))
64
+ epoch_resume = int (re .sub ("\D" ,"" ,opt .resume ))
65
+ print ('resume success: epoch {}' .format (epoch_resume ))
66
+
67
+ criterion_ce = nn .CrossEntropyLoss ().to (dev )
68
+ criterion_perceptual = nn .MSELoss ().to (dev )
69
+ criterion = get_criterion (opt )
70
+ optimizer = torch .optim .Adam (model .parameters (), lr = opt .learning_rate ) # Be careful when you adjust learning rate, you can refer to the linear scaling rule
71
+ scheduler = torch .optim .lr_scheduler .CosineAnnealingWarmRestarts (optimizer , 10 , T_mult = 2 , eta_min = 0 , last_epoch = - 1 )
72
+
73
+ """
74
+ Set starting values
75
+ """
76
+ best_metrics = {'cd_f1scores' : - 1 , 'cd_recalls' : - 1 , 'cd_precisions' : - 1 }
77
+ logging .info ('STARTING training' )
78
+
79
+ for epoch in range (opt .epochs ):
80
+ epoch = epoch + epoch_resume + 1
81
+ train_metrics = initialize_metrics ()
82
+ val_metrics = initialize_metrics ()
83
+
84
+ """
85
+ Begin Training
86
+ """
87
+ model .train ()
88
+ logging .info ('SET model mode to train!' )
89
+
90
+ for batch_img1 , batch_img2 , labels in train_loader :
91
+ # Set variables for training
92
+ batch_img1 = batch_img1 .float ().to (dev )
93
+ batch_img2 = batch_img2 .float ().to (dev )
94
+ labels = labels .long ().to (dev )
95
+
96
+ # Zero the gradient
97
+ optimizer .zero_grad ()
98
+
99
+ # Get model predictions, calculate loss, backprop
100
+ cd_preds = model (batch_img1 , batch_img2 )
101
+ loss = criterion (criterion_ce , criterion_perceptual , cd_preds , labels , batch_img1 , vgg , dev )
102
+
103
+ loss .backward ()
104
+ optimizer .step ()
105
+
106
+ # Calculate and log other batch metrics
107
+ cd_preds = torch .argmax (cd_preds , dim = 1 )
108
+ cd_corrects = (100 *
109
+ (cd_preds .squeeze ().byte () == labels .squeeze ().byte ()).sum () /
110
+ (labels .size ()[0 ] * (opt .patch_size ** 2 )))
111
+ cd_train_report = prfs (labels .data .cpu ().numpy ().flatten (),
112
+ cd_preds .data .cpu ().numpy ().flatten (),
113
+ average = 'binary' ,
114
+ pos_label = 1 )
115
+ train_metrics = set_metrics (train_metrics ,
116
+ loss ,
117
+ cd_corrects ,
118
+ cd_train_report ,
119
+ scheduler .get_last_lr ())
120
+
121
+ # log the batch mean metrics
122
+ mean_train_metrics = get_mean_metrics (train_metrics )
123
+
124
+ # clear batch variables from memory
125
+ del batch_img1 , batch_img2 , labels
126
+
127
+ scheduler .step ()
128
+ logging .info ("EPOCH {} TRAIN METRICS. " .format (epoch ) + str (mean_train_metrics ))
129
+
130
+
131
+ """
132
+ Begin Validation
133
+ """
134
+ model .eval ()
135
+ with torch .no_grad ():
136
+ for batch_img1 , batch_img2 , labels in val_loader :
137
+ # Set variables for training
138
+ batch_img1 = batch_img1 .float ().to (dev )
139
+ batch_img2 = batch_img2 .float ().to (dev )
140
+ labels = labels .long ().to (dev )
141
+
142
+ # Get predictions and calculate loss
143
+ cd_preds = model (batch_img1 , batch_img2 )
144
+ val_loss = criterion (criterion_ce , criterion_perceptual , cd_preds , labels , batch_img1 , vgg , dev )
145
+
146
+ # Calculate and log other batch metrics
147
+ cd_preds = torch .argmax (cd_preds , dim = 1 )
148
+ cd_corrects = (100 *
149
+ (cd_preds .squeeze ().byte () == labels .squeeze ().byte ()).sum () /
150
+ (labels .size ()[0 ] * (opt .patch_size ** 2 )))
151
+ cd_val_report = prfs (labels .data .cpu ().numpy ().flatten (),
152
+ cd_preds .data .cpu ().numpy ().flatten (),
153
+ average = 'binary' ,
154
+ pos_label = 1 )
155
+ val_metrics = set_metrics (val_metrics ,
156
+ val_loss ,
157
+ cd_corrects ,
158
+ cd_val_report ,
159
+ scheduler .get_lr ())
160
+
161
+ # log the batch mean metrics
162
+ mean_val_metrics = get_mean_metrics (val_metrics )
163
+
164
+ # clear batch variables from memory
165
+ del batch_img1 , batch_img2 , labels
166
+
167
+ logging .info ("EPOCH {} VALIDATION METRICS" .format (epoch )+ str (mean_val_metrics ))
168
+
169
+ """
170
+ Store the weights of good epochs based on validation results
171
+ """
172
+ if (mean_val_metrics ['cd_f1scores' ] > best_metrics ['cd_f1scores' ]):
173
+ # Insert training and epoch information to metadata dictionary
174
+ logging .info ('updata the model' )
175
+ metadata ['val_metrics' ] = mean_val_metrics
176
+
177
+ # Save model and log
178
+ if not os .path .exists (opt .weight_dir ):
179
+ os .mkdir (opt .weight_dir )
180
+ with open (opt .weight_dir + 'metadata_val_epoch_' + str (epoch ) + '.json' , 'w' ) as fout :
181
+ json .dump (metadata , fout )
182
+
183
+ torch .save (model .state_dict (), opt .weight_dir + 'checkpoint_epoch_' + str (epoch )+ '_f1_' + str (mean_val_metrics ['cd_f1scores' ])+ '.pt' )
184
+ best_metrics = mean_val_metrics
185
+ print ('best val: ' + str (mean_val_metrics ))
186
+
187
+ print ('An epoch finished.' )
188
+
189
+ print ('Done!' )
0 commit comments