Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 6bff547

Browse files
liuzh47ptrendx
authored andcommitted
Add evaluation_loss to the estimator base class. (#16888)
* Add evaluation_loss to the estimator base class. * Update the base estimator class to support the separate evaluation loss. * Add evaluation loss to the base estimator class. * Add unittest for evaluation loss in the test_evaluation function * Update estimator.py * Update estimator.py
1 parent ca76bf1 commit 6bff547

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

python/mxnet/gluon/contrib/estimator/estimator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class Estimator(object):
5959
Trainer to apply optimizer on network parameters.
6060
context : Context or list of Context
6161
Device(s) to run the training on.
62+
evaluation_loss: gluon.loss.loss
63+
Loss (objective) function to calculate during evaluation. If set evaluation_loss
64+
None, it will use the same loss function as self.loss
6265
6366
"""
6467

@@ -85,12 +88,16 @@ def __init__(self, net,
8588
metrics=None,
8689
initializer=None,
8790
trainer=None,
88-
context=None):
91+
context=None,
92+
evaluation_loss=None):
8993
self.net = net
9094
self.loss = self._check_loss(loss)
9195
self._train_metrics = _check_metrics(metrics)
9296
self._add_default_training_metrics()
9397
self._add_validation_metrics()
98+
self.evaluation_loss = self.loss
99+
if evaluation_loss is not None:
100+
self.evaluation_loss = self._check_loss(evaluation_loss)
94101

95102
self.logger = logging.Logger(name='Estimator', level=logging.INFO)
96103
self.logger.addHandler(logging.StreamHandler(sys.stdout))
@@ -228,7 +235,7 @@ def evaluate_batch(self,
228235
"""
229236
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
230237
pred = [self.net(x) for x in data]
231-
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
238+
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
232239
# update metrics
233240
for metric in val_metrics:
234241
if isinstance(metric, metric_loss):

tests/python/unittest/test_gluon_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@ def test_validation():
8383
ctx = mx.cpu()
8484
loss = gluon.loss.L2Loss()
8585
acc = mx.metric.Accuracy()
86+
evaluation_loss = gluon.loss.L1Loss()
8687
net.initialize(ctx=ctx)
8788
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
8889
est = Estimator(net=net,
8990
loss=loss,
9091
metrics=acc,
9192
trainer=trainer,
92-
context=ctx)
93+
context=ctx,
94+
evaluation_loss=evaluation_loss)
9395
# Input dataloader
9496
est.fit(train_data=dataloader,
9597
val_data=dataloader,

0 commit comments

Comments
 (0)