Skip to content

Commit 7b858ab

Browse files
committed
add torch flip with support for single axis
1 parent a09d42f commit 7b858ab

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

autoray/autoray.py

+15
Original file line numberDiff line numberDiff line change
@@ -2062,6 +2062,20 @@ def torch_indices(dimensions):
20622062
return _meshgrid(*map(_arange, dimensions), indexing="ij")
20632063

20642064

2065+
def torch_flip_wrap(torch_flip):
2066+
def numpy_like(x, axis=None):
2067+
if axis is None:
2068+
dims = tuple(range(x.ndimension()))
2069+
elif isinstance(axis, int):
2070+
dims = (axis,)
2071+
else:
2072+
# already tuple/list
2073+
dims = axis
2074+
return torch_flip(x, dims)
2075+
2076+
return numpy_like
2077+
2078+
20652079
_FUNCS["torch", "pad"] = torch_pad
20662080
_FUNCS["torch", "real"] = torch_real
20672081
_FUNCS["torch", "imag"] = torch_imag
@@ -2125,6 +2139,7 @@ def torch_indices(dimensions):
21252139
[("a", ("input",)), ("axis", ("dim",))]
21262140
)
21272141
_CUSTOM_WRAPPERS["torch", "sort"] = torch_sort_wrap
2142+
_CUSTOM_WRAPPERS["torch", "flip"] = torch_flip_wrap
21282143

21292144
# for older versions of torch, can provide some alternative implementations
21302145
_MODULE_ALIASES["torch[alt]"] = "torch"

0 commit comments

Comments
 (0)