Skip to content

Commit defd2d6

Browse files
authored
add type promotion static T+T logit. (#60638)
* add type promotion static T+T logit. * fix bug * fix code comment * add where op test for type promotion. * fix * fix bug * fix * fix path * fix * fix * fix spelling problem. * support paddle inference. * add where grad
1 parent 5e87a34 commit defd2d6

File tree

7 files changed

+291
-44
lines changed

7 files changed

+291
-44
lines changed

python/paddle/base/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
is_compiled_with_rocm,
108108
is_compiled_with_xpu,
109109
name_scope,
110+
process_type_promotion,
110111
program_guard,
111112
require_version,
112113
set_flags,

python/paddle/base/executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
get_flags,
4242
in_pir_mode,
4343
paddle_type_to_proto_type,
44+
process_type_promotion,
4445
set_flags,
4546
)
4647
from .incubate.checkpoint import auto_checkpoint as acp
@@ -1770,6 +1771,8 @@ def run(
17701771
return_numpy=return_numpy,
17711772
)
17721773
else:
1774+
# do type promotion if necessary
1775+
program = process_type_promotion(program)
17731776
res = self._run_impl(
17741777
program=program,
17751778
feed=feed,

python/paddle/base/framework.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@
5656
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
5757
_global_flags_ = core.globals()
5858

59+
SUPPORT_PROMOTION_OPS_AND_INPUTNAME = {
60+
"elementwise_add": ['X', 'Y'],
61+
"elementwise_add_grad": ['X', 'Y'],
62+
"elementwise_sub": ['X', 'Y'],
63+
"elementwise_sub_grad": ['X', 'Y'],
64+
"elementwise_mul": ['X', 'Y'],
65+
"elementwise_mul_grad": ['X', 'Y'],
66+
"where": ['X', 'Y'],
67+
"where_grad": ['X', 'Y'],
68+
}
69+
5970

6071
def _global_flags():
6172
return _global_flags_
@@ -8141,3 +8152,99 @@ def _get_paddle_place_list(places):
81418152
ret.append(p)
81428153

81438154
return ret
8155+
8156+
8157+
def dtype_to_str(in_dtype):
8158+
if in_dtype == core.VarDesc.VarType.FP16:
8159+
return "fp16"
8160+
elif in_dtype == core.VarDesc.VarType.BF16:
8161+
return "bf16"
8162+
elif in_dtype == core.VarDesc.VarType.FP32:
8163+
return "fp32"
8164+
elif in_dtype == core.VarDesc.VarType.FP64:
8165+
return "fp64"
8166+
else:
8167+
return None
8168+
8169+
8170+
def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype):
8171+
op_device = op.attr('op_device')
8172+
cast_name = var_name.name + '.cast_' + dtype_to_str(out_dtype)
8173+
out_var = block.create_var(
8174+
name=cast_name,
8175+
dtype=out_dtype,
8176+
persistable=False,
8177+
stop_gradient=var_name.stop_gradient,
8178+
)
8179+
op_role = (
8180+
int(core.op_proto_and_checker_maker.OpRole.Forward)
8181+
if not op.has_attr('op_role')
8182+
else op.attr('op_role')
8183+
)
8184+
block._insert_op_without_sync(
8185+
idx,
8186+
type="cast",
8187+
inputs={"X": var_name},
8188+
outputs={"Out": out_var},
8189+
attrs={
8190+
"in_dtype": var_name.dtype,
8191+
"out_dtype": out_var.dtype,
8192+
"op_device": op_device,
8193+
"op_role": op_role,
8194+
},
8195+
)
8196+
op.desc._rename_input(var_name.name, out_var.name)
8197+
8198+
8199+
def process_type_promotion(program):
8200+
org_program = program
8201+
if program is None:
8202+
program = default_main_program()
8203+
# not support pir for now
8204+
if not isinstance(program, Program):
8205+
return org_program
8206+
global_block = program.global_block()
8207+
all_params = global_block.all_parameters()
8208+
for block in program.blocks:
8209+
ops = block.ops
8210+
idx = 0
8211+
while idx < len(ops):
8212+
op = ops[idx]
8213+
var_name = None
8214+
all_dtypes = []
8215+
all_input_name_need_cast = []
8216+
8217+
need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get(
8218+
op.type, None
8219+
)
8220+
# type promotion only support some dyadic api
8221+
if need_transed_var_names is None:
8222+
idx += 1
8223+
continue
8224+
8225+
# get all dtype and input_name
8226+
for input_idx in range(len(op.input_arg_names)):
8227+
if op.input_names[input_idx] in need_transed_var_names:
8228+
input_arg_name = op.input_arg_names[input_idx]
8229+
all_dtypes.append(
8230+
op.block._var_recursive(input_arg_name).dtype
8231+
)
8232+
all_input_name_need_cast.append(input_arg_name)
8233+
8234+
# only support promote between float
8235+
if core.need_type_promotion(*all_dtypes):
8236+
common_dtype = core.get_promote_dtype(op.type, *all_dtypes)
8237+
for input_name_need_cast in all_input_name_need_cast:
8238+
var_name = op.block._var_recursive(input_name_need_cast)
8239+
if var_name.dtype != common_dtype:
8240+
# add cast op for different dtype
8241+
add_cast_for_type_promotion(
8242+
op,
8243+
block,
8244+
idx,
8245+
var_name,
8246+
common_dtype,
8247+
)
8248+
idx += 1
8249+
idx += 1
8250+
return program

