Skip to content

Commit 991246c

Browse files
Implement AzureDataFactoryPipelineRunStatusSensorAsync (#253)
1 parent 2915a67 commit 991246c

File tree

21 files changed

+810
-0
lines changed

21 files changed

+810
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import logging
2+
import os
3+
import time
4+
from datetime import datetime, timedelta
5+
6+
from airflow import DAG
7+
from airflow.operators.python import PythonOperator
8+
from airflow.providers.microsoft.azure.operators.data_factory import (
9+
AzureDataFactoryRunPipelineOperator,
10+
)
11+
from azure.identity import ClientSecretCredential
12+
from azure.mgmt.datafactory import DataFactoryManagementClient
13+
from azure.mgmt.datafactory.models import (
14+
AzureBlobDataset,
15+
AzureStorageLinkedService,
16+
BlobSink,
17+
BlobSource,
18+
CopyActivity,
19+
DatasetReference,
20+
DatasetResource,
21+
Factory,
22+
LinkedServiceReference,
23+
LinkedServiceResource,
24+
PipelineResource,
25+
SecureString,
26+
)
27+
from azure.mgmt.resource import ResourceManagementClient
28+
29+
from astronomer.providers.microsoft.azure.sensors.data_factory import (
30+
AzureDataFactoryPipelineRunStatusSensorAsync,
31+
)
32+
33+
default_args = {
34+
"execution_timeout": timedelta(minutes=30),
35+
"azure_data_factory_conn_id": "azure_data_factory_default",
36+
"factory_name": "ADFProvidersTeamDataFactory", # This can also be specified in the ADF connection.
37+
"resource_group_name": "team_provider_resource_group_test", # This can also be specified in the ADF connection.
38+
}
39+
40+
CLIENT_ID = os.getenv("CLIENT_ID", "")
41+
CLIENT_SECRET = os.getenv("CLIENT_SECRET", "")
42+
TENANT_ID = os.getenv("TENANT_ID", "")
43+
SUBSCRIPTION_ID = os.getenv("SUBSCRIPTION_ID", "")
44+
RESOURCE_GROUP_NAME = os.getenv("RESOURCE_GROUP_NAME", "")
45+
DATAFACTORY_NAME = os.getenv("DATAFACTORY_NAME", "")
46+
LOCATION = os.getenv("LOCATION", "eastus")
47+
CONNECTION_STRING = os.getenv("CONNECTION_STRING", "")
48+
PIPELINE_NAME = os.getenv("PIPELINE_NAME", "pipeline1")
49+
ACTIVITY_NAME = os.getenv("ACTIVITY_NAME", "copyBlobtoBlob")
50+
DATASET_INPUT_NAME = os.getenv("DATASET_INPUT_NAME", "ds_in")
51+
DATASET_OUTPUT_NAME = os.getenv("DATASET_OUTPUT_NAME", "ds_out")
52+
BLOB_FILE_NAME = os.getenv("BLOB_FILE_NAME", "test.txt")
53+
OUTPUT_BLOB_PATH = os.getenv("OUTPUT_BLOB_PATH", "container1/output")
54+
BLOB_PATH = os.getenv("BLOB_PATH", "container1/input")
55+
STORAGE_LINKED_SERVICE_NAME = os.getenv("STORAGE_LINKED_SERVICE_NAME", "storageLinkedService001")
56+
rg_params = {"location": LOCATION}
57+
df_params = {"location": LOCATION}
58+
59+
60+
def create_adf_storage_pipeline() -> None:
61+
"""
62+
Creates Azure resource if not present, Azure Data factory, Azure Storage linked service,
63+
Azure blob dataset both input and output and Data factory pipeline
64+
"""
65+
credentials = ClientSecretCredential(
66+
client_id=CLIENT_ID, client_secret=CLIENT_SECRET, tenant_id=TENANT_ID
67+
)
68+
resource_client = ResourceManagementClient(credentials, SUBSCRIPTION_ID)
69+
resource_group_exist = None
70+
try:
71+
resource_group_exist = resource_client.resource_groups.get(RESOURCE_GROUP_NAME)
72+
except Exception:
73+
logging.info("Resource group not found, so creating one")
74+
if not resource_group_exist:
75+
resource_client.resource_groups.create_or_update(RESOURCE_GROUP_NAME, rg_params)
76+
77+
# Create a data factory
78+
adf_client = DataFactoryManagementClient(credentials, SUBSCRIPTION_ID)
79+
df_resource = Factory(location=LOCATION)
80+
df = adf_client.factories.create_or_update(RESOURCE_GROUP_NAME, DATAFACTORY_NAME, df_resource)
81+
while df.provisioning_state != "Succeeded":
82+
df = adf_client.factories.get(RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
83+
time.sleep(1)
84+
85+
# Create an Azure Storage linked service
86+
87+
# IMPORTANT: specify the name and key of your Azure Storage account.
88+
storage_string = SecureString(value=CONNECTION_STRING)
89+
90+
ls_azure_storage = LinkedServiceResource(
91+
properties=AzureStorageLinkedService(connection_string=storage_string)
92+
)
93+
adf_client.linked_services.create_or_update(
94+
RESOURCE_GROUP_NAME, DATAFACTORY_NAME, STORAGE_LINKED_SERVICE_NAME, ls_azure_storage
95+
)
96+
97+
# Create an Azure blob dataset (input)
98+
ds_ls = LinkedServiceReference(reference_name=STORAGE_LINKED_SERVICE_NAME)
99+
ds_azure_blob = DatasetResource(
100+
properties=AzureBlobDataset(
101+
linked_service_name=ds_ls, folder_path=BLOB_PATH, file_name=BLOB_FILE_NAME
102+
)
103+
)
104+
adf_client.datasets.create_or_update(
105+
RESOURCE_GROUP_NAME, DATAFACTORY_NAME, DATASET_INPUT_NAME, ds_azure_blob
106+
)
107+
108+
# Create an Azure blob dataset (output)
109+
ds_out_azure_blob = DatasetResource(
110+
properties=AzureBlobDataset(linked_service_name=ds_ls, folder_path=OUTPUT_BLOB_PATH)
111+
)
112+
adf_client.datasets.create_or_update(
113+
RESOURCE_GROUP_NAME, DATAFACTORY_NAME, DATASET_OUTPUT_NAME, ds_out_azure_blob
114+
)
115+
116+
# Create a copy activity
117+
blob_source = BlobSource()
118+
blob_sink = BlobSink()
119+
ds_in_ref = DatasetReference(reference_name=DATASET_INPUT_NAME)
120+
ds_out_ref = DatasetReference(reference_name=DATASET_OUTPUT_NAME)
121+
copy_activity = CopyActivity(
122+
name=ACTIVITY_NAME, inputs=[ds_in_ref], outputs=[ds_out_ref], source=blob_source, sink=blob_sink
123+
)
124+
125+
# Create a pipeline with the copy activity
126+
p_obj = PipelineResource(activities=[copy_activity], parameters={})
127+
adf_client.pipelines.create_or_update(RESOURCE_GROUP_NAME, DATAFACTORY_NAME, PIPELINE_NAME, p_obj)
128+
129+
130+
def delete_azure_data_factory_storage_pipeline() -> None:
131+
"""Delete data factory, storage linked service pipeline, dataset"""
132+
credentials = ClientSecretCredential(
133+
client_id=CLIENT_ID, client_secret=CLIENT_SECRET, tenant_id=TENANT_ID
134+
)
135+
# create resource client
136+
resource_client = ResourceManagementClient(credentials, SUBSCRIPTION_ID)
137+
138+
# create Data factory client
139+
adf_client = DataFactoryManagementClient(credentials, SUBSCRIPTION_ID)
140+
141+
# Delete pipeline
142+
adf_client.pipelines.delete(RESOURCE_GROUP_NAME, DATAFACTORY_NAME, PIPELINE_NAME)
143+
144+
# Delete input dataset
145+
adf_client.datasets.delete(RESOURCE_GROUP_NAME, DATAFACTORY_NAME, DATASET_INPUT_NAME)
146+
147+
# Delete output dataset
148+
adf_client.datasets.delete(RESOURCE_GROUP_NAME, DATAFACTORY_NAME, DATASET_OUTPUT_NAME)
149+
150+
# Delete Linked services
151+
adf_client.linked_services.delete(
152+
RESOURCE_GROUP_NAME, DATAFACTORY_NAME, linked_service_name=STORAGE_LINKED_SERVICE_NAME
153+
)
154+
155+
# Delete Data factory
156+
adf_client.factories.delete(RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
157+
158+
# Delete Resource Group
159+
resource_client.resource_groups.begin_delete(RESOURCE_GROUP_NAME)
160+
161+
162+
with DAG(
163+
dag_id="example_adf_run_pipeline",
164+
start_date=datetime(2021, 8, 13),
165+
schedule_interval=None,
166+
catchup=False,
167+
default_args=default_args,
168+
tags=["example", "async", "Azure Pipeline"],
169+
) as dag:
170+
# [START howto_create_resource_group]
171+
create_azure_data_factory_storage_pipeline = PythonOperator(
172+
task_id="create_azure_data_factory_storage_pipeline",
173+
python_callable=create_adf_storage_pipeline,
174+
)
175+
# [END howto_create_resource_group]
176+
177+
# [START howto_operator_adf_run_pipeline]
178+
run_pipeline = AzureDataFactoryRunPipelineOperator(
179+
task_id="run_pipeline",
180+
pipeline_name=PIPELINE_NAME,
181+
wait_for_termination=False,
182+
)
183+
# [END howto_operator_adf_run_pipeline]
184+
185+
# [START howto_sensor_pipeline_run_sensor_async]
186+
pipeline_run_sensor_async = AzureDataFactoryPipelineRunStatusSensorAsync(
187+
task_id="pipeline_run_sensor_async",
188+
run_id=run_pipeline.output["run_id"],
189+
)
190+
# [END howto_sensor_pipeline_run_sensor_async]
191+
192+
remove_azure_data_factory_storage_pipeline = PythonOperator(
193+
task_id="remove_azure_data_factory_storage_pipeline",
194+
python_callable=delete_azure_data_factory_storage_pipeline,
195+
trigger_rule="all_done",
196+
)
197+
198+
(
199+
create_azure_data_factory_storage_pipeline
200+
>> run_pipeline
201+
>> pipeline_run_sensor_async
202+
>> remove_azure_data_factory_storage_pipeline
203+
)

astronomer/providers/microsoft/azure/hooks/__init__.py

Whitespace-only changes.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Any, Optional, Union
2+
3+
from airflow import AirflowException
4+
from airflow.providers.microsoft.azure.hooks.data_factory import AzureDataFactoryHook
5+
from asgiref.sync import sync_to_async
6+
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
7+
from azure.mgmt.datafactory.aio import DataFactoryManagementClient
8+
from azure.mgmt.datafactory.models import PipelineRun
9+
10+
Credentials = Union[ClientSecretCredential, DefaultAzureCredential]
11+
12+
13+
class AzureDataFactoryHookAsync(AzureDataFactoryHook):
14+
"""
15+
An Async Hook connects to Azure DataFactory to perform pipeline operations
16+
17+
:param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id<howto/connection:adf>`.
18+
"""
19+
20+
def __init__(self, azure_data_factory_conn_id: str):
21+
self._async_conn: DataFactoryManagementClient = None
22+
self.conn_id = azure_data_factory_conn_id
23+
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
24+
25+
async def get_async_conn(self) -> DataFactoryManagementClient:
26+
"""Get async connection and connect to azure data factory"""
27+
if self._conn is not None:
28+
return self._conn
29+
30+
conn = await sync_to_async(self.get_connection)(self.conn_id)
31+
tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId")
32+
33+
try:
34+
subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"]
35+
except KeyError:
36+
raise ValueError("A Subscription ID is required to connect to Azure Data Factory.")
37+
38+
credential: Credentials
39+
if conn.login is not None and conn.password is not None:
40+
if not tenant:
41+
raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.")
42+
43+
credential = ClientSecretCredential(
44+
client_id=conn.login, client_secret=conn.password, tenant_id=tenant
45+
)
46+
else:
47+
credential = DefaultAzureCredential()
48+
49+
return DataFactoryManagementClient(
50+
credential=credential,
51+
subscription_id=subscription_id,
52+
)
53+
54+
async def get_pipeline_run(
55+
self,
56+
run_id: str,
57+
resource_group_name: Optional[str] = None,
58+
factory_name: Optional[str] = None,
59+
**config: Any,
60+
) -> PipelineRun:
61+
"""
62+
Connects to Azure Data Factory asynchronously to get the pipeline run details by run id
63+
64+
:param run_id: The pipeline run identifier.
65+
:param resource_group_name: The resource group name.
66+
:param factory_name: The factory name.
67+
"""
68+
async with await self.get_async_conn() as client:
69+
try:
70+
pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id)
71+
return pipeline_run
72+
except Exception as e:
73+
raise AirflowException(e)
74+
75+
async def get_adf_pipeline_run_status(
76+
self, run_id: str, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None
77+
) -> str:
78+
"""
79+
Connects to Azure Data Factory asynchronously and gets the pipeline status by run_id
80+
81+
:param run_id: The pipeline run identifier.
82+
:param resource_group_name: The resource group name.
83+
:param factory_name: The factory name.
84+
"""
85+
try:
86+
pipeline_run = await self.get_pipeline_run(
87+
run_id=run_id,
88+
factory_name=factory_name,
89+
resource_group_name=resource_group_name,
90+
)
91+
status: str = pipeline_run.status
92+
return status
93+
except Exception as e:
94+
raise AirflowException(e)

astronomer/providers/microsoft/azure/operators/__init__.py

Whitespace-only changes.

astronomer/providers/microsoft/azure/operators/data_factory.py

Whitespace-only changes.

astronomer/providers/microsoft/azure/sensors/__init__.py

Whitespace-only changes.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Any, Dict
2+
3+
from airflow import AirflowException
4+
from airflow.providers.microsoft.azure.sensors.data_factory import (
5+
AzureDataFactoryPipelineRunStatusSensor,
6+
)
7+
8+
from astronomer.providers.microsoft.azure.triggers.data_factory import (
9+
ADFPipelineRunStatusSensorTrigger,
10+
)
11+
12+
13+
class AzureDataFactoryPipelineRunStatusSensorAsync(AzureDataFactoryPipelineRunStatusSensor):
14+
"""
15+
Checks the status of a pipeline run.
16+
17+
:param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory.
18+
:param run_id: The pipeline run identifier.
19+
:param resource_group_name: The resource group name.
20+
:param factory_name: The data factory name.
21+
"""
22+
23+
def __init__(
24+
self,
25+
*,
26+
poll_interval: float = 5,
27+
**kwargs: Any,
28+
):
29+
self.poll_interval = poll_interval
30+
super().__init__(**kwargs)
31+
32+
def execute(self, context: Dict[Any, Any]) -> None:
33+
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
34+
self.defer(
35+
timeout=self.execution_timeout,
36+
trigger=ADFPipelineRunStatusSensorTrigger(
37+
run_id=self.run_id,
38+
azure_data_factory_conn_id=self.azure_data_factory_conn_id,
39+
resource_group_name=self.resource_group_name,
40+
factory_name=self.factory_name,
41+
poll_interval=self.poll_interval,
42+
),
43+
method_name="execute_complete",
44+
)
45+
46+
def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None:
47+
"""
48+
Callback for when the trigger fires - returns immediately.
49+
Relies on trigger to throw an exception, otherwise it assumes execution was
50+
successful.
51+
"""
52+
if event:
53+
if event["status"] == "error":
54+
raise AirflowException(event["message"])
55+
self.log.info(event["message"])
56+
return None

astronomer/providers/microsoft/azure/triggers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)