Skip to content

Commit a07a98e

Browse files
Implement DbtCloudJobRunOperatorAsync and DbtCloudJobRunSensorAsync (#623)
1 parent 08a9943 commit a07a98e

File tree

28 files changed

+995
-0
lines changed

28 files changed

+995
-0
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ Extras
8383
- ``pip install 'astronomer-providers[databricks]'``
8484
- Databricks
8585

86+
* - ``dbt.cloud``
87+
- ``pip install 'astronomer-providers[dbt.cloud]'``
88+
- Dbt Cloud
89+
8690
* - ``google``
8791
- ``pip install 'astronomer-providers[google]'``
8892
- Google

astronomer/providers/dbt/__init__.py

Whitespace-only changes.

astronomer/providers/dbt/cloud/__init__.py

Whitespace-only changes.

astronomer/providers/dbt/cloud/example_dags/__init__.py

Whitespace-only changes.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Example use of DBTCloudAsync related providers."""
2+
3+
import os
4+
from datetime import timedelta
5+
6+
from airflow import DAG
7+
from airflow.operators.empty import EmptyOperator
8+
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator
9+
from airflow.utils.timezone import datetime
10+
11+
from astronomer.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperatorAsync
12+
from astronomer.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensorAsync
13+
14+
DBT_CLOUD_CONN_ID = os.getenv("ASTRO_DBT_CLOUD_CONN", "dbt_cloud_default")
15+
DBT_CLOUD_ACCOUNT_ID = os.getenv("ASTRO_DBT_CLOUD_ACCOUNT_ID", 12345)
16+
DBT_CLOUD_JOB_ID = int(os.getenv("ASTRO_DBT_CLOUD_JOB_ID", 12345))
17+
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
18+
19+
20+
default_args = {
21+
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
22+
"dbt_cloud_conn_id": DBT_CLOUD_CONN_ID,
23+
"account_id": DBT_CLOUD_ACCOUNT_ID,
24+
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
25+
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
26+
}
27+
28+
with DAG(
29+
dag_id="example_dbt_cloud",
30+
start_date=datetime(2022, 1, 1),
31+
schedule_interval=None,
32+
default_args=default_args,
33+
tags=["example", "async", "dbt-cloud"],
34+
catchup=False,
35+
) as dag:
36+
start = EmptyOperator(task_id="start")
37+
end = EmptyOperator(task_id="end")
38+
# [START howto_operator_dbt_cloud_run_job_async]
39+
trigger_dbt_job_run_async = DbtCloudRunJobOperatorAsync(
40+
task_id="trigger_dbt_job_run_async",
41+
job_id=DBT_CLOUD_JOB_ID,
42+
check_interval=10,
43+
timeout=300,
44+
)
45+
# [END howto_operator_dbt_cloud_run_job_async]
46+
47+
trigger_job_run2 = DbtCloudRunJobOperator(
48+
task_id="trigger_job_run2",
49+
job_id=DBT_CLOUD_JOB_ID,
50+
wait_for_termination=False,
51+
additional_run_config={"threads_override": 8},
52+
)
53+
54+
# [START howto_operator_dbt_cloud_run_job_sensor_async]
55+
job_run_sensor_async = DbtCloudJobRunSensorAsync(
56+
task_id="job_run_sensor_async", run_id=trigger_job_run2.output, timeout=20
57+
)
58+
# [END howto_operator_dbt_cloud_run_job_sensor_async]
59+
60+
start >> trigger_dbt_job_run_async >> end
61+
start >> trigger_job_run2 >> job_run_sensor_async >> end

astronomer/providers/dbt/cloud/hooks/__init__.py

Whitespace-only changes.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from functools import wraps
2+
from inspect import signature
3+
from typing import Any, Dict, List, Optional, Tuple, TypeVar, cast
4+
5+
import aiohttp
6+
from aiohttp import ClientResponseError
7+
from airflow import AirflowException
8+
from airflow.hooks.base import BaseHook
9+
from airflow.models import Connection
10+
from asgiref.sync import sync_to_async
11+
12+
from astronomer.providers.package import get_provider_info
13+
14+
T = TypeVar("T", bound=Any)
15+
16+
17+
def provide_account_id(func: T) -> T:
18+
"""
19+
Decorator which provides a fallback value for ``account_id``. If the ``account_id`` is None or not passed
20+
to the decorated function, the value will be taken from the configured dbt Cloud Airflow Connection.
21+
"""
22+
function_signature = signature(func)
23+
24+
@wraps(func)
25+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
26+
bound_args = function_signature.bind(*args, **kwargs)
27+
28+
if bound_args.arguments.get("account_id") is None:
29+
self = args[0]
30+
if self.dbt_cloud_conn_id:
31+
connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
32+
default_account_id = connection.login
33+
if not default_account_id:
34+
raise AirflowException("Could not determine the dbt Cloud account.")
35+
bound_args.arguments["account_id"] = int(default_account_id)
36+
37+
return await func(*bound_args.args, **bound_args.kwargs)
38+
39+
return cast(T, wrapper)
40+
41+
42+
class DbtCloudHookAsync(BaseHook):
43+
"""
44+
Interact with dbt Cloud using the V2 API.
45+
46+
:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection <howto/connection:dbt-cloud>`.
47+
"""
48+
49+
conn_name_attr = "dbt_cloud_conn_id"
50+
default_conn_name = "dbt_cloud_default"
51+
conn_type = "dbt_cloud"
52+
hook_name = "dbt Cloud"
53+
54+
def __init__(self, dbt_cloud_conn_id: str):
55+
self.dbt_cloud_conn_id = dbt_cloud_conn_id
56+
57+
async def get_headers_tenants_from_connection(self) -> Tuple[Dict[str, Any], str]:
58+
"""Get Headers, tenants from the connection details"""
59+
headers: Dict[str, Any] = {}
60+
connection: Connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
61+
tenant: str = connection.schema if connection.schema else "cloud"
62+
provider_info = get_provider_info()
63+
package_name = provider_info["package-name"]
64+
version = provider_info["versions"]
65+
headers["User-Agent"] = f"{package_name}-v{version}"
66+
headers["Content-Type"] = "application/json"
67+
headers["Authorization"] = f"Token {connection.password}"
68+
return headers, tenant
69+
70+
@staticmethod
71+
def get_request_url_params(
72+
tenant: str, endpoint: str, include_related: Optional[List[str]] = None
73+
) -> Tuple[str, Dict[str, Any]]:
74+
"""
75+
Form URL from base url and endpoint url
76+
77+
:param tenant: The tenant name which is need to be replaced in base url.
78+
:param endpoint: Endpoint url to be requested.
79+
:param include_related: Optional. List of related fields to pull with the run.
80+
Valid values are "trigger", "job", "repository", and "environment".
81+
"""
82+
data: Dict[str, Any] = {}
83+
base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"
84+
if include_related:
85+
data = {"include_related": include_related}
86+
if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
87+
url = base_url + "/" + endpoint
88+
else:
89+
url = (base_url or "") + (endpoint or "")
90+
return url, data
91+
92+
@provide_account_id
93+
async def get_job_details(
94+
self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None
95+
) -> Any:
96+
"""
97+
Uses Http async call to retrieve metadata for a specific run of a dbt Cloud job.
98+
99+
:param run_id: The ID of a dbt Cloud job run.
100+
:param account_id: Optional. The ID of a dbt Cloud account.
101+
:param include_related: Optional. List of related fields to pull with the run.
102+
Valid values are "trigger", "job", "repository", and "environment".
103+
"""
104+
endpoint = f"{account_id}/runs/{run_id}/"
105+
headers, tenant = await self.get_headers_tenants_from_connection()
106+
url, params = self.get_request_url_params(tenant, endpoint, include_related)
107+
async with aiohttp.ClientSession(headers=headers) as session:
108+
async with session.get(url, params=params) as response:
109+
try:
110+
response.raise_for_status()
111+
return await response.json()
112+
except ClientResponseError as e:
113+
raise AirflowException(str(e.status) + ":" + e.message)
114+
115+
async def get_job_status(
116+
self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None
117+
) -> int:
118+
"""
119+
Retrieves the status for a specific run of a dbt Cloud job.
120+
121+
:param run_id: The ID of a dbt Cloud job run.
122+
:param account_id: Optional. The ID of a dbt Cloud account.
123+
:param include_related: Optional. List of related fields to pull with the run.
124+
Valid values are "trigger", "job", "repository", and "environment".
125+
"""
126+
try:
127+
self.log.info("Getting the status of job run %s.", str(run_id))
128+
response = await self.get_job_details(
129+
run_id, account_id=account_id, include_related=include_related
130+
)
131+
job_run_status: int = response["data"]["status"]
132+
return job_run_status
133+
except Exception as e:
134+
raise e

astronomer/providers/dbt/cloud/operators/__init__.py

Whitespace-only changes.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import time
2+
from typing import TYPE_CHECKING, Any, Dict
3+
4+
from airflow import AirflowException
5+
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook
6+
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator
7+
8+
from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
9+
10+
if TYPE_CHECKING: # pragma: no cover
11+
from airflow.utils.context import Context
12+
13+
14+
class DbtCloudRunJobOperatorAsync(DbtCloudRunJobOperator):
15+
"""
16+
Executes a dbt Cloud job asynchronously. Trigger the dbt cloud job via worker to dbt and with run id in response
17+
poll for the status in trigger.
18+
19+
.. seealso::
20+
For more information on sync Operator DbtCloudRunJobOperator, take a look at the guide:
21+
:ref:`howto/operator:DbtCloudRunJobOperator`
22+
23+
:param dbt_cloud_conn_id: The connection ID for connecting to dbt Cloud.
24+
:param job_id: The ID of a dbt Cloud job.
25+
:param account_id: Optional. The ID of a dbt Cloud account.
26+
:param trigger_reason: Optional Description of the reason to trigger the job. Dbt requires the trigger reason while
27+
making an API. if it is not provided uses the default reasons.
28+
:param steps_override: Optional. List of dbt commands to execute when triggering the job instead of those
29+
configured in dbt Cloud.
30+
:param schema_override: Optional. Override the destination schema in the configured target for this job.
31+
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
32+
:param check_interval: Time in seconds to check on a job run's status. Defaults to 60 seconds.
33+
:param additional_run_config: Optional. Any additional parameters that should be included in the API
34+
request when triggering the job.
35+
:return: The ID of the triggered dbt Cloud job run.
36+
"""
37+
38+
def execute(self, context: "Context") -> None: # type: ignore[override]
39+
"""Submits a job which generates a run_id and gets deferred"""
40+
if self.trigger_reason is None:
41+
self.trigger_reason = (
42+
f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG."
43+
)
44+
hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id)
45+
trigger_job_response = hook.trigger_job_run(
46+
account_id=self.account_id,
47+
job_id=self.job_id,
48+
cause=self.trigger_reason,
49+
steps_override=self.steps_override,
50+
schema_override=self.schema_override,
51+
additional_run_config=self.additional_run_config,
52+
)
53+
run_id = trigger_job_response.json()["data"]["id"]
54+
job_run_url = trigger_job_response.json()["data"]["href"]
55+
56+
context["ti"].xcom_push(key="job_run_url", value=job_run_url)
57+
end_time = time.time() + self.timeout
58+
self.defer(
59+
timeout=self.execution_timeout,
60+
trigger=DbtCloudRunJobTrigger(
61+
conn_id=self.dbt_cloud_conn_id,
62+
run_id=run_id,
63+
end_time=end_time,
64+
account_id=self.account_id,
65+
poll_interval=self.check_interval,
66+
),
67+
method_name="execute_complete",
68+
)
69+
70+
def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int:
71+
"""
72+
Callback for when the trigger fires - returns immediately.
73+
Relies on trigger to throw an exception, otherwise it assumes execution was
74+
successful.
75+
"""
76+
if event["status"] == "error":
77+
raise AirflowException(event["message"])
78+
self.log.info(event["message"])
79+
return int(event["run_id"])

astronomer/providers/dbt/cloud/sensors/__init__.py

Whitespace-only changes.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import time
2+
from typing import TYPE_CHECKING, Any, Dict
3+
4+
from airflow import AirflowException
5+
from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor
6+
7+
from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
8+
9+
if TYPE_CHECKING: # pragma: no cover
10+
from airflow.utils.context import Context
11+
12+
13+
class DbtCloudJobRunSensorAsync(DbtCloudJobRunSensor):
14+
"""
15+
Checks the status of a dbt Cloud job run.
16+
17+
.. seealso::
18+
For more information on sync Sensor DbtCloudJobRunSensor, take a look at the guide::
19+
:ref:`howto/operator:DbtCloudJobRunSensor`
20+
21+
:param dbt_cloud_conn_id: The connection identifier for connecting to dbt Cloud.
22+
:param run_id: The job run identifier.
23+
:param account_id: The dbt Cloud account identifier.
24+
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
25+
"""
26+
27+
def __init__(
28+
self,
29+
*,
30+
poll_interval: float = 5,
31+
timeout: float = 60 * 60 * 24 * 7,
32+
**kwargs: Any,
33+
):
34+
self.poll_interval = poll_interval
35+
self.timeout = timeout
36+
super().__init__(**kwargs)
37+
38+
def execute(self, context: "Context") -> None:
39+
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
40+
end_time = time.time() + self.timeout
41+
self.defer(
42+
timeout=self.execution_timeout,
43+
trigger=DbtCloudRunJobTrigger(
44+
run_id=self.run_id,
45+
conn_id=self.dbt_cloud_conn_id,
46+
account_id=self.account_id,
47+
poll_interval=self.poll_interval,
48+
end_time=end_time,
49+
),
50+
method_name="execute_complete",
51+
)
52+
53+
def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int:
54+
"""
55+
Callback for when the trigger fires - returns immediately.
56+
Relies on trigger to throw an exception, otherwise it assumes execution was
57+
successful.
58+
"""
59+
if event["status"] in ["error", "cancelled"]:
60+
raise AirflowException(event["message"])
61+
self.log.info(event["message"])
62+
return int(event["run_id"])

astronomer/providers/dbt/cloud/triggers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)