Skip to content

Commit 8fbbd2a

Browse files
authored
Develop (#154)
Add tf ckpt converter for bert (#159) The memory of addspilttranspose is not continous for Q, K, V out. fix bert profiler bugs. update gpu fixed length benchmark scripts. add protection for splittranspose's QKV outputs
1 parent e623096 commit 8fbbd2a

18 files changed

+477
-104
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ set(CMAKE_CXX_STANDARD 14)
2121
set(CMAKE_CXX_FLAGS "-Wall")
2222
set(CMAKE_C_FLAGS "-Wall")
2323

24-
set(TURBO_TRANSFORMERS_VERSION 0.4.1)
24+
set(TURBO_TRANSFORMERS_VERSION 0.4.2)
2525

2626
option(WITH_PROFILER "Compile with profiler" OFF)
2727
option(WITH_GPU "Build with GPU" OFF)

benchmark/run_gpu_fixed_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ do
2929
for framework in ${FRAMEWORKS[*]}
3030
do
3131
python benchmark.py ${MODEL} --seq_len=${seq_len} --batch_size=${batch_size}\
32-
-n ${N} --framework=${framework}
32+
-n ${N} --framework=${framework} --use_gpu
3333
done
3434
done
3535
done

example/cpp/bert_model.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,18 @@ struct BertModel::Impl {
211211
layer(hidden, extendedAttentionMask, &attOut, &intermediateOut, &hidden);
212212
}
213213

214-
core::Tensor poolingOutput(nullptr);
215-
layers::SequencePool(static_cast<layers::types::PoolType>(pooling))(
216-
hidden, &poolingOutput);
217214
std::vector<float> vec;
218215
if (use_pooler) {
219216
core::Tensor output(nullptr);
217+
core::Tensor poolingOutput(nullptr);
218+
layers::SequencePool(static_cast<layers::types::PoolType>(pooling))(
219+
hidden, &poolingOutput);
220220
(*pooler_)(poolingOutput, &output);
221221
vec.resize(output.numel());
222222
core::Copy(output, vec);
223223
} else {
224-
vec.resize(poolingOutput.numel());
225-
core::Copy(poolingOutput, vec);
224+
vec.resize(hidden.numel());
225+
core::Copy(hidden, vec);
226226
}
227227

228228
return vec;

example/python/bert_for_sequence_classification_example.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,51 @@
1919
# import the class of the acceleration model. here is the example of BertForSequenceClassification.
2020
from transformers.modeling_bert import BertModel as TorchBertModel
2121
from transformers import BertTokenizer
22-
from transformers.modeling_bert import BertForSequenceClassification as TorchBertForSequenceClassification
22+
from transformers.modeling_bert import (
23+
BertForSequenceClassification as TorchBertForSequenceClassification,
24+
)
2325
import os
2426
import torch
2527
from typing import Optional
2628

2729

28-
#TODO(jiarufang) developed under v0.1.0, after that not tested.
29-
#Contact me if you find it is wrong.
30+
# TODO(jiarufang) developed under v0.1.0, after that not tested.
31+
# Contact me if you find it is wrong.
3032
class BertForSequenceClassification: # create a new class for speeding up
3133
def __init__(
32-
self, bertmodel, classifier
34+
self, bertmodel, classifier
3335
): # the realization of the init function(we can just copy it)
3436
self.bert = bertmodel
3537
self.classifier = classifier
3638

3739
def __call__(
38-
self, # the realization of the call function(we can just copy it)
39-
inputs,
40-
attention_masks=None,
41-
token_type_ids=None,
42-
position_ids=None,
43-
pooling_type=PoolingType.FIRST,
44-
return_type=None):
45-
pooler_output, _, _ = self.bert(inputs,
46-
attention_masks,
47-
token_type_ids,
48-
position_ids,
49-
pooling_type,
50-
return_type=ReturnType.TORCH)
40+
self, # the realization of the call function(we can just copy it)
41+
input_ids,
42+
attention_mask=None,
43+
token_type_ids=None,
44+
position_ids=None,
45+
pooling_type=PoolingType.FIRST,
46+
return_type=None,
47+
):
48+
bert_outputs = self.bert(
49+
input_ids,
50+
attention_mask,
51+
token_type_ids,
52+
position_ids,
53+
pooling_type,
54+
return_type=ReturnType.TORCH,
55+
)
56+
pooled_output = bert_outputs[1]
5157
logits = self.classifier(
52-
pooler_output
58+
pooled_output
5359
) # It's the output of classifier, if User want to output the other type, he can define them after that.
5460
return logits
5561

5662
@staticmethod
5763
def from_torch(
58-
model: TorchBertModel, # from_torch函数实现
59-
device: Optional[torch.device] = None):
60-
if device is not None and 'cuda' in device.type and torch.cuda.is_available(
61-
):
64+
model: TorchBertModel, device: Optional[torch.device] = None # from_torch函数实现
65+
):
66+
if device is not None and "cuda" in device.type and torch.cuda.is_available():
6267
model.to(device)
6368
bertmodel = turbo_transformers.BertModel.from_torch(model.bert)
6469
# We can copy the following code and do not change it
@@ -67,11 +72,11 @@ def from_torch(
6772
return BertForSequenceClassification(bertmodel, model.classifier)
6873

6974
@staticmethod
70-
def from_pretrained(model_id_or_path: str,
71-
device: Optional[torch.device] = None):
75+
def from_pretrained(model_id_or_path: str, device: Optional[torch.device] = None):
7276
# First, Use the function of from_pretrained to load the model you trained.
7377
torch_model = TorchBertForSequenceClassification.from_pretrained(
74-
model_id_or_path)
78+
model_id_or_path
79+
)
7580
# Then, Use the init function of the acceleration model to get it.
7681
model = BertForSequenceClassification.from_torch(torch_model, device)
7782
model._torch_model = torch_model # prevent destroy torch model.
@@ -82,18 +87,24 @@ def from_pretrained(model_id_or_path: str,
8287
turbo_transformers.set_num_threads(4)
8388

8489
model_id = os.path.join(
85-
os.path.dirname(__file__),
86-
'test-seq-classification-model') # the model of huggingface's path
87-
tokenizer = BertTokenizer.from_pretrained(
88-
model_id) # the initialization of tokenizer
90+
os.path.dirname(__file__), "bert_model"
91+
) # the model of huggingface's path
92+
tokenizer = BertTokenizer.from_pretrained(model_id) # the initialization of tokenizer
8993
turbo_model = BertForSequenceClassification.from_pretrained(
90-
model_id,
91-
torch.device('cpu:0')) # the initialization of the acceleration model
94+
model_id, torch.device("cpu:0")
95+
) # the initialization of the acceleration model
9296

