Skip to content

[airflow] -- add dependencies for airflow to customJson #648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions api/python/ai/chronon/airflow_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@

import json

import ai.chronon.utils as utils
from ai.chronon.api.ttypes import GroupBy, Join


def create_airflow_dependency(table, partition_column):
"""
Create an Airflow dependency object for a table.

Args:
table: The table name (with namespace)
partition_column: The partition column to use (defaults to 'ds')

Returns:
A dictionary with name and spec for the Airflow dependency
"""
# Default partition column to 'ds' if not specified
partition_col = partition_column or 'ds'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't default to ds we'll need to default to it in teams.py

return {
"name": f"wf_{utils.sanitize(table)}",
"spec": f"{table}/{partition_col}={{{{ ds }}}}",
}

def _get_partition_col_from_query(query):
"""Gets partition column from query if available"""
if query:
return query.partitionColumn
return None

def _get_airflow_deps_from_source(source, partition_column=None):
"""
Given a source, return a list of Airflow dependencies.

Args:
source: The source object (events, entities, or joinSource)
partition_column: The partition column to use

Returns:
A list of Airflow dependency objects
"""
tables = []
# Assumes source has already been normalized
if source.events:
tables = [source.events.table]
# Use partition column from query if available, otherwise use the provided one
source_partition_column = _get_partition_col_from_query(source.events.query) or partition_column
elif source.entities:
# Given the setup of Query, we currently mandate the same partition column for snapshot and mutations tables
tables = [source.entities.snapshotTable]
if source.entities.mutationTable:
tables.append(source.entities.mutationTable)
source_partition_column = _get_partition_col_from_query(source.entities.query) or partition_column
elif source.joinSource:
namespace = source.joinSource.join.metaData.outputNamespace
table = utils.sanitize(source.joinSource.join.metaData.name)
tables = [f"{namespace}.{table}"]
source_partition_column = _get_partition_col_from_query(source.joinSource.query) or partition_column
else:
# Unknown source type
return []

return [create_airflow_dependency(table, source_partition_column) for table in tables]


def extract_default_partition_column(obj):
return obj.metaData.executionInfo.env.common.get("partitionColumn")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for missing metadata

The function assumes obj.metaData.executionInfo.env.common exists and is a dictionary.

def extract_default_partition_column(obj):
-    return obj.metaData.executionInfo.env.common.get("partitionColumn")
+    try:
+        return obj.metaData.executionInfo.env.common.get("PARTITION_COLUMN", "ds").lower()
+    except (AttributeError, TypeError):
+        return "ds"  # Default fallback
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def extract_default_partition_column(obj):
return obj.metaData.executionInfo.env.common.get("partitionColumn")
def extract_default_partition_column(obj):
try:
return obj.metaData.executionInfo.env.common.get("PARTITION_COLUMN", "ds").lower()
except (AttributeError, TypeError):
return "ds" # Default fallback

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to grab this from spark.chronon.partition.column



def _set_join_deps(join):
default_partition_col = extract_default_partition_column(join)

deps = []

# Handle left source
left_query = utils.get_query(join.left)
left_partition_column = _get_partition_col_from_query(left_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(join.left, left_partition_column))

# Handle right parts (join parts)
if join.joinParts:
for join_part in join.joinParts:
if join_part.groupBy and join_part.groupBy.sources:
for source in join_part.groupBy.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))

# Handle label parts
if join.labelParts and join.labelParts.labels:
for label_part in join.labelParts.labels:
if label_part.groupBy and label_part.groupBy.sources:
for source in label_part.groupBy.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))

# Update the metadata customJson with dependencies
_set_airflow_deps_json(join, deps)

Comment on lines +79 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Deduplicate dependencies in Join

Dependencies collected from multiple sources may contain duplicates.

Add deduplication before updating metadata:

def _set_join_deps(join):
    default_partition_col = extract_default_partition_column(join)

    deps = []

    # Handle left source
    left_query = utils.get_query(join.left)
    left_partition_column = _get_partition_col_from_query(left_query) or default_partition_col
    deps.extend(_get_airflow_deps_from_source(join.left, left_partition_column))

    # Handle right parts (join parts)
    if join.joinParts:
        for join_part in join.joinParts:
            if join_part.groupBy and join_part.groupBy.sources:
                for source in join_part.groupBy.sources:
                    source_query = utils.get_query(source)
                    source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
                    deps.extend(_get_airflow_deps_from_source(source, source_partition_column))

    # Handle label parts
    if join.labelParts and join.labelParts.labels:
        for label_part in join.labelParts.labels:
            if label_part.groupBy and label_part.groupBy.sources:
                for source in label_part.groupBy.sources:
                    source_query = utils.get_query(source)
                    source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
                    deps.extend(_get_airflow_deps_from_source(source, source_partition_column))

