Skip to content

Commit 88389b5

Browse files
committed
passing test for dualv
1 parent 49518cd commit 88389b5

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

tests/codegen/autodiffv2.rs

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
//
5+
// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
6+
// breakages. One benefit is that we match the IR generated by Enzyme only after running it
7+
// through LLVM's O3 pipeline, which will remove most of the noise.
8+
// However, our integration test could also be affected by changes in how rustc lowers MIR into
9+
// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
10+
// reduce this test to only match the first lines and the ret instructions.
11+
12+
#![feature(autodiff)]
13+
14+
use std::autodiff::autodiff;
15+
16+
#[no_mangle]
17+
//#[autodiff(d_square1, Forward, Dual, Dual)]
18+
#[autodiff(d_square2, Forward, 4, Dualv, Dualv)]
19+
#[autodiff(d_square3, Forward, 4, Dual, Dual)]
20+
fn square(x: &[f32], y: &mut [f32]) {
21+
assert!(x.len() >= 4);
22+
assert!(y.len() >= 5);
23+
y[0] = 4.3 * x[0] + 1.2 * x[1] + 3.4 * x[2] + 2.1 * x[3];
24+
y[1] = 2.3 * x[0] + 4.5 * x[1] + 1.7 * x[2] + 6.4 * x[3];
25+
y[2] = 1.1 * x[0] + 3.3 * x[1] + 2.5 * x[2] + 4.7 * x[3];
26+
y[3] = 5.2 * x[0] + 1.4 * x[1] + 2.6 * x[2] + 3.8 * x[3];
27+
y[4] = 1.0 * x[0] + 2.0 * x[1] + 3.0 * x[2] + 4.0 * x[3];
28+
}
29+
30+
fn main() {
31+
let x1 = std::hint::black_box(vec![0.0, 1.0, 2.0, 3.0]);
32+
33+
let mut dx1 = std::hint::black_box(vec![1.0; 12]);
34+
35+
let z1 = std::hint::black_box(vec![1.0, 0.0, 0.0, 0.0]);
36+
let z2 = std::hint::black_box(vec![0.0, 1.0, 0.0, 0.0]);
37+
let z3 = std::hint::black_box(vec![0.0, 0.0, 1.0, 0.0]);
38+
let z4 = std::hint::black_box(vec![0.0, 0.0, 0.0, 1.0]);
39+
40+
let z5 = std::hint::black_box(vec![1.0, 0.0, 0.0, 0.0,
41+
0.0, 1.0, 0.0, 0.0,
42+
0.0, 0.0, 1.0, 0.0,
43+
0.0, 0.0, 0.0, 1.0]);
44+
45+
let mut y1 = std::hint::black_box(vec![0.0; 5]);
46+
let mut y2 = std::hint::black_box(vec![0.0; 5]);
47+
let mut y3 = std::hint::black_box(vec![0.0; 5]);
48+
let mut y4 = std::hint::black_box(vec![0.0; 5]);
49+
50+
let mut y5 = std::hint::black_box(vec![0.0; 5]);
51+
52+
let mut y6 = std::hint::black_box(vec![0.0; 5]);
53+
54+
let mut dy1_1 = std::hint::black_box(vec![0.0; 5]);
55+
let mut dy1_2 = std::hint::black_box(vec![0.0; 5]);
56+
let mut dy1_3 = std::hint::black_box(vec![0.0; 5]);
57+
let mut dy1_4 = std::hint::black_box(vec![0.0; 5]);
58+
59+
let mut dy2 = std::hint::black_box(vec![0.0; 20]);
60+
61+
let mut dy3_1 = std::hint::black_box(vec![0.0; 5]);
62+
let mut dy3_2 = std::hint::black_box(vec![0.0; 5]);
63+
let mut dy3_3 = std::hint::black_box(vec![0.0; 5]);
64+
let mut dy3_4 = std::hint::black_box(vec![0.0; 5]);
65+
66+
let result = std::hint::black_box(x1.iter().map(|x| 2.0 * x).collect::<Vec<_>>());
67+
68+
// scalar.
69+
//d_square1(&x1, &z1, &mut y1, &mut dy1_1);
70+
//d_square1(&x1, &z2, &mut y2, &mut dy1_2);
71+
//d_square1(&x1, &z3, &mut y3, &mut dy1_3);
72+
//d_square1(&x1, &z4, &mut y4, &mut dy1_4);
73+
74+
// assert y1 == y2 == y3 == y4
75+
//for i in 0..5 {
76+
// assert_eq!(y1[i], y2[i]);
77+
// assert_eq!(y1[i], y3[i]);
78+
// assert_eq!(y1[i], y4[i]);
79+
//}
80+
81+
// batch mode A)
82+
//dx1 = std::hint::black_box(vec![1.0; 12]);
83+
d_square2(&x1, &z5, &mut y5, &mut dy2);
84+
85+
// assert y1 == y2 == y3 == y4 == y5
86+
//for i in 0..5 {
87+
// assert_eq!(y1[i], y5[i]);
88+
//}
89+
90+
// batch mode B)
91+
d_square3(&x1, &z1, &z2, &z3, &z4, &mut y6, &mut dy3_1, &mut dy3_2, &mut dy3_3, &mut dy3_4);
92+
for i in 0..5 {
93+
assert_eq!(y5[i], y6[i]);
94+
}
95+
96+
dbg!(&dy2);
97+
dbg!(&dy3_1);
98+
dbg!(&dy3_2);
99+
dbg!(&dy3_3);
100+
dbg!(&dy3_4);
101+
for i in 0..5 {
102+
assert_eq!(dy2[0..5][i], dy3_1[i]);
103+
assert_eq!(dy2[5..10][i], dy3_2[i]);
104+
assert_eq!(dy2[10..15][i], dy3_3[i]);
105+
assert_eq!(dy2[15..20][i], dy3_4[i]);
106+
}
107+
}

0 commit comments

Comments
 (0)