diff --git a/modules/core/arrow-extras-extensions/src/main/kotlin/arrow/data/extensions/andthen.kt b/modules/core/arrow-extras-extensions/src/main/kotlin/arrow/data/extensions/andthen.kt new file mode 100644 index 00000000000..8611fe967b3 --- /dev/null +++ b/modules/core/arrow-extras-extensions/src/main/kotlin/arrow/data/extensions/andthen.kt @@ -0,0 +1,89 @@ +package arrow.data.extensions + +import arrow.Kind +import arrow.Kind2 +import arrow.core.Either +import arrow.data.* +import arrow.extension +import arrow.typeclasses.* + +@extension +interface AndThenSemigroup : Semigroup> { + fun SB(): Semigroup + + override fun AndThen.combine(b: AndThen): AndThen = SB().run { + AndThen { a: A -> invoke(a).combine(b.invoke(a)) } + } + +} + +@extension +interface AndThenMonoid : Monoid>, AndThenSemigroup { + + fun MB(): Monoid + + override fun SB(): Semigroup = MB() + + override fun empty(): AndThen = + AndThen { MB().empty() } + +} + +@extension +interface AndThenFunctor : Functor> { + override fun AndThenOf.map(f: (A) -> B): AndThen = + fix().map(f) +} + +@extension +interface AndThenApplicative : Applicative>, AndThenFunctor { + override fun just(a: A): AndThenOf = + AndThen.just(a) + + override fun AndThenOf.ap(ff: AndThenOf B>): AndThen = + fix().ap(ff) + + override fun AndThenOf.map(f: (A) -> B): AndThen = + fix().map(f) + +} + +@extension +interface AndThenMonad : Monad>, AndThenApplicative { + override fun AndThenOf.flatMap(f: (A) -> AndThenOf): AndThen = + fix().flatMap(f) + + override fun tailRecM(a: A, f: (A) -> AndThenOf>): AndThen = + AndThen.tailRecM(a, f) + + override fun AndThenOf.map(f: (A) -> B): AndThen = + fix().map(f) + + override fun AndThenOf.ap(ff: AndThenOf B>): AndThen = + fix().ap(ff) + +} + +@extension +interface AndThenCategory : Category { + override fun id(): AndThen = + AndThen.id() + + override fun AndThenOf.compose(arr: Kind2): AndThen = + fix().compose(arr::invoke) + +} + +@extension +interface AndThenContravariant : Contravariant> { + + override fun Kind, A>.contramap(f: (B) -> A): Kind, B> = + counnest().fix().contramap(f).conest() + +} + +@extension +interface AndThenProfunctor : Profunctor { + override fun AndThenOf.dimap(fl: (C) -> A, fr: (B) -> D): AndThen = + fix().andThen(fr).compose(fl) +} diff --git a/modules/core/arrow-extras/src/main/kotlin/arrow/data/AndThen.kt b/modules/core/arrow-extras/src/main/kotlin/arrow/data/AndThen.kt new file mode 100644 index 00000000000..0c79c3b0341 --- /dev/null +++ b/modules/core/arrow-extras/src/main/kotlin/arrow/data/AndThen.kt @@ -0,0 +1,250 @@ +package arrow.data + +import arrow.core.* +import arrow.higherkind + +operator fun AndThenOf.invoke(a: A): B = fix().invoke(a) + +/** + * [AndThen] wraps a function of shape `(A) -> B` and can be used to do function composition. + * It's similar to [arrow.core.andThen] and [arrow.core.compose] and can be used to build stack safe + * data structures that make use of lambdas. Usage is typically used for signature such as `A -> Kind` where + * `F` has a [arrow.typeclasses.Monad] instance i.e. [StateT.flatMap]. + * + * As you can see the usage of [AndThen] is the same as `[arrow.core.andThen] except we start our computation by + * wrapping our function in [AndThen]. + * + * ```kotlin:ank:playground + * import arrow.core.andThen + * import arrow.data.AndThen + * import arrow.data.extensions.list.foldable.foldLeft + * + * fun main(args: Array) { + * //sampleStart + * val f = (0..10000).toList() + * .fold({ x: Int -> x + 1 }) { acc, _ -> + * acc.andThen { it + 1 } + * } + * + * val f2 = (0..10000).toList() + * .foldLeft(AndThen { x: Int -> x + 1 }) { acc, _ -> + * acc.andThen { it + 1 } + * } + * //sampleEnd + * println("f(0) = ${f(0)}, f2(0) = ${f2(0)}") + * } + * ``` + * + */ +@higherkind +sealed class AndThen : (A) -> B, AndThenOf { + + private data class Single(val f: (A) -> B, val index: Int) : AndThen() + + private data class Concat(val left: AndThen, val right: AndThen) : AndThen() { + override fun toString(): String = "AndThen.Concat(...)" + } + + /** + * Compose a function to be invoked after the current function is invoked. + * + * ```kotlin:ank:playground + * import arrow.data.AndThen + * import arrow.data.extensions.list.foldable.foldLeft + * + * fun main(args: Array) { + * //sampleStart + * val f = (0..10000).toList().foldLeft(AndThen { i: Int -> i + 1 }) { acc, _ -> + * acc.andThen { it + 1 } + * } + * + * val result = f(0) + * //sampleEnd + * println("result = $result") + * } + * ``` + * + * @param g function to apply on the result of this function. + * @return a composed [AndThen] function that first applies this function to its input, + * and then applies [g] to the result. + */ + fun andThen(g: (B) -> X): AndThen = + when (this) { + // Fusing calls up to a certain threshold, using the fusion technique implemented for `IO#map` + is Single -> if (index != maxStackDepthSize) Single(f andThen g, index + 1) + else andThenF(AndThen(g)) + else -> andThenF(AndThen(g)) + } + + /** + * Compose a function to be invoked before the current function is invoked. + * + * ```kotlin:ank:playground + * import arrow.data.AndThen + * import arrow.data.extensions.list.foldable.foldLeft + * + * fun main(args: Array) { + * //sampleStart + * val f = (0..10000).toList().foldLeft(AndThen { i: Int -> i + 1 }) { acc, _ -> + * acc.compose { it + 1 } + * } + * + * val result = f(0) + * //sampleEnd + * println("result = $result") + * } + * ``` + * + * @param g function to invoke before invoking this function with the result. + * @return a composed [AndThen] function that first applies [g] to its input, + * and then applies this function to the result. + */ + infix fun compose(g: (C) -> A): AndThen = + when (this) { + // Fusing calls up to a certain threshold, using the fusion technique implemented for `IO#map` + is Single -> if (index != maxStackDepthSize) Single(f compose g, index + 1) + else composeF(AndThen(g)) + else -> composeF(AndThen(g)) + } + + /** + * Alias for [andThen] + * + * @see andThen + */ + fun map(f: (B) -> C): AndThen = + andThen(f) + + /** + * Alias for [andThen] + * + * @see compose + */ + fun contramap(f: (C) -> A): AndThen = + this compose f + + fun flatMap(f: (B) -> AndThenOf): AndThen = + AndThen { a: A -> f(this.invoke(a)).fix().invoke(a) } + + fun ap(ff: AndThenOf C>): AndThen = + ff.fix().flatMap { f -> + map(f) + } + + /** + * Invoke the `[AndThen]` function + * + * ```kotlin:ank:playground + * import arrow.data.AndThen + * + * fun main(args: Array) { + * //sampleStart + * val f: AndThen = AndThen(Int::toString) + * val result = f.invoke(0) + * //sampleEnd + * println("result = $result") + * } + * ``` + * + * @param a value to invoke function with + * @return result of type [B]. + * + **/ + @Suppress("UNCHECKED_CAST") + override fun invoke(a: A): B = loop(this as AndThen, a) + + override fun toString(): String = "AndThen(...)" + + companion object { + + fun just(b: B): AndThen = + AndThen { b } + + fun id(): AndThen = + AndThen(::identity) + + /** + * Wraps a function in [AndThen]. + * + * ```kotlin:ank:playground + * import arrow.data.AndThen + * + * fun main(args: Array) { + * //sampleStart + * val f = AndThen { x: Int -> x + 1 } + * val result = f(0) + * //sampleEnd + * println("result = $result") + * } + * ``` + * + * @param f the function to wrap + * @return wrapped function [f]. + * + */ + operator fun invoke(f: (A) -> B): AndThen = when (f) { + is AndThen -> f + else -> Single(f, 0) + } + + fun tailRecM(a: A, f: (A) -> AndThenOf>): AndThen = + AndThen { t: I -> step(a, t, f) } + + private tailrec fun step(a: A, t: I, fn: (A) -> AndThenOf>): B { + val af = fn(a)(t) + return when (af) { + is Either.Right -> af.b + is Either.Left -> step(af.a, t, fn) + } + } + + /** + * Establishes the maximum stack depth when fusing `andThen` or `compose` calls. + * + * The default is `128`, from which we substract one as an + * optimization. This default has been reached like this: + * + * - according to official docs, the default stack size on 32-bits + * Windows and Linux was 320 KB, whereas for 64-bits it is 1024 KB + * - according to measurements chaining `Function1` references uses + * approximately 32 bytes of stack space on a 64 bits system; + * this could be lower if "compressed oops" is activated + * - therefore a "map fusion" that goes 128 in stack depth can use + * about 4 KB of stack space + */ + private const val maxStackDepthSize = 127 + } + + private fun andThenF(right: AndThen): AndThen = Concat(this, right) + private fun composeF(right: AndThen): AndThen = Concat(right, this) + + @Suppress("UNCHECKED_CAST") + private tailrec fun loop(self: AndThen, current: Any?): B = when (self) { + is Single -> self.f(current) as B + is Concat<*, *, *> -> { + when (val oldLeft = self.left) { + is Single<*, *> -> { + val left = oldLeft as Single + val newSelf = self.right as AndThen + loop(newSelf, left.f(current)) + } + is Concat<*, *, *> -> loop( + rotateAccumulate(self.left as AndThen, self.right as AndThen), + current + ) + } + } + } + + @Suppress("UNCHECKED_CAST") + private tailrec fun rotateAccumulate( + left: AndThen, + right: AndThen): AndThen = when (left) { + is Concat<*, *, *> -> rotateAccumulate( + left.left as AndThen, + (left.right as AndThen).andThenF(right) + ) + is Single<*, *> -> left.andThenF(right) + } + +} diff --git a/modules/core/arrow-extras/src/main/kotlin/arrow/data/StateT.kt b/modules/core/arrow-extras/src/main/kotlin/arrow/data/StateT.kt index 06fe0bf10ff..8ada10943d5 100644 --- a/modules/core/arrow-extras/src/main/kotlin/arrow/data/StateT.kt +++ b/modules/core/arrow-extras/src/main/kotlin/arrow/data/StateT.kt @@ -99,7 +99,7 @@ class StateT( * @param f the modify function to apply. */ fun modify(AF: Applicative, f: (S) -> S): StateT = AF.run { - StateT(just({ s -> + StateT(just({ s -> just(f(s)).map { Tuple2(it, Unit) } })) } @@ -169,7 +169,7 @@ class StateT( fun map2(MF: Monad, sb: StateTOf, fn: (A, B) -> Z): StateT = MF.run { invokeF(runF.map2(sb.fix().runF) { (ssa, ssb) -> - ssa.andThen { fsa -> + AndThen(ssa).andThen { fsa -> fsa.flatMap { (s, a) -> ssb(s).map { (s, b) -> Tuple2(s, fn(a, b)) } } @@ -186,7 +186,7 @@ class StateT( */ fun map2Eval(MF: Monad, sb: EvalOf>, fn: (A, B) -> Z): Eval> = MF.run { runF.map2Eval(sb.fix().map { it.runF }) { (ssa, ssb) -> - ssa.andThen { fsa -> + AndThen(ssa).andThen { fsa -> fsa.flatMap { (s, a) -> ssb((s)).map { (s, b) -> Tuple2(s, fn(a, b)) } } @@ -221,7 +221,7 @@ class StateT( fun flatMap(MF: Monad, fas: (A) -> StateTOf): StateT = MF.run { invokeF( runF.map { sfsa -> - sfsa.andThen { fsa -> + AndThen(sfsa).andThen { fsa -> fsa.flatMap { fas(it.b).runM(MF, it.a) } @@ -238,7 +238,7 @@ class StateT( fun flatMapF(MF: Monad, faf: (A) -> Kind): StateT = MF.run { invokeF( runF.map { sfsa -> - sfsa.andThen { fsa -> + AndThen(sfsa).andThen { fsa -> fsa.flatMap { (s, a) -> faf(a).map { b -> Tuple2(s, b) } } diff --git a/modules/core/arrow-extras/src/test/kotlin/arrow/data/AndThenTest.kt b/modules/core/arrow-extras/src/test/kotlin/arrow/data/AndThenTest.kt new file mode 100644 index 00000000000..5e853a1b961 --- /dev/null +++ b/modules/core/arrow-extras/src/test/kotlin/arrow/data/AndThenTest.kt @@ -0,0 +1,100 @@ +package arrow.data + +import arrow.Kind +import arrow.core.* +import arrow.core.extensions.monoid +import arrow.data.extensions.andthen.category.category +import arrow.data.extensions.andthen.contravariant.contravariant +import arrow.data.extensions.andthen.monad.monad +import arrow.data.extensions.andthen.monoid.monoid +import arrow.data.extensions.andthen.profunctor.profunctor +import arrow.data.extensions.list.foldable.foldLeft +import arrow.test.UnitSpec +import arrow.test.generators.genFunctionAToB +import arrow.test.laws.* +import arrow.typeclasses.Conested +import arrow.typeclasses.Eq +import arrow.typeclasses.conest +import arrow.typeclasses.counnest +import io.kotlintest.matchers.shouldBe +import io.kotlintest.properties.Gen +import io.kotlintest.properties.forAll +import io.kotlintest.runner.junit4.KotlinTestRunner +import org.junit.runner.RunWith + +@RunWith(KotlinTestRunner::class) +class AndThenTest : UnitSpec() { + + val ConestedEQ: Eq, Int>> = Eq { a, b -> + a.counnest().invoke(1) == b.counnest().invoke(1) + } + + val EQ: Eq> = Eq { a, b -> + a(1) == b(1) + } + + init { + + testLaws( + MonadLaws.laws(AndThen.monad(), EQ), + MonoidLaws.laws(AndThen.monoid(Int.monoid()), Gen.int().map { i -> AndThen { i } }, EQ), + ContravariantLaws.laws(AndThen.contravariant(), { AndThen.just(it).conest() }, ConestedEQ), + ProfunctorLaws.laws(AndThen.profunctor(), { AndThen.just(it) }, EQ), + CategoryLaws.laws(AndThen.category(), { AndThen.just(it) }, EQ) + ) + + "compose a chain of functions with andThen should be same with AndThen" { + forAll(Gen.int(), Gen.list(genFunctionAToB(Gen.int()))) { i, fs -> + val result = fs.map(AndThen.Companion::invoke) + .fold(AndThen(::identity)) { acc, b -> + acc.andThen(b) + }.invoke(i) + + val expect = fs.fold({ x: Int -> x }) { acc, b -> + acc.andThen(b) + }.invoke(i) + + result == expect + } + } + + "compose a chain of function with compose should be same with AndThen" { + forAll(Gen.int(), Gen.list(genFunctionAToB(Gen.int()))) { i, fs -> + val result = fs.map(AndThen.Companion::invoke) + .fold(AndThen(::identity)) { acc, b -> + acc.compose(b) + }.invoke(i) + + val expect = fs.fold({ x: Int -> x }) { acc, b -> + acc.compose(b) + }.invoke(i) + + result == expect + } + } + + val count = 500000 + + "andThen is stack safe" { + val result = (0 until count).toList().foldLeft(AndThen(::identity)) { acc, _ -> + acc.andThen { it + 1 } + }.invoke(0) + + result shouldBe count + } + + "compose is stack safe" { + val result = (0 until count).toList().foldLeft(AndThen(::identity)) { acc, _ -> + acc.compose { it + 1 } + }.invoke(0) + + result shouldBe count + } + + "toString is stack safe" { + (0 until count).toList().foldLeft(AndThen(::identity)) { acc, _ -> + acc.compose { it + 1 } + }.toString() shouldBe "AndThen.Concat(...)" + } + } +} \ No newline at end of file diff --git a/modules/core/arrow-test/src/main/kotlin/arrow/test/laws/MonadDeferLaws.kt b/modules/core/arrow-test/src/main/kotlin/arrow/test/laws/MonadDeferLaws.kt index 6be0c2cc2a8..599dcd8641d 100644 --- a/modules/core/arrow-test/src/main/kotlin/arrow/test/laws/MonadDeferLaws.kt +++ b/modules/core/arrow-test/src/main/kotlin/arrow/test/laws/MonadDeferLaws.kt @@ -10,6 +10,7 @@ import arrow.test.concurrency.SideEffect import arrow.test.generators.genIntSmall import arrow.test.generators.genThrowable import arrow.typeclasses.Eq +import io.kotlintest.properties.Gen import io.kotlintest.properties.forAll import io.kotlintest.shouldBe import kotlinx.coroutines.Dispatchers @@ -141,29 +142,33 @@ object MonadDeferLaws { df.flatMap { df }.flatMap { df }.equalUnderTheLaw(just(3), EQ) shouldBe true } - fun MonadDefer.stackSafetyOverRepeatedLeftBinds(iterations: Int = 5000, EQ: Eq>): Unit { - (0..iterations).toList().k().foldLeft(just(0)) { def, x -> - def.flatMap { just(x) } - }.equalUnderTheLaw(just(iterations), EQ) shouldBe true - } + fun MonadDefer.stackSafetyOverRepeatedLeftBinds(iterations: Int = 5000, EQ: Eq>): Unit = + forAll(Gen.create { Unit }) { + (0..iterations).toList().k().foldLeft(just(0)) { def, x -> + def.flatMap { just(x) } + }.equalUnderTheLaw(just(iterations), EQ) + } - fun MonadDefer.stackSafetyOverRepeatedRightBinds(iterations: Int = 5000, EQ: Eq>): Unit { - (0..iterations).toList().foldRight(just(iterations)) { x, def -> - lazy().flatMap { def } - }.equalUnderTheLaw(just(iterations), EQ) shouldBe true - } + fun MonadDefer.stackSafetyOverRepeatedRightBinds(iterations: Int = 5000, EQ: Eq>): Unit = + forAll(Gen.create { Unit }) { + (0..iterations).toList().foldRight(just(iterations)) { x, def -> + lazy().flatMap { def } + }.equalUnderTheLaw(just(iterations), EQ) + } - fun MonadDefer.stackSafetyOverRepeatedAttempts(iterations: Int = 5000, EQ: Eq>): Unit { - (0..iterations).toList().foldLeft(just(0)) { def, x -> - def.attempt().map { x } - }.equalUnderTheLaw(just(iterations), EQ) shouldBe true - } + fun MonadDefer.stackSafetyOverRepeatedAttempts(iterations: Int = 5000, EQ: Eq>): Unit = + forAll(Gen.create { Unit }) { + (0..iterations).toList().foldLeft(just(0)) { def, x -> + def.attempt().map { x } + }.equalUnderTheLaw(just(iterations), EQ) + } - fun MonadDefer.stackSafetyOnRepeatedMaps(iterations: Int = 5000, EQ: Eq>): Unit { - (0..iterations).toList().foldLeft(just(0)) { def, x -> - def.map { x } - }.equalUnderTheLaw(just(iterations), EQ) shouldBe true - } + fun MonadDefer.stackSafetyOnRepeatedMaps(iterations: Int = 5000, EQ: Eq>): Unit = + forAll(Gen.create { Unit }) { + (0..iterations).toList().foldLeft(just(0)) { def, x -> + def.map { x } + }.equalUnderTheLaw(just(iterations), EQ) + } fun MonadDefer.asyncBind(EQ: Eq>): Unit = forAll(genIntSmall(), genIntSmall(), genIntSmall()) { x: Int, y: Int, z: Int ->