Skip to content

Commit 9fd8287

Browse files
Lattigo: Add Negate op and e2e CMUX example
1 parent 2ab4d48 commit 9fd8287

File tree

12 files changed

+257
-28
lines changed

12 files changed

+257
-28
lines changed

lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ struct ConvertRlweCommutativePlainOp : public OpConversionPattern<PlainOp> {
290290
};
291291

292292
// Lattigo API enforces ciphertext, plaintext ordering.
293-
template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp>
293+
template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp,
294+
typename LattigoAddOp, typename LattigoNegateOp>
294295
struct ConvertRlweSubPlainOp : public OpConversionPattern<PlainOp> {
295296
using OpConversionPattern<PlainOp>::OpConversionPattern;
296297

@@ -302,21 +303,28 @@ struct ConvertRlweSubPlainOp : public OpConversionPattern<PlainOp> {
302303
if (failed(result)) return result;
303304

304305
Value evaluator = result.value();
305-
Value ciphertext = adaptor.getLhs();
306-
Value plaintext = adaptor.getRhs();
307-
if (isa<lwe::NewLWECiphertextType>(adaptor.getLhs().getType())) {
306+
if (isa<lattigo::RLWECiphertextType>(adaptor.getLhs().getType())) {
308307
// Lattigo API enforces ciphertext, plaintext ordering, so we can use
309308
// LattigoPlainOp directly.
309+
Value ciphertext = adaptor.getLhs();
310+
Value plaintext = adaptor.getRhs();
310311
rewriter.replaceOpWithNewOp<LattigoPlainOp>(
311312
op, this->typeConverter->convertType(op.getOutput().getType()),
312313
evaluator, ciphertext, plaintext);
313314
return success();
314315
}
315316

316-
// TODO(#1623): Support this case by lowering it to Add(Negate(ciphertext),
317-
// plaintext).
318-
return op.emitOpError() << "subplain op does not support plaintext, "
319-
"ciphertext operand ordering";
317+
// handle plaintext - ciphertext using (-ciphertext) + plaintext
318+
Value plaintext = adaptor.getLhs();
319+
Value ciphertext = adaptor.getRhs();
320+
321+
auto negated = rewriter.create<LattigoNegateOp>(
322+
op.getLoc(), this->typeConverter->convertType(op.getOutput().getType()),
323+
evaluator, ciphertext);
324+
rewriter.replaceOpWithNewOp<LattigoAddOp>(
325+
op, this->typeConverter->convertType(op.getOutput().getType()),
326+
evaluator, negated, plaintext);
327+
return success();
320328
}
321329
};
322330

@@ -478,8 +486,10 @@ struct ConvertLWEReinterpretApplicationData
478486
LogicalResult matchAndRewrite(
479487
lwe::ReinterpretApplicationDataOp op, OpAdaptor adaptor,
480488
ConversionPatternRewriter &rewriter) const override {
481-
// erase reinterpret underlying
482-
rewriter.replaceOp(op, adaptor.getOperands()[0].getDefiningOp());
489+
// Erase reinterpret application data.
490+
// If operand has no defining op, we can not replace it with defining op.
491+
rewriter.replaceAllOpUsesWith(op, adaptor.getOperands()[0]);
492+
rewriter.eraseOp(op);
483493
return success();
484494
}
485495
};
@@ -498,7 +508,8 @@ using ConvertBGVAddPlainOp =
498508
lattigo::BGVAddNewOp>;
499509
using ConvertBGVSubPlainOp =
500510
ConvertRlweSubPlainOp<lattigo::BGVEvaluatorType, bgv::SubPlainOp,
501-
lattigo::BGVSubNewOp>;
511+
lattigo::BGVSubNewOp, lattigo::BGVAddNewOp,
512+
lattigo::RLWENegateNewOp>;
502513
using ConvertBGVMulPlainOp =
503514
ConvertRlweCommutativePlainOp<lattigo::BGVEvaluatorType, bgv::MulPlainOp,
504515
lattigo::BGVMulNewOp>;
@@ -545,7 +556,8 @@ using ConvertCKKSAddPlainOp =
545556
lattigo::CKKSAddNewOp>;
546557
using ConvertCKKSSubPlainOp =
547558
ConvertRlweSubPlainOp<lattigo::CKKSEvaluatorType, ckks::SubPlainOp,
548-
lattigo::CKKSSubNewOp>;
559+
lattigo::CKKSSubNewOp, lattigo::CKKSAddNewOp,
560+
lattigo::RLWENegateNewOp>;
549561
using ConvertCKKSMulPlainOp =
550562
ConvertRlweCommutativePlainOp<lattigo::CKKSEvaluatorType, ckks::MulPlainOp,
551563
lattigo::CKKSMulNewOp>;

