Skip to content

Commit 77b2f9a

Browse files
committed
feat: update CTC task.
1 parent 111aaf0 commit 77b2f9a

File tree

7 files changed

+101
-80
lines changed

7 files changed

+101
-80
lines changed

baselines/run_classifier.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,14 @@ def main():
139139
train_samples = data_processor.get_train_sample()
140140
eval_samples = data_processor.get_dev_sample()
141141

142-
if args.task_name != 'ee':
143-
train_dataset = dataset_class(train_samples, data_processor, mode='train')
144-
eval_dataset = dataset_class(eval_samples, data_processor, mode='eval')
145-
else:
142+
if args.task_name == 'ee' or args.task_name == 'ctc':
146143
train_dataset = dataset_class(train_samples, data_processor, tokenizer, mode='train',
147144
model_type=args.model_type, ngram_dict=ngram_dict, max_length=args.max_length)
148145
eval_dataset = dataset_class(eval_samples, data_processor, tokenizer, mode='eval',
149146
model_type=args.model_type, ngram_dict=ngram_dict, max_length=args.max_length)
147+
else:
148+
train_dataset = dataset_class(train_samples, data_processor, mode='train')
149+
eval_dataset = dataset_class(eval_samples, data_processor, mode='eval')
150150

151151
model = model_class.from_pretrained(os.path.join(args.model_dir, args.model_name),
152152
num_labels=data_processor.num_labels)
@@ -167,12 +167,12 @@ def main():
167167
data_processor = data_processor_class(root=args.data_dir)
168168
test_samples = data_processor.get_test_sample()
169169

170-
if args.task_name != 'ee':
171-
test_dataset = dataset_class(test_samples, data_processor, mode='test')
172-
else:
170+
if args.task_name == 'ee' or args.task_name == 'ctc':
173171
test_dataset = dataset_class(test_samples, data_processor, tokenizer, mode='test', ngram_dict=ngram_dict,
174172
max_length=args.max_length, model_type=args.model_type)
175-
173+
else:
174+
test_dataset = dataset_class(test_samples, data_processor, mode='test')
175+
176176
model = model_class.from_pretrained(args.output_dir, num_labels=data_processor.num_labels)
177177
trainer = trainer_class(args=args, model=model, data_processor=data_processor,
178178
tokenizer=tokenizer, logger=logger, model_class=model_class, ngram_dict=ngram_dict)
-143 Bytes
Binary file not shown.
680 Bytes
Binary file not shown.

cblue/data/data_process.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def _pre_process(self, path, is_predict=False):
693693
samples = load_json(path)
694694
outputs = {'text': [], 'label': [], 'id': []}
695695
for sample in samples:
696-
outputs['text'].append(sample['text'])
696+
outputs['text'].append("\002".join([ t for t in list(sample["text"].lower())]))
697697
outputs['id'].append(sample['id'])
698698
if not is_predict:
699699
outputs['label'].append(self.label2id[sample['label']])

cblue/data/dataset.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -267,26 +267,52 @@ def __init__(
267267
self,
268268
samples,
269269
data_processor,
270-
mode='train'
270+
tokenizer,
271+
max_length=128,
272+
mode='train',
273+
model_type='bert',
274+
ngram_dict=None
271275
):
272276
super(CTCDataset, self).__init__()
273277

274-
self.texts = samples['text']
278+
self.texts = [text.split("\002") for text in samples['text']]
275279
self.ids = samples['id']
276280

277281
if mode != 'test':
278282
self.labels = samples['label']
279283
self.data_processor = data_processor
280284
self.mode = mode
285+
self.ngram_dict = ngram_dict
286+
self.max_length = max_length
287+
self.tokenizer = tokenizer
288+
self.model_type = model_type
281289

282290
def __getitem__(self, idx):
283291
text = self.texts[idx]
292+
if self.model_type == 'zen':
293+
inputs = convert_examples_to_features(text1=text, ngram_dict=self.ngram_dict,
294+
tokenizer=self.tokenizer, max_seq_length=self.max_length)
295+
else:
296+
inputs = self.tokenizer.encode_plus(text, padding='max_length', max_length=self.max_length, truncation=True)
284297

285298
if self.mode != 'test':
286-
label = self.labels[idx]
287-
return text, label
299+
if self.model_type == 'zen':
300+
return inputs['input_ids'], inputs['token_type_ids'], \
301+
inputs['attention_mask'], self.labels[idx], inputs['input_ngram_ids'], \
302+
inputs['ngram_attention_mask'], inputs['ngram_token_type_ids'], \
303+
inputs['ngram_position_matrix']
304+
else:
305+
return np.array(inputs['input_ids']), np.array(inputs['token_type_ids']), \
306+
np.array(inputs['attention_mask']), self.labels[idx]
288307
else:
289-
return text
308+
if self.model_type == 'zen':
309+
return inputs['input_ids'], inputs['token_type_ids'], \
310+
inputs['attention_mask'], inputs['input_ngram_ids'], \
311+
inputs['ngram_attention_mask'], inputs['ngram_token_type_ids'], \
312+
inputs['ngram_position_matrix']
313+
else:
314+
return np.array(inputs['input_ids']), np.array(inputs['token_type_ids']), \
315+
np.array(inputs['attention_mask']),
290316

