Skip to content

Commit 4382ab1

Browse files
committed
Support build with Tensorflow
It expects include files in /usr/include/tensorflow. * Add configure option --with-tensorflow (disabled by default) * Fix data type tensorflow::int64 * Remove "third_party/" in include statements * Add dummy implementations for Backward and DebugWeights in TFNetwork * Add files generated with protoc from tfnetwork.proto (so the Tensorflow sources are not needed for the build) * Update Makefiles Signed-off-by: Stefan Weil <[email protected]>
1 parent 3f74da5 commit 4382ab1

File tree

9 files changed

+1717
-8
lines changed

9 files changed

+1717
-8
lines changed

configure.ac

+9
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,15 @@ if test "$enable_opencl" = "yes"; then
195195
])
196196
fi
197197

198+
# Check whether to build with support for Tensorflow.
199+
AC_MSG_CHECKING([--with-tensorflow])
200+
AC_ARG_WITH([tensorflow],
201+
AS_HELP_STRING([--with-tensorflow],
202+
[support Tensorflow @<:@default=check@:>@]),
203+
[], [with_tensorflow=check])
204+
AC_MSG_RESULT([$with_tensorflow])
205+
AM_CONDITIONAL([TENSORFLOW], [test "$with_tensorflow" != "no"])
206+
198207
# https://lists.apple.com/archives/unix-porting/2009/Jan/msg00026.html
199208
m4_define([MY_CHECK_FRAMEWORK],
200209
[AC_CACHE_CHECK([if -framework $1 works],[my_cv_framework_$1],

src/api/Makefile.am

+5
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,8 @@ endif
9898
if ADD_RT
9999
tesseract_LDADD += -lrt
100100
endif
101+
102+
if TENSORFLOW
103+
tesseract_LDADD += -lprotobuf
104+
tesseract_LDADD += -ltensorflow_cc
105+
endif

src/lstm/Makefile.am

+9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ AM_CPPFLAGS += \
1010

1111
AM_CXXFLAGS = $(OPENMP_CXXFLAGS)
1212

13+
if TENSORFLOW
14+
AM_CPPFLAGS += -DINCLUDE_TENSORFLOW
15+
AM_CPPFLAGS += -I/usr/include/tensorflow
16+
endif
17+
1318
if !NO_TESSDATA_PREFIX
1419
AM_CXXFLAGS += -DTESSDATA_PREFIX=@datadir@
1520
endif
@@ -37,3 +42,7 @@ libtesseract_lstm_la_SOURCES = \
3742
networkbuilder.cpp network.cpp networkio.cpp \
3843
parallel.cpp plumbing.cpp recodebeam.cpp reconfig.cpp reversed.cpp \
3944
series.cpp stridemap.cpp tfnetwork.cpp weightmatrix.cpp
45+
46+
if TENSORFLOW
47+
libtesseract_lstm_la_SOURCES += tfnetwork.pb.cc
48+
endif

src/lstm/tfnetwork.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// Description: Encapsulation of an entire tensorflow graph as a
44
// Tesseract Network.
55
// Author: Ray Smith
6-
// Created: Fri Feb 26 09:35:29 PST 2016
76
//
87
// (C) Copyright 2016, Google Inc.
98
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -90,14 +89,14 @@ void TFNetwork::Forward(bool debug, const NetworkIO& input,
9089
if (!model_proto_.image_widths().empty()) {
9190
TensorShape size_shape{1};
9291
Tensor width_tensor(tensorflow::DT_INT64, size_shape);
93-
auto eigen_wtensor = width_tensor.flat<int64>();
92+
auto eigen_wtensor = width_tensor.flat<tensorflow::int64>();
9493
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
9594
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
9695
}
9796
if (!model_proto_.image_heights().empty()) {
9897
TensorShape size_shape{1};
9998
Tensor height_tensor(tensorflow::DT_INT64, size_shape);
100-
auto eigen_htensor = height_tensor.flat<int64>();
99+
auto eigen_htensor = height_tensor.flat<tensorflow::int64>();
101100
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
102101
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
103102
}

src/lstm/tfnetwork.h

+15-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727

2828
#include "network.h"
2929
#include "static_shape.h"
30-
#include "tfnetwork.proto.h"
31-
#include "third_party/tensorflow/core/framework/graph.pb.h"
32-
#include "third_party/tensorflow/core/public/session.h"
30+
#include "tfnetwork.pb.h"
31+
#include "tensorflow/core/framework/graph.pb.h"
32+
#include "tensorflow/core/public/session.h"
3333

3434
namespace tesseract {
3535

@@ -69,6 +69,18 @@ class TFNetwork : public Network {
6969
NetworkScratch* scratch, NetworkIO* output) override;
7070

7171
private:
72+
// Runs backward propagation of errors on the deltas line.
73+
// See Network for a detailed discussion of the arguments.
74+
bool Backward(bool debug, const NetworkIO& fwd_deltas,
75+
NetworkScratch* scratch,
76+
NetworkIO* back_deltas) override {
77+
tprintf("Must override Network::DebugWeights for type %d\n", type_);
78+
}
79+
80+
void DebugWeights() override {
81+
tprintf("Must override Network::DebugWeights for type %d\n", type_);
82+
}
83+
7284
int InitFromProto();
7385

7486
// The original network definition for reference.

0 commit comments

Comments
 (0)