Skip to content

Commit cfd3452

Browse files
committed
sgd
1 parent e30c612 commit cfd3452

File tree

2 files changed

+43
-44
lines changed

2 files changed

+43
-44
lines changed

NMT/code/grid_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def load_pairs(train_or_test):
5959
num_layers = 1
6060
bidirectional = False
6161

62-
for att_strategy in ['dot','general','concat']: #['none']:
62+
for att_strategy in ['concat','general','dot','none']:
6363

64-
hidden_dim_s = 30
64+
hidden_dim_s = 30
6565

6666
if bidirectional:
6767
if att_strategy == 'dot':
@@ -89,7 +89,7 @@ def load_pairs(train_or_test):
8989
max_size = 30, # for the decoder, in prediction mode
9090
dropout = 0)
9191

92-
model.fit(training_set, test_set, lr=0.002, batch_size=64, n_epochs=1, patience=5)
92+
model.fit(training_set, test_set, lr=0.1, batch_size=64, n_epochs=200, patience=10, my_optimizer='SGD')
9393

9494
model_name = '_'.join([att_strategy, str(num_layers), str(bidirectional)])
9595
model.save(path_to_save, model_name)

NMT/code/model.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -84,44 +84,34 @@ def __init__(self, hidden_dim, hidden_dim_s, hidden_dim_t, strategy, bidirection
8484

8585
def forward(self, target_h, source_hs):
8686

87-
# if self.strategy in ['dot','general']:
88-
# source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
89-
90-
# if self.strategy == 'dot':
91-
# # with this strategy, no trainable parameters are involved
92-
# # here, feat = hidden_dim_t = hidden_dim_s
93-
# target_h = target_h.permute(1,2,0) # (1,batch,feat) -> (batch,feat,1)
94-
# dot_product = torch.matmul(source_hs, target_h) # (batch,seq,feat) * (batch,feat,1) -> (batch,seq,1)
95-
# scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
87+
if self.strategy in ['dot','general']:
88+
source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
89+
90+
if self.strategy == 'dot':
91+
# with this strategy, no trainable parameters are involved
92+
# here, feat = hidden_dim_t = hidden_dim_s
93+
target_h = target_h.permute(1,2,0) # (1,batch,feat) -> (batch,feat,1)
94+
dot_product = torch.matmul(source_hs, target_h) # (batch,seq,feat) * (batch,feat,1) -> (batch,seq,1)
95+
scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
9696

97-
# elif self.strategy == 'general':
98-
# target_h = target_h.permute(1,0,2) # (1,batch,hidden_dim_t) -> (batch,1,hidden_dim_t)
99-
# output = self.ff_general(target_h) # -> (batch,1,hidden_dim_s)
100-
# output = output.permute(0,2,1) # -> (batch,hidden_dim_s,1)
101-
# dot_product = torch.matmul(source_hs, output) # (batch,seq,hidden_dim_s) * (batch,hidden_dim_s,1) -> (batch,seq,1)
102-
# scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
97+
elif self.strategy == 'general':
98+
target_h = target_h.permute(1,0,2) # (1,batch,hidden_dim_t) -> (batch,1,hidden_dim_t)
99+
output = self.ff_general(target_h) # -> (batch,1,hidden_dim_s)
100+
output = output.permute(0,2,1) # -> (batch,hidden_dim_s,1)
101+
dot_product = torch.matmul(source_hs, output) # (batch,seq,hidden_dim_s) * (batch,hidden_dim_s,1) -> (batch,seq,1)
102+
scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
103103

104-
# elif self.strategy == 'concat':
105-
# target_h_rep = target_h.repeat(source_hs.size(0),1,1) # (1,batch,hidden_dim_s) -> (seq,batch,hidden_dim_s)
106-
# concat_output = self.ff_concat(torch.cat((target_h_rep,source_hs),-1)) # (seq,batch,hidden_dim_s+hidden_dim_t) -> (seq,batch,hidden_dim)
107-
# scores = self.ff_score(torch.tanh(concat_output)) # -> (seq,batch,1)
108-
# source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
104+
elif self.strategy == 'concat':
105+
target_h_rep = target_h.repeat(source_hs.size(0),1,1) # (1,batch,hidden_dim_s) -> (seq,batch,hidden_dim_s)
106+
concat_output = self.ff_concat(torch.cat((target_h_rep,source_hs),-1)) # (seq,batch,hidden_dim_s+hidden_dim_t) -> (seq,batch,hidden_dim)
107+
scores = self.ff_score(torch.tanh(concat_output)) # -> (seq,batch,1)
108+
source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
109109

110-
# scores = scores.squeeze(dim=2) # (seq,batch,1) -> (seq,batch). We specify a dimension, because we don't want to squeeze the batch dim in case batch size is equal to 1
111-
# norm_scores = torch.softmax(scores,0) # sequence-wise normalization
112-
# source_hs_p = source_hs.permute((2,1,0)) # (batch,seq,hidden_dim_s) -> (hidden_dim_s,seq,batch)
113-
# weighted_source_hs = (norm_scores * source_hs_p) # (seq,batch) * (hidden_dim_s,seq,batch) -> (hidden_dim_s,seq,batch) (we use broadcasting here - the * operator checks from right to left that the dimensions match)
114-
# ct = torch.sum(weighted_source_hs.permute((1,2,0)),0,keepdim=True) # (hidden_dim_s,seq,batch) -> (seq,batch,hidden_dim_s) -> (1,batch,hidden_dim_s); we need keepdim as sum squeezes by default
115-
116-
target_h_rep = target_h.repeat(source_hs.size(0),1,1) # (1,batch,feat) -> (seq,batch,feat)
117-
concat_output = self.ff_concat(torch.cat((target_h_rep,source_hs),-1)) # source_hs is (seq,batch,feat)
118-
scores = self.ff_score(torch.tanh(concat_output)) # (seq,batch,feat) -> (seq,batch,1)
119-
scores = scores.squeeze(dim=2) # (seq,batch,1) -> (seq,batch). dim=2 because we don't want to squeeze the batch dim if batch size = 1
120-
norm_scores = torch.softmax(scores,0)
121-
source_hs_p = source_hs.permute((2,0,1)) # (seq,batch,feat) -> (feat,seq,batch)
122-
weighted_source_hs = (norm_scores * source_hs_p) # (seq,batch) * (feat,seq,batch) (* checks from right to left that the dimensions match)
123-
ct = torch.sum(weighted_source_hs.permute((1,2,0)),0,keepdim=True) # (feat,seq,batch) -> (seq,batch,feat) -> (1,batch,feat); keepdim otherwise sum squeezes
124-
110+
scores = scores.squeeze(dim=2) # (seq,batch,1) -> (seq,batch). We specify a dimension, because we don't want to squeeze the batch dim in case batch size is equal to 1
111+
norm_scores = torch.softmax(scores,0) # sequence-wise normalization
112+
source_hs_p = source_hs.permute((2,1,0)) # (batch,seq,hidden_dim_s) -> (hidden_dim_s,seq,batch)
113+
weighted_source_hs = (norm_scores * source_hs_p) # (seq,batch) * (hidden_dim_s,seq,batch) -> (hidden_dim_s,seq,batch) (we use broadcasting here - the * operator checks from right to left that the dimensions match)
114+
ct = torch.sum(weighted_source_hs.permute((1,2,0)),0,keepdim=True) # (hidden_dim_s,seq,batch) -> (seq,batch,hidden_dim_s) -> (1,batch,hidden_dim_s); we need keepdim as sum squeezes by default
125115

126116
return ct
127117

@@ -261,11 +251,17 @@ def forward(self, input, max_size, is_prod):
261251
return to_return
262252

263253

264-
def fit(self, trainingDataset, testDataset, lr, batch_size, n_epochs, patience):
254+
def fit(self, trainingDataset, testDataset, lr, batch_size, n_epochs, patience, my_optimizer):
265255

266256
parameters = [p for p in self.parameters() if p.requires_grad]
267257

268-
optimizer = optim.Adam(parameters, lr=lr)
258+
if my_optimizer == 'adam':
259+
optimizer = optim.Adam(parameters, lr=lr)
260+
elif my_optimizer == 'SGD':
261+
optimizer = optim.SGD(parameters, lr=lr) # https://pytorch.org/docs/stable/optim.html#torch.optim.SGD
262+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
263+
factor=0.1, patience=5,
264+
verbose=True, threshold=0.1) # https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
269265

270266
criterion = torch.nn.CrossEntropyLoss(ignore_index=self.padding_token) # the softmax is inside the loss!
271267

@@ -286,9 +282,9 @@ def fit(self, trainingDataset, testDataset, lr, batch_size, n_epochs, patience):
286282
it_times = []
287283

288284
# my fake code
289-
for p in self.parameters():
290-
if not p.requires_grad:
291-
print(p.name, p.data)
285+
#for p in self.parameters():
286+
# if not p.requires_grad:
287+
# print(p.name, p.data)
292288

293289
for epoch in range(n_epochs):
294290

@@ -366,7 +362,10 @@ def fit(self, trainingDataset, testDataset, lr, batch_size, n_epochs, patience):
366362

367363
if patience_counter>patience:
368364
break
369-
365+
366+
if my_optimizer == 'SGD':
367+
scheduler.step(total_loss)
368+
370369
self.test_toy(test_sents)
371370

372371
self.logs['avg_time_it'] = round(np.mean(it_times),4)

0 commit comments

Comments
 (0)