-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathPopulateScaleCKKS.cpp
108 lines (94 loc) · 4.53 KB
/
PopulateScaleCKKS.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <cstdint>
#include <utility>
#include "lib/Analysis/ScaleAnalysis/ScaleAnalysis.h"
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/CKKS/IR/CKKSAttributes.h"
#include "lib/Dialect/CKKS/IR/CKKSDialect.h"
#include "lib/Dialect/Mgmt/Transforms/AnnotateMgmt.h"
#include "lib/Dialect/ModuleAttributes.h"
#include "lib/Parameters/CKKS/Params.h"
#include "lib/Transforms/PopulateScale/PopulateScale.h"
#include "lib/Transforms/PopulateScale/PopulateScalePatterns.h"
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project
namespace mlir {
namespace heir {
class CKKSAdjustScaleMaterializer : public AdjustScaleMaterializer {
public:
virtual ~CKKSAdjustScaleMaterializer() = default;
int64_t deltaScale(int64_t scale, int64_t inputScale) const override {
// TODO(#1640): support high-precision scale management
return scale - inputScale;
}
};
#define GEN_PASS_DEF_POPULATESCALECKKS
#include "lib/Transforms/PopulateScale/PopulateScale.h.inc"
struct PopulateScaleCKKS : impl::PopulateScaleCKKSBase<PopulateScaleCKKS> {
using PopulateScaleCKKSBase::PopulateScaleCKKSBase;
void runOnOperation() override {
// skip scale management for openfhe
if (moduleIsOpenfhe(getOperation())) {
return;
}
auto ckksSchemeParamAttr = mlir::dyn_cast<ckks::SchemeParamAttr>(
getOperation()->getAttr(ckks::CKKSDialect::kSchemeParamAttrName));
auto logDefaultScale = ckksSchemeParamAttr.getLogDefaultScale();
DataFlowSolver solver;
SymbolTableCollection symbolTable;
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
// ScaleAnalysis depends on SecretnessAnalysis
solver.load<SecretnessAnalysis>();
// set input scale to logDefaultScale
auto inputScale = logDefaultScale;
if (beforeMulIncludeFirstMul) {
// encode at double degree
inputScale *= 2;
}
solver.load<ScaleAnalysis<CKKSScaleModel>>(
ckks::SchemeParam::getSchemeParamFromAttr(ckksSchemeParamAttr),
/*inputScale*/ inputScale);
// Back-prop ScaleAnalysis depends on (forward) ScaleAnalysis
solver.load<ScaleAnalysisBackward<CKKSScaleModel>>(
symbolTable,
ckks::SchemeParam::getSchemeParamFromAttr(ckksSchemeParamAttr));
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}
// at this time all adjust_scale should have ScaleLattice for its result.
// all plaintext (mgmt.init) should have ScaleLattice for its result.
// pass scale to AnnotateMgmt pass
annotateScale(getOperation(), &solver);
OpPassManager annotateMgmt("builtin.module");
annotateMgmt.addPass(mgmt::createAnnotateMgmt());
(void)runPipeline(annotateMgmt, getOperation());
// convert adjust scale to mul plain
RewritePatternSet patterns(&getContext());
CKKSAdjustScaleMaterializer materializer;
// TODO(#1641): handle arith.muli in CKKS
patterns.add<ConvertAdjustScaleToMulPlain<arith::MulFOp>>(&getContext(),
&materializer);
walkAndApplyPatterns(getOperation(), std::move(patterns));
// run canonicalizer and CSE to clean up arith.constant and move no-op out
// of the secret.generic
OpPassManager pipeline("builtin.module");
pipeline.addPass(createCanonicalizerPass());
pipeline.addPass(createCSEPass());
(void)runPipeline(pipeline, getOperation());
}
};
} // namespace heir
} // namespace mlir