|
| 1 | +import abc |
1 | 2 | import inspect
|
2 | 3 | import sys
|
| 4 | +import typing |
| 5 | +from collections import defaultdict |
3 | 6 | from types import ModuleType
|
4 | 7 | from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, TypedDict, Union
|
5 | 8 |
|
|
11 | 14 | else:
|
12 | 15 | from typing import NotRequired
|
13 | 16 |
|
| 17 | +from pandas import DataFrame as PandasDataFrame |
| 18 | +from polars import DataFrame as PolarsDataFrame |
| 19 | +from polars import LazyFrame as PolarsLazyFrame |
14 | 20 |
|
15 | 21 | # Copied this over from function_graph
|
16 | 22 | # TODO -- determine the best place to put this code
|
17 |
| -from hamilton import graph_utils, node |
| 23 | +from hamilton import graph_utils, node, registry |
18 | 24 | from hamilton.function_modifiers import base, dependencies
|
19 | 25 | from hamilton.function_modifiers.base import InvalidDecoratorException, NodeTransformer
|
20 | 26 | from hamilton.function_modifiers.dependencies import (
|
21 | 27 | LiteralDependency,
|
22 | 28 | ParametrizedDependency,
|
23 | 29 | UpstreamDependency,
|
24 | 30 | )
|
| 31 | +from hamilton.function_modifiers.expanders import extract_columns |
25 | 32 |
|
26 | 33 |
|
27 | 34 | def assign_namespace(node_name: str, namespace: str) -> str:
|
@@ -626,3 +633,233 @@ def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> L
|
626 | 633 | stack.append(dep_node)
|
627 | 634 | seen_nodes.add(dep)
|
628 | 635 | return output
|
| 636 | + |
| 637 | + |
| 638 | +SUPPORTED_DATAFAME_TYPES = [PandasDataFrame, PolarsDataFrame, PolarsLazyFrame] |
| 639 | + |
| 640 | + |
| 641 | +class with_columns_factory(base.NodeInjector, abc.ABC): |
| 642 | + """Performs with_columns operation on a dataframe. This is a special case of NodeInjector |
| 643 | + that applies only to dataframes. For now can be used with: |
| 644 | +
|
| 645 | + - Pandas |
| 646 | + - Polars |
| 647 | +
|
| 648 | + This is used when you want to extract some columns out of the dataframe, perform operations |
| 649 | + on them and then append to the original dataframe. |
| 650 | +
|
| 651 | + def processed_data(data: pd.DataFrame) -> pd.DataFrame: |
| 652 | + ... |
| 653 | +
|
| 654 | + In this case we would build a subdag out of the node ``data`` and append selected nodes back to |
| 655 | + the original dataframe before feeding it into ``processed_data``. |
| 656 | + """ |
| 657 | + |
| 658 | + # TODO: if we rename the column nodes into something smarter this can be avoided and |
| 659 | + # can also modify columns in place |
| 660 | + @staticmethod |
| 661 | + def _check_for_duplicates(nodes_: List[node.Node]) -> bool: |
| 662 | + """Ensures that we don't run into name clashing of columns and group operations. |
| 663 | +
|
| 664 | + In the case when we extract columns for the user, because ``columns_to_pass`` was used, we want |
| 665 | + to safeguard against nameclashing with functions that are passed into ``with_columns`` - i.e. |
| 666 | + there are no functions that have the same name as the columns. This effectively means that |
| 667 | + using ``columns_to_pass`` will only append new columns to the dataframe and for changing |
| 668 | + existing columns ``pass_dataframe_as`` needs to be used. |
| 669 | + """ |
| 670 | + node_counter = defaultdict(int) |
| 671 | + for node_ in nodes_: |
| 672 | + node_counter[node_.name] += 1 |
| 673 | + if node_counter[node_.name] > 1: |
| 674 | + return True |
| 675 | + return False |
| 676 | + |
| 677 | + def validate_dataframe_type(self): |
| 678 | + if not set(self.allowed_dataframe_types).issubset(list(SUPPORTED_DATAFAME_TYPES)): |
| 679 | + raise InvalidDecoratorException( |
| 680 | + f"The provided dataframe types: {self.allowed_dataframe_types} are currently not supported " |
| 681 | + "to be used in `with_columns`. Please reach out if you need it. " |
| 682 | + f"We currently only support: {SUPPORTED_DATAFAME_TYPES}." |
| 683 | + ) |
| 684 | + |
| 685 | + def __init__( |
| 686 | + self, |
| 687 | + *load_from: Union[Callable, ModuleType], |
| 688 | + columns_to_pass: List[str] = None, |
| 689 | + pass_dataframe_as: str = None, |
| 690 | + select: List[str] = None, |
| 691 | + namespace: str = None, |
| 692 | + config_required: List[str] = None, |
| 693 | + dataframe_types: Collection[Type] = None, |
| 694 | + ): |
| 695 | + """Instantiates a ``@with_column`` decorator. |
| 696 | +
|
| 697 | + :param load_from: The functions or modules that will be used to generate the group of map operations. |
| 698 | + :param columns_to_pass: The initial schema of the dataframe. This is used to determine which |
| 699 | + upstream inputs should be taken from the dataframe, and which shouldn't. Note that, if this is |
| 700 | + left empty (and external_inputs is as well), we will assume that all dependencies come |
| 701 | + from the dataframe. This cannot be used in conjunction with pass_dataframe_as. |
| 702 | + :param pass_dataframe_as: The name of the dataframe that we're modifying, as known to the subdag. |
| 703 | + If you pass this in, you are responsible for extracting columns out. If not provided, you have |
| 704 | + to pass columns_to_pass in, and we will extract the columns out for you. |
| 705 | + :param select: The end nodes that represent columns to be appended to the original dataframe |
| 706 | + via with_columns. Existing columns will be overridden. |
| 707 | + :param namespace: The namespace of the nodes, so they don't clash with the global namespace |
| 708 | + and so this can be reused. If its left out, there will be no namespace (in which case you'll want |
| 709 | + to be careful about repeating it/reusing the nodes in other parts of the DAG.) |
| 710 | + :param config_required: the list of config keys that are required to resolve any functions. Pass in None\ |
| 711 | + if you want the functions/modules to have access to all possible config. |
| 712 | + """ |
| 713 | + |
| 714 | + if dataframe_types is None: |
| 715 | + raise ValueError("You need to specify which dataframe types it will be applied to.") |
| 716 | + else: |
| 717 | + if isinstance(dataframe_types, Type): |
| 718 | + dataframe_types = [dataframe_types] |
| 719 | + self.allowed_dataframe_types = dataframe_types |
| 720 | + self.validate_dataframe_type() |
| 721 | + |
| 722 | + self.subdag_functions = subdag.collect_functions(load_from) |
| 723 | + self.select = select |
| 724 | + |
| 725 | + if (pass_dataframe_as is not None and columns_to_pass is not None) or ( |
| 726 | + pass_dataframe_as is None and columns_to_pass is None |
| 727 | + ): |
| 728 | + raise ValueError( |
| 729 | + "You must specify only one of columns_to_pass and " |
| 730 | + "pass_dataframe_as. " |
| 731 | + "This is because specifying pass_dataframe_as injects into " |
| 732 | + "the set of columns, allowing you to perform your own extraction" |
| 733 | + "from the dataframe. We then execute all columns in the sbudag" |
| 734 | + "in order, passing in that initial dataframe. If you want" |
| 735 | + "to reference columns in your code, you'll have to specify " |
| 736 | + "the set of initial columns, and allow the subdag decorator " |
| 737 | + "to inject the dataframe through. The initial columns tell " |
| 738 | + "us which parameters to take from that dataframe, so we can" |
| 739 | + "feed the right data into the right columns." |
| 740 | + ) |
| 741 | + |
| 742 | + self.initial_schema = columns_to_pass |
| 743 | + self.dataframe_subdag_param = pass_dataframe_as |
| 744 | + self.namespace = namespace |
| 745 | + self.config_required = config_required |
| 746 | + |
| 747 | + def required_config(self) -> List[str]: |
| 748 | + return self.config_required |
| 749 | + |
| 750 | + def _create_column_nodes( |
| 751 | + self, inject_parameter: str, params: Dict[str, Type[Type]] |
| 752 | + ) -> List[node.Node]: |
| 753 | + output_type = params[inject_parameter] |
| 754 | + |
| 755 | + if self.is_async: |
| 756 | + |
| 757 | + async def temp_fn(**kwargs) -> Any: |
| 758 | + return kwargs[inject_parameter] |
| 759 | + else: |
| 760 | + |
| 761 | + def temp_fn(**kwargs) -> Any: |
| 762 | + return kwargs[inject_parameter] |
| 763 | + |
| 764 | + # We recreate the df node to use extract columns |
| 765 | + temp_node = node.Node( |
| 766 | + name=inject_parameter, |
| 767 | + typ=output_type, |
| 768 | + callabl=temp_fn, |
| 769 | + input_types={inject_parameter: output_type}, |
| 770 | + ) |
| 771 | + |
| 772 | + extract_columns_decorator = extract_columns(*self.initial_schema) |
| 773 | + |
| 774 | + out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn) |
| 775 | + return out_nodes[1:] |
| 776 | + |
| 777 | + def _get_inital_nodes( |
| 778 | + self, fn: Callable, params: Dict[str, Type[Type]] |
| 779 | + ) -> Tuple[str, Collection[node.Node]]: |
| 780 | + """Selects the correct dataframe and optionally extracts out columns.""" |
| 781 | + initial_nodes = [] |
| 782 | + sig = inspect.signature(fn) |
| 783 | + input_types = typing.get_type_hints(fn) |
| 784 | + |
| 785 | + if self.dataframe_subdag_param is not None: |
| 786 | + inject_parameter = self.dataframe_subdag_param |
| 787 | + else: |
| 788 | + # If we don't have a specified dataframe we assume it's the first argument |
| 789 | + inject_parameter = list(sig.parameters.values())[0].name |
| 790 | + |
| 791 | + if inject_parameter not in params: |
| 792 | + raise base.InvalidDecoratorException( |
| 793 | + f"Function: {fn.__name__} does not have the parameter {inject_parameter} as a dependency. " |
| 794 | + f"@with_columns requires the parameter names to match the function parameters. " |
| 795 | + f"If you wish do not wish to use the first argument, please use `pass_dataframe_as` option. " |
| 796 | + f"It might not be compatible with some other decorators." |
| 797 | + ) |
| 798 | + |
| 799 | + if input_types[inject_parameter] not in self.allowed_dataframe_types: |
| 800 | + raise ValueError(f"Dataframe has to be a {self.allowed_dataframe_types} DataFrame.") |
| 801 | + else: |
| 802 | + self.dataframe_type = input_types[inject_parameter] |
| 803 | + |
| 804 | + initial_nodes = ( |
| 805 | + [] |
| 806 | + if self.dataframe_subdag_param is not None |
| 807 | + else self._create_column_nodes(inject_parameter=inject_parameter, params=params) |
| 808 | + ) |
| 809 | + |
| 810 | + return inject_parameter, initial_nodes |
| 811 | + |
| 812 | + @abc.abstractmethod |
| 813 | + def create_merge_node(self, upstream_node: str, node_name: str) -> node.Node: |
| 814 | + """Node that adds to / overrides columns for the original dataframe based on selected output. |
| 815 | +
|
| 816 | + This will be platform specific, e.g. Polars already has with_columns whereas Pandas we need |
| 817 | + to implement it ourselves. |
| 818 | + """ |
| 819 | + pass |
| 820 | + |
| 821 | + def inject_nodes( |
| 822 | + self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable |
| 823 | + ) -> Tuple[List[node.Node], Dict[str, str]]: |
| 824 | + self.is_async = inspect.iscoroutinefunction(fn) |
| 825 | + namespace = fn.__name__ if self.namespace is None else self.namespace |
| 826 | + |
| 827 | + inject_parameter, initial_nodes = self._get_inital_nodes(fn=fn, params=params) |
| 828 | + |
| 829 | + subdag_nodes = subdag.collect_nodes(config, self.subdag_functions) |
| 830 | + |
| 831 | + # TODO: for now we restrict that if user wants to change columns that already exist, he needs to |
| 832 | + # pass the dataframe and extract them himself. If we add namespace to initial nodes and rewire the |
| 833 | + # initial node names with the ongoing ones that have a column argument, we can also allow in place |
| 834 | + # changes when using columns_to_pass |
| 835 | + if with_columns_factory._check_for_duplicates(initial_nodes + subdag_nodes): |
| 836 | + raise ValueError( |
| 837 | + "You can only specify columns once. You used `columns_to_pass` and we " |
| 838 | + "extract the columns for you. In this case they cannot be overwritten -- only new columns get " |
| 839 | + "appended. If you want to modify in-place columns pass in a dataframe and " |
| 840 | + "extract + modify the columns and afterwards select them." |
| 841 | + ) |
| 842 | + |
| 843 | + pruned_nodes = prune_nodes(subdag_nodes, self.select) |
| 844 | + if len(pruned_nodes) == 0: |
| 845 | + raise ValueError( |
| 846 | + f"No nodes found upstream from select columns: {self.select} for function: " |
| 847 | + f"{fn.__qualname__}" |
| 848 | + ) |
| 849 | + # In case no node is selected we append all possible nodes that have a column type matching |
| 850 | + # what the dataframe expects |
| 851 | + if self.select is None: |
| 852 | + self.select = [ |
| 853 | + sink_node.name |
| 854 | + for sink_node in pruned_nodes |
| 855 | + if sink_node.type == registry.get_column_type_from_df_type(self.dataframe_type) |
| 856 | + ] |
| 857 | + |
| 858 | + merge_node = self.create_merge_node(inject_parameter, node_name="__append") |
| 859 | + |
| 860 | + output_nodes = initial_nodes + pruned_nodes + [merge_node] |
| 861 | + output_nodes = subdag.add_namespace(output_nodes, namespace) |
| 862 | + return output_nodes, {inject_parameter: assign_namespace(merge_node.name, namespace)} |
| 863 | + |
| 864 | + def validate(self, fn: Callable): |
| 865 | + self.validate_dataframe_type() |
0 commit comments