2
2
// File: network.cpp
3
3
// Description: Base class for neural network implementations.
4
4
// Author: Ray Smith
5
- // Created: Wed May 01 17:25:06 PST 2013
6
5
//
7
6
// (C) Copyright 2013, Google Inc.
8
7
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -53,10 +52,11 @@ const int kMaxWinSize = 2000;
53
52
const int kXWinFrameSize = 30 ;
54
53
const int kYWinFrameSize = 80 ;
55
54
56
- // String names corresponding to the NetworkType enum. Keep in sync.
55
+ // String names corresponding to the NetworkType enum.
56
+ // Keep in sync with NetworkType.
57
57
// Names used in Serialization to allow re-ordering/addition/deletion of
58
58
// layer types in NetworkType without invalidating existing network files.
59
- char const * const Network:: kTypeNames [NT_COUNT] = {
59
+ static char const * const kTypeNames [NT_COUNT] = {
60
60
" Invalid" , " Input" ,
61
61
" Convolve" , " Maxpool" ,
62
62
" Parallel" , " Replicated" ,
@@ -165,81 +165,87 @@ bool Network::Serialize(TFile* fp) const {
165
165
return true ;
166
166
}
167
167
168
- // Reads from the given file. Returns false in case of error.
169
- // Should be overridden by subclasses, but NOT called by their DeSerialize.
170
- bool Network::DeSerialize (TFile* fp) {
168
+ static NetworkType getNetworkType (TFile* fp) {
171
169
int8_t data;
172
- if (!fp->DeSerialize (&data)) return false ;
170
+ if (!fp->DeSerialize (&data)) return NT_NONE ;
173
171
if (data == NT_NONE) {
174
172
STRING type_name;
175
- if (!type_name.DeSerialize (fp)) return false ;
173
+ if (!type_name.DeSerialize (fp)) return NT_NONE ;
176
174
for (data = 0 ; data < NT_COUNT && type_name != kTypeNames [data]; ++data) {
177
175
}
178
176
if (data == NT_COUNT) {
179
177
tprintf (" Invalid network layer type:%s\n " , type_name.string ());
180
- return false ;
178
+ return NT_NONE ;
181
179
}
182
180
}
183
- type_ = static_cast <NetworkType>(data);
184
- if (!fp->DeSerialize (&data)) return false ;
185
- training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
186
- if (!fp->DeSerialize (&data)) return false ;
187
- needs_to_backprop_ = data != 0 ;
188
- if (!fp->DeSerialize (&network_flags_)) return false ;
189
- if (!fp->DeSerialize (&ni_)) return false ;
190
- if (!fp->DeSerialize (&no_)) return false ;
191
- if (!fp->DeSerialize (&num_weights_)) return false ;
192
- if (!name_.DeSerialize (fp)) return false ;
193
- return true ;
181
+ return static_cast <NetworkType>(data);
194
182
}
195
183
196
184
// Reads from the given file. Returns nullptr in case of error.
197
185
// Determines the type of the serialized class and calls its DeSerialize
198
186
// on a new object of the appropriate type, which is returned.
199
187
Network* Network::CreateFromFile (TFile* fp) {
200
- Network stub;
201
- if (!stub.DeSerialize (fp)) return nullptr ;
188
+ NetworkType type; // Type of the derived network class.
189
+ TrainingState training; // Are we currently training?
190
+ bool needs_to_backprop; // This network needs to output back_deltas.
191
+ int32_t network_flags; // Behavior control flags in NetworkFlags.
192
+ int32_t ni; // Number of input values.
193
+ int32_t no; // Number of output values.
194
+ int32_t num_weights; // Number of weights in this and sub-network.
195
+ STRING name; // A unique name for this layer.
196
+ int8_t data;
202
197
Network* network = nullptr ;
203
- switch (stub.type_ ) {
198
+ type = getNetworkType (fp);
199
+ if (!fp->DeSerialize (&data)) return nullptr ;
200
+ training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
201
+ if (!fp->DeSerialize (&data)) return nullptr ;
202
+ needs_to_backprop = data != 0 ;
203
+ if (!fp->DeSerialize (&network_flags)) return nullptr ;
204
+ if (!fp->DeSerialize (&ni)) return nullptr ;
205
+ if (!fp->DeSerialize (&no)) return nullptr ;
206
+ if (!fp->DeSerialize (&num_weights)) return nullptr ;
207
+ if (!name.DeSerialize (fp)) return nullptr ;
208
+
209
+ switch (type) {
204
210
case NT_CONVOLVE:
205
- network = new Convolve (stub. name_ , stub. ni_ , 0 , 0 );
211
+ network = new Convolve (name, ni , 0 , 0 );
206
212
break ;
207
213
case NT_INPUT:
208
- network = new Input (stub. name_ , stub. ni_ , stub. no_ );
214
+ network = new Input (name, ni, no );
209
215
break ;
210
216
case NT_LSTM:
211
217
case NT_LSTM_SOFTMAX:
212
218
case NT_LSTM_SOFTMAX_ENCODED:
213
219
case NT_LSTM_SUMMARY:
214
220
network =
215
- new LSTM (stub. name_ , stub. ni_ , stub. no_ , stub. no_ , false , stub. type_ );
221
+ new LSTM (name, ni, no, no , false , type );
216
222
break ;
217
223
case NT_MAXPOOL:
218
- network = new Maxpool (stub. name_ , stub. ni_ , 0 , 0 );
224
+ network = new Maxpool (name, ni , 0 , 0 );
219
225
break ;
220
226
// All variants of Parallel.
221
227
case NT_PARALLEL:
222
228
case NT_REPLICATED:
223
229
case NT_PAR_RL_LSTM:
224
230
case NT_PAR_UD_LSTM:
225
231
case NT_PAR_2D_LSTM:
226
- network = new Parallel (stub. name_ , stub. type_ );
232
+ network = new Parallel (name, type );
227
233
break ;
228
234
case NT_RECONFIG:
229
- network = new Reconfig (stub. name_ , stub. ni_ , 0 , 0 );
235
+ network = new Reconfig (name, ni , 0 , 0 );
230
236
break ;
231
237
// All variants of reversed.
232
238
case NT_XREVERSED:
233
239
case NT_YREVERSED:
234
240
case NT_XYTRANSPOSE:
235
- network = new Reversed (stub. name_ , stub. type_ );
241
+ network = new Reversed (name, type );
236
242
break ;
237
243
case NT_SERIES:
238
- network = new Series (stub. name_ );
244
+ network = new Series (name );
239
245
break ;
240
246
case NT_TENSORFLOW:
241
247
#ifdef INCLUDE_TENSORFLOW
242
- network = new TFNetwork (stub. name_ );
248
+ network = new TFNetwork (name );
243
249
#else
244
250
tprintf (" TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n " );
245
251
#endif
@@ -253,16 +259,16 @@ Network* Network::CreateFromFile(TFile* fp) {
253
259
case NT_LOGISTIC:
254
260
case NT_POSCLIP:
255
261
case NT_SYMCLIP:
256
- network = new FullyConnected (stub. name_ , stub. ni_ , stub. no_ , stub. type_ );
262
+ network = new FullyConnected (name, ni, no, type );
257
263
break ;
258
264
default :
259
265
break ;
260
266
}
261
267
if (network) {
262
- network->training_ = stub. training_ ;
263
- network->needs_to_backprop_ = stub. needs_to_backprop_ ;
264
- network->network_flags_ = stub. network_flags_ ;
265
- network->num_weights_ = stub. num_weights_ ;
268
+ network->training_ = training ;
269
+ network->needs_to_backprop_ = needs_to_backprop ;
270
+ network->network_flags_ = network_flags ;
271
+ network->num_weights_ = num_weights ;
266
272
if (!network->DeSerialize (fp)) {
267
273
delete network;
268
274
network = nullptr ;
0 commit comments