Skip to content

Commit 7b0cfbd

Browse files
committed
train_eval
1 parent 3a76e97 commit 7b0cfbd

13 files changed

+1027
-3
lines changed

MFPNet_code/eval.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from shutil import copyfile
2+
import torch.utils.data
3+
from utils.parser import get_parser_with_args
4+
from utils.helpers import get_test_loaders
5+
from tqdm import tqdm
6+
from sklearn.metrics import confusion_matrix
7+
import numpy as np
8+
import torch.nn.functional as F
9+
import cv2
10+
import os
11+
from utils.helpers import load_model
12+
13+
parser, metadata = get_parser_with_args(metadata_json_path='/home/aaa/xujialang/master_thesis/MFPNet/metadata.json')
14+
opt = parser.parse_args()
15+
dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16+
17+
test_loader = get_test_loaders(opt)
18+
19+
weight_path = os.path.join(opt.weight_dir, 'model_weight.pt') # the path of the model weight
20+
model = load_model(opt, dev)
21+
model.load_state_dict(torch.load(weight_path))
22+
"""
23+
Begin Test
24+
"""
25+
model.eval()
26+
with torch.no_grad():
27+
c_matrix = {'tn': 0, 'fp': 0, 'fn': 0, 'tp': 0}
28+
test_metrics = {
29+
'cd_precisions': [],
30+
'cd_recalls': [],
31+
'cd_f1scores': [],
32+
}
33+
34+
for batch_img1, batch_img2, labels in test_loader:
35+
batch_img1 = batch_img1.float().to(dev)
36+
batch_img2 = batch_img2.float().to(dev)
37+
labels = labels.long().to(dev)
38+
cd_preds = model(batch_img1, batch_img2)
39+
cd_preds = torch.argmax(cd_preds, dim = 1)
40+
41+
tp= (labels.cpu().numpy() * cd_preds.cpu().numpy()).sum()
42+
tn= ((1-labels.cpu().numpy()) * (1-cd_preds.cpu().numpy())).sum()
43+
fn= (labels.cpu().numpy() * (1-cd_preds.cpu().numpy())).sum()
44+
fp= ((1-labels.cpu().numpy()) * cd_preds.cpu().numpy()).sum()
45+
c_matrix['tn'] += tn
46+
c_matrix['fp'] += fp
47+
c_matrix['fn'] += fn
48+
c_matrix['tp'] += tp
49+
50+
tn, fp, fn, tp = c_matrix['tn'], c_matrix['fp'], c_matrix['fn'], c_matrix['tp']
51+
P = tp / (tp + fp)
52+
R = tp / (tp + fn)
53+
F1 = 2 * P * R / (R + P)
54+
IOU = tp/ (fn+tp+fp)
55+
56+
ttt_test=tn+fp+fn+tp
57+
TA_test = (tp+tn) / ttt_test
58+
Pcp1_test = (tp + fn) / ttt_test
59+
Pcp2_test = (tp + fp) / ttt_test
60+
Pcn1_test = (fp + tn) / ttt_test
61+
Pcn2_test = (fn + tn) / ttt_test
62+
Pc_test = Pcp1_test*Pcp2_test + Pcn1_test*Pcn2_test
63+
kappa_test = (TA_test - Pc_test) / (1 - Pc_test)
64+
65+
test_metrics['cd_f1scores'] = F1
66+
test_metrics['cd_precisions'] = P
67+
test_metrics['cd_recalls'] = R
68+
print("TEST METRICS. KAPPA: {}. IOU: {} ".format(kappa_test, IOU) + str(test_metrics))

MFPNet_code/metadata.json

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"patch_size": 256,
3+
"augmentation": true,
4+
"num_gpus": 1,
5+
"num_workers": 4,
6+
"num_channel": 3,
7+
"epochs": 200,
8+
"batch_size": 4,
9+
"learning_rate": 1e-4,
10+
"loss_function": "hybrid",
11+
"dataset_dir": "/home/bigspace/xujialang/cd_dataset/Google/",
12+
"weight_dir": "/home/bigspace/xujialang/MFPNet_result/Google/",
13+
"resume": "None"
14+
}

