Skip to content

Commit e36d1dc

Browse files
committed
torch reduce fns: no default dim or keepdim value
1 parent 380bcfe commit e36d1dc

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

autoray/autoray.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -1312,21 +1312,24 @@ def translated_function(*args, **kwargs):
13121312
new_kwargs = {}
13131313
translation = translator.copy()
13141314

1315-
# convert args
1315+
# convert args, pairing them off with kwargs
13161316
for arg_value in args:
13171317
new_arg_name = translation.popitem(last=False)[1][0]
13181318
new_kwargs[new_arg_name] = arg_value
13191319

1320-
# convert kwargs - but only those in the translation
1320+
# convert kwargs - but only those in the translation
13211321
for key, value in kwargs.items():
13221322
try:
13231323
new_kwargs[translation.pop(key)[0]] = value
13241324
except KeyError:
13251325
new_kwargs[key] = value
13261326

13271327
# set remaining default kwargs
1328-
for key, value in translation.items():
1329-
new_kwargs[value[0]] = value[1]
1328+
for opt in translation.values():
1329+
if len(opt) == 2:
1330+
# backend_name, default_value
1331+
new_kwargs[opt[0]] = opt[1]
1332+
# else, no default value -> don't inject
13301333

13311334
return fn(**new_kwargs)
13321335

@@ -2146,12 +2149,13 @@ def numpy_like(x, axis=None):
21462149
_CUSTOM_WRAPPERS["torch", "sort"] = torch_sort_wrap
21472150
_CUSTOM_WRAPPERS["torch", "flip"] = torch_flip_wrap
21482151
_torch_reduce_translation = [
2149-
("a", ("input",)), ("axis", ("dim", None)), ("keepdims", ("keepdim", False))
2152+
("a", ("input",)),
2153+
("axis", ("dim",)),
2154+
("keepdims", ("keepdim",)),
21502155
]
2151-
_CUSTOM_WRAPPERS["torch", "sum"] = make_translator(_torch_reduce_translation)
2152-
_CUSTOM_WRAPPERS["torch", "max"] = make_translator(_torch_reduce_translation)
2153-
_CUSTOM_WRAPPERS["torch", "min"] = make_translator(_torch_reduce_translation)
2154-
_CUSTOM_WRAPPERS["torch", "prod"] = make_translator(_torch_reduce_translation)
2156+
for f in ("sum", "max", "min", "prod", "mean", "median", "std", "var"):
2157+
# TODO: search "keepdim" in torch docs to find more
2158+
_CUSTOM_WRAPPERS["torch", f] = make_translator(_torch_reduce_translation)
21552159

21562160
# for older versions of torch, can provide some alternative implementations
21572161
_MODULE_ALIASES["torch[alt]"] = "torch"

tests/test_autoray.py

+36
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,42 @@ def test_binary_functions(f, args, xfail_backends, backend):
149149
assert ar.do("allclose", yt, yn)
150150

151151

152+
@pytest.mark.parametrize(
153+
"f",
154+
[
155+
"sum",
156+
"prod",
157+
"max",
158+
"min",
159+
"mean",
160+
],
161+
)
162+
@pytest.mark.parametrize(
163+
"kwargs",
164+
[
165+
{},
166+
{"axis": 1},
167+
{"axis": 1, "keepdims": True},
168+
{"axis": (0, 2)},
169+
],
170+
)
171+
@pytest.mark.parametrize("backend", BACKENDS)
172+
def test_reduce_functions(f, kwargs, backend):
173+
if (
174+
backend == "torch"
175+
and f == "prod"
176+
and isinstance(kwargs.get("axis"), tuple)
177+
):
178+
pytest.xfail("Pytorch doesn't support prod with tuple axis.")
179+
180+
x = ar.do("random.normal", size=(2, 3, 4), like="numpy")
181+
y = ar.do(f, x, **kwargs)
182+
xb = ar.do("asarray", x, like=backend)
183+
yb = ar.do(f, xb, **kwargs)
184+
yt = ar.do("to_numpy", yb)
185+
assert ar.do("allclose", yt, y)
186+
187+
152188
@pytest.mark.parametrize("backend", BACKENDS)
153189
@pytest.mark.parametrize("fn", ["sqrt", "exp", "sum"])
154190
def test_basic(backend, fn):

0 commit comments

Comments
 (0)