Get runtime sharding without doing any FLOPs inside nnx.shard_map
#4797
Unanswered
carlesoctav
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a modifier/function that basically converts any module into an FSDP module. In short, it adds
gather_params()
andscatter_params()
hooks before and after__call__
.By default, any model has no sharding metadata, but I add runtime metadata when
__call__
gets called. However, I've encountered a problem:shape = jax.eval_shape(model.init, param_rng, x)
and get the partition vialinen.get_partition_spec()
. This function returns the PartitionSpec fornn.partitioned
classes. I think this works becausemodel.init
calls__call__
at least once, soscatter_params()
gets called.Since NNX creates the params when we create the object in
__init__
, it doesn't even call__call__
. We could simply run the model once, but that would use FLOPs to get the runtime PartitionSpec.I tried running
nnx.eval_shape(lambda x: model(x))
and then doingnnx.get_partition_spec(nnx.state(model))
to return the partition_spec, but I got an error because PartitionSpec is not a valid JAX type.Please note that I need to run this under
nnx.shard_map
because I need to usejax.lax
modules (for example, to find indices and do other operations).Any ideas on how I can achieve this? Thanks!
Beta Was this translation helpful? Give feedback.
All reactions