Skip to content

Commit fc14f6b

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.arange.start_step and lowering.
PiperOrigin-RevId: 766755082
1 parent 51248ef commit fc14f6b

File tree

5 files changed

+92
-1
lines changed

5 files changed

+92
-1
lines changed

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def edges(self):
7979

8080
@property
8181
def graph(self):
82-
edges = np.array(self.edges)
82+
# Ensure edges is a 2D array with shape (N, 3) and int32 dtype.
83+
# If self.edges is empty, this will result in an array with shape (0, 3).
84+
edges = np.array(self.edges, dtype=np.int32).reshape(-1, 3)
8385
return scipy.sparse.csr_matrix(
8486
(
8587
np.minimum(edges[:, 2], MinCutSolver.INF_COST),

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def _aten_view_decomp(x, shape: Sequence[int]):
160160
return torch.ops.tfl.reshape(x, shape)
161161

162162

163+
@register_decomp(torch.ops.aten.arange.start_step)
164+
def _aten_arange_start_step_decomp(
165+
start, end, step=1, dtype=None, layout=None, device=None, pin_memory=None
166+
):
167+
return torch.ops.tfl.range(start, end, step)
168+
169+
163170
@register_decomp(torch.ops.aten._softmax.default)
164171
def _aten__softmax_decomp(
165172
x, dim: int, half_to_float: bool # pylint: disable=unused-argument

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,75 @@ def _tfl_reshape_lowering(
339339
)
340340

341341

342+
@lower(torch.ops.tfl.range.default)
343+
def _tfl_range_lowering(
344+
lctx: LoweringContext,
345+
start: int | float | ir.Value,
346+
limit: int | float | ir.Value,
347+
delta: int | float | ir.Value = 1,
348+
) -> ir.Value:
349+
tensor_meta = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
350+
output_torch_dtype = tensor_meta.dtype
351+
352+
original_mlir_output_types = lowering_utils.node_meta_to_ir_types(lctx.node)
353+
if not original_mlir_output_types or not isinstance(
354+
original_mlir_output_types[0], ir.RankedTensorType
355+
):
356+
raise ValueError(
357+
"tfl.range output type is not a RankedTensorType as expected."
358+
)
359+
360+
original_mlir_output_type = original_mlir_output_types[0]
361+
original_output_shape = original_mlir_output_type.shape
362+
original_output_element_type = original_mlir_output_type.element_type
363+
tflite_op_internal_element_type = (
364+
lowering_utils.torch_dtype_to_ir_element_type(output_torch_dtype)
365+
)
366+
367+
operands = []
368+
for val_py_scalar in [
369+
start,
370+
limit,
371+
delta,
372+
]:
373+
if isinstance(val_py_scalar, ir.Value):
374+
operands.append(val_py_scalar)
375+
else:
376+
numpy_scalar_0d = (
377+
torch.tensor(val_py_scalar, dtype=output_torch_dtype)
378+
.detach()
379+
.numpy()
380+
)
381+
scalar_tensor_val = lowering_utils.numpy_array_constant(numpy_scalar_0d)
382+
operands.append(scalar_tensor_val)
383+
384+
# Define the result type that the tfl.range *kernel* (the custom op) will
385+
# produce.
386+
tfl_op_kernel_output_type = ir.RankedTensorType.get(
387+
original_output_shape, tflite_op_internal_element_type
388+
)
389+
390+
tfl_range_op_val = _ir_operation(
391+
"tfl.range",
392+
results=[tfl_op_kernel_output_type],
393+
operands=operands,
394+
)
395+
396+
# The _tfl_range_lowering function must return a value of the
397+
# original_mlir_output_type.
398+
# If the tfl.range op's internal element type is different from the
399+
# original_output_element_type, we need to convert.
400+
if tflite_op_internal_element_type != original_output_element_type:
401+
# Convert the tfl.range output to the original expected type.
402+
final_output_val = stablehlo.convert(
403+
original_mlir_output_type, tfl_range_op_val
404+
)
405+
else:
406+
final_output_val = tfl_range_op_val
407+
408+
return final_output_val
409+
410+
342411
@lower(torch.ops.tfl.softmax.default)
343412
def _tfl_softmax_lowering(
344413
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,15 @@ def tfl_reshape_fake(input: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
189189
return torch.empty(inferred_shape, dtype=input.dtype)
190190

191191

192+
@custom_op_with_fake(
193+
"tfl::range", schema="(Scalar start, Scalar limit, Scalar delta) -> Tensor"
194+
)
195+
def tfl_range(
196+
start: int | float, limit: int | float, delta: int | float
197+
) -> torch.Tensor:
198+
return torch.arange(start, limit, delta)
199+
200+
192201
@custom_op_with_fake("tfl::softmax")
193202
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
194203
return torch.nn.functional.softmax(x, dim=-1)

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ def _assert_export_and_close(
154154
("aten_gelu_3", torch.ops.aten.gelu.default, (rnd(torch.float32, (1, 10)),), dict(approximate="tanh")),
155155
("aten_permute_0", torch.ops.aten.permute.default, (rnd(torch.float32, (10, 10)), [0, 1],), dict()),
156156
("aten_permute_1", torch.ops.aten.permute.default, (rnd(torch.float32, (1, 10)), [0, 1],), dict()),
157+
("aten_arange_start_step_0", torch.ops.aten.arange.start_step, (0, 100, 5), dict()),
158+
("aten_arange_start_step_1", torch.ops.aten.arange.start_step, (0, 100), dict()),
159+
("aten_arange_start_step_2", torch.ops.aten.arange.start_step, (100, 0, -1), dict()),
160+
("aten_arange_start_step_3", torch.ops.aten.arange.start_step, (0, 10.5, 0.5), dict()),
157161
("aten_view_0", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [1, 100],), dict()),
158162
("aten_view_1", torch.ops.aten.view.default, (rnd(torch.float32, (1, 10)), [10, 1],), dict()),
159163
("aten_view_2", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [2, 5, 10],), dict()),

0 commit comments

Comments
 (0)