Skip to content

Commit 03bc04f

Browse files
committed
fix example for enas
1 parent 67d3e50 commit 03bc04f

File tree

9 files changed

+352
-16
lines changed

9 files changed

+352
-16
lines changed

cmd/metricscollector/v1beta1/tfevent-metricscollector/Dockerfile.ppc64le

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
FROM ibmcom/tensorflow-ppc64le:2.2.0-py3
2-
RUN pip install rfc3339 grpcio googleapis-common-protos
32
ADD . /usr/src/app/github.com/kubeflow/katib
43
WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/metricscollector/v1beta1/tfevent-metricscollector/
54
RUN pip install --no-cache-dir -r requirements.txt

examples/v1beta1/trial-images/enas-cnn-cifar10/Dockerfile.cpu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ ENV TARGET_DIR /opt/enas-cnn-cifar10
55
ADD examples/v1beta1/trial-images/enas-cnn-cifar10 ${TARGET_DIR}
66
WORKDIR ${TARGET_DIR}
77

8+
RUN pip3 install --no-cache-dir -r requirements.txt
89
ENV PYTHONPATH ${TARGET_DIR}
910

1011
RUN chgrp -R 0 ${TARGET_DIR} \

examples/v1beta1/trial-images/enas-cnn-cifar10/RunTrial.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import keras
2-
import numpy as np
1+
from tensorflow import keras
32
from keras.datasets import cifar10
43
from ModelConstructor import ModelConstructor
54
from tensorflow.keras.utils import to_categorical
65
from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
76
from keras.preprocessing.image import ImageDataGenerator
87
import argparse
9-
import time
108

119
if __name__ == "__main__":
1210
parser = argparse.ArgumentParser(description='TrainingContainer')
@@ -46,7 +44,7 @@
4644

4745
test_model.summary()
4846
test_model.compile(loss=keras.losses.categorical_crossentropy,
49-
optimizer=keras.optimizers.Adam(lr=1e-3, decay=1e-4),
47+
optimizer=keras.optimizers.Adam(learning_rate=1e-3, decay=1e-4),
5048
metrics=['accuracy'])
5149

5250
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
@@ -67,12 +65,12 @@
6765

6866
print(">>> Data Loaded. Training starts.")
6967
for e in range(num_epochs):
70-
print("\nTotal Epoch {}/{}".format(e+1, num_epochs))
71-
history = test_model.fit_generator(generator=aug_data_flow,
72-
steps_per_epoch=int(len(x_train)/128)+1,
73-
epochs=1, verbose=1,
74-
validation_data=(x_test, y_test))
75-
print("Training-Accuracy={}".format(history.history['acc'][-1]))
68+
print("\nTotal Epoch {}/{}".format(e + 1, num_epochs))
69+
history = test_model.fit(aug_data_flow,
70+
steps_per_epoch=int(len(x_train) / 128) + 1,
71+
epochs=1, verbose=1,
72+
validation_data=(x_test, y_test))
73+
print("Training-Accuracy={}".format(history.history['accuracy'][-1]))
7674
print("Training-Loss={}".format(history.history['loss'][-1]))
77-
print("Validation-Accuracy={}".format(history.history['val_acc'][-1]))
75+
print("Validation-Accuracy={}".format(history.history['val_accuracy'][-1]))
7876
print("Validation-Loss={}".format(history.history['val_loss'][-1]))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
scipy>=1.7.2

examples/v1beta1/trial-images/tf-mnist-with-summaries/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ If you want to read more about this example, visit the official
88
GitHub repository.
99

