Skip to content

Commit 1818302

Browse files
committed
Refactor with_columns to with_columns_factory
Central with_columns_factory from which dataframe libraries inherit. Changes: - Use with_columns_factory to inherit from and only need to add correct dataframe types and create_merge_node method. - Refactored h_pandas.with_columns and h_polars.with_columns to inherit from it. - Added async support for pandas (not sure if it makes sense for polars) - Select=None appends all sink nodes with column types -- same as h_spark.with_columns
1 parent 801e5ce commit 1818302

File tree

9 files changed

+1806
-1601
lines changed

9 files changed

+1806
-1601
lines changed

examples/pandas/with_columns/notebook.ipynb

Lines changed: 430 additions & 433 deletions
Large diffs are not rendered by default.

examples/polars/notebook.ipynb

Lines changed: 121 additions & 119 deletions
Large diffs are not rendered by default.

examples/polars/with_columns/notebook.ipynb

Lines changed: 673 additions & 691 deletions
Large diffs are not rendered by default.

hamilton/function_modifiers/recursive.py

Lines changed: 238 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import abc
12
import inspect
23
import sys
4+
import typing
5+
from collections import defaultdict
36
from types import ModuleType
47
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, TypedDict, Union
58

@@ -11,17 +14,21 @@
1114
else:
1215
from typing import NotRequired
1316

17+
from pandas import DataFrame as PandasDataFrame
18+
from polars import DataFrame as PolarsDataFrame
19+
from polars import LazyFrame as PolarsLazyFrame
1420

1521
# Copied this over from function_graph
1622
# TODO -- determine the best place to put this code
17-
from hamilton import graph_utils, node
23+
from hamilton import graph_utils, node, registry
1824
from hamilton.function_modifiers import base, dependencies
1925
from hamilton.function_modifiers.base import InvalidDecoratorException, NodeTransformer
2026
from hamilton.function_modifiers.dependencies import (
2127
LiteralDependency,
2228
ParametrizedDependency,
2329
UpstreamDependency,
2430
)
31+
from hamilton.function_modifiers.expanders import extract_columns
2532

2633

