Skip to content

Commit 5894867

Browse files
committed
update dtype
1 parent 3d7089c commit 5894867

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

rfcs/APIs/20230929_api_design_for_diagonal_scatter.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ paddle.diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)
300300
```
301301
参数定义:
302302

303-
- `x(Tensor)`:输入张量,张量的维度至少为2维,支持bool、int32、int64、float16、float32、float64、complex64、complex128数据类型
304-
- `y(Tensor)`:嵌入张量,将会被嵌入到输入张量中,支持bool、int32、int64、float16、float32、float64、complex64、complex128数据类型
303+
- `x(Tensor)`:输入张量,张量的维度至少为2维,支持float16、 float32、 float64、 bfloat16、 uint8、 int8、 int32、 int64、 bool、 complex64、 complex128数据类型
304+
- `y(Tensor)`:嵌入张量,将会被嵌入到输入张量中,支持float16、 float32、 float64、 bfloat16、 uint8、 int8、 int32、 int64、 bool、 complex64、 complex128数据类型
305305
- `offset(int, optional)`:偏移的对角线,默认值为0
306306
- 偏移量为0,则嵌入对角线位置
307307
- 偏移量大于0,则嵌入对角线上方
@@ -318,7 +318,7 @@ Tensor.diagonal_scatter(y, offset=0, axis1=0, axis2=1, name=None)
318318
```
319319
参数定义:
320320

321-
- `y(Tensor)`:嵌入张量,将会被嵌入到输入张量中,支持bool、int32、int64、float16、float32、float64、complex64、complex128数据类型
321+
- `y(Tensor)`:嵌入张量,将会被嵌入到输入张量中,支持float16、 float32、 float64、 bfloat16、 uint8、 int8、 int32、 int64、 bool、 complex64、 complex128数据类型
322322
- `offset(int, optional)`:偏移的对角线,默认值为0
323323
- 偏移量为0,则嵌入对角线位置
324324
- 偏移量大于0,则嵌入对角线上方
@@ -349,14 +349,13 @@ def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)
349349
check_variable_and_dtype(
350350
x,
351351
'x',
352-
['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'complex64', 'complex128'],
352+
['float16', 'float32', 'float64', 'bfloat16', 'uint8', 'int8', 'int32', 'int64', 'bool', 'complex64', 'complex128'],
353353
'paddle.tensor.manipulation.diagonal_scatter',
354354
)
355355
check_variable_and_dtype(
356356
y,
357357
'y',
358-
['float16', 'float32', 'float64', 'int32', 'int64', 'bool',
359-
'complex64', 'complex128'],
358+
['float16', 'float32', 'float64', 'bfloat16', 'uint8', 'int8', 'int32', 'int64', 'bool', 'complex64', 'complex128'],
360359
'paddle.tensor.manipulation.diagonal_scatter',
361360
)
362361
out = helper.create_variable_for_type_inference(x.dtype)

0 commit comments

Comments
 (0)