9397
# predict after loading the model
94-
input_ids = torch.tensor(
95-
tokenizer.encode('测试一下bert模型的性能和精度是不是符合要求?',
96-
add_special_tokens=True)).unsqueeze(0)
97-
torch_result = turbo_model(input_ids)
98-
print(torch_result)
99-
# tensor([[ 0.1451, -0.0373]], grad_fn=<AddmmBackward>)
98+
99+
text = "Sample input text"
100+
inputs = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors="pt")
101+
# turbo_result holds the returned logits from TurboTransformers model
102+
turbo_result = turbo_model(**inputs)
103+
104+
torch_model = TorchBertForSequenceClassification.from_pretrained(model_id)
105+
# torch_result holds the returned logits from original Transformers model
106+
torch_result = torch_model(**inputs)[0]
107+
print(turbo_result)
108+
# tensor([[0.2716, 0.0318]], grad_fn=<AddmmBackward>)
109+
print(torch_result) # torch_result and turbo_result should hold the same logits
110+
# tensor([[0.2716, 0.0318]], grad_fn=<AddmmBackward>)

tools/convert_tf_bert_to_npz.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (C) 2020 THL A29 Limited, a Tencent company.
2+
# All rights reserved.
3+
# Licensed under the BSD 3-Clause License (the "License"); you may
4+
# not use this file except in compliance with the License. You may
5+
# obtain a copy of the License at
6+
# https://opensource.org/licenses/BSD-3-Clause
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" basis,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
10+
# implied. See the License for the specific language governing
11+
# permissions and limitations under the License.
12+
# See the AUTHORS file for names of contributors.
13+
14+
from transformers import BertConfig
15+
try:
16+
import tensorflow as tf
17+
except ImportError:
18+
print("please install tensorflow 2.0 by run `pip install tensorflow`")
19+
import numpy as np
20+
import sys
21+
import os
22+
23+
24+
# User should define the map between tf model's layer name to tt model's layer name
25+
def build_dic(num_layers):
26+
dic = {
27+
'bert/embeddings/word_embeddings':
28+
'embeddings.word_embeddings.weight',
29+
'bert/embeddings/position_embeddings':
30+
'embeddings.position_embeddings.weight',
31+
'bert/embeddings/token_type_embeddings':
32+
'embeddings.token_type_embeddings.weight',
33+
'bert/embeddings/LayerNorm/gamma':
34+
'embeddings.LayerNorm.weight',
35+
'bert/embeddings/LayerNorm/beta':
36+
'embeddings.LayerNorm.bias',
37+
'bert/pooler/dense/kernel': 'pooler.dense.weight',
38+
'bert/pooler/dense/bias': 'pooler.dense.bias'
39+
}
40+
41+
for i in range(num_layers):
42+
dic[f'bert/encoder/layer_{i}/attention/self/query/kernel'] = f'encoder.layer.{i}.attention.self.query.weight'
43+
dic[f'bert/encoder/layer_{i}/attention/self/query/bias'] = f'encoder.layer.{i}.attention.self.query.bias'
44+
dic[f'bert/encoder/layer_{i}/attention/self/key/kernel'] = f'encoder.layer.{i}.attention.self.key.weight'
45+
dic[f'bert/encoder/layer_{i}/attention/self/key/bias'] = f'encoder.layer.{i}.attention.self.key.bias'
46+
dic[f'bert/encoder/layer_{i}/attention/self/value/kernel'] = f'encoder.layer.{i}.attention.self.value.weight'
47+
dic[f'bert/encoder/layer_{i}/attention/self/value/bias'] = f'encoder.layer.{i}.attention.self.value.bias'
48+
dic[f'bert/encoder/layer_{i}/attention/output/dense/kernel'] = f'encoder.layer.{i}.attention.output.dense.weight'
49+
dic[f'bert/encoder/layer_{i}/attention/output/dense/bias'] = f'encoder.layer.{i}.attention.output.dense.bias'
50+
dic[f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma'] = f'encoder.layer.{i}.attention.output.LayerNorm.weight'
51+
dic[f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'] = f'encoder.layer.{i}.attention.output.LayerNorm.bias'
52+
dic[f'bert/encoder/layer_{i}/intermediate/dense/kernel'] = f'encoder.layer.{i}.intermediate.dense.weight'
53+
dic[f'bert/encoder/layer_{i}/intermediate/dense/bias'] = f'encoder.layer.{i}.intermediate.dense.bias'
54+
dic[f'bert/encoder/layer_{i}/output/dense/kernel'] = f'encoder.layer.{i}.output.dense.weight'
55+
dic[f'bert/encoder/layer_{i}/output/dense/bias'] = f'encoder.layer.{i}.output.dense.bias'
56+
dic[f'bert/encoder/layer_{i}/output/LayerNorm/gamma'] = f'encoder.layer.{i}.output.LayerNorm.weight'
57+
dic[f'bert/encoder/layer_{i}/output/LayerNorm/beta'] = f'encoder.layer.{i}.output.LayerNorm.bias'
58+
return dic
59+
60+
61+
def trans_layer_name_tf2turbo(dic, name):
62+
return dic[name]
63+
64+
65+
def main():
66+
if len(sys.argv) != 3:
67+
print(
68+
"Usage: \n"
69+
" convert_tf_bert_to_npz.py model_name output_file")
70+
exit(0)
71+
model_path = sys.argv[1]
72+
ckpt_path = os.path.join(model_path, "bert_model.ckpt")
73+
cfg = BertConfig.from_pretrained(os.path.join(model_path, "bert_config.json"))
74+
dic = build_dic(cfg.num_hidden_layers)
75+
names = [v[0] for v in tf.train.list_variables(ckpt_path)]
76+
77+
arrays = {}
78+
for i in range(len(names)):
79+
if names[i].startswith("cls"):
80+
continue
81+
arrays[trans_layer_name_tf2turbo(dic, names[i])] = tf.train.load_variable(ckpt_path, names[i])
82+
83+
q_weight_key = 'self.query.weight'
84+
k_weight_key = 'self.key.weight'
85+
v_weight_key = 'self.value.weight'
86+
87+
q_bias_key = 'self.query.bias'
88+
k_bias_key = 'self.key.bias'
89+
v_bias_key = 'self.value.bias'
90+
91+
numpy_dict = {}
92+
93+
for k in arrays.keys():
94+
if k.endswith(q_weight_key):
95+
ret = []
96+
ret.append(arrays[k])
97+
ret.append(arrays[k[:-len(q_weight_key)] + k_weight_key])
98+
ret.append(arrays[k[:-len(q_weight_key)] + v_weight_key])
99+
v = np.concatenate(ret, axis=1)
100+
numpy_dict[k[:-len(q_weight_key)] +
101+
"qkv.weight"] = np.ascontiguousarray(v)
102+
elif k.endswith(q_bias_key):
103+
ret = []
104+
ret.append(arrays[k])
105+
ret.append(arrays[k[:-len(q_bias_key)] + k_bias_key])
106+
ret.append(arrays[k[:-len(q_bias_key)] + v_bias_key])
107+
v = np.ascontiguousarray(np.concatenate(ret, axis=0))
108+
numpy_dict[k[:-len(q_bias_key)] + 'qkv.bias'] = v
109+
elif any((k.endswith(suffix) for suffix in (k_weight_key, v_weight_key,
110+
k_bias_key, v_bias_key))):
111+
continue
112+
else:
113+
numpy_dict[k] = np.ascontiguousarray(arrays[k])
114+
115+
np.savez_compressed(sys.argv[2], **numpy_dict)
116+
117+
118+
if __name__ == '__main__':
119+
main()

tools/docker/Dockerfile_release.cpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ RUN /opt/conda/bin/conda install pytorch==1.5.0 cpuonly -c pytorch && \
1414
/opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \
1515
/opt/conda/bin/conda clean -afy
1616

17-
RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt
17+
RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt onnxruntime-tools
1818
WORKDIR /workspace

turbo_transformers/layers/bert_embedding.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "turbo_transformers/layers/kernels/embedding.h"
1919
#include "turbo_transformers/layers/kernels/layer_norm.h"
2020

21-
2221
namespace turbo_transformers {
2322
namespace layers {
2423

turbo_transformers/layers/bert_intermediate.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void BertIntermediate::operator()(const core::Tensor& input_tensor,
4545
kernels::AddBiasAct<float, kernels::ActivationType::Gelu>(
4646
dense_bias_, output_tensor, "BertIntermediate/AddBiasAct");
4747
#ifdef WITH_PERFTOOLS
48-
profile_ctx.end_profile("BertIntermediate");
48+
profile_ctx.end_profile("BertIntermediate", input_tensor.device_type());
4949
#endif
5050
}
5151

turbo_transformers/layers/bert_output.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void BertOutput::operator()(const core::Tensor &hidden_states,
4848
input_tensor, dense_bias_, layer_norm_weight_, layer_norm_bias_,
4949
output_tensor, 1e-12, "BertOutput/AddBiasLayerNorm");
5050
#ifdef WITH_PERFTOOLS
51-
profile_ctx.end_profile("BertOutput");
51+
profile_ctx.end_profile("BertOutput", input_tensor.device_type());
5252
#endif
5353
}
5454

turbo_transformers/layers/kernels/gpu_transpose_kernel.cu

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,66 @@ void GPUSplitAddBiasTransposeForScore(
7373
weight_num, size_per_head, out_data);
7474
}
7575

76+
/*
77+
Output transpose results into three tensors
78+
*/
79+
static __global__ void split_add_bias_transpose_for_score_3output(
80+
const float* input_data, const float* bias_data, const int batch_size,
81+
const int seq_len, const int head_num, const int weight_num,
82+
const int size_per_head, float* q_output_data, float* k_output_data,
83+
float* v_output_data) {
84+
int tid = threadIdx.x;
85+
int bid = blockIdx.x;
86+
int idx = tid;
87+
int batch_id = bid / (seq_len * weight_num * head_num);
88+
int seq_id =
89+
bid % (seq_len * weight_num * head_num) / (weight_num * head_num);
90+
int weight_id = bid % (weight_num * head_num) / head_num;
91+
int head_id = bid % head_num;
92+
93+
int head_num_size_per_head = head_num * size_per_head;
94+
int weight_id_head_num_size_per_head = weight_id * head_num_size_per_head;
95+
int head_id_size_per_head = head_id * size_per_head;
96+
97+
float* output_data = nullptr;
98+
if (weight_id == 0) {
99+
output_data = q_output_data;
100+
} else if (weight_id == 1) {
101+
output_data = k_output_data;
102+
} else if (weight_id == 2) {
103+
output_data = v_output_data;
104+
}
105+
106+
while (idx < size_per_head) {
107+
float bias_val = bias_data[weight_id_head_num_size_per_head +
108+
head_id_size_per_head + idx];
109+
output_data[batch_id * seq_len * head_num_size_per_head +
110+
head_id * seq_len * size_per_head + seq_id * size_per_head +
111+
idx] =
112+
input_data[batch_id * seq_len * weight_num * head_num_size_per_head +
113+
seq_id * weight_num * head_num_size_per_head +
114+
weight_id_head_num_size_per_head + head_id_size_per_head +
115+
idx] +
116+
bias_val;
117+
idx += blockDim.x;
118+
}
119+
}
120+
121+
template <>
122+
void GPUSplitAddBiasTransposeForScoreThreeOutput(
123+
const float* input_data, const float* bias_data, int64_t batch_size,
124+
int64_t seq_len, int64_t weight_num, int64_t num_attention_heads,
125+
int64_t size_per_head, cudaStream_t stream, float* q_out_data,
126+
float* k_out_data, float* v_out_data) {
127+
const int n = size_per_head;
128+
const int m = batch_size * seq_len * num_attention_heads * weight_num;
129+
dim3 grid(m);
130+
dim3 block(min(n, 1024));
131+
split_add_bias_transpose_for_score_3output<<<grid, block, 0, stream>>>(
132+
input_data, bias_data, batch_size, seq_len, num_attention_heads,
133+
weight_num, size_per_head, q_out_data, k_out_data, v_out_data);
134+
}
135+
76136
namespace {
77137

78138
// batch, head, seq, size_per_head -> batch head seq size_per_head

0 commit comments

Comments
 (0)