Skip to content

Commit fe2afd5

Browse files
Add interpolate nearest composite support (#41)
* Add interpolate nearest composite support * fix fmt
1 parent db77cc4 commit fe2afd5

File tree

4 files changed

+57
-8
lines changed

4 files changed

+57
-8
lines changed

ai_edge_torch/convert/conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ai_edge_torch import model
2626
from ai_edge_torch.convert import conversion_utils as cutils
2727
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
28-
from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
28+
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
2929
from ai_edge_torch.convert.fx_passes import CanonicalizePass
3030
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
3131
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
@@ -41,7 +41,7 @@ def _run_convert_passes(
4141
return run_passes(
4242
exported_program,
4343
[
44-
BuildUpsampleBilinear2DCompositePass(),
44+
BuildInterpolateCompositePass(),
4545
CanonicalizePass(),
4646
OptimizeLayoutTransposesPass(),
4747
CanonicalizePass(),

ai_edge_torch/convert/fx_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
2525
from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
2626
from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
27-
from ai_edge_torch.convert.fx_passes.build_upsample_bilinear2d_composite_pass import BuildUpsampleBilinear2DCompositePass # NOQA
27+
from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
2828
from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
2929
from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
3030
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA

ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py renamed to ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,34 @@ def attr_builder(graph_module, pattern, internal_match):
6666
return pattern
6767

6868

69-
class BuildUpsampleBilinear2DCompositePass(FxPassBase):
69+
@functools.cache
70+
def _get_interpolate_nearest2d_pattern():
71+
pattern = mark_pattern.Pattern(
72+
"tfl.resize_nearest_neighbor",
73+
lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
74+
export_args=(torch.rand(1, 3, 100, 100),),
75+
)
76+
77+
@pattern.register_attr_builder
78+
def attr_builder(pattern, graph_module, internal_match):
79+
output = internal_match.returning_nodes[0]
80+
output_h, output_w = output.meta["val"].shape[-2:]
81+
return {
82+
"size": (int(output_h), int(output_w)),
83+
"is_nchw_op": True,
84+
}
85+
86+
return pattern
87+
88+
89+
class BuildInterpolateCompositePass(FxPassBase):
7090

7191
def __init__(self):
7292
super().__init__()
7393
self._patterns = [
7494
_get_upsample_bilinear2d_pattern(),
7595
_get_upsample_bilinear2d_align_corners_pattern(),
96+
_get_interpolate_nearest2d_pattern(),
7697
]
7798

7899
def call(self, graph_module: torch.fx.GraphModule):

ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
import torch_xla
2121

22-
from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
22+
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
2323
from ai_edge_torch.convert.fx_passes import run_passes
2424

2525

@@ -38,9 +38,7 @@ def forward(self, *args, **kwargs):
3838
module = func
3939

4040
exported_program = torch.export.export(module, export_args)
41-
exported_program = run_passes(
42-
exported_program, [BuildUpsampleBilinear2DCompositePass()]
43-
)
41+
exported_program = run_passes(exported_program, [BuildInterpolateCompositePass()])
4442

4543
return torch_xla.stablehlo.exported_program_to_stablehlo(
4644
exported_program
@@ -192,6 +190,36 @@ def test_nn_functional_interpolate_bilinear_size_align_corners(self):
192190
1,
193191
)
194192

193+
def test_nn_functional_interpolate_nearest(self):
194+
stablehlo = _export_to_stablehlo_with_composite(
195+
lambda x: torch.nn.functional.interpolate(x, scale_factor=3.0, mode='nearest'),
196+
(torch.rand(1, 3, 10, 10),),
197+
)
198+
self.assertTrue(
199+
stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1
200+
)
201+
self.assertTrue(
202+
stablehlo.count(
203+
'composite_attributes = {is_nchw_op = true, size = dense<30> : tensor<2xi64>}'
204+
),
205+
1,
206+
)
207+
208+
def test_nn_functional_interpolate_nearest_size(self):
209+
stablehlo = _export_to_stablehlo_with_composite(
210+
lambda x: torch.nn.functional.interpolate(x, size=[15, 20], mode='nearest'),
211+
(torch.rand(1, 3, 10, 10),),
212+
)
213+
self.assertTrue(
214+
stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1
215+
)
216+
self.assertTrue(
217+
stablehlo.count(
218+
'composite_attributes = {is_nchw_op = true, size = dense<[15, 20]> : tensor<2xi64>}'
219+
),
220+
1,
221+
)
222+
195223

196224
if __name__ == '__main__':
197225
unittest.main()

0 commit comments

Comments
 (0)