Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 48ef500

Browse files
author
Sheng Zha
committed
move interval sampler to contrib
1 parent cd04663 commit 48ef500

File tree

5 files changed

+72
-34
lines changed

5 files changed

+72
-34
lines changed

python/mxnet/gluon/contrib/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@
2020
"""Contrib datasets."""
2121

2222
from . import text
23+
24+
from .sampler import *
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
# pylint: disable=
20+
"""Dataset sampler."""
21+
__all__ = ['IntervalSampler']
22+
23+
from ...data import sampler
24+
25+
class IntervalSampler(sampler.Sampler):
26+
"""Samples elements from [0, length) at fixed intervals.
27+
28+
Parameters
29+
----------
30+
length : int
31+
Length of the sequence.
32+
interval : int
33+
The number of items to skip between two samples.
34+
rollover : bool, default True
35+
Whether to start again from the first skipped item after reaching the end.
36+
If true, this sampler would start again from the first skipped item until all items
37+
are visited.
38+
Otherwise, iteration stops when end is reached and skipped items are ignored.
39+
40+
Examples
41+
--------
42+
>>> sampler = contrib.data.IntervalSampler(13, interval=3)
43+
>>> list(sampler)
44+
[0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11]
45+
>>> sampler = contrib.data.IntervalSampler(13, interval=3, rollover=False)
46+
>>> list(sampler)
47+
[0, 3, 6, 9, 12]
48+
"""
49+
def __init__(self, length, interval, rollover=True):
50+
assert interval < length, \
51+
"Interval {} must be smaller than length {}".format(interval, length)
52+
self._length = length
53+
self._interval = interval
54+
self._rollover = rollover
55+
56+
def __iter__(self):
57+
for i in range(self._interval if self._rollover else 1):
58+
for j in range(i, self._length, self._interval):
59+
yield j
60+
61+
def __len__(self):
62+
return self._length

python/mxnet/gluon/data/sampler.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# coding: utf-8
1919
# pylint: disable=
2020
"""Dataset sampler."""
21-
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'BatchSampler', 'IntervalSampler']
21+
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'BatchSampler']
2222

2323
import random
2424

@@ -136,34 +136,3 @@ def __len__(self):
136136
raise ValueError(
137137
"last_batch must be one of 'keep', 'discard', or 'rollover', " \
138138
"but got %s"%self._last_batch)
139-
140-
141-
class IntervalSampler(Sampler):
142-
"""Samples elements from [0, length) at fixed intervals.
143-
144-
Parameters
145-
----------
146-
length : int
147-
Length of the sequence.
148-
interval : int
149-
The number of items to skip between two samples.
150-
151-
Examples
152-
--------
153-
>>> sampler = gluon.data.IntervalSampler(13, interval=3)
154-
>>> list(sampler)
155-
[0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11]
156-
"""
157-
def __init__(self, length, interval):
158-
assert interval < length, \
159-
"Interval {} must be smaller than length {}".format(interval, length)
160-
self._length = length
161-
self._interval = interval
162-
163-
def __iter__(self):
164-
for i in range(self._interval):
165-
for j in range(i, self._length, self._interval):
166-
yield j
167-
168-
def __len__(self):
169-
return self._length

tests/python/unittest/test_gluon_contrib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ def test_datasets():
178178
assert len(contrib.data.text.WikiText2(root='data/wikitext-2', segment='test')) == 15941
179179

180180

181+
def test_sampler():
182+
interval_sampler = contrib.data.IntervalSampler(10, 3)
183+
assert sorted(list(interval_sampler)) == list(range(10))
184+
interval_sampler = contrib.data.IntervalSampler(10, 3, rollover=False)
185+
assert list(interval_sampler) == [0, 3, 6, 9]
186+
187+
181188
if __name__ == '__main__':
182189
import nose
183190
nose.runmodule()

tests/python/unittest/test_gluon_data.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def test_recordimage_dataset():
6666
def test_sampler():
6767
seq_sampler = gluon.data.SequentialSampler(10)
6868
assert list(seq_sampler) == list(range(10))
69-
interval_sampler = gluon.data.IntervalSampler(10, 3)
70-
assert sorted(list(interval_sampler)) == list(range(10))
7169
rand_sampler = gluon.data.RandomSampler(10)
7270
assert sorted(list(rand_sampler)) == list(range(10))
7371
seq_batch_keep = gluon.data.BatchSampler(seq_sampler, 3, 'keep')

0 commit comments

Comments
 (0)