Skip to content

Commit 7ebe500

Browse files
anuragarnabScenic Authors
authored andcommitted
Internal
PiperOrigin-RevId: 730796173
1 parent ac06058 commit 7ebe500

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

scenic/model_lib/base_models/multilabel_classification_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Dict, Optional, Tuple, Union
1919

2020
from flax.training import common_utils
21-
from immutabledict import immutabledict
21+
from immutabledict import immutabledict # pylint: disable=g-importing-member
2222
import jax.numpy as jnp
2323
from scenic.model_lib.base_models import base_model
2424
from scenic.model_lib.base_models import model_utils
@@ -151,8 +151,8 @@ def get_metrics_fn_jit(self,
151151
del split # For all splits, we return the same metric functions.
152152
return functools.partial(
153153
base_model.metrics_function_jit,
154-
target_is_multihot=self.dataset_meta_data.get('target_is_onehot',
155-
False),
154+
target_is_one_or_multihot=self.dataset_meta_data.get('target_is_onehot',
155+
False),
156156
metrics=_MULTI_LABEL_CLASSIFICATION_METRICS)
157157

158158
def loss_function(

0 commit comments

Comments
 (0)