Skip to content

Commit a2f1280

Browse files
authored
[PIR] translate old ir to pir_program and trt program (PaddlePaddle#70364)
* support pir_trt * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix
1 parent efa5971 commit a2f1280

File tree

12 files changed

+636
-356
lines changed

12 files changed

+636
-356
lines changed

paddle/fluid/pir/utils/name_analysis.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ pir::Value GetOutputValueByName(const pir::Program &program,
3434
}
3535
value = op.operand_source(0);
3636
}
37+
} else if (op.isa<paddle::dialect::FeedOp>() ||
38+
op.isa<paddle::dialect::FetchOp>()) {
39+
if (op.attribute("name") == name_attr) {
40+
if (value) {
41+
PADDLE_THROW(common::errors::PreconditionNotMet(
42+
"More than one feed/fetch named with %s found.", name));
43+
}
44+
value = op.result(0);
45+
}
3746
}
3847
}
3948
return value;

python/paddle/base/executor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,17 @@ def cinn_process(program):
12381238
op.result(0).persistable,
12391239
)
12401240
data_op_infos.append(tup)
1241+
if op.name() == 'pd_op.feed':
1242+
feed_target_name = op.attrs()["name"]
1243+
var_type = paddle_type_to_proto_type[op.results()[0].dtype]
1244+
var_shape = op.results()[0].shape
1245+
tup = (
1246+
feed_target_name,
1247+
var_type,
1248+
var_shape,
1249+
op.result(0).persistable,
1250+
)
1251+
data_op_infos.append(tup)
12411252

12421253
return program, new_exe, data_op_infos
12431254

python/paddle/tensorrt/converter.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@
1919
import numpy as np
2020
import tensorrt as trt
2121

22-
# init tensorrt plugin
23-
trt_plugin_lib = ctypes.CDLL('libnvinfer_plugin.so')
24-
trt_plugin_lib.initLibNvInferPlugins(None, "")
25-
2622
import paddle
2723
from paddle import pir
2824
from paddle.base.core import clear_shape_info, get_value_shape_range_info
@@ -83,6 +79,9 @@ def __init__(self, paddle_program, scope, trt_config=None):
8379
self.input_info = {}
8480
self.trt_output_value_map = {}
8581
self.engine_num = 0
82+
# init tensorrt plugin
83+
trt_plugin_lib = ctypes.CDLL('libnvinfer_plugin.so')
84+
trt_plugin_lib.initLibNvInferPlugins(None, "")
8685

8786
def find_graph_inputs_outputs(self, group_op):
8887
operations = next(iter(group_op.blocks())).ops

python/paddle/tensorrt/export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,13 @@ def generate_input_data(self):
123123
low, high = self.input_range
124124
self.input_min_data = np.random.randint(
125125
low, high, size=self.min_input_shape
126-
)
126+
).astype(self.input_data_type)
127127
self.input_optim_data = np.random.randint(
128128
low, high, size=self.optim_input_shape
129-
)
129+
).astype(self.input_data_type)
130130
self.input_max_data = np.random.randint(
131131
low, high, size=self.max_input_shape
132-
)
132+
).astype(self.input_data_type)
133133
else:
134134
low, high = self.input_range if self.input_range else (0, 1)
135135
self.input_min_data = np.random.uniform(

0 commit comments

Comments
 (0)