|
| 1 | +"""Generate custom flex delegate library.""" |
| 2 | + |
| 3 | +load( |
| 4 | + "//tensorflow:tensorflow.bzl", |
| 5 | + "clean_dep", |
| 6 | + "if_android", |
| 7 | + "if_ios", |
| 8 | + "if_mobile", |
| 9 | + "tf_cc_binary", |
| 10 | + "tf_copts", |
| 11 | + "tf_defines_nortti_if_lite_protos", |
| 12 | + "tf_features_nolayering_check_if_ios", |
| 13 | + "tf_features_nomodules_if_mobile", |
| 14 | + "tf_opts_nortti_if_lite_protos", |
| 15 | + "tf_portable_full_lite_protos", |
| 16 | +) |
| 17 | +load("//tensorflow/compiler/mlir/lite:special_rules.bzl", "flex_portable_tensorflow_deps") |
| 18 | + |
| 19 | +# LINT.IfChange |
| 20 | + |
| 21 | +def generate_flex_kernel_header( |
| 22 | + name, |
| 23 | + models, |
| 24 | + testonly = 0, |
| 25 | + additional_deps = []): |
| 26 | + """A rule to generate a header file listing only used operators. |
| 27 | +
|
| 28 | + Args: |
| 29 | + name: Name of the generated library. |
| 30 | + models: TFLite models to interpret. |
| 31 | + testonly: Should be marked as true if additional_deps is testonly. |
| 32 | + additional_deps: Dependencies for additional TF ops. |
| 33 | +
|
| 34 | + Returns: |
| 35 | + A struct with 'header' and 'include_path' fields that |
| 36 | + contain the generated header and the required include entry. |
| 37 | + """ |
| 38 | + include_path = "%s_tf_generated_kernel_header" % name |
| 39 | + header = include_path + "/ops_to_register.h" |
| 40 | + |
| 41 | + if type(models) != type([]): |
| 42 | + models = [models] |
| 43 | + |
| 44 | + # List all flex ops from models. |
| 45 | + model_file_args = " --graphs=%s" % ",".join( |
| 46 | + ["$(location %s)" % f for f in models], |
| 47 | + ) |
| 48 | + list_ops_output = include_path + "/list_flex_ops" |
| 49 | + list_ops_tool = clean_dep("//tensorflow/lite/tools:list_flex_ops_main") |
| 50 | + if additional_deps: |
| 51 | + tf_cc_binary( |
| 52 | + name = "%s_list_flex_ops_main" % name, |
| 53 | + deps = [ |
| 54 | + clean_dep("//tensorflow/lite/tools:list_flex_ops_main_lib"), |
| 55 | + ] + additional_deps, |
| 56 | + testonly = testonly, |
| 57 | + ) |
| 58 | + list_ops_tool = ":%s_list_flex_ops_main" % name |
| 59 | + native.genrule( |
| 60 | + name = "%s_list_flex_ops" % name, |
| 61 | + srcs = models, |
| 62 | + outs = [list_ops_output], |
| 63 | + tools = [list_ops_tool], |
| 64 | + message = "Listing flex ops from %s..." % ",".join(models), |
| 65 | + cmd = ("$(location " + list_ops_tool + ")" + |
| 66 | + model_file_args + " > \"$@\""), |
| 67 | + testonly = testonly, |
| 68 | + ) |
| 69 | + |
| 70 | + # Generate the kernel registration header file from list of flex ops. |
| 71 | + tool = clean_dep("//tensorflow/python/tools:print_selective_registration_header") |
| 72 | + native.genrule( |
| 73 | + name = "%s_kernel_registration" % name, |
| 74 | + srcs = [list_ops_output], |
| 75 | + outs = [header], |
| 76 | + tools = [tool], |
| 77 | + message = "Processing %s..." % list_ops_output, |
| 78 | + cmd = ("$(location " + tool + ")" + |
| 79 | + " --default_ops=\"\"" + |
| 80 | + " --proto_fileformat=ops_list" + |
| 81 | + " --graphs=" + "$(location " + list_ops_output + ") > \"$@\""), |
| 82 | + ) |
| 83 | + return struct(include_path = include_path, header = header) |
| 84 | + |
| 85 | +def tflite_flex_cc_library( |
| 86 | + name, |
| 87 | + models = [], |
| 88 | + additional_deps = [], |
| 89 | + testonly = 0, |
| 90 | + visibility = ["//visibility:public"], |
| 91 | + link_symbol = True, |
| 92 | + compatible_with = None): |
| 93 | + """A rule to generate a flex delegate with only ops to run listed models. |
| 94 | +
|
| 95 | + Args: |
| 96 | + name: Name of the generated flex delegate. |
| 97 | + models: TFLite models to interpret. The library will only include ops and kernels |
| 98 | + to support these models. If empty, the library will include all Tensorflow |
| 99 | + ops and kernels. |
| 100 | + additional_deps: Dependencies for additional TF ops. |
| 101 | + testonly: Mark this library as testonly if true. |
| 102 | + visibility: visibility of the generated rules. |
| 103 | + link_symbol: If true, add delegate_symbol to deps. |
| 104 | + compatible_with: The standard compatible_with attribute. |
| 105 | + """ |
| 106 | + portable_tensorflow_lib = clean_dep("//tensorflow/core:portable_tensorflow_lib") |
| 107 | + if models: |
| 108 | + CUSTOM_KERNEL_HEADER = generate_flex_kernel_header( |
| 109 | + name = "%s_tf_op_headers" % name, |
| 110 | + models = models, |
| 111 | + additional_deps = additional_deps, |
| 112 | + testonly = testonly, |
| 113 | + ) |
| 114 | + |
| 115 | + # Define a custom tensorflow_lib with selective registration. |
| 116 | + # The library will only contain ops exist in provided models. |
| 117 | + native.cc_library( |
| 118 | + name = "%s_tensorflow_lib" % name, |
| 119 | + srcs = if_mobile([ |
| 120 | + clean_dep("//tensorflow/core:portable_op_registrations_and_gradients"), |
| 121 | + clean_dep("//tensorflow/core/kernels:portable_core_ops"), |
| 122 | + clean_dep("//tensorflow/core/kernels:portable_extended_ops"), |
| 123 | + ]) + [CUSTOM_KERNEL_HEADER.header], |
| 124 | + copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]), |
| 125 | + compatible_with = compatible_with, |
| 126 | + defines = [ |
| 127 | + "SELECTIVE_REGISTRATION", |
| 128 | + "SUPPORT_SELECTIVE_REGISTRATION", |
| 129 | + "EIGEN_NEON_GEBP_NR=4", |
| 130 | + ] + tf_portable_full_lite_protos( |
| 131 | + full = [], |
| 132 | + lite = ["TENSORFLOW_LITE_PROTOS"], |
| 133 | + ) + tf_defines_nortti_if_lite_protos(), |
| 134 | + features = tf_features_nomodules_if_mobile() + tf_features_nolayering_check_if_ios(), |
| 135 | + linkopts = if_android(["-lz"]) + if_ios(["-lz"]), |
| 136 | + includes = [ |
| 137 | + CUSTOM_KERNEL_HEADER.include_path, |
| 138 | + ], |
| 139 | + textual_hdrs = [ |
| 140 | + clean_dep("//tensorflow/core/kernels:portable_all_ops_textual_hdrs"), |
| 141 | + ], |
| 142 | + visibility = visibility, |
| 143 | + deps = flex_portable_tensorflow_deps() + [ |
| 144 | + clean_dep("@ducc//:fft_wrapper"), |
| 145 | + clean_dep("//tensorflow/core:protos_all_cc"), |
| 146 | + clean_dep("//tensorflow/core:portable_tensorflow_lib_lite"), |
| 147 | + clean_dep("//tensorflow/core/platform:strong_hash"), |
| 148 | + clean_dep("//tensorflow/lite/delegates/flex:portable_images_lib"), |
| 149 | + ], |
| 150 | + alwayslink = 1, |
| 151 | + testonly = testonly, |
| 152 | + ) |
| 153 | + portable_tensorflow_lib = ":%s_tensorflow_lib" % name |
| 154 | + |
| 155 | + delegate_symbol = [] |
| 156 | + if link_symbol: |
| 157 | + delegate_symbol.append(clean_dep("//tensorflow/lite/delegates/flex:delegate_symbol")) |
| 158 | + |
| 159 | + # Define a custom flex delegate with above tensorflow_lib. |
| 160 | + native.cc_library( |
| 161 | + name = name, |
| 162 | + hdrs = [ |
| 163 | + clean_dep("//tensorflow/lite/delegates/flex:delegate.h"), |
| 164 | + ], |
| 165 | + features = tf_features_nolayering_check_if_ios(), |
| 166 | + compatible_with = compatible_with, |
| 167 | + visibility = visibility, |
| 168 | + deps = [ |
| 169 | + clean_dep("//tensorflow/lite/delegates/flex:delegate_data"), |
| 170 | + clean_dep("//tensorflow/lite/delegates/flex:delegate_only_runtime"), |
| 171 | + clean_dep("//tensorflow/lite/delegates/utils:simple_delegate"), |
| 172 | + ] + select({ |
| 173 | + clean_dep("//tensorflow:android"): [ |
| 174 | + portable_tensorflow_lib, |
| 175 | + ], |
| 176 | + clean_dep("//tensorflow:ios"): [ |
| 177 | + portable_tensorflow_lib, |
| 178 | + ], |
| 179 | + clean_dep("//tensorflow:chromiumos"): [ |
| 180 | + portable_tensorflow_lib, |
| 181 | + ], |
| 182 | + "//conditions:default": [ |
| 183 | + clean_dep("//tensorflow/core:tensorflow"), |
| 184 | + clean_dep("//tensorflow/lite/core/c:private_common"), |
| 185 | + ], |
| 186 | + }) + additional_deps + delegate_symbol, |
| 187 | + testonly = testonly, |
| 188 | + alwayslink = 1, |
| 189 | + ) |
| 190 | + |
| 191 | +# LINT.ThenChange(//tensorflow/lite/delegates/flex/build_def.bzl) |
0 commit comments