24
24
ParametrizedDependency ,
25
25
UpstreamDependency ,
26
26
)
27
- from hamilton .function_modifiers .expanders import extract_columns
28
27
29
28
30
29
def assign_namespace (node_name : str , namespace : str ) -> str :
@@ -631,15 +630,9 @@ def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> L
631
630
return output
632
631
633
632
634
- class with_columns (base .NodeInjector , abc .ABC ):
635
- """Performs with_columns operation on a dataframe. This is used when you want to extract some
633
+ class with_columns_factory (base .NodeInjector , abc .ABC ):
634
+ """Factory for with_columns operation on a dataframe. This is used when you want to extract some
636
635
columns out of the dataframe, perform operations on them and then append to the original dataframe.
637
- For now can be used with:
638
-
639
- - Pandas
640
- - Polars
641
-
642
-
643
636
644
637
Here's an example of calling it on a pandas dataframe -- if you've seen ``@subdag``, you should be familiar with
645
638
the concepts:
@@ -742,6 +735,25 @@ def _check_for_duplicates(nodes_: List[node.Node]) -> bool:
742
735
return True
743
736
return False
744
737
738
+ @staticmethod
739
+ def validate_dataframe (
740
+ fn : Callable , inject_parameter : str , params : Dict [str , Type [Type ]], required_type : Type
741
+ ) -> None :
742
+ input_types = typing .get_type_hints (fn )
743
+ if inject_parameter not in params :
744
+ raise InvalidDecoratorException (
745
+ f"Function: { fn .__name__ } does not have the parameter { inject_parameter } as a dependency. "
746
+ f"@with_columns requires the parameter names to match the function parameters. "
747
+ f"If you wish do not wish to use the first argument, please use `pass_dataframe_as` option. "
748
+ f"It might not be compatible with some other decorators."
749
+ )
750
+
751
+ if input_types [inject_parameter ] != required_type :
752
+ raise InvalidDecoratorException (
753
+ "The selected dataframe parameter is not the correct dataframe type. "
754
+ f"You selected a parameter of type { input_types [inject_parameter ]} , but we expect to get { required_type } "
755
+ )
756
+
745
757
def __init__ (
746
758
self ,
747
759
* load_from : Union [Callable , ModuleType ],
@@ -750,6 +762,7 @@ def __init__(
750
762
select : List [str ] = None ,
751
763
namespace : str = None ,
752
764
config_required : List [str ] = None ,
765
+ dataframe_type : Type = None ,
753
766
):
754
767
"""Instantiates a ``@with_columns`` decorator.
755
768
@@ -795,119 +808,64 @@ def __init__(
795
808
self .namespace = namespace
796
809
self .config_required = config_required
797
810
798
- def required_config (self ) -> List [str ]:
799
- return self .config_required
800
-
801
- def _create_column_nodes (
802
- self , inject_parameter : str , params : Dict [str , Type [Type ]]
803
- ) -> List [node .Node ]:
804
- output_type = params [inject_parameter ]
805
-
806
- if self .is_async :
807
-
808
- async def temp_fn (** kwargs ) -> Any :
809
- return kwargs [inject_parameter ]
810
- else :
811
-
812
- def temp_fn (** kwargs ) -> Any :
813
- return kwargs [inject_parameter ]
814
-
815
- # We recreate the df node to use extract columns
816
- temp_node = node .Node (
817
- name = inject_parameter ,
818
- typ = output_type ,
819
- callabl = temp_fn ,
820
- input_types = {inject_parameter : output_type },
821
- )
811
+ if dataframe_type is None :
812
+ raise InvalidDecoratorException (
813
+ "Please provide the dataframe type for this specific library."
814
+ )
822
815
823
- extract_columns_decorator = extract_columns ( * self . initial_schema )
816
+ self . dataframe_type = dataframe_type
824
817
825
- out_nodes = extract_columns_decorator . transform_node ( temp_node , config = {}, fn = temp_fn )
826
- return out_nodes [ 1 :]
818
+ def required_config ( self ) -> List [ str ]:
819
+ return self . config_required
827
820
828
- def _get_inital_nodes (
821
+ @abc .abstractmethod
822
+ def get_initial_nodes (
829
823
self , fn : Callable , params : Dict [str , Type [Type ]]
830
824
) -> Tuple [str , Collection [node .Node ]]:
831
- """Selects the correct dataframe and optionally extracts out columns."""
832
- initial_nodes = []
833
- sig = inspect .signature (fn )
834
- input_types = typing .get_type_hints (fn )
825
+ """Preparation stage where columns get extracted into nodes. In case `pass_dataframe_as` is
826
+ used, this should return an empty list (no column nodes) since the users will extract it
827
+ themselves.
828
+
829
+ :param fn: the function we are decorating. By using the inspect library you can get information.
830
+ about what arguments it has / find out the dataframe argument.
831
+ :param params: Dictionary of all the type names one wants to inject.
832
+ :return: name of the dataframe parameter and list of nodes representing the extracted columns (can be empty).
833
+ """
834
+ pass
835
835
836
- if self .dataframe_subdag_param is not None :
837
- inject_parameter = self .dataframe_subdag_param
838
- else :
839
- # If we don't have a specified dataframe we assume it's the first argument
840
- inject_parameter = list (sig .parameters .values ())[0 ].name
836
+ @abc .abstractmethod
837
+ def get_subdag_nodes (self , config : Dict [str , Any ]) -> Collection [node .Node ]:
838
+ """Creates subdag from the passed in module / functions.
841
839
842
- if inject_parameter not in params :
843
- raise base .InvalidDecoratorException (
844
- f"Function: { fn .__name__ } does not have the parameter { inject_parameter } as a dependency. "
845
- f"@with_columns requires the parameter names to match the function parameters. "
846
- f"If you wish do not wish to use the first argument, please use `pass_dataframe_as` option. "
847
- f"It might not be compatible with some other decorators."
848
- )
840
+ :param config: Configuration with which the DAG was constructed.
841
+ :return: the subdag as a list of nodes.
842
+ """
843
+ pass
849
844
850
- dataframe_type = input_types [inject_parameter ]
851
- initial_nodes = (
852
- []
853
- if self .dataframe_subdag_param is not None
854
- else self ._create_column_nodes (inject_parameter = inject_parameter , params = params )
855
- )
845
+ @abc .abstractmethod
846
+ def create_merge_node (self , fn : Callable , inject_parameter : str ) -> node .Node :
847
+ """Combines the origanl dataframe with selected columns. This should produce a
848
+ dataframe output that is injected into the decorated function with new columns
849
+ appended and existing columns overriden.
856
850
857
- return inject_parameter , initial_nodes , dataframe_type
858
-
859
- def create_merge_node (
860
- self , upstream_node : str , node_name : str , dataframe_type : Type
861
- ) -> node .Node :
862
- "Node that adds to / overrides columns for the original dataframe based on selected output."
863
- if self .is_async :
864
-
865
- async def new_callable (** kwargs ) -> Any :
866
- df = kwargs [upstream_node ]
867
- columns_to_append = {}
868
- for column in self .select :
869
- columns_to_append [column ] = kwargs [column ]
870
- new_df = registry .with_columns (df , columns_to_append )
871
- return new_df
872
- else :
873
-
874
- def new_callable (** kwargs ) -> Any :
875
- df = kwargs [upstream_node ]
876
- columns_to_append = {}
877
- for column in self .select :
878
- columns_to_append [column ] = kwargs [column ]
879
-
880
- new_df = registry .with_columns (df , columns_to_append )
881
- return new_df
882
-
883
- column_type = registry .get_column_type_from_df_type (dataframe_type )
884
- input_map = {column : column_type for column in self .select }
885
- input_map [upstream_node ] = dataframe_type
886
-
887
- return node .Node (
888
- name = node_name ,
889
- typ = dataframe_type ,
890
- callabl = new_callable ,
891
- input_types = input_map ,
892
- )
851
+ :param inject_parameter: the name of the original dataframe that.
852
+ :return: the new dataframe with the columns appended / overwritten.
853
+ """
854
+ pass
893
855
894
856
def inject_nodes (
895
857
self , params : Dict [str , Type [Type ]], config : Dict [str , Any ], fn : Callable
896
858
) -> Tuple [List [node .Node ], Dict [str , str ]]:
897
- self .is_async = inspect .iscoroutinefunction (fn )
898
859
namespace = fn .__name__ if self .namespace is None else self .namespace
899
860
900
- inject_parameter , initial_nodes , dataframe_type = self ._get_inital_nodes (
901
- fn = fn , params = params
902
- )
903
-
904
- subdag_nodes = subdag .collect_nodes (config , self .subdag_functions )
861
+ inject_parameter , initial_nodes = self .get_initial_nodes (fn = fn , params = params )
862
+ subdag_nodes = self .get_subdag_nodes (config = config )
905
863
906
864
# TODO: for now we restrict that if user wants to change columns that already exist, he needs to
907
865
# pass the dataframe and extract them himself. If we add namespace to initial nodes and rewire the
908
866
# initial node names with the ongoing ones that have a column argument, we can also allow in place
909
867
# changes when using columns_to_pass
910
- if with_columns ._check_for_duplicates (initial_nodes + subdag_nodes ):
868
+ if with_columns_factory ._check_for_duplicates (initial_nodes + subdag_nodes ):
911
869
raise ValueError (
912
870
"You can only specify columns once. You used `columns_to_pass` and we "
913
871
"extract the columns for you. In this case they cannot be overwritten -- only new columns get "
@@ -927,16 +885,11 @@ def inject_nodes(
927
885
self .select = [
928
886
sink_node .name
929
887
for sink_node in pruned_nodes
930
- if sink_node .type == registry .get_column_type_from_df_type (dataframe_type )
888
+ if sink_node .type == registry .get_column_type_from_df_type (self . dataframe_type )
931
889
]
932
890
933
- merge_node = self .create_merge_node (
934
- inject_parameter , node_name = "__append" , dataframe_type = dataframe_type
935
- )
891
+ merge_node = self .create_merge_node (fn = fn , inject_parameter = inject_parameter )
936
892
937
893
output_nodes = initial_nodes + pruned_nodes + [merge_node ]
938
894
output_nodes = subdag .add_namespace (output_nodes , namespace )
939
895
return output_nodes , {inject_parameter : assign_namespace (merge_node .name , namespace )}
940
-
941
- def validate (self , fn : Callable ):
942
- pass
0 commit comments