@@ -1312,21 +1312,24 @@ def translated_function(*args, **kwargs):
1312
1312
new_kwargs = {}
1313
1313
translation = translator .copy ()
1314
1314
1315
- # convert args
1315
+ # convert args, pairing them off with kwargs
1316
1316
for arg_value in args :
1317
1317
new_arg_name = translation .popitem (last = False )[1 ][0 ]
1318
1318
new_kwargs [new_arg_name ] = arg_value
1319
1319
1320
- # convert kwargs - but only those in the translation
1320
+ # convert kwargs - but only those in the translation
1321
1321
for key , value in kwargs .items ():
1322
1322
try :
1323
1323
new_kwargs [translation .pop (key )[0 ]] = value
1324
1324
except KeyError :
1325
1325
new_kwargs [key ] = value
1326
1326
1327
1327
# set remaining default kwargs
1328
- for key , value in translation .items ():
1329
- new_kwargs [value [0 ]] = value [1 ]
1328
+ for opt in translation .values ():
1329
+ if len (opt ) == 2 :
1330
+ # backend_name, default_value
1331
+ new_kwargs [opt [0 ]] = opt [1 ]
1332
+ # else, no default value -> don't inject
1330
1333
1331
1334
return fn (** new_kwargs )
1332
1335
@@ -2146,12 +2149,13 @@ def numpy_like(x, axis=None):
2146
2149
_CUSTOM_WRAPPERS ["torch" , "sort" ] = torch_sort_wrap
2147
2150
_CUSTOM_WRAPPERS ["torch" , "flip" ] = torch_flip_wrap
2148
2151
_torch_reduce_translation = [
2149
- ("a" , ("input" ,)), ("axis" , ("dim" , None )), ("keepdims" , ("keepdim" , False ))
2152
+ ("a" , ("input" ,)),
2153
+ ("axis" , ("dim" ,)),
2154
+ ("keepdims" , ("keepdim" ,)),
2150
2155
]
2151
- _CUSTOM_WRAPPERS ["torch" , "sum" ] = make_translator (_torch_reduce_translation )
2152
- _CUSTOM_WRAPPERS ["torch" , "max" ] = make_translator (_torch_reduce_translation )
2153
- _CUSTOM_WRAPPERS ["torch" , "min" ] = make_translator (_torch_reduce_translation )
2154
- _CUSTOM_WRAPPERS ["torch" , "prod" ] = make_translator (_torch_reduce_translation )
2156
+ for f in ("sum" , "max" , "min" , "prod" , "mean" , "median" , "std" , "var" ):
2157
+ # TODO: search "keepdim" in torch docs to find more
2158
+ _CUSTOM_WRAPPERS ["torch" , f ] = make_translator (_torch_reduce_translation )
2155
2159
2156
2160
# for older versions of torch, can provide some alternative implementations
2157
2161
_MODULE_ALIASES ["torch[alt]" ] = "torch"
0 commit comments