@@ -50,7 +50,17 @@ def train(self):
50
50
iteration = 0
51
51
52
52
while cutoff < max_cutoff :
53
- fvs , labels = self ._remove_outliers (fvs , labels , cutoff )
53
+ try :
54
+ fvs , labels = self ._remove_outliers (fvs , labels , cutoff )
55
+ except : # Failure - usually when there are no instances left (removes all)
56
+ if self .verbose :
57
+ print ('\n ORL Iteration:' , iteration , '- factor:' , factor ,
58
+ '- cutoff:' , cutoff , '- FAILURE\n ' )
59
+
60
+ factor += 1
61
+ cutoff = base_cutoff * factor
62
+ fvs , labels = deepcopy (orig_fvs ), deepcopy (orig_labels )
63
+ iteration += 1
54
64
55
65
self .w = np .full (fvs .shape [1 ], 0.0 )
56
66
for i , fv in enumerate (fvs ):
@@ -97,7 +107,7 @@ def _remove_outliers(self, fvs, labels, cutoff):
97
107
while old_number_of_instances != len (labels ):
98
108
# Assume at least 50% are non-poisonous instances
99
109
if iteration > 0 and old_number_of_instances < 0.5 * original_num_instances :
100
- break
110
+ raise ValueError ()
101
111
102
112
if self .verbose :
103
113
print ('Iteration:' , iteration , '- num_instances:' , len (labels ))
0 commit comments