Skip to content

Commit 12c1225

Browse files
authored
Merge pull request #2271 from stweil/refactor
Refactor class Network
2 parents 13bd96f + 98dd3b6 commit 12c1225

File tree

5 files changed

+63
-59
lines changed

5 files changed

+63
-59
lines changed

src/lstm/convolve.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
// and pulls in random data to fill out-of-input inputs.
55
// Output is therefore same size as its input, but deeper.
66
// Author: Ray Smith
7-
// Created: Tue Mar 18 16:45:34 PST 2014
87
//
98
// (C) Copyright 2014, Google Inc.
109
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -61,6 +60,11 @@ class Convolve : public Network {
6160
NetworkScratch* scratch,
6261
NetworkIO* back_deltas) override;
6362

63+
private:
64+
void DebugWeights() override {
65+
tprintf("Must override Network::DebugWeights for type %d\n", type_);
66+
}
67+
6468
protected:
6569
// Serialized data.
6670
int32_t half_x_;
@@ -69,5 +73,4 @@ class Convolve : public Network {
6973

7074
} // namespace tesseract.
7175

72-
7376
#endif // TESSERACT_LSTM_SUBSAMPLE_H_

src/lstm/input.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// File: input.h
33
// Description: Input layer class for neural network implementations.
44
// Author: Ray Smith
5-
// Created: Thu Mar 13 08:56:26 PDT 2014
65
//
76
// (C) Copyright 2014, Google Inc.
87
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -93,6 +92,10 @@ class Input : public Network {
9392
TRand* randomizer, NetworkIO* input);
9493

9594
private:
95+
void DebugWeights() override {
96+
tprintf("Must override Network::DebugWeights for type %d\n", type_);
97+
}
98+
9699
// Input shape determines how images are dealt with.
97100
StaticShape shape_;
98101
// Cached total network x scale factor for scaling bounding boxes.

src/lstm/network.cpp

+43-37
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// File: network.cpp
33
// Description: Base class for neural network implementations.
44
// Author: Ray Smith
5-
// Created: Wed May 01 17:25:06 PST 2013
65
//
76
// (C) Copyright 2013, Google Inc.
87
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -53,10 +52,11 @@ const int kMaxWinSize = 2000;
5352
const int kXWinFrameSize = 30;
5453
const int kYWinFrameSize = 80;
5554