MFPNet_code/metadata_descripation.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
# For Seasonvarying/LEVIR-CD/Google Dataset
3+
{
4+
"patch_size": 256,
5+
"augmentation": true,
6+
"num_gpus": 1,
7+
"num_workers": 4,
8+
"num_channel": 3,
9+
"epochs": 200,
10+
"batch_size": 4,
11+
"learning_rate": 1e-4,
12+
"loss_function": "hybrid", # ['hybird', 'bce', 'dice', 'jaccard'], 'hybrid' means Softmax PPCE + Perceputal Loss
13+
"dataset_dir": "/home/bigspace/xujialang/cd_dataset/Seasonvarying/", # change to your own path
14+
"weight_dir": "/home/bigspace/xujialang/MFPNet_result/Seasonvarying/", # change to your own path
15+
"resume": "None" # Change if you want to continue your training process
16+
}
17+
18+
# For Zhang dataset
19+
{
20+
"patch_size": 512,
21+
"augmentation": true,
22+
"num_gpus": 1,
23+
"num_workers": 4,
24+
"num_channel": 3,
25+
"epochs": 200,
26+
"batch_size": 2,
27+
"learning_rate": 1e-4,
28+
"loss_function": "hybrid",
29+
"dataset_dir": "/home/bigspace/xujialang/cd_dataset/Zhang/"
30+
"weight_dir": "/home/bigspace/xujialang/MFPNet_result/Zhang/",
31+
"resume": "None"
32+
}

MFPNet_code/MFPNet_model.py MFPNet_code/models/MFPNet_model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn.functional as F
66
import os
77

8-
from seresnet50 import se_resnet50
8+
from .seresnet50 import se_resnet50
99

1010
class BasicConvBlock(nn.Module):
1111
def __init__(self, in_channels, out_channels=None):
@@ -357,9 +357,8 @@ def forward(self, x_prev, x_now):
357357

358358
x_fuse=self.maffm(features_t1_t2)
359359
dis_map=self.dec(x_fuse)
360-
result = torch.argmax(dis_map, dim = 1, keepdim = True)
361360

362-
return dis_map, result
361+
return dis_map
363362

364363
if __name__ == "__main__":
365364
model = MFPNET(classes = 2)
File renamed without changes.
File renamed without changes.

MFPNet_code/train.py