lib/Dialect/Lattigo/IR/LattigoRLWEOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,31 @@ def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InplaceOpInterface]>
151151
let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
152152
}
153153

154+
def Lattigo_RLWENegateNewOp : Lattigo_RLWEOp<"negate_new"> {
155+
let summary = "Negate a ciphertext";
156+
let arguments = (ins
157+
Lattigo_RLWEEvaluator:$evaluator,
158+
Lattigo_RLWECiphertext:$input
159+
);
160+
let results = (outs Lattigo_RLWECiphertext:$output);
161+
}
162+
163+
def Lattigo_RLWENegateOp : Lattigo_RLWEOp<"negate", [InplaceOpInterface]> {
164+
let summary = "Negate of a ciphertext";
165+
let description = [{
166+
This operation negates a ciphertext
167+
168+
The result will be written to the `inplace` operand. The `output` result is
169+
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
170+
}];
171+
let arguments = (ins
172+
Lattigo_RLWEEvaluator:$evaluator,
173+
Lattigo_RLWECiphertext:$input,
174+
Lattigo_RLWECiphertext:$inplace
175+
);
176+
let results = (outs Lattigo_RLWECiphertext:$output);
177+
178+
let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
179+
}
180+
154181
#endif // LIB_DIALECT_LATTIGO_IR_LATTIGORLWEOPS_TD_

