@@ -290,7 +290,8 @@ struct ConvertRlweCommutativePlainOp : public OpConversionPattern<PlainOp> {
290
290
};
291
291
292
292
// Lattigo API enforces ciphertext, plaintext ordering.
293
- template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp>
293
+ template <typename EvaluatorType, typename PlainOp, typename LattigoPlainOp,
294
+ typename LattigoAddOp, typename LattigoNegateOp>
294
295
struct ConvertRlweSubPlainOp : public OpConversionPattern <PlainOp> {
295
296
using OpConversionPattern<PlainOp>::OpConversionPattern;
296
297
@@ -302,21 +303,28 @@ struct ConvertRlweSubPlainOp : public OpConversionPattern<PlainOp> {
302
303
if (failed (result)) return result;
303
304
304
305
Value evaluator = result.value ();
305
- Value ciphertext = adaptor.getLhs ();
306
- Value plaintext = adaptor.getRhs ();
307
306
if (isa<lattigo::RLWECiphertextType>(adaptor.getLhs ().getType ())) {
308
307
// Lattigo API enforces ciphertext, plaintext ordering, so we can use
309
308
// LattigoPlainOp directly.
309
+ Value ciphertext = adaptor.getLhs ();
310
+ Value plaintext = adaptor.getRhs ();
310
311
rewriter.replaceOpWithNewOp <LattigoPlainOp>(
311
312
op, this ->typeConverter ->convertType (op.getOutput ().getType ()),
312
313
evaluator, ciphertext, plaintext);
313
314
return success ();
314
315
}
315
316
316
- // TODO(#1623): Support this case by lowering it to Add(Negate(ciphertext),
317
- // plaintext).
318
- return op.emitOpError () << " subplain op does not support plaintext, "
319
- " ciphertext operand ordering" ;
317
+ // handle plaintext - ciphertext using (-ciphertext) + plaintext
318
+ Value plaintext = adaptor.getLhs ();
319
+ Value ciphertext = adaptor.getRhs ();
320
+
321
+ auto negated = rewriter.create <LattigoNegateOp>(
322
+ op.getLoc (), this ->typeConverter ->convertType (op.getOutput ().getType ()),
323
+ evaluator, ciphertext);
324
+ rewriter.replaceOpWithNewOp <LattigoAddOp>(
325
+ op, this ->typeConverter ->convertType (op.getOutput ().getType ()),
326
+ evaluator, negated, plaintext);
327
+ return success ();
320
328
}
321
329
};
322
330
@@ -478,8 +486,10 @@ struct ConvertLWEReinterpretApplicationData
478
486
LogicalResult matchAndRewrite (
479
487
lwe::ReinterpretApplicationDataOp op, OpAdaptor adaptor,
480
488
ConversionPatternRewriter &rewriter) const override {
481
- // erase reinterpret underlying
482
- rewriter.replaceOp (op, adaptor.getOperands ()[0 ].getDefiningOp ());
489
+ // Erase reinterpret application data.
490
+ // If operand has no defining op, we can not replace it with defining op.
491
+ rewriter.replaceAllOpUsesWith (op, adaptor.getOperands ()[0 ]);
492
+ rewriter.eraseOp (op);
483
493
return success ();
484
494
}
485
495
};
@@ -498,7 +508,8 @@ using ConvertBGVAddPlainOp =
498
508
lattigo::BGVAddNewOp>;
499
509
using ConvertBGVSubPlainOp =
500
510
ConvertRlweSubPlainOp<lattigo::BGVEvaluatorType, bgv::SubPlainOp,
501
- lattigo::BGVSubNewOp>;
511
+ lattigo::BGVSubNewOp, lattigo::BGVAddNewOp,
512
+ lattigo::RLWENegateNewOp>;
502
513
using ConvertBGVMulPlainOp =
503
514
ConvertRlweCommutativePlainOp<lattigo::BGVEvaluatorType, bgv::MulPlainOp,
504
515
lattigo::BGVMulNewOp>;
@@ -545,7 +556,8 @@ using ConvertCKKSAddPlainOp =
545
556
lattigo::CKKSAddNewOp>;
546
557
using ConvertCKKSSubPlainOp =
547
558
ConvertRlweSubPlainOp<lattigo::CKKSEvaluatorType, ckks::SubPlainOp,
548
- lattigo::CKKSSubNewOp>;
559
+ lattigo::CKKSSubNewOp, lattigo::CKKSAddNewOp,
560
+ lattigo::RLWENegateNewOp>;
549
561
using ConvertCKKSMulPlainOp =
550
562
ConvertRlweCommutativePlainOp<lattigo::CKKSEvaluatorType, ckks::MulPlainOp,
551
563
lattigo::CKKSMulNewOp>;
0 commit comments