Skip to content

Commit 92a955b

Browse files
frontend: various cleanups and minor fixes
1 parent c2ee317 commit 92a955b

File tree

10 files changed

+486
-215
lines changed

10 files changed

+486
-215
lines changed

frontend/e2e_test.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from heir import compile
22
from heir.mlir import I16, Secret
3+
from heir import compile
4+
from heir.mlir import F32, I16, I64, Secret
5+
from heir.backends.cleartext import CleartextBackend
36

47

58
from absl.testing import absltest # fmt: skip
@@ -22,6 +25,107 @@ def foo(a: Secret[I16], b: Secret[I16]):
2225
result = foo.decrypt_result(result_enc)
2326
self.assertEqual(-15, result)
2427

28+
def test_simple_example(self):
29+
30+
@compile()
31+
def func(x: Secret[I16], y: Secret[I16]):
32+
sum = x + y
33+
diff = x - y
34+
mul = x * y
35+
expression = sum * diff + mul
36+
deadcode = expression * mul
37+
return expression
38+
39+
# Test cleartext functionality
40+
self.assertEqual(41, func.original(7, 8))
41+
42+
# Test FHE functionality
43+
self.assertEqual(41, func(7, 8))
44+
45+
def test_manual_example(self):
46+
47+
@compile()
48+
def manual(x: Secret[I16], y: Secret[I16]):
49+
return (x + y) * (x - y) + (x * y)
50+
51+
manual.setup() # runs keygen/etc
52+
enc_x = manual.encrypt_x(7)
53+
enc_y = manual.encrypt_y(8)
54+
result_enc = manual.eval(enc_x, enc_y)
55+
result = manual.decrypt_result(result_enc)
56+
57+
# Test cleartext functionality
58+
self.assertEqual(41, manual.original(7, 8))
59+
60+
# Test FHE functionality
61+
self.assertEqual(41, result)
62+
63+
def test_loop_example(self):
64+
65+
@compile()
66+
def loop_test(a: Secret[I64]):
67+
"""An example function with a static loop."""
68+
result = 2
69+
for i in range(3):
70+
result = a + result
71+
return result
72+
73+
# Test cleartext functionality
74+
self.assertEqual(8, loop_test.original(2))
75+
76+
# Test FHE functionality
77+
self.assertEqual(8, loop_test(2))
78+
79+
def test_ckks_example(self):
80+
81+
@compile(scheme="ckks")
82+
def bar(x: Secret[F32], y: Secret[F32]):
83+
return (x + y) * (x - y) + (x * y)
84+
85+
# Test cleartext functionality
86+
self.assertAlmostEqual(0.41, bar.original(0.7, 0.8))
87+
88+
# Test FHE functionality
89+
self.assertAlmostEqual(0.41, bar(0.7, 0.8))
90+
91+
def test_ctxt_ptxt_example(self):
92+
93+
@compile()
94+
def baz(x: Secret[I16], y: Secret[I16], z: I16):
95+
ptxt_mul = x * z
96+
ctxt_mul = x * x
97+
ctxt_mul2 = y * y
98+
add = ctxt_mul + ctxt_mul2
99+
return ptxt_mul + add
100+
101+
# Test cleartext functionality
102+
self.assertEqual(127, baz.original(7, 8, 2))
103+
104+
# Test FHE functionality
105+
self.assertEqual(127, baz(7, 8, 2))
106+
107+
def test_custom_example(self):
108+
109+
@compile(
110+
heir_opt_options=[
111+
"--mlir-to-secret-arithmetic",
112+
"--canonicalize",
113+
"--cse",
114+
],
115+
backend=CleartextBackend(), # just runs the python function when `custom(...)` is called
116+
debug=True,
117+
)
118+
def custom(x: Secret[I16], y: Secret[I16]):
119+
return (x + y) * (x - y) + (x * y)
120+
121+
# Test cleartext functionality
122+
self.assertEqual(41, custom.original(7, 8))
123+
124+
# Test cleartext functionality via CleartextBackend
125+
self.assertEqual(41, custom(7, 8))
126+
127+
# There's unfortunately no way to test the MLIR output here
128+
25129

