Skip to content

Commit 65cd3a9

Browse files
Added Result filter methods (#224)
* Added Result filter methods * PR update: Use a different error value in the tests * PR updates * Added pipeable functions for some methods that were missing them * Added tests for `Result#map_error` * Added tests for piping functions
1 parent 172bb4b commit 65cd3a9

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

expression/core/result.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,38 @@ def is_ok(self) -> bool:
150150
"""Return `True` if the result is an `Ok` value."""
151151
return self.tag == "ok"
152152

153+
def filter(self, predicate: Callable[[_TSource], bool], default: _TError) -> Result[_TSource, _TError]:
154+
"""Filter result.
155+
156+
Returns the input if the predicate evaluates to true, otherwise
157+
returns the `default`
158+
"""
159+
match self:
160+
case Result(tag="ok", ok=value) if predicate(value):
161+
return self
162+
case Result(tag="error"):
163+
return self
164+
case _:
165+
return Error(default)
166+
167+
def filter_with(
168+
self,
169+
predicate: Callable[[_TSource], bool],
170+
default: Callable[[_TSource], _TError],
171+
) -> Result[_TSource, _TError]:
172+
"""Filter result.
173+
174+
Returns the input if the predicate evaluates to true, otherwise
175+
returns the `default` using the value as input
176+
"""
177+
match self:
178+
case Result(tag="ok", ok=value) if predicate(value):
179+
return self
180+
case Result(tag="ok", ok=value):
181+
return Error(default(value))
182+
case Result():
183+
return self
184+
153185
def dict(self) -> builtins.dict[str, _TSource | _TError | Literal["ok", "error"]]:
154186
"""Return a json serializable representation of the result."""
155187
match self:
@@ -352,6 +384,11 @@ def map2(
352384
return x.map2(y, mapper)
353385

354386

387+
@curry_flip(1)
388+
def map_error(result: Result[_TSource, _TError], mapper: Callable[[_TError], _TResult]) -> Result[_TSource, _TResult]:
389+
return result.map_error(mapper)
390+
391+
355392
@curry_flip(1)
356393
def bind(
357394
result: Result[_TSource, _TError],
@@ -374,11 +411,46 @@ def is_error(result: Result[_TSource, _TError]) -> TypeGuard[Result[_TSource, _T
374411
return result.is_error()
375412

376413

414+
@curry_flip(1)
415+
def filter(
416+
result: Result[_TSource, _TError],
417+
predicate: Callable[[_TSource], bool],
418+
default: _TError,
419+
) -> Result[_TSource, _TError]:
420+
return result.filter(predicate, default)
421+
422+
423+
@curry_flip(1)
424+
def filter_with(
425+
result: Result[_TSource, _TError],
426+
predicate: Callable[[_TSource], bool],
427+
default: Callable[[_TSource], _TError],
428+
) -> Result[_TSource, _TError]:
429+
return result.filter_with(predicate, default)
430+
431+
377432
def swap(result: Result[_TSource, _TError]) -> Result[_TError, _TSource]:
378433
"""Swaps the value in the result so an Ok becomes an Error and an Error becomes an Ok."""
379434
return result.swap()
380435

381436

437+
@curry_flip(1)
438+
def or_else(result: Result[_TSource, _TError], other: Result[_TSource, _TError]) -> Result[_TSource, _TError]:
439+
return result.or_else(other)
440+
441+
442+
@curry_flip(1)
443+
def or_else_with(
444+
result: Result[_TSource, _TError],
445+
other: Callable[[_TError], Result[_TSource, _TError]],
446+
) -> Result[_TSource, _TError]:
447+
return result.or_else_with(other)
448+
449+
450+
def merge(result: Result[_TSource, _TSource]) -> _TSource:
451+
return result.merge()
452+
453+
382454
def to_option(result: Result[_TSource, Any]) -> Option[_TSource]:
383455
from expression.core.option import Nothing, Some
384456

@@ -406,9 +478,17 @@ def of_option_with(value: Option[_TSource], error: Callable[[], _TError]) -> Res
406478
"map",
407479
"bind",
408480
"dict",
481+
"filter",
482+
"filter_with",
409483
"is_ok",
410484
"is_error",
485+
"map2",
486+
"map_error",
487+
"merge",
411488
"to_option",
412489
"of_option",
413490
"of_option_with",
491+
"or_else",
492+
"or_else_with",
493+
"swap",
414494
]

tests/test_result.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ def test_result_error_chained_map(msg: str, y: int):
194194
case _:
195195
assert False
196196

197+
@given(st.text())
198+
def test_map_error(msg: str):
199+
assert Error(msg).map_error(lambda x: f"more {x}") == Error("more " + msg)
200+
201+
@given(st.text())
202+
def test_map_error_piped(msg: str):
203+
assert Error(msg).pipe(result.map_error(lambda x: f"more {x}")) == Error(f"more {msg}")
204+
197205

198206
@given(st.integers(), st.integers()) # type: ignore
199207
def test_result_bind_piped(x: int, y: int):
@@ -362,6 +370,54 @@ def test_pipeline_error():
362370
assert hn(42) == error
363371

364372

373+
def test_filter_ok_passing_predicate():
374+
xs: Result[int, str] = Ok(42)
375+
ys = xs.filter(lambda x: x > 10, "error")
376+
377+
assert ys == xs
378+
379+
380+
def test_filter_ok_failing_predicate():
381+
xs: Result[int, str] = Ok(5)
382+
ys = xs.filter(lambda x: x > 10, "error")
383+
384+
assert ys == Error("error")
385+
386+
387+
def test_filter_error():
388+
error = Error("original error")
389+
ys = error.filter(lambda x: x > 10, "error")
390+
391+
assert ys == error
392+
393+
def test_filter_piped():
394+
assert Ok(42).pipe(result.filter(lambda x: x > 10, "error")) == Ok(42)
395+
396+
397+
def test_filter_with_ok_passing_predicate():
398+
xs: Result[int, str] = Ok(42)
399+
ys = xs.filter_with(lambda x: x > 10, lambda value: f"error {value}")
400+
401+
assert ys == xs
402+
403+
404+
def test_filter_with_ok_failing_predicate():
405+
xs: Result[int, str] = Ok(5)
406+
ys = xs.filter_with(lambda x: x > 10, lambda value: f"error {value}")
407+
408+
assert ys == Error("error 5")
409+
410+
411+
def test_filter_with_error():
412+
error = Error("original error")
413+
ys = error.filter_with(lambda x: x > 10, lambda value: f"error {value}")
414+
415+
assert ys == error
416+
417+
def test_filter_with_piped():
418+
assert Ok(42).pipe(result.filter_with(lambda x: x > 10, lambda value: f"error {value}")) == Ok(42)
419+
420+
365421
class MyError(BaseModel):
366422
message: str
367423

@@ -525,6 +581,8 @@ def test_result_swap_with_error():
525581
xs = result.swap(error)
526582
assert xs == Ok(1)
527583

584+
def test_swap_piped():
585+
assert Ok(42).pipe(result.swap) == Error(42)
528586

529587
def test_ok_or_else_ok():
530588
xs: Result[int, str] = Ok(42)
@@ -549,6 +607,8 @@ def test_error_or_else_error():
549607
ys = xs.or_else(Error("new error"))
550608
assert ys == Error("new error")
551609

610+
def test_or_else_piped():
611+
assert Ok(42).pipe(result.or_else(Ok(0))) == Ok(42)
552612

553613
def test_ok_or_else_with_ok():
554614
xs: Result[str, str] = Ok("good")
@@ -574,6 +634,10 @@ def test_error_or_else_with_error():
574634
assert ys == Error("new error from original error")
575635

576636

637+
def test_or_else_with_piped():
638+
assert Ok(42).pipe(result.or_else_with(lambda _: Ok(0))) == Ok(42)
639+
640+
577641
def test_merge_ok():
578642
assert Result.Ok(42).merge() == 42
579643

@@ -601,3 +665,7 @@ class Child2(Parent):
601665
def test_merge_subclasses():
602666
xs: Result[Parent, Parent] = Result.Ok(Child1(x=42))
603667
assert xs.merge() == Child1(x=42)
668+
669+
670+
def test_merge_piped():
671+
assert Ok(42).pipe(result.merge) == 42

0 commit comments

Comments
 (0)