Open
Description
Describe the bug
- If using calculate_per_token_loss and cp > 1,
firstly, aux_loss is divided by the square of full num_tokens (considered cp)
)
secondly, aux_loss is scaled by num_local_tokens here.
finally, scale both the main_loss gradient and aux_loss gradient by 1/(num_local_tokens * dp_size * num_micro_batches) in finalize_model_grads function.
however, the num_local_tokens is not local but full.
Line 179 in a845aa7
so we should scale aux_loss by full num_tokens (considered cp and sp)not num_local_tokens
- If not use calculate_per_token_loss but use cp, gradient is divided by dp*cp in finalize_model_grads function. lm_loss is scaled by cp in advance, but aux_loss is not scaled by cp, so should we multiply aux_loss by cp?
Metadata
Metadata
Assignees
Labels
No labels