1010
Katib uses this training container in some Experiments, for instance in the
11-
[TF Event Metrics Collector](../../metrics-collector/tfevent-metrics-collector.yaml#L55-L64).
11+
[TF Event Metrics Collector](../../metrics-collector/tfevent-metrics-collector.yaml#L42-L49).
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Functions for downloading and reading MNIST data (deprecated).
16+
17+
This module and all its submodules are deprecated.
18+
"""
19+
20+
from __future__ import absolute_import
21+
from __future__ import division
22+
from __future__ import print_function
23+
24+
import collections
25+
import gzip
26+
import os
27+
28+
import numpy
29+
from six.moves import urllib
30+
from six.moves import xrange # pylint: disable=redefined-builtin
31+
32+
from tensorflow.python.framework import dtypes
33+
from tensorflow.python.framework import random_seed
34+
from tensorflow.python.platform import gfile
35+
from tensorflow.python.util.deprecation import deprecated
36+
37+
_Datasets = collections.namedtuple('_Datasets', ['train', 'validation', 'test'])
38+
39+
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
40+
DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
41+
42+
43+
def _read32(bytestream):
44+
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
45+
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
46+
47+
48+
@deprecated(None, 'Please use tf.data to implement this functionality.')
49+
def _extract_images(f):
50+
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].
51+
52+
Args:
53+
f: A file object that can be passed into a gzip reader.
54+
55+
Returns:
56+
data: A 4D uint8 numpy array [index, y, x, depth].
57+
58+
Raises:
59+
ValueError: If the bytestream does not start with 2051.
60+
61+
"""
62+
print('Extracting', f.name)
63+
with gzip.GzipFile(fileobj=f) as bytestream:
64+
magic = _read32(bytestream)
65+
if magic != 2051:
66+
raise ValueError('Invalid magic number %d in MNIST image file: %s' %
67+
(magic, f.name))
68+
num_images = _read32(bytestream)
69+
rows = _read32(bytestream)
70+
cols = _read32(bytestream)
71+
buf = bytestream.read(rows * cols * num_images)
72+
data = numpy.frombuffer(buf, dtype=numpy.uint8)
73+
data = data.reshape(num_images, rows, cols, 1)
74+
return data
75+
76+
77+
@deprecated(None, 'Please use tf.one_hot on tensors.')
78+
def _dense_to_one_hot(labels_dense, num_classes):
79+
"""Convert class labels from scalars to one-hot vectors."""
80+
num_labels = labels_dense.shape[0]
81+
index_offset = numpy.arange(num_labels) * num_classes
82+
labels_one_hot = numpy.zeros((num_labels, num_classes))
83+
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
84+
return labels_one_hot
85+
86+
87+
@deprecated(None, 'Please use tf.data to implement this functionality.')
88+
def _extract_labels(f, one_hot=False, num_classes=10):
89+
"""Extract the labels into a 1D uint8 numpy array [index].
90+
91+
Args:
92+
f: A file object that can be passed into a gzip reader.
93+
one_hot: Does one hot encoding for the result.
94+
num_classes: Number of classes for the one hot encoding.
95+
96+
Returns:
97+
labels: a 1D uint8 numpy array.
98+
99+
Raises:
100+
ValueError: If the bystream doesn't start with 2049.
101+
"""
102+
print('Extracting', f.name)
103+
with gzip.GzipFile(fileobj=f) as bytestream:
104+
magic = _read32(bytestream)
105+
if magic != 2049:
106+
raise ValueError('Invalid magic number %d in MNIST label file: %s' %
107+
(magic, f.name))
108+
num_items = _read32(bytestream)
109+
buf = bytestream.read(num_items)
110+
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
111+
if one_hot:
112+
return _dense_to_one_hot(labels, num_classes)
113+
return labels
114+
115+
116+
class _DataSet(object):
117+
"""Container class for a _DataSet (deprecated).
118+
119+
THIS CLASS IS DEPRECATED.
120+
"""
121+
122+
@deprecated(None, 'Please use alternatives such as official/mnist/_DataSet.py'
123+
' from tensorflow/models.')
124+
def __init__(self,
125+
images,
126+
labels,
127+
fake_data=False,
128+
one_hot=False,
129+
dtype=dtypes.float32,
130+
reshape=True,
131+
seed=None):
132+
"""Construct a _DataSet.
133+
134+
one_hot arg is used only if fake_data is true. `dtype` can be either
135+
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
136+
`[0, 1]`. Seed arg provides for convenient deterministic testing.
137+
138+
Args:
139+
images: The images
140+
labels: The labels
141+
fake_data: Ignore inages and labels, use fake data.
142+
one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
143+
False).
144+
dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
145+
range [0,255]. float32 output has range [0,1].
146+
reshape: Bool. If True returned images are returned flattened to vectors.
147+
seed: The random seed to use.
148+
"""
149+
seed1, seed2 = random_seed.get_seed(seed)
150+
# If op level seed is not set, use whatever graph level seed is returned
151+
numpy.random.seed(seed1 if seed is None else seed2)
152+
dtype = dtypes.as_dtype(dtype).base_dtype
153+
if dtype not in (dtypes.uint8, dtypes.float32):
154+
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
155+
dtype)
156+
if fake_data:
157+
self._num_examples = 10000
158+
self.one_hot = one_hot
159+
else:
160+
assert images.shape[0] == labels.shape[0], (
161+
'images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
162+
self._num_examples = images.shape[0]
163+
164+
# Convert shape from [num examples, rows, columns, depth]
165+
# to [num examples, rows*columns] (assuming depth == 1)
166+
if reshape:
167+
assert images.shape[3] == 1
168+
images = images.reshape(images.shape[0],
169+
images.shape[1] * images.shape[2])
170+
if dtype == dtypes.float32:
171+
# Convert from [0, 255] -> [0.0, 1.0].
172+
images = images.astype(numpy.float32)
173+
images = numpy.multiply(images, 1.0 / 255.0)
174+
self._images = images
175+
self._labels = labels
176+
self._epochs_completed = 0
177+
self._index_in_epoch = 0
178+
179+
@property
180+
def images(self):
181+
return self._images
182+
183+
@property
184+
def labels(self):
185+
return self._labels
186+
187+
@property
188+
def num_examples(self):
189+
return self._num_examples
190+
191+
@property
192+
def epochs_completed(self):
193+
return self._epochs_completed
194+
195+
def next_batch(self, batch_size, fake_data=False, shuffle=True):
196+
"""Return the next `batch_size` examples from this data set."""
197+
if fake_data:
198+
fake_image = [1] * 784
199+
if self.one_hot:
200+
fake_label = [1] + [0] * 9
201+
else:
202+
fake_label = 0
203+
return [fake_image for _ in xrange(batch_size)
204+
], [fake_label for _ in xrange(batch_size)]
205+
start = self._index_in_epoch
206+
# Shuffle for the first epoch
207+
if self._epochs_completed == 0 and start == 0 and shuffle:
208+
perm0 = numpy.arange(self._num_examples)
209+
numpy.random.shuffle(perm0)
210+
self._images = self.images[perm0]
211+
self._labels = self.labels[perm0]
212+
# Go to the next epoch
213+
if start + batch_size > self._num_examples:
214+
# Finished epoch
215+
self._epochs_completed += 1
216+
# Get the rest examples in this epoch
217+
rest_num_examples = self._num_examples - start
218+
images_rest_part = self._images[start:self._num_examples]
219+
labels_rest_part = self._labels[start:self._num_examples]
220+
# Shuffle the data
221+
if shuffle:
222+
perm = numpy.arange(self._num_examples)
223+
numpy.random.shuffle(perm)
224+
self._images = self.images[perm]
225+
self._labels = self.labels[perm]
226+
# Start next epoch
227+
start = 0
228+
self._index_in_epoch = batch_size - rest_num_examples
229+
end = self._index_in_epoch
230+
images_new_part = self._images[start:end]
231+
labels_new_part = self._labels[start:end]
232+
return numpy.concatenate((images_rest_part, images_new_part),
233+
axis=0), numpy.concatenate(
234+
(labels_rest_part, labels_new_part), axis=0)
235+
else:
236+
self._index_in_epoch += batch_size
237+
end = self._index_in_epoch
238+
return self._images[start:end], self._labels[start:end]
239+
240+
241+
@deprecated(None, 'Please write your own downloading logic.')
242+
def _maybe_download(filename, work_directory, source_url):
243+
"""Download the data from source url, unless it's already here.
244+
245+
Args:
246+
filename: string, name of the file in the directory.
247+
work_directory: string, path to working directory.
248+
source_url: url to download from if file doesn't exist.
249+
250+
Returns:
251+
Path to resulting file.
252+
"""
253+
if not gfile.Exists(work_directory):
254+
gfile.MakeDirs(work_directory)
255+
filepath = os.path.join(work_directory, filename)
256+
if not gfile.Exists(filepath):
257+
urllib.request.urlretrieve(source_url, filepath)
258+
with gfile.GFile(filepath) as f:
259+
size = f.size()
260+
print('Successfully downloaded', filename, size, 'bytes.')
261+
return filepath
262+
263+
264+
@deprecated(None, 'Please use alternatives such as:'
265+
' tensorflow_datasets.load(\'mnist\')')
266+
def read_data_sets(train_dir,
267+
fake_data=False,
268+
one_hot=False,
269+
dtype=dtypes.float32,
270+
reshape=True,
271+
validation_size=5000,
272+
seed=None,
273+
source_url=DEFAULT_SOURCE_URL):
274+
if fake_data:
275+
276+
def fake():
277+
return _DataSet([], [],
278+
fake_data=True,
279+
one_hot=one_hot,
280+
dtype=dtype,
281+
seed=seed)
282+
283+
train = fake()
284+
validation = fake()
285+
test = fake()
286+
return _Datasets(train=train, validation=validation, test=test)
287+
288+
if not source_url: # empty string check
289+
source_url = DEFAULT_SOURCE_URL
290+
291+
train_images_file = 'train-images-idx3-ubyte.gz'
292+
train_labels_file = 'train-labels-idx1-ubyte.gz'
293+
test_images_file = 't10k-images-idx3-ubyte.gz'
294+
test_labels_file = 't10k-labels-idx1-ubyte.gz'
295+
296+
local_file = _maybe_download(train_images_file, train_dir,
297+
source_url + train_images_file)
298+
with gfile.Open(local_file, 'rb') as f:
299+
train_images = _extract_images(f)
300+
301+
local_file = _maybe_download(train_labels_file, train_dir,
302+
source_url + train_labels_file)
303+
with gfile.Open(local_file, 'rb') as f:
304+
train_labels = _extract_labels(f, one_hot=one_hot)
305+
306+
local_file = _maybe_download(test_images_file, train_dir,
307+
source_url + test_images_file)
308+
with gfile.Open(local_file, 'rb') as f:
309+
test_images = _extract_images(f)
310+
311+
local_file = _maybe_download(test_labels_file, train_dir,
312+
source_url + test_labels_file)
313+
with gfile.Open(local_file, 'rb') as f:
314+
test_labels = _extract_labels(f, one_hot=one_hot)
315+
316+
if not 0 <= validation_size <= len(train_images):
317+
raise ValueError(
318+
'Validation size should be between 0 and {}. Received: {}.'.format(
319+
len(train_images), validation_size))
320+
321+
validation_images = train_images[:validation_size]
322+
validation_labels = train_labels[:validation_size]
323+
train_images = train_images[validation_size:]
324+
train_labels = train_labels[validation_size:]
325+
326+
options = dict(dtype=dtype, reshape=reshape, seed=seed)
327+
328+
train = _DataSet(train_images, train_labels, **options)
329+
validation = _DataSet(validation_images, validation_labels, **options)
330+
test = _DataSet(test_images, test_labels, **options)
331+
332+
return _Datasets(train=train, validation=validation, test=test)
333+

0 commit comments

Comments
 (0)