Skip to content

Commit 7a92a9d

Browse files
authored
Fix pydantic validate and pickle for tagged union (#207)
* add test and fix * Update expression/core/tagged_union.py
1 parent d4bd6c7 commit 7a92a9d

File tree

6 files changed

+121
-17
lines changed

6 files changed

+121
-17
lines changed

expression/core/option.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,8 @@ def validate_none(v: Any, handler: ValidatorFunctionWrapHandler) -> Option[Any]:
329329
),
330330
core_schema.chain_schema(
331331
[
332-
# Ensure the value is an instance of _T
333-
core_schema.is_instance_schema(item_tp),
332+
# item_tp's schema should ensure the value is an instance of _T
333+
# is_instance_schema doesn't work for Annotated[_T, ...]
334334
# Use the value_schema to validate `values`
335335
core_schema.no_info_wrap_validator_function(validate_some, value_schema),
336336
]

expression/core/tagged_union.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ def transform(cls: Any) -> Any:
4545
field_names = tuple(f.name for f in fields_)
4646
original_init = cls.__init__
4747

48+
def tagged_union_getstate(self: Any) -> dict[str, Any]:
49+
return {f.name: getattr(self, f.name) for f in fields(self)}
50+
51+
def tagged_union_setstate(self: Any, state: dict[str, Any]):
52+
self.__init__(**state)
53+
54+
cls.__setstate__ = tagged_union_setstate
55+
cls.__getstate__ = tagged_union_getstate
56+
4857
def __init__(self: Any, **kwargs: Any) -> None:
4958
tag = kwargs.pop("tag", None)
5059

tests/test_block.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import functools
22
from builtins import list as list
33
from collections.abc import Callable
4-
from typing import Any, List
4+
from typing import Any, List, Annotated
55

66
from hypothesis import given # type: ignore
77
from hypothesis import strategies as st
8-
from pydantic import BaseModel
8+
from pydantic import BaseModel, Field, GetCoreSchemaHandler
9+
from pydantic_core import CoreSchema, core_schema
910

1011
from expression import Nothing, Option, Some, pipe
1112
from expression.collections import Block, block
1213

13-
1414
Func = Callable[[int], int]
1515

1616

@@ -408,24 +408,43 @@ def test_block_monad_law_associativity_iterable(xs: List[int]):
408408
assert m.collect(f).collect(g) == m.collect(lambda x: f(x).collect(g))
409409

410410

411+
PositiveInt = Annotated[int, Field(gt=0)]
412+
413+
414+
class Username(str):
415+
@classmethod
416+
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
417+
return core_schema.no_info_after_validator_function(cls, handler(str))
418+
419+
411420
class Model(BaseModel):
412421
one: Block[int]
413422
two: Block[str] = block.empty
414423
three: Block[float] = block.empty
424+
annotated_type: Block[PositiveInt] = block.empty
425+
annotated_type_empty: Block[PositiveInt] = block.empty
426+
427+
custom_type: Block[Username] = block.empty
428+
custom_type_empty: Block[Username] = block.empty
415429

416430

417431
def test_parse_block_works():
418-
obj = dict(one=[1, 2, 3], two=[])
432+
obj = dict(one=[1, 2, 3], two=[], annotated_type=[1, 2, 3], custom_type=["a", "b", "c"])
419433
model = Model.model_validate(obj)
420434
assert isinstance(model.one, Block)
421435
assert model.one == Block([1, 2, 3])
422436
assert model.two == Block.empty()
423437
assert model.three == block.empty
438+
assert model.annotated_type == Block([1, 2, 3])
439+
assert model.annotated_type_empty == block.empty
440+
assert model.custom_type == Block(["a", "b", "c"])
441+
assert model.custom_type_empty == block.empty
424442

425443

426444
def test_serialize_block_works():
427445
# arrange
428-
model = Model(one=Block([1, 2, 3]), two=Block.empty())
446+
obj = dict(one=[1, 2, 3], two=[], annotated_type=[1, 2, 3], custom_type=["a", "b", "c"])
447+
model = Model.model_validate(obj)
429448

430449
# act
431450
json = model.model_dump_json()
@@ -435,3 +454,7 @@ def test_serialize_block_works():
435454
assert model_.one == Block([1, 2, 3])
436455
assert model_.two == Block.empty()
437456
assert model_.three == block.empty
457+
assert model_.annotated_type == Block([1, 2, 3])
458+
assert model_.annotated_type_empty == block.empty
459+
assert model_.custom_type == Block(["a", "b", "c"])
460+
assert model_.custom_type_empty == block.empty

tests/test_option.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from collections.abc import Callable, Generator
2-
from typing import Any
2+
from typing import Any, Annotated
33

44
import pytest
55
from hypothesis import given # type: ignore
66
from hypothesis import strategies as st
7-
from pydantic import BaseModel
7+
from pydantic import BaseModel, Field, GetCoreSchemaHandler
8+
from pydantic_core import CoreSchema, core_schema
89

910
from expression import (
1011
Error,
@@ -599,30 +600,69 @@ def test_pipeline_error():
599600
assert hn(42) == Nothing
600601

601602

603+
PositiveInt = Annotated[int, Field(gt=0)]
604+
605+
606+
class Username(str):
607+
@classmethod
608+
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
609+
return core_schema.no_info_after_validator_function(cls, handler(str))
610+
611+
602612
class Model(BaseModel):
603613
one: Option[int]
604614
two: Option[str] = Nothing
605615
three: Option[float] = Nothing
616+
annotated_type: Option[PositiveInt] = Nothing
617+
annotated_type_none: Option[PositiveInt] = Nothing
618+
619+
custom_type: Option[Username] = Nothing
620+
custom_type_none: Option[Username] = Nothing
606621

607622

608623
def test_parse_option_works():
609-
obj = dict(one=10, two=None)
624+
obj = dict(
625+
one=10, two=None, annotated_type=20, annotated_type_none=None, custom_type="test_user", custom_type_none=None
626+
)
610627
model = Model.model_validate(obj)
611628

612629
assert model.one.is_some()
613630
assert model.one.value == 10
614631
assert model.two == Nothing
615632
assert model.three == Nothing
633+
assert model.custom_type == Some("test_user")
634+
assert model.annotated_type == Some(20)
635+
assert model.annotated_type_none == Nothing
636+
assert model.custom_type_none == Nothing
616637

617638

618639
def test_serialize_option_works():
619640
model = Model(one=Some(10))
620641
json = model.model_dump_json()
621-
assert json == '{"one":10,"two":null,"three":null}'
642+
assert (
643+
json
644+
== '{"one":10,"two":null,"three":null,"annotated_type":null,"annotated_type_none":null,"custom_type":null,"custom_type_none":null}'
645+
)
622646

623647
model_ = Model.model_validate_json(json)
624648

625649
assert model_.one.is_some()
626650
assert model_.one.value == 10
627651
assert model_.two == Nothing
628652
assert model_.three == Nothing
653+
654+
655+
def test_pickle_option_works():
656+
import pickle
657+
658+
x = Some(10)
659+
y = Nothing
660+
dump_x = pickle.dumps(x)
661+
load_x = pickle.loads(dump_x)
662+
dump_y = pickle.dumps(y)
663+
load_y = pickle.loads(dump_y)
664+
assert x == load_x
665+
assert y == load_y
666+
667+
668+
#

tests/test_result.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from collections.abc import Callable, Generator
2-
from typing import Any
2+
from typing import Any, Annotated
33

44
import pytest
55
from hypothesis import given # type: ignore
66
from hypothesis import strategies as st
7-
from pydantic import BaseModel, TypeAdapter
7+
from pydantic import BaseModel, TypeAdapter, Field, GetCoreSchemaHandler
8+
from pydantic_core import CoreSchema, core_schema
89

910
from expression import Error, Nothing, Ok, Option, Result, Some, effect, result
1011
from expression.collections import Block
@@ -364,20 +365,38 @@ class MyError(BaseModel):
364365
message: str
365366

366367

368+
PositiveInt = Annotated[int, Field(gt=0)]
369+
370+
371+
class Username(str):
372+
@classmethod
373+
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
374+
return core_schema.no_info_after_validator_function(cls, handler(str))
375+
376+
367377
class Model(BaseModel):
368378
one: Result[int, MyError]
369379
two: Result[str, MyError] = Error(MyError(message="error"))
370380
three: Result[float, MyError] = Error(MyError(message="error"))
381+
annotated_type: Result[PositiveInt, MyError] = Error(MyError(message="error"))
382+
annotated_type_error: Result[PositiveInt, MyError] = Error(MyError(message="error"))
383+
384+
custom_type: Result[Username, MyError] = Error(MyError(message="error"))
385+
custom_type_error: Result[Username, MyError] = Error(MyError(message="error"))
371386

372387

373388
def test_parse_block_works():
374-
obj = dict(one=dict(ok=42))
389+
obj = dict(one=dict(ok=42), annotated_type=dict(ok=42), custom_type=dict(ok="johndoe"))
375390
model = Model.model_validate(obj)
376391

377392
assert isinstance(model.one, Result)
378393
assert model.one == Ok(42)
379394
assert model.two == Error(MyError(message="error"))
380395
assert model.three == Error(MyError(message="error"))
396+
assert model.annotated_type == Ok(42)
397+
assert model.annotated_type_error == Error(MyError(message="error"))
398+
assert model.custom_type == Ok(Username("johndoe"))
399+
assert model.custom_type_error == Error(MyError(message="error"))
381400

382401

383402
def test_ok_to_dict_works():
@@ -422,11 +441,13 @@ def test_error_from_dict_works():
422441

423442

424443
def test_model_to_json_works():
425-
model = Model(one=Ok(10))
444+
obj = dict(one=dict(ok=10), annotated_type=dict(ok=10), custom_type=dict(ok="johndoe"))
445+
446+
model = Model.model_validate(obj)
426447
obj = model.model_dump_json()
427448
assert (
428449
obj
429-
== '{"one":{"tag":"ok","ok":10},"two":{"tag":"error","error":{"message":"error"}},"three":{"tag":"error","error":{"message":"error"}}}'
450+
== '{"one":{"tag":"ok","ok":10},"two":{"tag":"error","error":{"message":"error"}},"three":{"tag":"error","error":{"message":"error"}},"annotated_type":{"tag":"ok","ok":10},"annotated_type_error":{"tag":"error","error":{"message":"error"}},"custom_type":{"tag":"ok","ok":"johndoe"},"custom_type_error":{"tag":"error","error":{"message":"error"}}}'
430451
)
431452

432453

tests/test_tagged_union.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from expression import case, tag, tagged_union
99

10-
1110
_T = TypeVar("_T")
1211

1312

@@ -70,6 +69,18 @@ def test_union_shape_repr_works():
7069
assert repr(shape) == "Shape(circle=Circle(radius=10.0))"
7170

7271

72+
def test_union_shape_pickle_works():
73+
shape = Shape(circle=Circle(10.0))
74+
import pickle, dataclasses
75+
76+
shape_ser = pickle.dumps(shape)
77+
shape_deser = pickle.loads(shape_ser)
78+
assert shape == shape_deser
79+
assert shape.__dict__ == shape_deser.__dict__
80+
assert dataclasses.fields(shape) == dataclasses.fields(shape_deser)
81+
assert dataclasses.asdict(shape) == dataclasses.asdict(shape_deser)
82+
83+
7384
def test_union_can_add_custom_attributes_to_shape():
7485
shape = Shape(circle=Circle(10.0))
7586
setattr(shape, "custom", "rectangle")

0 commit comments

Comments
 (0)