Skip to content

Commit cd63d14

Browse files
Jake VanderPlasThe oryx Authors
Jake VanderPlas
authored and
The oryx Authors
committed
Remove references to deprecated submodule jax.abstract_arrays
Use jax.core instead (see jax-ml/jax#16271) PiperOrigin-RevId: 538767146
1 parent 92b0585 commit cd63d14

File tree

6 files changed

+11
-17
lines changed

6 files changed

+11
-17
lines changed

oryx/core/interpreters/harvest.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def f(x):
138138
import functools
139139
from typing import Any, Callable, Dict, FrozenSet, Hashable, Iterable, List, Optional, Tuple, Union
140140

141-
from jax import abstract_arrays
142141
from jax import api_util
143142
from jax import lax
144143
from jax import linear_util as lu
@@ -419,7 +418,7 @@ def __init__(self, trace: 'HarvestTrace', val: Value):
419418

420419
@property
421420
def aval(self):
422-
return abstract_arrays.raise_to_shaped(jax_core.get_aval(self.val))
421+
return jax_core.raise_to_shaped(jax_core.get_aval(self.val))
423422

424423
def full_lower(self):
425424
return self
@@ -512,7 +511,7 @@ def handle_sow(self, *values, name, tag, tree, mode):
512511
raise ValueError(f'Variable has already been reaped: {name}')
513512
avals = tree_util.tree_unflatten(
514513
tree,
515-
[abstract_arrays.raise_to_shaped(jax_core.get_aval(v)) for v in values])
514+
[jax_core.raise_to_shaped(jax_core.get_aval(v)) for v in values])
516515
self.reaps[name] = Reap(
517516
tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals))
518517
return values
@@ -781,7 +780,7 @@ def _get_harvest_metadata(closed_jaxpr, settings, *args):
781780
flat_args, in_tree = tree_util.tree_flatten(args)
782781
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
783782
in_avals = jax_util.safe_map(
784-
lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)),
783+
lambda a: jax_core.raise_to_shaped(jax_core.get_aval(a)),
785784
flat_args)
786785
pe.trace_to_jaxpr_final(flat_fun, in_avals)
787786
metadata = aux()

oryx/core/interpreters/inverse/core.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Iterable
1818

1919
import jax
20-
from jax import abstract_arrays
2120
from jax import tree_util
2221
from jax import util as jax_util
2322
from jax._src import core as jax_core
@@ -142,7 +141,7 @@ def unknown(cls, aval):
142141
def new(cls, val):
143142
val = np.array(val)
144143
aval = jax_core.get_aval(val)
145-
aval = abstract_arrays.raise_to_shaped(aval)
144+
aval = jax_core.raise_to_shaped(aval)
146145
ndslice = NDSlice.new(val, np.zeros_like(val))
147146
return InverseAndILDJ(aval, frozenset([ndslice]))
148147

@@ -319,8 +318,8 @@ def map_ildj(prim, incells, outcells, **params):
319318
f, incells = incells[0], incells[1:]
320319

321320
def slice_aval(aval):
322-
return abstract_arrays.ShapedArray(aval.shape[1:], aval.dtype,
323-
aval.weak_type)
321+
return jax_core.ShapedArray(aval.shape[1:], aval.dtype,
322+
aval.weak_type)
324323

325324
def add_slice(cell, old_cell):
326325
new_slices = [

oryx/core/ppl/effect_handler_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Tests for oryx.core.ppl.effect_handler."""
1616
from absl.testing import absltest
1717
import jax
18-
from jax import abstract_arrays
1918
from jax import random
2019
import jax.numpy as np
2120

@@ -40,7 +39,7 @@ def _random_normal_impl(key, loc, scale):
4039
@random_normal_p.def_abstract_eval
4140
def _random_normal_abstract(key, loc, scale):
4241
del key, loc, scale
43-
return [abstract_arrays.ShapedArray((), np.float32)]
42+
return [jax.core.ShapedArray((), np.float32)]
4443

4544

4645
class EffectHandlerTest(test_util.TestCase):

oryx/core/ppl/transformations_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from absl.testing import absltest
1717

1818
import jax
19-
from jax import abstract_arrays
2019
from jax import random
2120
from jax._src import core as jax_core
2221
from jax.interpreters import batching
@@ -61,7 +60,7 @@ def random_normal_impl(rng, *, batch_ndims):
6160

6261
def random_normal_abstract(key, **_):
6362
del key
64-
return abstract_arrays.ShapedArray((), jnp.float32)
63+
return jax_core.ShapedArray((), jnp.float32)
6564

6665

6766
def random_normal_log_prob_rule(incells, outcells, *, batch_ndims, **_):

oryx/core/primitive.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import itertools as it
1717
from typing import Callable
1818

19-
from jax import abstract_arrays
2019
from jax import api_util
2120
from jax import linear_util as lu
2221
from jax import tree_util
@@ -237,7 +236,7 @@ def subcall(self, name):
237236
tie_all_p.multiple_results = True
238237
tie_all_p.def_impl(lambda *args: args)
239238
tie_all_p.def_abstract_eval(lambda *args: safe_map( # pylint: disable=g-long-lambda
240-
abstract_arrays.raise_to_shaped, args))
239+
jax_core.raise_to_shaped, args))
241240

242241
mlir.register_lowering(tie_all_p, lambda c, *args: args)
243242

oryx/core/trace_util.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import threading
1818
from typing import Any, Dict, Generator, List
1919

20-
from jax import abstract_arrays
2120
from jax import api_util
2221
from jax import linear_util as lu
2322
from jax import tree_util
@@ -41,9 +40,9 @@
4140
def get_shaped_aval(x):
4241
"""Converts a JAX value type into a shaped abstract value."""
4342
if hasattr(x, 'dtype') and hasattr(x, 'shape'):
44-
return abstract_arrays.ShapedArray(
43+
return jax_core.ShapedArray(
4544
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True))
46-
return abstract_arrays.raise_to_shaped(jax_core.get_aval(x))
45+
return jax_core.raise_to_shaped(jax_core.get_aval(x))
4746

4847

4948
def pv_like(x, abstract=True):

0 commit comments

Comments
 (0)