Skip to content

Commit b7715e2

Browse files
committed
Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32. Closes jax-ml#18847
1 parent 433f66a commit b7715e2

File tree

10 files changed

+65
-85
lines changed

10 files changed

+65
-85
lines changed

jax/_src/core.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -2076,14 +2076,16 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
20762076
if handler: return handler(aval, weak_type)
20772077
raise TypeError(type(aval))
20782078

2079-
raise_to_shaped_mappings : dict[type, Callable] = {
2080-
AbstractToken: lambda aval, _: aval,
2081-
Bot: lambda aval, _: aval,
2082-
UnshapedArray: lambda aval, _: aval,
2083-
ShapedArray: lambda aval, weak_type: ShapedArray(
2084-
aval.shape, aval.dtype, weak_type, aval.named_shape),
2085-
DConcreteArray: lambda aval, weak_type: DShapedArray(
2086-
aval.shape, aval.dtype, weak_type),
2079+
raise_to_shaped_mappings: dict[type, Callable] = {
2080+
AbstractToken: lambda aval, _: aval,
2081+
Bot: lambda aval, _: aval,
2082+
UnshapedArray: lambda aval, _: aval,
2083+
ShapedArray: lambda aval, weak_type: ShapedArray(
2084+
aval.shape, aval.dtype, weak_type, aval.named_shape
2085+
),
2086+
DConcreteArray: lambda aval, weak_type: DShapedArray(
2087+
aval.shape, aval.dtype, weak_type
2088+
),
20872089
}
20882090

20892091
### Operations on shapes and dimension sizes.

jax/_src/lax/control_flow/for_loop.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def wrapped_body(i, refs):
132132
nsteps, = nsteps
133133
flat_state, state_tree = tree_flatten(init_state)
134134
state_avals = map(state_utils.val_to_ref_aval, flat_state)
135-
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
135+
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64))
136136
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
137137
body, state_tree, [idx_aval, *state_avals])
138138
if out_tree != tree_structure(None):
@@ -251,7 +251,7 @@ def body(i, state):
251251

252252
def _for_impl_unrolled(body, nsteps, unroll, *args):
253253
remainder = nsteps % unroll
254-
i = jnp.int32(0)
254+
i = jnp.astype(0, dtypes.canonicalize_dtype(jnp.int64))
255255
state = list(args)
256256

257257
for _ in range(remainder):
@@ -748,15 +748,15 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
748748
"""
749749
flat_state, state_tree = tree_flatten(init_state)
750750
state_avals = map(state_utils.val_to_ref_aval, flat_state)
751-
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
751+
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64))
752752
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
753753
body, state_tree, [idx_aval, *state_avals])
754754
if out_tree != tree_structure(None):
755755
raise Exception("`body` should not return anything.")
756756
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
757757

758758
def fori_body(i, carry):
759-
i = jnp.int32(i)
759+
i = jnp.astype(i, dtypes.canonicalize_dtype(jnp.int64))
760760
if reverse:
761761
i = nsteps - i - 1
762762
out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,

jax/_src/pallas/pallas_call.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def body(carry):
274274
len(blocks),
275275
len(scratch_values),
276276
)
277-
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
277+
blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
278278
*blocks, *scratch)
279279
blocks = blocks[grid_mapping.num_index_operands:]
280280
blocks, out_scratch = split_list(blocks, [num_inout])
@@ -787,7 +787,7 @@ def checked_kernel_fn(*args):
787787
# errors before other arguments.
788788
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
789789
assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args)
790-
result_flat = jax.core.eval_jaxpr(
790+
result_flat = jax_core.eval_jaxpr(
791791
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
792792
output_errors, _ = split_list(result_flat, [num_err_vals])
793793
# Store new errors back in the error refs.

jax/_src/pallas/primitives.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from jax import tree_util
2727
from jax._src import ad_util
2828
from jax._src import core as jax_core
29+
from jax._src import dtypes
2930
from jax._src import effects
3031
from jax._src import pretty_printer as pp
3132
from jax._src import state
@@ -359,7 +360,12 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
359360
# of bounds, it will instead move the start_index backwards so the slice
360361
# will fit in memory.
361362
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
362-
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
363+
idx_dtype = dtypes.canonicalize_dtype(jnp.int64)
364+
out_ones = lax.dynamic_slice(
365+
ref,
366+
[jnp.astype(s, idx_dtype) for s in slice_starts],
367+
slice_sizes=slice_sizes,
368+
)
363369
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
364370
out = out_ones[out_indexer]
365371
elif all(not isinstance(s, Slice) for s in idx.indices):

jax/_src/pallas/triton/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pytype_strict_library(
5252
"//jax",
5353
"//jax:ad_util",
5454
"//jax:api_util",
55+
"//jax:config",
5556
"//jax:core",
5657
"//jax:mlir",
5758
"//jax:partial_eval",

jax/_src/pallas/triton/lowering.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from jax._src import ad_checkpoint
3030
from jax._src import ad_util
3131
from jax._src import api_util
32+
from jax._src import config
3233
from jax._src import core as jax_core
3334
from jax._src import custom_derivatives
3435
from jax._src import linear_util as lu
@@ -2263,9 +2264,10 @@ def _for_lowering_rule(
22632264
del which_linear
22642265
if reverse or unroll != 1:
22652266
raise NotImplementedError
2266-
lower_bound = _i32_constant(0)
2267-
upper_bound = _i32_constant(nsteps)
2268-
step = _i32_constant(1)
2267+
_i_constant = _i64_constant if config.enable_x64.value else _i32_constant
2268+
lower_bound = _i_constant(0)
2269+
upper_bound = _i_constant(nsteps)
2270+
step = _i_constant(1)
22692271
init_args = map(_ensure_ir_value, args, ctx.avals_in)
22702272
# Partially discharge state from jaxpr for non-pointers
22712273
should_discharge = [

jax/_src/state/primitives.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
190190
f"Expected shape: {expected_out_shape}. "
191191
f"Value shape: {val_aval.shape}. "
192192
f"Indices: {indexers}. ")
193-
if ref_aval.dtype != val_aval.dtype:
193+
if ref_aval.dtype != val_aval.dtype and not val_aval.weak_type:
194194
raise ValueError("Invalid dtype for `swap`. "
195195
f"Ref dtype: {ref_aval.dtype}. "
196196
f"Value dtype: {val_aval.dtype}. ")

jax/_src/state/types.py

+6
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ class AbstractRef(core.AbstractValue):
132132
def __init__(self, inner_aval: core.AbstractValue):
133133
self.inner_aval = inner_aval
134134

135+
@property
136+
def weak_type(self) -> bool:
137+
if not hasattr(self.inner_aval, "weak_type"):
138+
raise AttributeError
139+
return self.inner_aval.weak_type
140+
135141
def update(self, inner_aval=None):
136142
if inner_aval is None:
137143
return AbstractRef(self.inner_aval)

tests/pallas/BUILD

-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ jax_test(
4141
disable_configs = [
4242
"gpu",
4343
"gpu_x32",
44-
"gpu_a100",
45-
"gpu_h100",
4644
"gpu_p100",
4745
"gpu_p100_x32",
4846
],

0 commit comments

Comments
 (0)