Skip to content

Commit bfdefac

Browse files
Lattigo: Add Negate op and e2e CMUX example
1 parent 55d883d commit bfdefac

File tree

12 files changed

+344
-63
lines changed

12 files changed

+344
-63
lines changed

lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ struct ConvertRlweCommutativePlainOp : public OpConversionPattern<PlainOp> {
290290
};
291291

292292
// Lattigo API enforces ciphertext, plaintext ordering.
293-
template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp>
293+
template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp,
294+
typename LattigoAddOp, typename LattigoNegateOp>
294295
struct ConvertRlweSubPlainOp : public OpConversionPattern<PlainOp> {
295296
using OpConversionPattern<PlainOp>::OpConversionPattern;
296297

@@ -302,21 +303,28 @@ struct ConvertRlweSubPlainOp : public OpConversionPattern<PlainOp> {
302303
if (failed(result)) return result;
303304

304305
Value evaluator = result.value();
305-
Value ciphertext = adaptor.getLhs();
306-
Value plaintext = adaptor.getRhs();
307306
if (isa<lattigo::RLWECiphertextType>(adaptor.getLhs().getType())) {
308307
// Lattigo API enforces ciphertext, plaintext ordering, so we can use
309308
// LattigoPlainOp directly.
309+
Value ciphertext = adaptor.getLhs();
310+
Value plaintext = adaptor.getRhs();
310311
rewriter.replaceOpWithNewOp<LattigoPlainOp>(
311312
op, this->typeConverter->convertType(op.getOutput().getType()),
312313
evaluator, ciphertext, plaintext);
313314
return success();
314315
}
315316

316-
// TODO(#1623): Support this case by lowering it to Add(Negate(ciphertext),
317-
// plaintext).
318-
return op.emitOpError() << "subplain op does not support plaintext, "
319-
"ciphertext operand ordering";
317+
// handle plaintext - ciphertext using (-ciphertext) + plaintext
318+
Value plaintext = adaptor.getLhs();
319+
Value ciphertext = adaptor.getRhs();
320+
321+
auto negated = rewriter.create<LattigoNegateOp>(
322+
op.getLoc(), this->typeConverter->convertType(op.getOutput().getType()),
323+
evaluator, ciphertext);
324+
rewriter.replaceOpWithNewOp<LattigoAddOp>(
325+
op, this->typeConverter->convertType(op.getOutput().getType()),
326+
evaluator, negated, plaintext);
327+
return success();
320328
}
321329
};
322330

@@ -478,8 +486,10 @@ struct ConvertLWEReinterpretApplicationData
478486
LogicalResult matchAndRewrite(
479487
lwe::ReinterpretApplicationDataOp op, OpAdaptor adaptor,
480488
ConversionPatternRewriter &rewriter) const override {
481-
// erase reinterpret underlying
482-
rewriter.replaceOp(op, adaptor.getOperands()[0].getDefiningOp());
489+
// Erase reinterpret application data.
490+
// If operand has no defining op, we can not replace it with defining op.
491+
rewriter.replaceAllOpUsesWith(op, adaptor.getOperands()[0]);
492+
rewriter.eraseOp(op);
483493
return success();
484494
}
485495
};
@@ -498,7 +508,8 @@ using ConvertBGVAddPlainOp =
498508
lattigo::BGVAddNewOp>;
499509
using ConvertBGVSubPlainOp =
500510
ConvertRlweSubPlainOp<lattigo::BGVEvaluatorType, bgv::SubPlainOp,
501-
lattigo::BGVSubNewOp>;
511+
lattigo::BGVSubNewOp, lattigo::BGVAddNewOp,
512+
lattigo::RLWENegateNewOp>;
502513
using ConvertBGVMulPlainOp =
503514
ConvertRlweCommutativePlainOp<lattigo::BGVEvaluatorType, bgv::MulPlainOp,
504515
lattigo::BGVMulNewOp>;
@@ -545,7 +556,8 @@ using ConvertCKKSAddPlainOp =
545556
lattigo::CKKSAddNewOp>;
546557
using ConvertCKKSSubPlainOp =
547558
ConvertRlweSubPlainOp<lattigo::CKKSEvaluatorType, ckks::SubPlainOp,
548-
lattigo::CKKSSubNewOp>;
559+
lattigo::CKKSSubNewOp, lattigo::CKKSAddNewOp,
560+
lattigo::RLWENegateNewOp>;
549561
using ConvertCKKSMulPlainOp =
550562
ConvertRlweCommutativePlainOp<lattigo::CKKSEvaluatorType, ckks::MulPlainOp,
551563
lattigo::CKKSMulNewOp>;

lib/Dialect/Lattigo/IR/LattigoRLWEOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,31 @@ def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InplaceOpInterface]>
151151
let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
152152
}
153153

154+
def Lattigo_RLWENegateNewOp : Lattigo_RLWEOp<"negate_new"> {
155+
let summary = "Negate a ciphertext";
156+
let arguments = (ins
157+
Lattigo_RLWEEvaluator:$evaluator,
158+
Lattigo_RLWECiphertext:$input
159+
);
160+
let results = (outs Lattigo_RLWECiphertext:$output);
161+
}
162+
163+
def Lattigo_RLWENegateOp : Lattigo_RLWEOp<"negate", [InplaceOpInterface]> {
164+
let summary = "Negate of a ciphertext";
165+
let description = [{
166+
This operation negates a ciphertext
167+
168+
The result will be written to the `inplace` operand. The `output` result is
169+
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
170+
}];
171+
let arguments = (ins
172+
Lattigo_RLWEEvaluator:$evaluator,
173+
Lattigo_RLWECiphertext:$input,
174+
Lattigo_RLWECiphertext:$inplace
175+
);
176+
let results = (outs Lattigo_RLWECiphertext:$output);
177+
178+
let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
179+
}
180+
154181
#endif // LIB_DIALECT_LATTIGO_IR_LATTIGORLWEOPS_TD_

lib/Dialect/Lattigo/Transforms/AllocToInplace.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ struct AllocToInplace : impl::AllocToInplaceBase<AllocToInplace> {
331331
ConvertUnaryOp<lattigo::CKKSRescaleNewOp, lattigo::CKKSRescaleOp>,
332332
ConvertRotateOp<lattigo::CKKSRotateNewOp, lattigo::CKKSRotateOp>,
333333
// RLWE
334+
ConvertUnaryOp<lattigo::RLWENegateNewOp, lattigo::RLWENegateOp>,
334335
ConvertDropLevelOp<lattigo::RLWEDropLevelNewOp,
335336
lattigo::RLWEDropLevelOp>>(context, &liveness,
336337
&blockToStorageInfo);

0 commit comments

Comments
 (0)