Skip to content

Commit 21d5ce5

Browse files
committed
Fix issue with big endian handling
Signed-off-by: Stefan Weil <[email protected]>
1 parent 9c1fe09 commit 21d5ce5

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

src/lstm/input.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,12 @@ Input::~Input() {
4242

4343
// Writes to the given file. Returns false in case of error.
4444
bool Input::Serialize(TFile* fp) const {
45-
if (!Network::Serialize(fp)) return false;
46-
if (fp->FWrite(&shape_, sizeof(shape_), 1) != 1) return false;
47-
return true;
45+
return Network::Serialize(fp) && shape_.Serialize(fp);
4846
}
4947

5048
// Reads from the given file. Returns false in case of error.
5149
bool Input::DeSerialize(TFile* fp) {
52-
return fp->FReadEndian(&shape_, sizeof(shape_), 1) == 1;
50+
return shape_.DeSerialize(fp);
5351
}
5452

5553
// Returns an integer reduction factor that the network applies to the

src/lstm/static_shape.h

+26-4
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,40 @@ class StaticShape {
5959
height_, width_, depth_, loss_type_);
6060
}
6161

62+
bool DeSerialize(TFile *fp) {
63+
int32_t tmp;
64+
bool result =
65+
fp->FReadEndian(&batch_, sizeof(batch_), 1) == 1 &&
66+
fp->FReadEndian(&height_, sizeof(height_), 1) == 1 &&
67+
fp->FReadEndian(&width_, sizeof(width_), 1) == 1 &&
68+
fp->FReadEndian(&depth_, sizeof(depth_), 1) == 1 &&
69+
fp->FReadEndian(&tmp, sizeof(tmp), 1) == 1;
70+
loss_type_ = static_cast<LossType>(tmp);
71+
return result;
72+
}
73+
74+
bool Serialize(TFile *fp) const {
75+
int32_t tmp = loss_type_;
76+
return
77+
fp->FWrite(&batch_, sizeof(batch_), 1) == 1 &&
78+
fp->FWrite(&height_, sizeof(height_), 1) == 1 &&
79+
fp->FWrite(&width_, sizeof(width_), 1) == 1 &&
80+
fp->FWrite(&depth_, sizeof(depth_), 1) == 1 &&
81+
fp->FWrite(&tmp, sizeof(tmp), 1) == 1;
82+
}
83+
6284
private:
6385
// Size of the 4-D tensor input/output to a network. A value of zero is
6486
// allowed for all except depth_ and means to be determined at runtime, and
6587
// regarded as variable.
6688
// Number of elements in a batch, or number of frames in a video stream.
67-
int batch_;
89+
int32_t batch_;
6890
// Height of the image.
69-
int height_;
91+
int32_t height_;
7092
// Width of the image.
71-
int width_;
93+
int32_t width_;
7294
// Depth of the image. (Number of "nodes").
73-
int depth_;
95+
int32_t depth_;
7496
// How to train/interpret the output.
7597
LossType loss_type_;
7698
};

0 commit comments

Comments
 (0)