Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit c82af38

Browse files
liuzh47leezu
authored andcommitted
Add support of plug and play fit_batch and evaluate_batch (#16982)
* Add support of plug and play fit_batch and evaluate_batch * Add check for the validity of the estimator model * Rename estimator model as batch processor * Remove unused import * Add documentation of the batch processor class * refine the documentation of the batch processor * Fix merge bugs * fix bugs introduced during merge * fix sanity check failures * fix CI bugs
1 parent 27389b1 commit c82af38

File tree

4 files changed

+248
-60
lines changed

4 files changed

+248
-60
lines changed

python/mxnet/gluon/contrib/estimator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,7 @@
1919
"""Gluon Estimator Module"""
2020
from . import estimator
2121
from . import event_handler
22+
from . import batch_processor
2223
from .estimator import *
2324
from .event_handler import *
25+
from .batch_processor import *
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
20+
"""Gluon Batch Processor for Estimators"""
21+
22+
from ...utils import split_and_load
23+
from .... import autograd
24+
25+
__all__ = ['BatchProcessor']
26+
27+
class BatchProcessor(object):
28+
"""BatchProcessor Class for plug and play fit_batch & evaluate_batch
29+
30+
During training or validation, data are divided into minibatches for processing. This
31+
class aims at providing hooks of training or validating on a minibatch of data. Users
32+
may provide customized fit_batch() and evaluate_batch() methods by inheriting from
33+
this class and overriding class methods.
34+
35+
:py:class:`BatchProcessor` can be used to replace fit_batch() and evaluate_batch()
36+
in the base estimator class
37+
"""
38+
39+
def __init__(self):
40+
pass
41+
42+
def _get_data_and_label(self, batch, ctx, batch_axis=0):
43+
data = batch[0]
44+
label = batch[1]
45+
data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
46+
label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
47+
return data, label
48+
49+
def evaluate_batch(self, estimator,
50+
val_batch,
51+
batch_axis=0):
52+
"""Evaluate the estimator model on a batch of validation data.
53+
54+
Parameters
55+
----------
56+
estimator : Estimator
57+
Reference to the estimator
58+
val_batch : tuple
59+
Data and label of a batch from the validation data loader.
60+
batch_axis : int, default 0
61+
Batch axis to split the validation data into devices.
62+
"""
63+
data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis)
64+
pred = [estimator.eval_net(x) for x in data]
65+
loss = [estimator.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
66+
67+
return data, label, pred, loss
68+
69+
def fit_batch(self, estimator,
70+
train_batch,
71+
batch_axis=0):
72+
"""Trains the estimator model on a batch of training data.
73+
74+
Parameters
75+
----------
76+
estimator : Estimator
77+
Reference to the estimator
78+
train_batch : tuple
79+
Data and label of a batch from the training data loader.
80+
batch_axis : int, default 0
81+
Batch axis to split the training data into devices.
82+
83+
Returns
84+
-------
85+
data: List of NDArray
86+
Sharded data from the batch. Data is sharded with
87+
`gluon.split_and_load`.
88+
label: List of NDArray
89+
Sharded label from the batch. Labels are sharded with
90+
`gluon.split_and_load`.
91+
pred: List of NDArray
92+
Prediction on each of the sharded inputs.
93+
loss: List of NDArray
94+
Loss on each of the sharded inputs.
95+
"""
96+
data, label = self._get_data_and_label(train_batch, estimator.context, batch_axis)
97+
98+
with autograd.record():
99+
pred = [estimator.net(x) for x in data]
100+
loss = [estimator.loss(y_hat, y) for y_hat, y in zip(pred, label)]
101+
102+
for l in loss:
103+
l.backward()
104+
105+
return data, label, pred, loss

python/mxnet/gluon/contrib/estimator/estimator.py

Lines changed: 24 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
from ...loss import Loss as gluon_loss
3333
from ...trainer import Trainer
3434
from ...utils import split_and_load
35-
from .... import autograd
3635
from ....context import Context, cpu, gpu, num_gpus
3736
from ....metric import Loss as metric_loss
37+
from .batch_processor import BatchProcessor
3838

3939
__all__ = ['Estimator']
4040

@@ -84,7 +84,8 @@ class Estimator(object):
8484
the naming in mxnet Gluon API, please refer to the site
8585
(https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html)
8686
for future information.
87-
87+
batch_processor: BatchProcessor
88+
BatchProcessor provides customized fit_batch() and evaluate_batch() methods
8889
"""
8990

9091
logger = None
@@ -113,7 +114,8 @@ def __init__(self, net,
113114
trainer=None,
114115
context=None,
115116
evaluation_loss=None,
116-
eval_net=None):
117+
eval_net=None,
118+
batch_processor=None):
117119
self.net = net
118120
self.loss = self._check_loss(loss)
119121
self._train_metrics = _check_metrics(train_metrics)
@@ -133,6 +135,7 @@ def __init__(self, net,
133135
self.context = self._check_context(context)
134136
self._initialize(initializer)
135137
self.trainer = self._check_trainer(trainer)
138+
self.batch_processor = self._check_batch_processor(batch_processor)
136139

137140
def _check_loss(self, loss):
138141
if not isinstance(loss, gluon_loss):
@@ -173,6 +176,18 @@ def _check_context(self, context):
173176
context = [cpu()]
174177
return context
175178

179+
def _check_batch_processor(self, batch_processor):
180+
# check whether the batch processor contains fit_batch() and evaluate_batch() methods
181+
if batch_processor is not None:
182+
model_fit = getattr(batch_processor, 'fit_batch', None)
183+
model_evaluate = getattr(batch_processor, 'evaluate_batch', None)
184+
if not callable(model_fit) or not callable(model_evaluate):
185+
raise ValueError('Customized Batch Processor must contain fit_batch()'
186+
' and evaluate_batch() methods')
187+
else:
188+
batch_processor = BatchProcessor()
189+
return batch_processor
190+
176191
def _initialize(self, initializer):
177192
# initialize the network
178193
if not self._is_initialized():
@@ -254,24 +269,6 @@ def train_metrics(self):
254269
def val_metrics(self):
255270
return self._val_metrics
256271

257-
def evaluate_batch(self,
258-
val_batch,
259-
batch_axis=0):
260-
"""Evaluate model on a batch of validation data.
261-
262-
Parameters
263-
----------
264-
val_batch : tuple
265-
Data and label of a batch from the validation data loader.
266-
batch_axis : int, default 0
267-
Batch axis to split the validation data into devices.
268-
"""
269-
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
270-
pred = [self.eval_net(x) for x in data]
271-
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
272-
273-
return data, label, pred, loss
274-
275272
def evaluate(self,
276273
val_data,
277274
batch_axis=0,
@@ -300,6 +297,7 @@ def evaluate(self,
300297

301298
for metric in self.val_metrics:
302299
metric.reset()
300+
estimator_ref = self
303301

304302
event_handlers = self._prepare_default_validation_handlers(event_handlers)
305303

@@ -315,50 +313,16 @@ def evaluate(self,
315313
for handler in batch_begin:
316314
handler.batch_begin(estimator_ref, batch=batch)
317315

318-
_, label, pred, loss = self.evaluate_batch(batch, batch_axis)
316+
_, label, pred, loss = \
317+
self.batch_processor.evaluate_batch(estimator_ref, batch,
318+
batch_axis)
319319

320320
for handler in batch_end:
321321
handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss)
322322

323323
for handler in epoch_end:
324324
handler.epoch_end(estimator_ref)
325325

326-
def fit_batch(self, train_batch, batch_axis=0):
327-
"""Trains the model on a batch of training data.
328-
329-
Parameters
330-
----------
331-
train_batch : tuple
332-
Data and label of a batch from the training data loader.
333-
batch_axis : int, default 0
334-
Batch axis to split the training data into devices.
335-
336-
Returns
337-
-------
338-
data: List of NDArray
339-
Sharded data from the batch. Data is sharded with
340-
`gluon.split_and_load`.
341-
label: List of NDArray
342-
Sharded label from the batch. Labels are sharded with
343-
`gluon.split_and_load`.
344-
pred: List of NDArray
345-
Prediction on each of the sharded inputs.
346-
loss: List of NDArray
347-
Loss on each of the sharded inputs.
348-
"""
349-
data, label = self._get_data_and_label(train_batch, self.context, batch_axis)
350-
351-
batch_size = train_batch[0].shape[batch_axis]
352-
353-
with autograd.record():
354-
pred = [self.net(x) for x in data]
355-
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
356-
357-
for l in loss:
358-
l.backward()
359-
360-
return data, label, pred, loss
361-
362326
def fit(self, train_data,
363327
val_data=None,
364328
epochs=None,
@@ -432,8 +396,8 @@ def fit(self, train_data,
432396
for handler in batch_begin:
433397
handler.batch_begin(estimator_ref, batch=batch)
434398

435-
_, label, pred, loss = self.fit_batch(batch, batch_axis)
436-
399+
_, label, pred, loss = self.batch_processor.fit_batch(estimator_ref,
400+
batch, batch_axis)
437401
# batch end
438402

439403
batch_end_result = []
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
''' Unit tests for Gluon Batch Processor '''
19+
20+
import sys
21+
import unittest
22+
import warnings
23+
24+
import mxnet as mx
25+
from mxnet import gluon
26+
from mxnet.gluon import nn
27+
from mxnet.gluon.contrib.estimator import *
28+
from mxnet.gluon.contrib.estimator.event_handler import *
29+
from mxnet.gluon.contrib.estimator.batch_processor import BatchProcessor
30+
from nose.tools import assert_raises
31+
32+
def _get_test_network():
33+
net = nn.Sequential()
34+
net.add(nn.Dense(4, activation='relu', flatten=False))
35+
return net
36+
37+
38+
def _get_test_data():
39+
batch_size = 4
40+
in_data = mx.nd.random.uniform(shape=(10, 3))
41+
out_data = mx.nd.random.uniform(shape=(10, 4))
42+
# Input dataloader
43+
dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
44+
dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
45+
dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
46+
return dataloader, dataiter
47+
48+
def test_batch_processor_fit():
49+
''' test estimator with different train data types '''
50+
net = _get_test_network()
51+
dataloader, dataiter = _get_test_data()
52+
num_epochs = 1
53+
ctx = mx.cpu()
54+
loss = gluon.loss.L2Loss()
55+
acc = mx.metric.Accuracy()
56+
net.initialize(ctx=ctx)
57+
processor = BatchProcessor()
58+
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
59+
est = Estimator(net=net,
60+
loss=loss,
61+
train_metrics=acc,
62+
trainer=trainer,
63+
context=ctx,
64+
batch_processor=processor)
65+
66+
est.fit(train_data=dataloader,
67+
epochs=num_epochs)
68+
69+
with assert_raises(ValueError):
70+
est.fit(train_data=dataiter,
71+
epochs=num_epochs)
72+
73+
# Input NDArray
74+
with assert_raises(ValueError):
75+
est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
76+
epochs=num_epochs)
77+
78+
79+
def test_batch_processor_validation():
80+
''' test different validation data types'''
81+
net = _get_test_network()
82+
dataloader, dataiter = _get_test_data()
83+
num_epochs = 1
84+
ctx = mx.cpu()
85+
loss = gluon.loss.L2Loss()
86+
acc = mx.metric.Accuracy()
87+
evaluation_loss = gluon.loss.L1Loss()
88+
net.initialize(ctx=ctx)
89+
processor = BatchProcessor()
90+
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
91+
est = Estimator(net=net,
92+
loss=loss,
93+
train_metrics=acc,
94+
trainer=trainer,
95+
context=ctx,
96+
evaluation_loss=evaluation_loss,
97+
batch_processor=processor)
98+
# Input dataloader
99+
est.fit(train_data=dataloader,
100+
val_data=dataloader,
101+
epochs=num_epochs)
102+
103+
# using validation handler
104+
train_metrics = est.train_metrics
105+
val_metrics = est.val_metrics
106+
validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate)
107+
108+
with assert_raises(ValueError):
109+
est.fit(train_data=dataiter,
110+
val_data=dataiter,
111+
epochs=num_epochs)
112+
# Input NDArray
113+
with assert_raises(ValueError):
114+
est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
115+
val_data=[mx.nd.ones(shape=(10, 3))],
116+
epochs=num_epochs)
117+

0 commit comments

Comments
 (0)