Skip to content

Fix pydantic validate and pickle for tagged union #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions expression/core/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def validate_none(v: Any, handler: ValidatorFunctionWrapHandler) -> Option[Any]:
),
core_schema.chain_schema(
[
# Ensure the value is an instance of _T
core_schema.is_instance_schema(item_tp),
# item_tp's schema should ensure the value is an instance of _T
# is_instance_schema doesn't work for Annotated[_T, ...]
# Use the value_schema to validate `values`
core_schema.no_info_wrap_validator_function(validate_some, value_schema),
]
Expand Down
9 changes: 9 additions & 0 deletions expression/core/tagged_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def transform(cls: Any) -> Any:
field_names = tuple(f.name for f in fields_)
original_init = cls.__init__

def tagged_union_getstate(self: Any) -> dict[str, Any]:
return {f.name: getattr(self, f.name) for f in fields(self)}

def tagged_union_setstate(self: Any, state: dict[str, Any]):
self.__init__(**state)

cls.__setstate__ = tagged_union_setstate
cls.__getstate__ = tagged_union_getstate

def __init__(self: Any, **kwargs: Any) -> None:
tag = kwargs.pop("tag", None)

Expand Down
33 changes: 28 additions & 5 deletions tests/test_block.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import functools
from builtins import list as list
from collections.abc import Callable
from typing import Any, List
from typing import Any, List, Annotated

from hypothesis import given # type: ignore
from hypothesis import strategies as st
from pydantic import BaseModel
from pydantic import BaseModel, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

from expression import Nothing, Option, Some, pipe
from expression.collections import Block, block


Func = Callable[[int], int]


Expand Down Expand Up @@ -408,24 +408,43 @@ def test_block_monad_law_associativity_iterable(xs: List[int]):
assert m.collect(f).collect(g) == m.collect(lambda x: f(x).collect(g))


PositiveInt = Annotated[int, Field(gt=0)]


class Username(str):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return core_schema.no_info_after_validator_function(cls, handler(str))


class Model(BaseModel):
one: Block[int]
two: Block[str] = block.empty
three: Block[float] = block.empty
annotated_type: Block[PositiveInt] = block.empty
annotated_type_empty: Block[PositiveInt] = block.empty

custom_type: Block[Username] = block.empty
custom_type_empty: Block[Username] = block.empty


def test_parse_block_works():
obj = dict(one=[1, 2, 3], two=[])
obj = dict(one=[1, 2, 3], two=[], annotated_type=[1, 2, 3], custom_type=["a", "b", "c"])
model = Model.model_validate(obj)
assert isinstance(model.one, Block)
assert model.one == Block([1, 2, 3])
assert model.two == Block.empty()
assert model.three == block.empty
assert model.annotated_type == Block([1, 2, 3])
assert model.annotated_type_empty == block.empty
assert model.custom_type == Block(["a", "b", "c"])
assert model.custom_type_empty == block.empty


def test_serialize_block_works():
# arrange
model = Model(one=Block([1, 2, 3]), two=Block.empty())
obj = dict(one=[1, 2, 3], two=[], annotated_type=[1, 2, 3], custom_type=["a", "b", "c"])
model = Model.model_validate(obj)

# act
json = model.model_dump_json()
Expand All @@ -435,3 +454,7 @@ def test_serialize_block_works():
assert model_.one == Block([1, 2, 3])
assert model_.two == Block.empty()
assert model_.three == block.empty
assert model_.annotated_type == Block([1, 2, 3])
assert model_.annotated_type_empty == block.empty
assert model_.custom_type == Block(["a", "b", "c"])
assert model_.custom_type_empty == block.empty
48 changes: 44 additions & 4 deletions tests/test_option.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import Callable, Generator
from typing import Any
from typing import Any, Annotated

import pytest
from hypothesis import given # type: ignore
from hypothesis import strategies as st
from pydantic import BaseModel
from pydantic import BaseModel, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

from expression import (
Error,
Expand Down Expand Up @@ -599,30 +600,69 @@ def test_pipeline_error():
assert hn(42) == Nothing


PositiveInt = Annotated[int, Field(gt=0)]


class Username(str):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return core_schema.no_info_after_validator_function(cls, handler(str))


class Model(BaseModel):
one: Option[int]
two: Option[str] = Nothing
three: Option[float] = Nothing
annotated_type: Option[PositiveInt] = Nothing
annotated_type_none: Option[PositiveInt] = Nothing

custom_type: Option[Username] = Nothing
custom_type_none: Option[Username] = Nothing


def test_parse_option_works():
obj = dict(one=10, two=None)
obj = dict(
one=10, two=None, annotated_type=20, annotated_type_none=None, custom_type="test_user", custom_type_none=None
)
model = Model.model_validate(obj)

assert model.one.is_some()
assert model.one.value == 10
assert model.two == Nothing
assert model.three == Nothing
assert model.custom_type == Some("test_user")
assert model.annotated_type == Some(20)
assert model.annotated_type_none == Nothing
assert model.custom_type_none == Nothing


def test_serialize_option_works():
model = Model(one=Some(10))
json = model.model_dump_json()
assert json == '{"one":10,"two":null,"three":null}'
assert (
json
== '{"one":10,"two":null,"three":null,"annotated_type":null,"annotated_type_none":null,"custom_type":null,"custom_type_none":null}'
)

model_ = Model.model_validate_json(json)

assert model_.one.is_some()
assert model_.one.value == 10
assert model_.two == Nothing
assert model_.three == Nothing


def test_pickle_option_works():
import pickle

x = Some(10)
y = Nothing
dump_x = pickle.dumps(x)
load_x = pickle.loads(dump_x)
dump_y = pickle.dumps(y)
load_y = pickle.loads(dump_y)
assert x == load_x
assert y == load_y


#
31 changes: 26 additions & 5 deletions tests/test_result.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import Callable, Generator
from typing import Any
from typing import Any, Annotated

import pytest
from hypothesis import given # type: ignore
from hypothesis import strategies as st
from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel, TypeAdapter, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

from expression import Error, Nothing, Ok, Option, Result, Some, effect, result
from expression.collections import Block
Expand Down Expand Up @@ -364,20 +365,38 @@ class MyError(BaseModel):
message: str


PositiveInt = Annotated[int, Field(gt=0)]


class Username(str):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return core_schema.no_info_after_validator_function(cls, handler(str))


class Model(BaseModel):
one: Result[int, MyError]
two: Result[str, MyError] = Error(MyError(message="error"))
three: Result[float, MyError] = Error(MyError(message="error"))
annotated_type: Result[PositiveInt, MyError] = Error(MyError(message="error"))
annotated_type_error: Result[PositiveInt, MyError] = Error(MyError(message="error"))

custom_type: Result[Username, MyError] = Error(MyError(message="error"))
custom_type_error: Result[Username, MyError] = Error(MyError(message="error"))


def test_parse_block_works():
obj = dict(one=dict(ok=42))
obj = dict(one=dict(ok=42), annotated_type=dict(ok=42), custom_type=dict(ok="johndoe"))
model = Model.model_validate(obj)

assert isinstance(model.one, Result)
assert model.one == Ok(42)
assert model.two == Error(MyError(message="error"))
assert model.three == Error(MyError(message="error"))
assert model.annotated_type == Ok(42)
assert model.annotated_type_error == Error(MyError(message="error"))
assert model.custom_type == Ok(Username("johndoe"))
assert model.custom_type_error == Error(MyError(message="error"))


def test_ok_to_dict_works():
Expand Down Expand Up @@ -422,11 +441,13 @@ def test_error_from_dict_works():


def test_model_to_json_works():
model = Model(one=Ok(10))
obj = dict(one=dict(ok=10), annotated_type=dict(ok=10), custom_type=dict(ok="johndoe"))

model = Model.model_validate(obj)
obj = model.model_dump_json()
assert (
obj
== '{"one":{"tag":"ok","ok":10},"two":{"tag":"error","error":{"message":"error"}},"three":{"tag":"error","error":{"message":"error"}}}'
== '{"one":{"tag":"ok","ok":10},"two":{"tag":"error","error":{"message":"error"}},"three":{"tag":"error","error":{"message":"error"}},"annotated_type":{"tag":"ok","ok":10},"annotated_type_error":{"tag":"error","error":{"message":"error"}},"custom_type":{"tag":"ok","ok":"johndoe"},"custom_type_error":{"tag":"error","error":{"message":"error"}}}'
)


Expand Down
13 changes: 12 additions & 1 deletion tests/test_tagged_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from expression import case, tag, tagged_union


_T = TypeVar("_T")


Expand Down Expand Up @@ -70,6 +69,18 @@ def test_union_shape_repr_works():
assert repr(shape) == "Shape(circle=Circle(radius=10.0))"


def test_union_shape_pickle_works():
shape = Shape(circle=Circle(10.0))
import pickle, dataclasses

shape_ser = pickle.dumps(shape)
shape_deser = pickle.loads(shape_ser)
assert shape == shape_deser
assert shape.__dict__ == shape_deser.__dict__
assert dataclasses.fields(shape) == dataclasses.fields(shape_deser)
assert dataclasses.asdict(shape) == dataclasses.asdict(shape_deser)


def test_union_can_add_custom_attributes_to_shape():
shape = Shape(circle=Circle(10.0))
setattr(shape, "custom", "rectangle")
Expand Down
Loading