Skip to content

Commit 98730a1

Browse files
Fix issue when calling cleanup while concurrently executing searches (#483)
1 parent b637f65 commit 98730a1

File tree

5 files changed

+270
-14
lines changed

5 files changed

+270
-14
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -372,11 +372,11 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi
372372
// Loop over 0..maxLayer, re-score neighbors for each layer
373373
var sf = newProvider.searchProviderFor(i).scoreFunction();
374374
for (int lvl = 0; lvl <= maxLayer; lvl++) {
375-
var oldNeighbors = other.graph.getNeighbors(lvl, i);
375+
var oldNeighborsIt = other.graph.getNeighborsIterator(lvl, i);
376376
// Copy edges, compute new scores
377-
var newNeighbors = new NodeArray(oldNeighbors.size());
378-
for (var it = oldNeighbors.iterator(); it.hasNext();) {
379-
int neighbor = it.nextInt();
377+
var newNeighbors = new NodeArray(oldNeighborsIt.size());
378+
while (oldNeighborsIt.hasNext()) {
379+
int neighbor = oldNeighborsIt.nextInt();
380380
// since we're using a different score provider, use insertSorted instead of addInOrder
381381
newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor));
382382
}
@@ -647,15 +647,14 @@ public synchronized long removeDeletedNodes() {
647647
var newEdges = new ConcurrentHashMap<Integer, Set<Integer>>(); // new edges for key k are values v
648648
parallelExecutor.submit(() -> {
649649
IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(i -> {
650-
var neighbors = graph.getNeighbors(level, i);
651-
if (neighbors == null || toDelete.get(i)) {
650+
if (toDelete.get(i)) {
652651
return;
653652
}
654-
for (var it = neighbors.iterator(); it.hasNext(); ) {
653+
for (var it = graph.getNeighborsIterator(level, i); it.hasNext(); ) {
655654
var j = it.nextInt();
656655
if (toDelete.get(j)) {
657656
var newEdgesForI = newEdges.computeIfAbsent(i, __ -> ConcurrentHashMap.newKeySet());
658-
for (var jt = graph.getNeighbors(level, j).iterator(); jt.hasNext(); ) {
657+
for (var jt = graph.getNeighborsIterator(level, j); jt.hasNext(); ) {
659658
int k = jt.nextInt();
660659
if (i != k && !toDelete.get(k)) {
661660
newEdgesForI.add(k);

jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,23 @@ public boolean hasNext() {
9696
return cur < size;
9797
}
9898
}
99+
100+
EmptyNodeIterator EMPTY_NODE_ITERATOR = new EmptyNodeIterator();
101+
102+
class EmptyNodeIterator implements NodesIterator {
103+
@Override
104+
public int size() {
105+
return 0;
106+
}
107+
108+
@Override
109+
public int nextInt() {
110+
throw new NoSuchElementException();
111+
}
112+
113+
@Override
114+
public boolean hasNext() {
115+
return false;
116+
}
117+
}
99118
}

jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import java.io.UncheckedIOException;
4242
import java.util.ArrayList;
4343
import java.util.List;
44+
import java.util.NoSuchElementException;
4445
import java.util.concurrent.ConcurrentHashMap;
4546
import java.util.concurrent.ConcurrentMap;
4647
import java.util.concurrent.atomic.AtomicInteger;
@@ -108,6 +109,25 @@ Neighbors getNeighbors(int level, int node) {
108109
return layers.get(level).get(node);
109110
}
110111

112+
/**
113+
* Returns an iterator over the neighbors for the given node at the specified level.
114+
*
115+
* @param level the layer
116+
* @param node the node id
117+
* @return a NodesIterator, which can be empty
118+
*/
119+
NodesIterator getNeighborsIterator(int level, int node) {
120+
if (level >= layers.size()) {
121+
return NodesIterator.EMPTY_NODE_ITERATOR;
122+
}
123+
var neighs = layers.get(level).get(node);
124+
if (neighs == null) {
125+
return NodesIterator.EMPTY_NODE_ITERATOR;
126+
} else {
127+
return neighs.iterator();
128+
}
129+
}
130+
111131
@Override
112132
public int size(int level) {
113133
return layers.get(level).size();
@@ -366,7 +386,8 @@ public class ConcurrentGraphIndexView extends FrozenView {
366386

367387
@Override
368388
public NodesIterator getNeighborsIterator(int level, int node) {
369-
var it = getNeighbors(level, node).iterator();
389+
NodesIterator it = OnHeapGraphIndex.this.getNeighborsIterator(level, node);
390+
370391
return new NodesIterator() {
371392
int nextNode = advance();
372393

@@ -389,7 +410,7 @@ public int size() {
389410
public int nextInt() {
390411
int current = nextNode;
391412
if (current == Integer.MIN_VALUE) {
392-
throw new IndexOutOfBoundsException();
413+
throw new NoSuchElementException();
393414
}
394415
nextNode = advance();
395416
return current;
@@ -406,7 +427,8 @@ public boolean hasNext() {
406427
private class FrozenView implements View {
407428
@Override
408429
public NodesIterator getNeighborsIterator(int level, int node) {
409-
return getNeighbors(level, node).iterator();
430+
return OnHeapGraphIndex.this.getNeighborsIterator(level, node);
431+
410432
}
411433

412434
@Override
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.graph;
18+
19+
import com.carrotsearch.randomizedtesting.RandomizedTest;
20+
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
21+
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
22+
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
23+
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
24+
import io.github.jbellis.jvector.util.FixedBitSet;
25+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
26+
import io.github.jbellis.jvector.vector.types.VectorFloat;
27+
import org.junit.Test;
28+
import org.slf4j.Logger;
29+
import org.slf4j.LoggerFactory;
30+
31+
import static io.github.jbellis.jvector.TestUtil.createRandomVectors;
32+
import static io.github.jbellis.jvector.TestUtil.randomVector;
33+
34+
import java.util.Collections;
35+
import java.util.ArrayList;
36+
import java.util.List;
37+
import java.util.Map;
38+
import java.util.concurrent.ConcurrentHashMap;
39+
import java.util.concurrent.CopyOnWriteArrayList;
40+
import java.util.concurrent.ExecutionException;
41+
import java.util.concurrent.atomic.AtomicInteger;
42+
import java.util.concurrent.locks.Lock;
43+
import java.util.concurrent.locks.ReentrantLock;
44+
import java.util.stream.Collectors;
45+
import java.util.stream.IntStream;
46+
47+
/**
48+
* Runs "nVectors" operations, where each operation is either:
49+
* - an insertion
50+
* - a mock deletion, instantiated through the use of a BitSet for skipping these nodes during search
51+
* - a search
52+
* With probability 0.01, we run cleanup to commit the deletions to the index. The cleanup process and the insertions
53+
* cannot be concurrently executed (we use a lock to control their execution).
54+
*/
55+
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
56+
public class TestConcurrentReadWriteDeletes extends RandomizedTest {
57+
private static final Logger logger = LoggerFactory.getLogger(TestConcurrentReadWriteDeletes.class);
58+
59+
private static final int nVectors = 20_000;
60+
private static final int dimension = 16;
61+
private static final double cleanupProbability = 0.01;
62+
63+
private KeySet keysInserted = new KeySet();
64+
private List<Integer> keysRemoved = new CopyOnWriteArrayList();
65+
66+
private List<VectorFloat<?>> vectors = createRandomVectors(nVectors, dimension);
67+
private RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension);
68+
69+
private VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
70+
71+
private BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, similarityFunction);
72+
private GraphIndexBuilder builder = new GraphIndexBuilder(bsp, 2, 2, 10, 1.0f, 1.0f, true);
73+
74+
private FixedBitSet liveNodes = new FixedBitSet(nVectors);
75+
76+
private final Lock writeLock = new ReentrantLock();
77+
78+
@Test
79+
public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException {
80+
var vv = ravv.threadLocalSupplier();
81+
82+
testConcurrentOps(i -> {
83+
var R = getRandom();
84+
if (R.nextDouble() < 0.2 || keysInserted.isEmpty())
85+
{
86+
// In the future, we could improve this test by acquiring the lock earlier and executing other
87+
writeLock.lock();
88+
try {
89+
builder.addGraphNode(i, vv.get().getVector(i));
90+
liveNodes.set(i);
91+
keysInserted.add(i);
92+
} finally {
93+
writeLock.unlock();
94+
}
95+
} else if (R.nextDouble() < 0.1) {
96+
var key = keysInserted.getRandom();
97+
if (!keysRemoved.contains(key)) {
98+
liveNodes.flip(key);
99+
keysRemoved.add(key);
100+
}
101+
} else {
102+
var queryVector = randomVector(getRandom(), dimension);
103+
SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, ravv);
104+
105+
int topK = Math.min(1, keysInserted.size());
106+
int rerankK = Math.min(50, keysInserted.size());
107+
108+
GraphSearcher searcher = new GraphSearcher(builder.getGraph());
109+
searcher.search(ssp, topK, rerankK, 0.f, 0.f, liveNodes);
110+
}
111+
});
112+
}
113+
114+
@FunctionalInterface
115+
private interface Op
116+
{
117+
void run(int i) throws Throwable;
118+
}
119+
120+
private void testConcurrentOps(Op op) throws ExecutionException, InterruptedException {
121+
AtomicInteger counter = new AtomicInteger();
122+
long start = System.currentTimeMillis();
123+
124+
// Use a simpler approach that doesn't rely on parallel streams
125+
var keys = IntStream.range(0, nVectors).boxed().collect(Collectors.toList());
126+
Collections.shuffle(keys, getRandom());
127+
128+
// Use a thread-safe approach without relying on RandomizedContext
129+
int threadCount = Math.min(Runtime.getRuntime().availableProcessors(), 8); // Limit thread count
130+
List<Thread> threads = new ArrayList<>();
131+
int keysPerThread = nVectors / threadCount;
132+
133+
// Create a thread-safe random seed for each thread
134+
final long randomSeed = getRandom().nextLong();
135+
136+
for (int t = 0; t < threadCount; t++) {
137+
final int threadIndex = t;
138+
final int startIdx = threadIndex * keysPerThread;
139+
final int endIdx = (threadIndex == threadCount - 1) ? keys.size() : (threadIndex + 1) * keysPerThread;
140+
141+
Thread thread = new Thread(() -> {
142+
for (int i = startIdx; i < endIdx; i++) {
143+
int key = keys.get(i);
144+
wrappedOp(op, key);
145+
146+
if (counter.incrementAndGet() % 1_000 == 0) {
147+
var elapsed = System.currentTimeMillis() - start;
148+
logger.info(String.format("%d ops in %dms = %f ops/s",
149+
counter.get(), elapsed, counter.get() * 1000.0 / elapsed));
150+
}
151+
152+
if (getRandom().nextDouble() < cleanupProbability) {
153+
writeLock.lock();
154+
try {
155+
for (Integer keyToRemove : keysRemoved) {
156+
builder.markNodeDeleted(keyToRemove);
157+
}
158+
keysRemoved.clear();
159+
builder.cleanup();
160+
} finally {
161+
writeLock.unlock();
162+
}
163+
}
164+
}
165+
});
166+
167+
threads.add(thread);
168+
thread.start();
169+
}
170+
171+
// Wait for all threads to complete
172+
for (Thread thread : threads) {
173+
thread.join();
174+
}
175+
}
176+
177+
private static void wrappedOp(Op op, Integer i) {
178+
try
179+
{
180+
op.run(i);
181+
}
182+
catch (Throwable e)
183+
{
184+
throw new RuntimeException(e);
185+
}
186+
}
187+
188+
private static class KeySet
189+
{
190+
private final Map<Integer, Integer> keys = new ConcurrentHashMap<>();
191+
private final AtomicInteger ordinal = new AtomicInteger();
192+
193+
public void add(Integer key)
194+
{
195+
var i = ordinal.getAndIncrement();
196+
keys.put(i, key);
197+
}
198+
199+
public int getRandom()
200+
{
201+
if (isEmpty())
202+
throw new IllegalStateException();
203+
var i = TestConcurrentReadWriteDeletes.getRandom().nextInt(ordinal.get());
204+
// in case there is race with add(key), retry another random
205+
return keys.containsKey(i) ? keys.get(i) : getRandom();
206+
}
207+
208+
public boolean isEmpty()
209+
{
210+
return keys.isEmpty();
211+
}
212+
213+
public int size() {
214+
return keys.size();
215+
}
216+
}
217+
}

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -590,9 +590,8 @@ public void testDiversity3d(boolean addHierarchy) {
590590

591591
private void assertNeighbors(OnHeapGraphIndex graph, int node, int... expected) {
592592
Arrays.sort(expected);
593-
ConcurrentNeighborMap.Neighbors nn = graph.getNeighbors(0, node); // TODO
594-
Iterator<Integer> it = nn.iterator();
595-
int[] actual = new int[nn.size()];
593+
NodesIterator it = graph.getNeighborsIterator(0, node);
594+
int[] actual = new int[it.size()];
596595
for (int i = 0; i < actual.length; i++) {
597596
actual[i] = it.next();
598597
}

0 commit comments

Comments
 (0)