Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit cd04663

Browse files
author
Sheng Zha
committed
address comments
1 parent a440c1f commit cd04663

File tree

7 files changed

+37
-9
lines changed

7 files changed

+37
-9
lines changed

python/mxnet/gluon/contrib/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@
2121
from . import nn
2222

2323
from . import rnn
24+
25+
from . import data
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
# pylint: disable=wildcard-import
20+
"""Contrib datasets."""
21+
22+
from . import text

python/mxnet/gluon/data/text.py renamed to python/mxnet/gluon/contrib/data/text.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
import shutil
2727
import numpy as np
2828

29-
from . import dataset
30-
from ..utils import download, check_sha1
31-
from ...contrib import text
32-
from ... import nd
29+
from ...data import dataset
30+
from ...utils import download, check_sha1
31+
from ....contrib import text
32+
from .... import nd
3333

3434

3535
class WikiText2(dataset._DownloadedDataset):

python/mxnet/gluon/data/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,3 @@
2626
from .dataloader import *
2727

2828
from . import vision
29-
30-
from . import text

python/mxnet/gluon/data/sampler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ class IntervalSampler(Sampler):
145145
----------
146146
length : int
147147
Length of the sequence.
148+
interval : int
149+
The number of items to skip between two samples.
148150
149151
Examples
150152
--------
@@ -153,6 +155,8 @@ class IntervalSampler(Sampler):
153155
[0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11]
154156
"""
155157
def __init__(self, length, interval):
158+
assert interval < length, \
159+
"Interval {} must be smaller than length {}".format(interval, length)
156160
self._length = length
157161
self._interval = interval
158162

tests/python/unittest/test_gluon_contrib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ def test_identity():
172172
mx.test_utils.assert_almost_equal(model(x).asnumpy(),
173173
x.asnumpy())
174174

175+
def test_datasets():
176+
assert len(contrib.data.text.WikiText2(root='data/wikitext-2', segment='train')) == 42780
177+
assert len(contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation')) == 632
178+
assert len(contrib.data.text.WikiText2(root='data/wikitext-2', segment='test')) == 15941
179+
175180

176181
if __name__ == '__main__':
177182
import nose

tests/python/unittest/test_gluon_data.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,6 @@ def test_datasets():
8787
assert len(gluon.data.vision.CIFAR100(root='data/cifar100')) == 50000
8888
assert len(gluon.data.vision.CIFAR100(root='data/cifar100', fine_label=True)) == 50000
8989
assert len(gluon.data.vision.CIFAR100(root='data/cifar100', train=False)) == 10000
90-
assert len(gluon.data.text.WikiText2(root='data/wikitext-2', segment='train')) == 42780
91-
assert len(gluon.data.text.WikiText2(root='data/wikitext-2', segment='validation')) == 632
92-
assert len(gluon.data.text.WikiText2(root='data/wikitext-2', segment='test')) == 15941
9390

9491
def test_image_folder_dataset():
9592
prepare_record()

0 commit comments

Comments
 (0)