2734
def assign_namespace(node_name: str, namespace: str) -> str:
@@ -626,3 +633,233 @@ def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> L
626633
stack.append(dep_node)
627634
seen_nodes.add(dep)
628635
return output
636+
637+
638+
SUPPORTED_DATAFAME_TYPES = [PandasDataFrame, PolarsDataFrame, PolarsLazyFrame]
639+
640+
641+
class with_columns_factory(base.NodeInjector, abc.ABC):
642+
"""Performs with_columns operation on a dataframe. This is a special case of NodeInjector
643+
that applies only to dataframes. For now can be used with:
644+
645+
- Pandas
646+
- Polars
647+
648+
This is used when you want to extract some columns out of the dataframe, perform operations
649+
on them and then append to the original dataframe.
650+
651+
def processed_data(data: pd.DataFrame) -> pd.DataFrame:
652+
...
653+
654+
In this case we would build a subdag out of the node ``data`` and append selected nodes back to
655+
the original dataframe before feeding it into ``processed_data``.
656+
"""
657+
658+
# TODO: if we rename the column nodes into something smarter this can be avoided and
659+
# can also modify columns in place
660+
@staticmethod
661+
def _check_for_duplicates(nodes_: List[node.Node]) -> bool:
662+
"""Ensures that we don't run into name clashing of columns and group operations.
663+
664+
In the case when we extract columns for the user, because ``columns_to_pass`` was used, we want
665+
to safeguard against nameclashing with functions that are passed into ``with_columns`` - i.e.
666+
there are no functions that have the same name as the columns. This effectively means that
667+
using ``columns_to_pass`` will only append new columns to the dataframe and for changing
668+
existing columns ``pass_dataframe_as`` needs to be used.
669+
"""
670+
node_counter = defaultdict(int)
671+
for node_ in nodes_:
672+
node_counter[node_.name] += 1
673+
if node_counter[node_.name] > 1:
674+
return True
675+
return False
676+
677+
def validate_dataframe_type(self):
678+
if not set(self.allowed_dataframe_types).issubset(list(SUPPORTED_DATAFAME_TYPES)):
679+
raise InvalidDecoratorException(
680+
f"The provided dataframe types: {self.allowed_dataframe_types} are currently not supported "
681+
"to be used in `with_columns`. Please reach out if you need it. "
682+
f"We currently only support: {SUPPORTED_DATAFAME_TYPES}."
683+
)
684+
685+
def __init__(
686+
self,
687+
*load_from: Union[Callable, ModuleType],
688+
columns_to_pass: List[str] = None,
689+
pass_dataframe_as: str = None,
690+
select: List[str] = None,
691+
namespace: str = None,
692+
config_required: List[str] = None,
693+
dataframe_types: Collection[Type] = None,
694+
):
695+
"""Instantiates a ``@with_column`` decorator.
696+
697+
:param load_from: The functions or modules that will be used to generate the group of map operations.
698+
:param columns_to_pass: The initial schema of the dataframe. This is used to determine which
699+
upstream inputs should be taken from the dataframe, and which shouldn't. Note that, if this is
700+
left empty (and external_inputs is as well), we will assume that all dependencies come
701+
from the dataframe. This cannot be used in conjunction with pass_dataframe_as.
702+
:param pass_dataframe_as: The name of the dataframe that we're modifying, as known to the subdag.
703+
If you pass this in, you are responsible for extracting columns out. If not provided, you have
704+
to pass columns_to_pass in, and we will extract the columns out for you.
705+
:param select: The end nodes that represent columns to be appended to the original dataframe
706+
via with_columns. Existing columns will be overridden.
707+
:param namespace: The namespace of the nodes, so they don't clash with the global namespace
708+
and so this can be reused. If its left out, there will be no namespace (in which case you'll want
709+
to be careful about repeating it/reusing the nodes in other parts of the DAG.)
710+
:param config_required: the list of config keys that are required to resolve any functions. Pass in None\
711+
if you want the functions/modules to have access to all possible config.
712+
"""
713+
714+
if dataframe_types is None:
715+
raise ValueError("You need to specify which dataframe types it will be applied to.")
716+
else:
717+
if isinstance(dataframe_types, Type):
718+
dataframe_types = [dataframe_types]
719+
self.allowed_dataframe_types = dataframe_types
720+
self.validate_dataframe_type()
721+
722+
self.subdag_functions = subdag.collect_functions(load_from)
723+
self.select = select
724+
725+
if (pass_dataframe_as is not None and columns_to_pass is not None) or (
726+
pass_dataframe_as is None and columns_to_pass is None
727+
):
728+
raise ValueError(
729+
"You must specify only one of columns_to_pass and "
730+
"pass_dataframe_as. "
731+
"This is because specifying pass_dataframe_as injects into "
732+
"the set of columns, allowing you to perform your own extraction"
733+
"from the dataframe. We then execute all columns in the sbudag"
734+
"in order, passing in that initial dataframe. If you want"
735+
"to reference columns in your code, you'll have to specify "
736+
"the set of initial columns, and allow the subdag decorator "
737+
"to inject the dataframe through. The initial columns tell "
738+
"us which parameters to take from that dataframe, so we can"
739+
"feed the right data into the right columns."
740+
)
741+
742+
self.initial_schema = columns_to_pass
743+
self.dataframe_subdag_param = pass_dataframe_as
744+
self.namespace = namespace
745+
self.config_required = config_required
746+
747+
def required_config(self) -> List[str]:
748+
return self.config_required
749+
750+
def _create_column_nodes(
751+
self, inject_parameter: str, params: Dict[str, Type[Type]]
752+
) -> List[node.Node]:
753+
output_type = params[inject_parameter]
754+
755+
if self.is_async:
756+
757+
async def temp_fn(**kwargs) -> Any:
758+
return kwargs[inject_parameter]
759+
else:
760+
761+
def temp_fn(**kwargs) -> Any:
762+
return kwargs[inject_parameter]
763+
764+
# We recreate the df node to use extract columns
765+
temp_node = node.Node(
766+
name=inject_parameter,
767+
typ=output_type,
768+
callabl=temp_fn,
769+
input_types={inject_parameter: output_type},
770+
)
771+
772+
extract_columns_decorator = extract_columns(*self.initial_schema)
773+
774+
out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn)
775+
return out_nodes[1:]
776+
777+
def _get_inital_nodes(
778+
self, fn: Callable, params: Dict[str, Type[Type]]
779+
) -> Tuple[str, Collection[node.Node]]:
780+
"""Selects the correct dataframe and optionally extracts out columns."""
781+
initial_nodes = []
782+
sig = inspect.signature(fn)
783+
input_types = typing.get_type_hints(fn)
784+
785+
if self.dataframe_subdag_param is not None:
786+
inject_parameter = self.dataframe_subdag_param
787+
else:
788+
# If we don't have a specified dataframe we assume it's the first argument
789+
inject_parameter = list(sig.parameters.values())[0].name
790+
791+
if inject_parameter not in params:
792+
raise base.InvalidDecoratorException(
793+
f"Function: {fn.__name__} does not have the parameter {inject_parameter} as a dependency. "
794+
f"@with_columns requires the parameter names to match the function parameters. "
795+
f"If you wish do not wish to use the first argument, please use `pass_dataframe_as` option. "
796+
f"It might not be compatible with some other decorators."
797+
)
798+
799+
if input_types[inject_parameter] not in self.allowed_dataframe_types:
800+
raise ValueError(f"Dataframe has to be a {self.allowed_dataframe_types} DataFrame.")
801+
else:
802+
self.dataframe_type = input_types[inject_parameter]
803+
804+
initial_nodes = (
805+
[]
806+
if self.dataframe_subdag_param is not None
807+
else self._create_column_nodes(inject_parameter=inject_parameter, params=params)
808+
)
809+
810+
return inject_parameter, initial_nodes
811+
812+
@abc.abstractmethod
813+
def create_merge_node(self, upstream_node: str, node_name: str) -> node.Node:
814+
"""Node that adds to / overrides columns for the original dataframe based on selected output.
815+
816+
This will be platform specific, e.g. Polars already has with_columns whereas Pandas we need
817+
to implement it ourselves.
818+
"""
819+
pass
820+
821+
def inject_nodes(
822+
self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable
823+
) -> Tuple[List[node.Node], Dict[str, str]]:
824+
self.is_async = inspect.iscoroutinefunction(fn)
825+
namespace = fn.__name__ if self.namespace is None else self.namespace
826+
827+
inject_parameter, initial_nodes = self._get_inital_nodes(fn=fn, params=params)
828+
829+
subdag_nodes = subdag.collect_nodes(config, self.subdag_functions)
830+
831+
# TODO: for now we restrict that if user wants to change columns that already exist, he needs to
832+
# pass the dataframe and extract them himself. If we add namespace to initial nodes and rewire the
833+
# initial node names with the ongoing ones that have a column argument, we can also allow in place
834+
# changes when using columns_to_pass
835+
if with_columns_factory._check_for_duplicates(initial_nodes + subdag_nodes):
836+
raise ValueError(
837+
"You can only specify columns once. You used `columns_to_pass` and we "
838+
"extract the columns for you. In this case they cannot be overwritten -- only new columns get "
839+
"appended. If you want to modify in-place columns pass in a dataframe and "
840+
"extract + modify the columns and afterwards select them."
841+
)
842+
843+
pruned_nodes = prune_nodes(subdag_nodes, self.select)
844+
if len(pruned_nodes) == 0:
845+
raise ValueError(
846+
f"No nodes found upstream from select columns: {self.select} for function: "
847+
f"{fn.__qualname__}"
848+
)
849+
# In case no node is selected we append all possible nodes that have a column type matching
850+
# what the dataframe expects
851+
if self.select is None:
852+
self.select = [
853+
sink_node.name
854+
for sink_node in pruned_nodes
855+
if sink_node.type == registry.get_column_type_from_df_type(self.dataframe_type)
856+
]
857+
858+
merge_node = self.create_merge_node(inject_parameter, node_name="__append")
859+
860+
output_nodes = initial_nodes + pruned_nodes + [merge_node]
861+
output_nodes = subdag.add_namespace(output_nodes, namespace)
862+
return output_nodes, {inject_parameter: assign_namespace(merge_node.name, namespace)}
863+
864+
def validate(self, fn: Callable):
865+
self.validate_dataframe_type()

0 commit comments

Comments
 (0)