Skip to content

Commit 354cb65

Browse files
committed
[model cards] Keep evaluation order in training logs if there's multiple evaluators (#2963)
Also rename "loss" to Validation Loss
1 parent 6a3750e commit 354cb65

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

sentence_transformers/model_card.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def on_evaluate(
152152
**kwargs,
153153
) -> None:
154154
loss_dict = {" ".join(key.split("_")[1:]): metrics[key] for key in metrics if key.endswith("_loss")}
155+
if len(loss_dict) == 1 and "loss" in loss_dict:
156+
loss_dict = {"Validation Loss": loss_dict["loss"]}
155157
if (
156158
model.model_card_data.training_logs
157159
and model.model_card_data.training_logs[-1]["Step"] == state.global_step
@@ -830,19 +832,25 @@ def try_to_pure_python(value: Any) -> Any:
830832

831833
def format_training_logs(self):
832834
# Get the keys from all evaluation lines
833-
eval_lines_keys = {key for lines in self.training_logs for key in lines.keys()}
835+
eval_lines_keys = []
836+
for lines in self.training_logs:
837+
for key in lines.keys():
838+
if key not in eval_lines_keys:
839+
eval_lines_keys.append(key)
834840

835841
# Sort the metric columns: Epoch, Step, Training Loss, Validation Loss, Evaluator results
836842
def sort_metrics(key: str) -> str:
837843
if key == "Epoch":
838-
return "0"
844+
return 0
839845
if key == "Step":
840-
return "1"
846+
return 1
841847
if key == "Training Loss":
842-
return "2"
848+
return 2
849+
if key == "Validation Loss":
850+
return 3
843851
if key.endswith("loss"):
844-
return "3"
845-
return key
852+
return 4
853+
return eval_lines_keys.index(key) + 5
846854

847855
sorted_eval_lines_keys = sorted(eval_lines_keys, key=sort_metrics)
848856
training_logs = [

0 commit comments

Comments
 (0)