Skip to content

Feat: upgrade variable assigner #11285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6d5b1f0
refactor(variable_assigner): Rename to variable_operator.
laipz8200 Nov 21, 2024
96f9fc8
Merge branch 'main' into feat/variable-operator
YIXIAO0 Nov 21, 2024
ba5e2bd
fix(variable_operator): Rename _node_type
laipz8200 Nov 21, 2024
fdf9bea
refactor: rearrange SegmentType enum order
laipz8200 Nov 21, 2024
17d5548
refactor(variable_operator): Removes unused attributes from VariableO…
laipz8200 Nov 21, 2024
23372eb
chore(variable_operator): Move exc.py into common.
laipz8200 Nov 21, 2024
1e56773
Merge branch 'main' into feat/variable-operator
YIXIAO0 Nov 22, 2024
5371c79
feat(variable_operator): Add version 2 of variable operator.
laipz8200 Nov 22, 2024
9fa4fdd
feat(variable_operator): Avoid do any change if operator is failed.
laipz8200 Nov 22, 2024
4653f47
refactor(variable_operator): Rename VariableOperator to VariableOpera…
laipz8200 Nov 22, 2024
92ab118
fix(api): Set default value in DefaultBlockConfigApi
laipz8200 Nov 22, 2024
080d7a0
feat(workflow): Add version control in workflow nodes
laipz8200 Nov 22, 2024
6891999
feat(variable_operator_v2): Update operation mark.
laipz8200 Nov 22, 2024
80b0782
refactor(variable_operator): Move old version into v1/
laipz8200 Nov 22, 2024
49f8482
chore: variable assigner ui update
YIXIAO0 Nov 25, 2024
6d2bb16
fix(variable_operator_v2): Append and extend variable correctly.
laipz8200 Nov 27, 2024
580afd1
Merge branch 'feat/variable-operator' of https://github.com/langgeniu…
YIXIAO0 Nov 27, 2024
e4dd750
chore: operations with variables frontend
YIXIAO0 Nov 27, 2024
95b2565
fix: typo
YIXIAO0 Nov 27, 2024
0b86c4c
fix(variable_factory): Add selector into environment / conversation v…
laipz8200 Nov 27, 2024
28e39bf
fix(variable_operator_v2): Skip value check if operation is 'clear'
laipz8200 Nov 27, 2024
303d91e
feat(variable-operator-v2): Update `value` to optional and set defaul…
laipz8200 Nov 28, 2024
17cbe70
feat(variable-operator-v2): Check if value is None before use.
laipz8200 Nov 28, 2024
8307283
feat(variable-operator-v2): Convert string to object when set to a ob…
laipz8200 Nov 28, 2024
84df0f3
chore(variable-operator-v2): Add comment
laipz8200 Nov 28, 2024
4e2be52
refactor(variable-operator): Move update_conversation_variable into c…
laipz8200 Nov 28, 2024
1beccc4
fix(variable-operator-v2): Update conversation variables after operation
laipz8200 Nov 28, 2024
00c7437
fix(variable-operator-v2): Update input validation.
laipz8200 Nov 28, 2024
8e924eb
test(variable-operator-v2): Test overwrite array string
laipz8200 Nov 28, 2024
651845d
chore(variable-operator-v2): Update comments
laipz8200 Nov 28, 2024
e042776
chore: add operations and i18n
YIXIAO0 Nov 28, 2024
3f5d9c5
Merge remote-tracking branch 'origin/main' into feat/variable-operator
laipz8200 Nov 28, 2024
9925e5f
chore: allow parameters to be same as the assigned variable
YIXIAO0 Nov 28, 2024
ae062c5
refactor(variable-operator): Rename node to Variable Assigner
laipz8200 Nov 28, 2024
7f44871
chore(variable_factory): Add comments
laipz8200 Nov 29, 2024
d82c5dc
fix: checklist display issue & operation selector ui issue
YIXIAO0 Nov 29, 2024
39190ae
fix: number value can be 0
YIXIAO0 Nov 29, 2024
b7b5cef
Merge branch 'main' into feat/variable-operator
YIXIAO0 Dec 2, 2024
a31bac9
Merge branch 'main' into feat/variable-operator
YIXIAO0 Dec 2, 2024
3e02f75
feat: compatible to previous workflow
YIXIAO0 Dec 2, 2024
7cadebf
fix: add no variable available tips for other langs
YIXIAO0 Dec 3, 2024
9d135f0
Merge branch 'main' into feat/variable-operator
YIXIAO0 Dec 3, 2024
3504572
Merge branch 'main' into feat/variable-operator
laipz8200 Dec 3, 2024
3dbcd9f
test(test_workflow): fix test
laipz8200 Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/controllers/console/app/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def post(self, app_model: App):
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
Expand Down Expand Up @@ -382,7 +382,7 @@ def get(self, app_model: App, block_type: str):
filters = None
if args.get("q"):
try:
filters = json.loads(args.get("q"))
filters = json.loads(args.get("q", ""))
except json.JSONDecodeError:
raise ValueError("Invalid filters")