+189
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import datetime
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from utils.parser import get_parser_with_args
6+
from utils.helpers import (get_loaders, get_criterion,
7+
load_model, initialize_metrics, get_mean_metrics,
8+
set_metrics)
9+
from sklearn.metrics import precision_recall_fscore_support as prfs
10+
import os
11+
import logging
12+
import json
13+
import random
14+
import numpy as np
15+
import re
16+
import warnings
17+
from models.vgg import Vgg19
18+
warnings.filterwarnings("ignore")
19+
20+
"""
21+
Initialize Parser and define arguments
22+
"""
23+
parser, metadata = get_parser_with_args(metadata_json_path='/home/aaa/xujialang/master_thesis/MFPNet/metadata.json')
24+
opt = parser.parse_args()
25+
26+
"""
27+
Initialize experiments log
28+
"""
29+
logging.basicConfig(level=logging.INFO)
30+
31+
"""
32+
Set up environment: define paths, download data, and set device
33+
"""
34+
dev = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
35+
logging.info('GPU AVAILABLE? ' + str(torch.cuda.is_available()))
36+
37+
def seed_torch(seed):
38+
random.seed(seed)
39+
os.environ['PYTHONHASHSEED'] = str(seed)
40+
np.random.seed(seed)
41+
torch.manual_seed(seed)
42+
torch.cuda.manual_seed(seed)
43+
# torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
44+
torch.backends.cudnn.benchmark = False
45+
torch.backends.cudnn.deterministic = True
46+
seed_torch(seed=777)
47+
48+
train_loader, val_loader = get_loaders(opt)
49+
print(opt.batch_size * len(train_loader))
50+
print(opt.batch_size * len(val_loader))
51+
52+
"""
53+
Load Model then define other aspects of the model
54+
"""
55+
logging.info('LOADING Model')
56+
model = load_model(opt, dev)
57+
vgg=Vgg19().to(dev)
58+
"""
59+
Resume
60+
"""
61+
epoch_resume=0
62+
if opt.resume != "None":
63+
model.load_state_dict(torch.load(os.path.join(opt.resume)))
64+
epoch_resume=int(re.sub("\D","",opt.resume))
65+
print('resume success: epoch {}'.format(epoch_resume))
66+
67+
criterion_ce = nn.CrossEntropyLoss().to(dev)
68+
criterion_perceptual = nn.MSELoss().to(dev)
69+
criterion = get_criterion(opt)
70+
optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate) # Be careful when you adjust learning rate, you can refer to the linear scaling rule
71+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, T_mult=2, eta_min=0, last_epoch=-1)
72+
73+
"""
74+
Set starting values
75+
"""
76+
best_metrics = {'cd_f1scores': -1, 'cd_recalls': -1, 'cd_precisions': -1}
77+
logging.info('STARTING training')
78+
79+
for epoch in range(opt.epochs):
80+
epoch= epoch + epoch_resume +1
81+
train_metrics = initialize_metrics()
82+
val_metrics = initialize_metrics()
83+
84+
"""
85+
Begin Training
86+
"""
87+
model.train()
88+
logging.info('SET model mode to train!')
89+
90+
for batch_img1, batch_img2, labels in train_loader:
91+
# Set variables for training
92+
batch_img1 = batch_img1.float().to(dev)
93+
batch_img2 = batch_img2.float().to(dev)
94+
labels = labels.long().to(dev)
95+
96+
# Zero the gradient
97+
optimizer.zero_grad()
98+
99+
# Get model predictions, calculate loss, backprop
100+
cd_preds= model(batch_img1, batch_img2)
101+
loss = criterion(criterion_ce, criterion_perceptual, cd_preds, labels, batch_img1, vgg, dev)
102+
103+
loss.backward()
104+
optimizer.step()
105+
106+
# Calculate and log other batch metrics
107+
cd_preds = torch.argmax(cd_preds, dim = 1)
108+
cd_corrects = (100 *
109+
(cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() /
110+
(labels.size()[0] * (opt.patch_size**2)))
111+
cd_train_report = prfs(labels.data.cpu().numpy().flatten(),
112+
cd_preds.data.cpu().numpy().flatten(),
113+
average='binary',
114+
pos_label=1)
115+
train_metrics = set_metrics(train_metrics,
116+
loss,
117+
cd_corrects,
118+
cd_train_report,
119+
scheduler.get_last_lr())
120+
121+
# log the batch mean metrics
122+
mean_train_metrics = get_mean_metrics(train_metrics)
123+
124+
# clear batch variables from memory
125+
del batch_img1, batch_img2, labels
126+
127+
scheduler.step()
128+
logging.info("EPOCH {} TRAIN METRICS. ".format(epoch) + str(mean_train_metrics))
129+
130+
131+
"""
132+
Begin Validation
133+
"""
134+
model.eval()
135+
with torch.no_grad():
136+
for batch_img1, batch_img2, labels in val_loader:
137+
# Set variables for training
138+
batch_img1 = batch_img1.float().to(dev)
139+
batch_img2 = batch_img2.float().to(dev)
140+
labels = labels.long().to(dev)
141+
142+
# Get predictions and calculate loss
143+
cd_preds = model(batch_img1, batch_img2)
144+
val_loss = criterion(criterion_ce, criterion_perceptual, cd_preds, labels, batch_img1, vgg, dev)
145+
146+
# Calculate and log other batch metrics
147+
cd_preds = torch.argmax(cd_preds, dim = 1)
148+
cd_corrects = (100 *
149+
(cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() /
150+
(labels.size()[0] * (opt.patch_size**2)))
151+
cd_val_report = prfs(labels.data.cpu().numpy().flatten(),
152+
cd_preds.data.cpu().numpy().flatten(),
153+
average='binary',
154+
pos_label=1)
155+
val_metrics = set_metrics(val_metrics,
156+
val_loss,
157+
cd_corrects,
158+
cd_val_report,
159+
scheduler.get_lr())
160+
161+
# log the batch mean metrics
162+
mean_val_metrics = get_mean_metrics(val_metrics)
163+
164+
# clear batch variables from memory
165+
del batch_img1, batch_img2, labels
166+
167+
logging.info("EPOCH {} VALIDATION METRICS".format(epoch)+str(mean_val_metrics))
168+
169+
"""
170+
Store the weights of good epochs based on validation results
171+
"""
172+
if (mean_val_metrics['cd_f1scores'] > best_metrics['cd_f1scores']):
173+
# Insert training and epoch information to metadata dictionary
174+
logging.info('updata the model')
175+
metadata['val_metrics'] = mean_val_metrics
176+
177+
# Save model and log
178+
if not os.path.exists(opt.weight_dir):
179+
os.mkdir(opt.weight_dir)
180+
with open(opt.weight_dir + 'metadata_val_epoch_' + str(epoch) + '.json', 'w') as fout:
181+
json.dump(metadata, fout)
182+
183+
torch.save(model.state_dict(), opt.weight_dir + 'checkpoint_epoch_'+str(epoch)+'_f1_'+str(mean_val_metrics['cd_f1scores'])+'.pt')
184+
best_metrics = mean_val_metrics
185+
print('best val: ' + str(mean_val_metrics))
186+
187+
print('An epoch finished.')
188+
189+
print('Done!')

0 commit comments

Comments
 (0)