Skip to content

Commit 75d48d5

Browse files
junjiang-labcopybara-github
authored andcommitted
Promote operand types before calling TFLite binary ops.
PiperOrigin-RevId: 773888968
1 parent 654fc93 commit 75d48d5

File tree

3 files changed

+54
-13
lines changed

3 files changed

+54
-13
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,33 +66,60 @@ def _aten_sub_tensor_decomp(x, y, alpha=1):
6666
return out
6767

6868

69+
def _promote_types_for_binary_op(x, y):
70+
"""Promotes operand types for a binary op."""
71+
# TFLite's binary ops require operands to have the same element type.
72+
# We promote the types before calling the op.
73+
if not isinstance(x, torch.Tensor):
74+
# x is a scalar, y must be a tensor.
75+
x = torch.scalar_tensor(x, dtype=y.dtype)
76+
elif not isinstance(y, torch.Tensor):
77+
# y is a scalar, x is a tensor.
78+
# Handle scalar operand by converting scalar to a tensor.
79+
# The dtype of the new tensor should match x's dtype for promotion.
80+
y = torch.scalar_tensor(y, dtype=x.dtype)
81+
82+
target_dtype = torch.promote_types(x.dtype, y.dtype)
83+
if x.dtype != target_dtype:
84+
x = x.to(target_dtype)
85+
if y.dtype != target_dtype:
86+
y = y.to(target_dtype)
87+
return x, y
88+
89+
6990
@register_decomp(torch.ops.aten.mul.Tensor)
7091
def _aten_mul_tensor_decomp(x, y):
92+
x, y = _promote_types_for_binary_op(x, y)
7193
return torch.ops.tfl.mul(x, y)
7294

7395

7496
@register_decomp(torch.ops.aten.mul.Scalar)
7597
def _aten_mul_scalar_decomp(x, y):
98+
x, y = _promote_types_for_binary_op(x, y)
7699
return torch.ops.tfl.mul(x, y)
77100

78101

79102
@register_decomp(torch.ops.aten.div.Tensor)
80103
def _aten_div_tensor_decomp(x, y):
104+
x, y = _promote_types_for_binary_op(x, y)
81105
return torch.ops.tfl.div(x, y)
82106

83107

84108
@register_decomp(torch.ops.aten.pow.Scalar)
85109
def _aten_pow_scalar_decomp(x, y):
110+
x, y = _promote_types_for_binary_op(x, y)
86111
return torch.ops.tfl.pow(x, y)
87112

88113

89114
@register_decomp(torch.ops.aten.pow.Tensor_Scalar)
90115
def _aten_pow_tensor_scalar_decomp(x, y):
116+
x, y = _promote_types_for_binary_op(x, y)
91117
return torch.ops.tfl.pow(x, y)
92118

93119

94120
@register_decomp(torch.ops.aten.pow.Tensor_Tensor)
95121
def _aten_pow_tensor_tensor_decomp(x, y):
122+
x, y = _promote_types_for_binary_op(x, y)
96123
return torch.ops.tfl.pow(x, y)
97124

98125

@@ -117,11 +144,13 @@ def _aten_mean_dim_decomp(x, dim, keepdim=False):
117144

118145
@register_decomp(torch.ops.aten.gt.Tensor)
119146
def _aten_gt_tensor_decomp(x, y):
147+
x, y = _promote_types_for_binary_op(x, y)
120148
return torch.ops.tfl.greater(x, y)
121149

122150

123151
@register_decomp(torch.ops.aten.lt.Tensor)
124152
def _aten_lt_tensor_decomp(x, y):
153+
x, y = _promote_types_for_binary_op(x, y)
125154
return torch.ops.tfl.less(x, y)
126155

127156

@@ -280,6 +309,7 @@ def _aten_select_int_decomp(x, dim, index):
280309

281310
@register_decomp(torch.ops.aten.where.self)
282311
def _aten_where_self_decomp(condition, x, y):
312+
x, y = _promote_types_for_binary_op(x, y)
283313
return torch.ops.tfl.select_v2(condition, x, y)
284314

285315

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -429,22 +429,32 @@ def _tfl_range_lowering(
429429
lowering_utils.torch_dtype_to_ir_element_type(output_torch_dtype)
430430
)
431431

432+
# All operands and the output of tfl.range must have the same element type.
433+
# We cast all operands to the expected output element type.
432434
operands = []
433-
for val_py_scalar in [
434-
start,
435-
limit,
436-
delta,
437-
]:
438-
if isinstance(val_py_scalar, ir.Value):
439-
operands.append(val_py_scalar)
440-
else:
435+
for operand in [start, limit, delta]:
436+
if not isinstance(operand, ir.Value):
437+
# Convert python scalars to ir.Value.
441438
numpy_scalar_0d = (
442-
torch.tensor(val_py_scalar, dtype=output_torch_dtype)
443-
.detach()
444-
.numpy()
439+
torch.tensor(operand, dtype=output_torch_dtype).detach().numpy()
445440
)
446-
scalar_tensor_val = lowering_utils.numpy_array_constant(numpy_scalar_0d)
447-
operands.append(scalar_tensor_val)
441+
operand = lowering_utils.numpy_array_constant(numpy_scalar_0d)
442+
443+
# `operand` is now an ir.Value.
444+
# Cast its element type to the output element type if they don't match.
445+
operand_type = operand.type
446+
if not isinstance(operand_type, ir.RankedTensorType):
447+
raise TypeError(
448+
"tfl.range operand expected to be RankedTensorType, got"
449+
f" {operand_type}"
450+
)
451+
452+
if operand_type.element_type != tflite_op_internal_element_type:
453+
cast_to_type = ir.RankedTensorType.get(
454+
operand_type.shape, tflite_op_internal_element_type
455+
)
456+
operand = stablehlo.convert(cast_to_type, operand)
457+
operands.append(operand)
448458

449459
# Define the result type that the tfl.range *kernel* (the custom op) will
450460
# produce.

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _assert_export_and_close(
132132
("aten_div_Tensor_2", torch.ops.aten.div.Tensor, (rnd(torch.float32, (10, 10)), np.random.rand(),), dict()),
133133
("aten_pow_Scalar_0", torch.ops.aten.pow.Scalar, (np.random.rand(), rnd(torch.float32, (10, 10)),), dict()),
134134
("aten_pow_Tensor_Scalar_0", torch.ops.aten.pow.Tensor_Scalar, (rnd(torch.float32, (10, 10)), np.random.rand(),), dict()),
135+
("aten_pow_Tensor_Scalar_1", torch.ops.aten.pow.Tensor_Scalar, (rnd(torch.float32, (10, 10)), np.random.randint(2, 5),), dict()),
135136
("aten_pow_Tensor_Tensor_0", torch.ops.aten.pow.Tensor_Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
136137
("aten_bitwise_and_Tensor_0", torch.ops.aten.bitwise_and.Tensor, (rnd(torch.bool, (10, 10)), rnd(torch.bool, (10, 10)),), dict()),
137138
("aten_mean_dim_0", torch.ops.aten.mean.dim, (rnd(torch.float32, (10, 10)), 0), dict()),

0 commit comments

Comments
 (0)