Skip to content

Commit e02499c

Browse files
committed
update api implementation scheme
1 parent 1be2cf6 commit e02499c

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

rfcs/APIs/20230929_api_design_for_diagonal_scatter.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,20 @@ Tensor.diagonal_scatter(x, offset=0, axis1=0, axis2=1, name=None)
265265
## API实现方案
266266
在python/paddle/tensor/manipulation.py中增加diagonal_scatter函数
267267

268-
- 方案一:通过调用`paddle.fill_diagonal_tensor_impl`实现对应逻辑
268+
- 动态图
269+
270+
1. clone输入张量,获得output张量
271+
2. 调用diagonal方法,获得output张量对应位置上的张量视图diagonal_slice
272+
3. 通过张量索引,将diagonal_slice中的元素都变为嵌入张量
269273

270-
- 方案二:clone input张量得到output张量,再对output张量的diagnoal位置上的元素使用src张量元素进行覆盖
274+
- 静态图(无法仅通过修改python代码实现)
275+
276+
- 方案一:通过调用`fill_diagonal_tensor`实现对应逻辑,但是该方法只能在动态图中使用
277+
278+
- 方案二:调用`paddle.static.setitem`方法,覆盖diagonal_slice的元素,但是该方法在动态图中调用时,只会返回新的tensor,而不是inplace写
279+
- 如果想要调用`paddle.static.setitem(x, index, y)`,通过index来修改输入张量diagonal对应位置的元素,没有现成实现获得diagonal元素对应的index
280+
281+
- 方案三:类似torch实现方案,实现cpp算子逻辑
271282

272283
## 代码实现文件路径
273284

@@ -287,6 +298,8 @@ Tensor.diagonal_scatter(x, offset=0, axis1=0, axis2=1, name=None)
287298

288299
- 检查input的slice和src的维度是否相等,这样才能进行覆盖
289300

301+
- 对多种offset/axis1/axis2设置的情况进行测试
302+
290303
# 七、可行性分析和排期规划
291304

292305
方案实施难度可控,工期上可以满足在当前版本周期内开发完成

0 commit comments

Comments
 (0)