python/paddle/base/layers/math_op_patch.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -534,19 +534,13 @@ def __impl__(self, other_var):
534534
if lhs_dtype != rhs_dtype:
535535
if method_name in SUPPORT_PROMOTION_OPS:
536536
if core.need_type_promotion(lhs_dtype, rhs_dtype):
537-
common_dtype = core.get_promote_dtype(
538-
op_type, lhs_dtype, rhs_dtype
539-
)
537+
# only report warning here, real promotion deal in Executor
540538
warnings.warn(
541-
f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted to {common_dtype}"
539+
f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted"
542540
)
543541
warnings.filterwarnings(
544542
"ignore", message="The input dtypes of OP"
545543
)
546-
if rhs_dtype != common_dtype:
547-
other_var = astype(other_var, common_dtype)
548-
if lhs_dtype != common_dtype:
549-
self = astype(self, common_dtype)
550544
else:
551545
# NOTE(zoooo0820): Currently, we still keep the old illogical \
552546
# logic for compatibility reasons

python/paddle/static/io.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
unique_name,
3434
)
3535
from paddle.base.executor import Executor, global_scope
36-
from paddle.base.framework import Parameter, dygraph_not_support, static_only
36+
from paddle.base.framework import (
37+
Parameter,
38+
dygraph_not_support,
39+
process_type_promotion,
40+
static_only,
41+
)
3742
from paddle.base.log_helper import get_logger
3843
from paddle.framework.io_utils import (
3944
_clone_var_in_block_,
@@ -587,6 +592,10 @@ def save_inference_model(
587592
_check_vars('fetch_vars', fetch_vars)
588593

589594
program = _get_valid_program(kwargs.get('program', None))
595+
596+
# do type promotion
597+
program = process_type_promotion(program)
598+
590599
clip_extra = kwargs.get('clip_extra', True)
591600
program = normalize_program(
592601
program,
@@ -903,6 +912,9 @@ def load_inference_model(path_prefix, executor, **kwargs):
903912
# deserialize bytes to program
904913
program = deserialize_program(program_bytes)
905914

915+
# do type promotion
916+
program = process_type_promotion(program)
917+
906918
vars = list(filter(is_persistable, program.list_vars()))
907919
if len(vars) > 0:
908920
load_vars(
@@ -958,6 +970,9 @@ def load_inference_model(path_prefix, executor, **kwargs):
958970
# deserialize bytes to program
959971
program = deserialize_program(program_bytes)
960972

973+
# do type promotion
974+
program = process_type_promotion(program)
975+
961976
vars = list(filter(is_persistable, program.list_vars()))
962977
if len(vars) > 0:
963978
load_dirname = os.path.dirname(params_path)

test/legacy_test/test_tensor_type_promotion.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,31 @@ def test_dtype_is_expected(self):
119119
)
120120

121121

122+
class TestAPIAddInStatic(TestOperatorOverloadAddInStatic):
123+
def run_api(self):
124+
prog = paddle.static.Program()
125+
with paddle.static.program_guard(prog):
126+
self.generate_test_value()
127+
128+
out = paddle.add(self.l_value, self.r_value)
129+
out_reverse = paddle.add(self.r_value, self.l_value)
130+
131+
res = self.exe.run(prog, fetch_list=[out, out_reverse])
132+
return res
133+
134+
135+
create_test_case(TestAPIAddInStatic, 'float16', 'float32', 'float32')
136+
create_test_case(TestAPIAddInStatic, 'float16', 'float64', 'float64')
137+
138+
create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64')
139+
140+
141+
if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
142+
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float16', 'float32')
143+
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float32', 'float32')
144+
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float64', 'float64')
145+
146+
122147
class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic):
123148
def run_api(self):
124149
prog = paddle.static.Program()
@@ -156,74 +181,64 @@ def run_api(self):
156181
)
157182

158183

159-
class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic):
184+
class TestAPISubInStatic(TestOperatorOverloadAddInStatic):
160185
def run_api(self):
161186
prog = paddle.static.Program()
162187
with paddle.static.program_guard(prog):
163188
self.generate_test_value()
164189

