Skip to content

Commit 039264d

Browse files
(Array API support): Add __array_namespace_info__ and device (#101)
Co-authored-by: Christian Bourjau <[email protected]>
1 parent d9802b0 commit 039264d

15 files changed

+172
-39
lines changed

.gitmodules

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
[submodule "array-api-tests"]
2-
path = api-coverage-tests
1+
[submodule "api-coverage-tests"]
2+
path = array-api-tests
33
url = [email protected]:data-apis/array-api-tests.git

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ repos:
4848
language: system
4949
types: [python]
5050
require_serial: true
51-
exclude: ^(tests|api-coverage-tests)/
51+
exclude: ^(tests|array-api-tests)/
5252
# prettier
5353
- id: prettier
5454
name: prettier

CHANGELOG.rst

+6
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@ Changelog
1010
0.10.0 (unreleased)
1111
-------------------
1212

13+
**Breaking change**
1314
- Removed the deprecated :func:`ndonnx.promote_nullable` function. Use :func:`ndonnx.additional.make_nullable` instead.
1415

16+
**Array API compliance**
17+
18+
- ndonnx now supports the :func:`ndonnx.__array_namespace_info__` function from the Array API standard.
19+
- Arrays now expose the :meth:`ndonnx.Array.device` property to improve Array API compatibility. Note that serializing an ONNX model inherently postpones device placement decisions to the runtime so currently one abstract device is supported.
20+
1521

1622
0.9.3 (2024-10-25)
1723
------------------

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pytest tests -n auto
4343

4444
It has a couple of key features:
4545

46-
- It implements the [`Array API`](https://data-apis.org/array-api/) standard. Standard compliant code can be executed without changes across numerous backends such as like `NumPy`, `JAX` and now `ndonnx`.
46+
- It implements the [`Array API`](https://data-apis.org/array-api/) standard. Standard compliant code can be executed without changes across numerous backends such as like NumPy, JAX and now ndonnx.
4747

4848
```python
4949
import numpy as np
@@ -93,7 +93,7 @@ In the future we will be enabling a stable API for an extensible data type syste
9393

9494
## Array API coverage
9595

96-
Array API compatibility is tracked in `api-coverage-tests`. Missing coverage is tracked in the `skips.txt` file. Contributions are welcome!
96+
Array API compatibility is tracked in `array-api-tests`. Missing coverage is tracked in the `skips.txt` file. Contributions are welcome!
9797

9898
Summary(1119 total):
9999

api-coverage-tests

-1
This file was deleted.

array-api-tests

Submodule array-api-tests added at dad7731

ndonnx/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
take,
162162
UnsupportedOperationError,
163163
)
164+
from ._info import __array_namespace_info__
164165
from ._constants import (
165166
e,
166167
inf,
@@ -176,6 +177,7 @@
176177

177178

178179
__all__ = [
180+
"__array_namespace_info__",
179181
"Array",
180182
"array",
181183
"from_spox_var",

ndonnx/_array.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# Copyright (c) QuantCo 2023-2024
1+
# Copyright (c) QuantCo 2023-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from __future__ import annotations
55

66
import typing
77
from collections.abc import Callable
8-
from typing import Union
8+
from typing import Any, Union
99

1010
import numpy as np
1111
import spox.opset.ai.onnx.v19 as op
@@ -254,6 +254,19 @@ def shape(self) -> tuple[int | None, ...]:
254254
else:
255255
return static_shape(self)
256256

257+
@property
258+
def device(self):
259+
return device
260+
261+
def to_device(
262+
self, device: _Device, /, *, stream: int | Any | None = None
263+
) -> Array:
264+
if device != self.device:
265+
raise ValueError("Cannot move Array to a different device")
266+
if stream is not None:
267+
raise ValueError("The 'stream' parameter is not supported in ndonnx.")
268+
return self.copy()
269+
257270
@property
258271
def values(self) -> Array:
259272
"""Accessor for data in a ``Array`` with nullable datatype."""
@@ -579,7 +592,23 @@ def any(self, axis: int | None = 0, keepdims: bool | None = False) -> ndx.Array:
579592
return ndx.any(self, axis=axis, keepdims=False)
580593

581594

595+
class _Device:
596+
# We would rather not give users the impression that their arrays
597+
# are tied to a specific device when serializing an ONNX graph as
598+
# such a concept does not exist in the ONNX standard.
599+
600+
def __str__(self):
601+
return "ndonnx device"
602+
603+
def __eq__(self, other):
604+
return type(other) is _Device
605+
606+
607+
device = _Device()
608+
609+
582610
__all__ = [
583611
"Array",
584612
"array",
613+
"device",
585614
]

ndonnx/_data_types/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) QuantCo 2023-2024
1+
# Copyright (c) QuantCo 2023-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from __future__ import annotations
@@ -28,6 +28,8 @@
2828
uint32,
2929
uint64,
3030
utf8,
31+
canonical_name,
32+
kinds,
3133
)
3234
from .classes import (
3335
Floating,
@@ -145,4 +147,6 @@ def into_nullable(dtype: StructType | CoreType) -> NullableCore:
145147
"CastMixin",
146148
"CastError",
147149
"Dtype",
150+
"canonical_name",
151+
"kinds",
148152
]

ndonnx/_data_types/aliases.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
# Copyright (c) QuantCo 2023-2024
1+
# Copyright (c) QuantCo 2023-2025
22
# SPDX-License-Identifier: BSD-3-Clause
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
from ndonnx import CoreType
39

410
from .classes import (
511
Boolean,
@@ -55,3 +61,52 @@
5561
nuint32: NUInt32 = NUInt32()
5662
nuint64: NUInt64 = NUInt64()
5763
nutf8: NUtf8 = NUtf8()
64+
65+
66+
_canonical_names = {
67+
bool: "bool",
68+
float32: "float32",
69+
float64: "float64",
70+
int8: "int8",
71+
int16: "int16",
72+
int32: "int32",
73+
int64: "int64",
74+
uint8: "uint8",
75+
uint16: "uint16",
76+
uint32: "uint32",
77+
uint64: "uint64",
78+
utf8: "utf8",
79+
}
80+
81+
82+
def canonical_name(dtype: CoreType) -> str:
83+
"""Return the canonical name of the data type."""
84+
if dtype in _canonical_names:
85+
return _canonical_names[dtype]
86+
else:
87+
raise ValueError(f"Unknown data type: {dtype}")
88+
89+
90+
_kinds = {
91+
bool: ("bool",),
92+
int8: ("signed integer", "integer", "numeric"),
93+
int16: ("signed integer", "integer", "numeric"),
94+
int32: ("signed integer", "integer", "numeric"),
95+
int64: ("signed integer", "integer", "numeric"),
96+
uint8: ("unsigned integer", "integer", "numeric"),
97+
uint16: ("unsigned integer", "integer", "numeric"),
98+
uint32: ("unsigned integer", "integer", "numeric"),
99+
uint64: ("unsigned integer", "integer", "numeric"),
100+
float32: ("floating", "numeric"),
101+
float64: ("floating", "numeric"),
102+
}
103+
104+
105+
def kinds(dtype: CoreType) -> tuple[str, ...]:
106+
"""Return the kinds of the data type."""
107+
if dtype in _kinds:
108+
return _kinds[dtype]
109+
elif dtype == utf8:
110+
raise ValueError(f"We don't yet define a kind for {dtype}")
111+
else:
112+
raise ValueError(f"Unknown data type: {dtype}")

ndonnx/_funcs.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -267,20 +267,7 @@ def iinfo(dtype):
267267

268268
def isdtype(dtype, kind) -> bool:
269269
if isinstance(kind, str):
270-
if kind == "bool":
271-
return dtype == dtypes.bool
272-
elif kind == "signed integer":
273-
return dtype in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
274-
elif kind == "unsigned integer":
275-
return dtype in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
276-
elif kind == "integral":
277-
return isinstance(dtype, dtypes.Integral)
278-
elif kind == "real floating":
279-
return isinstance(dtype, dtypes.Floating)
280-
elif kind == "complex floating":
281-
raise ValueError("'complex floating' is not supported")
282-
elif kind == "numeric":
283-
return isinstance(dtype, dtypes.Numerical)
270+
return kind in dtypes.kinds(dtype)
284271
elif isinstance(kind, dtypes.CoreType):
285272
return dtype == kind
286273
elif isinstance(kind, tuple):

ndonnx/_info.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) QuantCo 2023-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
from __future__ import annotations
5+
6+
import ndonnx as ndx
7+
from ndonnx._array import _Device, device
8+
from ndonnx._data_types import canonical_name
9+
10+
11+
class ArrayNamespaceInfo:
12+
"""Namespace metadata for the Array API standard."""
13+
14+
_all_array_api_types = [
15+
ndx.bool,
16+
ndx.float32,
17+
ndx.float64,
18+
ndx.int8,
19+
ndx.int16,
20+
ndx.int32,
21+
ndx.int64,
22+
ndx.uint8,
23+
ndx.uint16,
24+
ndx.uint32,
25+
ndx.uint64,
26+
]
27+
28+
def capabilities(self) -> dict[str, bool]:
29+
return {
30+
"boolean indexing": True,
31+
"data-dependent shapes": True,
32+
}
33+
34+
def default_device(self) -> _Device:
35+
return device
36+
37+
def devices(self) -> list[_Device]:
38+
return [device]
39+
40+
def dtypes(
41+
self, *, device=None, kind: str | tuple[str, ...] | None = None
42+
) -> dict[str, ndx.CoreType]:
43+
out: dict[str, ndx.CoreType] = {}
44+
for dtype in self._all_array_api_types:
45+
if kind is None or ndx.isdtype(dtype, kind):
46+
out[canonical_name(dtype)] = dtype
47+
return out
48+
49+
def default_dtypes(
50+
self,
51+
*,
52+
device=None,
53+
) -> dict[str, ndx.CoreType]:
54+
return {
55+
"real floating": ndx.float64,
56+
"integral": ndx.int64,
57+
"indexing": ndx.int64,
58+
}
59+
60+
61+
def __array_namespace_info__() -> ArrayNamespaceInfo: # noqa: N807
62+
return ArrayNamespaceInfo()

pixi.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ test = "pytest"
4949
test-coverage = "pytest --cov=ndonnx --cov-report=xml --cov-report=term-missing"
5050

5151
[feature.test.tasks.arrayapitests]
52-
cmd = "pytest api-coverage-tests/array_api_tests/ -v -rfX --json-report --json-report-file=api-coverage-tests.json -n auto --disable-deadline --disable-extension linalg --skips-file=skips.txt --xfails-file=xfails.txt"
52+
cmd = "python -m pytest array-api-tests/array_api_tests/ -v --disable-extension linalg --disable-deadline --skips-file=skips.txt --xfails-file=xfails.txt --json-report --json-report-file=api-coverage-tests.json -nauto"
5353
[feature.test.tasks.arrayapitests.env]
5454
ARRAY_API_TESTS_MODULE = "ndonnx"
5555
ARRAY_API_TESTS_VERSION = "2023.12"

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ indent-style = "space"
7777
python_version = '3.10'
7878
no_implicit_optional = true
7979
check_untyped_defs = true
80-
exclude = ["api-coverage-tests", "tests"]
80+
exclude = ["array-api-tests", "tests"]
8181

8282
[[tool.mypy.overrides]]
8383
module = ["onnxruntime"]
8484
ignore_missing_imports = true
8585

8686
[tool.pytest.ini_options]
87-
addopts = "--ignore=api-coverage-tests"
87+
addopts = "--ignore=array-api-tests"
8888
filterwarnings = ["ignore:.*google.protobuf.pyext.*:DeprecationWarning"]
8989

9090
[tool.typos.default]

xfails.txt

-12
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@ array_api_tests/test_creation_functions.py::test_eye
77
array_api_tests/test_creation_functions.py::test_meshgrid
88
array_api_tests/test_data_type_functions.py::test_can_cast
99
array_api_tests/test_data_type_functions.py::test_isdtype
10-
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
1110
array_api_tests/test_has_names.py::test_has_names[array_method-__complex__]
1211
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
1312
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
14-
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
1513
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
1614
array_api_tests/test_has_names.py::test_has_names[creation-meshgrid]
1715
array_api_tests/test_has_names.py::test_has_names[elementwise-conj]
@@ -36,7 +34,6 @@ array_api_tests/test_has_names.py::test_has_names[fft-irfftn]
3634
array_api_tests/test_has_names.py::test_has_names[fft-rfft]
3735
array_api_tests/test_has_names.py::test_has_names[fft-rfftfreq]
3836
array_api_tests/test_has_names.py::test_has_names[fft-rfftn]
39-
array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__]
4037
array_api_tests/test_has_names.py::test_has_names[linalg-cholesky]
4138
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
4239
array_api_tests/test_has_names.py::test_has_names[linalg-det]
@@ -65,8 +62,6 @@ array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
6562
array_api_tests/test_has_names.py::test_has_names[manipulation-moveaxis]
6663
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
6764
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
68-
array_api_tests/test_inspection_functions.py::test_array_namespace_info
69-
array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes
7065
array_api_tests/test_linalg.py::test_matrix_transpose
7166
array_api_tests/test_linalg.py::test_vecdot
7267
array_api_tests/test_manipulation_functions.py::test_moveaxis
@@ -101,8 +96,6 @@ array_api_tests/test_set_functions.py::test_unique_values
10196
array_api_tests/test_signatures.py::test_array_method_signature[__complex__]
10297
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
10398
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
104-
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
105-
array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__]
10699
array_api_tests/test_signatures.py::test_func_signature[astype]
107100
array_api_tests/test_signatures.py::test_func_signature[conj]
108101
array_api_tests/test_signatures.py::test_func_signature[copysign]
@@ -118,11 +111,6 @@ array_api_tests/test_signatures.py::test_func_signature[signbit]
118111
array_api_tests/test_signatures.py::test_func_signature[tile]
119112
array_api_tests/test_signatures.py::test_func_signature[unstack]
120113
array_api_tests/test_signatures.py::test_func_signature[vecdot]
121-
array_api_tests/test_signatures.py::test_info_func_signature[capabilities]
122-
array_api_tests/test_signatures.py::test_info_func_signature[default_device]
123-
array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes]
124-
array_api_tests/test_signatures.py::test_info_func_signature[devices]
125-
array_api_tests/test_signatures.py::test_info_func_signature[dtypes]
126114
array_api_tests/test_sorting_functions.py::test_argsort
127115
array_api_tests/test_special_cases.py::test_binary[copysign(x1_i is NaN and x2_i < 0) -> NaN]
128116
array_api_tests/test_special_cases.py::test_binary[copysign(x1_i is NaN and x2_i > 0) -> NaN]

0 commit comments

Comments
 (0)