diff --git a/code/DeepDA/models.py b/code/DeepDA/models.py index d4d5502e..75648439 100755 --- a/code/DeepDA/models.py +++ b/code/DeepDA/models.py @@ -82,12 +82,13 @@ def get_parameters(self, initial_lr=1.0): def predict(self, x): features = self.base_network(x) - x = self.bottleneck_layer(features) - clf = self.classifier_layer(x) + if self.use_bottleneck: + features = self.bottleneck_layer(features) + clf = self.classifier_layer(features) return clf def epoch_based_processing(self, *args, **kwargs): if self.transfer_loss == "daan": self.adapt_loss.loss_func.update_dynamic_factor(*args, **kwargs) else: - pass \ No newline at end of file + pass