165-
out = self.l_value * self.r_value
166-
out_reverse = self.r_value * self.l_value
190+
out = paddle.subtract(self.l_value, self.r_value)
191+
out_reverse = paddle.subtract(self.r_value, self.l_value)
167192

168193
res = self.exe.run(prog, fetch_list=[out, out_reverse])
169194
return res
170195

171196

172-
create_test_case(
173-
TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32'
174-
)
175-
create_test_case(
176-
TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64'
177-
)
197+
create_test_case(TestAPISubInStatic, 'float16', 'float32', 'float32')
198+
create_test_case(TestAPISubInStatic, 'float16', 'float64', 'float64')
178199

179-
create_test_case(
180-
TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64'
181-
)
200+
create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64')
182201

183-
if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
184-
create_test_case(
185-
TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32'
186-
)
187-
create_test_case(
188-
TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32'
189-
)
190-
create_test_case(
191-
TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64'
192-
)
193202

203+
if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
204+
create_test_case(TestAPISubInStatic, 'bfloat16', 'float16', 'float32')
205+
create_test_case(TestAPISubInStatic, 'bfloat16', 'float32', 'float32')
206+
create_test_case(TestAPISubInStatic, 'bfloat16', 'float64', 'float64')
194207

195-
class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic):
196-
def set_dtype(self):
197-
self.ldtype = 'float32'
198-
self.rdtype = 'float64'
199-
self.expected_out_dtype = 'bool'
200208

209+
class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic):
201210
def run_api(self):
202211
prog = paddle.static.Program()
203212
with paddle.static.program_guard(prog):
204213
self.generate_test_value()
205214

206-
out = self.l_value > self.r_value
207-
out_reverse = self.r_value > self.l_value
215+
out = self.l_value * self.r_value
216+
out_reverse = self.r_value * self.l_value
208217

209218
res = self.exe.run(prog, fetch_list=[out, out_reverse])
210219
return res
211220

212221

213-
create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool')
214-
create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool')
222+
create_test_case(
223+
TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32'
224+
)
225+
create_test_case(
226+
TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64'
227+
)
215228

216-
create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool')
229+
create_test_case(
230+
TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64'
231+
)
217232

218233
if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
219234
create_test_case(
220-
TestOperatorOverloadGTInStatic, 'bfloat16', 'float16', 'bool'
235+
TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32'
221236
)
222237
create_test_case(
223-
TestOperatorOverloadGTInStatic, 'bfloat16', 'float32', 'bool'
238+
TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32'
224239
)
225240
create_test_case(
226-
TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool'
241+
TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64'
227242
)
228243

229244

0 commit comments

Comments
 (0)