@@ -1061,7 +1061,8 @@ def dag(x):
1061
1061
try :
1062
1062
return x .H
1063
1063
except AttributeError :
1064
- return do ("conj" , do ("transpose" , x ))
1064
+ backend = infer_backend (x )
1065
+ return do ("conj" , do ("transpose" , x , like = backend ), like = backend )
1065
1066
1066
1067
1067
1068
def real (x ):
@@ -1093,13 +1094,10 @@ def to_backend_dtype(dtype_name, like):
1093
1094
return dtype_name
1094
1095
1095
1096
1097
+ @compose
1096
1098
def get_dtype_name (x ):
1097
1099
"""Find string specifier ``dtype_name`` of array ``x``."""
1098
- try :
1099
- return x .dtype .name
1100
- except AttributeError :
1101
- # let modules provide their own
1102
- return do ("get_dtype_name" , x , like = x )
1100
+ return x .dtype .name
1103
1101
1104
1102
1105
1103
_COMPLEX_DTYPES = {"complex64" , "complex128" }
@@ -1269,7 +1267,7 @@ def numpy_like(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
1269
1267
if size is None :
1270
1268
size = ()
1271
1269
1272
- x = fn (size = size , ** kwargs )
1270
+ x = fn (size , ** kwargs )
1273
1271
1274
1272
if (loc != 0.0 ) or (scale != 1.0 ):
1275
1273
x = scale * x + loc
@@ -1513,11 +1511,11 @@ def __repr__():
1513
1511
}
1514
1512
1515
1513
1514
+ @get_dtype_name .register ("builtins" )
1516
1515
def builtins_get_dtype_name (x ):
1517
1516
return _builtin_dtype_lookup [x .__class__ ]
1518
1517
1519
1518
1520
- _FUNCS ["builtins" , "get_dtype_name" ] = builtins_get_dtype_name
1521
1519
_FUNCS ["builtins" , "complex" ] = complex
1522
1520
1523
1521
# ---------------------------------- numpy ---------------------------------- #
@@ -1691,6 +1689,7 @@ def ctf_count_nonzero(x):
1691
1689
return (x != 0 ).astype (int ).sum ()
1692
1690
1693
1691
1692
+ @get_dtype_name .register ("ctf" )
1694
1693
def ctf_get_dtype_name (x ):
1695
1694
return x .dtype .__name__
1696
1695
@@ -1700,7 +1699,6 @@ def ctf_get_dtype_name(x):
1700
1699
_FUNCS ["ctf" , "allclose" ] = allclose
1701
1700
_FUNCS ["ctf" , "to_numpy" ] = ctf_to_numpy
1702
1701
_FUNCS ["ctf" , "count_nonzero" ] = ctf_count_nonzero
1703
- _FUNCS ["ctf" , "get_dtype_name" ] = ctf_get_dtype_name
1704
1702
1705
1703
_SUBMODULE_ALIASES ["ctf" , "float32" ] = "numpy"
1706
1704
_SUBMODULE_ALIASES ["ctf" , "float64" ] = "numpy"
@@ -1931,6 +1929,7 @@ def _torch_get_dtype_name(dtype):
1931
1929
return str (dtype ).split ("." )[- 1 ]
1932
1930
1933
1931
1932
+ @get_dtype_name .register ("torch" )
1934
1933
def torch_get_dtype_name (x ):
1935
1934
return _torch_get_dtype_name (x .dtype )
1936
1935
@@ -1952,7 +1951,7 @@ def torch_imag(x):
1952
1951
return x .imag
1953
1952
except AttributeError :
1954
1953
pass
1955
- return do ("zeros_like" , x , like = "torch" )
1954
+ return do ("zeros_like" , x )
1956
1955
1957
1956
1958
1957
def torch_linalg_solve_wrap (fn ):
@@ -2092,7 +2091,6 @@ def numpy_like(x, axis=None):
2092
2091
_FUNCS ["torch" , "complex" ] = complex_add_re_im
2093
2092
_FUNCS ["torch" , "transpose" ] = torch_transpose
2094
2093
_FUNCS ["torch" , "count_nonzero" ] = torch_count_nonzero
2095
- _FUNCS ["torch" , "get_dtype_name" ] = torch_get_dtype_name
2096
2094
_FUNCS ["torch" , "indices" ] = torch_indices
2097
2095
2098
2096
_FUNC_ALIASES ["torch" , "array" ] = "tensor"
@@ -2185,3 +2183,161 @@ def mxnet_to_numpy(x):
2185
2183
2186
2184
_MODULE_ALIASES ["mxnet" ] = "mxnet.numpy"
2187
2185
_FUNCS ["mxnet" , "to_numpy" ] = mxnet_to_numpy
2186
+
2187
+
2188
+ # --------------------------------- paddle ---------------------------------- #
2189
+
2190
+ _paddle_dtype_name_conversion = {
2191
+ "BOOL" : "bool" ,
2192
+ "INT8" : "int8" ,
2193
+ "INT16" : "int16" ,
2194
+ "INT32" : "int32" ,
2195
+ "INT64" : "int64" ,
2196
+ "FP16" : "float16" ,
2197
+ "FP32" : "float32" ,
2198
+ "FP64" : "float64" ,
2199
+ "COMPLEX64" : "complex64" ,
2200
+ "COMPLEX128" : "complex128" ,
2201
+ }
2202
+
2203
+
2204
+ @get_dtype_name .register ("paddle" )
2205
+ def paddle_get_dtype_name (x ):
2206
+ return _paddle_dtype_name_conversion [x .dtype .name ]
2207
+
2208
+
2209
+ @shape .register ("paddle" )
2210
+ def paddle_shape (x ):
2211
+ # convert from list
2212
+ return tuple (x .shape )
2213
+
2214
+
2215
+ def paddle_to_numpy (x ):
2216
+ return x .numpy ()
2217
+
2218
+
2219
+ def paddle_transpose (a , axes = None ):
2220
+ if axes is None :
2221
+ axes = tuple (range (a .ndim - 1 , - 1 , - 1 ))
2222
+ return a .transpose (perm = axes )
2223
+
2224
+
2225
+ def paddle_real (x ):
2226
+ # paddle doesn't support calling real on real arrays
2227
+ try :
2228
+ if x .is_complex ():
2229
+ return x .real ()
2230
+ except AttributeError :
2231
+ pass
2232
+ return x
2233
+
2234
+
2235
+ def paddle_imag (x ):
2236
+ # paddle doesn't support calling imag on real arrays
2237
+ try :
2238
+ if x .is_complex ():
2239
+ return x .imag ()
2240
+ except AttributeError :
2241
+ pass
2242
+ return do ("zeros_like" , x )
2243
+
2244
+
2245
+ def paddle_indices (dimensions ):
2246
+ _meshgrid = get_lib_fn ("paddle" , "meshgrid" )
2247
+ _arange = get_lib_fn ("paddle" , "arange" )
2248
+ return _meshgrid (* map (_arange , dimensions ), indexing = "ij" )
2249
+
2250
+
2251
+ def paddle_ravel (x ):
2252
+ return x .reshape ((- 1 ,))
2253
+
2254
+
2255
+ def paddle_pad (array , pad_width , mode = "constant" , constant_values = 0 ):
2256
+ if mode != "constant" :
2257
+ raise NotImplementedError
2258
+
2259
+ try :
2260
+ # numpy takes pads like ((0, 0), (1, 1), ... (n-1, n-1))
2261
+ # paddle takes pads like (0, 0, 1, 1, 2, 2, ...)
2262
+ pad = tuple (itertools .chain .from_iterable (pad_width ))
2263
+
2264
+ # a single tuple was specified ((a, b),) - use for all axes
2265
+ if len (pad ) == 2 :
2266
+ pad = pad * array .ndim
2267
+
2268
+ except TypeError :
2269
+ # assume int
2270
+ pad = (pad_width ,) * 2 * array .ndim
2271
+
2272
+ return do (
2273
+ "nn.functional.pad" ,
2274
+ array ,
2275
+ pad = pad ,
2276
+ mode = mode ,
2277
+ value = constant_values ,
2278
+ like = "paddle" ,
2279
+ )
2280
+
2281
+
2282
+ def paddle_wrap_reduction (fn ):
2283
+ def numpy_like (* args , ** kwargs ):
2284
+ keepdims = kwargs .pop ("keepdims" , None )
2285
+ if keepdims is not None :
2286
+ kwargs ["keepdim" ] = keepdims
2287
+ return fn (* args , ** kwargs )
2288
+
2289
+ return numpy_like
2290
+
2291
+
2292
+ def paddle_split_wrap (fn ):
2293
+ # paddle doesn't seem to have `tensor_split always`
2294
+
2295
+ @functools .wraps (fn )
2296
+ def numpy_like (ary , indices_or_sections , axis = 0 , ** kwargs ):
2297
+ if isinstance (indices_or_sections , int ):
2298
+ return fn (ary , indices_or_sections , axis = axis , ** kwargs )
2299
+ else :
2300
+ diff = do (
2301
+ "diff" ,
2302
+ indices_or_sections ,
2303
+ prepend = 0 ,
2304
+ append = shape (ary )[axis ],
2305
+ like = "numpy" ,
2306
+ )
2307
+ diff = list (diff )
2308
+ return fn (ary , diff , axis = axis )
2309
+
2310
+ return numpy_like
2311
+
2312
+ _MODULE_ALIASES ["paddle[alt]" ] = "paddle"
2313
+
2314
+ _FUNCS ["paddle" , "to_numpy" ] = paddle_to_numpy
2315
+ _FUNCS ["paddle" , "transpose" ] = paddle_transpose
2316
+ _FUNCS ["paddle" , "real" ] = paddle_real
2317
+ _FUNCS ["paddle" , "imag" ] = paddle_imag
2318
+ _FUNCS ["paddle" , "indices" ] = paddle_indices
2319
+ _FUNCS ["paddle" , "ravel" ] = paddle_ravel
2320
+ _FUNCS ["paddle" , "pad" ] = paddle_pad
2321
+
2322
+ _FUNC_ALIASES ["paddle" , "random.normal" ] = "randn"
2323
+ _FUNC_ALIASES ["paddle" , "random.uniform" ] = "rand"
2324
+ _FUNC_ALIASES ["paddle" , "asarray" ] = "to_tensor"
2325
+ _FUNC_ALIASES ["paddle" , "concatenate" ] = "concat"
2326
+ _FUNC_ALIASES ["paddle" , "power" ] = "pow"
2327
+ _FUNC_ALIASES ["paddle" , "identity" ] = "eye"
2328
+ _FUNC_ALIASES ["paddle" , "split" ] = "tensor_split"
2329
+
2330
+ _SUBMODULE_ALIASES ["paddle" , "random.normal" ] = "paddle"
2331
+ _SUBMODULE_ALIASES ["paddle" , "random.uniform" ] = "paddle"
2332
+
2333
+ _CUSTOM_WRAPPERS ["paddle" , "random.normal" ] = scale_random_normal_manually
2334
+ _CUSTOM_WRAPPERS ["paddle" , "random.uniform" ] = scale_random_uniform_manually
2335
+ _CUSTOM_WRAPPERS ["paddle[alt]" , "split" ] = paddle_split_wrap
2336
+ _CUSTOM_WRAPPERS ["paddle" , "tril" ] = make_translator (
2337
+ [("m" , ("x" ,)), ("k" , ("diagonal" , 0 ))]
2338
+ )
2339
+ _CUSTOM_WRAPPERS ["paddle" , "triu" ] = make_translator (
2340
+ [("m" , ("x" ,)), ("k" , ("diagonal" , 0 ))]
2341
+ )
2342
+ for f in ("sum" , "max" , "min" , "prod" , "mean" , "std" , "var" ):
2343
+ _CUSTOM_WRAPPERS ["paddle" , f ] = paddle_wrap_reduction
0 commit comments