291317
def __len__(self):
292318
return len(self.texts)

cblue/trainer/train.py

+54-59
Original file line numberDiff line numberDiff line change
@@ -1129,28 +1129,25 @@ def __init__(
11291129
def training_step(self, model, item):
11301130
model.train()
11311131

1132-
text1 = item[0]
1133-
labels = item[1].to(self.args.device)
1134-
if self.args.model_type == 'zen':
1135-
inputs = convert_examples_to_features(text1=text1, ngram_dict=self.ngram_dict,
1136-
tokenizer=self.tokenizer, max_seq_length=self.args.max_length,
1137-
return_tensors=True)
1138-
else:
1139-
inputs = self.tokenizer(text1, padding='max_length', max_length=self.args.max_length,
1140-
truncation=True, return_tensors='pt')
1132+
input_ids = item[0].to(self.args.device)
1133+
token_type_ids = item[1].to(self.args.device)
1134+
attention_mask = item[2].to(self.args.device)
1135+
labels = item[3].to(self.args.device)
11411136

11421137
if self.args.model_type == 'zen':
1143-
inputs['input_ngram_ids'] = inputs['input_ngram_ids'].to(self.args.device)
1144-
inputs['ngram_position_matrix'] = inputs['ngram_position_matrix'].to(self.args.device)
1145-
inputs['ngram_attention_mask'] = inputs['ngram_attention_mask'].to(self.args.device)
1146-
inputs['ngram_token_type_ids'] = inputs['ngram_token_type_ids'].to(self.args.device)
1147-
1148-
inputs['input_ids'] = inputs['input_ids'].to(self.args.device)
1149-
inputs['attention_mask'] = inputs['attention_mask'].to(self.args.device)
1150-
inputs['token_type_ids'] = inputs['token_type_ids'].to(self.args.device)
1138+
input_ngram_ids = item[4].to(self.args.device)
1139+
ngram_attention_mask = item[5].to(self.args.device)
1140+
ngram_token_type_ids = item[6].to(self.args.device)
1141+
ngram_position_matrix = item[7].to(self.args.device)
11511142

11521143
# default using 'Transformers' library models.
1153-
outputs = model(labels=labels, **inputs)
1144+
if self.args.model_type == 'zen':
1145+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
1146+
labels=labels, ngram_ids=input_ngram_ids, ngram_positions=ngram_position_matrix,
1147+
ngram_attention_mask=ngram_attention_mask, ngram_token_type_ids=ngram_token_type_ids)
1148+
else:
1149+
outputs = model(labels=labels, input_ids=input_ids, token_type_ids=token_type_ids,
1150+
attention_mask=attention_mask)
11541151
loss = outputs[0]
11551152
loss.backward()
11561153

@@ -1170,29 +1167,28 @@ def evaluate(self, model):
11701167
for step, item in enumerate(eval_dataloader):
11711168
model.eval()
11721169

1173-
text1 = item[0]
1174-
labels = item[1].to(args.device)
1175-
1176-
if self.args.model_type == 'zen':
1177-
inputs = convert_examples_to_features(text1=text1, ngram_dict=self.ngram_dict,
1178-
tokenizer=self.tokenizer, max_seq_length=self.args.max_length,
1179-
return_tensors=True)
1180-
else:
1181-
inputs = self.tokenizer(text1, padding='max_length', max_length=self.args.max_length,
1182-
truncation=True, return_tensors='pt')
1183-
1184-
if self.args.model_type == 'zen':
1185-
inputs['input_ngram_ids'] = inputs['input_ngram_ids'].to(self.args.device)
1186-
inputs['ngram_position_matrix'] = inputs['ngram_position_matrix'].to(self.args.device)
1187-
inputs['ngram_attention_mask'] = inputs['ngram_attention_mask'].to(self.args.device)
1188-
inputs['ngram_token_type_ids'] = inputs['ngram_token_type_ids'].to(self.args.device)
1170+
input_ids = item[0].to(self.args.device)
1171+
token_type_ids = item[1].to(self.args.device)
1172+
attention_mask = item[2].to(self.args.device)
1173+
labels = item[3].to(self.args.device)
11891174

1190-
inputs['input_ids'] = inputs['input_ids'].to(self.args.device)
1191-
inputs['attention_mask'] = inputs['attention_mask'].to(self.args.device)
1192-
inputs['token_type_ids'] = inputs['token_type_ids'].to(self.args.device)
1175+
if args.model_type == 'zen':
1176+
input_ngram_ids = item[4].to(self.args.device)
1177+
ngram_attention_mask = item[5].to(self.args.device)
1178+
ngram_token_type_ids = item[6].to(self.args.device)
1179+
ngram_position_matrix = item[7].to(self.args.device)
11931180

