Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 18badaa

Browse files
committed
add unit test to check output context
1 parent 4b5edf7 commit 18badaa

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/python/unittest/test_gluon_data.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,30 @@ def test_multi_worker_dataloader_release_pool():
256256
del the_iter
257257
del D
258258

259+
260+
def test_dataloader_context():
261+
X = np.random.uniform(size=(10, 20))
262+
dataset = gluon.data.ArrayDataset(X)
263+
default_dev_id = 0
264+
custom_dev_id = 1
265+
266+
# use non-pinned memory
267+
loader1 = gluon.data.DataLoader(dataset, 8)
268+
for _, x in enumerate(loader1):
269+
assert x.context == context.cpu(default_dev_id)
270+
271+
# use pinned memory with default device id
272+
loader2 = gluon.data.DataLoader(dataset, 8, pin_memory=True)
273+
for _, x in enumerate(loader2):
274+
assert x.context == context.cpu_pinned(default_dev_id)
275+
276+
# use pinned memory with custom device id
277+
loader3 = gluon.data.DataLoader(dataset, 8, pin_memory=True,
278+
pin_device_id=custom_dev_id)
279+
for _, x in enumerate(loader3):
280+
assert x.context == context.cpu_pinned(custom_dev_id)
281+
282+
259283
if __name__ == '__main__':
260284
import nose
261285
nose.runmodule()

0 commit comments

Comments
 (0)