Skip to content

Commit 913f911

Browse files
ecalubaquibtensorflower-gardener
authored andcommitted
Copy and use flex_portable_tensorflow_deps under mlir/lite
PiperOrigin-RevId: 650532052
1 parent 37a8362 commit 913f911

File tree

3 files changed

+208
-3
lines changed

3 files changed

+208
-3
lines changed

tensorflow/compiler/mlir/lite/delegates/flex/BUILD

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ load(
66
"tf_features_nolayering_check_if_ios",
77
)
88
load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
9+
load("//tensorflow/compiler/mlir/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library")
910
load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist")
1011

11-
# TODO(b/321735756) Copy tflite_flex_cc_library under mlir/ to remove compiler dep with lite/
12-
load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library")
13-
1412
default_visibility = [
1513
"//tensorflow/compiler/mlir/lite:__subpackages__",
1614
"//tensorflow/lite/android:__subpackages__",
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Generate custom flex delegate library."""
2+
3+
load(
4+
"//tensorflow:tensorflow.bzl",
5+
"clean_dep",
6+
"if_android",
7+
"if_ios",
8+
"if_mobile",
9+
"tf_cc_binary",
10+
"tf_copts",
11+
"tf_defines_nortti_if_lite_protos",
12+
"tf_features_nolayering_check_if_ios",
13+
"tf_features_nomodules_if_mobile",
14+
"tf_opts_nortti_if_lite_protos",
15+
"tf_portable_full_lite_protos",
16+
)
17+
load("//tensorflow/compiler/mlir/lite:special_rules.bzl", "flex_portable_tensorflow_deps")
18+
19+
# LINT.IfChange
20+
21+
def generate_flex_kernel_header(
22+
name,
23+
models,
24+
testonly = 0,
25+
additional_deps = []):
26+
"""A rule to generate a header file listing only used operators.
27+
28+
Args:
29+
name: Name of the generated library.
30+
models: TFLite models to interpret.
31+
testonly: Should be marked as true if additional_deps is testonly.
32+
additional_deps: Dependencies for additional TF ops.
33+
34+
Returns:
35+
A struct with 'header' and 'include_path' fields that
36+
contain the generated header and the required include entry.
37+
"""
38+
include_path = "%s_tf_generated_kernel_header" % name
39+
header = include_path + "/ops_to_register.h"
40+
41+
if type(models) != type([]):
42+
models = [models]
43+
44+
# List all flex ops from models.
45+
model_file_args = " --graphs=%s" % ",".join(
46+
["$(location %s)" % f for f in models],
47+
)
48+
list_ops_output = include_path + "/list_flex_ops"
49+
list_ops_tool = clean_dep("//tensorflow/lite/tools:list_flex_ops_main")
50+
if additional_deps:
51+
tf_cc_binary(
52+
name = "%s_list_flex_ops_main" % name,
53+
deps = [
54+
clean_dep("//tensorflow/lite/tools:list_flex_ops_main_lib"),
55+
] + additional_deps,
56+
testonly = testonly,
57+
)
58+
list_ops_tool = ":%s_list_flex_ops_main" % name
59+
native.genrule(
60+
name = "%s_list_flex_ops" % name,
61+
srcs = models,
62+
outs = [list_ops_output],
63+
tools = [list_ops_tool],
64+
message = "Listing flex ops from %s..." % ",".join(models),
65+
cmd = ("$(location " + list_ops_tool + ")" +
66+
model_file_args + " > \"$@\""),
67+
testonly = testonly,
68+
)
69+
70+
# Generate the kernel registration header file from list of flex ops.
71+
tool = clean_dep("//tensorflow/python/tools:print_selective_registration_header")
72+
native.genrule(
73+
name = "%s_kernel_registration" % name,
74+
srcs = [list_ops_output],
75+
outs = [header],
76+
tools = [tool],
77+
message = "Processing %s..." % list_ops_output,
78+
cmd = ("$(location " + tool + ")" +
79+
" --default_ops=\"\"" +
80+
" --proto_fileformat=ops_list" +
81+
" --graphs=" + "$(location " + list_ops_output + ") > \"$@\""),
82+
)
83+
return struct(include_path = include_path, header = header)
84+
85+
def tflite_flex_cc_library(
86+
name,
87+
models = [],
88+
additional_deps = [],
89+
testonly = 0,
90+
visibility = ["//visibility:public"],
91+
link_symbol = True,
92+
compatible_with = None):
93+
"""A rule to generate a flex delegate with only ops to run listed models.
94+
95+
Args:
96+
name: Name of the generated flex delegate.
97+
models: TFLite models to interpret. The library will only include ops and kernels
98+
to support these models. If empty, the library will include all Tensorflow
99+
ops and kernels.
100+
additional_deps: Dependencies for additional TF ops.
101+
testonly: Mark this library as testonly if true.
102+
visibility: visibility of the generated rules.
103+
link_symbol: If true, add delegate_symbol to deps.
104+
compatible_with: The standard compatible_with attribute.
105+
"""
106+
portable_tensorflow_lib = clean_dep("//tensorflow/core:portable_tensorflow_lib")
107+
if models:
108+
CUSTOM_KERNEL_HEADER = generate_flex_kernel_header(
109+
name = "%s_tf_op_headers" % name,
110+
models = models,
111+
additional_deps = additional_deps,
112+
testonly = testonly,
113+
)
114+
115+
# Define a custom tensorflow_lib with selective registration.
116+
# The library will only contain ops exist in provided models.
117+
native.cc_library(
118+
name = "%s_tensorflow_lib" % name,
119+
srcs = if_mobile([
120+
clean_dep("//tensorflow/core:portable_op_registrations_and_gradients"),
121+
clean_dep("//tensorflow/core/kernels:portable_core_ops"),
122+
clean_dep("//tensorflow/core/kernels:portable_extended_ops"),
123+
]) + [CUSTOM_KERNEL_HEADER.header],
124+
copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]),
125+
compatible_with = compatible_with,
126+
defines = [
127+
"SELECTIVE_REGISTRATION",
128+
"SUPPORT_SELECTIVE_REGISTRATION",
129+
"EIGEN_NEON_GEBP_NR=4",
130+
] + tf_portable_full_lite_protos(
131+
full = [],
132+
lite = ["TENSORFLOW_LITE_PROTOS"],
133+
) + tf_defines_nortti_if_lite_protos(),
134+
features = tf_features_nomodules_if_mobile() + tf_features_nolayering_check_if_ios(),
135+
linkopts = if_android(["-lz"]) + if_ios(["-lz"]),
136+
includes = [
137+
CUSTOM_KERNEL_HEADER.include_path,
138+
],
139+
textual_hdrs = [
140+
clean_dep("//tensorflow/core/kernels:portable_all_ops_textual_hdrs"),
141+
],
142+
visibility = visibility,
143+
deps = flex_portable_tensorflow_deps() + [
144+
clean_dep("@ducc//:fft_wrapper"),
145+
clean_dep("//tensorflow/core:protos_all_cc"),
146+
clean_dep("//tensorflow/core:portable_tensorflow_lib_lite"),
147+
clean_dep("//tensorflow/core/platform:strong_hash"),
148+
clean_dep("//tensorflow/lite/delegates/flex:portable_images_lib"),
149+
],
150+
alwayslink = 1,
151+
testonly = testonly,
152+
)
153+
portable_tensorflow_lib = ":%s_tensorflow_lib" % name
154+
155+
delegate_symbol = []
156+
if link_symbol:
157+
delegate_symbol.append(clean_dep("//tensorflow/lite/delegates/flex:delegate_symbol"))
158+
159+
# Define a custom flex delegate with above tensorflow_lib.
160+
native.cc_library(
161+
name = name,
162+
hdrs = [
163+
clean_dep("//tensorflow/lite/delegates/flex:delegate.h"),
164+
],
165+
features = tf_features_nolayering_check_if_ios(),
166+
compatible_with = compatible_with,
167+
visibility = visibility,
168+
deps = [
169+
clean_dep("//tensorflow/lite/delegates/flex:delegate_data"),
170+
clean_dep("//tensorflow/lite/delegates/flex:delegate_only_runtime"),
171+
clean_dep("//tensorflow/lite/delegates/utils:simple_delegate"),
172+
] + select({
173+
clean_dep("//tensorflow:android"): [
174+
portable_tensorflow_lib,
175+
],
176+
clean_dep("//tensorflow:ios"): [
177+
portable_tensorflow_lib,
178+
],
179+
clean_dep("//tensorflow:chromiumos"): [
180+
portable_tensorflow_lib,
181+
],
182+
"//conditions:default": [
183+
clean_dep("//tensorflow/core:tensorflow"),
184+
clean_dep("//tensorflow/lite/core/c:private_common"),
185+
],
186+
}) + additional_deps + delegate_symbol,
187+
testonly = testonly,
188+
alwayslink = 1,
189+
)
190+
191+
# LINT.ThenChange(//tensorflow/lite/delegates/flex/build_def.bzl)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
"""External versions of build rules that differ outside of Google."""
22

3+
def flex_portable_tensorflow_deps():
4+
"""Returns dependencies for building portable tensorflow in Flex delegate."""
5+
6+
return [
7+
"//third_party/fft2d:fft2d_headers",
8+
"@com_google_absl//absl/log",
9+
"@com_google_absl//absl/log:check",
10+
"@com_google_absl//absl/strings",
11+
"@com_google_absl//absl/strings:str_format",
12+
"@com_google_absl//absl/types:optional",
13+
"@eigen_archive//:eigen3",
14+
"@gemmlowp",
15+
"@icu//:common",
16+
"//third_party/icu/data:conversion_data",
17+
]
18+
319
def tflite_copts_extra():
420
"""Defines extra compile time flags for tflite_copts(). Currently empty."""
521
return []

0 commit comments

Comments
 (0)