Skip to content

Commit 7c3c1ad

Browse files
committed
修改配置参数
1 parent e0de69a commit 7c3c1ad

File tree

5 files changed

+15
-34
lines changed

5 files changed

+15
-34
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
生成句向量不需要做fine tune,使用预先训练好的模型即可,可参考`extract_feature.py``main`方法,注意参数必须是一个list。
1414

15-
第一次生成句向量时需要加载graph,速度比较慢,后续速度会很快
15+
首次生成句向量时需要加载graph,并在output_dir路径下生成一个新的graph文件,因此速度比较慢,再次调用速度会很快
1616
```
1717
from bert.extrac_feature import BertVector
1818
bv = BertVector()

args.py

+10-15
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,26 @@
11
import os
2-
from enum import Enum
2+
import tensorflow as tf
3+
4+
tf.logging.set_verbosity(tf.logging.INFO)
35

46
file_path = os.path.dirname(__file__)
57

68
model_dir = os.path.join(file_path, 'chinese_L-12_H-768_A-12/')
79
config_name = os.path.join(model_dir, 'bert_config.json')
810
ckpt_name = os.path.join(model_dir, 'bert_model.ckpt')
9-
1011
output_dir = os.path.join(model_dir, '../tmp/result/')
11-
1212
vocab_file = os.path.join(model_dir, 'vocab.txt')
1313
data_dir = os.path.join(model_dir, '../data/')
1414

15-
max_seq_len = 32
16-
17-
layer_indexes = [-2, -3, -4]
18-
15+
num_train_epochs = 10
1916
batch_size = 128
17+
learning_rate = 0.00005
2018

19+
# gpu使用率
2120
gpu_memory_fraction = 0.8
2221

23-
learning_rate = 0.00005
24-
25-
num_train_epochs = 10
22+
# 默认取倒数第二层的输出值作为句向量
23+
layer_indexes = [-2]
2624

27-
use_gpu = False
28-
if use_gpu:
29-
device_id = '0'
30-
else:
31-
device_id = '-1'
25+
# 序列的最大程度,单文本建议把该值调小
26+
max_seq_len = 32

extract_feature.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from graph import import_tf
21
import modeling
32
import tokenization
43
from graph import optimize_graph
54
import args
65
from queue import Queue
76
from threading import Thread
8-
9-
tf = import_tf(0, True)
7+
import tensorflow as tf
108

119

1210
class InputExample(object):

graph.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
1-
import os
21
import tempfile
32
import json
43
import logging
54
from termcolor import colored
65
import modeling
76
import args
8-
import contextlib
9-
10-
11-
def import_tf(device_id=-1, verbose=False):
12-
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if device_id < 0 else str(device_id)
13-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if verbose else '3'
14-
import tensorflow as tf
15-
tf.logging.set_verbosity(tf.logging.DEBUG if verbose else tf.logging.ERROR)
16-
return tf
7+
import tensorflow as tf
178

189

1910
def set_logger(context, verbose=False):
@@ -35,7 +26,6 @@ def optimize_graph(logger=None, verbose=False):
3526
logger = set_logger(colored('BERT_VEC', 'yellow'), verbose)
3627
try:
3728
# we don't need GPU for optimizing the graph
38-
tf = import_tf(device_id=0, verbose=verbose)
3929
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
4030

4131
# allow_soft_placement:自动选择运行设备
@@ -75,9 +65,7 @@ def optimize_graph(logger=None, verbose=False):
7565

7666
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
7767

78-
minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30
7968
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
80-
masked_reduce_max = lambda x, m: tf.reduce_max(minus_mask(x, m), axis=1)
8169
masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
8270
tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
8371

@@ -113,7 +101,7 @@ def optimize_graph(logger=None, verbose=False):
113101
[n.name[:-2] for n in output_tensors],
114102
[dtype.as_datatype_enum for dtype in dtypes],
115103
False)
116-
tmp_file = tempfile.NamedTemporaryFile('w', delete=False).name
104+
tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
117105
logger.info('write graph to a tmp file: %s' % tmp_file)
118106
with tf.gfile.GFile(tmp_file, 'wb') as f:
119107
f.write(tmp_g.SerializeToString())

similarity.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
181181
def model_fn_builder(self, bert_config, num_labels, init_checkpoint, learning_rate,
182182
num_train_steps, num_warmup_steps,
183183
use_one_hot_embeddings):
184-
"""Returns `model_fn` closure for TPUEstimator."""
184+
"""Returns `model_fn` closurimport_tfe for TPUEstimator."""
185185

186186
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
187187
from tensorflow.python.estimator.model_fn import EstimatorSpec

0 commit comments

Comments
 (0)