Skip to content

Commit 41742f5

Browse files
serrasnomisRev
andauthored
MemoizedDeepRecursiveFunction (#3091)
Co-authored-by: serras <[email protected]> Co-authored-by: Simon Vergauwen <[email protected]>
1 parent b8a0ce0 commit 41742f5

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

arrow-libs/core/arrow-core/api/arrow-core.api

+4
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,10 @@ public final class arrow/core/Memoization {
776776
public static final fun memoize (Lkotlin/jvm/functions/Function5;)Lkotlin/jvm/functions/Function5;
777777
}
778778

779+
public final class arrow/core/MemoizedDeepRecursiveFunctionKt {
780+
public static final fun MemoizedDeepRecursiveFunction (Lkotlin/jvm/functions/Function3;)Lkotlin/DeepRecursiveFunction;
781+
}
782+
779783
public abstract interface class arrow/core/NonEmptyCollection : java/util/Collection, kotlin/jvm/internal/markers/KMappedMarker {
780784
public abstract fun distinct ()Larrow/core/NonEmptyList;
781785
public abstract fun distinctBy (Lkotlin/jvm/functions/Function1;)Larrow/core/NonEmptyList;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package arrow.core
2+
3+
import arrow.atomic.Atomic
4+
import arrow.atomic.loop
5+
6+
/**
7+
* Defines a recursive **pure** function that:
8+
* - keeps its stack on the heap, which allows very deep recursive computations that do not use the actual call stack;
9+
* - memoizes every call, which means that the function is execute only once per argument.
10+
*
11+
* [MemoizedDeepRecursiveFunction] takes one parameter of type [T] and returns a result of type [R].
12+
* The [block] of code defines the body of a recursive function. In this block
13+
* [callRecursive][DeepRecursiveScope.callRecursive] function can be used to make a recursive call
14+
* to the declared function.
15+
*/
16+
public fun <T, R> MemoizedDeepRecursiveFunction(
17+
block: suspend DeepRecursiveScope<T, R>.(T) -> R
18+
): DeepRecursiveFunction<T, R> {
19+
val cache = Atomic(emptyMap<T, R>())
20+
return DeepRecursiveFunction { x ->
21+
when (val v = cache.get()[x]) {
22+
null -> {
23+
val result = block(x)
24+
cache.loop { old ->
25+
when (x) {
26+
in old ->
27+
return@DeepRecursiveFunction old.getValue(x)
28+
else -> {
29+
if (cache.compareAndSet(old, old + Pair(x, result)))
30+
return@DeepRecursiveFunction result
31+
}
32+
}
33+
}
34+
}
35+
else -> v
36+
}
37+
}
38+
}

arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/MemoizationTest.kt

+31
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,37 @@ class MemoizationTest : StringSpec({
215215
memoized(1, 2, 3, 4, 5) shouldBe null
216216
runs shouldBe 1
217217
}
218+
219+
"Recursive memoization" {
220+
var runs = 0
221+
val memoizedDeepRecursiveFibonacci: DeepRecursiveFunction<Int, Int> =
222+
MemoizedDeepRecursiveFunction { n ->
223+
when (n) {
224+
0 -> 0.also { runs++ }
225+
1 -> 1
226+
else -> callRecursive(n - 1) + callRecursive(n - 2)
227+
}
228+
}
229+
val result = memoizedDeepRecursiveFibonacci(5)
230+
result shouldBe 5
231+
runs shouldBe 1
232+
}
233+
234+
"Recursive memoization, run twice should be memoized" {
235+
var runs = 0
236+
val memoizedDeepRecursiveFibonacci: DeepRecursiveFunction<Int, Int> =
237+
MemoizedDeepRecursiveFunction { n ->
238+
when (n) {
239+
0 -> 0.also { runs++ }
240+
1 -> 1
241+
else -> callRecursive(n - 1) + callRecursive(n - 2)
242+
}
243+
}
244+
val result1 = memoizedDeepRecursiveFibonacci(5)
245+
val result2 = memoizedDeepRecursiveFibonacci(5)
246+
result1 shouldBe result2
247+
runs shouldBe 1
248+
}
218249
})
219250

220251
private fun consecSumResult(n: Int): Int = (n * (n + 1)) / 2

0 commit comments

Comments
 (0)