56-
// String names corresponding to the NetworkType enum. Keep in sync.
55+
// String names corresponding to the NetworkType enum.
56+
// Keep in sync with NetworkType.
5757
// Names used in Serialization to allow re-ordering/addition/deletion of
5858
// layer types in NetworkType without invalidating existing network files.
59-
char const* const Network::kTypeNames[NT_COUNT] = {
59+
static char const* const kTypeNames[NT_COUNT] = {
6060
"Invalid", "Input",
6161
"Convolve", "Maxpool",
6262
"Parallel", "Replicated",
@@ -165,81 +165,87 @@ bool Network::Serialize(TFile* fp) const {
165165
return true;
166166
}
167167

168-
// Reads from the given file. Returns false in case of error.
169-
// Should be overridden by subclasses, but NOT called by their DeSerialize.
170-
bool Network::DeSerialize(TFile* fp) {
168+
static NetworkType getNetworkType(TFile* fp) {
171169
int8_t data;
172-
if (!fp->DeSerialize(&data)) return false;
170+
if (!fp->DeSerialize(&data)) return NT_NONE;
173171
if (data == NT_NONE) {
174172
STRING type_name;
175-
if (!type_name.DeSerialize(fp)) return false;
173+
if (!type_name.DeSerialize(fp)) return NT_NONE;
176174
for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
177175
}
178176
if (data == NT_COUNT) {
179177
tprintf("Invalid network layer type:%s\n", type_name.string());
180-
return false;
178+
return NT_NONE;
181179
}
182180
}
183-
type_ = static_cast<NetworkType>(data);
184-
if (!fp->DeSerialize(&data)) return false;
185-
training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
186-
if (!fp->DeSerialize(&data)) return false;
187-
needs_to_backprop_ = data != 0;
188-
if (!fp->DeSerialize(&network_flags_)) return false;
189-
if (!fp->DeSerialize(&ni_)) return false;
190-
if (!fp->DeSerialize(&no_)) return false;
191-
if (!fp->DeSerialize(&num_weights_)) return false;
192-
if (!name_.DeSerialize(fp)) return false;
193-
return true;
181+
return static_cast<NetworkType>(data);
194182
}
195183

196184
// Reads from the given file. Returns nullptr in case of error.
197185
// Determines the type of the serialized class and calls its DeSerialize
198186
// on a new object of the appropriate type, which is returned.
199187
Network* Network::CreateFromFile(TFile* fp) {
200-
Network stub;
201-
if (!stub.DeSerialize(fp)) return nullptr;
188+
NetworkType type; // Type of the derived network class.
189+
TrainingState training; // Are we currently training?
190+
bool needs_to_backprop; // This network needs to output back_deltas.
191+
int32_t network_flags; // Behavior control flags in NetworkFlags.
192+
int32_t ni; // Number of input values.
193+
int32_t no; // Number of output values.
194+
int32_t num_weights; // Number of weights in this and sub-network.
195+
STRING name; // A unique name for this layer.
196+
int8_t data;
202197
Network* network = nullptr;
203-
switch (stub.type_) {
198+
type = getNetworkType(fp);
199+
if (!fp->DeSerialize(&data)) return nullptr;
200+
training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
201+
if (!fp->DeSerialize(&data)) return nullptr;
202+
needs_to_backprop = data != 0;
203+
if (!fp->DeSerialize(&network_flags)) return nullptr;
204+
if (!fp->DeSerialize(&ni)) return nullptr;
205+
if (!fp->DeSerialize(&no)) return nullptr;
206+
if (!fp->DeSerialize(&num_weights)) return nullptr;
207+
if (!name.DeSerialize(fp)) return nullptr;
208+
209+
switch (type) {
204210
case NT_CONVOLVE:
205-
network = new Convolve(stub.name_, stub.ni_, 0, 0);
211+
network = new Convolve(name, ni, 0, 0);
206212
break;
207213
case NT_INPUT:
208-
network = new Input(stub.name_, stub.ni_, stub.no_);
214+
network = new Input(name, ni, no);
209215
break;
210216
case NT_LSTM:
211217
case NT_LSTM_SOFTMAX:
212218
case NT_LSTM_SOFTMAX_ENCODED:
213219
case NT_LSTM_SUMMARY:
214220
network =
215-
new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_);
221+
new LSTM(name, ni, no, no, false, type);
216222
break;
217223
case NT_MAXPOOL:
218-
network = new Maxpool(stub.name_, stub.ni_, 0, 0);
224+
network = new Maxpool(name, ni, 0, 0);
219225
break;
220226
// All variants of Parallel.
221227
case NT_PARALLEL:
222228
case NT_REPLICATED:
223229
case NT_PAR_RL_LSTM:
224230
case NT_PAR_UD_LSTM:
225231
case NT_PAR_2D_LSTM:
226-
network = new Parallel(stub.name_, stub.type_);
232+
network = new Parallel(name, type);
227233
break;
228234
case NT_RECONFIG:
229-
network = new Reconfig(stub.name_, stub.ni_, 0, 0);
235+
network = new Reconfig(name, ni, 0, 0);
230236
break;
231237
// All variants of reversed.
232238
case NT_XREVERSED:
233239
case NT_YREVERSED:
234240
case NT_XYTRANSPOSE:
235-
network = new Reversed(stub.name_, stub.type_);
241+
network = new Reversed(name, type);
236242
break;
237243
case NT_SERIES:
238-
network = new Series(stub.name_);
244+
network = new Series(name);
239245
break;
240246
case NT_TENSORFLOW:
241247
#ifdef INCLUDE_TENSORFLOW
242-
network = new TFNetwork(stub.name_);
248+
network = new TFNetwork(name);
243249
#else
244250
tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
245251
#endif
@@ -253,16 +259,16 @@ Network* Network::CreateFromFile(TFile* fp) {
253259
case NT_LOGISTIC:
254260
case NT_POSCLIP:
255261
case NT_SYMCLIP:
256-
network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_);
262+
network = new FullyConnected(name, ni, no, type);
257263
break;
258264
default:
259265
break;
260266
}
261267
if (network) {
262-
network->training_ = stub.training_;
263-
network->needs_to_backprop_ = stub.needs_to_backprop_;
264-
network->network_flags_ = stub.network_flags_;
265-
network->num_weights_ = stub.num_weights_;
268+
network->training_ = training;
269+
network->needs_to_backprop_ = needs_to_backprop;
270+
network->network_flags_ = network_flags;
271+
network->num_weights_ = num_weights;
266272
if (!network->DeSerialize(fp)) {
267273
delete network;
268274
network = nullptr;

src/lstm/network.h

+5-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// File: network.h
33
// Description: Base class for neural network implementations.
44
// Author: Ray Smith
5-
// Created: Wed May 01 16:38:06 PST 2013
65
//
76
// (C) Copyright 2013, Google Inc.
87
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -215,17 +214,16 @@ class Network {
215214
virtual void CacheXScaleFactor(int factor) {}
216215

217216
// Provides debug output on the weights.
218-
virtual void DebugWeights() {
219-
tprintf("Must override Network::DebugWeights for type %d\n", type_);
220-
}
217+
virtual void DebugWeights() = 0;
221218

222219
// Writes to the given file. Returns false in case of error.
223220
// Should be overridden by subclasses, but called by their Serialize.
224221
virtual bool Serialize(TFile* fp) const;
225222
// Reads from the given file. Returns false in case of error.
226223
// Should be overridden by subclasses, but NOT called by their DeSerialize.
227-
virtual bool DeSerialize(TFile* fp);
224+
virtual bool DeSerialize(TFile* fp) = 0;
228225

226+
public:
229227
// Updates the weights using the given learning rate, momentum and adam_beta.
230228
// num_samples is used in the adam computation iff use_adam_ is true.
231229
virtual void Update(float learning_rate, float momentum, float adam_beta,
@@ -261,9 +259,7 @@ class Network {
261259
// instead of all the replicated networks having to do it.
262260
virtual void Forward(bool debug, const NetworkIO& input,
263261
const TransposedArray* input_transpose,
264-
NetworkScratch* scratch, NetworkIO* output) {
265-
tprintf("Must override Network::Forward for type %d\n", type_);
266-
}
262+
NetworkScratch* scratch, NetworkIO* output) = 0;
267263

268264
// Runs backward propagation of errors on fwdX_deltas.
269265
// Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
@@ -272,10 +268,7 @@ class Network {
272268
// return false from Backward!
273269
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
274270
NetworkScratch* scratch,
275-
NetworkIO* back_deltas) {
276-
tprintf("Must override Network::Backward for type %d\n", type_);
277-
return false;
278-
}
271+
NetworkIO* back_deltas) = 0;
279272

280273
// === Debug image display methods. ===
281274
// Displays the image of the matrix to the forward window.
@@ -309,12 +302,8 @@ class Network {
309302
ScrollView* forward_win_; // Recognition debug display window.
310303
ScrollView* backward_win_; // Training debug display window.
311304
TRand* randomizer_; // Random number generator.
312-
313-
// Static serialized name/type_ mapping. Keep in sync with NetworkType.
314-
static char const* const kTypeNames[NT_COUNT];
315305
};
316306

317-
318307
} // namespace tesseract.
319308

320309
#endif // TESSERACT_LSTM_NETWORK_H_

src/lstm/reconfig.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// Description: Network layer that reconfigures the scaling vs feature
44
// depth.
55
// Author: Ray Smith
6-
// Created: Wed Feb 26 15:37:42 PST 2014
76
//
87
// (C) Copyright 2014, Google Inc.
98
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,10 +15,10 @@
1615
// See the License for the specific language governing permissions and
1716
// limitations under the License.
1817
///////////////////////////////////////////////////////////////////////
18+
1919
#ifndef TESSERACT_LSTM_RECONFIG_H_
2020
#define TESSERACT_LSTM_RECONFIG_H_
2121

22-
2322
#include "genericvector.h"
2423
#include "matrix.h"
2524
#include "network.h"
@@ -71,6 +70,11 @@ class Reconfig : public Network {
7170
NetworkScratch* scratch,
7271
NetworkIO* back_deltas) override;
7372

73+
private:
74+
void DebugWeights() override {
75+
tprintf("Must override Network::DebugWeights for type %d\n", type_);
76+
}
77+
7478
protected:
7579
// Non-serialized data used to store parameters between forward and back.
7680
StrideMap back_map_;
@@ -81,5 +85,4 @@ class Reconfig : public Network {
8185

8286
} // namespace tesseract.
8387

84-
8588
#endif // TESSERACT_LSTM_SUBSAMPLE_H_

0 commit comments

Comments
 (0)