Skip to content

[BUG] aux_loss and z_loss is incorrect when use calculate_per_token_loss and cp #1652

Open
@Suparjie

Description

@Suparjie

Describe the bug

  1. If using calculate_per_token_loss and cp > 1,
    firstly, aux_loss is divided by the square of full num_tokens (considered cp)
    num_tokens = probs.shape[0] * num_sub_sequence
    )

secondly, aux_loss is scaled by num_local_tokens here.

activation = MoEAuxLossAutoScaler.apply(activation, aux_loss * activation.shape[0])

finally, scale both the main_loss gradient and aux_loss gradient by 1/(num_local_tokens * dp_size * num_micro_batches) in finalize_model_grads function.
however, the num_local_tokens is not local but full.

torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())

so we should scale aux_loss by full num_tokens (considered cp and sp)not num_local_tokens

  1. If not use calculate_per_token_loss but use cp, gradient is divided by dp*cp in finalize_model_grads function. lm_loss is scaled by cp in advance, but aux_loss is not scaled by cp, so should we multiply aux_loss by cp?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions