-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
227 lines (185 loc) · 9.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# -*- coding: utf-8 -*-
"""Example main file for LSTM network training
Scenario: LSTM network for predicting 1 label per sequence.
Input: Command line argument with path to config file 'config.json'.
Output: Output files will be saved in the output folder specified in 'config.json'.
Dataset: Dataset 'RandomOrSine' gives us sequences that need to be classified into random uniform signal or
sine signals.
Sequences have different lengths, so we need to use widis_lstm_tools.preprocessing.PadToEqualLengths for padding.
Config: Setup is done via config file 'config.json'.
Author -- Michael Widrich
Contact -- [email protected]
"""
import os
import time
import matplotlib
matplotlib.use('Agg')
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from widis_lstm_tools.nn import LSTMLayer, LearningRateDecay
from widis_lstm_tools.utils.config_tools import get_config
from widis_lstm_tools.preprocessing import PadToEqualLengths
from widis_lstm_tools.examples.basic.dataset import RandomOrSine
from widis_lstm_tools.measures import bacc
from widis_lstm_tools.utils.collection import TeePrint, close_all
class Net(nn.Module):
def __init__(self, n_input_features, n_lstm, n_outputs):
super(Net, self).__init__()
# Let's say we want an LSTM with forward connections to cell input and recurrent connections to input- and
# output gate only; Furthermore we want a linear LSTM output activation instead of tanh:
self.lstm1 = LSTMLayer(
in_features=n_input_features, out_features=n_lstm,
# Possible input formats: 'NLC' (samples, length, channels), 'NCL', or 'LNC'
inputformat='NLC',
# cell input: initialize weights to forward inputs with xavier, disable connections to recurrent inputs
w_ci=(nn.init.xavier_normal_, False),
# input gate: disable connections to forward inputs, initialize weights to recurrent inputs with xavier
w_ig=(False, nn.init.xavier_normal_),
# output gate: disable connections to forward inputs, initialize weights to recurrent inputs with xavier
w_og=(False, nn.init.xavier_normal_),
# forget gate: disable all connections (=no forget gate) and bias
w_fg=False, b_fg=False,
# LSTM output activation shall be identity function
a_out=lambda x: x,
# Optionally use negative input gate bias for long sequences
b_ig=lambda *args, **kwargs: nn.init.normal_(mean=-5, *args, **kwargs),
# Optionally let LSTM do computations after sequence end, using tickersteps/tinkersteps
n_tickersteps=5,
)
# This would be a fully connected LSTM (cell input and gates connected to forward and recurrent connections)
# without tickersteps:
# self.lstm1 = LSTMLayer(
# in_features=n_input_features, out_features=n_lstm,
# inputformat='NLC',
# w_ci=nn.init.xavier_normal_, b_ci=nn.init.normal_, # equal to w_ci=(nn.init.normal_, nn.init.normal_)
# w_ig=nn.init.xavier_normal_, b_ig=nn.init.normal_,
# w_og=nn.init.xavier_normal_, b_og=nn.init.normal_,
# w_fg=nn.init.xavier_normal_, b_fg=nn.init.normal_,
# a_out=lambda x: x
# )
# After the LSTM layer, we add a fully connected output layer
self.fc_out = nn.Linear(n_lstm, n_outputs)
def forward(self, x, true_seq_lens):
# We only need the output of the LSTM; We get format (samples, n_lstm) since we set return_all_seq_pos=False:
lstm_out, *_ = self.lstm1.forward(x,
true_seq_lens=true_seq_lens, # true sequence lengths of padded sequences
return_all_seq_pos=False # return predictions for last sequence position
)
net_out = self.fc_out(lstm_out)
return net_out
def main():
# Read config file path and set up results folder
config, resdir = get_config()
logfile = os.path.join(resdir, 'log.txt')
os.makedirs(resdir, exist_ok=True)
# Get a tprint() function that prints to stdout and our logfile
tee_print = TeePrint(logfile)
tprint = tee_print.tee_print
# Set up PyTorch and set random seeds
torch.set_num_threads(config['num_threads'])
torch.manual_seed(config['rnd_seed'])
np.random.seed(config['rnd_seed'])
device = torch.device(config['device']) # e.g. "cpu" or "cuda:0"
# Get datasets
trainset = RandomOrSine(n_samples=config['n_trainingset_samples'])
testset = RandomOrSine(n_samples=config['n_testset_samples'])
# Set up sequence padding
padder = PadToEqualLengths(
padding_dims=(0, None, None), # only pad the first entry (sequences) in sample at dimension 0 (=seq.len.)
padding_values=(0, None, None) # pad with zeros
)
# Get Dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config['batch_size'], shuffle=True, num_workers=2,
collate_fn=padder.pad_collate_fn)
testloader = torch.utils.data.DataLoader(testset, batch_size=config['batch_size'] * 4, shuffle=False,
num_workers=2, collate_fn=padder.pad_collate_fn)
# Create Network
net = Net(n_input_features=trainset.n_features, n_lstm=config['n_lstm'], n_outputs=trainset.n_classes)
net.to(device)
# Get some loss functions
mean_cross_entropy = nn.CrossEntropyLoss()
# Get some optimizer
optimizer = optim.Adam(net.parameters(), lr=config['lr'], weight_decay=1e-5)
# Get a linear learning rate decay
lr_decay = LearningRateDecay(max_n_updates=config['n_updates'], optimizer=optimizer, original_lr=config['lr'])
#
# Start training
#
tprint("# settings: {}".format(config))
update = 0
running_loss = []
while update < config['n_updates']:
start_time = time.time()
for data in trainloader:
# Get and set current learning rate
lr = lr_decay.get_lr(update)
# Get samples
inputs, labels, sample_id = data
padded_sequences, seq_lens = inputs
padded_sequences, labels = padded_sequences.to(device), labels.long().to(device)
# Reset gradients
optimizer.zero_grad()
# Get outputs for network
outputs = net(padded_sequences, seq_lens)
# Calculate loss, do backward pass, and update
loss = mean_cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
update += 1
# Update running losses for our statistic
running_loss.append(loss.item() if hasattr(loss, 'item') else loss)
running_loss = running_loss[-10:] # keep last 10 losses for averaging
# Print current status and score
if update % config['print_stats_at'] == 0 and update > 0:
run_time = (time.time() - start_time) / config['print_stats_at']
tprint(f"[train] u: {update:07d}; loss: {np.mean(running_loss):8.7f}; "
f"sec/update: {run_time:8.7f};lr: {lr:8.7f}")
start_time = time.time()
# Do some plotting using the LSTMLayer plotting function
if update % config['plot_at'] == 0:
# This will plot the LSTM internals for sample 0 in minibatch
mb_index = 0
pred = (outputs[mb_index, 1] > outputs[mb_index, 0]).float().cpu().item()
net.lstm1.plot_internals(
filename=os.path.join(resdir, 'lstm_plots',
f'u{update:07d}_id{sample_id[0]}_cl{labels[0]}_pr{pred}.png'),
mb_index=mb_index, fdict=dict(figsize=(50, 10), dpi=100))
start_time = time.time()
if update >= config['n_updates']:
break
print('Finished Training! Starting evaluation on test set...')
# Compute scores on testset
with torch.no_grad():
tp_sum = 0.
tn_sum = 0.
p_sum = 0.
loss = 0.
for testdata in testloader:
# Get samples
inputs, labels, _ = testdata
padded_sequences, seq_lens = inputs
padded_sequences, labels = padded_sequences.to(device), labels.long().to(device)
# Get outputs for network
outputs = net(padded_sequences, seq_lens)
# Add loss to mean loss over testset
loss += (mean_cross_entropy(outputs, labels) * (len(labels) / len(testset)))
# Store sum of tp, tn, t for BACC calculation
labels = labels.float()
p_sum += labels.sum(dim=0) # number of positive samples
predictions = (outputs[:, 1] > outputs[:, 0]).float()
tp_sum += (predictions * labels).sum()
tn_sum += ((1 - predictions) * (1 - labels)).sum()
# Compute balanced accuracy
n_sum = len(testset) - p_sum
bacc_score = bacc(tp=tp_sum, tn=tn_sum, p=p_sum, n=n_sum).cpu().item()
loss = loss.cpu().item()
# Print results
tprint(f"[eval] u: {update:07d}; loss: {loss:8.7f}; bacc: {bacc_score:5.4f}")
print('Done!')
if __name__ == '__main__':
try:
main()
finally:
close_all()