Expand Down
5 changes: 3 additions & 2 deletions api/core/app/apps/workflow_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
Expand Down Expand Up @@ -138,7 +138,8 @@ def _get_graph_and_variable_pool_of_single_iteration(

# Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]

# init variable pool
variable_pool = VariablePool(
Expand Down
9 changes: 6 additions & 3 deletions api/core/variables/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@


class SegmentType(StrEnum):
NONE = "none"
NUMBER = "number"
STRING = "string"
OBJECT = "object"
SECRET = "secret"

FILE = "file"

ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
OBJECT = "object"
FILE = "file"
ARRAY_FILE = "array[file]"

NONE = "none"

GROUP = "group"
5 changes: 3 additions & 2 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
Expand Down Expand Up @@ -227,7 +227,8 @@ def _run(

# convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
node_version = node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]

previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _recursive_fetch_answer_dependencies(
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.CONVERSATION_VARIABLE_ASSIGNER,
NodeType.VARIABLE_ASSIGNER,
}:
answer_dependencies[answer_node_id].append(source_node_id)
else:
Expand Down
1 change: 1 addition & 0 deletions api/core/workflow/nodes/base/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
version: str = "1"


class BaseIterationNodeData(BaseNodeData):
Expand Down
4 changes: 3 additions & 1 deletion api/core/workflow/nodes/base/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def __init__(
raise ValueError("Node ID is required.")

self.node_id = node_id
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))

node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = cast(GenericNodeData, node_data)

@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
Expand Down
4 changes: 2 additions & 2 deletions api/core/workflow/nodes/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
7 changes: 4 additions & 3 deletions api/core/workflow/nodes/iteration/iteration_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,13 @@ def _extract_variable_selector_to_variable_mapping(
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING

node_type = NodeType(sub_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
if not node_cls:
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]

sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
Expand Down
105 changes: 84 additions & 21 deletions api/core/workflow/nodes/node_mapping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Mapping

from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
Expand All @@ -16,26 +18,87 @@
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2

LATEST_VERSION = "latest"

node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode,
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
}
8 changes: 0 additions & 8 deletions api/core/workflow/nodes/variable_assigner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +0,0 @@
from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode

__all__ = [
"VariableAssignerData",
"VariableAssignerNode",
"WriteMode",
]
Empty file.
4 changes: 4 additions & 0 deletions api/core/workflow/nodes/variable_assigner/common/exc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class VariableOperatorNodeError(Exception):
"""Base error type, don't use directly."""

pass
19 changes: 19 additions & 0 deletions api/core/workflow/nodes/variable_assigner/common/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from sqlalchemy import select
from sqlalchemy.orm import Session

from core.variables import Variable
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable


def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
2 changes: 0 additions & 2 deletions api/core/workflow/nodes/variable_assigner/exc.py

This file was deleted.

3 changes: 3 additions & 0 deletions api/core/workflow/nodes/variable_assigner/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .node import VariableAssignerNode

__all__ = ["VariableAssignerNode"]
Original file line number Diff line number Diff line change
@@ -1,40 +1,36 @@
from sqlalchemy import select
from sqlalchemy.orm import Session

from core.variables import SegmentType, Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from models import ConversationVariable
from models.workflow import WorkflowNodeExecutionStatus

from .exc import VariableAssignerNodeError
from .node_data import VariableAssignerData, WriteMode


class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
_node_type = NodeType.VARIABLE_ASSIGNER

def _run(self) -> NodeRunResult:
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError("assigned variable not found")
raise VariableOperatorNodeError("assigned variable not found")

match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})

case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})

Expand All @@ -43,7 +39,7 @@ def _run(self) -> NodeRunResult:
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})

case _:
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")

# Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
Expand All @@ -52,8 +48,8 @@ def _run(self) -> NodeRunResult:
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableAssignerNodeError("conversation_id not found")
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
raise VariableOperatorNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
Expand All @@ -63,18 +59,6 @@ def _run(self) -> NodeRunResult:
)


def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()


def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
Expand All @@ -86,4 +70,4 @@ def get_zero_value(t: SegmentType):
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
Loading
Loading