26130
if __name__ == "__main__":
27131
absltest.main()

frontend/example.py

Lines changed: 124 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,132 @@
11
"""Example of HEIR Python usage."""
22

33
from heir import compile
4-
from heir.mlir import F32, I16, I64, Secret, Tensor
4+
from heir.mlir import F32, I16, I64, Secret
5+
from heir.backends.cleartext import CleartextBackend
56

67
# TODO (#1162): Also add the tensorflow-to-tosa-to-HEIR example in example.py, even it doesn't use the main Python frontend?
78

8-
9-
### Simple Example
10-
@compile() # defaults to scheme="bgv", OpenFHE backend, and debug=False
11-
def func(x: Secret[I16], y: Secret[I16]):
12-
sum = x + y
13-
diff = x - y
14-
mul = x * y
15-
expression = sum * diff + mul
16-
deadcode = expression * mul
17-
return expression
18-
19-
20-
print(
21-
f"Expected result for `func`: {func.original(7,8)}, FHE result: {func(7,8)}"
22-
)
23-
24-
25-
# ### Manual setup/enc/dec example
26-
# @compile()
27-
# def foo(x: Secret[I16], y: Secret[I16]):
28-
# return (x + y) * (x - y) + (x * y)
29-
30-
31-
# foo.setup() # runs keygen/etc
32-
# enc_x = foo.encrypt_x(7)
33-
# enc_y = foo.encrypt_y(8)
34-
# result_enc = foo.eval(enc_x, enc_y)
35-
# result = foo.decrypt_result(result_enc)
36-
# print(
37-
# f"Expected result for `foo`: {foo.original(7,8)}, "
38-
# f"decrypted FHE result: {result}"
39-
# )
40-
41-
42-
# ### Loop Example
43-
# @compile()
44-
# def loop_test(a: Secret[I64]):
45-
# """An example function with a static loop."""
46-
# result = 2
47-
# for i in range(3):
48-
# result = a + result
49-
# return result
50-
51-
52-
# print(
53-
# f"Expected result for `loop_test`: {loop_test(2)}, "
54-
# f"FHE result: {loop_test(2)}"
55-
# )
9+
# TODO (#1162): Remove the need for wrapper functions around each `@compile`-d function to isolate backend pybindings
5610

5711

