Skip to content

refactor(list_operator): refine exception handling for error specificity #10206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions api/core/workflow/nodes/list_operator/exc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class ListOperatorError(ValueError):
"""Base class for all ListOperator errors."""

pass


class InvalidFilterValueError(ListOperatorError):
pass


class InvalidKeyError(ListOperatorError):
pass


class InvalidConditionError(ListOperatorError):
pass
155 changes: 97 additions & 58 deletions api/core/workflow/nodes/list_operator/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Literal
from typing import Literal, Union

from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
Expand All @@ -9,6 +9,7 @@
from models.workflow import WorkflowNodeExecutionStatus

from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError


class ListOperatorNode(BaseNode[ListOperatorNodeData]):
Expand All @@ -26,7 +27,17 @@ def _run(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
Expand All @@ -36,78 +47,106 @@ def _run(self):
)

if isinstance(variable, ArrayFileSegment):
inputs = {"variable": [item.to_dict() for item in variable.value]}
process_data["variable"] = [item.to_dict() for item in variable.value]
else:
inputs = {"variable": variable.value}
process_data["variable"] = variable.value

# Filter
if self.node_data.filter_by.enabled:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})

# Order
if self.node_data.order_by.enabled:
try:
# Filter
if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable)

# Order
if self.node_data.order_by.enabled:
variable = self._apply_order(variable)

# Slice
if self.node_data.limit.enabled:
variable = self._apply_slice(variable)

outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
except ListOperatorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
outputs=outputs,
)

def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
return variable

# Slice
if self.node_data.limit.enabled:
result = variable.value[: self.node_data.limit.size]
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable

outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
result = variable.value[: self.node_data.limit.size]
return variable.model_copy(update={"value": result})


def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
match key:
case "size":
return lambda x: x.size
case _:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")


def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
Expand All @@ -125,7 +164,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "url":
return lambda x: x.remote_url or ""
case _:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")


def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
Expand All @@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "not empty":
return lambda x: x != ""
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")


def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
Expand All @@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "not in":
return lambda x: not _in(value)(x)
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")


def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
Expand All @@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
case "≥":
return _ge(value)
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")


def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
Expand All @@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")


def _contains(value: str):
Expand Down