Skip to content

Commit ddf8220

Browse files
committed
Revert "Fixes"
This reverts commit 8dfb702.
1 parent 2043f77 commit ddf8220

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

adlib/learners/outlier_removal_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def train(self):
4040

4141
base_cutoff = cutoff
4242
factor = 1
43-
max_cutoff = cutoff * 200
43+
max_cutoff = cutoff * 100
4444

4545
if self.verbose:
4646
print('\nBase cutoff:', cutoff, '\nMax cutoff:', max_cutoff, '\n')

adlib/learners/simple_learner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def set_model(self, model):
2222

2323
def set_params(self, params: Dict):
2424
if 'model' in params:
25-
self.set_model(params['model'])
25+
self.model = self.set_model(params['model'])
2626
self.model.set_params(params)
2727

2828
def train(self):
@@ -60,6 +60,10 @@ def predict_log_proba(self, testing_instances):
6060
def decision_function(self, X):
6161
return self.model.learner.decision_function(X)
6262

63+
def set_params(self, params: Dict):
64+
if params['model'] is not None:
65+
self.model = self.set_model(params['model'])
66+
6367
def get_weight(self):
6468
if self.model.learner.kernel == 'rbf':
6569
return None
@@ -71,3 +75,6 @@ def get_weight(self):
7175

7276
def get_constant(self):
7377
return self.model.learner.intercept_
78+
79+
def decision_function(self, X):
80+
return self.model.learner.decision_function(X)

adlib/tests/learners/dp_learner_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test(self):
121121
list(self.dp_learner_testing_pred_labels),
122122
end - begin)
123123

124-
self.results.append(deepcopy(result))
124+
self.results.append(result)
125125

126126
if self.verbose:
127127
print('\nEND', self.learner_names[0] if len(self.learner_names) == 1 else 'learner',
@@ -201,7 +201,7 @@ def attack(self, instances):
201201
def _retrain(self):
202202
# Retrain the model with poisoned data
203203
learning_model = svm.SVC(probability=True, kernel='linear')
204-
self.attack_learner = SimpleLearner(learning_model, deepcopy(self.attack_instances))
204+
self.attack_learner = SimpleLearner(learning_model, self.attack_instances)
205205
self.attack_learner.train()
206206

207207
self.attack_training_pred_labels = self.attack_learner.predict(self.training_instances)

0 commit comments

Comments
 (0)