Skip to content

Commit d9cc780

Browse files
authored
feat: add a check for whether values in score_df are NaN (#756)
* add a check for whether values in score_df are NaN * fix ci * change raise to assert
1 parent 113889f commit d9cc780

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

rdagent/components/coder/data_science/ensemble/eval_tests/ensemble_test.txt

+5
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,10 @@ assert model_set_in_scores == set({{model_names}}).union({"ensemble"}), (
128128
assert score_df.index.is_unique, "The scores dataframe has duplicate model names."
129129
assert score_df.columns.tolist() == ["{{metric_name}}"], f"The column names of the scores dataframe should be ['{{metric_name}}'], but is '{score_df.columns.tolist()}'"
130130

131+
# Check for NaN values in score_df
132+
assert not score_df.isnull().values.any(), (
133+
f"The scores dataframe contains NaN values at the following locations:\n{score_df[score_df.isnull().any(axis=1)]}"
134+
)
135+
131136

132137
print("Ensemble test end.")

rdagent/components/coder/data_science/pipeline/eval.py

+6
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def evaluate(
8282
score_check_text += f"\n[Error] The scores dataframe does not contain the correct column names.\nCorrect columns is: ['{self.scen.metric_name}']\nBut got: {score_df.columns.tolist()}"
8383
score_ret_code = 1
8484

85+
# Check if scores contain NaN (values)
86+
if score_df.isnull().values.any():
87+
nan_locations = score_df[score_df.isnull().any(axis=1)]
88+
score_check_text += f"\n[Error] The scores dataframe contains NaN values at the following locations:\n{nan_locations}"
89+
score_ret_code = 1
90+
8591
except Exception as e:
8692
score_check_text += f"\n[Error] in checking the scores.csv file: {e}\nscores.csv's content:\n-----\n{score_fp.read_text()}\n-----"
8793
score_ret_code = 1

rdagent/components/coder/data_science/workflow/eval.py

+6
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def evaluate(
106106
score_check_text += f"\n[Error] The scores dataframe does not contain the correct column names.\nCorrect columns is: ['{self.scen.metric_name}']\nBut got: {score_df.columns.tolist()}"
107107
score_ret_code = 1
108108

109+
# Check if scores contain NaN (values)
110+
if score_df.isnull().values.any():
111+
nan_locations = score_df[score_df.isnull().any(axis=1)]
112+
score_check_text += f"\n[Error] The scores dataframe contains NaN values at the following locations:\n{nan_locations}"
113+
score_ret_code = 1
114+
109115
except Exception as e:
110116
score_check_text += f"\n[Error] in checking the scores.csv file: {e}\nscores.csv's content:\n-----\n{score_fp.read_text()}\n-----"
111117
score_ret_code = 1

0 commit comments

Comments
 (0)