Skip to content

Commit 8dce080

Browse files
authored
Track subflow results along with task results (#18338)
1 parent 3636ab8 commit 8dce080

File tree

19 files changed

+315
-145
lines changed

19 files changed

+315
-145
lines changed

src/integrations/prefect-dask/prefect_dask/task_runners.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def count_to(highest_number):
9292
import distributed.deploy.cluster
9393
from typing_extensions import ParamSpec
9494

95-
from prefect.client.schemas.objects import State, TaskRunInput
95+
from prefect.client.schemas.objects import RunInput, State
9696
from prefect.futures import PrefectFuture, PrefectFutureList, PrefectWrappedFuture
9797
from prefect.logging.loggers import get_logger
9898
from prefect.task_runners import TaskRunner
@@ -405,7 +405,7 @@ def submit(
405405
task: "Task[P, Coroutine[Any, Any, R]]",
406406
parameters: dict[str, Any],
407407
wait_for: Iterable[PrefectDaskFuture[R]] | None = None,
408-
dependencies: dict[str, Set[TaskRunInput]] | None = None,
408+
dependencies: dict[str, Set[RunInput]] | None = None,
409409
) -> PrefectDaskFuture[R]: ...
410410

411411
@overload
@@ -414,15 +414,15 @@ def submit(
414414
task: "Task[Any, R]",
415415
parameters: dict[str, Any],
416416
wait_for: Iterable[PrefectDaskFuture[R]] | None = None,
417-
dependencies: dict[str, Set[TaskRunInput]] | None = None,
417+
dependencies: dict[str, Set[RunInput]] | None = None,
418418
) -> PrefectDaskFuture[R]: ...
419419

420420
def submit(
421421
self,
422422
task: "Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]",
423423
parameters: dict[str, Any],
424424
wait_for: Iterable[PrefectDaskFuture[R]] | None = None,
425-
dependencies: dict[str, Set[TaskRunInput]] | None = None,
425+
dependencies: dict[str, Set[RunInput]] | None = None,
426426
) -> PrefectDaskFuture[R]:
427427
if not self._started:
428428
raise RuntimeError(

src/integrations/prefect-ray/prefect_ray/task_runners.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def count_to(highest_number):
8787

8888
from typing_extensions import ParamSpec, Self
8989

90-
from prefect.client.schemas.objects import TaskRunInput
90+
from prefect.client.schemas.objects import RunInput
9191
from prefect.context import serialize_context
9292
from prefect.futures import PrefectFuture, PrefectFutureList, PrefectWrappedFuture
9393
from prefect.logging.loggers import get_logger
@@ -227,7 +227,7 @@ def submit(
227227
task: "Task[P, Coroutine[Any, Any, R]]",
228228
parameters: dict[str, Any],
229229
wait_for: Iterable[PrefectFuture[Any]] | None = None,
230-
dependencies: dict[str, set[TaskRunInput]] | None = None,
230+
dependencies: dict[str, set[RunInput]] | None = None,
231231
) -> PrefectRayFuture[R]: ...
232232

233233
@overload
@@ -236,15 +236,15 @@ def submit(
236236
task: "Task[P, R]",
237237
parameters: dict[str, Any],
238238
wait_for: Iterable[PrefectFuture[Any]] | None = None,
239-
dependencies: dict[str, set[TaskRunInput]] | None = None,
239+
dependencies: dict[str, set[RunInput]] | None = None,
240240
) -> PrefectRayFuture[R]: ...
241241

242242
def submit(
243243
self,
244244
task: Task[P, R],
245245
parameters: dict[str, Any],
246246
wait_for: Iterable[PrefectFuture[Any]] | None = None,
247-
dependencies: dict[str, set[TaskRunInput]] | None = None,
247+
dependencies: dict[str, set[RunInput]] | None = None,
248248
):
249249
if not self._started:
250250
raise RuntimeError(
@@ -342,7 +342,7 @@ def _run_prefect_task(
342342
context: dict[str, Any],
343343
parameters: dict[str, Any],
344344
wait_for: Iterable[PrefectFuture[Any]] | None = None,
345-
dependencies: dict[str, set[TaskRunInput]] | None = None,
345+
dependencies: dict[str, set[RunInput]] | None = None,
346346
) -> Any:
347347
"""Resolves Ray futures before calling the actual Prefect task function.
348348

src/prefect/client/orchestration/__init__.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,11 @@
101101
WorkQueueFilterName,
102102
)
103103
from prefect.client.schemas.objects import (
104-
Constant,
104+
TaskRunResult,
105+
FlowRunResult,
105106
Parameter,
107+
Constant,
106108
TaskRunPolicy,
107-
TaskRunResult,
108109
WorkQueue,
109110
WorkQueueStatusDetail,
110111
)
@@ -778,13 +779,7 @@ async def create_task_run(
778779
task_inputs: Optional[
779780
dict[
780781
str,
781-
list[
782-
Union[
783-
TaskRunResult,
784-
Parameter,
785-
Constant,
786-
]
787-
],
782+
list[Union[TaskRunResult, FlowRunResult, Parameter, Constant]],
788783
]
789784
] = None,
790785
) -> TaskRun:
@@ -1440,6 +1435,7 @@ def create_task_run(
14401435
list[
14411436
Union[
14421437
TaskRunResult,
1438+
FlowRunResult,
14431439
Parameter,
14441440
Constant,
14451441
]

src/prefect/client/schemas/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
BlockType,
1212
FlowRun,
1313
FlowRunPolicy,
14+
FlowRunResult,
1415
State,
1516
StateDetails,
1617
StateType,
1718
TaskRun,
18-
TaskRunInput,
19+
RunInput,
1920
TaskRunPolicy,
2021
TaskRunResult,
2122
Workspace,
@@ -36,6 +37,7 @@
3637
"DEFAULT_BLOCK_SCHEMA_VERSION": (__package__, ".objects"),
3738
"FlowRun": (__package__, ".objects"),
3839
"FlowRunPolicy": (__package__, ".objects"),
40+
"FlowRunResult": (__package__, ".objects"),
3941
"OrchestrationResult": (__package__, ".responses"),
4042
"SetStateStatus": (__package__, ".responses"),
4143
"State": (__package__, ".objects"),
@@ -60,6 +62,7 @@
6062
"DEFAULT_BLOCK_SCHEMA_VERSION",
6163
"FlowRun",
6264
"FlowRunPolicy",
65+
"FlowRunResult",
6366
"OrchestrationResult",
6467
"SetStateStatus",
6568
"State",
@@ -70,7 +73,7 @@
7073
"StateRejectDetails",
7174
"StateType",
7275
"TaskRun",
73-
"TaskRunInput",
76+
"RunInput",
7477
"TaskRunPolicy",
7578
"TaskRunResult",
7679
"Workspace",

src/prefect/client/schemas/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ class TaskRunCreate(ActionBaseModel):
443443
list[
444444
Union[
445445
objects.TaskRunResult,
446+
objects.FlowRunResult,
446447
objects.Parameter,
447448
objects.Constant,
448449
]

src/prefect/client/schemas/objects.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@
8181
DEFAULT_AGENT_WORK_POOL_NAME: Literal["default-agent-pool"] = "default-agent-pool"
8282

8383

84+
class RunType(AutoEnum):
85+
FLOW_RUN = "flow_run"
86+
TASK_RUN = "task_run"
87+
88+
8489
class StateType(AutoEnum):
8590
"""Enumeration of state types."""
8691

@@ -164,7 +169,6 @@ class ConcurrencyLimitConfig(PrefectBaseModel):
164169
class StateDetails(PrefectBaseModel):
165170
flow_run_id: Optional[UUID] = None
166171
task_run_id: Optional[UUID] = None
167-
# for task runs that represent subflows, the subflow's run ID
168172
child_flow_run_id: Optional[UUID] = None
169173
scheduled_time: Optional[DateTime] = None
170174
cache_key: Optional[str] = None
@@ -182,6 +186,16 @@ class StateDetails(PrefectBaseModel):
182186
# Captures the trace_id and span_id of the span where this state was created
183187
traceparent: Optional[str] = None
184188

189+
def to_run_result(
190+
self, run_type: RunType
191+
) -> Optional[Union[FlowRunResult, TaskRunResult]]:
192+
if run_type == run_type.FLOW_RUN and self.flow_run_id:
193+
return FlowRunResult(id=self.flow_run_id)
194+
elif run_type == run_type.TASK_RUN and self.task_run_id:
195+
return TaskRunResult(id=self.task_run_id)
196+
else:
197+
return None
198+
185199

186200
def data_discriminator(x: Any) -> str:
187201
if isinstance(x, dict) and "storage_key" in x:
@@ -734,7 +748,7 @@ def validate_jitter_factor(cls, v: Optional[float]) -> Optional[float]:
734748
return validate_not_negative(v)
735749

736750

737-
class TaskRunInput(PrefectBaseModel):
751+
class RunInput(PrefectBaseModel):
738752
"""
739753
Base class for classes that represent inputs to task runs, which
740754
could include, constants, parameters, or other task runs.
@@ -747,21 +761,26 @@ class TaskRunInput(PrefectBaseModel):
747761
input_type: str
748762

749763

750-
class TaskRunResult(TaskRunInput):
764+
class TaskRunResult(RunInput):
751765
"""Represents a task run result input to another task run."""
752766

753767
input_type: Literal["task_run"] = "task_run"
754768
id: UUID
755769

756770

757-
class Parameter(TaskRunInput):
771+
class FlowRunResult(RunInput):
772+
input_type: Literal["flow_run"] = "flow_run"
773+
id: UUID
774+
775+
776+
class Parameter(RunInput):
758777
"""Represents a parameter input to a task run."""
759778

760779
input_type: Literal["parameter"] = "parameter"
761780
name: str
762781

763782

764-
class Constant(TaskRunInput):
783+
class Constant(RunInput):
765784
"""Represents constant input value to a task run."""
766785

767786
input_type: Literal["constant"] = "constant"
@@ -811,7 +830,9 @@ class TaskRun(TimeSeriesBaseModel, ObjectBaseModel):
811830
state_id: Optional[UUID] = Field(
812831
default=None, description="The id of the current task run state."
813832
)
814-
task_inputs: dict[str, list[Union[TaskRunResult, Parameter, Constant]]] = Field(
833+
task_inputs: dict[
834+
str, list[Union[TaskRunResult, FlowRunResult, Parameter, Constant]]
835+
] = Field(
815836
default_factory=dict,
816837
description=(
817838
"Tracks the source of inputs to a task run. Used for internal bookkeeping. "

src/prefect/context.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from prefect.assets import Asset
3535
from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
3636
from prefect.client.schemas import FlowRun, TaskRun
37+
from prefect.client.schemas.objects import RunType
3738
from prefect.events.worker import EventsWorker
3839
from prefect.exceptions import MissingContextError
3940
from prefect.results import (
@@ -363,7 +364,7 @@ class EngineContext(RunContext):
363364
flow: The flow instance associated with the run
364365
flow_run: The API metadata for the flow run
365366
task_runner: The task runner instance being used for the flow run
366-
task_run_results: A mapping of result ids to task run states for this flow run
367+
run_results: A mapping of result ids to run states for this flow run
367368
log_prints: Whether to log print statements from the flow run
368369
parameters: The parameters passed to the flow run
369370
detached: Flag indicating if context has been serialized and sent to remote infrastructure
@@ -394,9 +395,10 @@ class EngineContext(RunContext):
394395
# Counter for flow pauses
395396
observed_flow_pauses: dict[str, int] = Field(default_factory=dict)
396397

397-
# Tracking for result from task runs in this flow run for dependency tracking
398-
# Holds the ID of the object returned by the task run and task run state
399-
task_run_results: dict[int, State] = Field(default_factory=dict)
398+
# Tracking for result from task runs and sub flows in this flow run for
399+
# dependency tracking. Holds the ID of the object returned by
400+
# the run and state
401+
run_results: dict[int, tuple[State, RunType]] = Field(default_factory=dict)
400402

401403
# Tracking information needed to track asset linage between
402404
# tasks and materialization

src/prefect/flow_engine.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
from prefect.utilities.collections import visit_collection
106106
from prefect.utilities.engine import (
107107
capture_sigterm,
108-
link_state_to_result,
108+
link_state_to_flow_run_result,
109109
propose_state,
110110
propose_state_sync,
111111
resolve_to_final_result,
@@ -338,6 +338,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
338338
self._return_value, State
339339
):
340340
_result = self._return_value
341+
link_state_to_flow_run_result(self.state, _result)
341342

342343
if asyncio.iscoroutine(_result):
343344
# getting the value for a BaseResult may return an awaitable
@@ -373,6 +374,7 @@ def handle_success(self, result: R) -> R:
373374
self.set_state(terminal_state)
374375
self._return_value = resolved_result
375376

377+
link_state_to_flow_run_result(terminal_state, resolved_result)
376378
self._telemetry.end_span_on_success()
377379

378380
return result
@@ -903,6 +905,7 @@ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]"
903905
self._return_value, State
904906
):
905907
_result = self._return_value
908+
link_state_to_flow_run_result(self.state, _result)
906909

907910
if asyncio.iscoroutine(_result):
908911
# getting the value for a BaseResult may return an awaitable
@@ -1426,7 +1429,7 @@ def run_generator_flow_sync(
14261429
while True:
14271430
gen_result = next(gen)
14281431
# link the current state to the result for dependency tracking
1429-
link_state_to_result(engine.state, gen_result)
1432+
link_state_to_flow_run_result(engine.state, gen_result)
14301433
yield gen_result
14311434
except StopIteration as exc:
14321435
engine.handle_success(exc.value)
@@ -1468,7 +1471,7 @@ async def run_generator_flow_async(
14681471
# can't use anext in Python < 3.10
14691472
gen_result = await gen.__anext__()
14701473
# link the current state to the result for dependency tracking
1471-
link_state_to_result(engine.state, gen_result)
1474+
link_state_to_flow_run_result(engine.state, gen_result)
14721475
yield gen_result
14731476
except (StopAsyncIteration, GeneratorExit) as exc:
14741477
await engine.handle_success(None)

src/prefect/runner/submit.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
FlowRunFilterParentFlowRunId,
1818
TaskRunFilter,
1919
)
20-
from prefect.client.schemas.objects import Constant, FlowRun, Parameter, TaskRunResult
20+
from prefect.client.schemas.objects import (
21+
Constant,
22+
FlowRun,
23+
FlowRunResult,
24+
Parameter,
25+
TaskRunResult,
26+
)
2127
from prefect.context import FlowRunContext
2228
from prefect.flows import Flow
2329
from prefect.logging import get_logger
@@ -66,9 +72,9 @@ async def _submit_flow_to_runner(
6672

6773
parent_flow_run_context = FlowRunContext.get()
6874

69-
task_inputs: dict[str, list[TaskRunResult | Parameter | Constant]] = {
70-
k: list(await collect_task_run_inputs(v)) for k, v in parameters.items()
71-
}
75+
task_inputs: dict[
76+
str, list[Union[TaskRunResult, FlowRunResult, Parameter, Constant]]
77+
] = {k: list(await collect_task_run_inputs(v)) for k, v in parameters.items()}
7278
parameters = await resolve_inputs(parameters)
7379
dummy_task = Task(name=flow.name, fn=flow.fn, version=flow.version)
7480
parent_task_run = await client.create_task_run(

src/prefect/server/database/orm_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,10 @@ def __table_args__(cls) -> Iterable[sa.Index]:
643643

644644

645645
_TaskInput = Union[
646-
schemas.core.TaskRunResult, schemas.core.Parameter, schemas.core.Constant
646+
schemas.core.TaskRunResult,
647+
schemas.core.FlowRunResult,
648+
schemas.core.Parameter,
649+
schemas.core.Constant,
647650
]
648651
_TaskInputs = dict[str, list[_TaskInput]]
649652

src/prefect/server/schemas/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ class TaskRunCreate(ActionBaseModel):
501501
List[
502502
Union[
503503
schemas.core.TaskRunResult,
504+
schemas.core.FlowRunResult,
504505
schemas.core.Parameter,
505506
schemas.core.Constant,
506507
]

0 commit comments

Comments
 (0)