Skip to content

Commit 8bda758

Browse files
asraacopybara-github
authored andcommitted
refactor: rename tosa-to-boolean-XYZ to --mlir-to-cggi and --to-XYZ
Follow-ups: * I'd also like to unify the arith-to-cggi (non-Yosys) paths with the `--mlir-to-cggi` but I don't know how possible it is given the compile time yosys dependency. I think I could make it work, tbh. In a follow-up #1646 * All the backends that support parallism should use the JaxiteBackendOptions, for now I just wanted to maintain parity with the codebase so only the jaxite pipeline has this option. #1645 fixes #1573 PiperOrigin-RevId: 741191624
1 parent 093b81e commit 8bda758

22 files changed

+161
-155
lines changed

docs/content/en/docs/Design/secret.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,9 @@ func.func @dot_product(%arg0: !ty1, %arg1: !ty1) -> !ty2 {
208208

209209
## Differences for CGGI-style pipeline
210210

211-
The `tosa-to-boolean-tfhe` and related pipelines add a few additional steps. The
212-
main goal here is to apply a hardware circuit optimizer to blocks of standard
213-
MLIR code (inside `secret.generic` ops) which converts the computation to an
211+
The `mlir-to-cggi` and related pipelines add a few additional steps. The main
212+
goal here is to apply a hardware circuit optimizer to blocks of standard MLIR
213+
code (inside `secret.generic` ops) which converts the computation to an
214214
optimized boolean circuit with a desired set of gates. Only then is
215215
`-secret-distribute-generic` applied to split the ops up and lower them to the
216216
`cggi` dialect. In particular, because passing an IR through the circuit
@@ -245,9 +245,9 @@ func.func @main(%arg0: tensor<1x1xi8> {secret.secret}) -> tensor<1x16xi32> {
245245
}
246246
```
247247

248-
After running `--tosa-to-boolean-tfhe` and dumping the IR after the linalg ops
249-
are lowered to loops, we can see the `secret.separator` ops enclose the lowered
250-
ops, with the exception of some pure ops that are speculatively executed.
248+
After running `--mlir-to-cggi` and dumping the IR after the linalg ops are
249+
lowered to loops, we can see the `secret.separator` ops enclose the lowered ops,
250+
with the exception of some pure ops that are speculatively executed.
251251

252252
```mlir
253253
func.func @main(%arg0: memref<1x1xi8, strided<[?, ?], offset: ?>> {secret.secret}) -> memref<1x16xi32> {

docs/content/en/docs/pipelines.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ location of the Yosys' techlib files that are needed to execute the path.
6161
This pass can be disabled by defining `HEIR_NO_YOSYS`; this will avoid Yosys
6262
library and ABC binary compilation, and avoid registration of this pass.
6363

64-
### `--tosa-to-boolean-tfhe`
64+
### `--mlir-to-cggi`
6565

66-
This is an experimental pipeline for end-to-end private inference.
66+
This is an experimental pipeline for lowering standard MLIR (including TOSA) to
67+
CGGI.
6768

6869
Converts a TOSA MLIR model to tfhe_rust dialect defined by HEIR. It converts a
69-
tosa model to optimized boolean circuit using Yosys ABC optimizations. The
70+
TOSA model to optimized boolean circuit using Yosys ABC optimizations. The
7071
resultant optimized boolean circuit in comb dialect is then converted to cggi
7172
and then to tfhe_rust exit dialect. This pipeline can be used with
7273
heir-translate --emit-tfhe-rust to generate code for
+81-98
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "lib/Pipelines/BooleanPipelineRegistration.h"
22

3-
#include <functional>
43
#include <memory>
54
#include <string>
65
#include <vector>
@@ -40,11 +39,17 @@ namespace mlir::heir {
4039
static std::vector<std::string> opsToDistribute = {"secret.separator"};
4140
static std::vector<unsigned> bitWidths = {1, 2, 4, 8, 16};
4241

43-
void tosaToCGGIPipelineBuilder(OpPassManager &pm,
44-
const TosaToBooleanTfheOptions &options,
45-
const std::string &yosysFilesPath,
46-
const std::string &abcPath,
47-
bool abcBooleanGates) {
42+
CGGIPipelineBuilder mlirToCGGIPipelineBuilder(const std::string &yosysFilesPath,
43+
const std::string &abcPath) {
44+
return [=](OpPassManager &pm, const MLIRToCGGIPipelineOptions &options) {
45+
mlirToCGGIPipeline(pm, options, yosysFilesPath, abcPath);
46+
};
47+
}
48+
49+
void mlirToCGGIPipeline(OpPassManager &pm,
50+
const MLIRToCGGIPipelineOptions &options,
51+
const std::string &yosysFilesPath,
52+
const std::string &abcPath) {
4853
// TOSA to linalg
4954
::mlir::heir::tosaToLinalg(pm);
5055

@@ -89,9 +94,9 @@ void tosaToCGGIPipelineBuilder(OpPassManager &pm,
8994
pm.addPass(createCanonicalizerPass());
9095

9196
// Booleanize and Yosys Optimize
92-
pm.addPass(createYosysOptimizer(
93-
yosysFilesPath, abcPath, options.abcFast, options.unrollFactor,
94-
/*useSubmodules=*/true, abcBooleanGates ? Mode::Boolean : Mode::LUT));
97+
pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, options.abcFast,
98+
options.unrollFactor, options.useSubmodules,
99+
options.mode));
95100

96101
// Cleanup
97102
pm.addPass(mlir::createCSEPass());
@@ -115,99 +120,77 @@ void tosaToCGGIPipelineBuilder(OpPassManager &pm,
115120
pm.addPass(createSCCPPass());
116121
}
117122

118-
void registerTosaToBooleanTfhePipeline(const std::string &yosysFilesPath,
119-
const std::string &abcPath) {
120-
PassPipelineRegistration<TosaToBooleanTfheOptions>(
121-
"tosa-to-boolean-tfhe", "Arithmetic modules to boolean tfhe-rs pipeline.",
122-
[yosysFilesPath, abcPath](OpPassManager &pm,
123-
const TosaToBooleanTfheOptions &options) {
124-
tosaToCGGIPipelineBuilder(pm, options, yosysFilesPath, abcPath,
125-
/*abcBooleanGates=*/false);
126-
127-
// CGGI to Tfhe-Rust exit dialect
128-
pm.addPass(createCGGIToTfheRust());
129-
// CSE must be run before canonicalizer, so that redundant ops are
130-
// cleared before the canonicalizer hoists TfheRust ops.
131-
pm.addPass(createCSEPass());
132-
pm.addPass(createCanonicalizerPass());
133-
134-
// Cleanup loads and stores
135-
pm.addPass(createExpandCopyPass(
136-
ExpandCopyPassOptions{.disableAffineLoop = true}));
137-
pm.addPass(memref::createFoldMemRefAliasOpsPass());
138-
pm.addPass(createForwardStoreToLoad());
139-
pm.addPass(createCanonicalizerPass());
140-
pm.addPass(createCSEPass());
141-
pm.addPass(createSCCPPass());
142-
});
123+
CGGIBackendPipelineBuilder toTfheRsPipelineBuilder() {
124+
return [=](OpPassManager &pm) {
125+
// CGGI to Tfhe-Rust exit dialect
126+
pm.addPass(createCGGIToTfheRust());
127+
// CSE must be run before canonicalizer, so that redundant ops are
128+
// cleared before the canonicalizer hoists TfheRust ops.
129+
pm.addPass(createCSEPass());
130+
pm.addPass(createCanonicalizerPass());
131+
132+
// Cleanup loads and stores
133+
pm.addPass(
134+
createExpandCopyPass(ExpandCopyPassOptions{.disableAffineLoop = true}));
135+
pm.addPass(memref::createFoldMemRefAliasOpsPass());
136+
pm.addPass(createForwardStoreToLoad());
137+
pm.addPass(createCanonicalizerPass());
138+
pm.addPass(createCSEPass());
139+
pm.addPass(createSCCPPass());
140+
};
143141
}
144142

145-
void registerTosaToBooleanFpgaTfhePipeline(const std::string &yosysFilesPath,
146-
const std::string &abcPath) {
147-
PassPipelineRegistration<TosaToBooleanTfheOptions>(
148-
"tosa-to-boolean-fpga-tfhe",
149-
"Arithmetic modules to boolean tfhe-rs for FPGA backend pipeline.",
150-
[yosysFilesPath, abcPath](OpPassManager &pm,
151-
const TosaToBooleanTfheOptions &options) {
152-
tosaToCGGIPipelineBuilder(pm, options, yosysFilesPath, abcPath,
153-
/*abcBooleanGates=*/true);
154-
155-
// Vectorize CGGI operations
156-
pm.addPass(cggi::createBooleanVectorizer());
157-
pm.addPass(createCanonicalizerPass());
158-
pm.addPass(createCSEPass());
159-
pm.addPass(createSCCPPass());
160-
161-
// CGGI to Tfhe-Rust exit dialect
162-
pm.addPass(createCGGIToTfheRustBool());
163-
// CSE must be run before canonicalizer, so that redundant ops are
164-
// cleared before the canonicalizer hoists TfheRust ops.
165-
pm.addPass(createCSEPass());
166-
pm.addPass(createCanonicalizerPass());
167-
168-
// Cleanup loads and stores
169-
pm.addPass(createExpandCopyPass(
170-
ExpandCopyPassOptions{.disableAffineLoop = true}));
171-
pm.addPass(memref::createFoldMemRefAliasOpsPass());
172-
pm.addPass(createForwardStoreToLoad());
173-
pm.addPass(createCanonicalizerPass());
174-
pm.addPass(createCSEPass());
175-
pm.addPass(createSCCPPass());
176-
});
143+
CGGIBackendPipelineBuilder toFptPipelineBuilder() {
144+
return [=](OpPassManager &pm) {
145+
// Vectorize CGGI operations
146+
pm.addPass(cggi::createBooleanVectorizer());
147+
pm.addPass(createCanonicalizerPass());
148+
pm.addPass(createCSEPass());
149+
pm.addPass(createSCCPPass());
150+
151+
// CGGI to Tfhe-Rust exit dialect
152+
pm.addPass(createCGGIToTfheRustBool());
153+
// CSE must be run before canonicalizer, so that redundant ops are
154+
// cleared before the canonicalizer hoists TfheRust ops.
155+
pm.addPass(createCSEPass());
156+
pm.addPass(createCanonicalizerPass());
157+
158+
// Cleanup loads and stores
159+
pm.addPass(
160+
createExpandCopyPass(ExpandCopyPassOptions{.disableAffineLoop = true}));
161+
pm.addPass(memref::createFoldMemRefAliasOpsPass());
162+
pm.addPass(createForwardStoreToLoad());
163+
pm.addPass(createCanonicalizerPass());
164+
pm.addPass(createCSEPass());
165+
pm.addPass(createSCCPPass());
166+
};
177167
}
178168

179-
void registerTosaToJaxitePipeline(const std::string &yosysFilesPath,
180-
const std::string &abcPath) {
181-
PassPipelineRegistration<TosaToBooleanJaxiteOptions>(
182-
"tosa-to-boolean-jaxite", "Arithmetic modules to jaxite pipeline.",
183-
[yosysFilesPath, abcPath](OpPassManager &pm,
184-
const TosaToBooleanJaxiteOptions &options) {
185-
tosaToCGGIPipelineBuilder(pm, options, yosysFilesPath, abcPath,
186-
/*abcBooleanGates=*/false);
187-
if (options.parallelism > 0) {
188-
pm.addPass(
189-
cggi::createBooleanVectorizer(cggi::BooleanVectorizerOptions{
190-
.parallelism = options.parallelism}));
191-
pm.addPass(createCSEPass());
192-
pm.addPass(createRemoveDeadValuesPass());
193-
}
194-
195-
// CGGI to Jaxite exit dialect
196-
pm.addPass(createCGGIToJaxite());
197-
// CSE must be run before canonicalizer, so that redundant ops are
198-
// cleared before the canonicalizer hoists TfheRust ops.
199-
pm.addPass(createCSEPass());
200-
pm.addPass(createCanonicalizerPass());
201-
202-
// Cleanup loads and stores
203-
pm.addPass(createExpandCopyPass(
204-
ExpandCopyPassOptions{.disableAffineLoop = true}));
205-
pm.addPass(memref::createFoldMemRefAliasOpsPass());
206-
pm.addPass(createForwardStoreToLoad());
207-
pm.addPass(createCanonicalizerPass());
208-
pm.addPass(createCSEPass());
209-
pm.addPass(createSCCPPass());
210-
});
169+
JaxiteBackendPipelineBuilder toJaxitePipelineBuilder() {
170+
return [=](OpPassManager &pm, const CGGIBackendOptions &options) {
171+
if (options.parallelism > 0) {
172+
pm.addPass(cggi::createBooleanVectorizer(
173+
cggi::BooleanVectorizerOptions{.parallelism = options.parallelism}));
174+
pm.addPass(createCSEPass());
175+
pm.addPass(createRemoveDeadValuesPass());
176+
}
177+
178+
// CGGI to Jaxite exit dialect
179+
pm.addPass(createCGGIToJaxite());
180+
// CSE must be run before canonicalizer, so that redundant ops are
181+
// cleared before the canonicalizer hoists TfheRust ops.
182+
pm.addPass(createCSEPass());
183+
pm.addPass(createCanonicalizerPass());
184+
185+
// Cleanup loads and stores
186+
pm.addPass(
187+
createExpandCopyPass(ExpandCopyPassOptions{.disableAffineLoop = true}));
188+
pm.addPass(memref::createFoldMemRefAliasOpsPass());
189+
pm.addPass(createForwardStoreToLoad());
190+
pm.addPass(createCanonicalizerPass());
191+
pm.addPass(createCSEPass());
192+
pm.addPass(createSCCPPass());
193+
};
211194
}
212195

213196
} // namespace mlir::heir

lib/Pipelines/BooleanPipelineRegistration.h

+24-26
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,49 @@
11
#ifndef LIB_PIPELINES_BOOLEANPIPELINEREGISTRATION_H_
22
#define LIB_PIPELINES_BOOLEANPIPELINEREGISTRATION_H_
33

4+
#include <functional>
45
#include <string>
56

7+
#include "lib/Transforms/YosysOptimizer/YosysOptimizer.h"
68
#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project
79
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
810
#include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project
911
#include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project
1012

1113
namespace mlir::heir {
1214

13-
struct TosaToBooleanTfheOptions
14-
: public PassPipelineOptions<TosaToBooleanTfheOptions> {
15-
PassOptions::Option<bool> abcFast{*this, "abc-fast",
16-
llvm::cl::desc("Run abc in fast mode."),
17-
llvm::cl::init(false)};
18-
19-
PassOptions::Option<int> unrollFactor{
20-
*this, "unroll-factor",
21-
llvm::cl::desc("Unroll loops by a given factor before optimizing. A "
22-
"value of zero (default) prevents unrolling."),
23-
llvm::cl::init(0)};
24-
};
15+
struct MLIRToCGGIPipelineOptions : public YosysOptimizerPipelineOptions {};
2516

26-
struct TosaToBooleanJaxiteOptions : public TosaToBooleanTfheOptions {
17+
struct CGGIBackendOptions : public PassPipelineOptions<CGGIBackendOptions> {
2718
PassOptions::Option<int> parallelism{
2819
*this, "parallelism",
2920
llvm::cl::desc(
30-
"batching size for parallel execution on tpu. A value of 0 is no "
21+
"batching size for parallelism. A value of -1 (default) is infinite "
3122
"parallelism"),
32-
llvm::cl::init(0)};
23+
llvm::cl::init(-1)};
3324
};
3425

35-
void tosaToCGGIPipelineBuilder(OpPassManager &pm,
36-
const TosaToBooleanTfheOptions &options,
37-
const std::string &yosysFilesPath,
38-
const std::string &abcPath,
39-
bool abcBooleanGates);
26+
using CGGIPipelineBuilder =
27+
std::function<void(OpPassManager &, const MLIRToCGGIPipelineOptions &)>;
28+
29+
using CGGIBackendPipelineBuilder = std::function<void(OpPassManager &)>;
30+
31+
using JaxiteBackendPipelineBuilder =
32+
std::function<void(OpPassManager &, const CGGIBackendOptions &)>;
33+
34+
CGGIPipelineBuilder mlirToCGGIPipelineBuilder(const std::string &yosysFilesPath,
35+
const std::string &abcPath);
36+
37+
void mlirToCGGIPipeline(OpPassManager &pm,
38+
const MLIRToCGGIPipelineOptions &options,
39+
const std::string &yosysFilesPath,
40+
const std::string &abcPath);
4041

41-
void registerTosaToBooleanTfhePipeline(const std::string &yosysFilesPath,
42-
const std::string &abcPath);
42+
CGGIBackendPipelineBuilder toTfheRsPipelineBuilder();
4343

44-
void registerTosaToBooleanFpgaTfhePipeline(const std::string &yosysFilesPath,
45-
const std::string &abcPath);
44+
CGGIBackendPipelineBuilder toFptPipelineBuilder();
4645

47-
void registerTosaToJaxitePipeline(const std::string &yosysFilesPath,
48-
const std::string &abcPath);
46+
JaxiteBackendPipelineBuilder toJaxitePipelineBuilder();
4947

5048
} // namespace mlir::heir
5149

scripts/gcp/examples/BUILD

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ package(
99

1010
fhe_jaxite_lib(
1111
name = "add_one_lut3",
12-
heir_opt_pass_flags = ["--tosa-to-boolean-jaxite"],
12+
heir_opt_pass_flags = [
13+
"--mlir-to-cggi",
14+
"--scheme-to-jaxite",
15+
],
1316
mlir_src = "add_one_lut3.mlir",
1417
py_lib_target_name = "add_one_lut3_py_lib",
1518
)

tests/Examples/jaxite/BUILD

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ py_library(
2424

2525
jaxite_end_to_end_test(
2626
name = "add_one_lut3",
27-
heir_opt_pass_flags = ["--tosa-to-boolean-jaxite"],
27+
heir_opt_pass_flags = [
28+
"--mlir-to-cggi",
29+
"--scheme-to-jaxite",
30+
],
2831
mlir_src = "add_one_lut3.mlir",
2932
test_src = "add_one_lut3_test.py",
3033
deps = [":test_utils"],

tests/Examples/jaxite/add_one_lut3.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: heir-opt --tosa-to-boolean-jaxite %s | heir-translate --emit-jaxite | FileCheck %s
1+
// RUN: heir-opt --mlir-to-cggi=abc-fast=true --scheme-to-jaxite %s | heir-translate --emit-jaxite | FileCheck %s
22

33
module {
44
// CHECK-LABEL: def test_add_one_lut3(

tests/Examples/jaxite/pmap_add_one_lut3.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: heir-opt --tosa-to-boolean-jaxite="parallelism=4" %s | heir-translate --emit-jaxite | FileCheck %s
1+
// RUN: heir-opt --mlir-to-cggi=abc-fast=true --scheme-to-jaxite="parallelism=4" %s | heir-translate --emit-jaxite | FileCheck %s
22

33
module {
44
// CHECK-LABEL: def test_add_one_lut3(

tests/Examples/tfhe_rust/test_fully_connected.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: heir-opt --tosa-to-boolean-tfhe="abc-fast=true" %s | heir-translate --emit-tfhe-rust > %S/src/fn_under_test.rs
1+
// RUN: heir-opt --mlir-to-cggi --scheme-to-tfhe-rs %s | heir-translate --emit-tfhe-rust > %S/src/fn_under_test.rs
22
// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_fully_connected -- 2 --message_bits=3 | FileCheck %s
33

44
// This takes takes the input x and outputs 2 \cdot x + 1.

tests/Examples/tfhe_rust_bool/fpga/test_fully_connected.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: heir-opt --tosa-to-boolean-fpga-tfhe="abc-fast=true" %s | heir-translate --emit-tfhe-rust-bool-packed > %S/src/fn_under_test.rs
1+
// RUN: heir-opt --mlir-to-cggi=abc-fast=true --scheme-to-fpt %s | heir-translate --emit-tfhe-rust-bool-packed > %S/src/fn_under_test.rs
22
// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_fully_connected -- 2 | FileCheck %s
33

44
// This takes takes the input x and outputs a FC layer operation.

tests/Transforms/tosa_to_boolean_tfhe/add_one.mlir tests/Transforms/mlir_to_tfhe_rs/add_one.mlir

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
// RUN: heir-opt --tosa-to-boolean-tfhe %s | FileCheck %s
2-
3-
// While this is not a TOSA model, it should still lower through the pipeline.
1+
// RUN: heir-opt --mlir-to-cggi --scheme-to-tfhe-rs %s | FileCheck %s
42

53
module {
64
// CHECK: @add_one([[sks:.*]]: !tfhe_rust.server_key, [[arg:.*]]: memref<8x!tfhe_rust.eui3>)

tests/Transforms/tosa_to_boolean_tfhe/fully_connected.mlir tests/Transforms/mlir_to_tfhe_rs/fully_connected.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: heir-opt --tosa-to-boolean-tfhe=abc-fast=true %s | FileCheck %s
1+
// RUN: heir-opt --mlir-to-cggi=abc-fast=true --scheme-to-tfhe-rs %s | FileCheck %s
22

33
#map = affine_map<(d0, d1) -> (0)>
44
#map1 = affine_map<(d0, d1) -> (d0, d1)>

tests/Transforms/tosa_to_boolean_tfhe/hello_world.mlir tests/Transforms/mlir_to_tfhe_rs/hello_world.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: heir-opt --tosa-to-boolean-tfhe=abc-fast=true %s | FileCheck %s
1+
// RUN: heir-opt --mlir-to-cggi=abc-fast=true --scheme-to-tfhe-rs %s | FileCheck %s
22

33
// CHECK-LABEL: module
44
#map = affine_map<(d0, d1) -> (d1)>

0 commit comments

Comments
 (0)