Skip to content

fix iter args bug, allow multi-test files #1648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/heir/mlir_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def emit_loop(self, target, blocks_to_print):
raise NotImplementedError("Nested loops are not supported")

body_str = self.emit_block(loop_block, blocks_to_print)
if len(loop.inits) > 1:
if loop.inits:
# Yield the iter args
yield_vars = ", ".join([self.get_name(init) for init in loop.inits])
ret_types = ", ".join(
Expand Down
15 changes: 14 additions & 1 deletion frontend/loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from absl.testing import absltest # fmt: skip


class EndToEndTest(absltest.TestCase):
class LoopTest(absltest.TestCase):

def test_loop(self):

Expand All @@ -21,6 +21,19 @@ def loopa(a: Secret[I64]):

self.assertEqual(32, loopa(2))

def test_loop_one_iter_arg(self):

@compile()
def one_iter_arg(a: Secret[I64]):
result1 = a
lb = 1
ub = 5
for _ in range(lb, ub):
result1 = result1 + result1
return result1

self.assertEqual(32, one_iter_arg(2))


if __name__ == "__main__":
absltest.main()
14 changes: 9 additions & 5 deletions lib/Target/OpenFhePke/OpenFhePkeTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,22 @@ using namespace lbcrypto;
namespace py = pybind11;

// Minimal bindings required for generated functions to run.
// Cf. https://pybind11.readthedocs.io/en/stable/advanced/classes.html#module-local-class-bindings
// which is a temporary workaround to allow us to have multiple compilations in
// the same python program. Better would be to cache the pybind11 module across
// calls.
void bind_common(py::module &m)
{
py::class_<PublicKeyImpl<DCRTPoly>, std::shared_ptr<PublicKeyImpl<DCRTPoly>>>(m, "PublicKey")
py::class_<PublicKeyImpl<DCRTPoly>, std::shared_ptr<PublicKeyImpl<DCRTPoly>>>(m, "PublicKey", py::module_local())
.def(py::init<>());
py::class_<PrivateKeyImpl<DCRTPoly>, std::shared_ptr<PrivateKeyImpl<DCRTPoly>>>(m, "PrivateKey")
py::class_<PrivateKeyImpl<DCRTPoly>, std::shared_ptr<PrivateKeyImpl<DCRTPoly>>>(m, "PrivateKey", py::module_local())
.def(py::init<>());
py::class_<KeyPair<DCRTPoly>>(m, "KeyPair")
py::class_<KeyPair<DCRTPoly>>(m, "KeyPair", py::module_local())
.def_readwrite("publicKey", &KeyPair<DCRTPoly>::publicKey)
.def_readwrite("secretKey", &KeyPair<DCRTPoly>::secretKey);
py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext", py::module_local())
.def(py::init<>());
py::class_<CryptoContextImpl<DCRTPoly>, std::shared_ptr<CryptoContextImpl<DCRTPoly>>>(m, "CryptoContext")
py::class_<CryptoContextImpl<DCRTPoly>, std::shared_ptr<CryptoContextImpl<DCRTPoly>>>(m, "CryptoContext", py::module_local())
.def(py::init<>())
.def("KeyGen", &CryptoContextImpl<DCRTPoly>::KeyGen);
}
Expand Down
24 changes: 12 additions & 12 deletions tests/Dialect/Openfhe/Emitters/emit_pybind.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
// CHECK: namespace py = pybind11;
// CHECK: void bind_common(py::module &m)
// CHECK: {
// CHECK: py::class_<PublicKeyImpl<DCRTPoly>, std::shared_ptr<PublicKeyImpl<DCRTPoly>>>(m, "PublicKey")
// CHECK: .def(py::init<>());
// CHECK: py::class_<PrivateKeyImpl<DCRTPoly>, std::shared_ptr<PrivateKeyImpl<DCRTPoly>>>(m, "PrivateKey")
// CHECK: .def(py::init<>());
// CHECK: py::class_<KeyPair<DCRTPoly>>(m, "KeyPair")
// CHECK: .def_readwrite("publicKey", &KeyPair<DCRTPoly>::publicKey)
// CHECK: .def_readwrite("secretKey", &KeyPair<DCRTPoly>::secretKey);
// CHECK: py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
// CHECK: .def(py::init<>());
// CHECK: py::class_<CryptoContextImpl<DCRTPoly>, std::shared_ptr<CryptoContextImpl<DCRTPoly>>>(m, "CryptoContext")
// CHECK: .def(py::init<>())
// CHECK: .def("KeyGen", &CryptoContextImpl<DCRTPoly>::KeyGen);
// CHECK: py::class_<PublicKeyImpl<DCRTPoly>, std::shared_ptr<PublicKeyImpl<DCRTPoly>>>(m, "PublicKey", py::module_local())
// CHECK: .def(py::init<>());
// CHECK: py::class_<PrivateKeyImpl<DCRTPoly>, std::shared_ptr<PrivateKeyImpl<DCRTPoly>>>(m, "PrivateKey", py::module_local())
// CHECK: .def(py::init<>());
// CHECK: py::class_<KeyPair<DCRTPoly>>(m, "KeyPair", py::module_local())
// CHECK: .def_readwrite("publicKey", &KeyPair<DCRTPoly>::publicKey)
// CHECK: .def_readwrite("secretKey", &KeyPair<DCRTPoly>::secretKey);
// CHECK: py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext", py::module_local())
// CHECK: .def(py::init<>());
// CHECK: py::class_<CryptoContextImpl<DCRTPoly>, std::shared_ptr<CryptoContextImpl<DCRTPoly>>>(m, "CryptoContext", py::module_local())
// CHECK: .def(py::init<>())
// CHECK: .def("KeyGen", &CryptoContextImpl<DCRTPoly>::KeyGen);
// CHECK: }

// CHECK: PYBIND11_MODULE(_heir_foo, m) {
Expand Down
Loading