Skip to content

Commit 4fba792

Browse files
authored
Generate JUnit XML report for tutorials (#4203)
Fixes #4055 --------- Signed-off-by: Pavel Chekin <[email protected]>
1 parent e4f83c9 commit 4fba792

File tree

5 files changed

+164
-54
lines changed

5 files changed

+164
-54
lines changed

.github/workflows/build-test-reusable.yml

+7
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,13 @@ jobs:
387387
name: pass_rate-${{ inputs.python_version }}-${{ inputs.runner_label || inputs.driver_version }}
388388
path: pass_rate*.json
389389

390+
- name: Upload tutorials test report
391+
uses: actions/upload-artifact@v4
392+
with:
393+
name: test-reports-tutorials-${{ inputs.python_version }}-${{ inputs.runner_label || inputs.driver_version }}
394+
include-hidden-files: true
395+
path: reports/tutorials.xml
396+
390397
- name: Upload tutorials performance report
391398
uses: actions/upload-artifact@v4
392399
with:

scripts/pass_rate.py

+41-18
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import os
88
import pathlib
99
import platform
10+
import xml.etree.ElementTree as et
11+
1012
from typing import List
1113

1214
from defusedxml.ElementTree import parse
@@ -124,31 +126,51 @@ def find_stats(stats: List[ReportStats], name: str) -> ReportStats:
124126
raise ValueError(f'{name} not found')
125127

126128

127-
def parse_junit_reports(args: argparse.Namespace) -> List[ReportStats]:
129+
def parse_junit_reports(reports_path: pathlib.Path) -> List[ReportStats]:
128130
"""Parses junit report in the specified directory."""
129-
reports_path = pathlib.Path(args.reports)
130131
return [parse_report(report) for report in reports_path.glob('*.xml')]
131132

132133

133-
def parse_tutorials_reports(args: argparse.Namespace) -> List[ReportStats]:
134-
"""Parses tutorials reports in the specified directory."""
135-
reports_path = pathlib.Path(args.reports)
136-
stats = ReportStats(name='tutorials')
137-
for report in reports_path.glob('tutorial-*.txt'):
138-
result = report.read_text().strip()
139-
stats.total += 1
134+
# pylint: disable=too-many-locals
135+
def generate_junit_report(reports_path: pathlib.Path):
136+
"""Parses info files for tutorials and generates JUnit report.
137+
The script `run_tutorial.py` generates `tutorial-*.json` files in the reports directory.
138+
This function loads them and generates `tutorials.xml` file (JUnit XML report) in the same
139+
directory.
140+
"""
141+
testsuites = et.Element('testsuites')
142+
testsuite = et.SubElement(testsuites, 'testsuite', name='tutorials')
143+
144+
total_tests, total_errors, total_failures, total_skipped = 0, 0, 0, 0
145+
total_time = 0.0
146+
147+
for item in reports_path.glob('tutorial-*.json'):
148+
data = json.loads(item.read_text())
149+
name, result, time = data['name'], data['result'], data.get('time', 0)
150+
testcase = et.SubElement(testsuite, 'testcase', name=name)
140151
if result == 'PASS':
141-
stats.passed += 1
152+
testcase.set('time', str(time))
142153
elif result == 'SKIP':
143-
stats.skipped += 1
154+
total_skipped += 1
155+
et.SubElement(testcase, 'skipped', type='pytest.skip')
144156
elif result == 'FAIL':
145-
stats.failed += 1
146-
return [stats]
157+
total_failures += 1
158+
et.SubElement(testcase, 'failure', message=data.get('message', ''))
159+
else:
160+
continue
161+
total_tests += 1
162+
total_time += time
147163

164+
testsuite.set('tests', str(total_tests))
165+
testsuite.set('errors', str(total_errors))
166+
testsuite.set('failures', str(total_failures))
167+
testsuite.set('skipped', str(total_skipped))
168+
testsuite.set('time', str(total_time))
148169

149-
def parse_reports(args: argparse.Namespace) -> List[ReportStats]:
150-
"""Parses all report in the specified directory."""
151-
return parse_junit_reports(args) + parse_tutorials_reports(args)
170+
report_path = reports_path / 'tutorials.xml'
171+
with report_path.open('wb') as f:
172+
tree = et.ElementTree(testsuites)
173+
tree.write(f, encoding='UTF-8', xml_declaration=True)
152174

153175

154176
def print_text_stats(stats: ReportStats):
@@ -194,9 +216,10 @@ def print_json_stats(stats: ReportStats):
194216
def main():
195217
"""Main."""
196218
args = create_argument_parser().parse_args()
197-
args.report_path = pathlib.Path(args.reports)
198219

199-
stats = parse_reports(args)
220+
reports_path = pathlib.Path(args.reports)
221+
generate_junit_report(reports_path)
222+
stats = parse_junit_reports(reports_path)
200223

201224
if args.suite == 'all':
202225
summary = overall_stats(stats)

scripts/pytest-utils.sh

+8-26
Original file line numberDiff line numberDiff line change
@@ -63,36 +63,18 @@ run_tutorial_test() {
6363
echo "****** Running $1 test ******"
6464
echo
6565

66-
TUTORIAL_RESULT=TODO
67-
68-
if [[ -f $TRITON_TEST_SKIPLIST_DIR/tutorials.txt ]]; then
69-
if grep --fixed-strings --quiet "$1" "$TRITON_TEST_SKIPLIST_DIR/tutorials.txt"; then
70-
TUTORIAL_RESULT=SKIP
71-
fi
72-
fi
73-
74-
if [[ $TRITON_TEST_REPORTS = true ]]; then
75-
RUN_TUTORIAL="python -u $SCRIPTS_DIR/run_tutorial.py --reports $TRITON_TEST_REPORTS_DIR $1.py"
76-
else
77-
RUN_TUTORIAL="python -u $1.py"
78-
fi
79-
80-
if [[ $TUTORIAL_RESULT = TODO ]]; then
81-
if $RUN_TUTORIAL; then
82-
TUTORIAL_RESULT=PASS
83-
else
84-
TUTORIAL_RESULT=FAIL
85-
fi
86-
fi
66+
run_tutorial_args=(
67+
"--skip-list=$TRITON_TEST_SKIPLIST_DIR/tutorials.txt"
68+
"$1.py"
69+
)
8770

8871
if [[ $TRITON_TEST_REPORTS = true ]]; then
89-
mkdir -p "$TRITON_TEST_REPORTS_DIR"
90-
echo $TUTORIAL_RESULT > "$TRITON_TEST_REPORTS_DIR/tutorial-$1.txt"
72+
run_tutorial_args+=(
73+
"--reports=$TRITON_TEST_REPORTS_DIR"
74+
)
9175
fi
9276

93-
if [[ $TUTORIAL_RESULT = FAIL && $TRITON_TEST_IGNORE_ERRORS = false ]]; then
94-
exit 1
95-
fi
77+
python -u "$SCRIPTS_DIR/run_tutorial.py" "${run_tutorial_args[@]}" || $TRITON_TEST_IGNORE_ERRORS
9678
}
9779

9880
capture_runtime_env() {

scripts/run_tutorial.py

+72-10
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
"""Script to run a Triton tutorial and collect generated csv files."""
22

33
import argparse
4+
import dataclasses
5+
import datetime
46
import importlib.util
7+
import json
58
import pathlib
69
import shutil
710
import tempfile
11+
import sys
12+
from typing import Set, Optional
813

914
import triton.testing
1015

@@ -33,38 +38,95 @@ def create_argument_parser() -> argparse.ArgumentParser:
3338
"""Creates argument parser."""
3439
parser = argparse.ArgumentParser()
3540
parser.add_argument('tutorial', help='Tutorial to run')
36-
parser.add_argument('--reports', required=False, type=str, default='.',
37-
help='Directory to store tutorial CSV reports, default: %(default)s')
41+
parser.add_argument('--reports', required=False, type=str, help='Directory to store tutorial CSV reports')
42+
parser.add_argument('--skip-list', required=False, type=str, help='Skip list for tutorials')
3843
return parser
3944

4045

41-
def run_tutorial(path: pathlib.Path):
42-
"""Runs """
46+
def get_skip_list(path: pathlib.Path) -> Set[str]:
47+
"""Loads skip list from the specified file."""
48+
skip_list = set()
49+
if path.exists() and path.is_file():
50+
for item in path.read_text().splitlines():
51+
skip_list.add(item.strip())
52+
return skip_list
53+
54+
55+
def run_tutorial(path: pathlib.Path) -> float:
56+
"""Runs tutorial."""
4357
spec = importlib.util.spec_from_file_location('__main__', path)
4458
if not spec or not spec.loader:
4559
raise AssertionError(f'Failed to load module from {path}')
4660
module = importlib.util.module_from_spec(spec)
47-
# set __file__ to the absolute name, a workaround for 10i-experimental-block-pointer, which
61+
# Set __file__ to the absolute name, a workaround for 10i-experimental-block-pointer, which
4862
# uses dirname of its location to find 10-experimental-block-pointer.
4963
module.__file__ = path.resolve().as_posix()
64+
# Reset sys.argv because some tutorials, such as 09, parse their command line arguments.
65+
sys.argv = [str(path)]
66+
start_time = datetime.datetime.now()
5067
spec.loader.exec_module(module)
68+
elapsed_time = datetime.datetime.now() - start_time
69+
return elapsed_time.total_seconds()
70+
71+
72+
@dataclasses.dataclass
73+
class TutorialInfo:
74+
"""A record about tutorial execution."""
75+
name: str
76+
path: Optional[pathlib.Path] = None
77+
78+
def _report(self, **kwargs):
79+
"""Writes tutorial info."""
80+
if self.path:
81+
with self.path.open(mode='w') as f:
82+
json.dump({'name': self.name} | kwargs, f)
83+
84+
def report_pass(self, time: float):
85+
"""Reports successful tutorial."""
86+
self._report(result='PASS', time=time)
87+
88+
def report_skip(self):
89+
"""Reports skipped tutorial."""
90+
self._report(result='SKIP')
91+
92+
def report_fail(self, message: str):
93+
"""Reports failed tutorial."""
94+
self._report(result='FAIL', message=message)
5195

5296

5397
def main():
5498
"""Main."""
99+
skip_list = set()
55100
args = create_argument_parser().parse_args()
56101
tutorial_path = pathlib.Path(args.tutorial)
57-
reports_path = pathlib.Path(args.reports)
102+
103+
reports_path = pathlib.Path(args.reports) if args.reports else None
104+
if args.skip_list:
105+
skip_list = get_skip_list(pathlib.Path(args.skip_list))
106+
58107
name = tutorial_path.stem
59-
report_path = reports_path / name
60-
report_path.mkdir(parents=True, exist_ok=True)
108+
if reports_path:
109+
report_path = reports_path / name
110+
report_path.mkdir(parents=True, exist_ok=True)
111+
info = TutorialInfo(name=name, path=reports_path / f'tutorial-{name}.json')
112+
else:
113+
info = TutorialInfo(name=name)
61114

62115
def perf_report(benchmarks):
63116
"""Marks a function for benchmarking."""
64117
return lambda fn: CustomMark(fn, benchmarks, report_path)
65118

66-
triton.testing.perf_report = perf_report
67-
run_tutorial(tutorial_path)
119+
if name in skip_list:
120+
info.report_skip()
121+
else:
122+
if reports_path:
123+
triton.testing.perf_report = perf_report
124+
try:
125+
time = run_tutorial(tutorial_path)
126+
info.report_pass(time)
127+
except Exception as e:
128+
info.report_fail(str(e))
129+
raise
68130

69131

70132
if __name__ == '__main__':

scripts/test_pass_rate.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from xml.etree.ElementTree import parse
2+
13
import pass_rate
24

35
WARNINGS = """\
@@ -34,3 +36,37 @@ def test_get_warnings(tmp_path):
3436
assert len(warnings) == 1
3537
assert warnings[0].location == 'location'
3638
assert warnings[0].message == 'message'
39+
40+
41+
def test_generate_junit_report(tmp_path):
42+
file1 = tmp_path / 'tutorial-1.json'
43+
file2 = tmp_path / 'tutorial-2.json'
44+
file3 = tmp_path / 'tutorial-3.json'
45+
46+
file1.write_text('{"name": "1", "result": "PASS", "time": 1.0}')
47+
file2.write_text('{"name": "2", "result": "SKIP"}')
48+
file3.write_text('{"name": "3", "result": "FAIL", "message": "Error"}')
49+
50+
pass_rate.generate_junit_report(tmp_path)
51+
52+
report_path = tmp_path / 'tutorials.xml'
53+
assert report_path.exists()
54+
55+
xml = parse(report_path)
56+
testsuites = xml.getroot()
57+
assert testsuites.tag == 'testsuites'
58+
59+
testsuite = testsuites[0]
60+
assert testsuite.tag == 'testsuite'
61+
assert testsuite.get('name') == 'tutorials'
62+
assert testsuite.get('tests') == '3'
63+
assert testsuite.get('skipped') == '1'
64+
assert testsuite.get('failures') == '1'
65+
66+
stats = pass_rate.parse_junit_reports(tmp_path)
67+
assert len(stats) == 1
68+
assert stats[0].name == 'tutorials'
69+
assert stats[0].passed == 1
70+
assert stats[0].failed == 1
71+
assert stats[0].skipped == 1
72+
assert stats[0].total == 3

0 commit comments

Comments
 (0)