Skip to content

Commit 50e05d7

Browse files
authored
Fix segfault (#35)
* Fix segfault * add test build file * Now with tests in bazel * fix lit
1 parent d8c2d85 commit 50e05d7

File tree

8 files changed

+84
-20
lines changed

8 files changed

+84
-20
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ steps:
6262
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
6363
cp bazel-bin/*.whl .
6464
python -m pip install *.whl
65-
cd test && python -m pip install "jax[cpu]" && python test.py && python bench_vs_xla.py
65+
python -m pip install "jax[cpu]"
66+
bazel --output_user_root=`pwd`/baztmp test --test_output=errors ...
6667
artifact_paths:
6768
- "*.whl"
6869

.github/workflows/build.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,7 @@ jobs:
4545
- name: test
4646
run: |
4747
python3 -m pip install --user --force-reinstall "jax[cpu]" *.whl
48-
cd test
49-
nm -C $(python3 -c "from enzyme_ad.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache::
50-
python3 test.py
51-
python3 bench_vs_xla.py
52-
cd lit_tests
53-
lit . --verbose
48+
bazel test --test_output=errors ...
5449
5550
- name: Upload Build
5651
uses: actions/upload-artifact@v3

WORKSPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen
6060

6161
pip_install_dependencies()
6262

63-
ENZYME_COMMIT = "41e16dee0efd7f4c81e552e5899fe0068576573f"
64-
ENZYME_SHA256 = "fbfa6707db19b96ac8bd7f5edb58c4d028e11301873305fd7df7d7cd52abe66d"
63+
ENZYME_COMMIT = "e8ca2b1de3b770c767145d027b357bed97178bb0"
64+
ENZYME_SHA256 = "dd3789b8e749ed989d7a1a4956880826d9fd2501283dc60be468adec1186e63f"
6565

6666
http_archive(
6767
name = "enzyme",

src/enzyme_ad/jax/enzyme_call.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,6 @@ class CpuKernel {
972972
void **outs = num_out > 1 ? reinterpret_cast<void **>(out) : &out;
973973
for (int i = 0; i < num_out; i++) {
974974
void *data = outs[i];
975-
*(void **)(data) = 0;
976975
}
977976
auto fn = (void (*)(void **outs, void **ins))addr;
978977
fn(outs, ins);

test/BUILD

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
load("@rules_python//python:py_test.bzl", "py_test")
2+
load("@llvm-project//llvm:lit_test.bzl", "package_path", "lit_test")
3+
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
4+
5+
expand_template(
6+
name = "lit_site_cfg_py",
7+
testonly = True,
8+
out = "lit.site.cfg.py",
9+
substitutions = {
10+
"@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.",
11+
"@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"),
12+
"@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"),
13+
"@ENZYME_SOURCE_DIR@": "",
14+
"@ENZYME_BINARY_DIR@": "",
15+
},
16+
template = "lit.site.cfg.py.in",
17+
visibility = [":__subpackages__"],
18+
)
19+
20+
exports_files(
21+
["lit.cfg.py"],
22+
visibility = [":__subpackages__"],
23+
)
24+
25+
[
26+
lit_test(
27+
name = "%s.test" % src,
28+
srcs = [src],
29+
data = [
30+
":lit.cfg.py",
31+
":lit_site_cfg_py",
32+
"//src/enzyme_ad/jax:enzyme_jax_internal",
33+
"@llvm-project//clang:builtin_headers_gen",
34+
"@llvm-project//llvm:FileCheck",
35+
"@llvm-project//llvm:count",
36+
"@llvm-project//llvm:not",
37+
] + glob(["**/*.h"]),
38+
)
39+
for src in glob(
40+
[
41+
"**/*.pyt",
42+
],
43+
)
44+
]
45+
46+
py_test(
47+
name = "test",
48+
srcs = [
49+
"test.py",
50+
],
51+
deps = [
52+
"//src/enzyme_ad/jax:enzyme_jax_internal",
53+
],
54+
)
55+
56+
py_test(
57+
name = "bench_vs_xla",
58+
srcs = [
59+
"bench_vs_xla.py",
60+
],
61+
deps = [
62+
"//src/enzyme_ad/jax:enzyme_jax_internal",
63+
],
64+
)
65+

test/lit_tests/lit.cfg.py renamed to test/lit.cfg.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,7 @@
3434

3535
# Tweak the PATH to include the tools dir and the scripts dir.
3636
base_paths = [
37-
os.path.join(
38-
os.path.dirname(__file__),
39-
"..",
40-
"..",
41-
"bazel-bin",
42-
"external",
43-
"llvm-project",
44-
"llvm",
45-
),
37+
config.llvm_tools_dir,
4638
config.environment["PATH"],
4739
]
4840
path = os.path.pathsep.join(base_paths) # + config.extra_paths)

test/lit.site.cfg.py.in

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@LIT_SITE_CFG_IN_HEADER@
2+
3+
import os
4+
config.llvm_tools_dir = "@LLVM_TOOLS_BINARY_DIR@"
5+
6+
7+
if len("@ENZYME_BINARY_DIR@") == 0:
8+
config.llvm_tools_dir = os.getcwd() + "/" + config.llvm_tools_dir
9+
10+
cfgfile = os.path.dirname(os.path.abspath(__file__)) + "/lit.cfg.py"
11+
lit_config.load_config(config, cfgfile)

test/lit_tests/ir.pyt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import jax
44
import jax.numpy as jnp
55
from enzyme_ad.jax import cpp_call
66

7+
argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11")
78

89
def do_something(ones, twos):
910
shape = jax.core.ShapedArray(tuple(3 * s for s in ones.shape), ones.dtype)
@@ -26,7 +27,7 @@ def do_something(ones, twos):
2627
}
2728
}
2829
}
29-
""",
30+
""", argv=argv,
3031
fn="myfn",
3132
)
3233
return a, b

0 commit comments

Comments
 (0)