Skip to content

Commit 081a0ca

Browse files
Bulk Load CDK: Unit tests for memory manager (#45091)
1 parent 6730a3b commit 081a0ca

File tree

2 files changed

+137
-15
lines changed

2 files changed

+137
-15
lines changed

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/MemoryManager.kt

+35-15
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
package io.airbyte.cdk.state
66

7+
import io.micronaut.context.annotation.Secondary
78
import jakarta.inject.Singleton
89
import java.util.concurrent.atomic.AtomicLong
9-
import java.util.concurrent.locks.ReentrantLock
10-
import kotlin.concurrent.withLock
10+
import kotlinx.coroutines.channels.Channel
11+
import kotlinx.coroutines.sync.Mutex
12+
import kotlinx.coroutines.sync.withLock
1113

1214
/**
1315
* Manages memory usage for the destination.
@@ -17,31 +19,49 @@ import kotlin.concurrent.withLock
1719
* TODO: Some degree of logging/monitoring around how accurate we're actually being?
1820
*/
1921
@Singleton
20-
class MemoryManager {
21-
private val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory()
22+
class MemoryManager(availableMemoryProvider: AvailableMemoryProvider) {
23+
private val totalMemoryBytes: Long = availableMemoryProvider.availableMemoryBytes
2224
private var usedMemoryBytes = AtomicLong(0L)
23-
private val memoryLock = ReentrantLock()
24-
private val memoryLockCondition = memoryLock.newCondition()
25+
private val mutex = Mutex()
26+
private val syncChannel = Channel<Unit>(Channel.UNLIMITED)
2527

28+
val remainingMemoryBytes: Long
29+
get() = totalMemoryBytes - usedMemoryBytes.get()
30+
31+
/* Attempt to reserve memory. If enough memory is not available, waits until it is, then reserves. */
2632
suspend fun reserveBlocking(memoryBytes: Long) {
27-
memoryLock.withLock {
28-
while (usedMemoryBytes.get() + memoryBytes > availableMemoryBytes) {
29-
memoryLockCondition.await()
33+
if (memoryBytes > totalMemoryBytes) {
34+
throw IllegalArgumentException(
35+
"Requested ${memoryBytes}b memory exceeds ${totalMemoryBytes}b total"
36+
)
37+
}
38+
39+
mutex.withLock {
40+
while (usedMemoryBytes.get() + memoryBytes > totalMemoryBytes) {
41+
syncChannel.receive()
3042
}
3143
usedMemoryBytes.addAndGet(memoryBytes)
3244
}
3345
}
3446

3547
suspend fun reserveRatio(ratio: Double): Long {
36-
val estimatedSize = (availableMemoryBytes.toDouble() * ratio).toLong()
48+
val estimatedSize = (totalMemoryBytes.toDouble() * ratio).toLong()
3749
reserveBlocking(estimatedSize)
3850
return estimatedSize
3951
}
4052

41-
fun release(memoryBytes: Long) {
42-
memoryLock.withLock {
43-
usedMemoryBytes.addAndGet(-memoryBytes)
44-
memoryLockCondition.signalAll()
45-
}
53+
suspend fun release(memoryBytes: Long) {
54+
usedMemoryBytes.addAndGet(-memoryBytes)
55+
syncChannel.send(Unit)
4656
}
4757
}
58+
59+
interface AvailableMemoryProvider {
60+
val availableMemoryBytes: Long
61+
}
62+
63+
@Singleton
64+
@Secondary
65+
class JavaRuntimeAvailableMemoryProvider : AvailableMemoryProvider {
66+
override val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory()
67+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
3+
*/
4+
5+
package io.airbyte.cdk.state
6+
7+
import io.micronaut.context.annotation.Replaces
8+
import io.micronaut.context.annotation.Requires
9+
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
10+
import jakarta.inject.Singleton
11+
import java.util.concurrent.atomic.AtomicBoolean
12+
import kotlinx.coroutines.Dispatchers
13+
import kotlinx.coroutines.launch
14+
import kotlinx.coroutines.test.runTest
15+
import kotlinx.coroutines.withContext
16+
import kotlinx.coroutines.withTimeout
17+
import org.junit.jupiter.api.Assertions
18+
import org.junit.jupiter.api.Test
19+
20+
@MicronautTest
21+
class MemoryManagerTest {
22+
@Singleton
23+
@Replaces(MemoryManager::class)
24+
@Requires(env = ["test"])
25+
class MockAvailableMemoryProvider : AvailableMemoryProvider {
26+
override val availableMemoryBytes: Long = 1000
27+
}
28+
29+
@Test
30+
fun testReserveBlocking() = runTest {
31+
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
32+
val reserved = AtomicBoolean(false)
33+
34+
try {
35+
withTimeout(5000) { memoryManager.reserveBlocking(900) }
36+
} catch (e: Exception) {
37+
Assertions.fail<Unit>("Failed to reserve memory")
38+
}
39+
40+
Assertions.assertEquals(100, memoryManager.remainingMemoryBytes)
41+
42+
val job = launch {
43+
memoryManager.reserveBlocking(200)
44+
reserved.set(true)
45+
}
46+
47+
memoryManager.reserveBlocking(0)
48+
Assertions.assertFalse(reserved.get())
49+
50+
memoryManager.release(50)
51+
memoryManager.reserveBlocking(0)
52+
Assertions.assertEquals(150, memoryManager.remainingMemoryBytes)
53+
Assertions.assertFalse(reserved.get())
54+
55+
memoryManager.release(25)
56+
memoryManager.reserveBlocking(0)
57+
Assertions.assertEquals(175, memoryManager.remainingMemoryBytes)
58+
Assertions.assertFalse(reserved.get())
59+
60+
memoryManager.release(25)
61+
try {
62+
withTimeout(5000) { job.join() }
63+
} catch (e: Exception) {
64+
Assertions.fail<Unit>("Failed to unblock reserving memory")
65+
}
66+
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
67+
Assertions.assertTrue(reserved.get())
68+
}
69+
70+
@Test
71+
fun testReserveBlockingMultithreaded() = runTest {
72+
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
73+
withContext(Dispatchers.IO) {
74+
memoryManager.reserveBlocking(1000)
75+
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
76+
val nIterations = 100000
77+
78+
val jobs = (0 until nIterations).map { launch { memoryManager.reserveBlocking(10) } }
79+
80+
repeat(nIterations) {
81+
memoryManager.release(10)
82+
Assertions.assertTrue(
83+
memoryManager.remainingMemoryBytes >= 0,
84+
"Remaining memory is negative: ${memoryManager.remainingMemoryBytes}"
85+
)
86+
}
87+
jobs.forEach { it.join() }
88+
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
89+
}
90+
}
91+
92+
@Test
93+
fun testRequestingMoreThanAvailableThrows() = runTest {
94+
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
95+
try {
96+
memoryManager.reserveBlocking(1001)
97+
} catch (e: IllegalArgumentException) {
98+
return@runTest
99+
}
100+
Assertions.fail<Unit>("Requesting more memory than available should throw an exception")
101+
}
102+
}

0 commit comments

Comments
 (0)