Skip to content

Commit 0515c7e

Browse files
committed
Keep with_columns in abstract factory pattern
Removed the registry dependency and single dispatch for now.
1 parent 343b5af commit 0515c7e

23 files changed

+1557
-1345
lines changed

docs/reference/decorators/with_columns.rst

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,32 @@
22
with_columns
33
=======================
44

5-
Pandas and Polars
5+
We support the `with_columns` operation that appends the results as new columns to the original dataframe for several libraries:
6+
7+
Pandas
8+
-----------------------
9+
10+
**Reference Documentation**
11+
12+
.. autoclass:: hamilton.plugins.h_pandas.with_columns
13+
:special-members: __init__
14+
15+
16+
Polar (Eager)
617
-----------------------
718

8-
We have a ``with_columns`` option to run operations on columns of a Pandas / Polars dataframe and append the results as new columns.
19+
**Reference Documentation**
20+
21+
.. autoclass:: hamilton.plugins.h_polars.with_columns
22+
:special-members: __init__
23+
24+
25+
Polars (Lazy)
26+
-----------------------
927

1028
**Reference Documentation**
1129

12-
.. autoclass:: hamilton.function_modifiers.with_columns
30+
.. autoclass:: hamilton.plugins.h_polars_lazyframe.with_columns
1331
:special-members: __init__
1432

1533

examples/pandas/with_columns/notebook.ipynb

Lines changed: 309 additions & 311 deletions
Large diffs are not rendered by default.

examples/polars/with_columns/notebook.ipynb

Lines changed: 598 additions & 587 deletions
Large diffs are not rendered by default.

hamilton/function_modifiers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888

8989
subdag = recursive.subdag
9090
parameterized_subdag = recursive.parameterized_subdag
91-
with_columns = recursive.with_columns
9291

9392
# resolve/meta stuff -- power user features
9493

hamilton/function_modifiers/recursive.py

Lines changed: 62 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
ParametrizedDependency,
2525
UpstreamDependency,
2626
)
27-
from hamilton.function_modifiers.expanders import extract_columns
2827

2928

3029
def assign_namespace(node_name: str, namespace: str) -> str:
@@ -631,15 +630,9 @@ def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> L
631630
return output
632631

633632

634-
class with_columns(base.NodeInjector, abc.ABC):
635-
"""Performs with_columns operation on a dataframe. This is used when you want to extract some
633+
class with_columns_factory(base.NodeInjector, abc.ABC):
634+
"""Factory for with_columns operation on a dataframe. This is used when you want to extract some
636635
columns out of the dataframe, perform operations on them and then append to the original dataframe.
637-
For now can be used with:
638-
639-
- Pandas
640-
- Polars
641-
642-
643636
644637
Here's an example of calling it on a pandas dataframe -- if you've seen ``@subdag``, you should be familiar with
645638
the concepts:
@@ -742,6 +735,25 @@ def _check_for_duplicates(nodes_: List[node.Node]) -> bool:
742735
return True
743736
return False
744737

