Skip to content

Lattigo: Add Negate op and e2e CMUX example #1627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ struct ConvertRlweCommutativePlainOp : public OpConversionPattern<PlainOp> {
};

// Lattigo API enforces ciphertext, plaintext ordering.
template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp>
template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp,
typename LattigoAddOp, typename LattigoNegateOp>
struct ConvertRlweSubPlainOp : public OpConversionPattern<PlainOp> {
using OpConversionPattern<PlainOp>::OpConversionPattern;

Expand All @@ -302,21 +303,28 @@ struct ConvertRlweSubPlainOp : public OpConversionPattern<PlainOp> {
if (failed(result)) return result;

Value evaluator = result.value();
Value ciphertext = adaptor.getLhs();
Value plaintext = adaptor.getRhs();
if (isa<lattigo::RLWECiphertextType>(adaptor.getLhs().getType())) {
// Lattigo API enforces ciphertext, plaintext ordering, so we can use
// LattigoPlainOp directly.
Value ciphertext = adaptor.getLhs();
Value plaintext = adaptor.getRhs();
rewriter.replaceOpWithNewOp<LattigoPlainOp>(
op, this->typeConverter->convertType(op.getOutput().getType()),
evaluator, ciphertext, plaintext);
return success();
}

// TODO(#1623): Support this case by lowering it to Add(Negate(ciphertext),
// plaintext).
return op.emitOpError() << "subplain op does not support plaintext, "
"ciphertext operand ordering";
// handle plaintext - ciphertext using (-ciphertext) + plaintext
Value plaintext = adaptor.getLhs();
Value ciphertext = adaptor.getRhs();

auto negated = rewriter.create<LattigoNegateOp>(
op.getLoc(), this->typeConverter->convertType(op.getOutput().getType()),
evaluator, ciphertext);
rewriter.replaceOpWithNewOp<LattigoAddOp>(
op, this->typeConverter->convertType(op.getOutput().getType()),
evaluator, negated, plaintext);
return success();
}
};

Expand Down Expand Up @@ -478,8 +486,10 @@ struct ConvertLWEReinterpretApplicationData
LogicalResult matchAndRewrite(
lwe::ReinterpretApplicationDataOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// erase reinterpret underlying
rewriter.replaceOp(op, adaptor.getOperands()[0].getDefiningOp());
// Erase reinterpret application data.
// If operand has no defining op, we can not replace it with defining op.
rewriter.replaceAllOpUsesWith(op, adaptor.getOperands()[0]);
rewriter.eraseOp(op);
return success();
}
};
Expand All @@ -498,7 +508,8 @@ using ConvertBGVAddPlainOp =
lattigo::BGVAddNewOp>;
using ConvertBGVSubPlainOp =
ConvertRlweSubPlainOp<lattigo::BGVEvaluatorType, bgv::SubPlainOp,
lattigo::BGVSubNewOp>;
lattigo::BGVSubNewOp, lattigo::BGVAddNewOp,
lattigo::RLWENegateNewOp>;
using ConvertBGVMulPlainOp =
ConvertRlweCommutativePlainOp<lattigo::BGVEvaluatorType, bgv::MulPlainOp,
lattigo::BGVMulNewOp>;
Expand Down Expand Up @@ -545,7 +556,8 @@ using ConvertCKKSAddPlainOp =
lattigo::CKKSAddNewOp>;
using ConvertCKKSSubPlainOp =
ConvertRlweSubPlainOp<lattigo::CKKSEvaluatorType, ckks::SubPlainOp,
lattigo::CKKSSubNewOp>;
lattigo::CKKSSubNewOp, lattigo::CKKSAddNewOp,
lattigo::RLWENegateNewOp>;
using ConvertCKKSMulPlainOp =
ConvertRlweCommutativePlainOp<lattigo::CKKSEvaluatorType, ckks::MulPlainOp,
lattigo::CKKSMulNewOp>;
Expand Down
27 changes: 27 additions & 0 deletions lib/Dialect/Lattigo/IR/LattigoRLWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,31 @@ def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InplaceOpInterface]>
let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
}

def Lattigo_RLWENegateNewOp : Lattigo_RLWEOp<"negate_new"> {
let summary = "Negate a ciphertext";
let arguments = (ins
Lattigo_RLWEEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input
);
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_RLWENegateOp : Lattigo_RLWEOp<"negate", [InplaceOpInterface]> {
let summary = "Negate of a ciphertext";
let description = [{
This operation negates a ciphertext

The result will be written to the `inplace` operand. The `output` result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
let arguments = (ins
Lattigo_RLWEEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
Lattigo_RLWECiphertext:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
}

#endif // LIB_DIALECT_LATTIGO_IR_LATTIGORLWEOPS_TD_
1 change: 1 addition & 0 deletions lib/Dialect/Lattigo/Transforms/AllocToInplace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ struct AllocToInplace : impl::AllocToInplaceBase<AllocToInplace> {
ConvertUnaryOp<lattigo::CKKSRescaleNewOp, lattigo::CKKSRescaleOp>,
ConvertRotateOp<lattigo::CKKSRotateNewOp, lattigo::CKKSRotateOp>,
// RLWE
ConvertUnaryOp<lattigo::RLWENegateNewOp, lattigo::RLWENegateOp>,
ConvertDropLevelOp<lattigo::RLWEDropLevelNewOp,
lattigo::RLWEDropLevelOp>>(context, &liveness,
&blockToStorageInfo);
Expand Down
Loading
Loading