11941181
with torch.no_grad():
1195-
outputs = model(labels=labels, **inputs)
1182+
if self.args.model_type == 'zen':
1183+
outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
1184+
labels=labels, ngram_ids=input_ngram_ids,
1185+
ngram_positions=ngram_position_matrix,
1186+
ngram_token_type_ids=ngram_token_type_ids,
1187+
ngram_attention_mask=ngram_attention_mask)
1188+
else:
1189+
outputs = model(labels=labels, input_ids=input_ids, token_type_ids=token_type_ids,
1190+
attention_mask=attention_mask)
1191+
11961192
loss, logits = outputs[:2]
11971193

11981194
if preds is None:
@@ -1222,32 +1218,31 @@ def predict(self, test_dataset, model):
12221218
for step, item in enumerate(test_dataloader):
12231219
model.eval()
12241220

1225-
text1 = item
1226-
1227-
if self.args.model_type == 'zen':
1228-
inputs = convert_examples_to_features(text1=text1, ngram_dict=self.ngram_dict,
1229-
tokenizer=self.tokenizer, max_seq_length=self.args.max_length,
1230-
return_tensors=True)
1231-
else:
1232-
inputs = self.tokenizer(text1, padding='max_length', max_length=self.args.max_length,
1233-
truncation=True, return_tensors='pt')
1234-
1235-
if self.args.model_type == 'zen':
1236-
inputs['input_ngram_ids'] = inputs['input_ngram_ids'].to(self.args.device)
1237-
inputs['ngram_position_matrix'] = inputs['ngram_position_matrix'].to(self.args.device)
1238-
inputs['ngram_attention_mask'] = inputs['ngram_attention_mask'].to(self.args.device)
1239-
inputs['ngram_token_type_ids'] = inputs['ngram_token_type_ids'].to(self.args.device)
1221+
input_ids = item[0].to(self.args.device)
1222+
token_type_ids = item[1].to(self.args.device)
1223+
attention_mask = item[2].to(self.args.device)
12401224

1241-
inputs['input_ids'] = inputs['input_ids'].to(self.args.device)
1242-
inputs['attention_mask'] = inputs['attention_mask'].to(self.args.device)
1243-
inputs['token_type_ids'] = inputs['token_type_ids'].to(self.args.device)
1225+
if args.model_type == 'zen':
1226+
input_ngram_ids = item[3].to(self.args.device)
1227+
ngram_attention_mask = item[4].to(self.args.device)
1228+
ngram_token_type_ids = item[5].to(self.args.device)
1229+
ngram_position_matrix = item[6].to(self.args.device)
12441230

12451231
with torch.no_grad():
1246-
outputs = model(**inputs)
1232+
if self.args.model_type == 'zen':
1233+
outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
1234+
ngram_ids=input_ngram_ids,
1235+
ngram_positions=ngram_position_matrix,
1236+
ngram_token_type_ids=ngram_token_type_ids,
1237+
ngram_attention_mask=ngram_attention_mask)
1238+
else:
1239+
outputs = model(input_ids=input_ids, token_type_ids=token_type_ids,
1240+
attention_mask=attention_mask)
1241+
12471242
if args.model_type == 'zen':
1248-
logits = outputs
1243+
logits = outputs.detach()
12491244
else:
1250-
logits = outputs[0]
1245+
logits = outputs[0].detach()
12511246

12521247
if preds is None:
12531248
preds = logits.detach().cpu().numpy()

examples/run_ctc.sh

+7-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ DATA_DIR="CBLUEDatasets"
44
TASK_NAME="ctc"
55
MODEL_TYPE="bert"
66
MODEL_DIR="data/model_data"
7-
MODEL_NAME="chinese-bert-wwm-ext"
7+
MODEL_NAME="chinese-roberta-large"
88
OUTPUT_DIR="data/output"
99
RESULT_OUTPUT_DIR="data/result_output"
1010

@@ -23,15 +23,15 @@ if [ $# == 0 ]; then
2323
--result_output_dir=${RESULT_OUTPUT_DIR} \
2424
--do_train \
2525
--max_length=${MAX_LENGTH} \
26-
--train_batch_size=16 \
27-
--eval_batch_size=16 \
28-
--learning_rate=3e-5 \
29-
--epochs=3 \
26+
--train_batch_size=24 \
27+
--eval_batch_size=64 \
28+
--learning_rate=2e-5 \
29+
--epochs=5 \
3030
--warmup_proportion=0.1 \
31-
--earlystop_patience=3 \
31+
--earlystop_patience=10 \
3232
--logging_steps=200 \
3333
--save_steps=200 \
34-
--seed=2021
34+
--seed=1000
3535
elif [ $1 == "predict" ]; then
3636
python baselines/run_classifier.py \
3737
--data_dir=${DATA_DIR} \

0 commit comments

Comments
 (0)