@@ -84,44 +84,34 @@ def __init__(self, hidden_dim, hidden_dim_s, hidden_dim_t, strategy, bidirection
84
84
85
85
def forward (self , target_h , source_hs ):
86
86
87
- # if self.strategy in ['dot','general']:
88
- # source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
89
-
90
- # if self.strategy == 'dot':
91
- # # with this strategy, no trainable parameters are involved
92
- # # here, feat = hidden_dim_t = hidden_dim_s
93
- # target_h = target_h.permute(1,2,0) # (1,batch,feat) -> (batch,feat,1)
94
- # dot_product = torch.matmul(source_hs, target_h) # (batch,seq,feat) * (batch,feat,1) -> (batch,seq,1)
95
- # scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
87
+ if self .strategy in ['dot' ,'general' ]:
88
+ source_hs = source_hs .permute (1 ,0 ,2 ) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
89
+
90
+ if self .strategy == 'dot' :
91
+ # with this strategy, no trainable parameters are involved
92
+ # here, feat = hidden_dim_t = hidden_dim_s
93
+ target_h = target_h .permute (1 ,2 ,0 ) # (1,batch,feat) -> (batch,feat,1)
94
+ dot_product = torch .matmul (source_hs , target_h ) # (batch,seq,feat) * (batch,feat,1) -> (batch,seq,1)
95
+ scores = dot_product .permute (1 ,0 ,2 ) # -> (seq,batch,1)
96
96
97
- # elif self.strategy == 'general':
98
- # target_h = target_h.permute(1,0,2) # (1,batch,hidden_dim_t) -> (batch,1,hidden_dim_t)
99
- # output = self.ff_general(target_h) # -> (batch,1,hidden_dim_s)
100
- # output = output.permute(0,2,1) # -> (batch,hidden_dim_s,1)
101
- # dot_product = torch.matmul(source_hs, output) # (batch,seq,hidden_dim_s) * (batch,hidden_dim_s,1) -> (batch,seq,1)
102
- # scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
97
+ elif self .strategy == 'general' :
98
+ target_h = target_h .permute (1 ,0 ,2 ) # (1,batch,hidden_dim_t) -> (batch,1,hidden_dim_t)
99
+ output = self .ff_general (target_h ) # -> (batch,1,hidden_dim_s)
100
+ output = output .permute (0 ,2 ,1 ) # -> (batch,hidden_dim_s,1)
101
+ dot_product = torch .matmul (source_hs , output ) # (batch,seq,hidden_dim_s) * (batch,hidden_dim_s,1) -> (batch,seq,1)
102
+ scores = dot_product .permute (1 ,0 ,2 ) # -> (seq,batch,1)
103
103
104
- # elif self.strategy == 'concat':
105
- # target_h_rep = target_h.repeat(source_hs.size(0),1,1) # (1,batch,hidden_dim_s) -> (seq,batch,hidden_dim_s)
106
- # concat_output = self.ff_concat(torch.cat((target_h_rep,source_hs),-1)) # (seq,batch,hidden_dim_s+hidden_dim_t) -> (seq,batch,hidden_dim)
107
- # scores = self.ff_score(torch.tanh(concat_output)) # -> (seq,batch,1)
108
- # source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
104
+ elif self .strategy == 'concat' :
105
+ target_h_rep = target_h .repeat (source_hs .size (0 ),1 ,1 ) # (1,batch,hidden_dim_s) -> (seq,batch,hidden_dim_s)
106
+ concat_output = self .ff_concat (torch .cat ((target_h_rep ,source_hs ),- 1 )) # (seq,batch,hidden_dim_s+hidden_dim_t) -> (seq,batch,hidden_dim)
107
+ scores = self .ff_score (torch .tanh (concat_output )) # -> (seq,batch,1)
108
+ source_hs = source_hs .permute (1 ,0 ,2 ) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
109
109
110
- # scores = scores.squeeze(dim=2) # (seq,batch,1) -> (seq,batch). We specify a dimension, because we don't want to squeeze the batch dim in case batch size is equal to 1
111
- # norm_scores = torch.softmax(scores,0) # sequence-wise normalization
112
- # source_hs_p = source_hs.permute((2,1,0)) # (batch,seq,hidden_dim_s) -> (hidden_dim_s,seq,batch)
113
- # weighted_source_hs = (norm_scores * source_hs_p) # (seq,batch) * (hidden_dim_s,seq,batch) -> (hidden_dim_s,seq,batch) (we use broadcasting here - the * operator checks from right to left that the dimensions match)
114
- # ct = torch.sum(weighted_source_hs.permute((1,2,0)),0,keepdim=True) # (hidden_dim_s,seq,batch) -> (seq,batch,hidden_dim_s) -> (1,batch,hidden_dim_s); we need keepdim as sum squeezes by default
115
-
116
- target_h_rep = target_h .repeat (source_hs .size (0 ),1 ,1 ) # (1,batch,feat) -> (seq,batch,feat)
117
- concat_output = self .ff_concat (torch .cat ((target_h_rep ,source_hs ),- 1 )) # source_hs is (seq,batch,feat)
118
- scores = self .ff_score (torch .tanh (concat_output )) # (seq,batch,feat) -> (seq,batch,1)
119
- scores = scores .squeeze (dim = 2 ) # (seq,batch,1) -> (seq,batch). dim=2 because we don't want to squeeze the batch dim if batch size = 1
120
- norm_scores = torch .softmax (scores ,0 )
121
- source_hs_p = source_hs .permute ((2 ,0 ,1 )) # (seq,batch,feat) -> (feat,seq,batch)
122
- weighted_source_hs = (norm_scores * source_hs_p ) # (seq,batch) * (feat,seq,batch) (* checks from right to left that the dimensions match)
123
- ct = torch .sum (weighted_source_hs .permute ((1 ,2 ,0 )),0 ,keepdim = True ) # (feat,seq,batch) -> (seq,batch,feat) -> (1,batch,feat); keepdim otherwise sum squeezes
124
-
110
+ scores = scores .squeeze (dim = 2 ) # (seq,batch,1) -> (seq,batch). We specify a dimension, because we don't want to squeeze the batch dim in case batch size is equal to 1
111
+ norm_scores = torch .softmax (scores ,0 ) # sequence-wise normalization
112
+ source_hs_p = source_hs .permute ((2 ,1 ,0 )) # (batch,seq,hidden_dim_s) -> (hidden_dim_s,seq,batch)
113
+ weighted_source_hs = (norm_scores * source_hs_p ) # (seq,batch) * (hidden_dim_s,seq,batch) -> (hidden_dim_s,seq,batch) (we use broadcasting here - the * operator checks from right to left that the dimensions match)
114
+ ct = torch .sum (weighted_source_hs .permute ((1 ,2 ,0 )),0 ,keepdim = True ) # (hidden_dim_s,seq,batch) -> (seq,batch,hidden_dim_s) -> (1,batch,hidden_dim_s); we need keepdim as sum squeezes by default
125
115
126
116
return ct
127
117
@@ -261,11 +251,17 @@ def forward(self, input, max_size, is_prod):
261
251
return to_return
262
252
263
253
264
- def fit (self , trainingDataset , testDataset , lr , batch_size , n_epochs , patience ):
254
+ def fit (self , trainingDataset , testDataset , lr , batch_size , n_epochs , patience , my_optimizer ):
265
255
266
256
parameters = [p for p in self .parameters () if p .requires_grad ]
267
257
268
- optimizer = optim .Adam (parameters , lr = lr )
258
+ if my_optimizer == 'adam' :
259
+ optimizer = optim .Adam (parameters , lr = lr )
260
+ elif my_optimizer == 'SGD' :
261
+ optimizer = optim .SGD (parameters , lr = lr ) # https://pytorch.org/docs/stable/optim.html#torch.optim.SGD
262
+ scheduler = optim .lr_scheduler .ReduceLROnPlateau (optimizer , mode = 'min' ,
263
+ factor = 0.1 , patience = 5 ,
264
+ verbose = True , threshold = 0.1 ) # https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
269
265
270
266
criterion = torch .nn .CrossEntropyLoss (ignore_index = self .padding_token ) # the softmax is inside the loss!
271
267
@@ -286,9 +282,9 @@ def fit(self, trainingDataset, testDataset, lr, batch_size, n_epochs, patience):
286
282
it_times = []
287
283
288
284
# my fake code
289
- for p in self .parameters ():
290
- if not p .requires_grad :
291
- print (p .name , p .data )
285
+ # for p in self.parameters():
286
+ # if not p.requires_grad:
287
+ # print(p.name, p.data)
292
288
293
289
for epoch in range (n_epochs ):
294
290
@@ -366,7 +362,10 @@ def fit(self, trainingDataset, testDataset, lr, batch_size, n_epochs, patience):
366
362
367
363
if patience_counter > patience :
368
364
break
369
-
365
+
366
+ if my_optimizer == 'SGD' :
367
+ scheduler .step (total_loss )
368
+
370
369
self .test_toy (test_sents )
371
370
372
371
self .logs ['avg_time_it' ] = round (np .mean (it_times ),4 )
0 commit comments