@@ -300,8 +300,8 @@ paddle.diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)
300
300
```
301
301
参数定义:
302
302
303
- - ` x(Tensor) ` :输入张量,张量的维度至少为2维,支持bool、int32、int64、float16、float32、float64 、complex64、complex128数据类型
304
- - ` y(Tensor) ` :嵌入张量,将会被嵌入到输入张量中,支持bool、int32、int64、float16、float32、float64 、complex64、complex128数据类型
303
+ - ` x(Tensor) ` :输入张量,张量的维度至少为2维,支持float16、float32、float64、bfloat16、uint8、int8、int32、int64、bool 、complex64、complex128数据类型
304
+ - ` y(Tensor) ` :嵌入张量,将会被嵌入到输入张量中,支持float16、float32、float64、bfloat16、uint8、int8、int32、int64、bool 、complex64、complex128数据类型
305
305
- ` offset(int, optional) ` :偏移的对角线,默认值为0
306
306
- 偏移量为0,则嵌入对角线位置
307
307
- 偏移量大于0,则嵌入对角线上方
@@ -318,7 +318,7 @@ Tensor.diagonal_scatter(y, offset=0, axis1=0, axis2=1, name=None)
318
318
```
319
319
参数定义:
320
320
321
- - ` y(Tensor) ` :嵌入张量,将会被嵌入到输入张量中,支持bool、int32、int64、float16、float32、float64 、complex64、complex128数据类型
321
+ - ` y(Tensor) ` :嵌入张量,将会被嵌入到输入张量中,支持float16、float32、float64、bfloat16、uint8、int8、int32、int64、bool 、complex64、complex128数据类型
322
322
- ` offset(int, optional) ` :偏移的对角线,默认值为0
323
323
- 偏移量为0,则嵌入对角线位置
324
324
- 偏移量大于0,则嵌入对角线上方
@@ -349,14 +349,13 @@ def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)
349
349
check_variable_and_dtype(
350
350
x,
351
351
' x' ,
352
- [' float16' , ' float32' , ' float64' , ' int32' , ' int64' , ' bool' , ' complex64' , ' complex128' ],
352
+ [' float16' , ' float32' , ' float64' , ' bfloat16 ' , ' uint8 ' , ' int8 ' , ' int32' , ' int64' , ' bool' , ' complex64' , ' complex128' ],
353
353
' paddle.tensor.manipulation.diagonal_scatter' ,
354
354
)
355
355
check_variable_and_dtype(
356
356
y,
357
357
' y' ,
358
- [' float16' , ' float32' , ' float64' , ' int32' , ' int64' , ' bool' ,
359
- ' complex64' , ' complex128' ],
358
+ [' float16' , ' float32' , ' float64' , ' bfloat16' , ' uint8' , ' int8' , ' int32' , ' int64' , ' bool' , ' complex64' , ' complex128' ],
360
359
' paddle.tensor.manipulation.diagonal_scatter' ,
361
360
)
362
361
out = helper.create_variable_for_type_inference(x.dtype)
0 commit comments