diff --git a/torchci/scripts/check_alerts.py b/torchci/scripts/check_alerts.py index 9b9da5b282..d1ff4e3cac 100755 --- a/torchci/scripts/check_alerts.py +++ b/torchci/scripts/check_alerts.py @@ -444,8 +444,16 @@ def handle_flaky_tests_alert(existing_alerts: List[Dict]) -> Dict: return None -def check_for_recurrently_failing_jobs_alert(repo: str, branch: str, dry_run: bool): +# filter job names that don't match the regex +def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]: + if job_name_regex: + return [job_name for job_name in job_names if re.match(job_name_regex, job_name)] + return job_names + + +def check_for_recurrently_failing_jobs_alert(repo: str, branch: str, job_name_regex: str, dry_run: bool): job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch) + job_names = filter_job_names(job_names, job_name_regex) (jobs_to_alert_on, flaky_jobs) = classify_jobs(job_names, sha_grid) # Fetch alerts @@ -497,6 +505,12 @@ def parse_args() -> argparse.Namespace: type=str, default=os.getenv("BRANCH_TO_CHECK", "master") ) + parser.add_argument( + "--job-name-regex", + help="Consider only job names matching given regex (if omitted, all jobs are matched)", + type=str, + default=os.getenv("JOB_NAME_REGEX", "") + ) parser.add_argument( "--with-flaky-test-alert", help="Run this script with the flaky test alerting", @@ -514,7 +528,7 @@ def parse_args() -> argparse.Namespace: def main(): args = parse_args() - check_for_recurrently_failing_jobs_alert(args.repo, args.branch, args.dry_run) + check_for_recurrently_failing_jobs_alert(args.repo, args.branch, args.job_name_regex, args.dry_run) # TODO: Fill out dry run for flaky test alerting, not going to do in one PR if args.with_flaky_test_alert: check_for_no_flaky_tests_alert() diff --git a/torchci/scripts/test_check_alerts.py b/torchci/scripts/test_check_alerts.py index 6fd66e9aa4..8c6a9393c0 100644 --- a/torchci/scripts/test_check_alerts.py +++ b/torchci/scripts/test_check_alerts.py @@ -3,6 +3,7 @@ from unittest.mock import patch from check_alerts import ( + filter_job_names, gen_update_comment, generate_no_flaky_tests_issue, handle_flaky_tests_alert, @@ -126,6 +127,36 @@ def test_handle_flaky_tests_alert( res = handle_flaky_tests_alert(existing_alerts) self.assertDictEqual(res, mock_issue) + # test filter job names + def test_job_filter(self): + job_names = ["pytorch_linux_xenial_py3_6_gcc5_4_test", "pytorch_linux_xenial_py3_6_gcc5_4_test2"] + self.assertListEqual( + filter_job_names(job_names, ""), + job_names, + "empty regex should match all jobs" + ) + self.assertListEqual( + filter_job_names(job_names, ".*"), + job_names + ) + self.assertListEqual( + filter_job_names(job_names, ".*xenial.*"), + job_names + ) + self.assertListEqual( + filter_job_names(job_names, ".*xenial.*test2"), + ["pytorch_linux_xenial_py3_6_gcc5_4_test2"] + ) + self.assertListEqual( + filter_job_names(job_names, ".*xenial.*test3"), + [] + ) + self.assertRaises( + Exception, + lambda: filter_job_names(job_names, "["), + msg="malformed regex should throw exception" + ) + if __name__ == "__main__": main()