Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 1c18261

Browse files
author
Sheng Zha
committed
move interval sampler to contrib
1 parent cd04663 commit 1c18261

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

python/mxnet/gluon/contrib/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@
2020
"""Contrib datasets."""
2121

2222
from . import text
23+
24+
from .sampler import *
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
# pylint: disable=
20+
"""Dataset sampler."""
21+
__all__ = ['IntervalSampler']
22+
23+
from ...data import sampler
24+
25+
class IntervalSampler(sampler.Sampler):
26+
"""Samples elements from [0, length) at fixed intervals.
27+
28+
Parameters
29+
----------
30+
length : int
31+
Length of the sequence.
32+
interval : int
33+
The number of items to skip between two samples.
34+
rollover : bool, default True
35+
Whether to start again from the first skipped item after reaching the end.
36+
If true, this sampler would start again from the first skipped item until all items
37+
are visited.
38+
Otherwise, iteration stops when end is reached and skipped items are ignored.
39+
40+
Examples
41+
--------
42+
>>> sampler = contrib.data.IntervalSampler(13, interval=3)
43+
>>> list(sampler)
44+
[0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11]
45+
>>> sampler = contrib.data.IntervalSampler(13, interval=3, rollover=False)
46+
>>> list(sampler)
47+
[0, 3, 6, 9, 12]
48+
"""
49+
def __init__(self, length, interval, rollover=True):
50+
assert interval < length, \
51+
"Interval {} must be smaller than length {}".format(interval, length)
52+
self._length = length
53+
self._interval = interval
54+
self._rollover = rollover
55+
56+
def __iter__(self):
57+
for i in range(self._interval if self._rollover else 1):
58+
for j in range(i, self._length, self._interval):
59+
yield j
60+
61+
def __len__(self):
62+
return self._length

tests/python/unittest/test_gluon_contrib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ def test_datasets():
178178
assert len(contrib.data.text.WikiText2(root='data/wikitext-2', segment='test')) == 15941
179179

180180

181+
def test_sampler():
182+
interval_sampler = contrib.data.IntervalSampler(10, 3)
183+
assert sorted(list(interval_sampler)) == list(range(10))
184+
interval_sampler = contrib.data.IntervalSampler(10, 3, rollover=False)
185+
assert list(interval_sampler) == [0, 3, 6, 9]
186+
187+
181188
if __name__ == '__main__':
182189
import nose
183190
nose.runmodule()

0 commit comments

Comments
 (0)