@@ -84,34 +84,44 @@ def __init__(self, hidden_dim, hidden_dim_s, hidden_dim_t, strategy, bidirection
84
84
85
85
def forward (self , target_h , source_hs ):
86
86
87
- if self .strategy in ['dot' ,'general' ]:
88
- source_hs = source_hs .permute (1 ,0 ,2 ) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
89
-
90
- if self .strategy == 'dot' :
91
- # with this strategy, no trainable parameters are involved
92
- # here, feat = hidden_dim_t = hidden_dim_s
93
- target_h = target_h .permute (1 ,2 ,0 ) # (1,batch,feat) -> (batch,feat,1)
94
- dot_product = torch .matmul (source_hs , target_h ) # (batch,seq,feat) * (batch,feat,1) -> (batch,seq,1)
95
- scores = dot_product .permute (1 ,0 ,2 ) # -> (seq,batch,1)
87
+ # if self.strategy in ['dot','general']:
88
+ # source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
89
+
90
+ # if self.strategy == 'dot':
91
+ # # with this strategy, no trainable parameters are involved
92
+ # # here, feat = hidden_dim_t = hidden_dim_s
93
+ # target_h = target_h.permute(1,2,0) # (1,batch,feat) -> (batch,feat,1)
94
+ # dot_product = torch.matmul(source_hs, target_h) # (batch,seq,feat) * (batch,feat,1) -> (batch,seq,1)
95
+ # scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
96
96
97
- elif self .strategy == 'general' :
98
- target_h = target_h .permute (1 ,0 ,2 ) # (1,batch,hidden_dim_t) -> (batch,1,hidden_dim_t)
99
- output = self .ff_general (target_h ) # -> (batch,1,hidden_dim_s)
100
- output = output .permute (0 ,2 ,1 ) # -> (batch,hidden_dim_s,1)
101
- dot_product = torch .matmul (source_hs , output ) # (batch,seq,hidden_dim_s) * (batch,hidden_dim_s,1) -> (batch,seq,1)
102
- scores = dot_product .permute (1 ,0 ,2 ) # -> (seq,batch,1)
97
+ # elif self.strategy == 'general':
98
+ # target_h = target_h.permute(1,0,2) # (1,batch,hidden_dim_t) -> (batch,1,hidden_dim_t)
99
+ # output = self.ff_general(target_h) # -> (batch,1,hidden_dim_s)
100
+ # output = output.permute(0,2,1) # -> (batch,hidden_dim_s,1)
101
+ # dot_product = torch.matmul(source_hs, output) # (batch,seq,hidden_dim_s) * (batch,hidden_dim_s,1) -> (batch,seq,1)
102
+ # scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
103
103
104
- elif self .strategy == 'concat' :
105
- target_h_rep = target_h .repeat (source_hs .size (0 ),1 ,1 ) # (1,batch,hidden_dim_s) -> (seq,batch,hidden_dim_s)
106
- concat_output = self .ff_concat (torch .cat ((target_h_rep ,source_hs ),- 1 )) # (seq,batch,hidden_dim_s+hidden_dim_t) -> (seq,batch,hidden_dim)
107
- scores = self .ff_score (torch .tanh (concat_output )) # -> (seq,batch,1)
108
- source_hs = source_hs .permute (1 ,0 ,2 ) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
104
+ # elif self.strategy == 'concat':
105
+ # target_h_rep = target_h.repeat(source_hs.size(0),1,1) # (1,batch,hidden_dim_s) -> (seq,batch,hidden_dim_s)
106
+ # concat_output = self.ff_concat(torch.cat((target_h_rep,source_hs),-1)) # (seq,batch,hidden_dim_s+hidden_dim_t) -> (seq,batch,hidden_dim)
107
+ # scores = self.ff_score(torch.tanh(concat_output)) # -> (seq,batch,1)
108
+ # source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
109
109
110
- scores = scores .squeeze (dim = 2 ) # (seq,batch,1) -> (seq,batch). We specify a dimension, because we don't want to squeeze the batch dim in case batch size is equal to 1
111
- norm_scores = torch .softmax (scores ,0 ) # sequence-wise normalization
112
- source_hs_p = source_hs .permute ((2 ,1 ,0 )) # (batch,seq,hidden_dim_s) -> (hidden_dim_s,seq,batch)
113
- weighted_source_hs = (norm_scores * source_hs_p ) # (seq,batch) * (hidden_dim_s,seq,batch) -> (hidden_dim_s,seq,batch) (we use broadcasting here - the * operator checks from right to left that the dimensions match)
114
- ct = torch .sum (weighted_source_hs .permute ((1 ,2 ,0 )),0 ,keepdim = True ) # (hidden_dim_s,seq,batch) -> (seq,batch,hidden_dim_s) -> (1,batch,hidden_dim_s); we need keepdim as sum squeezes by default
110
+ # scores = scores.squeeze(dim=2) # (seq,batch,1) -> (seq,batch). We specify a dimension, because we don't want to squeeze the batch dim in case batch size is equal to 1
111
+ # norm_scores = torch.softmax(scores,0) # sequence-wise normalization
112
+ # source_hs_p = source_hs.permute((2,1,0)) # (batch,seq,hidden_dim_s) -> (hidden_dim_s,seq,batch)
113
+ # weighted_source_hs = (norm_scores * source_hs_p) # (seq,batch) * (hidden_dim_s,seq,batch) -> (hidden_dim_s,seq,batch) (we use broadcasting here - the * operator checks from right to left that the dimensions match)
114
+ # ct = torch.sum(weighted_source_hs.permute((1,2,0)),0,keepdim=True) # (hidden_dim_s,seq,batch) -> (seq,batch,hidden_dim_s) -> (1,batch,hidden_dim_s); we need keepdim as sum squeezes by default
115
+
116
+ target_h_rep = target_h .repeat (source_hs .size (0 ),1 ,1 ) # (1,batch,feat) -> (seq,batch,feat)
117
+ concat_output = self .ff_concat (torch .cat ((target_h_rep ,source_hs ),- 1 )) # source_hs is (seq,batch,feat)
118
+ scores = self .ff_score (torch .tanh (concat_output )) # (seq,batch,feat) -> (seq,batch,1)
119
+ scores = scores .squeeze (dim = 2 ) # (seq,batch,1) -> (seq,batch). dim=2 because we don't want to squeeze the batch dim if batch size = 1
120
+ norm_scores = torch .softmax (scores ,0 )
121
+ source_hs_p = source_hs .permute ((2 ,0 ,1 )) # (seq,batch,feat) -> (feat,seq,batch)
122
+ weighted_source_hs = (norm_scores * source_hs_p ) # (seq,batch) * (feat,seq,batch) (* checks from right to left that the dimensions match)
123
+ ct = torch .sum (weighted_source_hs .permute ((1 ,2 ,0 )),0 ,keepdim = True ) # (feat,seq,batch) -> (seq,batch,feat) -> (1,batch,feat); keepdim otherwise sum squeezes
124
+
115
125
116
126
return ct
117
127
@@ -186,7 +196,6 @@ def __init__(self, vocab_s, source_language, vocab_t_inv, embedding_dim_s, embed
186
196
self .decoder = Decoder (self .max_target_idx + 1 , self .embedding_dim_t , self .hidden_dim_t , self .hidden_dim_s , self .num_layers , self .bidirectional , self .padding_token , self .dropout ).to (self .device )
187
197
188
198
if not self .att_strategy == 'none' :
189
- print (self .att_strategy )
190
199
self .att_mech = seq2seqAtt (self .hidden_dim_att , self .hidden_dim_s , self .hidden_dim_t , self .att_strategy , self .bidirectional ).to (self .device )
191
200
192
201
def my_pad (self , my_list ):
0 commit comments