@@ -1129,28 +1129,25 @@ def __init__(
1129
1129
def training_step (self , model , item ):
1130
1130
model .train ()
1131
1131
1132
- text1 = item [0 ]
1133
- labels = item [1 ].to (self .args .device )
1134
- if self .args .model_type == 'zen' :
1135
- inputs = convert_examples_to_features (text1 = text1 , ngram_dict = self .ngram_dict ,
1136
- tokenizer = self .tokenizer , max_seq_length = self .args .max_length ,
1137
- return_tensors = True )
1138
- else :
1139
- inputs = self .tokenizer (text1 , padding = 'max_length' , max_length = self .args .max_length ,
1140
- truncation = True , return_tensors = 'pt' )
1132
+ input_ids = item [0 ].to (self .args .device )
1133
+ token_type_ids = item [1 ].to (self .args .device )
1134
+ attention_mask = item [2 ].to (self .args .device )
1135
+ labels = item [3 ].to (self .args .device )
1141
1136
1142
1137
if self .args .model_type == 'zen' :
1143
- inputs ['input_ngram_ids' ] = inputs ['input_ngram_ids' ].to (self .args .device )
1144
- inputs ['ngram_position_matrix' ] = inputs ['ngram_position_matrix' ].to (self .args .device )
1145
- inputs ['ngram_attention_mask' ] = inputs ['ngram_attention_mask' ].to (self .args .device )
1146
- inputs ['ngram_token_type_ids' ] = inputs ['ngram_token_type_ids' ].to (self .args .device )
1147
-
1148
- inputs ['input_ids' ] = inputs ['input_ids' ].to (self .args .device )
1149
- inputs ['attention_mask' ] = inputs ['attention_mask' ].to (self .args .device )
1150
- inputs ['token_type_ids' ] = inputs ['token_type_ids' ].to (self .args .device )
1138
+ input_ngram_ids = item [4 ].to (self .args .device )
1139
+ ngram_attention_mask = item [5 ].to (self .args .device )
1140
+ ngram_token_type_ids = item [6 ].to (self .args .device )
1141
+ ngram_position_matrix = item [7 ].to (self .args .device )
1151
1142
1152
1143
# default using 'Transformers' library models.
1153
- outputs = model (labels = labels , ** inputs )
1144
+ if self .args .model_type == 'zen' :
1145
+ outputs = model (input_ids = input_ids , attention_mask = attention_mask , token_type_ids = token_type_ids ,
1146
+ labels = labels , ngram_ids = input_ngram_ids , ngram_positions = ngram_position_matrix ,
1147
+ ngram_attention_mask = ngram_attention_mask , ngram_token_type_ids = ngram_token_type_ids )
1148
+ else :
1149
+ outputs = model (labels = labels , input_ids = input_ids , token_type_ids = token_type_ids ,
1150
+ attention_mask = attention_mask )
1154
1151
loss = outputs [0 ]
1155
1152
loss .backward ()
1156
1153
@@ -1170,29 +1167,28 @@ def evaluate(self, model):
1170
1167
for step , item in enumerate (eval_dataloader ):
1171
1168
model .eval ()
1172
1169
1173
- text1 = item [0 ]
1174
- labels = item [1 ].to (args .device )
1175
-
1176
- if self .args .model_type == 'zen' :
1177
- inputs = convert_examples_to_features (text1 = text1 , ngram_dict = self .ngram_dict ,
1178
- tokenizer = self .tokenizer , max_seq_length = self .args .max_length ,
1179
- return_tensors = True )
1180
- else :
1181
- inputs = self .tokenizer (text1 , padding = 'max_length' , max_length = self .args .max_length ,
1182
- truncation = True , return_tensors = 'pt' )
1183
-
1184
- if self .args .model_type == 'zen' :
1185
- inputs ['input_ngram_ids' ] = inputs ['input_ngram_ids' ].to (self .args .device )
1186
- inputs ['ngram_position_matrix' ] = inputs ['ngram_position_matrix' ].to (self .args .device )
1187
- inputs ['ngram_attention_mask' ] = inputs ['ngram_attention_mask' ].to (self .args .device )
1188
- inputs ['ngram_token_type_ids' ] = inputs ['ngram_token_type_ids' ].to (self .args .device )
1170
+ input_ids = item [0 ].to (self .args .device )
1171
+ token_type_ids = item [1 ].to (self .args .device )
1172
+ attention_mask = item [2 ].to (self .args .device )
1173
+ labels = item [3 ].to (self .args .device )
1189
1174
1190
- inputs ['input_ids' ] = inputs ['input_ids' ].to (self .args .device )
1191
- inputs ['attention_mask' ] = inputs ['attention_mask' ].to (self .args .device )
1192
- inputs ['token_type_ids' ] = inputs ['token_type_ids' ].to (self .args .device )
1175
+ if args .model_type == 'zen' :
1176
+ input_ngram_ids = item [4 ].to (self .args .device )
1177
+ ngram_attention_mask = item [5 ].to (self .args .device )
1178
+ ngram_token_type_ids = item [6 ].to (self .args .device )
1179
+ ngram_position_matrix = item [7 ].to (self .args .device )
1193
1180
1194
1181
with torch .no_grad ():
1195
- outputs = model (labels = labels , ** inputs )
1182
+ if self .args .model_type == 'zen' :
1183
+ outputs = model (input_ids = input_ids , token_type_ids = token_type_ids , attention_mask = attention_mask ,
1184
+ labels = labels , ngram_ids = input_ngram_ids ,
1185
+ ngram_positions = ngram_position_matrix ,
1186
+ ngram_token_type_ids = ngram_token_type_ids ,
1187
+ ngram_attention_mask = ngram_attention_mask )
1188
+ else :
1189
+ outputs = model (labels = labels , input_ids = input_ids , token_type_ids = token_type_ids ,
1190
+ attention_mask = attention_mask )
1191
+
1196
1192
loss , logits = outputs [:2 ]
1197
1193
1198
1194
if preds is None :
@@ -1222,32 +1218,31 @@ def predict(self, test_dataset, model):
1222
1218
for step , item in enumerate (test_dataloader ):
1223
1219
model .eval ()
1224
1220
1225
- text1 = item
1226
-
1227
- if self .args .model_type == 'zen' :
1228
- inputs = convert_examples_to_features (text1 = text1 , ngram_dict = self .ngram_dict ,
1229
- tokenizer = self .tokenizer , max_seq_length = self .args .max_length ,
1230
- return_tensors = True )
1231
- else :
1232
- inputs = self .tokenizer (text1 , padding = 'max_length' , max_length = self .args .max_length ,
1233
- truncation = True , return_tensors = 'pt' )
1234
-
1235
- if self .args .model_type == 'zen' :
1236
- inputs ['input_ngram_ids' ] = inputs ['input_ngram_ids' ].to (self .args .device )
1237
- inputs ['ngram_position_matrix' ] = inputs ['ngram_position_matrix' ].to (self .args .device )
1238
- inputs ['ngram_attention_mask' ] = inputs ['ngram_attention_mask' ].to (self .args .device )
1239
- inputs ['ngram_token_type_ids' ] = inputs ['ngram_token_type_ids' ].to (self .args .device )
1221
+ input_ids = item [0 ].to (self .args .device )
1222
+ token_type_ids = item [1 ].to (self .args .device )
1223
+ attention_mask = item [2 ].to (self .args .device )
1240
1224
1241
- inputs ['input_ids' ] = inputs ['input_ids' ].to (self .args .device )
1242
- inputs ['attention_mask' ] = inputs ['attention_mask' ].to (self .args .device )
1243
- inputs ['token_type_ids' ] = inputs ['token_type_ids' ].to (self .args .device )
1225
+ if args .model_type == 'zen' :
1226
+ input_ngram_ids = item [3 ].to (self .args .device )
1227
+ ngram_attention_mask = item [4 ].to (self .args .device )
1228
+ ngram_token_type_ids = item [5 ].to (self .args .device )
1229
+ ngram_position_matrix = item [6 ].to (self .args .device )
1244
1230
1245
1231
with torch .no_grad ():
1246
- outputs = model (** inputs )
1232
+ if self .args .model_type == 'zen' :
1233
+ outputs = model (input_ids = input_ids , token_type_ids = token_type_ids , attention_mask = attention_mask ,
1234
+ ngram_ids = input_ngram_ids ,
1235
+ ngram_positions = ngram_position_matrix ,
1236
+ ngram_token_type_ids = ngram_token_type_ids ,
1237
+ ngram_attention_mask = ngram_attention_mask )
1238
+ else :
1239
+ outputs = model (input_ids = input_ids , token_type_ids = token_type_ids ,
1240
+ attention_mask = attention_mask )
1241
+
1247
1242
if args .model_type == 'zen' :
1248
- logits = outputs
1243
+ logits = outputs . detach ()
1249
1244
else :
1250
- logits = outputs [0 ]
1245
+ logits = outputs [0 ]. detach ()
1251
1246
1252
1247
if preds is None :
1253
1248
preds = logits .detach ().cpu ().numpy ()
0 commit comments