41
41
public class FilteredHnswGraphSearcher extends HnswGraphSearcher {
42
42
// The maximum percentage of filtered docs before using this filtered strategy becomes less
43
43
// effective than regular HNSW search
44
- static final float MAX_FILTER_THRESHOLD = 0.60f ;
44
+ static final float MAX_FILTER_THRESHOLD = 1.0f ;
45
45
46
46
// How many filtered candidates must be found to consider N-hop neighbors
47
47
private static final float EXPANDED_EXPLORATION_LAMBDA = 0.10f ;
48
48
49
- private final BitSet explorationVisited ;
50
49
private final int maxExplorationMultiplier ;
50
+ private final int minToScore ;
51
51
52
52
/** Creates a new graph searcher. */
53
53
private FilteredHnswGraphSearcher (
54
- NeighborQueue candidates ,
55
- BitSet explorationVisited ,
56
- BitSet visited ,
57
- int filterSize ,
58
- HnswGraph graph ) {
54
+ NeighborQueue candidates , BitSet visited , int filterSize , HnswGraph graph ) {
59
55
super (candidates , visited );
60
56
assert graph .maxConn () > 0 : "graph must have known max connections" ;
61
- this .explorationVisited = explorationVisited ;
62
- this .maxExplorationMultiplier = Math .min (graph .size () / filterSize , 8 );
57
+ this .maxExplorationMultiplier = Math . min ( graph . size () / filterSize , graph . maxConn () / 2 ) ;
58
+ this .minToScore = Math .max (graph .maxConn () / 4 , 1 );
63
59
}
64
60
65
61
/**
@@ -80,11 +76,7 @@ public static FilteredHnswGraphSearcher create(
80
76
throw new IllegalArgumentException ("filterSize must be > 0 and < graph size" );
81
77
}
82
78
return new FilteredHnswGraphSearcher (
83
- new NeighborQueue (k , true ),
84
- bitSet (filterSize , getGraphSize (graph ), k ),
85
- new SparseFixedBitSet (getGraphSize (graph )),
86
- filterSize ,
87
- graph );
79
+ new NeighborQueue (k , true ), bitSet (filterSize , getGraphSize (graph ), k ), filterSize , graph );
88
80
}
89
81
90
82
private static BitSet bitSet (long filterSize , int graphSize , int topk ) {
@@ -164,22 +156,33 @@ void searchLevel(
164
156
float filteredAmount = toExplore .count () / (float ) neighborCount ;
165
157
int maxToScoreCount =
166
158
(int ) (neighborCount * Math .min (maxExplorationMultiplier , 1f / (1f - filteredAmount )));
159
+ int maxAdditionalToExploreCount = toExplore .capacity () - 1 ;
167
160
// There is enough filtered, or we don't have enough candidates to score and explore
168
- if (toScore .count () < maxToScoreCount && filteredAmount > EXPANDED_EXPLORATION_LAMBDA ) {
161
+ int totalExplored = toScore .count () + toExplore .count ();
162
+ if (toScore .count () < maxToScoreCount
163
+ && filteredAmount > EXPANDED_EXPLORATION_LAMBDA
164
+ && totalExplored < maxAdditionalToExploreCount ) {
169
165
// Now we need to explore the neighbors of the neighbors
170
166
int exploreFriend ;
171
167
while ((exploreFriend = toExplore .poll ()) != NO_MORE_DOCS
168
+ // only explore initial additional neighborhood
169
+ && totalExplored < maxAdditionalToExploreCount
172
170
&& toScore .count () < maxToScoreCount ) {
173
171
graphSeek (graph , level , exploreFriend );
174
172
int friendOfAFriendOrd ;
175
173
while ((friendOfAFriendOrd = graph .nextNeighbor ()) != NO_MORE_DOCS
176
174
&& toScore .count () < maxToScoreCount ) {
177
- if (visited .get (friendOfAFriendOrd )
178
- || explorationVisited .getAndSet (friendOfAFriendOrd )) {
175
+ if (visited .getAndSet (friendOfAFriendOrd )) {
179
176
continue ;
180
177
}
178
+ totalExplored ++;
181
179
if (acceptOrds .get (friendOfAFriendOrd )) {
182
180
toScore .add (friendOfAFriendOrd );
181
+ // If we have YET to find a minimum of number candidates, we will continue to explore
182
+ // until our max
183
+ } else if (totalExplored < maxAdditionalToExploreCount
184
+ && toScore .count () < minToScore ) {
185
+ toExplore .add (friendOfAFriendOrd );
183
186
}
184
187
}
185
188
}
@@ -202,7 +205,6 @@ void searchLevel(
202
205
private void prepareScratchState () {
203
206
candidates .clear ();
204
207
visited .clear ();
205
- explorationVisited .clear ();
206
208
}
207
209
208
210
private static class IntArrayQueue {
@@ -214,22 +216,17 @@ private static class IntArrayQueue {
214
216
nodes = new int [capacity ];
215
217
}
216
218
217
- int count () {
218
- return size - upto ;
219
+ int capacity () {
220
+ return nodes . length ;
219
221
}
220
222
221
- void expand (int capacity ) {
222
- if (nodes .length < capacity ) {
223
- int [] newNodes = new int [capacity ];
224
- System .arraycopy (nodes , 0 , newNodes , 0 , size );
225
- nodes = newNodes ;
226
- }
223
+ int count () {
224
+ return size - upto ;
227
225
}
228
226
229
227
void add (int node ) {
230
- assert isFull () == false ;
231
- if (size == nodes .length ) {
232
- expand (size * 2 );
228
+ if (isFull ()) {
229
+ throw new UnsupportedOperationException ("Initial capacity should remain unchanged" );
233
230
}
234
231
nodes [size ++] = node ;
235
232
}
0 commit comments