+    # Deduplicate dependencies by converting to dict using name as key and back to list
+    unique_deps = {dep["name"]: dep for dep in deps}.values()
+    deps = list(unique_deps)
+
    # Update the metadata customJson with dependencies
    _set_airflow_deps_json(join, deps)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _set_join_deps(join):
default_partition_col = extract_default_partition_column(join)
deps = []
# Handle left source
left_query = utils.get_query(join.left)
left_partition_column = _get_partition_col_from_query(left_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(join.left, left_partition_column))
# Handle right parts (join parts)
if join.joinParts:
for join_part in join.joinParts:
if join_part.groupBy and join_part.groupBy.sources:
for source in join_part.groupBy.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
# Handle label parts
if join.labelParts and join.labelParts.labels:
for label_part in join.labelParts.labels:
if label_part.groupBy and label_part.groupBy.sources:
for source in label_part.groupBy.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
# Update the metadata customJson with dependencies
_set_airflow_deps_json(join, deps)
def _set_join_deps(join):
default_partition_col = extract_default_partition_column(join)
deps = []
# Handle left source
left_query = utils.get_query(join.left)
left_partition_column = _get_partition_col_from_query(left_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(join.left, left_partition_column))
# Handle right parts (join parts)
if join.joinParts:
for join_part in join.joinParts:
if join_part.groupBy and join_part.groupBy.sources:
for source in join_part.groupBy.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
# Handle label parts
if join.labelParts and join.labelParts.labels:
for label_part in join.labelParts.labels:
if label_part.groupBy and label_part.groupBy.sources:
for source in label_part.groupBy.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
# Deduplicate dependencies by converting to dict using name as key and back to list
unique_deps = {dep["name"]: dep for dep in deps}.values()
deps = list(unique_deps)
# Update the metadata customJson with dependencies
_set_airflow_deps_json(join, deps)


def _set_group_by_deps(group_by):
if not group_by.sources:
return

default_partition_col = extract_default_partition_column(group_by)

deps = []

# Process each source in the group_by
for source in group_by.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))

# Update the metadata customJson with dependencies
_set_airflow_deps_json(group_by, deps)

Comment on lines +111 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Deduplicate dependencies in GroupBy

Similar to Join, GroupBy should deduplicate dependencies.

def _set_group_by_deps(group_by):
    if not group_by.sources:
        return
    
    default_partition_col = extract_default_partition_column(group_by)
    
    deps = []
    
    # Process each source in the group_by
    for source in group_by.sources:
        source_query = utils.get_query(source)
        source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
        deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
    
+    # Deduplicate dependencies
+    unique_deps = {dep["name"]: dep for dep in deps}.values()
+    deps = list(unique_deps)
+
    # Update the metadata customJson with dependencies
    _set_airflow_deps_json(group_by, deps)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _set_group_by_deps(group_by):
if not group_by.sources:
return
default_partition_col = extract_default_partition_column(group_by)
deps = []
# Process each source in the group_by
for source in group_by.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
# Update the metadata customJson with dependencies
_set_airflow_deps_json(group_by, deps)
def _set_group_by_deps(group_by):
if not group_by.sources:
return
default_partition_col = extract_default_partition_column(group_by)
deps = []
# Process each source in the group_by
for source in group_by.sources:
source_query = utils.get_query(source)
source_partition_column = _get_partition_col_from_query(source_query) or default_partition_col
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
# Deduplicate dependencies
unique_deps = {dep["name"]: dep for dep in deps}.values()
deps = list(unique_deps)
# Update the metadata customJson with dependencies
_set_airflow_deps_json(group_by, deps)


def _set_airflow_deps_json(obj, deps):
existing_json = obj.metaData.customJson or "{}"
json_map = json.loads(existing_json)
json_map["airflowDependencies"] = deps
obj.metaData.customJson = json.dumps(json_map)
Comment on lines +129 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add JSON error handling

Function will fail if customJson contains invalid JSON.

def _set_airflow_deps_json(obj, deps):
    existing_json = obj.metaData.customJson or "{}"
-    json_map = json.loads(existing_json)
+    try:
+        json_map = json.loads(existing_json)
+    except json.JSONDecodeError:
+        json_map = {}
    json_map["airflowDependencies"] = deps
    obj.metaData.customJson = json.dumps(json_map)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _set_airflow_deps_json(obj, deps):
