Should I reset train state metrics after each epoch of training? #3284
-
Hi, The MNIST example at https://flax.readthedocs.io/en/latest/getting_started.html replaces the state metrics to empty after each training epoch. for step,batch in enumerate(train_ds.as_numpy_iterator()):
# Run optimization steps over training batches and compute batch metrics
state = train_step(state, batch) # get updated train state (which contains the updated parameters)
state = compute_metrics(state=state, batch=batch) # aggregate batch metrics
if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
for metric,value in state.metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch
... But another MNIST example at https://github.com/google/flax/blob/main/examples/mnist/train.py#L118 does not use this technique. On the Internet, it seems most people are not using this reset technique. I wonder about the benefit of using it and when I should use it. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Here's a comparison of training w/o resetting the metrics.
So, apparently, resetting the metrics causes some statistical differences. So, should we reset the metrics after each training epoch? |
Beta Was this translation helpful? Give feedback.
In practice you probably just want use the per step metrics during training, that is, don't use cumulative metrics for training. Using cumulative metrics for training only makes sense to get a smooth plot on a notebook environment, tools such as tensorboard or wandb let you smooth the metrics later which is much more convenient as you can still see the real data and not miss loss spikes.