@@ -132,7 +132,7 @@ def wrapped_body(i, refs):
132
132
nsteps , = nsteps
133
133
flat_state , state_tree = tree_flatten (init_state )
134
134
state_avals = map (state_utils .val_to_ref_aval , flat_state )
135
- idx_aval = core .ShapedArray ((), jnp . dtype ( "int32" ))
135
+ idx_aval = core .ShapedArray ((), dtypes . canonicalize_dtype ( jnp . int64 ))
136
136
jaxpr , consts , out_tree = _trace_to_jaxpr_with_refs (
137
137
body , state_tree , [idx_aval , * state_avals ])
138
138
if out_tree != tree_structure (None ):
@@ -251,7 +251,7 @@ def body(i, state):
251
251
252
252
def _for_impl_unrolled (body , nsteps , unroll , * args ):
253
253
remainder = nsteps % unroll
254
- i = jnp .int32 ( 0 )
254
+ i = jnp .astype ( 0 , dtypes . canonicalize_dtype ( jnp . int64 ) )
255
255
state = list (args )
256
256
257
257
for _ in range (remainder ):
@@ -748,15 +748,15 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
748
748
"""
749
749
flat_state , state_tree = tree_flatten (init_state )
750
750
state_avals = map (state_utils .val_to_ref_aval , flat_state )
751
- idx_aval = core .ShapedArray ((), jnp . dtype ( "int32" ))
751
+ idx_aval = core .ShapedArray ((), dtypes . canonicalize_dtype ( jnp . int64 ))
752
752
jaxpr , consts , out_tree = _trace_to_jaxpr_with_refs (
753
753
body , state_tree , [idx_aval , * state_avals ])
754
754
if out_tree != tree_structure (None ):
755
755
raise Exception ("`body` should not return anything." )
756
756
discharged_jaxpr , discharged_consts = discharge_state (jaxpr , consts )
757
757
758
758
def fori_body (i , carry ):
759
- i = jnp .int32 ( i )
759
+ i = jnp .astype ( i , dtypes . canonicalize_dtype ( jnp . int64 ) )
760
760
if reverse :
761
761
i = nsteps - i - 1
762
762
out_flat = core .eval_jaxpr (discharged_jaxpr , discharged_consts ,
0 commit comments