Description
(Note: original issue was raised by @patrickvonplaten after Flax/HuggingFace hackathon)
For some pretraining and fine-tuning tasks using the correct initialization scheme is very important to have correct model pretraining. E.g. some pretraining methods don't work at all with incorrect initialization. I've had quite a time-consuming time working with the JAX init functions this morning and think the JAX initialization functions: https://jax.readthedocs.io/en/latest/jax.nn.initializers.html currently have very poor documentation IMO. A couple of things that could be improved:
It is not at all obvious that jax.nn.initializers.zeros and jax.nn.initializers.ones are not 1-to-1 replacebable by jax.nn.initializers.normal, and others. What I mean here is that one would pass:
flax.linen.Conv(
...
bias_init=jax.nn.initializers.zeros
...
)
but would have to do
flax.linen.Conv(
...
bias_init=jax.nn.initializers.uniform(1.0)
...
)
=> flax.linen.Module expects for all init functions a funcition of signature (key, shape) it seems, but it's no where mentioned that zeros and ones already have that form, but that the other init funcitons need to be called to return a function of that form. => Better docs and examples are needed here -> IMO both in the JAX docs as well as flax.linen (BTW, the flax.linen.Conv` docs are pretty hard to read: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Conv.html).
-
The function design is here less intuitive than the PyTorch equivalent IMO. E.g.:
-
Why have default scale value of uniform to be 0.02?? The default case when poeple talk about uniform is [0, 1] IMO -> so the default case should be scale=1.0
-
Also, I think the uniform signature with scale is not well chosen. Most people IMO get to know the uniform distribution as being x ~ U[a,b]. PyTorch has the inputs a, and b and also shows the mathematical expression in the docs: https://pytorch.org/docs/stable/nn.init.html?highlight=uniform#torch.nn.init.uniform_. This is intuitive and I don't have to think twice about how to use torch.nn.init.uniform_. In JAX however I see a function uniform(scale=0.02) without any docs -> so I don't know how this corresponds to U[a,b] & I don't know if setting scale=1.0 it corresponds to U[-1, 1] or to U[0, 1]. I can only find this out by trying the function => very time consuming! At least the mathematical expression for how scale influences the distribution should be used here.
-
Why does jax.nn.initilaizers.normal has 0.01 as the default stddev? Everybody would assume the default std is 1.0.
-
There are more or less no docs at all for the init functions.