Skip to content

Commit 221c7ce

Browse files
committed
Mark jax.abstract_arrays as deprecated.
1 parent 3ba308d commit 221c7ce

File tree

6 files changed

+44
-19
lines changed

6 files changed

+44
-19
lines changed

CHANGELOG.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
99
## jax 0.4.12
1010

1111
* Deprecations
12-
* The following APIs have been removed after a 3 month deprecation period, in
13-
accordance with the {ref}`api-compatibility` policy:
14-
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
15-
of `numpy.alltrue` in NumPy version 1.25.0.
16-
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
17-
of `numpy.sometrue` in NumPy version 1.25.0.
18-
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
19-
of `numpy.product` in NumPy version 1.25.0.
20-
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
21-
of `numpy.cumproduct` in NumPy version 1.25.0.
12+
* `jax.abstract_arrays` and its contents are now deprecated. See related
13+
functionality in :mod:`jax.core`.
14+
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
15+
of `numpy.alltrue` in NumPy version 1.25.0.
16+
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
17+
of `numpy.sometrue` in NumPy version 1.25.0.
18+
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
19+
of `numpy.product` in NumPy version 1.25.0.
20+
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
21+
of `numpy.cumproduct` in NumPy version 1.25.0.
2222

2323
## jaxlib 0.4.12
2424

docs/notebooks/How_JAX_primitives_work.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@
519519
}
520520
],
521521
"source": [
522-
"from jax._src import abstract_arrays\n",
522+
"from jax import core\n",
523523
"@trace(\"multiply_add_abstract_eval\")\n",
524524
"def multiply_add_abstract_eval(xs, ys, zs):\n",
525525
" \"\"\"Abstract evaluation of the primitive.\n",
@@ -533,7 +533,7 @@
533533
" \"\"\"\n",
534534
" assert xs.shape == ys.shape\n",
535535
" assert xs.shape == zs.shape\n",
536-
" return abstract_arrays.ShapedArray(xs.shape, xs.dtype)\n",
536+
" return core.ShapedArray(xs.shape, xs.dtype)\n",
537537
"\n",
538538
"# Now we register the abstract evaluation with JAX\n",
539539
"multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)"

docs/notebooks/How_JAX_primitives_work.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ In the latter case, JAX uses the actual concrete value wrapped as an abstract va
308308
:id: ctQmEeckIbdo
309309
:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2
310310
311-
from jax._src import abstract_arrays
311+
from jax import core
312312
@trace("multiply_add_abstract_eval")
313313
def multiply_add_abstract_eval(xs, ys, zs):
314314
"""Abstract evaluation of the primitive.
@@ -322,7 +322,7 @@ def multiply_add_abstract_eval(xs, ys, zs):
322322
"""
323323
assert xs.shape == ys.shape
324324
assert xs.shape == zs.shape
325-
return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
325+
return core.ShapedArray(xs.shape, xs.dtype)
326326
327327
# Now we register the abstract evaluation with JAX
328328
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

jax/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@
155155

156156
# These submodules are separate because they are in an import cycle with
157157
# jax and rely on the names imported above.
158-
from jax import abstract_arrays as abstract_arrays
158+
from jax import abstract_arrays as _deprecated_abstract_arrays
159159
from jax import custom_derivatives as custom_derivatives
160160
from jax import custom_batching as custom_batching
161161
from jax import custom_transpose as custom_transpose
@@ -186,6 +186,11 @@
186186
del _ccache
187187

188188
_deprecations = {
189+
# Added 06 June 2023
190+
"abstract_arrays": (
191+
"jax.abstract_arrays is deprecated. Refer to jax.core.",
192+
_deprecated_abstract_arrays
193+
),
189194
# Added 28 March 2023
190195
"ShapedArray": (
191196
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
@@ -219,6 +224,7 @@
219224

220225
import typing as _typing
221226
if _typing.TYPE_CHECKING:
227+
from jax._src import abstract_arrays as abstract_arrays
222228
from jax._src.core import ShapedArray as ShapedArray
223229
from jax.interpreters import ad as ad
224230
from jax.interpreters import partial_eval as partial_eval

jax/abstract_arrays.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,28 @@
1414

1515
# TODO(phawkins): fix users of these aliases and delete this file.
1616

17-
from jax._src.abstract_arrays import array_types
17+
from jax._src.abstract_arrays import array_types as _deprecated_array_types
1818
from jax._src.core import (
19-
ShapedArray,
20-
raise_to_shaped,
19+
ShapedArray as _deprecated_ShapedArray,
20+
raise_to_shaped as _deprecated_raise_to_shaped,
2121
)
22+
23+
_deprecations = {
24+
# Added 06 June 2023
25+
"array_types": (
26+
"jax.abstract_arrays.array_types is deprecated.",
27+
_deprecated_array_types,
28+
),
29+
"ShapedArray": (
30+
"jax.abstract_arrays.ShapedArray is deprecated. Use jax.core.ShapedArray.",
31+
_deprecated_ShapedArray,
32+
),
33+
"raise_to_shaped": (
34+
"jax.abstract_arrays.raise_to_shaped is deprecated. Use jax.core.raise_to_shaped.",
35+
_deprecated_raise_to_shaped,
36+
),
37+
}
38+
39+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
40+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
41+
del _deprecation_getattr

setup.cfg

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ per-file-ignores =
2121
# F401: unused imports
2222
# Note: we don't use jax/*.py because this matches contents of jax/_src
2323
__init__.py:F401
24-
jax/abstract_arrays.py:F401
2524
jax/ad_checkpoint.py:F401
2625
jax/api_util.py:F401
2726
jax/cloud_tpu_init.py:F401

0 commit comments

Comments
 (0)