Skip to content

Mark jax.abstract_arrays as deprecated. #16271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.12

* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
* `jax.abstract_arrays` and its contents are now deprecated. See related
functionality in :mod:`jax.core`.
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.

## jaxlib 0.4.12

Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/How_JAX_primitives_work.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@
}
],
"source": [
"from jax._src import abstract_arrays\n",
"from jax import core\n",
"@trace(\"multiply_add_abstract_eval\")\n",
"def multiply_add_abstract_eval(xs, ys, zs):\n",
" \"\"\"Abstract evaluation of the primitive.\n",
Expand All @@ -533,7 +533,7 @@
" \"\"\"\n",
" assert xs.shape == ys.shape\n",
" assert xs.shape == zs.shape\n",
" return abstract_arrays.ShapedArray(xs.shape, xs.dtype)\n",
" return core.ShapedArray(xs.shape, xs.dtype)\n",
"\n",
"# Now we register the abstract evaluation with JAX\n",
"multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)"
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/How_JAX_primitives_work.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ In the latter case, JAX uses the actual concrete value wrapped as an abstract va
:id: ctQmEeckIbdo
:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2

from jax._src import abstract_arrays
from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
Expand All @@ -322,7 +322,7 @@ def multiply_add_abstract_eval(xs, ys, zs):
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
return core.ShapedArray(xs.shape, xs.dtype)

# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
Expand Down
8 changes: 7 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@

# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from jax import abstract_arrays as abstract_arrays
from jax import abstract_arrays as _deprecated_abstract_arrays
from jax import custom_derivatives as custom_derivatives
from jax import custom_batching as custom_batching
from jax import custom_transpose as custom_transpose
Expand Down Expand Up @@ -186,6 +186,11 @@
del _ccache

_deprecations = {
# Added 06 June 2023
"abstract_arrays": (
"jax.abstract_arrays is deprecated. Refer to jax.core.",
_deprecated_abstract_arrays
),
# Added 28 March 2023
"ShapedArray": (
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
Expand Down Expand Up @@ -219,6 +224,7 @@

import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src import abstract_arrays as abstract_arrays
from jax._src.core import ShapedArray as ShapedArray
from jax.interpreters import ad as ad
from jax.interpreters import partial_eval as partial_eval
Expand Down
33 changes: 30 additions & 3 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,35 @@

# TODO(phawkins): fix users of these aliases and delete this file.

from jax._src.abstract_arrays import array_types
from jax._src.abstract_arrays import array_types as _deprecated_array_types
from jax._src.core import (
ShapedArray,
raise_to_shaped,
ShapedArray as _deprecated_ShapedArray,
raise_to_shaped as _deprecated_raise_to_shaped,
)

_deprecations = {
# Added 06 June 2023
"array_types": (
"jax.abstract_arrays.array_types is deprecated.",
_deprecated_array_types,
),
"ShapedArray": (
"jax.abstract_arrays.ShapedArray is deprecated. Use jax.core.ShapedArray.",
_deprecated_ShapedArray,
),
"raise_to_shaped": (
"jax.abstract_arrays.raise_to_shaped is deprecated. Use jax.core.raise_to_shaped.",
_deprecated_raise_to_shaped,
),
}

import typing
if typing.TYPE_CHECKING:
from jax._src.abstract_arrays import array_types as array_types
from jax._src.core import ShapedArray as ShapedArray
from jax._src.core import raise_to_shaped as raise_to_shaped
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing