Skip to content

Guide for Flax partitioning (auto SPMD) API #2730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 23, 2022
Merged

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Dec 16, 2022

Original code PR: #2704

@IvyZX IvyZX requested a review from levskaya December 16, 2022 04:03
@IvyZX IvyZX self-assigned this Dec 16, 2022
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@marcvanzee
Copy link
Collaborator

@IvyZX the doc build is failing, some problem with the sharding code. Could you please fix it?

@@ -0,0 +1,1135 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"scaled module output" what does "scaled" mean here? do you mean "sharded" or am I missing something?


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed "scaled" - hopefully it's clearer now.

@@ -0,0 +1,1135 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Profilling" --> "Profiling" (one "L")


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,1135 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"replaced"->"replaces"

We should perhaps note that there are more complicated/advanced sharding patterns that require annotating -activation- dimension names differently from -parameter- dimension names. For really advanced stuff, people may wish to do fine-grained manual mesh assignments using the simpler system above, but the logical-naming helpers are useful for getting started in exploring different sharding layouts.


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a new section at the bottom of logic axis introduction. Let me know if there could be a better phrasing of the idea!

@IvyZX IvyZX force-pushed the pjit branch 3 times, most recently from fb0813e to dba3086 Compare December 18, 2022 23:12
@IvyZX
Copy link
Collaborator Author

IvyZX commented Dec 18, 2022

@IvyZX the doc build is failing, some problem with the sharding code. Could you please fix it?

This seems to be due to JAX dropping Python 3.7. I switched readthedocs build to 3.8 and the test passes now.

Copy link
Collaborator

@marcvanzee marcvanzee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really cool! I reviewed it only partially, but thought I'd send the comments out already.

!pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax"
```

```python executionInfo={"elapsed": 2, "status": "ok", "timestamp": 1671158206224, "user": {"displayName": "Ivy Zheng", "userId": "15297372265856137303"}, "user_tz": 480} id="FJzB9vfp-OoK"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These cells look a bit odd, why is there all this execution info? I think in this way they aren't rendered properly and you don't get Python syntax highlighting (see preview here). Maybe just use python?




Each value passed into a sharding annotates a dimension, and its value should be among `data`, `model` or `None`. This refers to whether the data's axis should be sharded across one of the device mesh dimensions, or not sharded at all.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really get this sentence. What does it mean for a value to be passed into a sharding? Are you referring to the pairs like ('data', 'model') that are passed to with_partitioning and with_sharding_constraints?

Perhaps it is clearer to give an example and explain one of the calls below in more detail. For instance, I suppose the definition of W1 means that you define a parameter called W1, which has two dimensions(x.shape[-1], self.depth), and the sharding specification means that you don't partition the first axis (I think this means you replicate it on all devices?) and you partitiong the second axis over the model dimension. Is that correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct. I rephrased it to:

Calling these APIs annotates each dimension of your parameter and intermediate variables with an axis name, namely 'data', 'model' or None. This refers to how this dimension should be sharded - across one of the device mesh dimensions, or not sharded at all.

For example, if we define W1, which has shape (x.shape[-1], self.depth), with (None, 'model'), the first dimension (of length x.shape[-1]) will be replicated across all devices, while the second dimension (of length self.depth) will be sharded over the model axis of the device mesh.

@IvyZX IvyZX force-pushed the pjit branch 2 times, most recently from 36a086f to 3e81c37 Compare December 19, 2022 21:50
@IvyZX IvyZX requested a review from marcvanzee December 19, 2022 21:52
@IvyZX
Copy link
Collaborator Author

IvyZX commented Dec 19, 2022

Thank you Marc! I address all your comments, either fixing them directly or leaving some comments beneath yours. Please take a look and let me know what you think.

Copy link
Collaborator

@8bitmp3 8bitmp3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvyZX A few suggestions for the first part of the guide. Thanks for putting this together! Amazing doc

---

<!-- #region id="2-JYevEklHfB" -->
# Scale up Flax modules with `pjit`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