738+
@staticmethod
739+
def validate_dataframe(
740+
fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]], required_type: Type
741+
) -> None:
742+
input_types = typing.get_type_hints(fn)
743+
if inject_parameter not in params:
744+
raise InvalidDecoratorException(
745+
f"Function: {fn.__name__} does not have the parameter {inject_parameter} as a dependency. "
746+
f"@with_columns requires the parameter names to match the function parameters. "
747+
f"If you wish do not wish to use the first argument, please use `pass_dataframe_as` option. "
748+
f"It might not be compatible with some other decorators."
749+
)
750+
751+
if input_types[inject_parameter] != required_type:
752+
raise InvalidDecoratorException(
753+
"The selected dataframe parameter is not the correct dataframe type. "
754+
f"You selected a parameter of type {input_types[inject_parameter]}, but we expect to get {required_type}"
755+
)
756+
745757
def __init__(
746758
self,
747759
*load_from: Union[Callable, ModuleType],
@@ -750,6 +762,7 @@ def __init__(
750762
select: List[str] = None,
751763
namespace: str = None,
752764
config_required: List[str] = None,
765+
dataframe_type: Type = None,
753766
):
754767
"""Instantiates a ``@with_columns`` decorator.
755768
@@ -795,119 +808,64 @@ def __init__(
795808
self.namespace = namespace
796809
self.config_required = config_required
797810

798-
def required_config(self) -> List[str]:
799-
return self.config_required
800-
801-
def _create_column_nodes(
802-
self, inject_parameter: str, params: Dict[str, Type[Type]]
803-
) -> List[node.Node]:
804-
output_type = params[inject_parameter]
805-
806-
if self.is_async:
807-
808-
async def temp_fn(**kwargs) -> Any:
809-
return kwargs[inject_parameter]
810-
else:
811-
812-
def temp_fn(**kwargs) -> Any:
813-
return kwargs[inject_parameter]
814-
815-
# We recreate the df node to use extract columns
816-
temp_node = node.Node(
817-
name=inject_parameter,
818-
typ=output_type,
819-
callabl=temp_fn,
820-
input_types={inject_parameter: output_type},
821-
)
811+
if dataframe_type is None:
812+
raise InvalidDecoratorException(
813+
"Please provide the dataframe type for this specific library."
814+
)
822815

823-
extract_columns_decorator = extract_columns(*self.initial_schema)
816+
self.dataframe_type = dataframe_type
824817

825-
out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn)
826-
return out_nodes[1:]
818+
def required_config(self) -> List[str]:
819+
return self.config_required
827820

828-
def _get_inital_nodes(
821+
@abc.abstractmethod
822+
def get_initial_nodes(
829823
self, fn: Callable, params: Dict[str, Type[Type]]
830824
) -> Tuple[str, Collection[node.Node]]:
831-
"""Selects the correct dataframe and optionally extracts out columns."""
832-
initial_nodes = []
833-
sig = inspect.signature(fn)
834-
input_types = typing.get_type_hints(fn)
825+
"""Preparation stage where columns get extracted into nodes. In case `pass_dataframe_as` is
826+
used, this should return an empty list (no column nodes) since the users will extract it
827+
themselves.
828+
829+
:param fn: the function we are decorating. By using the inspect library you can get information.
830+
about what arguments it has / find out the dataframe argument.
831+
:param params: Dictionary of all the type names one wants to inject.
832+
:return: name of the dataframe parameter and list of nodes representing the extracted columns (can be empty).
833+
"""
834+
pass
835835

836-
if self.dataframe_subdag_param is not None:
837-
inject_parameter = self.dataframe_subdag_param
838-
else:
839-
# If we don't have a specified dataframe we assume it's the first argument
840-
inject_parameter = list(sig.parameters.values())[0].name
836+
@abc.abstractmethod
837+
def get_subdag_nodes(self, config: Dict[str, Any]) -> Collection[node.Node]:
838+
"""Creates subdag from the passed in module / functions.
841839
842-
if inject_parameter not in params:
843-
raise base.InvalidDecoratorException(
844-
f"Function: {fn.__name__} does not have the parameter {inject_parameter} as a dependency. "
845-
f"@with_columns requires the parameter names to match the function parameters. "
846-
f"If you wish do not wish to use the first argument, please use `pass_dataframe_as` option. "
847-
f"It might not be compatible with some other decorators."
848-
)
840+
:param config: Configuration with which the DAG was constructed.
841+
:return: the subdag as a list of nodes.
842+
"""
843+
pass
849844

850-
dataframe_type = input_types[inject_parameter]
851-
initial_nodes = (
852-
[]
853-
if self.dataframe_subdag_param is not None
854-
else self._create_column_nodes(inject_parameter=inject_parameter, params=params)
855-
)
845+
@abc.abstractmethod
846+
def create_merge_node(self, fn: Callable, inject_parameter: str) -> node.Node:
847+
"""Combines the origanl dataframe with selected columns. This should produce a
848+
dataframe output that is injected into the decorated function with new columns
849+
appended and existing columns overriden.
856850
857-
return inject_parameter, initial_nodes, dataframe_type
858-
859-
def create_merge_node(
860-
self, upstream_node: str, node_name: str, dataframe_type: Type
861-
) -> node.Node:
862-
"Node that adds to / overrides columns for the original dataframe based on selected output."
863-
if self.is_async:
864-
865-
async def new_callable(**kwargs) -> Any:
866-
df = kwargs[upstream_node]
867-
columns_to_append = {}
868-
for column in self.select:
869-
columns_to_append[column] = kwargs[column]
870-
new_df = registry.with_columns(df, columns_to_append)
871-
return new_df
872-
else:
873-
874-
def new_callable(**kwargs) -> Any:
875-
df = kwargs[upstream_node]
876-
columns_to_append = {}
877-
for column in self.select:
878-
columns_to_append[column] = kwargs[column]
879-
880-
new_df = registry.with_columns(df, columns_to_append)
881-
return new_df
882-
883-
column_type = registry.get_column_type_from_df_type(dataframe_type)
884-
input_map = {column: column_type for column in self.select}
885-
input_map[upstream_node] = dataframe_type
886-
887-
return node.Node(
888-
name=node_name,
889-
typ=dataframe_type,
890-
callabl=new_callable,
891-
input_types=input_map,
892-
)
851+
:param inject_parameter: the name of the original dataframe that.
852+
:return: the new dataframe with the columns appended / overwritten.
853+
"""
854+
pass
893855

894856
def inject_nodes(
895857
self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable
896858
) -> Tuple[List[node.Node], Dict[str, str]]:
897-
self.is_async = inspect.iscoroutinefunction(fn)
898859
namespace = fn.__name__ if self.namespace is None else self.namespace
899860

900-
inject_parameter, initial_nodes, dataframe_type = self._get_inital_nodes(
901-
fn=fn, params=params
902-
)
903-
904-
subdag_nodes = subdag.collect_nodes(config, self.subdag_functions)
861+
inject_parameter, initial_nodes = self.get_initial_nodes(fn=fn, params=params)
862+
subdag_nodes = self.get_subdag_nodes(config=config)
905863

906864
# TODO: for now we restrict that if user wants to change columns that already exist, he needs to
907865
# pass the dataframe and extract them himself. If we add namespace to initial nodes and rewire the
908866
# initial node names with the ongoing ones that have a column argument, we can also allow in place
909867
# changes when using columns_to_pass
910-
if with_columns._check_for_duplicates(initial_nodes + subdag_nodes):
868+
if with_columns_factory._check_for_duplicates(initial_nodes + subdag_nodes):
911869
raise ValueError(
912870
"You can only specify columns once. You used `columns_to_pass` and we "
913871
"extract the columns for you. In this case they cannot be overwritten -- only new columns get "
@@ -927,16 +885,11 @@ def inject_nodes(
927885
self.select = [
928886
sink_node.name
929887
for sink_node in pruned_nodes
930-
if sink_node.type == registry.get_column_type_from_df_type(dataframe_type)
888+
if sink_node.type == registry.get_column_type_from_df_type(self.dataframe_type)
931889
]
932890

933-
merge_node = self.create_merge_node(
934-
inject_parameter, node_name="__append", dataframe_type=dataframe_type
935-
)
891+
merge_node = self.create_merge_node(fn=fn, inject_parameter=inject_parameter)
936892

937893
output_nodes = initial_nodes + pruned_nodes + [merge_node]
938894
output_nodes = subdag.add_namespace(output_nodes, namespace)
939895
return output_nodes, {inject_parameter: assign_namespace(merge_node.name, namespace)}
940-
941-
def validate(self, fn: Callable):
942-
pass

hamilton/plugins/dask_extensions.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@ def fill_with_scalar_dask(df: dd.DataFrame, column_name: str, value: Any) -> dd.
2222
return df
2323

2424

25-
@registry.with_columns.register(dd.DataFrame)
26-
def with_columns_dask(df: dd.DataFrame, columns: dd.Series) -> dd.DataFrame:
27-
raise NotImplementedError(
28-
"As of Hamilton version 1.83.1, with_columns for Dask isn't supported."
29-
)
30-
31-
3225
def register_types():
3326
"""Function to register the types for this extension."""
3427
registry.register_types("dask", DATAFRAME_TYPE, COLUMN_TYPE)

hamilton/plugins/geopandas_extensions.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,6 @@ def fill_with_scalar_geopandas(
2424
return df
2525

2626

27-
@registry.with_columns.register(gpd.GeoDataFrame)
28-
def with_columns_geopandas(df: gpd.GeoDataFrame, columns: gpd.GeoSeries) -> gpd.GeoDataFrame:
29-
raise NotImplementedError(
30-
"As of Hamilton version 1.83.1, with_columns for geopandas isn't supported."
31-
)
32-
33-
3427
def register_types():
3528
"""Function to register the types for this extension."""
3629
registry.register_types("geopandas", DATAFRAME_TYPE, COLUMN_TYPE)

0 commit comments

Comments
 (0)