58-
# ### CKKS Example
59-
# @compile(scheme="ckks")
60-
# def bar(x: Secret[F32], y: Secret[F32]):
61-
# return (x + y) * (x - y) + (x * y)
62-
63-
64-
# print(f"Expected result for `bar`: {bar.original(7,8)}, FHE result: {bar(7,8)}")
65-
66-
67-
# ### Ciphertext-Plaintext Example
68-
# @compile(debug=True)
69-
# def baz(x: Secret[I16], y: Secret[I16], z: I16):
70-
# ptxt_mul = x * z
71-
# ctxt_mul = x * x
72-
# ctxt_mul2 = y * y
73-
# add = ctxt_mul + ctxt_mul2
74-
# return ptxt_mul + add
75-
76-
77-
# print(
78-
# f"Expected result for `baz`: {baz.original(7,8,9)}, "
79-
# f"FHE result: {baz(7,8,9)}"
80-
# )
81-
82-
83-
# ### Custom Pipeline Example
84-
# @compile(
85-
# heir_opt_options=["--mlir-to-secret-arithmetic", "--canonicalize", "--cse"],
86-
# backend=None, # defaults to CleartextBackend
87-
# debug=True,
88-
# )
89-
# def custom(x: Secret[I16], y: Secret[I16]):
90-
# return (x + y) * (x - y) + (x * y)
91-
92-
93-
# print(
94-
# f"CleartextBackend simply runs the original python function: {custom(7,8)}"
95-
# )
12+
### Simple Example
13+
def simple_example():
14+
print("Running simple example")
15+
16+
@compile() # defaults to scheme="bgv", OpenFHE backend, and debug=False
17+
def func(x: Secret[I16], y: Secret[I16]):
18+
sum = x + y
19+
diff = x - y
20+
mul = x * y
21+
expression = sum * diff + mul
22+
deadcode = expression * mul
23+
return expression
24+
25+
print(
26+
f"Expected result for `func`: {func.original(7,8)}, FHE result:"
27+
f" {func(7,8)}"
28+
)
29+
30+
31+
### Manual setup/enc/dec example
32+
def manual_example():
33+
print("Running manual example")
34+
35+
@compile()
36+
def foo(x: Secret[I16], y: Secret[I16]):
37+
return (x + y) * (x - y) + (x * y)
38+
39+
foo.setup() # runs keygen/etc
40+
enc_x = foo.encrypt_x(7)
41+
enc_y = foo.encrypt_y(8)
42+
result_enc = foo.eval(enc_x, enc_y)
43+
result = foo.decrypt_result(result_enc)
44+
print(
45+
f"Expected result for `foo`: {foo.original(7,8)}, "
46+
f"decrypted FHE result: {result}"
47+
)
48+
49+
50+
### Loop Example
51+
def loop_example():
52+
print("Running loop example")
53+
54+
@compile()
55+
def loop_test(a: Secret[I64]):
56+
"""An example function with a static loop."""
57+
result = 2
58+
for i in range(3):
59+
result = a + result
60+
return result
61+
62+
print(
63+
f"Expected result for `loop_test`: {loop_test(2)}, "
64+
f"FHE result: {loop_test(2)}"
65+
)
66+
67+
68+
### CKKS Example
69+
def ckks_example():
70+
print("Running CKKS example")
71+
72+
@compile(scheme="ckks")
73+
def bar(x: Secret[F32], y: Secret[F32]):
74+
return (x + y) * (x - y) + (x * y)
75+
76+
print(
77+
f"Expected result for `bar`: {bar.original(0.7,0.8)}, FHE result:"
78+
f" {bar(0.7,0.8)}"
79+
)
80+
81+
82+
### Ciphertext-Plaintext Example
83+
def ctxt_ptxt_example():
84+
print("Running ciphertext-plaintext example")
85+
86+
@compile()
87+
def baz(x: Secret[I16], y: Secret[I16], z: I16):
88+
ptxt_mul = x * z
89+
ctxt_mul = x * x
90+
ctxt_mul2 = y * y
91+
add = ctxt_mul + ctxt_mul2
92+
return ptxt_mul + add
93+
94+
print(
95+
f"Expected result for `baz`: {baz.original(7,8,9)}, "
96+
f"FHE result: {baz(7,8,9)}"
97+
)
98+
99+
100+
### Custom Pipeline Example
101+
def custom_example():
102+
print("Running custom pipeline example")
103+
104+
@compile(
105+
heir_opt_options=[
106+
"--mlir-to-secret-arithmetic",
107+
"--canonicalize",
108+
"--cse",
109+
],
110+
backend=CleartextBackend(), # just runs the python function when `custom(...)` is called
111+
debug=True, # so that we can see the file that contains the output of the pipeline
112+
)
113+
def custom(x: Secret[I16], y: Secret[I16]):
114+
return (x + y) * (x - y) + (x * y)
115+
116+
print(
117+
"CleartextBackend simply runs the original python function:"
118+
f" {custom(7,8)}"
119+
)
120+
121+
122+
def main():
123+
simple_example()
124+
manual_example()
125+
loop_example()
126+
ckks_example()
127+
ctxt_ptxt_example()
128+
custom_example()
129+
130+
131+
if __name__ == "__main__":
132+
main()

0 commit comments

Comments
 (0)