Skip to content

Adds FutureAdapter that delegates executions to a threadpool for parallelization #1264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/parallelism/lazy_threadpool_execution/my_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import time


def a() -> str:
print("a")
time.sleep(3)
return "a"


def b() -> str:
print("b")
time.sleep(3)
return "b"


def c(a: str, b: str) -> str:
print("c")
time.sleep(3)
return a + " " + b


def d() -> str:
print("d")
time.sleep(3)
return "d"


def e(c: str, d: str) -> str:
print("e")
time.sleep(3)
return c + " " + d


def z() -> str:
print("z")
time.sleep(3)
return "z"


def y() -> str:
print("y")
time.sleep(3)
return "y"


def x(z: str, y: str) -> str:
print("x")
time.sleep(3)
return z + " " + y


def s(x: str, e: str) -> str:
print("s")
time.sleep(3)
return x + " " + e
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 59 additions & 0 deletions examples/parallelism/lazy_threadpool_execution/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import time

import my_functions

from hamilton import driver
from hamilton.plugins import h_threadpool

start = time.time()
adapter = h_threadpool.FutureAdapter()
dr = driver.Builder().with_modules(my_functions).with_adapters(adapter).build()
dr.display_all_functions("my_funtions.png")
r = dr.execute("s")
print("got return from dr")
print(r)
print("Time taken with", time.time() - start)

from hamilton_sdk import adapters

tracker = adapters.HamiltonTracker(
project_id=21, # modify this as needed
username="[email protected]",
dag_name="with_caching",
tags={"environment": "DEV", "cached": "False", "team": "MY_TEAM", "version": "1"},
)

start = time.time()
dr = (
driver.Builder().with_modules(my_functions).with_adapters(tracker, adapter).with_cache().build()
)
dr.display_all_functions("a.png")
r = dr.execute("s")
print("got return from dr")
print(r)
print("Time taken with cold cache", time.time() - start)

tracker = adapters.HamiltonTracker(
project_id=21, # modify this as needed
username="[email protected]",
dag_name="with_caching",
tags={"environment": "DEV", "cached": "True", "team": "MY_TEAM", "version": "1"},
)

start = time.time()
dr = (
driver.Builder().with_modules(my_functions).with_adapters(tracker, adapter).with_cache().build()
)
dr.display_all_functions("a.png")
r = dr.execute("s")
print("got return from dr")
print(r)
print("Time taken with warm cache", time.time() - start)

start = time.time()
dr = driver.Builder().with_modules(my_functions).build()
dr.display_all_functions("a.png")
r = dr.execute("s")
print("got return from dr")
print(r)
print("Time taken without", time.time() - start)
58 changes: 58 additions & 0 deletions hamilton/plugins/h_threadpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Callable, Dict

from hamilton import registry

registry.disable_autoload()

from hamilton import lifecycle, node
from hamilton.lifecycle import base


def _new_fn(fn, **fn_kwargs):
"""Function that runs in the thread.

It can recursively check for Futures because we don't have to worry about
process serialization.
:param fn: Function to run
:param fn_kwargs: Keyword arguments to pass to the function
"""
for k, v in fn_kwargs.items():
if isinstance(v, Future):
while isinstance(v, Future):
v = v.result()
fn_kwargs[k] = v
# execute the function once all the futures are resolved
return fn(**fn_kwargs)


class FutureAdapter(base.BaseDoRemoteExecute, lifecycle.ResultBuilder):
def __init__(self, max_workers: int = None):
self.executor = ThreadPoolExecutor(max_workers=max_workers)
# self.executor = ProcessPoolExecutor(max_workers=max_workers)

def do_remote_execute(
self,
*,
execute_lifecycle_for_node: Callable,
node: node.Node,
**kwargs: Dict[str, Any],
) -> Any:
"""Method that is called to implement correct remote execution of hooks. This makes sure that all the pre-node and post-node hooks get executed in the remote environment which is necessary for some adapters. Node execution is called the same as before through "do_node_execute".

:param node: Node that is being executed
:param kwargs: Keyword arguments that are being passed into the node
:param execute_lifecycle_for_node: Function executing lifecycle_hooks and lifecycle_methods
"""
return self.executor.submit(_new_fn, execute_lifecycle_for_node, **kwargs)

def build_result(self, **outputs: Any) -> Any:
"""Given a set of outputs, build the result.

:param outputs: the outputs from the execution of the graph.
:return: the result of the execution of the graph.
"""
for k, v in outputs.items():
if isinstance(v, Future):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want concurrent.futures.as_completed(...). Can't convince myselfe that we're not deadlocking in certain cases, but I think the fact that we're doing topological order should be good enough...

Copy link
Contributor Author

@skrawcz skrawcz Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this matters here. Since this wont block anything executing in a thread...

Copy link
Contributor

@elijahbenizzy elijahbenizzy Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's IMO slightly cleaner, but yeah, it'll not return until the slowest does regardless. Might also not want to mutate the outputs dictionary (copying is cleaner). But yes, nits.

outputs[k] = v.result()
return outputs