Skip to content

Commit 8dfb702

Browse files
committed
Fixes
1 parent 20b06d5 commit 8dfb702

File tree

3 files changed

+4
-11
lines changed

3 files changed

+4
-11
lines changed

adlib/learners/outlier_removal_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def train(self):
4040

4141
base_cutoff = cutoff
4242
factor = 1
43-
max_cutoff = cutoff * 100
43+
max_cutoff = cutoff * 200
4444

4545
if self.verbose:
4646
print('\nBase cutoff:', cutoff, '\nMax cutoff:', max_cutoff, '\n')

adlib/learners/simple_learner.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def set_model(self, model):
2222

2323
def set_params(self, params: Dict):
2424
if 'model' in params:
25-
self.model = self.set_model(params['model'])
25+
self.set_model(params['model'])
2626
self.model.set_params(params)
2727

2828
def train(self):
@@ -60,10 +60,6 @@ def predict_log_proba(self, testing_instances):
6060
def decision_function(self, X):
6161
return self.model.learner.decision_function(X)
6262

63-
def set_params(self, params: Dict):
64-
if params['model'] is not None:
65-
self.model = self.set_model(params['model'])
66-
6763
def get_weight(self):
6864
if self.model.learner.kernel == 'rbf':
6965
return None
@@ -75,6 +71,3 @@ def get_weight(self):
7571

7672
def get_constant(self):
7773
return self.model.learner.intercept_
78-
79-
def decision_function(self, X):
80-
return self.model.learner.decision_function(X)

adlib/tests/learners/dp_learner_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test(self):
121121
list(self.dp_learner_testing_pred_labels),
122122
end - begin)
123123

124-
self.results.append(result)
124+
self.results.append(deepcopy(result))
125125

126126
if self.verbose:
127127
print('\nEND', self.learner_names[0] if len(self.learner_names) == 1 else 'learner',
@@ -201,7 +201,7 @@ def attack(self, instances):
201201
def _retrain(self):
202202
# Retrain the model with poisoned data
203203
learning_model = svm.SVC(probability=True, kernel='linear')
204-
self.attack_learner = SimpleLearner(learning_model, self.attack_instances)
204+
self.attack_learner = SimpleLearner(learning_model, deepcopy(self.attack_instances))
205205
self.attack_learner.train()
206206

207207
self.attack_training_pred_labels = self.attack_learner.predict(self.training_instances)

0 commit comments

Comments
 (0)