Skip to content

Commit 30ca1a4

Browse files
rebase and lint
Co-authored-by: Thomas Chow <[email protected]>
1 parent 0ee9fc1 commit 30ca1a4

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

api/python/ai/chronon/dag_builder.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import Union
2+
3+
import ai.chronon.utils as c_utils
4+
from ai.chronon.api.common.ttypes import TableDependency, TableInfo, TimeUnit, Window
5+
from ai.chronon.api.ttypes import BootstrapPart, GroupBy, JoinPart, Source
6+
7+
"""
8+
Given a node that represents an upstream, turn it into a TableDependency.
9+
"""
10+
11+
12+
def to_dependency(node: Union[Source, GroupBy, JoinPart, BootstrapPart], lag: int = 0) -> TableDependency:
13+
if isinstance(node, JoinPart):
14+
groupby_dep = to_dependency(node.groupBy, lag=lag)
15+
return groupby_dep
16+
elif isinstance(node, GroupBy):
17+
ti = node.metaData.executionInfo.outputTableInfo
18+
return TableDependency(
19+
tableInfo=ti,
20+
startOffset=Window(timeUnit=TimeUnit.DAYS, length=0),
21+
endOffset=Window(timeUnit=TimeUnit.DAYS, length=0),
22+
startCutOff=None,
23+
endCutOff=None,
24+
forceCompute=False,
25+
)
26+
elif isinstance(node, BootstrapPart):
27+
partition_column = node.query.partitionColumn if node.query is not None else None
28+
table_info = TableInfo(
29+
table=node.table,
30+
partitionColumn=partition_column,
31+
)
32+
start_cutoff = node.query.startPartition if node.query is not None else None
33+
end_cutoff = node.query.endPartition if node.query is not None else None
34+
return TableDependency(
35+
tableInfo=table_info,
36+
startOffset=Window(timeUnit=TimeUnit.DAYS, length=0),
37+
endOffset=Window(timeUnit=TimeUnit.DAYS, length=0),
38+
startCutOff=start_cutoff,
39+
endCutOff=end_cutoff,
40+
forceCompute=False,
41+
)
42+
else: # When type of node is a Source
43+
table_name = c_utils.get_table(node)
44+
query = c_utils.get_query(node)
45+
partition_column = query.partitionColumn
46+
table_info = TableInfo(
47+
table=table_name,
48+
partitionColumn=partition_column,
49+
)
50+
start_cutoff = query.startPartition if query is not None else None
51+
end_cutoff = query.endPartition if query is not None else None
52+
return TableDependency(
53+
tableInfo=table_info,
54+
startOffset=Window(timeUnit=TimeUnit.DAYS, length=0),
55+
endOffset=Window(timeUnit=TimeUnit.DAYS, length=0),
56+
startCutOff=start_cutoff,
57+
endCutOff=end_cutoff,
58+
forceCompute=False,
59+
)

api/python/ai/chronon/group_by.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import ai.chronon.api.common.ttypes as common
2222
import ai.chronon.api.ttypes as ttypes
23+
import ai.chronon.dag_builder as dag_builder
2324
import ai.chronon.utils as utils
2425
import ai.chronon.windows as window_utils
2526

@@ -629,6 +630,7 @@ def _normalize_source(source):
629630
sources = [sources]
630631

631632
sources = [_sanitize_columns(_normalize_source(source)) for source in sources]
633+
table_deps = [dag_builder.to_dependency(s) for s in sources]
632634

633635
# get caller's filename to assign team
634636
team = inspect.stack()[1].filename.split("/")[-2]
@@ -639,6 +641,7 @@ def _normalize_source(source):
639641
env=env_vars,
640642
stepDays=step_days,
641643
historicalBackfill=disable_historical_backfill,
644+
tableDependencies=table_deps
642645
)
643646

644647
column_tags = {}

api/python/ai/chronon/join.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import ai.chronon.api.common.ttypes as common
2323
import ai.chronon.api.ttypes as api
24+
import ai.chronon.dag_builder as dag_builder
2425
import ai.chronon.repo.extract_objects as eo
2526
import ai.chronon.utils as utils
2627

@@ -510,6 +511,13 @@ def Join(
510511
"""
511512
# create a deep copy for case: multiple LeftOuterJoin use the same left,
512513
# validation will fail after the first iteration
514+
515+
if bootstrap_parts is None:
516+
bootstrap_parts = []
517+
if online_external_parts is None:
518+
online_external_parts = []
519+
if right_parts is None:
520+
right_parts = []
513521
updated_left = copy.deepcopy(left)
514522
if left.events and left.events.query.selects:
515523
assert "ts" not in left.events.query.selects.keys(), (
@@ -562,12 +570,17 @@ def Join(
562570
)
563571
]
564572

573+
table_deps = [dag_builder.to_dependency(left)] \
574+
+ [dag_builder.to_dependency(right_part) for right_part in right_parts] \
575+
+ [dag_builder.to_dependency(b_part) for b_part in bootstrap_parts]
576+
565577
exec_info = common.ExecutionInfo(
566578
scheduleCron=offline_schedule,
567579
conf=conf,
568580
env=env_vars,
569581
stepDays=step_days,
570582
historicalBackfill=historical_backfill,
583+
tableDependencies=table_deps
571584
)
572585

573586
metadata = api.MetaData(

0 commit comments

Comments
 (0)