lib/Dialect/Lattigo/Transforms/AllocToInplace.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ struct AllocToInplace : impl::AllocToInplaceBase<AllocToInplace> {
331331
ConvertUnaryOp<lattigo::CKKSRescaleNewOp, lattigo::CKKSRescaleOp>,
332332
ConvertRotateOp<lattigo::CKKSRotateNewOp, lattigo::CKKSRotateOp>,
333333
// RLWE
334+
ConvertUnaryOp<lattigo::RLWENegateNewOp, lattigo::RLWENegateOp>,
334335
ConvertDropLevelOp<lattigo::RLWEDropLevelNewOp,
335336
lattigo::RLWEDropLevelOp>>(context, &liveness,
336337
&blockToStorageInfo);

lib/Target/Lattigo/LattigoEmitter.cpp

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ LogicalResult LattigoEmitter::translate(Operation &op) {
6565
RLWENewEncryptorOp, RLWENewDecryptorOp, RLWENewKeyGeneratorOp,
6666
RLWEGenKeyPairOp, RLWEGenRelinearizationKeyOp, RLWEGenGaloisKeyOp,
6767
RLWENewEvaluationKeySetOp, RLWEEncryptOp, RLWEDecryptOp,
68-
RLWEDropLevelNewOp, RLWEDropLevelOp,
68+
RLWEDropLevelNewOp, RLWEDropLevelOp, RLWENegateNewOp,
69+
RLWENegateOp,
6970
// BGV
7071
BGVNewParametersFromLiteralOp, BGVNewEncoderOp, BGVNewEvaluatorOp,
7172
BGVNewPlaintextOp, BGVEncodeOp, BGVDecodeOp, BGVAddNewOp,
@@ -224,9 +225,16 @@ LogicalResult LattigoEmitter::printOperation(arith::ConstantOp op) {
224225
})
225226
.Case<DenseElementsAttr>([&](DenseElementsAttr denseAttr) {
226227
if (succeeded(denseAttr.tryGetValues<APInt>())) {
227-
valueString = "[]int64{";
228-
for (auto value : denseAttr.getValues<APInt>()) {
229-
valueString += std::to_string(value.getSExtValue()) + ", ";
228+
if (denseAttr.getType().getElementType().isInteger(1)) {
229+
valueString = "[]bool{";
230+
for (auto value : denseAttr.getValues<APInt>()) {
231+
valueString += value.getBoolValue() ? "true, " : "false, ";
232+
}
233+
} else {
234+
valueString = "[]int64{";
235+
for (auto value : denseAttr.getValues<APInt>()) {
236+
valueString += std::to_string(value.getSExtValue()) + ", ";
237+
}
230238
}
231239
} else if (succeeded(denseAttr.tryGetValues<APFloat>())) {
232240
valueString = "[]float64{";
@@ -357,6 +365,50 @@ LogicalResult LattigoEmitter::printOperation(RLWEDropLevelOp op) {
357365
return success();
358366
}
359367

368+
LogicalResult LattigoEmitter::printOperation(RLWENegateNewOp op) {
369+
// there is no NegateNew method in Lattigo, manually create new
370+
// ciphertext
371+
os << getName(op.getOutput()) << " := " << getName(op.getInput())
372+
<< ".CopyNew()\n";
373+
// for i := 0; i < len(out.Value); i++ {
374+
// evaluator.GetRLWEParameters().RingQ().AtLevel(out.LevelQ()).Neg(out.Value[i],
375+
// out.Value[i])
376+
// }
377+
auto indexName = getName(op.getOutput()) + "_index";
378+
os << "for " << indexName << " := 0; " << indexName << " < len("
379+
<< getName(op.getOutput()) << ".Value); " << indexName << "++ {\n";
380+
os.indent();
381+
os << getName(op.getEvaluator()) << ".GetRLWEParameters().RingQ().AtLevel("
382+
<< getName(op.getOutput()) << ".LevelQ()).Neg(" << getName(op.getOutput())
383+
<< ".Value[" << indexName << "], " << getName(op.getOutput()) << ".Value["
384+
<< indexName << "])\n";
385+
os.unindent();
386+
os << "}\n";
387+
return success();
388+
}
389+
390+
LogicalResult LattigoEmitter::printOperation(RLWENegateOp op) {
391+
if (getName(op.getOutput()) != getName(op.getInput())) {
392+
os << getName(op.getInput()) << ".Copy(" << getName(op.getOutput())
393+
<< ")\n";
394+
}
395+
// for i := 0; i < len(out.Value); i++ {
396+
// evaluator.GetRLWEParameters().RingQ().AtLevel(out.LevelQ()).Neg(out.Value[i],
397+
// out.Value[i])
398+
// }
399+
auto indexName = getName(op.getOutput()) + "_index";
400+
os << "for " << indexName << " := 0; " << indexName << " < len("
401+
<< getName(op.getOutput()) << ".Value); " << indexName << "++ {\n";
402+
os.indent();
403+
os << getName(op.getEvaluator()) << ".GetRLWEParameters().RingQ().AtLevel("
404+
<< getName(op.getOutput()) << ".LevelQ()).Neg(" << getName(op.getOutput())
405+
<< ".Value[" << indexName << "], " << getName(op.getOutput()) << ".Value["
406+
<< indexName << "])\n";
407+
os.unindent();
408+
os << "}\n";
409+
return success();
410+
}
411+
360412
// BGV
361413

362414
LogicalResult LattigoEmitter::printOperation(BGVNewEncoderOp op) {
@@ -408,8 +460,28 @@ LogicalResult LattigoEmitter::printOperation(BGVEncodeOp op) {
408460
os << maxSlotsName << ")\n";
409461
os << "for i := range " << packedName << " {\n";
410462
os.indent();
411-
os << packedName << "[i] = int64(" << getName(op.getValue()) << "[i % len("
412-
<< getName(op.getValue()) << ")])\n";
463+
if (getElementTypeOrSelf(op.getValue().getType()).getIntOrFloatBitWidth() ==
464+
1) {
465+
// if value[i] {
466+
// packedName[i] = 1
467+
// } else {
468+
// packedName[i] = 0
469+
// }
470+
os << "if " << getName(op.getValue()) << "[i % len("
471+
<< getName(op.getValue()) << ")] {\n";
472+
os.indent();
473+
os << packedName << "[i] = 1\n";
474+
os.unindent();
475+
os << "} else {\n";
476+
os.indent();
477+
os << packedName << "[i] = 0\n";
478+
os.unindent();
479+
os << "}\n";
480+
} else {
481+
// packedName[i] = int64(value[i % len(value)])
482+
os << packedName << "[i] = int64(" << getName(op.getValue()) << "[i % len("
483+
<< getName(op.getValue()) << ")])\n";
484+
}
413485
os.unindent();
414486
os << "}\n";
415487

@@ -619,8 +691,28 @@ LogicalResult LattigoEmitter::printOperation(CKKSEncodeOp op) {
619691
os << maxSlotsName << ")\n";
620692
os << "for i := range " << packedName << " {\n";
621693
os.indent();
622-
os << packedName << "[i] = float64(" << getName(op.getValue()) << "[i \% len("
623-
<< getName(op.getValue()) << ")])\n";
694+
if (getElementTypeOrSelf(op.getValue().getType()).getIntOrFloatBitWidth() ==
695+
1) {
696+
// if value[i] {
697+
// packedName[i] = 1.0
698+
// } else {
699+
// packedName[i] = 0.0
700+
// }
701+
os << "if " << getName(op.getValue()) << "[i % len("
702+
<< getName(op.getValue()) << ")] {\n";
703+
os.indent();
704+
os << packedName << "[i] = 1.0\n";
705+
os.unindent();
706+
os << "} else {\n";
707+
os.indent();
708+
os << packedName << "[i] = 0.0\n";
709+
os.unindent();
710+
os << "}\n";
711+
} else {
712+
// packedName[i] = float64(value[i % len(value)])
713+
os << packedName << "[i] = float64(" << getName(op.getValue())
714+
<< "[i \% len(" << getName(op.getValue()) << ")])\n";
715+
}
624716
os.unindent();
625717
os << "}\n";
626718

@@ -859,6 +951,9 @@ FailureOr<std::string> LattigoEmitter::convertType(Type type) {
859951
[&](auto ty) { return std::string("ckks.Parameters"); })
860952
.Case<IntegerType>([&](auto ty) -> FailureOr<std::string> {
861953
auto width = ty.getWidth();
954+
if (width == 1) {
955+
return std::string("bool");
956+
}
862957
if (width != 8 && width != 16 && width != 32 && width != 64) {
863958
return failure();
864959
}

lib/Target/Lattigo/LattigoEmitter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class LattigoEmitter {
7171
LogicalResult printOperation(RLWEDecryptOp op);
7272
LogicalResult printOperation(RLWEDropLevelNewOp op);
7373
LogicalResult printOperation(RLWEDropLevelOp op);
74+
LogicalResult printOperation(RLWENegateNewOp op);
75+
LogicalResult printOperation(RLWENegateOp op);
7476
// BGV
7577
LogicalResult printOperation(BGVNewParametersFromLiteralOp op);
7678
LogicalResult printOperation(BGVNewEncoderOp op);

tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,26 @@ module attributes {scheme.bgv} {
197197
return %ct1 : !lattigo.rlwe.ciphertext
198198
}
199199
}
200+
201+
// -----
202+
203+
// func test_negate_new(evaluator *bgv.Evaluator, ct *rlwe.Ciphertext) (*rlwe.Ciphertext) {
204+
// ct1 := ct.CopyNew()
205+
// for ct1_index := 0; ct1_index < len(ct1.Value); ct1_index++ {
206+
// evaluator.GetRLWEParameters().RingQ().AtLevel(ct1.LevelQ()).Neg(ct1.Value[ct1_index], ct1.Value[ct1_index])
207+
// }
208+
// return ct1
209+
// }
210+
211+
module attributes {scheme.bgv} {
212+
// CHECK-LABEL: func test_negate_new
213+
// CHECK-SAME: ([[evaluator:.*]] *bgv.Evaluator, [[ct:.*]] *rlwe.Ciphertext) (*rlwe.Ciphertext)
214+
func.func @test_negate_new(%evaluator: !lattigo.bgv.evaluator, %ct: !lattigo.rlwe.ciphertext) -> (!lattigo.rlwe.ciphertext) {
215+
// CHECK: [[ct1:[^, ]*]] := [[ct]].CopyNew()
216+
// CHECK: for [[i:[^, ]*]] := 0; [[i]] < len([[ct1]].Value); [[i]]++ {
217+
// CHECK: [[evaluator]].GetRLWEParameters().RingQ().AtLevel([[ct1]].LevelQ()).Neg([[ct1]].Value[[[i]]], [[ct1]].Value[[[i]]])
218+
// CHECK: }
219+
%negated = lattigo.rlwe.negate_new %evaluator, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
220+
return %negated : !lattigo.rlwe.ciphertext
221+
}
222+
}

tests/Dialect/Lattigo/IR/rlwe_ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,18 @@ module {
104104
%ct1 = lattigo.rlwe.drop_level %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct
105105
return
106106
}
107+
108+
// CHECK-LABEL: func @test_rlwe_negate_new
109+
func.func @test_rlwe_negate_new(%evaluator : !evaluator, %ct: !ct) {
110+
// CHECK: %[[v1:.*]] = lattigo.rlwe.negate_new
111+
%ct1 = lattigo.rlwe.negate_new %evaluator, %ct : (!evaluator, !ct) -> !ct
112+
return
113+
}
114+
115+
// CHECK-LABEL: func @test_rlwe_negate
116+
func.func @test_rlwe_negate(%evaluator : !evaluator, %ct: !ct) {
117+
// CHECK: %[[v1:.*]] = lattigo.rlwe.negate
118+
%0 = lattigo.rlwe.negate %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> (!ct)
119+
return
120+
}
107121
}

tests/Examples/common/cmux.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
func.func @cmux(%a: i64, %b: i64, %cond: i1 {secret.secret}) -> (i64) {
2+
%2 = scf.if %cond -> (i64) {
3+
scf.yield %a : i64
4+
} else {
5+
scf.yield %b : i64
6+
}
7+
func.return %2 : i64
8+
}

tests/Examples/lattigo/bgv/cmux/BUILD

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# See README.md for setup required to run these tests
2+
3+
load("@heir//tests/Examples/lattigo:test.bzl", "heir_lattigo_lib")
4+
load("@rules_go//go:def.bzl", "go_test")
5+
6+
package(default_applicable_licenses = ["@heir//:license"])
7+
8+
heir_lattigo_lib(
9+
name = "cmux",
10+
go_library_name = "main",
11+
heir_opt_flags = [
12+
"--annotate-module=backend=lattigo scheme=bgv",
13+
"--mlir-to-bgv",
14+
"--scheme-to-lattigo",
15+
],
16+
mlir_src = "cmux.mlir",
17+
)
18+
19+
# For Google-internal reasons we must separate the go_test rules from the macro
20+
# above.
21+
22+
go_test(
23+
name = "cmux_test",
24+
srcs = ["cmux_test.go"],
25+
embed = [":main"],
26+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../common/cmux.mlir
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package main
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestBinops(t *testing.T) {
8+
evaluator, params, ecd, enc, dec := cmux__configure()
9+
10+
// Vector of plaintext values
11+
arg0 := []int64{1, 2, 3, 4, 5, 6, 7, 8}
12+
arg1 := []int64{2, 3, 4, 5, 6, 7, 8, 9}
13+
cond := []bool{true, false, true, true, false, true, false, true}
14+
expected := []int64{1, 3, 3, 4, 6, 6, 8, 8}
15+
16+
for i := 0; i < len(arg0); i++ {
17+
condEncrypted := cmux__encrypt__arg2(evaluator, params, ecd, enc, cond[i])
18+
19+
resultCt := cmux(evaluator, params, ecd, arg0[i], arg1[i], condEncrypted)
20+
21+
result := cmux__decrypt__result0(evaluator, params, ecd, dec, resultCt)
22+
23+
if result != expected[i] {
24+
t.Errorf("Decryption error %d != %d", result, expected[i])
25+
}
26+
}
27+
}

0 commit comments

Comments
 (0)