Maybe mention multi-host/multi-devices in the title so that users know what this scaling up with pjit means without indepth JAX knowledge. Similar to "Using JAX in multi-host and multi-process environments" https://jax.readthedocs.io/en/latest/multi_process.html


In this guide, you will learn how to use JAX's `pjit` and `flax.linen.spmd` to scale up your Flax modules on multiple devices and hosts.

Modern JAX provides [`jax.experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) as a way to automatically compile and scale up JAX computations using the [SPMD](https://jax.readthedocs.io/en/latest/glossary.html?highlight=spmd#term-SPMD) model. You can think of `pjit` as [`jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) where one explicitly specificies how the input and output data should be partitioned across devices. See the [JAX guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) for more information on pjit.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions:

  • Do we need "modern"? (Currently, there's just one JAX 0.4, isn't it?)
  • Spell out "Single Program Multi Data" before abbreviating it.
  • Use the SPMD link without the highlight=... part like https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)
  • Last sentence:
    • "You can think of pjit as [jax.jit]..." (with jax.jit).
    • Use "Refer to" instead of "See" (more accessibility-friendly language), "JAX guide" -> "JAX-101 pjit tutorial" (more precise).
    • pjit (formatting).
Suggested change
Modern JAX provides [`jax.experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) as a way to automatically compile and scale up JAX computations using the [SPMD](https://jax.readthedocs.io/en/latest/glossary.html?highlight=spmd#term-SPMD) model. You can think of `pjit` as [`jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) where one explicitly specificies how the input and output data should be partitioned across devices. See the [JAX guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) for more information on pjit.
JAX provides [`jax.experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) as a way to automatically compile and scale up JAX computations using the [Single Program Multiple Device (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) model. You can think of JAX's `pjit` as [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) where you can explicitly specify how the input and output data should be partitioned across devices. Refer to the [JAX-101 `pjit` tutorial](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) for more information on `pjit`.


Modern JAX provides [`jax.experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) as a way to automatically compile and scale up JAX computations using the [SPMD](https://jax.readthedocs.io/en/latest/glossary.html?highlight=spmd#term-SPMD) model. You can think of `pjit` as [`jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) where one explicitly specificies how the input and output data should be partitioned across devices. See the [JAX guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) for more information on pjit.

Flax provides an interface for you to specify partitions for your module parameters when defining your module, and later use it for `pjit` compilation. You can also use logical axis annotations to decouple your model code and partition plan, in order to customize and experiment with different partition layouts more easily.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions:

Suggested change
Flax provides an interface for you to specify partitions for your module parameters when defining your module, and later use it for `pjit` compilation. You can also use logical axis annotations to decouple your model code and partition plan, in order to customize and experiment with different partition layouts more easily.
Flax provides an interface to specify partitions for your [Module parameters](https://flax.readthedocs.io/en/latest/advanced_topics/arguments.html) when defining your [Flax `Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module), and later use it for `pjit` compilation. You can also use logical axis annotations to decouple your model code and partition plan to customize and experiment with different partition layouts more easily.


Install Flax from head and make imports as necessary.

Note that this Colab uses `--xla_force_host_platform_device_count=8` to emulate multiple devices on a CPU environment. You may also choose to run on a multi-device TPU environment following [this guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#setup), in which case you should ignore the `os.environ` cell.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

Suggested change
Note that this Colab uses `--xla_force_host_platform_device_count=8` to emulate multiple devices on a CPU environment. You may also choose to run on a multi-device TPU environment following [this guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#setup), in which case you should ignore the `os.environ` cell.
Note that this guide uses `--xla_force_host_platform_device_count=8` to emulate multiple devices on a CPU environment in a Google Colab/Jupyter Notebook. You can also follow [this JAX-101 `pjit` guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#setup) to emulate a multi-device TPU environment (in which case you should ignore the `os.environ` cell).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we link 2x to the same jax-101/08-pjit guide, could we use the same name when referring to it?
(Saves the user an extra click when reading through the guide)

Copy link
Collaborator

@marcvanzee marcvanzee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome tutorial @IvyZX, great work and thanks a lot! Just some minor comments but overall LGTM.

<!-- #region id="XfuDTqt1mO6g" -->
Import all the `pjit` related libraries. Note that they are still in the experimental package of JAX, so some API may change. In the near future, the plan is to move the `pjit` API out of the experimental status and merge it with the current `jax.jit` API.

Start a device mesh using the 8 devices available. In this example, we set them as a `2x4` device mesh, same as the layout of TPU v3-8, and annotate each axis with a name. A typical way to annotate axis names is `('data', 'model')`: `'data'` being the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations, and the `'model'` being the mesh dimension used for sharding parameters of the model across devices.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this some more, if you have a 2x4 device and you annotate axis names with ('data', 'model'), then this mean that if you partition over the data dimension you are effectively doing a two-way partition over the first dimension of the devices, and if you partition over the model dimension you partition it four way over the second dimension, right? It might be good to explain this explicitly, so the relation between the devices and the axis names becomes clearer.

```

<!-- #region id="3d6b15d08980" -->
However, to generate `PartitionSpec` for the output, we need to use some actual output as reference. One solution is to create a model and evaluate `model.init` abstractly using `jax.eval_shape`, and then use `nn.get_partition_spec` to automatically generate the `PartitionSpec`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"we need to use some actual output as reference" --> don't we need to generate the variable dict structure rather than the model outputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works too. I am just trying to show a more complicated example here, in which the pjitted function output is another dataclass instead of the simple variable dict. Added the explanation in doc.


> A side note: Here we define our `init_fn` as purely functional and takes `model` and `optimizer` as arguments. This is not necessary - you can simply define with `def init_fn(k, x):` and all will work fine here.
>
> This guide doesn't do it because later we will show you another way to define your model and will run `init_fn` with another model instance. But this does a bit of trouble because `jax.eval_shape` only takes numeric inputs, and we have to create an abstract closure before feeding the function in.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
> This guide doesn't do it because later we will show you another way to define your model and will run `init_fn` with another model instance. But this does a bit of trouble because `jax.eval_shape` only takes numeric inputs, and we have to create an abstract closure before feeding the function in.
> This guide doesn't do it because later we will show you another way to define your model and will run `init_fn` with another model instance. However, this is problematic because `jax.eval_shape` only takes numeric inputs, so we have to create an abstract closure before feeding the function in.

```

<!-- #region id="KAyPfNP9pTR4" -->
# `pjit` the initialization and train step
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initialization with pjit is complex! Quite a number of steps, maybe we can simplify this with a simple utility function (not now, but maybe as a follow-up?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use decorator @functools.partial(pjit, ...) like jit. I added the explanation in the doc.

```

<!-- #region id="dcLFwH1dhr3v" -->
## `pjit` the train step and inference
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't clear which parts of these code blocks refer to "train step" and which refer to "inference". Maybe add some text before the inference code block to clarify that.

<!-- #region id="vXY2mx7WjOyv" -->
## Profiling

If you are running on a TPU pod or pod slice, you can use the `block_all` util below to profile the cell below to observe a performance increase.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If you are running on a TPU pod or pod slice, you can use the `block_all` util below to profile the cell below to observe a performance increase.
If you are running on a TPU pod or pod slice, you can use the `block_all` utility function below to observe a performance increase.

<!-- #region id="GUXkJiQkotfm" -->
# Logical axis annotation

JAX auto SPMD encourages user to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any array with more descriptive axis names, instead of only the device mesh axis names (i.e., `data` and `model`).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
JAX auto SPMD encourages user to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any array with more descriptive axis names, instead of only the device mesh axis names (i.e., `data` and `model`).
JAX auto SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any array with more descriptive axis names, instead of only the device mesh axis names (i.e., `data` and `model`).


Check out the `Logical-` model definitions below. It's exactly the same with the model above, except for two differences:

1. Each axis are annotated with more concrete, meaningful names - like `embed`, `hidden`, `batch` and `layer`. These names are referred as "logical axis names" in Flax. They make the dimensional changes inside model definition more readable.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
1. Each axis are annotated with more concrete, meaningful names - like `embed`, `hidden`, `batch` and `layer`. These names are referred as "logical axis names" in Flax. They make the dimensional changes inside model definition more readable.
1. All axes are annotated with more concrete, meaningful names - like `embed`, `hidden`, `batch` and `layer`. These names are referred as "logical axis names" in Flax. They make the dimensional changes inside model definitions more readable.

@andsteing andsteing self-requested a review December 20, 2022 17:17
Copy link
Collaborator

@andsteing andsteing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great guide! This is very helpful to make use of nn.with_partitioning and friends.

I'm fairly new to pjit, so I also marked some things where I had genuine understanding questions.


Install Flax from head and make imports as necessary.

Note that this Colab uses `--xla_force_host_platform_device_count=8` to emulate multiple devices on a CPU environment. You may also choose to run on a multi-device TPU environment following [this guide](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#setup), in which case you should ignore the `os.environ` cell.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we link 2x to the same jax-101/08-pjit guide, could we use the same name when referring to it?
(Saves the user an extra click when reading through the guide)

<!-- #endregion -->

```python id="e3d005e672d6"
# TODO: replace this with `pip3 install flax` when 0.6.4 release rolls out.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do # TODO(username ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a public doc, it's probably better to not expose our names. I can open an issue to remind myself of that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi agreed with @IvyZX

# Once Flax v0.6.4 is out, replace this with `pip3 install flax`.

This way, the docs can be more or less up-to-date after 0.6.4.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use Github username (which is already known) ... Or create an issue and assign it to yourself.

The main motivation here is to avoid accumulating TODOs in public docs.


Note that in the output `initialized_state`, the params `W1` and `W2` are of type [`Partitioned`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.Partitioned.html). This is a wrapper around the actual JAX array that allows Flax to record metadata associated with it. You can access the raw JAX array by adding `.value` or running `.unbox()`.

You can also check the underlying `.sharding` of the JAX array, which gives a hint on the way it is partitioned.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the interpretation of .sharding merits a comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fairly new concept in JAX and is not very well documented as for now. I only find one mentioning of OpSharding here: https://jax.readthedocs.io/en/latest/jax_array_migration.html#why-create-jax-array

pjit_init_fn = pjit(init_fn,
static_argnums=(2, 3),
in_axis_resources=(PartitionSpec(None), x_spec), # RNG key and x
out_axis_resources=state_spec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to specify out_axis_resources here?
(I don't see a difference when I don't specify it; if not needed, that would make the usage pattern a bit simpler because we could get rid of the jax.eval_shape() step above)

... although if I don't specify it here, then the call to pjit_setp_fn() below fails with "Sharding passed to pjit does not match the sharding on the respective arg".

(but I don't quite understand why, maybe worth pointing out in this guide?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX team is making pjit more automatic so that user don't have to specify input and output axis. But as far as Anselm and I know, this approach is not very efficient on complex model for now, so we decided to still present the version that explicitly states the axis.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ... so it's still a bit rough around the edges, I see.

I think I would still mention this because in many examples users will see examples of pjit() where the out_axis_resources is omitted (e.g. our very own flax.core.meta.Partitioned).

<!-- #region id="vXY2mx7WjOyv" -->
## Profiling

If you are running on a TPU pod or pod slice, you can use the `block_all` util below to profile the cell below to observe a performance increase.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do I see the "performance increase" ?

@cgarciae
Copy link
Collaborator

cgarciae commented Dec 21, 2022

Thanks Ivy! This material is super useful, I've been reading and hacking stuff around this topic so personally its great to have this. Just a few remarks:

  1. Given that JAX 0.4 is out, should we consider using the new APIs (e.g. device_put + jit)? I am a bit biased because this is how I am learning.
  2. On one hand, showing the scan trick is nice as its used in real life, on the other hand its an additional complexity that might not be ideal as its yet another concept in an already packed guide.
  3. Some references or explanation as to why the parameters in DotReluDot are sharded the way there are would be nice. I just happened learn about this before reading this guide so it made perfect sense but without it I'd have some questions.

@marcvanzee
Copy link
Collaborator

On one hand, showing the scan trick is nice as its used in real life, on the other hand its an additional complexity that might not be ideal as its yet another concept in an already packed guide.

+1, I think scan over layers deserves it own guide and right now it is a bit much to take in here.

But to avoid stalling this guide, maybe we can keep it in but @IvyZX can file an issue saying we want to add a guide for scan over layers, and then link to it from here?

@cgarciae
Copy link
Collaborator

@marcvanzee agreed. Lets merge this PR and improve the guide in the near future. Maybe just collect some of the points in this comments in an issue.

@8bitmp3
Copy link
Collaborator

8bitmp3 commented Dec 21, 2022

fyi Going through the guide with @IvyZX today.

@IvyZX
Copy link
Collaborator Author

IvyZX commented Dec 22, 2022

Thank you all for the reviews! I have addressed all the comments listed here, either by a reply comment or by directly modifying the doc content. I also filed #2752 for scan over layer doc.

Given that JAX 0.4 is out, should we consider using the new APIs (e.g. device_put + jit)? I am a bit biased because this is how I am learning.

I think that's a good idea to explore in the near future. JAX team is actively developing on pjit and there probably will be other changes in the API. Let's stay tuned!

Some references or explanation as to why the parameters in DotReluDot are sharded the way there are would be nice. I just happened learn about this before reading this guide so it made perfect sense but without it I'd have some questions.

I got feedback from @8bitmp3 that this guide is already pretty heavy, so I am a bit hesitant to add more in-depth explanations. The idea of pjit is that the compiler will automatically choose how everything will be sharded, and more interpretation may diminish its flexibility (maybe not in this simple example though). I guess if people start to ask questions in the future, we can find a way to document it.


To replicate identical layers, you can either use `flax.linen.scan` or a for-loop: `flax.linen.scan` can offer faster compilation times, whereas the for-loop can be faster on runtime. This guide uses `flax.linen.scan` simply to show that [Flax lifted transforms](https://flax.readthedocs.io/en/latest/advanced_topics/lift.html#supported-transformations) work together with [JAX `pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html).

**Note:** `flax.linen.scan` will introduce another dimension for the params (the dimension over which `scan` is applied), and you need to use the `metadata_params` argument to annotate the partition of this dimension. Since the parameters inside your `DotReluDot` (a sub-`Module`) are already sharded along the `model` axis, you don't need to partition multiple layers across the `model` dimension here, and therefore you should denote it as `None`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this "note" is formatted differently from other "notes" - is this on purpose?
image
vs
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andsteing We are doing a 2nd (3rd?) review with @IvyZX asap, going through the second half. Will make sure "Note:" is formatted consistently 👍 Thanks for the feedback!!

def init_fn(k, x, model, optimizer):
variables = model.init(k, x)
state = train_state.TrainState.create(
apply_fn=model.apply,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2 spaces missing

pjit_init_fn = pjit(init_fn,
static_argnums=(2, 3),
in_axis_resources=(PartitionSpec(None), x_spec), # RNG key and x
out_axis_resources=state_spec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ... so it's still a bit rough around the edges, I see.

I think I would still mention this because in many examples users will see examples of pjit() where the out_axis_resources is omitted (e.g. our very own flax.core.meta.Partitioned).

Copy link
Collaborator

@andsteing andsteing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had another go at the updated guide. Nice improvements, thanks!

I left some more comments, but will approve here, since it's mostly nits and nothing is blocking.

@8bitmp3
Copy link
Collaborator

8bitmp3 commented Dec 22, 2022

Last review done dcf5517. Thanks @IvyZX for the amazing work 👍

@IvyZX IvyZX mentioned this pull request Dec 23, 2022
@copybara-service copybara-service bot merged commit 0453d38 into google:main Dec 23, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants