Skip to content

Commit 1549d42

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.where.self and lowering.
PiperOrigin-RevId: 773821850
1 parent 672576c commit 1549d42

File tree

4 files changed

+27
-0
lines changed

4 files changed

+27
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ def _aten_select_int_decomp(x, dim, index):
278278
return torch.ops.tfl.squeeze(sliced, [dim])
279279

280280

281+
@register_decomp(torch.ops.aten.where.self)
282+
def _aten_where_self_decomp(condition, x, y):
283+
return torch.ops.tfl.select_v2(condition, x, y)
284+
285+
281286
@register_decomp(torch.ops.aten._softmax.default)
282287
def _aten__softmax_decomp(
283288
x, dim: int, half_to_float: bool # pylint: disable=unused-argument

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,20 @@ def _tfl_strided_slice_lowering(
586586
)
587587

588588

589+
@lower(torch.ops.tfl.select_v2.default)
590+
def _tfl_select_v2_lowering(
591+
lctx: LoweringContext,
592+
condition: ir.Value,
593+
x: ir.Value,
594+
y: ir.Value,
595+
) -> ir.Value:
596+
return _ir_operation(
597+
"tfl.select_v2",
598+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
599+
operands=[condition, x, y],
600+
)
601+
602+
589603
@lower(torch.ops.tfl.softmax.default)
590604
def _tfl_softmax_lowering(
591605
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,13 @@ def tfl_strided_slice(
266266
return result
267267

268268

269+
@custom_op_with_fake("tfl::select_v2")
270+
def tfl_select_v2(
271+
condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor
272+
) -> torch.Tensor:
273+
return torch.where(condition, x, y)
274+
275+
269276
@custom_op_with_fake("tfl::softmax")
270277
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
271278
return torch.nn.functional.softmax(x, dim=-1)

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def _assert_export_and_close(
185185
("aten_squeeze_dims_0", torch.ops.aten.squeeze.dims, (rnd(torch.float32, (2, 1, 2, 1, 2)), [1, 2, 3],), dict()),
186186
("aten_select_int_0", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 0, 1,), dict()),
187187
("aten_select_int_1", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 1, 1,), dict()),
188+
("aten_where_self_0", torch.ops.aten.where.self, (rnd(torch.bool, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
188189
("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()),
189190
("aten__softmax_1", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), -1, False), dict()),
190191
("aten__softmax_2", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), 0, False), dict()),

0 commit comments

Comments
 (0)