|
| 1 | +import dataclasses |
| 2 | +import inspect |
| 3 | +import os |
| 4 | +from typing import List, Optional, Tuple, Type |
| 5 | + |
| 6 | +import git |
| 7 | +from docutils import nodes |
| 8 | +from docutils.parsers.rst import Directive |
| 9 | + |
| 10 | +import hamilton.io.data_adapters |
| 11 | +from hamilton import registry |
| 12 | + |
| 13 | +"""A module to crawl available data adapters and generate documentation for them. |
| 14 | +Note these currently link out to the source code on GitHub, but they should |
| 15 | +be linking to the documentation instead, which hasn't been generated yet. |
| 16 | +""" |
| 17 | + |
| 18 | +# These have fallbacks for local dev |
| 19 | +GIT_URL = os.environ.get("READTHEDOCS_GIT_CLONE_URL", "https://github.com/dagworks-inc/hamilton") |
| 20 | +GIT_ID = os.environ.get("READTHEDOCS_GIT_IDENTIFIER", "main") |
| 21 | + |
| 22 | +# All the modules that register data adapters |
| 23 | +# When you register a new one, add it here |
| 24 | +MODULES_TO_IMPORT = ["hamilton.io.default_data_loaders", "hamilton.plugins.pandas_extensions"] |
| 25 | + |
| 26 | +for module in MODULES_TO_IMPORT: |
| 27 | + __import__(module) |
| 28 | + |
| 29 | + |
| 30 | +def get_git_root(path: str) -> str: |
| 31 | + """Yields the git room of a repo, given an absolute path to |
| 32 | + a file within the repo. |
| 33 | +
|
| 34 | + :param path: Path to a file within a git repo |
| 35 | + :return: The root of the git repo |
| 36 | + """ |
| 37 | + git_repo = git.Repo(path, search_parent_directories=True) |
| 38 | + git_root = git_repo.git.rev_parse("--show-toplevel") |
| 39 | + return git_root |
| 40 | + |
| 41 | + |
| 42 | +@dataclasses.dataclass |
| 43 | +class Param: |
| 44 | + name: str |
| 45 | + type: str |
| 46 | + default: Optional[str] = None |
| 47 | + |
| 48 | + |
| 49 | +def get_default(param: dataclasses.Field) -> Optional[str]: |
| 50 | + """Gets the deafult of a dataclass field, if it has one. |
| 51 | +
|
| 52 | + :param param: The dataclass field |
| 53 | + :return: The str representation of the default. |
| 54 | + """ |
| 55 | + if param.default is dataclasses.MISSING: |
| 56 | + return None |
| 57 | + return str(param.default) |
| 58 | + |
| 59 | + |
| 60 | +def get_lines_for_class(class_: Type[Type]) -> Tuple[int, int]: |
| 61 | + """Gets the set of lines in which a class is implemented |
| 62 | +
|
| 63 | + :param class_: The class to get the lines for |
| 64 | + :return: A tuple of the start and end lines |
| 65 | + """ |
| 66 | + lines = inspect.getsourcelines(class_) |
| 67 | + start_line = lines[1] |
| 68 | + end_line = lines[1] + len(lines[0]) |
| 69 | + return start_line, end_line |
| 70 | + |
| 71 | + |
| 72 | +def get_class_repr(class_: Type) -> str: |
| 73 | + """Gets a representation of a class that can be used in documentation. |
| 74 | +
|
| 75 | + :param class_: Python class to get the representation for |
| 76 | + :return: Str representation |
| 77 | + """ |
| 78 | + |
| 79 | + try: |
| 80 | + return class_.__qualname__ |
| 81 | + except AttributeError: |
| 82 | + # This happens when we have generics or other oddities |
| 83 | + return str(class_) |
| 84 | + |
| 85 | + |
| 86 | +@dataclasses.dataclass |
| 87 | +class AdapterInfo: |
| 88 | + key: str |
| 89 | + class_name: str |
| 90 | + class_path: str |
| 91 | + load_params: List[Param] |
| 92 | + save_params: List[Param] |
| 93 | + applicable_types: List[str] |
| 94 | + file_: str |
| 95 | + line_nos: Tuple[int, int] |
| 96 | + |
| 97 | + @staticmethod |
| 98 | + def from_loader(loader: Type[hamilton.io.data_adapters.DataLoader]) -> "AdapterInfo": |
| 99 | + """Utility constructor to create the AdapterInfo from a DataLoader class |
| 100 | +
|
| 101 | + :param loader: DataLoader class |
| 102 | + :return: AdapterInfo derived from it |
| 103 | + """ |
| 104 | + |
| 105 | + return AdapterInfo( |
| 106 | + key=loader.name(), |
| 107 | + class_name=loader.__name__, |
| 108 | + class_path=loader.__module__, |
| 109 | + load_params=[ |
| 110 | + Param(name=p.name, type=get_class_repr(p.type), default=get_default(p)) |
| 111 | + for p in dataclasses.fields(loader) |
| 112 | + ] |
| 113 | + if issubclass(loader, hamilton.io.data_adapters.DataSaver) |
| 114 | + else None, |
| 115 | + save_params=[ |
| 116 | + Param(name=p.name, type=get_class_repr(p.type), default=get_default(p)) |
| 117 | + for p in dataclasses.fields(loader) |
| 118 | + ] |
| 119 | + if issubclass(loader, hamilton.io.data_adapters.DataSaver) |
| 120 | + else None, |
| 121 | + applicable_types=[get_class_repr(t) for t in loader.applicable_types()], |
| 122 | + file_=inspect.getfile(loader), |
| 123 | + line_nos=get_lines_for_class(loader), |
| 124 | + ) |
| 125 | + |
| 126 | + |
| 127 | +def _collect_loaders(saver_or_loader: str) -> List[Type[hamilton.io.data_adapters.AdapterCommon]]: |
| 128 | + """Collects all loaders from the registry. |
| 129 | +
|
| 130 | + :return: |
| 131 | + """ |
| 132 | + out = [] |
| 133 | + loaders = ( |
| 134 | + list(registry.LOADER_REGISTRY.values()) |
| 135 | + if saver_or_loader == "loader" |
| 136 | + else list(registry.SAVER_REGISTRY.values()) |
| 137 | + ) |
| 138 | + for classes in loaders: |
| 139 | + for cls in classes: |
| 140 | + if cls not in out: |
| 141 | + out.append(cls) |
| 142 | + return out |
| 143 | + |
| 144 | + |
| 145 | +# Utility functions to render different components of the adapter in table cells |
| 146 | + |
| 147 | + |
| 148 | +def render_key(key: str): |
| 149 | + return [nodes.Text(key, key)] |
| 150 | + |
| 151 | + |
| 152 | +def render_class_name(class_name: str): |
| 153 | + return [nodes.literal(text=class_name)] |
| 154 | + |
| 155 | + |
| 156 | +def render_class_path(class_path: str, file_: str, line_start: int, line_end: int): |
| 157 | + git_path = get_git_root(file_) |
| 158 | + file_relative_to_git_root = os.path.relpath(file_, git_path) |
| 159 | + href = f"{GIT_URL}/blob/{GIT_ID}/{file_relative_to_git_root}#L{line_start}-L{line_end}" |
| 160 | + # href = f"{GIT_URL}/blob/{GIT_ID}/{file_}#L{line_no}" |
| 161 | + return [nodes.raw("", f'<a href="{href}">{class_path}</a>', format="html")] |
| 162 | + |
| 163 | + |
| 164 | +def render_adapter_params(load_params: Optional[List[Param]]): |
| 165 | + if load_params is None: |
| 166 | + return nodes.raw("", "<div/>", format="html") |
| 167 | + fieldlist = nodes.field_list() |
| 168 | + for i, load_param in enumerate(load_params): |
| 169 | + fieldname = nodes.Text(load_param.name) |
| 170 | + fieldbody = nodes.literal( |
| 171 | + text=load_param.type |
| 172 | + + ("=" + load_param.default if load_param.default is not None else "") |
| 173 | + ) |
| 174 | + field = nodes.field("", fieldname, fieldbody) |
| 175 | + fieldlist += field |
| 176 | + if i < len(load_params) - 1: |
| 177 | + fieldlist += nodes.raw("", "<br/>", format="html") |
| 178 | + return fieldlist |
| 179 | + |
| 180 | + |
| 181 | +def render_applicable_types(applicable_types: List[str]): |
| 182 | + fieldlist = nodes.field_list() |
| 183 | + for applicable_type in applicable_types: |
| 184 | + fieldlist += nodes.field("", nodes.literal(text=applicable_type), nodes.Text("")) |
| 185 | + fieldlist += nodes.raw("", "<br/>", format="html") |
| 186 | + return fieldlist |
| 187 | + |
| 188 | + |
| 189 | +class DataAdapterTableDirective(Directive): |
| 190 | + """Custom directive to render a table of all data adapters. Takes in one argument |
| 191 | + that is either 'loader' or 'saver' to indicate which adapters to render.""" |
| 192 | + |
| 193 | + has_content = True |
| 194 | + required_arguments = 1 # Number of required arguments |
| 195 | + |
| 196 | + def run(self): |
| 197 | + """Runs the directive. This does the following: |
| 198 | + 1. Collects all loaders from the registry |
| 199 | + 2. Creates a table with the following columns: |
| 200 | + - Key |
| 201 | + - Class name |
| 202 | + - Class path |
| 203 | + - Load params |
| 204 | + - Applicable types |
| 205 | + 3. Returns the table |
| 206 | + :return: A list of nodes that Sphinx will render, consisting of the table node |
| 207 | + """ |
| 208 | + saver_or_loader = self.arguments[0] |
| 209 | + if saver_or_loader not in ("loader", "saver"): |
| 210 | + raise ValueError( |
| 211 | + f"loader_or_saver must be one of 'loader' or 'saver', " f"got {saver_or_loader}" |
| 212 | + ) |
| 213 | + table_data = [ |
| 214 | + AdapterInfo.from_loader(loader) for loader in _collect_loaders(saver_or_loader) |
| 215 | + ] |
| 216 | + |
| 217 | + # Create the table and add columns |
| 218 | + table_node = nodes.table() |
| 219 | + tgroup = nodes.tgroup(cols=6) |
| 220 | + table_node += tgroup |
| 221 | + |
| 222 | + # Create columns |
| 223 | + key_spec = nodes.colspec(colwidth=1) |
| 224 | + # class_spec = nodes.colspec(colwidth=1) |
| 225 | + load_params_spec = nodes.colspec(colwidth=2) |
| 226 | + applicable_types_spec = nodes.colspec(colwidth=1) |
| 227 | + class_path_spec = nodes.colspec(colwidth=1) |
| 228 | + |
| 229 | + tgroup += [key_spec, load_params_spec, applicable_types_spec, class_path_spec] |
| 230 | + |
| 231 | + # Create the table body |
| 232 | + thead = nodes.thead() |
| 233 | + row = nodes.row() |
| 234 | + |
| 235 | + # Create entry nodes for each cell |
| 236 | + key_entry = nodes.entry() |
| 237 | + load_params_entry = nodes.entry() |
| 238 | + applicable_types_entry = nodes.entry() |
| 239 | + class_path_entry = nodes.entry() |
| 240 | + |
| 241 | + key_entry += nodes.paragraph(text="key") |
| 242 | + |
| 243 | + load_params_entry += nodes.paragraph(text=f"{saver_or_loader} params") |
| 244 | + applicable_types_entry += nodes.paragraph(text="types") |
| 245 | + class_path_entry += nodes.paragraph(text="module") |
| 246 | + |
| 247 | + row += [key_entry, load_params_entry, applicable_types_entry, class_path_entry] |
| 248 | + thead += row |
| 249 | + tgroup += thead |
| 250 | + tbody = nodes.tbody() |
| 251 | + tgroup += tbody |
| 252 | + |
| 253 | + # Populate table rows based on your table_data |
| 254 | + for row_data in table_data: |
| 255 | + row = nodes.row() |
| 256 | + |
| 257 | + # Create entry nodes for each cell |
| 258 | + key_entry = nodes.entry() |
| 259 | + load_params_entry = nodes.entry() |
| 260 | + applicable_types_entry = nodes.entry() |
| 261 | + class_path_entry = nodes.entry() |
| 262 | + |
| 263 | + # Create a paragraph node for each entry |
| 264 | + # import pdb |
| 265 | + # pdb.set_trace() |
| 266 | + # para1 = nodes.literal(text=row_data['column1_data']) |
| 267 | + # para2 = nodes.paragraph(text=row_data['column2_data']) |
| 268 | + |
| 269 | + # Add the paragraph nodes to the entry nodes |
| 270 | + key_entry += render_key(row_data.key) |
| 271 | + load_params_entry += render_adapter_params(row_data.load_params) |
| 272 | + applicable_types_entry += render_applicable_types(row_data.applicable_types) |
| 273 | + class_path_entry += render_class_path( |
| 274 | + row_data.class_path, row_data.file_, *row_data.line_nos |
| 275 | + ) |
| 276 | + |
| 277 | + # Add the entry nodes to the row |
| 278 | + row += [key_entry, load_params_entry, applicable_types_entry, class_path_entry] |
| 279 | + |
| 280 | + # Add the row to the table body |
| 281 | + tbody += row |
| 282 | + |
| 283 | + return [table_node] |
| 284 | + |
| 285 | + |
| 286 | +def setup(app): |
| 287 | + """Required to register the extension""" |
| 288 | + app.add_directive("data_adapter_table", DataAdapterTableDirective) |
0 commit comments