Skip to content

Commit ae2273c

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.squeeze.dims and lowering.
PiperOrigin-RevId: 773096398
1 parent b097677 commit ae2273c

File tree

4 files changed

+35
-0
lines changed

4 files changed

+35
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,16 @@ def _aten_expand_decomp(x, shape: Sequence[int]):
248248
return torch.ops.tfl.broadcast_to(x, shape)
249249

250250

251+
@register_decomp(torch.ops.aten.squeeze.dims)
252+
def _aten_squeeze_dims_decomp(x, squeeze_dims: Sequence[int]):
253+
if len(squeeze_dims) > 8:
254+
raise ValueError(
255+
"torch.ops.tfl.squeeze supports squeezing at most 8 dimensions, but got"
256+
f" {len(squeeze_dims)} dimensions."
257+
)
258+
return torch.ops.tfl.squeeze(x, squeeze_dims)
259+
260+
251261
@register_decomp(torch.ops.aten._softmax.default)
252262
def _aten__softmax_decomp(
253263
x, dim: int, half_to_float: bool # pylint: disable=unused-argument

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,25 @@ def _tfl_broadcast_to_lowering(
525525
)
526526

527527

528+
@lower(torch.ops.tfl.squeeze.default)
529+
def _tfl_squeeze_lowering(
530+
lctx: LoweringContext,
531+
x: ir.Value,
532+
squeeze_dims: Sequence[int | ir.Value],
533+
) -> ir.Value:
534+
return _ir_operation(
535+
"tfl.squeeze",
536+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
537+
operands=[x],
538+
attributes={
539+
"squeeze_dims": ir.ArrayAttr.get([
540+
ir.IntegerAttr.get(ir.IntegerType.get_signless(64), int(d))
541+
for d in squeeze_dims
542+
]),
543+
},
544+
)
545+
546+
528547
@lower(torch.ops.tfl.softmax.default)
529548
def _tfl_softmax_lowering(
530549
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ def tfl_broadcast_to(x: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
237237
return x.expand(shape).clone()
238238

239239

240+
@custom_op_with_fake("tfl::squeeze")
241+
def tfl_squeeze(x: torch.Tensor, squeeze_dims: Sequence[int]) -> torch.Tensor:
242+
return torch.squeeze(x, squeeze_dims).clone()
243+
244+
240245
@custom_op_with_fake("tfl::softmax")
241246
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
242247
return torch.nn.functional.softmax(x, dim=-1)

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def _assert_export_and_close(
182182
("aten_unsqueeze_3", torch.ops.aten.unsqueeze.default, (rnd(torch.float32, (10, 10)), -1,), dict()),
183183
("aten_expand_0", torch.ops.aten.expand.default, (rnd(torch.float32, (10, 1)), [10, 10],), dict()),
184184
("aten_expand_1", torch.ops.aten.expand.default, (rnd(torch.float32, (1, 10)), [10, 10],), dict()),
185+
("aten_squeeze_dims_0", torch.ops.aten.squeeze.dims, (rnd(torch.float32, (2, 1, 2, 1, 2)), [1, 2, 3],), dict()),
185186
("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()),
186187
("aten__softmax_1", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), -1, False), dict()),
187188
("aten__softmax_2", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), 0, False), dict()),

0 commit comments

Comments
 (0)