Skip to content

ReLU bounding might lead to vanishing gradient #244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jakob-schloer opened this issue Apr 9, 2025 · 8 comments
Open

ReLU bounding might lead to vanishing gradient #244

jakob-schloer opened this issue Apr 9, 2025 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@jakob-schloer
Copy link
Collaborator

jakob-schloer commented Apr 9, 2025

After discussions with Tobias Finn and @lzampier, it came up that ReLU output bounding could lead to vanishing gradients.

Brief explanation:

  • When ReLU is the final function to ensure positive outputs, the network could get stuck if the predictions happen to be negative at some stage in the training. This could, for instance, be the case after the initialization of the network.
  • The gradient would be zero for all negative predictions, making it impossible to update the weights to predict positive values.

Alternatives to ensure positive outputs:

  • Softplus: f(x) = log(1 + e^x) is similar to ReLU but without the zero gradient problem
  • Exponential: f(x) = e^x but this is very sensitive to input changes

Any thoughts on that? Was that ever a problem?

@jakob-schloer jakob-schloer closed this as not planned Won't fix, can't repro, duplicate, stale Apr 9, 2025
@lzampier
Copy link
Member

lzampier commented Apr 9, 2025

I am adding @sahahner because she is more optimistic about this issue. It will be good to discuss soon and land on an optimal implementation :)

@lzampier lzampier reopened this Apr 9, 2025
@lzampier lzampier self-assigned this Apr 9, 2025
@lzampier lzampier added the bug Something isn't working label Apr 9, 2025
@sahahner
Copy link
Member

We have considered this for some sea ice variables, where most of the grid points have values equal to zero.
We still use relu/hardtanh bounding because we have generally noticed good convergence. We need some kind of bounding as if not using any bounding the focus of the loss function seems to be wrong.
Interestingly, the model outputs after initialisation seem to be mostly positive (I tend to check the plots). Maybe we can understand this better.
Happy to try out different bounding functions in the sea ice case.
I wonder whether the sea ice edge can be as crisp as we want it to be with another bounding?

@JesperDramsch
Copy link
Member

I don't see why this is a problem, in fact, I thought of the output clipping as something positive for the gradients. (So adding my 2 cents for the full spectrum of problem-benefit). But I'm also just working on the engineering, so I can be fully convinced otherwise.

My understanding of the training dynamics with "clipping" (let's dismiss the notion of ReLU for a second, because it might have other implications), gives a degree of freedom to the network in that it has to learn the "correct" dynamics for all positive values but the gradients that lead to zero can be "whatever" and don't get optimised. So we still get the proper optimisation for positive values but actually avoid possible jitter around zero.

Here we benefit from the interconnectedness of our implementation, the amount of data we have, and the fact we are building a global model that has enough sampling of, e.g. precip. The second you split up variables into multiple outputs, this would have to be re-evaluated, but in that case the correct implementation would include a penalty loss function anyways.

We could also consider looking at initialization of the layers, but they're (at least for linear layers) sampled uniformly from -n_features^-.5, n_features^-.5, so we'd have to get really unlucky with the initialisation and the sampling of the training data to have this actually become a problem. But maybe I'm missing a crucial aspect, so I can be convinced that this was actually really bad and isn't just a theoretical possibility when all the stars align.

@sahahner
Copy link
Member

This PR is related: #256

@pinnstorm
Copy link
Member

pinnstorm commented Apr 11, 2025

Interesting ideas! I like the way you put it @JesperDramsch, @gabrieloks has done some nice working looking at the negative space for ReLU bounded precipitation and it is quite interesting to see how it behaves.

Slightly connected as @sahahner has tagged we found the HardTanh function and a high loss scaling on Soil Moisture had caused issues with AIFS Single "Rain Pox" 😱 . Using a leaky formulation here seems to have helped (as well as down weighting soil moisture fairly aggressively) but will be good to understand if this is really what is driving the improvement (all other variables still have the same bounding, only soil moisture is "leaky").

We have a branch here with some other activation function boundings that might be interesting to try (SiLU, ScaledTanh, LeakReLU, LeakyHardTanh) https://github.com/ecmwf/anemoi-core/compare/feature/leaky-bounding

@gabrieloks
Copy link
Contributor

gabrieloks commented Apr 13, 2025

Yes, this is something we are aware of, and it actually happens with some variables in the AIFS 1.0 training. Here is a retrain of the current operational model, and you can see that there is this "falling of the cliff" behaviour that is quite visible in the train loss.

Image

When you check where this behaviour is coming from, it's mostly soil moisture and runoff.

Image

In this case, soil moisture is bounded to 0 on the minimum values and the model, in the first steps of training, is stuck in a state where it just outputs zeros everywhere. This is also enhanced by the fact that we are predicting soil moisture over the oceans, so a zero prediction is not that bad. But very interestingly, the model manages to jump out of this state where in theory in should be a blackhole (due to vanishing gradients). I think what is going on here is that, the bounded variables (sm, tp, etc...) are not the main driving forces in the model. We have t, z, winds etc driving the main dynamics and the other ones following. As soon as the model gains some skill in the forecast, the chances of the vanishing gradient actually being a problem are low, I would speculate. Especially in diagnostic variables.

Intuitively, I would agree with @JesperDramsch on the fact that clean cut clipping the output is beneficial to the model. I think by doing this, you are simplifying so much the prediction of zeros. Instead of forcing the model to learn to predict precisely the zero value, you transform the prediction to a classification problem (rain (positives) / no rain (negatives) ). I remember @JesperDramsch using this analogy some time ago and I liked it very much and I think it's quite true! I also wonder if we can get still maintain the "crispness" in the forecasts without clipping as @sahahner mentioned.

But I still agree that the vanishing gradient possibility is concerning, and we should explore it. I think the leaky approach @pinnstorm is exploring is quite interesting and would love to test when back from holidays haha

MLFlow Run with the falling of the cliff behaviour

https://mlflow.ecmwf.int/#/metric?runs=[%2232ae7d066de9438cbd15cf1acbf60d2a%22]&metric=%22val_wmae%2Fsfc_swvl1%2F1%22&experiments=[%22109%22]&plot_metric_keys=%5B%22train_wmse_step%22%5D&plot_layout={%22autosize%22:true,%22xaxis%22:{},%22yaxis%22:{%22type%22:%22log%22,%22autorange%22:true,%22exponentformat%22:%22e%22}}&x_axis=relative&y_axis_scale=linear&line_smoothness=1&show_point=false&deselected_curves=[]&last_linear_y_axis_range=[]

@clessig
Copy link

clessig commented Apr 14, 2025

@gabrieloks, many thanks for the details. Another direction would be to pretrain with the "dynamics" primitive equation variables and then add some of the more problematic ones only later in the training when the model has learned the overall dynamics. This might also help to avoid learning spurious correlations.

@sahahner
Copy link
Member

sahahner commented Apr 15, 2025

For the sea ice, we have moved some bounding functions to postprocessors to avoid the problem of vanishing gradients. I am working on the PR at the moment.
Great to see that using leaky bounding helps solving the vanishing gradient problem. Then it might be useful to apply non-leaky post processors after the loss function to enforce the physical consistency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Done
Development

No branches or pull requests

7 participants