|
17 | 17 | """This module contains Google Dataplex sensors."""
|
18 | 18 | from __future__ import annotations
|
19 | 19 |
|
| 20 | +import time |
20 | 21 | from typing import TYPE_CHECKING, Sequence
|
21 | 22 |
|
22 | 23 | if TYPE_CHECKING:
|
23 | 24 | from airflow.utils.context import Context
|
24 |
| - |
| 25 | +from google.api_core.exceptions import GoogleAPICallError |
25 | 26 | from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
26 | 27 | from google.api_core.retry import Retry
|
| 28 | +from google.cloud.dataplex_v1.types import DataScanJob |
27 | 29 |
|
28 | 30 | from airflow.exceptions import AirflowException
|
29 |
| -from airflow.providers.google.cloud.hooks.dataplex import DataplexHook |
| 31 | +from airflow.providers.google.cloud.hooks.dataplex import ( |
| 32 | + AirflowDataQualityScanException, |
| 33 | + AirflowDataQualityScanResultTimeoutException, |
| 34 | + DataplexHook, |
| 35 | +) |
30 | 36 | from airflow.sensors.base import BaseSensorOperator
|
31 | 37 |
|
32 | 38 |
|
@@ -114,3 +120,119 @@ def poke(self, context: Context) -> bool:
|
114 | 120 | self.log.info("Current status of the Dataplex task %s => %s", self.dataplex_task_id, task_status)
|
115 | 121 |
|
116 | 122 | return task_status == TaskState.ACTIVE
|
| 123 | + |
| 124 | + |
| 125 | +class DataplexDataQualityJobStatusSensor(BaseSensorOperator): |
| 126 | + """ |
| 127 | + Check the status of the Dataplex DataQuality job. |
| 128 | +
|
| 129 | + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. |
| 130 | + :param region: Required. The ID of the Google Cloud region that the task belongs to. |
| 131 | + :param data_scan_id: Required. Data Quality scan identifier. |
| 132 | + :param job_id: Required. Job ID. |
| 133 | + :param api_version: The version of the api that will be requested for example 'v3'. |
| 134 | + :param retry: A retry object used to retry requests. If `None` is specified, requests |
| 135 | + will not be retried. |
| 136 | + :param metadata: Additional metadata that is provided to the method. |
| 137 | + :param gcp_conn_id: The connection ID to use when fetching connection info. |
| 138 | + :param impersonation_chain: Optional service account to impersonate using short-term |
| 139 | + credentials, or chained list of accounts required to get the access_token |
| 140 | + of the last account in the list, which will be impersonated in the request. |
| 141 | + If set as a string, the account must grant the originating account |
| 142 | + the Service Account Token Creator IAM role. |
| 143 | + If set as a sequence, the identities from the list must grant |
| 144 | + Service Account Token Creator IAM role to the directly preceding identity, with first |
| 145 | + account from the list granting this role to the originating account (templated). |
| 146 | + :param result_timeout: Value in seconds for which operator will wait for the Data Quality scan result. |
| 147 | + Throws exception if there is no result found after specified amount of seconds. |
| 148 | + :param fail_on_dq_failure: If set to true and not all Data Quality scan rules have been passed, |
| 149 | + an exception is thrown. If set to false and not all Data Quality scan rules have been passed, |
| 150 | + execution will finish with success. |
| 151 | +
|
| 152 | + :return: Boolean indicating if the job run has reached the ``DataScanJob.State.SUCCEEDED``. |
| 153 | + """ |
| 154 | + |
| 155 | + template_fields = ["job_id"] |
| 156 | + |
| 157 | + def __init__( |
| 158 | + self, |
| 159 | + project_id: str, |
| 160 | + region: str, |
| 161 | + data_scan_id: str, |
| 162 | + job_id: str, |
| 163 | + api_version: str = "v1", |
| 164 | + retry: Retry | _MethodDefault = DEFAULT, |
| 165 | + metadata: Sequence[tuple[str, str]] = (), |
| 166 | + gcp_conn_id: str = "google_cloud_default", |
| 167 | + impersonation_chain: str | Sequence[str] | None = None, |
| 168 | + fail_on_dq_failure: bool = False, |
| 169 | + result_timeout: float = 60.0 * 10, |
| 170 | + start_sensor_time: float = time.monotonic(), |
| 171 | + *args, |
| 172 | + **kwargs, |
| 173 | + ) -> None: |
| 174 | + super().__init__(*args, **kwargs) |
| 175 | + self.project_id = project_id |
| 176 | + self.region = region |
| 177 | + self.data_scan_id = data_scan_id |
| 178 | + self.job_id = job_id |
| 179 | + self.api_version = api_version |
| 180 | + self.retry = retry |
| 181 | + self.metadata = metadata |
| 182 | + self.gcp_conn_id = gcp_conn_id |
| 183 | + self.impersonation_chain = impersonation_chain |
| 184 | + self.fail_on_dq_failure = fail_on_dq_failure |
| 185 | + self.result_timeout = result_timeout |
| 186 | + self.start_sensor_time = start_sensor_time |
| 187 | + |
| 188 | + def execute(self, context: Context) -> None: |
| 189 | + super().execute(context) |
| 190 | + |
| 191 | + def _duration(self): |
| 192 | + return time.monotonic() - self.start_sensor_time |
| 193 | + |
| 194 | + def poke(self, context: Context) -> bool: |
| 195 | + self.log.info("Waiting for job %s to be %s", self.job_id, DataScanJob.State.SUCCEEDED) |
| 196 | + if self.result_timeout: |
| 197 | + duration = self._duration() |
| 198 | + if duration > self.result_timeout: |
| 199 | + raise AirflowDataQualityScanResultTimeoutException( |
| 200 | + f"Timeout: Data Quality scan {self.job_id} is not ready after {self.result_timeout}s" |
| 201 | + ) |
| 202 | + |
| 203 | + hook = DataplexHook( |
| 204 | + gcp_conn_id=self.gcp_conn_id, |
| 205 | + api_version=self.api_version, |
| 206 | + impersonation_chain=self.impersonation_chain, |
| 207 | + ) |
| 208 | + |
| 209 | + try: |
| 210 | + job = hook.get_data_scan_job( |
| 211 | + project_id=self.project_id, |
| 212 | + region=self.region, |
| 213 | + data_scan_id=self.data_scan_id, |
| 214 | + job_id=self.job_id, |
| 215 | + timeout=self.timeout, |
| 216 | + retry=self.retry, |
| 217 | + metadata=self.metadata, |
| 218 | + ) |
| 219 | + except GoogleAPICallError as e: |
| 220 | + raise AirflowException( |
| 221 | + f"Error occurred when trying to retrieve Data Quality scan job: {self.data_scan_id}", e |
| 222 | + ) |
| 223 | + |
| 224 | + job_status = job.state |
| 225 | + self.log.info( |
| 226 | + "Current status of the Dataplex Data Quality scan job %s => %s", self.job_id, job_status |
| 227 | + ) |
| 228 | + if job_status == DataScanJob.State.FAILED: |
| 229 | + raise AirflowException(f"Data Quality scan job failed: {self.job_id}") |
| 230 | + if job_status == DataScanJob.State.CANCELLED: |
| 231 | + raise AirflowException(f"Data Quality scan job cancelled: {self.job_id}") |
| 232 | + if self.fail_on_dq_failure: |
| 233 | + if job_status == DataScanJob.State.SUCCEEDED and not job.data_quality_result.passed: |
| 234 | + raise AirflowDataQualityScanException( |
| 235 | + f"Data Quality job {self.job_id} execution failed due to failure of its scanning " |
| 236 | + f"rules: {self.data_scan_id}" |
| 237 | + ) |
| 238 | + return job_status == DataScanJob.State.SUCCEEDED |
0 commit comments