Skip to content

Commit c082690

Browse files
committed
Now with tests in bazel
1 parent ad97658 commit c082690

File tree

4 files changed

+32
-8
lines changed

4 files changed

+32
-8
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ steps:
6262
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
6363
cp bazel-bin/*.whl .
6464
python -m pip install *.whl
65-
cd test && python -m pip install "jax[cpu]" && python test.py && python bench_vs_xla.py
65+
python -m pip install "jax[cpu]"
66+
bazel --output_user_root=`pwd`/baztmp test --test_output=errors ...
6667
artifact_paths:
6768
- "*.whl"
6869

.github/workflows/build.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,7 @@ jobs:
4545
- name: test
4646
run: |
4747
python3 -m pip install --user --force-reinstall "jax[cpu]" *.whl
48-
cd test
49-
nm -C $(python3 -c "from enzyme_ad.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache::
50-
python3 test.py
51-
python3 bench_vs_xla.py
52-
cd lit_tests
53-
lit . --verbose
48+
bazel test --test_output=errors ...
5449
5550
- name: Upload Build
5651
uses: actions/upload-artifact@v3

test/BUILD

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,31 @@
11
load("@rules_python//python:py_test.bzl", "py_test")
2+
load("@llvm-project//llvm:lit_test.bzl", "package_path", "lit_test")
3+
4+
5+
exports_files(
6+
["lit_tests/lit.cfg.py"],
7+
visibility = [":__subpackages__"],
8+
)
9+
10+
[
11+
lit_test(
12+
name = "%s.test" % src,
13+
srcs = [src],
14+
data = [
15+
"//test:lit_tests/lit.cfg.py",
16+
"//src/enzyme_ad/jax:enzyme_jax_internal",
17+
"@llvm-project//clang:builtin_headers_gen",
18+
"@llvm-project//llvm:FileCheck",
19+
"@llvm-project//llvm:count",
20+
"@llvm-project//llvm:not",
21+
] + glob(["**/*.h"]),
22+
)
23+
for src in glob(
24+
[
25+
"**/*.pyt",
26+
],
27+
)
28+
]
229

330
py_test(
431
name = "test",

test/lit_tests/ir.pyt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import jax
44
import jax.numpy as jnp
55
from enzyme_ad.jax import cpp_call
66

7+
argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11")
78

89
def do_something(ones, twos):
910
shape = jax.core.ShapedArray(tuple(3 * s for s in ones.shape), ones.dtype)
@@ -26,7 +27,7 @@ def do_something(ones, twos):
2627
}
2728
}
2829
}
29-
""",
30+
""", argv=argv,
3031
fn="myfn",
3132
)
3233
return a, b

0 commit comments

Comments
 (0)