Skip to content

Commit 6ef267c

Browse files
committed
Use TFile::Serialize, TFile::DeSerialize
Signed-off-by: Stefan Weil <[email protected]>
1 parent c383b1a commit 6ef267c

File tree

5 files changed

+39
-42
lines changed

5 files changed

+39
-42
lines changed

src/lstm/network.cpp

+15-17
Original file line numberDiff line numberDiff line change
@@ -150,26 +150,26 @@ bool Network::SetupNeedsBackprop(bool needs_backprop) {
150150
// Writes to the given file. Returns false in case of error.
151151
bool Network::Serialize(TFile* fp) const {
152152
int8_t data = NT_NONE;
153-
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
153+
if (!fp->Serialize(&data)) return false;
154154
STRING type_name = kTypeNames[type_];
155155
if (!type_name.Serialize(fp)) return false;
156156
data = training_;
157-
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
157+
if (!fp->Serialize(&data)) return false;
158158
data = needs_to_backprop_;
159-
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
160-
if (fp->FWrite(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
161-
if (fp->FWrite(&ni_, sizeof(ni_), 1) != 1) return false;
162-
if (fp->FWrite(&no_, sizeof(no_), 1) != 1) return false;
163-
if (fp->FWrite(&num_weights_, sizeof(num_weights_), 1) != 1) return false;
159+
if (!fp->Serialize(&data)) return false;
160+
if (!fp->Serialize(&network_flags_)) return false;
161+
if (!fp->Serialize(&ni_)) return false;
162+
if (!fp->Serialize(&no_)) return false;
163+
if (!fp->Serialize(&num_weights_)) return false;
164164
if (!name_.Serialize(fp)) return false;
165165
return true;
166166
}
167167

168168
// Reads from the given file. Returns false in case of error.
169169
// Should be overridden by subclasses, but NOT called by their DeSerialize.
170170
bool Network::DeSerialize(TFile* fp) {
171-
int8_t data = 0;
172-
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
171+
int8_t data;
172+
if (!fp->DeSerialize(&data)) return false;
173173
if (data == NT_NONE) {
174174
STRING type_name;
175175
if (!type_name.DeSerialize(fp)) return false;
@@ -181,16 +181,14 @@ bool Network::DeSerialize(TFile* fp) {
181181
}
182182
}
183183
type_ = static_cast<NetworkType>(data);
184-
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
184+
if (!fp->DeSerialize(&data)) return false;
185185
training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
186-
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
186+
if (!fp->DeSerialize(&data)) return false;
187187
needs_to_backprop_ = data != 0;
188-
if (fp->FReadEndian(&network_flags_, sizeof(network_flags_), 1) != 1)
189-
return false;
190-
if (fp->FReadEndian(&ni_, sizeof(ni_), 1) != 1) return false;
191-
if (fp->FReadEndian(&no_, sizeof(no_), 1) != 1) return false;
192-
if (fp->FReadEndian(&num_weights_, sizeof(num_weights_), 1) != 1)
193-
return false;
188+
if (!fp->DeSerialize(&network_flags_)) return false;
189+
if (!fp->DeSerialize(&ni_)) return false;
190+
if (!fp->DeSerialize(&no_)) return false;
191+
if (!fp->DeSerialize(&num_weights_)) return false;
194192
if (!name_.DeSerialize(fp)) return false;
195193
return true;
196194
}

src/lstm/plumbing.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ float* Plumbing::LayerLearningRatePtr(const char* id) const {
181181
// Writes to the given file. Returns false in case of error.
182182
bool Plumbing::Serialize(TFile* fp) const {
183183
if (!Network::Serialize(fp)) return false;
184-
int32_t size = stack_.size();
184+
uint32_t size = stack_.size();
185185
// Can't use PointerVector::Serialize here as we need a special DeSerialize.
186-
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
187-
for (int i = 0; i < size; ++i)
186+
if (!fp->Serialize(&size)) return false;
187+
for (uint32_t i = 0; i < size; ++i)
188188
if (!stack_[i]->Serialize(fp)) return false;
189189
if ((network_flags_ & NF_LAYER_SPECIFIC_LR) &&
190190
!learning_rates_.Serialize(fp)) {
@@ -197,9 +197,9 @@ bool Plumbing::Serialize(TFile* fp) const {
197197
bool Plumbing::DeSerialize(TFile* fp) {
198198
stack_.truncate(0);
199199
no_ = 0; // We will be modifying this as we AddToStack.
200-
int32_t size;
201-
if (fp->FReadEndian(&size, sizeof(size), 1) != 1) return false;
202-
for (int i = 0; i < size; ++i) {
200+
uint32_t size;
201+
if (!fp->DeSerialize(&size)) return false;
202+
for (uint32_t i = 0; i < size; ++i) {
203203
Network* network = CreateFromFile(fp);
204204
if (network == nullptr) return false;
205205
AddToStack(network);

src/lstm/reconfig.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,15 @@ int Reconfig::XScaleFactor() const {
4949

5050
// Writes to the given file. Returns false in case of error.
5151
bool Reconfig::Serialize(TFile* fp) const {
52-
if (!Network::Serialize(fp)) return false;
53-
if (fp->FWrite(&x_scale_, sizeof(x_scale_), 1) != 1) return false;
54-
if (fp->FWrite(&y_scale_, sizeof(y_scale_), 1) != 1) return false;
55-
return true;
52+
return Network::Serialize(fp) &&
53+
fp->Serialize(&x_scale_) &&
54+
fp->Serialize(&y_scale_);
5655
}
5756

5857
// Reads from the given file. Returns false in case of error.
5958
bool Reconfig::DeSerialize(TFile* fp) {
60-
if (fp->FReadEndian(&x_scale_, sizeof(x_scale_), 1) != 1) return false;
61-
if (fp->FReadEndian(&y_scale_, sizeof(y_scale_), 1) != 1) return false;
59+
if (!fp->DeSerialize(&x_scale_)) return false;
60+
if (!fp->DeSerialize(&y_scale_)) return false;
6261
no_ = ni_ * x_scale_ * y_scale_;
6362
return true;
6463
}

src/lstm/static_shape.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,23 @@ class StaticShape {
6464
bool DeSerialize(TFile *fp) {
6565
int32_t tmp = LT_NONE;
6666
bool result =
67-
fp->FReadEndian(&batch_, sizeof(batch_), 1) == 1 &&
68-
fp->FReadEndian(&height_, sizeof(height_), 1) == 1 &&
69-
fp->FReadEndian(&width_, sizeof(width_), 1) == 1 &&
70-
fp->FReadEndian(&depth_, sizeof(depth_), 1) == 1 &&
71-
fp->FReadEndian(&tmp, sizeof(tmp), 1) == 1;
67+
fp->DeSerialize(&batch_) &&
68+
fp->DeSerialize(&height_) &&
69+
fp->DeSerialize(&width_) &&
70+
fp->DeSerialize(&depth_) &&
71+
fp->DeSerialize(&tmp);
7272
loss_type_ = static_cast<LossType>(tmp);
7373
return result;
7474
}
7575

7676
bool Serialize(TFile *fp) const {
7777
int32_t tmp = loss_type_;
7878
return
79-
fp->FWrite(&batch_, sizeof(batch_), 1) == 1 &&
80-
fp->FWrite(&height_, sizeof(height_), 1) == 1 &&
81-
fp->FWrite(&width_, sizeof(width_), 1) == 1 &&
82-
fp->FWrite(&depth_, sizeof(depth_), 1) == 1 &&
83-
fp->FWrite(&tmp, sizeof(tmp), 1) == 1;
79+
fp->Serialize(&batch_) &&
80+
fp->Serialize(&height_) &&
81+
fp->Serialize(&width_) &&
82+
fp->Serialize(&depth_) &&
83+
fp->Serialize(&tmp);
8484
}
8585

8686
private:

src/lstm/weightmatrix.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
148148
// format, without errs, so we can detect and read old format weight matrices.
149149
uint8_t mode =
150150
(int_mode_ ? kInt8Flag : 0) | (use_adam_ ? kAdamFlag : 0) | kDoubleFlag;
151-
if (fp->FWrite(&mode, sizeof(mode), 1) != 1) return false;
151+
if (!fp->Serialize(&mode)) return false;
152152
if (int_mode_) {
153153
if (!wi_.Serialize(fp)) return false;
154154
if (!scales_.Serialize(fp)) return false;
@@ -163,8 +163,8 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
163163
// Reads from the given file. Returns false in case of error.
164164

165165
bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
166-
uint8_t mode = 0;
167-
if (fp->FRead(&mode, sizeof(mode), 1) != 1) return false;
166+
uint8_t mode;
167+
if (!fp->DeSerialize(&mode)) return false;
168168
int_mode_ = (mode & kInt8Flag) != 0;
169169
use_adam_ = (mode & kAdamFlag) != 0;
170170
if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, fp);

0 commit comments

Comments
 (0)