File tree Expand file tree Collapse file tree 4 files changed +32
-8
lines changed Expand file tree Collapse file tree 4 files changed +32
-8
lines changed Original file line number Diff line number Diff line change 62
62
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
63
63
cp bazel-bin/*.whl .
64
64
python -m pip install *.whl
65
- cd test && python -m pip install "jax[cpu]" && python test.py && python bench_vs_xla.py
65
+ python -m pip install "jax[cpu]"
66
+ bazel --output_user_root=`pwd`/baztmp test --test_output=errors ...
66
67
artifact_paths :
67
68
- " *.whl"
68
69
Original file line number Diff line number Diff line change 45
45
- name : test
46
46
run : |
47
47
python3 -m pip install --user --force-reinstall "jax[cpu]" *.whl
48
- cd test
49
- nm -C $(python3 -c "from enzyme_ad.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache::
50
- python3 test.py
51
- python3 bench_vs_xla.py
52
- cd lit_tests
53
- lit . --verbose
48
+ bazel test --test_output=errors ...
54
49
55
50
- name : Upload Build
56
51
uses : actions/upload-artifact@v3
Original file line number Diff line number Diff line change 1
1
load ("@rules_python//python:py_test.bzl" , "py_test" )
2
+ load ("@llvm-project//llvm:lit_test.bzl" , "package_path" , "lit_test" )
3
+
4
+
5
+ exports_files (
6
+ ["lit_tests/lit.cfg.py" ],
7
+ visibility = [":__subpackages__" ],
8
+ )
9
+
10
+ [
11
+ lit_test (
12
+ name = "%s.test" % src ,
13
+ srcs = [src ],
14
+ data = [
15
+ "//test:lit_tests/lit.cfg.py" ,
16
+ "//src/enzyme_ad/jax:enzyme_jax_internal" ,
17
+ "@llvm-project//clang:builtin_headers_gen" ,
18
+ "@llvm-project//llvm:FileCheck" ,
19
+ "@llvm-project//llvm:count" ,
20
+ "@llvm-project//llvm:not" ,
21
+ ] + glob (["**/*.h" ]),
22
+ )
23
+ for src in glob (
24
+ [
25
+ "**/*.pyt" ,
26
+ ],
27
+ )
28
+ ]
2
29
3
30
py_test (
4
31
name = "test" ,
Original file line number Diff line number Diff line change @@ -4,6 +4,7 @@ import jax
4
4
import jax .numpy as jnp
5
5
from enzyme_ad .jax import cpp_call
6
6
7
+ argv = ("-I/usr/include/c++/11" , "-I/usr/include/x86_64-linux-gnu/c++/11" )
7
8
8
9
def do_something (ones , twos ):
9
10
shape = jax .core .ShapedArray (tuple (3 * s for s in ones .shape ), ones .dtype )
@@ -26,7 +27,7 @@ def do_something(ones, twos):
26
27
}
27
28
}
28
29
}
29
- """ ,
30
+ """ , argv = argv ,
30
31
fn = "myfn" ,
31
32
)
32
33
return a , b
You can’t perform that action at this time.
0 commit comments