-
Notifications
You must be signed in to change notification settings - Fork 716
Guide for Flax partitioning (auto SPMD) API #2730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@IvyZX the doc build is failing, some problem with the sharding code. Could you please fix it? |
docs/guides/spmd_api.ipynb
Outdated
@@ -0,0 +1,1135 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"scaled module output" what does "scaled" mean here? do you mean "sharded" or am I missing something?
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed "scaled" - hopefully it's clearer now.
docs/guides/spmd_api.ipynb
Outdated
@@ -0,0 +1,1135 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
docs/guides/spmd_api.ipynb
Outdated
@@ -0,0 +1,1135 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"replaced"->"replaces"
We should perhaps note that there are more complicated/advanced sharding patterns that require annotating -activation- dimension names differently from -parameter- dimension names. For really advanced stuff, people may wish to do fine-grained manual mesh assignments using the simpler system above, but the logical-naming helpers are useful for getting started in exploring different sharding layouts.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a new section at the bottom of logic axis introduction. Let me know if there could be a better phrasing of the idea!
fb0813e
to
dba3086
Compare
This seems to be due to JAX dropping Python 3.7. I switched readthedocs build to 3.8 and the test passes now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really cool! I reviewed it only partially, but thought I'd send the comments out already.
docs/guides/spmd_api.md
Outdated
!pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax" | ||
``` | ||
|
||
```python executionInfo={"elapsed": 2, "status": "ok", "timestamp": 1671158206224, "user": {"displayName": "Ivy Zheng", "userId": "15297372265856137303"}, "user_tz": 480} id="FJzB9vfp-OoK" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These cells look a bit odd, why is there all this execution info? I think in this way they aren't rendered properly and you don't get Python syntax highlighting (see preview here). Maybe just use python
?
docs/guides/spmd_api.md
Outdated
|
||
|
||
|
||
Each value passed into a sharding annotates a dimension, and its value should be among `data`, `model` or `None`. This refers to whether the data's axis should be sharded across one of the device mesh dimensions, or not sharded at all. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really get this sentence. What does it mean for a value to be passed into a sharding? Are you referring to the pairs like ('data', 'model')
that are passed to with_partitioning
and with_sharding_constraints
?
Perhaps it is clearer to give an example and explain one of the calls below in more detail. For instance, I suppose the definition of W1
means that you define a parameter called W1
, which has two dimensions(x.shape[-1], self.depth)
, and the sharding specification means that you don't partition the first axis (I think this means you replicate it on all devices?) and you partitiong the second axis over the model
dimension. Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is correct. I rephrased it to:
Calling these APIs annotates each dimension of your parameter and intermediate variables with an axis name, namely 'data'
, 'model'
or None
. This refers to how this dimension should be sharded - across one of the device mesh dimensions, or not sharded at all.
For example, if we define W1
, which has shape (x.shape[-1], self.depth)
, with (None, 'model')
, the first dimension (of length x.shape[-1]
) will be replicated across all devices, while the second dimension (of length self.depth
) will be sharded over the model
axis of the device mesh.
36a086f
to
3e81c37
Compare
Thank you Marc! I address all your comments, either fixing them directly or leaving some comments beneath yours. Please take a look and let me know what you think. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IvyZX A few suggestions for the first part of the guide. Thanks for putting this together! Amazing doc
docs/guides/spmd_api.md
Outdated
--- | ||
|
||
<!-- #region id="2-JYevEklHfB" --> | ||
# Scale up Flax modules with `pjit` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion:
Maybe mention multi-host/multi-devices in the title so that users know what this scaling up with pjit
means without indepth JAX knowledge. Similar to "Using JAX in multi-host and multi-process environments" https://jax.readthedocs.io/en/latest/multi_process.html
docs/guides/spmd_api.md
Outdated
|
||
In this guide, you will learn how to use JAX's `pjit` and `flax.linen.spmd` to scale up your Flax modules on multiple devices and hosts. | ||
|
||
Modern JAX provides [`jax.experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) as a way to automatically compile and scale up JAX computations using the [SPMD](https://jax.readthedocs.io/en/latest/glossary.html?highlight=spmd#term-SPMD) model. You can think of `pjit` as [`jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) where one explicitly specificies how the input and output data should be partitioned across devices. See the [JAX guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) for more information on pjit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestions:
- Do we need "modern"? (Currently, there's just one JAX 0.4, isn't it?)
- Spell out "Single Program Multi Data" before abbreviating it.
- Use the SPMD link without the
highlight=...
part likehttps://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)
- Last sentence:
- "You can think of
pjit
as [jax.jit
]..." (withjax.jit
). - Use "Refer to" instead of "See" (more accessibility-friendly language), "JAX guide" -> "JAX-101
pjit
tutorial" (more precise). pjit
(formatting).
- "You can think of
Modern JAX provides [`jax.experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) as a way to automatically compile and scale up JAX computations using the [SPMD](https://jax.readthedocs.io/en/latest/glossary.html?highlight=spmd#term-SPMD) model. You can think of `pjit` as [`jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) where one explicitly specificies how the input and output data should be partitioned across devices. See the [JAX guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) for more information on pjit. | |
JAX provides [`jax.experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) as a way to automatically compile and scale up JAX computations using the [Single Program Multiple Device (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) model. You can think of JAX's `pjit` as [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) where you can explicitly specify how the input and output data should be partitioned across devices. Refer to the [JAX-101 `pjit` tutorial](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) for more information on `pjit`. |
docs/guides/spmd_api.md
Outdated
|
||
Modern JAX provides [`jax.experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) as a way to automatically compile and scale up JAX computations using the [SPMD](https://jax.readthedocs.io/en/latest/glossary.html?highlight=spmd#term-SPMD) model. You can think of `pjit` as [`jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) where one explicitly specificies how the input and output data should be partitioned across devices. See the [JAX guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) for more information on pjit. | ||
|
||
Flax provides an interface for you to specify partitions for your module parameters when defining your module, and later use it for `pjit` compilation. You can also use logical axis annotations to decouple your model code and partition plan, in order to customize and experiment with different partition layouts more easily. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestions:
Flax provides an interface for you to specify partitions for your module parameters when defining your module, and later use it for `pjit` compilation. You can also use logical axis annotations to decouple your model code and partition plan, in order to customize and experiment with different partition layouts more easily. | |
Flax provides an interface to specify partitions for your [Module parameters](https://flax.readthedocs.io/en/latest/advanced_topics/arguments.html) when defining your [Flax `Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module), and later use it for `pjit` compilation. You can also use logical axis annotations to decouple your model code and partition plan to customize and experiment with different partition layouts more easily. |
docs/guides/spmd_api.md
Outdated
|
||
Install Flax from head and make imports as necessary. | ||
|
||
Note that this Colab uses `--xla_force_host_platform_device_count=8` to emulate multiple devices on a CPU environment. You may also choose to run on a multi-device TPU environment following [this guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#setup), in which case you should ignore the `os.environ` cell. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion:
Note that this Colab uses `--xla_force_host_platform_device_count=8` to emulate multiple devices on a CPU environment. You may also choose to run on a multi-device TPU environment following [this guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#setup), in which case you should ignore the `os.environ` cell. | |
Note that this guide uses `--xla_force_host_platform_device_count=8` to emulate multiple devices on a CPU environment in a Google Colab/Jupyter Notebook. You can also follow [this JAX-101 `pjit` guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#setup) to emulate a multi-device TPU environment (in which case you should ignore the `os.environ` cell). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we link 2x to the same jax-101/08-pjit
guide, could we use the same name when referring to it?
(Saves the user an extra click when reading through the guide)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome tutorial @IvyZX, great work and thanks a lot! Just some minor comments but overall LGTM.
docs/guides/spmd_api.md
Outdated
<!-- #region id="XfuDTqt1mO6g" --> | ||
Import all the `pjit` related libraries. Note that they are still in the experimental package of JAX, so some API may change. In the near future, the plan is to move the `pjit` API out of the experimental status and merge it with the current `jax.jit` API. | ||
|
||
Start a device mesh using the 8 devices available. In this example, we set them as a `2x4` device mesh, same as the layout of TPU v3-8, and annotate each axis with a name. A typical way to annotate axis names is `('data', 'model')`: `'data'` being the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations, and the `'model'` being the mesh dimension used for sharding parameters of the model across devices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about this some more, if you have a 2x4
device and you annotate axis names with ('data', 'model')
, then this mean that if you partition over the data
dimension you are effectively doing a two-way partition over the first dimension of the devices, and if you partition over the model
dimension you partition it four way over the second dimension, right? It might be good to explain this explicitly, so the relation between the devices and the axis names becomes clearer.
docs/guides/spmd_api.md
Outdated
``` | ||
|
||
<!-- #region id="3d6b15d08980" --> | ||
However, to generate `PartitionSpec` for the output, we need to use some actual output as reference. One solution is to create a model and evaluate `model.init` abstractly using `jax.eval_shape`, and then use `nn.get_partition_spec` to automatically generate the `PartitionSpec`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"we need to use some actual output as reference" --> don't we need to generate the variable dict structure rather than the model outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works too. I am just trying to show a more complicated example here, in which the pjitted function output is another dataclass instead of the simple variable dict. Added the explanation in doc.
docs/guides/spmd_api.md
Outdated
|
||
> A side note: Here we define our `init_fn` as purely functional and takes `model` and `optimizer` as arguments. This is not necessary - you can simply define with `def init_fn(k, x):` and all will work fine here. | ||
> | ||
> This guide doesn't do it because later we will show you another way to define your model and will run `init_fn` with another model instance. But this does a bit of trouble because `jax.eval_shape` only takes numeric inputs, and we have to create an abstract closure before feeding the function in. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
> This guide doesn't do it because later we will show you another way to define your model and will run `init_fn` with another model instance. But this does a bit of trouble because `jax.eval_shape` only takes numeric inputs, and we have to create an abstract closure before feeding the function in. | |
> This guide doesn't do it because later we will show you another way to define your model and will run `init_fn` with another model instance. However, this is problematic because `jax.eval_shape` only takes numeric inputs, so we have to create an abstract closure before feeding the function in. |
docs/guides/spmd_api.md
Outdated
``` | ||
|
||
<!-- #region id="KAyPfNP9pTR4" --> | ||
# `pjit` the initialization and train step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initialization with pjit is complex! Quite a number of steps, maybe we can simplify this with a simple utility function (not now, but maybe as a follow-up?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use decorator @functools.partial(pjit, ...)
like jit
. I added the explanation in the doc.
docs/guides/spmd_api.md
Outdated
``` | ||
|
||
<!-- #region id="dcLFwH1dhr3v" --> | ||
## `pjit` the train step and inference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't clear which parts of these code blocks refer to "train step" and which refer to "inference". Maybe add some text before the inference code block to clarify that.
docs/guides/spmd_api.md
Outdated
<!-- #region id="vXY2mx7WjOyv" --> | ||
## Profiling | ||
|
||
If you are running on a TPU pod or pod slice, you can use the `block_all` util below to profile the cell below to observe a performance increase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are running on a TPU pod or pod slice, you can use the `block_all` util below to profile the cell below to observe a performance increase. | |
If you are running on a TPU pod or pod slice, you can use the `block_all` utility function below to observe a performance increase. |
docs/guides/spmd_api.md
Outdated
<!-- #region id="GUXkJiQkotfm" --> | ||
# Logical axis annotation | ||
|
||
JAX auto SPMD encourages user to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any array with more descriptive axis names, instead of only the device mesh axis names (i.e., `data` and `model`). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX auto SPMD encourages user to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any array with more descriptive axis names, instead of only the device mesh axis names (i.e., `data` and `model`). | |
JAX auto SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any array with more descriptive axis names, instead of only the device mesh axis names (i.e., `data` and `model`). |
docs/guides/spmd_api.md
Outdated
|
||
Check out the `Logical-` model definitions below. It's exactly the same with the model above, except for two differences: | ||
|
||
1. Each axis are annotated with more concrete, meaningful names - like `embed`, `hidden`, `batch` and `layer`. These names are referred as "logical axis names" in Flax. They make the dimensional changes inside model definition more readable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1. Each axis are annotated with more concrete, meaningful names - like `embed`, `hidden`, `batch` and `layer`. These names are referred as "logical axis names" in Flax. They make the dimensional changes inside model definition more readable. | |
1. All axes are annotated with more concrete, meaningful names - like `embed`, `hidden`, `batch` and `layer`. These names are referred as "logical axis names" in Flax. They make the dimensional changes inside model definitions more readable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great guide! This is very helpful to make use of nn.with_partitioning
and friends.
I'm fairly new to pjit
, so I also marked some things where I had genuine understanding questions.
docs/guides/spmd_api.md
Outdated
|
||
Install Flax from head and make imports as necessary. | ||
|
||
Note that this Colab uses `--xla_force_host_platform_device_count=8` to emulate multiple devices on a CPU environment. You may also choose to run on a multi-device TPU environment following [this guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#setup), in which case you should ignore the `os.environ` cell. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we link 2x to the same jax-101/08-pjit
guide, could we use the same name when referring to it?
(Saves the user an extra click when reading through the guide)
docs/guides/spmd_api.md
Outdated
<!-- #endregion --> | ||
|
||
```python id="e3d005e672d6" | ||
# TODO: replace this with `pip3 install flax` when 0.6.4 release rolls out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do # TODO(username
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a public doc, it's probably better to not expose our names. I can open an issue to remind myself of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fyi agreed with @IvyZX
# Once Flax v0.6.4 is out, replace this with `pip3 install flax`.
This way, the docs can be more or less up-to-date after 0.6.4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use Github username (which is already known) ... Or create an issue and assign it to yourself.
The main motivation here is to avoid accumulating TODOs in public docs.
docs/guides/spmd_api.md
Outdated
|
||
Note that in the output `initialized_state`, the params `W1` and `W2` are of type [`Partitioned`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.Partitioned.html). This is a wrapper around the actual JAX array that allows Flax to record metadata associated with it. You can access the raw JAX array by adding `.value` or running `.unbox()`. | ||
|
||
You can also check the underlying `.sharding` of the JAX array, which gives a hint on the way it is partitioned. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the interpretation of .sharding
merits a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a fairly new concept in JAX and is not very well documented as for now. I only find one mentioning of OpSharding here: https://jax.readthedocs.io/en/latest/jax_array_migration.html#why-create-jax-array
docs/guides/spmd_api.md
Outdated
pjit_init_fn = pjit(init_fn, | ||
static_argnums=(2, 3), | ||
in_axis_resources=(PartitionSpec(None), x_spec), # RNG key and x | ||
out_axis_resources=state_spec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need to specify out_axis_resources
here?
(I don't see a difference when I don't specify it; if not needed, that would make the usage pattern a bit simpler because we could get rid of the jax.eval_shape()
step above)
... although if I don't specify it here, then the call to pjit_setp_fn()
below fails with "Sharding passed to pjit does not match the sharding on the respective arg".
(but I don't quite understand why, maybe worth pointing out in this guide?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX team is making pjit
more automatic so that user don't have to specify input and output axis. But as far as Anselm and I know, this approach is not very efficient on complex model for now, so we decided to still present the version that explicitly states the axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ... so it's still a bit rough around the edges, I see.
I think I would still mention this because in many examples users will see examples of pjit()
where the out_axis_resources
is omitted (e.g. our very own flax.core.meta.Partitioned
).
docs/guides/spmd_api.md
Outdated
<!-- #region id="vXY2mx7WjOyv" --> | ||
## Profiling | ||
|
||
If you are running on a TPU pod or pod slice, you can use the `block_all` util below to profile the cell below to observe a performance increase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how do I see the "performance increase" ?
Thanks Ivy! This material is super useful, I've been reading and hacking stuff around this topic so personally its great to have this. Just a few remarks:
|
+1, I think scan over layers deserves it own guide and right now it is a bit much to take in here. But to avoid stalling this guide, maybe we can keep it in but @IvyZX can file an issue saying we want to add a guide for scan over layers, and then link to it from here? |
@marcvanzee agreed. Lets merge this PR and improve the guide in the near future. Maybe just collect some of the points in this comments in an issue. |
fyi Going through the guide with @IvyZX today. |
Thank you all for the reviews! I have addressed all the comments listed here, either by a reply comment or by directly modifying the doc content. I also filed #2752 for scan over layer doc.
I think that's a good idea to explore in the near future. JAX team is actively developing on
I got feedback from @8bitmp3 that this guide is already pretty heavy, so I am a bit hesitant to add more in-depth explanations. The idea of |
docs/guides/flax_on_pjit.md
Outdated
|
||
To replicate identical layers, you can either use `flax.linen.scan` or a for-loop: `flax.linen.scan` can offer faster compilation times, whereas the for-loop can be faster on runtime. This guide uses `flax.linen.scan` simply to show that [Flax lifted transforms](https://flax.readthedocs.io/en/latest/advanced_topics/lift.html#supported-transformations) work together with [JAX `pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html). | ||
|
||
**Note:** `flax.linen.scan` will introduce another dimension for the params (the dimension over which `scan` is applied), and you need to use the `metadata_params` argument to annotate the partition of this dimension. Since the parameters inside your `DotReluDot` (a sub-`Module`) are already sharded along the `model` axis, you don't need to partition multiple layers across the `model` dimension here, and therefore you should denote it as `None`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andsteing We are doing a 2nd (3rd?) review with @IvyZX asap, going through the second half. Will make sure "Note:" is formatted consistently 👍 Thanks for the feedback!!
def init_fn(k, x, model, optimizer): | ||
variables = model.init(k, x) | ||
state = train_state.TrainState.create( | ||
apply_fn=model.apply, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: 2 spaces missing
docs/guides/spmd_api.md
Outdated
pjit_init_fn = pjit(init_fn, | ||
static_argnums=(2, 3), | ||
in_axis_resources=(PartitionSpec(None), x_spec), # RNG key and x | ||
out_axis_resources=state_spec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ... so it's still a bit rough around the edges, I see.
I think I would still mention this because in many examples users will see examples of pjit()
where the out_axis_resources
is omitted (e.g. our very own flax.core.meta.Partitioned
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had another go at the updated guide. Nice improvements, thanks!
I left some more comments, but will approve here, since it's mostly nits and nothing is blocking.
Original code PR: #2704