Skip to content

Commit b097677

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.expand and lowering.
PiperOrigin-RevId: 773076323
1 parent 2e75f89 commit b097677

File tree

4 files changed

+25
-0
lines changed

4 files changed

+25
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ def _aten_unsqueeze_decomp(x, dim):
243243
return torch.ops.tfl.expand_dims(x, dim)
244244

245245

246+
@register_decomp(torch.ops.aten.expand.default)
247+
def _aten_expand_decomp(x, shape: Sequence[int]):
248+
return torch.ops.tfl.broadcast_to(x, shape)
249+
250+
246251
@register_decomp(torch.ops.aten._softmax.default)
247252
def _aten__softmax_decomp(
248253
x, dim: int, half_to_float: bool # pylint: disable=unused-argument

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,19 @@ def _tfl_expand_dims_lowering(
512512
)
513513

514514

515+
@lower(torch.ops.tfl.broadcast_to.default)
516+
def _tfl_broadcast_to_lowering(
517+
lctx: LoweringContext,
518+
x: ir.Value,
519+
shape: Sequence[int | ir.Value],
520+
) -> ir.Value:
521+
return _ir_operation(
522+
"tfl.broadcast_to",
523+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
524+
operands=[x, lowering_utils.convert_shape_to_ir_value(shape)],
525+
)
526+
527+
515528
@lower(torch.ops.tfl.softmax.default)
516529
def _tfl_softmax_lowering(
517530
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ def tfl_expand_dims(x: torch.Tensor, dim: int) -> torch.Tensor:
232232
return torch.unsqueeze(x, dim).clone()
233233

234234

235+
@custom_op_with_fake("tfl::broadcast_to")
236+
def tfl_broadcast_to(x: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
237+
return x.expand(shape).clone()
238+
239+
235240
@custom_op_with_fake("tfl::softmax")
236241
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
237242
return torch.nn.functional.softmax(x, dim=-1)

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def _assert_export_and_close(
180180
("aten_unsqueeze_1", torch.ops.aten.unsqueeze.default, (rnd(torch.float32, (10, 10)), 1,), dict()),
181181
("aten_unsqueeze_2", torch.ops.aten.unsqueeze.default, (rnd(torch.float32, (10, 10)), 2,), dict()),
182182
("aten_unsqueeze_3", torch.ops.aten.unsqueeze.default, (rnd(torch.float32, (10, 10)), -1,), dict()),
183+
("aten_expand_0", torch.ops.aten.expand.default, (rnd(torch.float32, (10, 1)), [10, 10],), dict()),
184+
("aten_expand_1", torch.ops.aten.expand.default, (rnd(torch.float32, (1, 10)), [10, 10],), dict()),
183185
("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()),
184186
("aten__softmax_1", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), -1, False), dict()),
185187
("aten__softmax_2", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), 0, False), dict()),

0 commit comments

Comments
 (0)