Skip to content

Commit 654fc93

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.embedding and lowering.
PiperOrigin-RevId: 773846228
1 parent 1549d42 commit 654fc93

File tree

4 files changed

+55
-0
lines changed

4 files changed

+55
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ def _aten_where_self_decomp(condition, x, y):
283283
return torch.ops.tfl.select_v2(condition, x, y)
284284

285285

286+
@register_decomp(torch.ops.aten.embedding.default)
287+
def _aten_embedding_decomp(weight, indices, padding_idx=-1):
288+
# TODO: b/425747317 - Decomp to tfl.embedding_lookup once it's ready.
289+
return torch.ops.tfl.gather(weight, indices, axis=0)
290+
291+
286292
@register_decomp(torch.ops.aten._softmax.default)
287293
def _aten__softmax_decomp(
288294
x, dim: int, half_to_float: bool # pylint: disable=unused-argument

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,40 @@ def _tfl_select_v2_lowering(
600600
)
601601

602602

603+
@lower(torch.ops.tfl.embedding_lookup.default)
604+
def _tfl_embedding_lookup_lowering(
605+
lctx: LoweringContext,
606+
indices: ir.Value,
607+
weight: ir.Value,
608+
) -> ir.Value:
609+
return _ir_operation(
610+
"tfl.embedding_lookup",
611+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
612+
operands=[indices, weight],
613+
)
614+
615+
616+
@lower(torch.ops.tfl.gather.default)
617+
def _tfl_gather_lowering(
618+
lctx: LoweringContext,
619+
x: ir.Value,
620+
indices: ir.Value,
621+
axis: int,
622+
batch_dims: int = 0,
623+
) -> ir.Value:
624+
return _ir_operation(
625+
"tfl.gather",
626+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
627+
operands=[x, indices],
628+
attributes={
629+
"axis": ir.IntegerAttr.get(ir.IntegerType.get_signless(32), axis),
630+
"batch_dims": ir.IntegerAttr.get(
631+
ir.IntegerType.get_signless(32), batch_dims
632+
),
633+
},
634+
)
635+
636+
603637
@lower(torch.ops.tfl.softmax.default)
604638
def _tfl_softmax_lowering(
605639
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,20 @@ def tfl_select_v2(
273273
return torch.where(condition, x, y)
274274

275275

276+
@custom_op_with_fake("tfl::embedding_lookup")
277+
def tfl_embedding_lookup(
278+
indices: torch.Tensor, weight: torch.Tensor
279+
) -> torch.Tensor:
280+
return torch.nn.functional.embedding(indices, weight)
281+
282+
283+
@custom_op_with_fake("tfl::gather")
284+
def tfl_gather(
285+
input: torch.Tensor, indices: torch.Tensor, axis: int
286+
) -> torch.Tensor:
287+
return torch.index_select(input, axis, indices)
288+
289+
276290
@custom_op_with_fake("tfl::softmax")
277291
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
278292
return torch.nn.functional.softmax(x, dim=-1)

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def _assert_export_and_close(
186186
("aten_select_int_0", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 0, 1,), dict()),
187187
("aten_select_int_1", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 1, 1,), dict()),
188188
("aten_where_self_0", torch.ops.aten.where.self, (rnd(torch.bool, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
189+
("aten_embedding_0", torch.ops.aten.embedding.default, (rnd(torch.float32, (10, 10)), torch.tensor([0, 2, 4, 6, 8]),), dict()),
189190
("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()),
190191
("aten__softmax_1", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), -1, False), dict()),
191192
("aten__softmax_2", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), 0, False), dict()),

0 commit comments

Comments
 (0)