existing_json = obj.metaData.customJson or "{}"
json_map = json.loads(existing_json)
json_map["airflowDependencies"] = deps
obj.metaData.customJson = json.dumps(json_map)
def _set_airflow_deps_json(obj, deps):
existing_json = obj.metaData.customJson or "{}"
try:
json_map = json.loads(existing_json)
except json.JSONDecodeError:
json_map = {}
json_map["airflowDependencies"] = deps
obj.metaData.customJson = json.dumps(json_map)


def set_airflow_deps(obj):
"""
Set Airflow dependencies for a Chronon object.

Args:
obj: A Join, GroupBy
"""
# StagingQuery dependency setting is handled directly in object init
if isinstance(obj, Join):
_set_join_deps(obj)
elif isinstance(obj, GroupBy):
_set_group_by_deps(obj)
4 changes: 4 additions & 0 deletions api/python/ai/chronon/cli/compile/parse_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from typing import List

from ai.chronon import airflow_helpers
from ai.chronon.cli.compile import parse_teams, serializer
from ai.chronon.cli.compile.compile_context import CompileContext
from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
Expand All @@ -30,6 +31,9 @@ def from_folder(

for name, obj in results_dict.items():
parse_teams.update_metadata(obj, compile_context.teams_dict)
# Airflow deps must be set AFTER updating metadata
airflow_helpers.set_airflow_deps(obj)

obj.metaData.sourceFile = os.path.relpath(f, compile_context.chronon_root)

tjson = serializer.thrift_simple_json(obj)
Expand Down
123 changes: 123 additions & 0 deletions api/python/ai/chronon/staging_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@

import inspect
import json
from dataclasses import dataclass
from typing import Dict, List, Optional

import ai.chronon.airflow_helpers as airflow_helpers
import ai.chronon.api.common.ttypes as common
import ai.chronon.api.ttypes as ttypes


# Wrapper for EngineType
class EngineType:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually we can kill this for now, not using.

SPARK = ttypes.EngineType.SPARK
BIGQUERY = ttypes.EngineType.BIGQUERY

@dataclass
class TableDependency:
table: str
partition_column: Optional[str] = None

def StagingQuery(
name: str,
query: str,
output_namespace: Optional[str] = None,
start_partition: Optional[str] = None,
table_properties: Optional[Dict[str, str]] = None,
setups: Optional[List[str]] = None,
partition_column: Optional[str] = None,
engine_type: Optional[EngineType] = None,
dependencies: Optional[List[TableDependency]] = None,
tags: Optional[Dict[str, str]] = None,
# execution params
offline_schedule: str = "@daily",
conf: Optional[common.ConfigProperties] = None,
env_vars: Optional[common.EnvironmentVariables] = None,
step_days: Optional[int] = None,
) -> ttypes.StagingQuery:
"""
Creates a StagingQuery object for executing arbitrary SQL queries with templated date parameters.

:param query:
Arbitrary spark query that should be written with template parameters:
- `{{ start_date }}`: Initial run uses start_partition, future runs use latest partition + 1 day
- `{{ end_date }}`: The end partition of the computing range
- `{{ latest_date }}`: End partition independent of the computing range (for cumulative sources)
- `{{ max_date(table=namespace.my_table) }}`: Max partition available for a given table
These parameters can be modified with offset and bounds:
- `{{ start_date(offset=-10, lower_bound='2023-01-01', upper_bound='2024-01-01') }}`
:type query: str
:param start_partition:
On the first run, `{{ start_date }}` will be set to this user provided start date,
future incremental runs will set it to the latest existing partition + 1 day.
:type start_partition: str
:param setups:
Spark SQL setup statements. Used typically to register UDFs.
:type setups: List[str]
:param partition_column:
Only needed for `max_date` template
:type partition_column: str
:param engine_type:
By default, spark is the compute engine. You can specify an override (eg. bigquery, etc.)
Use the EngineType class constants: EngineType.SPARK, EngineType.BIGQUERY, etc.
:type engine_type: int
:param tags:
Additional metadata that does not directly affect computation, but is useful for management.
:type tags: Dict[str, str]
:param offline_schedule:
The offline schedule interval for batch jobs. Format examples:
'@hourly': '0 * * * *',
'@daily': '0 0 * * *',
'@weekly': '0 0 * * 0',
'@monthly': '0 0 1 * *',
'@yearly': '0 0 1 1 *'
:type offline_schedule: str
:param conf:
Configuration properties for the StagingQuery.
:type conf: common.ConfigProperties
:param env_vars:
Environment variables for the StagingQuery.
:type env_vars: common.EnvironmentVariables
:param step_days:
The maximum number of days to process at once
:type step_days: int
:return:
A StagingQuery object
"""
# Get caller's filename to assign team
team = inspect.stack()[1].filename.split("/")[-2]

# Create execution info
exec_info = common.ExecutionInfo(
scheduleCron=offline_schedule,
conf=conf,
env=env_vars,
stepDays=step_days,
)

airflow_dependencies = [airflow_helpers.create_airflow_dependency(t.table, t.partition_column) for t in dependencies] if dependencies else []
custom_json = json.dumps({"airflow_dependencies": airflow_dependencies})

Comment on lines +99 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Rename variable to avoid shadowing imported module.

Variable shadows imported function.

-    airflow_dependencies = [airflow_helpers.create_airflow_dependency(t.table, t.partition_column) for t in dependencies] if dependencies else []
-    custom_json = json.dumps({"airflow_dependencies": airflow_dependencies})
+    airflow_dependencies = [airflow_helpers.create_airflow_dependency(t.table, t.partition_column) for t in dependencies] if dependencies else []
+    custom_json_str = json.dumps({"airflow_dependencies": airflow_dependencies})

Also update line 110:

-        customJson=custom_json,
+        customJson=custom_json_str,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
airflow_dependencies = [airflow_helpers.create_airflow_dependency(t.table, t.partition_column) for t in dependencies] if dependencies else []
custom_json = json.dumps({"airflow_dependencies": airflow_dependencies})
airflow_dependencies = [airflow_helpers.create_airflow_dependency(t.table, t.partition_column) for t in dependencies] if dependencies else []
custom_json_str = json.dumps({"airflow_dependencies": airflow_dependencies})
...
customJson=custom_json_str,
🧰 Tools
🪛 Ruff (0.8.2)

101-101: Redefinition of unused custom_json from line 9

(F811)

# Create metadata
meta_data = ttypes.MetaData(
name=name,
outputNamespace=output_namespace,
team=team,
executionInfo=exec_info,
tags=tags,
customJson=custom_json,
tableProperties=table_properties,
)

# Create and return the StagingQuery object with camelCase parameter names
staging_query = ttypes.StagingQuery(
metaData=meta_data,
query=query,
startPartition=start_partition,
setups=setups,
partitionColumn=partition_column,
engineType=engine_type,
)

return staging_query
12 changes: 7 additions & 5 deletions api/python/test/sample/staging_queries/kaggle/outbrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ai.chronon.api.ttypes import MetaData, StagingQuery
from ai.chronon.staging_query import StagingQuery, TableDependency

base_table = StagingQuery(
name='outbrain_left',
query="""
SELECT
clicks_train.display_id,
Expand All @@ -35,8 +36,9 @@
AND ABS(HASH(clicks_train.display_id)) % 100 < 5
AND ABS(HASH(events.display_id)) % 100 < 5
""",
metaData=MetaData(
name='outbrain_left',
outputNamespace="default",
)
output_namespace="default",
dependencies=[
TableDependency(table="kaggle_outbrain.clicks_train", partition_column="ds"),
TableDependency(table="kaggle_outbrain.events", partition_column="ds")
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ai.chronon.api.ttypes import MetaData, StagingQuery
from ai.chronon.staging_query import StagingQuery, TableDependency

query = """
SELECT
Expand All @@ -30,11 +30,13 @@
WHERE purchases.ds BETWEEN '{{ start_date }}' AND '{{ end_date }}'
"""

staging_query = StagingQuery(
checkouts_query = StagingQuery(
query=query,
startPartition="2023-10-31",
metaData=MetaData(
name='checkouts_staging_query',
outputNamespace="data"
),
start_partition="2023-10-31",
name='checkouts_staging_query',
output_namespace="data",
dependencies=[
TableDependency(table="data.purchases", partition_column="ds"),
TableDependency(table="data.checkouts_external", partition_column="ds")
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ai.chronon.types import MetaData, StagingQuery
from ai.chronon.staging_query import StagingQuery, TableDependency

query = """
SELECT
Expand All @@ -28,15 +28,14 @@

v1 = StagingQuery(
query=query,
startPartition="2020-03-01",
start_partition="2020-03-01",
setups=[
"CREATE TEMPORARY FUNCTION S2_CELL AS 'com.sample.hive.udf.S2CellId'",
],
metaData=MetaData(
name="sample_staging_query",
outputNamespace="sample_namespace",
tableProperties={
"sample_config_json": """{"sample_key": "sample value}""",
},
),
)
name="sample_staging_query",
output_namespace="sample_namespace",
table_properties={"sample_config_json": """{"sample_key": "sample value}"""},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix JSON syntax.
Missing quote in sample_config_json. Likely a parsing error.

dependencies=[
TableDependency(table="sample_namespace.sample_table", partition_column="ds")
],
)