Skip to content

Commit ca5f45a

Browse files
committed
add paddle support
1 parent 8c6510e commit ca5f45a

File tree

5 files changed

+223
-28
lines changed

5 files changed

+223
-28
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ that means you can write backend agnostic code that works for:
2020
* [autograd](https://github.com/HIPS/autograd)
2121
* [tensorflow](https://github.com/tensorflow/tensorflow)
2222
* [sparse](https://sparse.pydata.org/)
23+
* [paddle](https://github.com/paddlepaddle/paddle)
2324
* [mars](https://github.com/mars-project/mars)
2425
* ... and indeed **any** library that provides a numpy-*ish* api, even if it
2526
knows nothing about `autoray`.

autoray/autoray.py

+167-11
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,8 @@ def dag(x):
10611061
try:
10621062
return x.H
10631063
except AttributeError:
1064-
return do("conj", do("transpose", x))
1064+
backend = infer_backend(x)
1065+
return do("conj", do("transpose", x, like=backend), like=backend)
10651066

10661067

10671068
def real(x):
@@ -1093,13 +1094,10 @@ def to_backend_dtype(dtype_name, like):
10931094
return dtype_name
10941095

10951096

1097+
@compose
10961098
def get_dtype_name(x):
10971099
"""Find string specifier ``dtype_name`` of array ``x``."""
1098-
try:
1099-
return x.dtype.name
1100-
except AttributeError:
1101-
# let modules provide their own
1102-
return do("get_dtype_name", x, like=x)
1100+
return x.dtype.name
11031101

11041102

11051103
_COMPLEX_DTYPES = {"complex64", "complex128"}
@@ -1269,7 +1267,7 @@ def numpy_like(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
12691267
if size is None:
12701268
size = ()
12711269

1272-
x = fn(size=size, **kwargs)
1270+
x = fn(size, **kwargs)
12731271

12741272
if (loc != 0.0) or (scale != 1.0):
12751273
x = scale * x + loc
@@ -1513,11 +1511,11 @@ def __repr__():
15131511
}
15141512

15151513

1514+
@get_dtype_name.register("builtins")
15161515
def builtins_get_dtype_name(x):
15171516
return _builtin_dtype_lookup[x.__class__]
15181517

15191518

1520-
_FUNCS["builtins", "get_dtype_name"] = builtins_get_dtype_name
15211519
_FUNCS["builtins", "complex"] = complex
15221520

15231521
# ---------------------------------- numpy ---------------------------------- #
@@ -1691,6 +1689,7 @@ def ctf_count_nonzero(x):
16911689
return (x != 0).astype(int).sum()
16921690

16931691

1692+
@get_dtype_name.register("ctf")
16941693
def ctf_get_dtype_name(x):
16951694
return x.dtype.__name__
16961695

@@ -1700,7 +1699,6 @@ def ctf_get_dtype_name(x):
17001699
_FUNCS["ctf", "allclose"] = allclose
17011700
_FUNCS["ctf", "to_numpy"] = ctf_to_numpy
17021701
_FUNCS["ctf", "count_nonzero"] = ctf_count_nonzero
1703-
_FUNCS["ctf", "get_dtype_name"] = ctf_get_dtype_name
17041702

17051703
_SUBMODULE_ALIASES["ctf", "float32"] = "numpy"
17061704
_SUBMODULE_ALIASES["ctf", "float64"] = "numpy"
@@ -1931,6 +1929,7 @@ def _torch_get_dtype_name(dtype):
19311929
return str(dtype).split(".")[-1]
19321930

19331931

1932+
@get_dtype_name.register("torch")
19341933
def torch_get_dtype_name(x):
19351934
return _torch_get_dtype_name(x.dtype)
19361935

@@ -1952,7 +1951,7 @@ def torch_imag(x):
19521951
return x.imag
19531952
except AttributeError:
19541953
pass
1955-
return do("zeros_like", x, like="torch")
1954+
return do("zeros_like", x)
19561955

19571956

19581957
def torch_linalg_solve_wrap(fn):
@@ -2092,7 +2091,6 @@ def numpy_like(x, axis=None):
20922091
_FUNCS["torch", "complex"] = complex_add_re_im
20932092
_FUNCS["torch", "transpose"] = torch_transpose
20942093
_FUNCS["torch", "count_nonzero"] = torch_count_nonzero
2095-
_FUNCS["torch", "get_dtype_name"] = torch_get_dtype_name
20962094
_FUNCS["torch", "indices"] = torch_indices
20972095

20982096
_FUNC_ALIASES["torch", "array"] = "tensor"
@@ -2185,3 +2183,161 @@ def mxnet_to_numpy(x):
21852183

21862184
_MODULE_ALIASES["mxnet"] = "mxnet.numpy"
21872185
_FUNCS["mxnet", "to_numpy"] = mxnet_to_numpy
2186+
2187+
2188+
# --------------------------------- paddle ---------------------------------- #
2189+
2190+
_paddle_dtype_name_conversion = {
2191+
"BOOL": "bool",
2192+
"INT8": "int8",
2193+
"INT16": "int16",
2194+
"INT32": "int32",
2195+
"INT64": "int64",
2196+
"FP16": "float16",
2197+
"FP32": "float32",
2198+
"FP64": "float64",
2199+
"COMPLEX64": "complex64",
2200+
"COMPLEX128": "complex128",
2201+
}
2202+
2203+
2204+
@get_dtype_name.register("paddle")
2205+
def paddle_get_dtype_name(x):
2206+
return _paddle_dtype_name_conversion[x.dtype.name]
2207+
2208+
2209+
@shape.register("paddle")
2210+
def paddle_shape(x):
2211+
# convert from list
2212+
return tuple(x.shape)
2213+
2214+
2215+
def paddle_to_numpy(x):
2216+
return x.numpy()
2217+
2218+
2219+
def paddle_transpose(a, axes=None):
2220+
if axes is None:
2221+
axes = tuple(range(a.ndim - 1, -1, -1))
2222+
return a.transpose(perm=axes)
2223+
2224+
2225+
def paddle_real(x):
2226+
# paddle doesn't support calling real on real arrays
2227+
try:
2228+
if x.is_complex():
2229+
return x.real()
2230+
except AttributeError:
2231+
pass
2232+
return x
2233+
2234+
2235+
def paddle_imag(x):
2236+
# paddle doesn't support calling imag on real arrays
2237+
try:
2238+
if x.is_complex():
2239+
return x.imag()
2240+
except AttributeError:
2241+
pass
2242+
return do("zeros_like", x)
2243+
2244+
2245+
def paddle_indices(dimensions):
2246+
_meshgrid = get_lib_fn("paddle", "meshgrid")
2247+
_arange = get_lib_fn("paddle", "arange")
2248+
return _meshgrid(*map(_arange, dimensions), indexing="ij")
2249+
2250+
2251+
def paddle_ravel(x):
2252+
return x.reshape((-1,))
2253+
2254+
2255+
def paddle_pad(array, pad_width, mode="constant", constant_values=0):
2256+
if mode != "constant":
2257+
raise NotImplementedError
2258+
2259+
try:
2260+
# numpy takes pads like ((0, 0), (1, 1), ... (n-1, n-1))
2261+
# paddle takes pads like (0, 0, 1, 1, 2, 2, ...)
2262+
pad = tuple(itertools.chain.from_iterable(pad_width))
2263+
2264+
# a single tuple was specified ((a, b),) - use for all axes
2265+
if len(pad) == 2:
2266+
pad = pad * array.ndim
2267+
2268+
except TypeError:
2269+
# assume int
2270+
pad = (pad_width,) * 2 * array.ndim
2271+
2272+
return do(
2273+
"nn.functional.pad",
2274+
array,
2275+
pad=pad,
2276+
mode=mode,
2277+
value=constant_values,
2278+
like="paddle",
2279+
)
2280+
2281+
2282+
def paddle_wrap_reduction(fn):
2283+
def numpy_like(*args, **kwargs):
2284+
keepdims = kwargs.pop("keepdims", None)
2285+
if keepdims is not None:
2286+
kwargs["keepdim"] = keepdims
2287+
return fn(*args, **kwargs)
2288+
2289+
return numpy_like
2290+
2291+
2292+
def paddle_split_wrap(fn):
2293+
# paddle doesn't seem to have `tensor_split always`
2294+
2295+
@functools.wraps(fn)
2296+
def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
2297+
if isinstance(indices_or_sections, int):
2298+
return fn(ary, indices_or_sections, axis=axis, **kwargs)
2299+
else:
2300+
diff = do(
2301+
"diff",
2302+
indices_or_sections,
2303+
prepend=0,
2304+
append=shape(ary)[axis],
2305+
like="numpy",
2306+
)
2307+
diff = list(diff)
2308+
return fn(ary, diff, axis=axis)
2309+
2310+
return numpy_like
2311+
2312+
_MODULE_ALIASES["paddle[alt]"] = "paddle"
2313+
2314+
_FUNCS["paddle", "to_numpy"] = paddle_to_numpy
2315+
_FUNCS["paddle", "transpose"] = paddle_transpose
2316+
_FUNCS["paddle", "real"] = paddle_real
2317+
_FUNCS["paddle", "imag"] = paddle_imag
2318+
_FUNCS["paddle", "indices"] = paddle_indices
2319+
_FUNCS["paddle", "ravel"] = paddle_ravel
2320+
_FUNCS["paddle", "pad"] = paddle_pad
2321+
2322+
_FUNC_ALIASES["paddle", "random.normal"] = "randn"
2323+
_FUNC_ALIASES["paddle", "random.uniform"] = "rand"
2324+
_FUNC_ALIASES["paddle", "asarray"] = "to_tensor"
2325+
_FUNC_ALIASES["paddle", "concatenate"] = "concat"
2326+
_FUNC_ALIASES["paddle", "power"] = "pow"
2327+
_FUNC_ALIASES["paddle", "identity"] = "eye"
2328+
_FUNC_ALIASES["paddle", "split"] = "tensor_split"
2329+
2330+
_SUBMODULE_ALIASES["paddle", "random.normal"] = "paddle"
2331+
_SUBMODULE_ALIASES["paddle", "random.uniform"] = "paddle"
2332+
2333+
_CUSTOM_WRAPPERS["paddle", "random.normal"] = scale_random_normal_manually
2334+
_CUSTOM_WRAPPERS["paddle", "random.uniform"] = scale_random_uniform_manually
2335+
_CUSTOM_WRAPPERS["paddle[alt]", "split"] = paddle_split_wrap
2336+
_CUSTOM_WRAPPERS["paddle", "tril"] = make_translator(
2337+
[("m", ("x",)), ("k", ("diagonal", 0))]
2338+
)
2339+
_CUSTOM_WRAPPERS["paddle", "triu"] = make_translator(
2340+
[("m", ("x",)), ("k", ("diagonal", 0))]
2341+
)
2342+
for f in ("sum", "max", "min", "prod", "mean", "std", "var"):
2343+
_CUSTOM_WRAPPERS["paddle", f] = paddle_wrap_reduction

docs/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ write backend agnostic code that works for:
2020
* [autograd](https://github.com/HIPS/autograd)
2121
* [tensorflow](https://github.com/tensorflow/tensorflow)
2222
* [sparse](https://sparse.pydata.org/)
23+
* [paddle](https://github.com/paddlepaddle/paddle)
2324
* [mars](https://github.com/mars-project/mars)
2425
* ... and indeed **any** library that provides a numpy-*ish* api, even if it
2526
knows nothing about `autoray`.

0 commit comments

Comments
 (0)