-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat: add serialization to State
/ move State
to utils
#9345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Pull Request Test Coverage Report for Build 14859791142Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
I’ve kept the current |
def test_state_to_dict(): | ||
# we test dict, a python type and a haystack dataclass | ||
state_schema = {"numbers": {"type": int}, "messages": {"type": List[ChatMessage]}, "dict_of_lists": {"type": dict}} | ||
|
||
data = { | ||
"numbers": 1, | ||
"messages": [ChatMessage.from_user(text="Hello, world!")], | ||
"dict_of_lists": {"numbers": [1, 2, 3]}, | ||
} | ||
state = State(state_schema, data) | ||
state_dict = state.to_dict() | ||
assert state_dict["schema"] == { | ||
"numbers": {"type": "int", "handler": "haystack.utils.state_utils.replace_values"}, | ||
"messages": { | ||
"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", | ||
"handler": "haystack.utils.state_utils.merge_lists", | ||
}, | ||
"dict_of_lists": {"type": "dict", "handler": "haystack.utils.state_utils.replace_values"}, | ||
} | ||
assert state_dict["data"] == { | ||
"numbers": 1, | ||
"messages": [ | ||
{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}], "_type": "ChatMessage"} | ||
], | ||
"dict_of_lists": {"numbers": [1, 2, 3]}, | ||
} | ||
|
||
|
||
def test_state_from_dict(): | ||
state_dict = { | ||
"schema": { | ||
"numbers": {"type": "int", "handler": "haystack.utils.state_utils.replace_values"}, | ||
"messages": { | ||
"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", | ||
"handler": "haystack.utils.state_utils.merge_lists", | ||
}, | ||
"dict_of_lists": {"type": "dict", "handler": "haystack.utils.state_utils.replace_values"}, | ||
}, | ||
"data": { | ||
"numbers": 1, | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"meta": {}, | ||
"name": None, | ||
"content": [{"text": "Hello, world!"}], | ||
"_type": "ChatMessage", | ||
} | ||
], | ||
"dict_of_lists": {"numbers": [1, 2, 3]}, | ||
}, | ||
} | ||
state = State.from_dict(state_dict) | ||
# Check types are correctly converted | ||
assert state.schema["numbers"]["type"] == int | ||
assert state.schema["dict_of_lists"]["type"] == dict | ||
# Check handlers are functions, not comparing exact functions as they might be different references | ||
assert callable(state.schema["numbers"]["handler"]) | ||
assert callable(state.schema["messages"]["handler"]) | ||
assert callable(state.schema["dict_of_lists"]["handler"]) | ||
# Check data is correct | ||
assert state.data["numbers"] == 1 | ||
assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")] | ||
assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ease in review, these are the two new tests for serialization. Rest are just moved from old test file.
_type_deserializers = { | ||
"Answer": Answer.from_dict, | ||
"ChatMessage": ChatMessage.from_dict, | ||
"Document": Document.from_dict, | ||
"ExtractedAnswer": ExtractedAnswer.from_dict, | ||
"GeneratedAnswer": GeneratedAnswer.from_dict, | ||
"SparseEmbedding": SparseEmbedding.from_dict, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using a predefined set of deserializers couldn't we follow a similar methodology like we do for component_from_dict
where we import the _type
field, check if the imported class has a from_dict
attribute and if so use that? That way we wouldn't need to create a list of hard-coded deserializers.
Related Issues
State
#9286Proposed Changes:
serialize_value
anddeserialize_value
utility methods in theutils
module. These methods encapsulate logic that is also used in the breakpoints de/serialization logic andtracing.utils.coerce_tag_value
(except the json load). Once the breakpoints feature is merged in haystack, it can reuse these centralized utility functions.State
class to theutils
module, as it is not actually a data class. A deprecation warning is added to the existing State class in the dataclasses module to guide users toward the updated implementation.How did you test it?
Moved the existing tests to
test_utils_state.py
and added two new tests for serialization and deserialization.