Skip to content

Commit a85e55f

Browse files
committed
Use CPUPinned context in ImageRecordIOParser2 (apache#13980)
* create NDArray with CPUPinned context in ImageRecordIOParser2 * update document * use -1 device_id as an option to create CPU(0) context * retrigger CI * fix cpplint error
1 parent 45a1554 commit a85e55f

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

src/io/image_iter_common.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,13 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
125125
bool verbose;
126126
/*! \brief partition the data into multiple parts */
127127
int num_parts;
128-
/*! \brief the index of the part will read*/
128+
/*! \brief the index of the part will read */
129129
int part_index;
130-
/*! \brief the size of a shuffle chunk*/
130+
/*! \brief device id used to create context for internal NDArray */
131+
int device_id;
132+
/*! \brief the size of a shuffle chunk */
131133
size_t shuffle_chunk_size;
132-
/*! \brief the seed for chunk shuffling*/
134+
/*! \brief the seed for chunk shuffling */
133135
int shuffle_chunk_seed;
134136

135137
// declare parameters
@@ -161,6 +163,11 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
161163
.describe("Virtually partition the data into these many parts.");
162164
DMLC_DECLARE_FIELD(part_index).set_default(0)
163165
.describe("The *i*-th virtual partition to be read.");
166+
DMLC_DECLARE_FIELD(device_id).set_default(0)
167+
.describe("The device id used to create context for internal NDArray. "\
168+
"Setting device_id to -1 will create Context::CPU(0). Setting "
169+
"device_id to valid positive device id will create "
170+
"Context::CPUPinned(device_id). Default is 0.");
164171
DMLC_DECLARE_FIELD(shuffle_chunk_size).set_default(0)
165172
.describe("The data shuffle buffer size in MB. Only valid if shuffle is true.");
166173
DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0)

src/io/iter_image_recordio_2.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,14 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
285285
shape_vec.push_back(param_.label_width);
286286
TShape label_shape(shape_vec.begin(), shape_vec.end());
287287

288-
out->data.at(0) = NDArray(data_shape, Context::CPU(0), false,
288+
auto ctx = Context::CPU(0);
289+
auto dev_id = param_.device_id;
290+
if (dev_id != -1) {
291+
ctx = Context::CPUPinned(dev_id);
292+
}
293+
out->data.at(0) = NDArray(data_shape, ctx, false,
289294
mshadow::DataType<DType>::kFlag);
290-
out->data.at(1) = NDArray(label_shape, Context::CPU(0), false,
295+
out->data.at(1) = NDArray(label_shape, ctx, false,
291296
mshadow::DataType<real_t>::kFlag);
292297
unit_size_[0] = param_.data_shape.Size();
293298
unit_size_[1] = param_.label_width;

0 commit comments

Comments
 (0)