14
14
else :
15
15
from typing import NotRequired
16
16
17
- from pandas import DataFrame as PandasDataFrame
18
- from polars import DataFrame as PolarsDataFrame
19
- from polars import LazyFrame as PolarsLazyFrame
20
-
21
17
# Copied this over from function_graph
22
18
# TODO -- determine the best place to put this code
23
19
from hamilton import graph_utils , node , registry
@@ -635,24 +631,96 @@ def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> L
635
631
return output
636
632
637
633
638
- SUPPORTED_DATAFAME_TYPES = [PandasDataFrame , PolarsDataFrame , PolarsLazyFrame ]
639
-
640
-
641
- class with_columns_factory (base .NodeInjector , abc .ABC ):
642
- """Performs with_columns operation on a dataframe. This is a special case of NodeInjector
643
- that applies only to dataframes. For now can be used with:
634
+ class with_columns (base .NodeInjector , abc .ABC ):
635
+ """Performs with_columns operation on a dataframe. This is used when you want to extract some
636
+ columns out of the dataframe, perform operations on them and then append to the original dataframe.
637
+ For now can be used with:
644
638
645
639
- Pandas
646
640
- Polars
647
641
648
- This is used when you want to extract some columns out of the dataframe, perform operations
649
- on them and then append to the original dataframe.
650
642
651
- def processed_data(data: pd.DataFrame) -> pd.DataFrame:
643
+
644
+ Here's an example of calling it on a pandas dataframe -- if you've seen ``@subdag``, you should be familiar with
645
+ the concepts:
646
+
647
+ .. code-block:: python
648
+
649
+ # my_module.py
650
+ def a(a_from_df: pd.Series) -> pd.Series:
651
+ return _process(a)
652
+
653
+ def b(b_from_df: pd.Series) -> pd.Series:
654
+ return _process(b)
655
+
656
+ def a_b_average(a_from_df: pd.Series, b_from_df: pd.Series) -> pd.Series:
657
+ return (a_from_df + b_from_df) / 2
658
+
659
+
660
+ .. code-block:: python
661
+
662
+ # with_columns_module.py
663
+ def a_plus_b(a: pd.Series, b: pd.Series) -> pd.Series:
664
+ return a + b
665
+
666
+
667
+ # the with_columns call
668
+ @with_columns(
669
+ *[my_module], # Load from any module
670
+ *[a_plus_b], # or list operations directly
671
+ columns_to_pass=["a_from_df", "b_from_df"], # The columns to pass from the dataframe to
672
+ # the subdag
673
+ select=["a", "b", "a_plus_b", "a_b_average"], # The columns to select from the dataframe
674
+ )
675
+ def final_df(initial_df: pd.DataFrame) -> pd.DataFrame:
676
+ # process, or just return unprocessed
652
677
...
653
678
654
- In this case we would build a subdag out of the node ``data`` and append selected nodes back to
655
- the original dataframe before feeding it into ``processed_data``.
679
+ In this instance the ``initial_df`` would get two columns added: ``a_plus_b`` and ``a_b_average``.
680
+
681
+ The operations are applied in topological order. This allows you to
682
+ express the operations individually, making it easy to unit-test and reuse.
683
+
684
+ Note that the operation is "append", meaning that the columns that are selected are appended
685
+ onto the dataframe.
686
+
687
+ If the function takes multiple dataframes, the dataframe input to process will always be
688
+ the first argument. This will be passed to the subdag, transformed, and passed back to the function.
689
+ This follows the hamilton rule of reference by parameter name. To demonstarte this, in the code
690
+ above, the dataframe that is passed to the subdag is `initial_df`. That is transformed
691
+ by the subdag, and then returned as the final dataframe.
692
+
693
+ You can read it as:
694
+
695
+ "final_df is a function that transforms the upstream dataframe initial_df, running the transformations
696
+ from my_module. It starts with the columns a_from_df and b_from_df, and then adds the columns
697
+ a, b, and a_plus_b to the dataframe. It then returns the dataframe, and does some processing on it."
698
+
699
+ In case you need more flexibility you can alternatively use ``pass_dataframe_as``, for example,
700
+
701
+ .. code-block:: python
702
+
703
+ # with_columns_module.py
704
+ def a_from_df(initial_df: pd.Series) -> pd.Series:
705
+ return initial_df["a_from_df"] / 100
706
+
707
+ def b_from_df(initial_df: pd.Series) -> pd.Series:
708
+ return initial_df["b_from_df"] / 100
709
+
710
+
711
+ # the with_columns call
712
+ @with_columns(
713
+ *[my_module],
714
+ *[a_from_df],
715
+ columns_to_pass=["a_from_df", "b_from_df"],
716
+ select=["a_from_df", "b_from_df", "a", "b", "a_plus_b", "a_b_average"],
717
+ )
718
+ def final_df(initial_df: pd.DataFrame) -> pd.DataFrame:
719
+ # process, or just return unprocessed
720
+ ...
721
+
722
+ the above would output a dataframe where the two columns ``a_from_df`` and ``b_from_df`` get
723
+ overwritten.
656
724
"""
657
725
658
726
# TODO: if we rename the column nodes into something smarter this can be avoided and
@@ -674,14 +742,6 @@ def _check_for_duplicates(nodes_: List[node.Node]) -> bool:
674
742
return True
675
743
return False
676
744
677
- def validate_dataframe_type (self ):
678
- if not set (self .allowed_dataframe_types ).issubset (list (SUPPORTED_DATAFAME_TYPES )):
679
- raise InvalidDecoratorException (
680
- f"The provided dataframe types: { self .allowed_dataframe_types } are currently not supported "
681
- "to be used in `with_columns`. Please reach out if you need it. "
682
- f"We currently only support: { SUPPORTED_DATAFAME_TYPES } ."
683
- )
684
-
685
745
def __init__ (
686
746
self ,
687
747
* load_from : Union [Callable , ModuleType ],
@@ -690,7 +750,6 @@ def __init__(
690
750
select : List [str ] = None ,
691
751
namespace : str = None ,
692
752
config_required : List [str ] = None ,
693
- dataframe_types : Collection [Type ] = None ,
694
753
):
695
754
"""Instantiates a ``@with_column`` decorator.
696
755
@@ -711,14 +770,6 @@ def __init__(
711
770
if you want the functions/modules to have access to all possible config.
712
771
"""
713
772
714
- if dataframe_types is None :
715
- raise ValueError ("You need to specify which dataframe types it will be applied to." )
716
- else :
717
- if isinstance (dataframe_types , Type ):
718
- dataframe_types = [dataframe_types ]
719
- self .allowed_dataframe_types = dataframe_types
720
- self .validate_dataframe_type ()
721
-
722
773
self .subdag_functions = subdag .collect_functions (load_from )
723
774
self .select = select
724
775
@@ -796,44 +847,67 @@ def _get_inital_nodes(
796
847
f"It might not be compatible with some other decorators."
797
848
)
798
849
799
- if input_types [inject_parameter ] not in self .allowed_dataframe_types :
800
- raise ValueError (f"Dataframe has to be a { self .allowed_dataframe_types } DataFrame." )
801
- else :
802
- self .dataframe_type = input_types [inject_parameter ]
803
-
850
+ dataframe_type = input_types [inject_parameter ]
804
851
initial_nodes = (
805
852
[]
806
853
if self .dataframe_subdag_param is not None
807
854
else self ._create_column_nodes (inject_parameter = inject_parameter , params = params )
808
855
)
809
856
810
- return inject_parameter , initial_nodes
857
+ return inject_parameter , initial_nodes , dataframe_type
858
+
859
+ def create_merge_node (
860
+ self , upstream_node : str , node_name : str , dataframe_type : Type
861
+ ) -> node .Node :
862
+ "Node that adds to / overrides columns for the original dataframe based on selected output."
863
+ if self .is_async :
811
864
812
- @abc .abstractmethod
813
- def create_merge_node (self , upstream_node : str , node_name : str ) -> node .Node :
814
- """Should create a node that merges the results back into the original dataframe.
865
+ async def new_callable (** kwargs ) -> Any :
866
+ df = kwargs [upstream_node ]
867
+ columns_to_append = {}
868
+ for column in self .select :
869
+ columns_to_append [column ] = kwargs [column ]
870
+ new_df = registry .with_columns (df , columns_to_append )
871
+ return new_df
872
+ else :
815
873
816
- Node that adds to / overrides columns for the original dataframe based on selected output.
874
+ def new_callable (** kwargs ) -> Any :
875
+ df = kwargs [upstream_node ]
876
+ columns_to_append = {}
877
+ for column in self .select :
878
+ columns_to_append [column ] = kwargs [column ]
817
879
818
- This will be platform specific, see Pandas and Polars plugins for implementation.
819
- """
820
- pass
880
+ new_df = registry .with_columns (df , columns_to_append )
881
+ return new_df
882
+
883
+ column_type = registry .get_column_type_from_df_type (dataframe_type )
884
+ input_map = {column : column_type for column in self .select }
885
+ input_map [upstream_node ] = dataframe_type
886
+
887
+ return node .Node (
888
+ name = node_name ,
889
+ typ = dataframe_type ,
890
+ callabl = new_callable ,
891
+ input_types = input_map ,
892
+ )
821
893
822
894
def inject_nodes (
823
895
self , params : Dict [str , Type [Type ]], config : Dict [str , Any ], fn : Callable
824
896
) -> Tuple [List [node .Node ], Dict [str , str ]]:
825
897
self .is_async = inspect .iscoroutinefunction (fn )
826
898
namespace = fn .__name__ if self .namespace is None else self .namespace
827
899
828
- inject_parameter , initial_nodes = self ._get_inital_nodes (fn = fn , params = params )
900
+ inject_parameter , initial_nodes , dataframe_type = self ._get_inital_nodes (
901
+ fn = fn , params = params
902
+ )
829
903
830
904
subdag_nodes = subdag .collect_nodes (config , self .subdag_functions )
831
905
832
906
# TODO: for now we restrict that if user wants to change columns that already exist, he needs to
833
907
# pass the dataframe and extract them himself. If we add namespace to initial nodes and rewire the
834
908
# initial node names with the ongoing ones that have a column argument, we can also allow in place
835
909
# changes when using columns_to_pass
836
- if with_columns_factory ._check_for_duplicates (initial_nodes + subdag_nodes ):
910
+ if with_columns ._check_for_duplicates (initial_nodes + subdag_nodes ):
837
911
raise ValueError (
838
912
"You can only specify columns once. You used `columns_to_pass` and we "
839
913
"extract the columns for you. In this case they cannot be overwritten -- only new columns get "
@@ -853,14 +927,16 @@ def inject_nodes(
853
927
self .select = [
854
928
sink_node .name
855
929
for sink_node in pruned_nodes
856
- if sink_node .type == registry .get_column_type_from_df_type (self . dataframe_type )
930
+ if sink_node .type == registry .get_column_type_from_df_type (dataframe_type )
857
931
]
858
932
859
- merge_node = self .create_merge_node (inject_parameter , node_name = "__append" )
933
+ merge_node = self .create_merge_node (
934
+ inject_parameter , node_name = "__append" , dataframe_type = dataframe_type
935
+ )
860
936
861
937
output_nodes = initial_nodes + pruned_nodes + [merge_node ]
862
938
output_nodes = subdag .add_namespace (output_nodes , namespace )
863
939
return output_nodes , {inject_parameter : assign_namespace (merge_node .name , namespace )}
864
940
865
941
def validate (self , fn : Callable ):
866
- self . validate_dataframe_type ()
942
+ pass
0 commit comments