Skip to content

Commit 0e7e502

Browse files
committed
Revert "Revert BigQueryAsyncExtractor to release 1.3.0 (#332)"
This reverts commit 39102d0.
1 parent a88481b commit 0e7e502

File tree

4 files changed

+257
-0
lines changed

4 files changed

+257
-0
lines changed

astronomer/providers/google/cloud/extractors/__init__.py

Whitespace-only changes.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Any, List, Optional
2+
3+
from airflow.exceptions import AirflowException
4+
from airflow.models.taskinstance import TaskInstance
5+
from airflow.utils.log.logging_mixin import LoggingMixin
6+
from google.cloud.bigquery import Client
7+
from openlineage.airflow.extractors.base import BaseExtractor, TaskMetadata
8+
from openlineage.airflow.utils import get_job_name
9+
from openlineage.common.provider.bigquery import BigQueryDatasetsProvider
10+
11+
from astronomer.providers.google.cloud.operators.bigquery import (
12+
BigQueryInsertJobOperatorAsync,
13+
)
14+
15+
16+
class BigQueryAsyncExtractor(BaseExtractor, LoggingMixin):
17+
"""
18+
This extractor provides visibility on the metadata of a BigQuery Insert Job
19+
including ``billedBytes``, ``rowCount``, ``size``, etc. submitted from a
20+
``BigQueryInsertJobOperatorAsync`` operator.
21+
"""
22+
23+
def __init__(self, operator: BigQueryInsertJobOperatorAsync):
24+
super().__init__(operator)
25+
self._big_query_client = self._get_big_query_client()
26+
27+
def _get_big_query_client(self) -> Client:
28+
"""
29+
Gets the BigQuery client to fetch job metadata.
30+
The method checks whether a connection hook is available with the Airflow configuration for the operator, and
31+
if yes, returns the same connection. Otherwise, returns the Client instance of``google.cloud.bigquery``.
32+
"""
33+
if hasattr(self.operator, "hook") and self.operator.hook:
34+
hook = self.operator.hook
35+
return hook.get_client(project_id=hook.project_id, location=hook.location)
36+
return Client()
37+
38+
def _get_xcom_bigquery_job_id(self, task_instance: TaskInstance) -> Any:
39+
"""
40+
Pulls the BigQuery Job ID from XCOM for the task instance whose metadata needs to be extracted.
41+
42+
:param task_instance: Instance of the Airflow task whose BigQuery ``job_id`` needs to be pulled from XCOM.
43+
"""
44+
bigquery_job_id = task_instance.xcom_pull(task_ids=task_instance.task_id, key="job_id")
45+
if not bigquery_job_id:
46+
raise AirflowException("Could not pull relevant BigQuery job ID from XCOM")
47+
self.log.debug("Big Query Job Id %s", bigquery_job_id)
48+
return bigquery_job_id
49+
50+
@classmethod
51+
def get_operator_classnames(cls) -> List[str]:
52+
"""Returns the list of operators this extractor works on."""
53+
return ["BigQueryInsertJobOperatorAsync"]
54+
55+
def extract(self) -> Optional[TaskMetadata]:
56+
"""Empty extract implementation for the abstractmethod of the ``BaseExtractor`` class."""
57+
return None
58+
59+
def extract_on_complete(self, task_instance: TaskInstance) -> Optional[TaskMetadata]:
60+
"""
61+
Callback on task completion to fetch metadata extraction details that are to be pushed to the Lineage server.
62+
63+
:param task_instance: Instance of the Airflow task whose metadata needs to be extracted.
64+
"""
65+
try:
66+
bigquery_job_id = self._get_xcom_bigquery_job_id(task_instance)
67+
except AirflowException as ae:
68+
exception_message = str(ae)
69+
self.log.exception("%s", exception_message)
70+
return TaskMetadata(name=get_job_name(task=self.operator))
71+
stats = BigQueryDatasetsProvider(client=self._big_query_client).get_facets(bigquery_job_id)
72+
inputs = stats.inputs
73+
output = stats.output
74+
run_facets = stats.run_facets
75+
76+
return TaskMetadata(
77+
name=get_job_name(task=self.operator),
78+
inputs=[ds.to_openlineage_dataset() for ds in inputs],
79+
outputs=[output.to_openlineage_dataset()] if output else [],
80+
run_facets=run_facets,
81+
)

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ apache.hive =
7676
impyla
7777
microsoft.azure =
7878
apache-airflow-providers-microsoft-azure
79+
80+
# If in future we move Openlineage extractors out of the repo, this dependency should be removed
81+
openlineage =
82+
openlineage-airflow==0.6.2
7983
docs =
8084
sphinx
8185
sphinx-autoapi
@@ -128,6 +132,7 @@ all =
128132
kubernetes_asyncio
129133
paramiko
130134
impyla
135+
openlineage-airflow==0.6.2
131136

132137
[options.packages.find]
133138
include =
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import json
2+
from unittest import mock
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
from airflow.exceptions import TaskDeferred
7+
from airflow.models.dagrun import DagRun
8+
from airflow.models.taskinstance import TaskInstance
9+
from airflow.utils.timezone import datetime
10+
from airflow.utils.types import DagRunType
11+
from openlineage.client.facet import OutputStatisticsOutputDatasetFacet
12+
from openlineage.common.dataset import Dataset, Source
13+
from openlineage.common.provider.bigquery import (
14+
BigQueryFacets,
15+
BigQueryJobRunFacet,
16+
BigQueryStatisticsDatasetFacet,
17+
)
18+
19+
from astronomer.providers.google.cloud.extractors.bigquery_async_extractor import (
20+
BigQueryAsyncExtractor,
21+
)
22+
from astronomer.providers.google.cloud.operators.bigquery import (
23+
BigQueryInsertJobOperatorAsync,
24+
)
25+
26+
TEST_DATASET_LOCATION = "EU"
27+
TEST_GCP_PROJECT_ID = "test-project"
28+
TEST_DATASET = "test-dataset"
29+
TEST_TABLE = "test-table"
30+
EXECUTION_DATE = datetime(2022, 1, 1, 0, 0, 0)
31+
INSERT_DATE = EXECUTION_DATE.strftime("%Y-%m-%d")
32+
INSERT_ROWS_QUERY = (
33+
f"INSERT {TEST_DATASET}.{TEST_TABLE} VALUES "
34+
f"(42, 'monthy python', '{INSERT_DATE}'), "
35+
f"(42, 'fishy fish', '{INSERT_DATE}');"
36+
)
37+
38+
INPUT_STATS = [
39+
Dataset(
40+
source=Source(scheme="bigquery"),
41+
name=f"astronomer-airflow-providers.{TEST_DATASET}.{TEST_TABLE}",
42+
fields=[],
43+
custom_facets={},
44+
input_facets={},
45+
output_facets={},
46+
)
47+
]
48+
49+
OUTPUT_STATS = Dataset(
50+
source=Source(scheme="bigquery"),
51+
name=f"astronomer-airflow-providers.{TEST_DATASET}.{TEST_TABLE}",
52+
fields=[],
53+
custom_facets={"stats": BigQueryStatisticsDatasetFacet(rowCount=2, size=0)},
54+
input_facets={},
55+
output_facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(rowCount=2, size=0)},
56+
)
57+
58+
with open("tests/google/cloud/extractors/job_details.json") as jd_json:
59+
JOB_PROPERTIES = json.load(jd_json)
60+
61+
RUN_FACETS = {
62+
"bigQuery_job": BigQueryJobRunFacet(billedBytes=0, cached=False, properties=json.dumps(JOB_PROPERTIES))
63+
}
64+
65+
66+
@pytest.fixture
67+
def context():
68+
"""
69+
Creates an empty context.
70+
"""
71+
context = {}
72+
yield context
73+
74+
75+
@mock.patch("astronomer.providers.google.cloud.operators.bigquery._BigQueryHook")
76+
@mock.patch("airflow.models.TaskInstance.xcom_pull")
77+
@mock.patch("openlineage.common.provider.bigquery.BigQueryDatasetsProvider.get_facets")
78+
def test_extract_on_complete(mock_bg_dataset_provider, mock_xcom_pull, mock_hook):
79+
"""
80+
Tests that the custom extractor's implementation for the BigQueryInsertJobOperatorAsync is able to process the
81+
operator's metadata that needs to be extracted as per OpenLineage.
82+
"""
83+
configuration = {
84+
"query": {
85+
"query": INSERT_ROWS_QUERY,
86+
"useLegacySql": False,
87+
}
88+
}
89+
job_id = "123456"
90+
mock_hook.return_value.insert_job.return_value = MagicMock(job_id=job_id, error_result=False)
91+
mock_bg_dataset_provider.return_value = BigQueryFacets(
92+
run_facets=RUN_FACETS, inputs=INPUT_STATS, output=OUTPUT_STATS
93+
)
94+
95+
task_id = "insert_query_job"
96+
operator = BigQueryInsertJobOperatorAsync(
97+
task_id=task_id,
98+
configuration=configuration,
99+
location=TEST_DATASET_LOCATION,
100+
job_id=job_id,
101+
project_id=TEST_GCP_PROJECT_ID,
102+
)
103+
104+
task_instance = TaskInstance(task=operator)
105+
with pytest.raises(TaskDeferred):
106+
operator.execute(context)
107+
108+
bq_extractor = BigQueryAsyncExtractor(operator)
109+
task_meta_extract = bq_extractor.extract()
110+
assert task_meta_extract is None
111+
112+
task_meta = bq_extractor.extract_on_complete(task_instance)
113+
114+
mock_xcom_pull.assert_called_once_with(task_ids=task_instance.task_id, key="job_id")
115+
116+
assert task_meta.name == f"adhoc_airflow.{task_id}"
117+
118+
assert task_meta.inputs[0].facets["dataSource"].name == INPUT_STATS[0].source.scheme
119+
assert task_meta.inputs[0].name == INPUT_STATS[0].name
120+
121+
assert task_meta.outputs[0].name == OUTPUT_STATS.name
122+
assert task_meta.outputs[0].facets["stats"].rowCount == 2
123+
assert task_meta.outputs[0].facets["stats"].size == 0
124+
125+
assert task_meta.run_facets["bigQuery_job"].billedBytes == 0
126+
run_facet_properties = json.loads(task_meta.run_facets["bigQuery_job"].properties)
127+
assert run_facet_properties == JOB_PROPERTIES
128+
129+
130+
def test_extractor_works_on_operator():
131+
"""Tests that the custom extractor implementation is available for the BigQueryInsertJobOperatorAsync Operator."""
132+
task_id = "insert_query_job"
133+
operator = BigQueryInsertJobOperatorAsync(task_id=task_id, configuration={})
134+
assert type(operator).__name__ in BigQueryAsyncExtractor.get_operator_classnames()
135+
136+
137+
@mock.patch("astronomer.providers.google.cloud.operators.bigquery._BigQueryHook")
138+
def test_unavailable_xcom_raises_exception(mock_hook):
139+
"""
140+
Tests that an exception is raised when the custom extractor is not available to retrieve required XCOM for the
141+
BigQueryInsertJobOperatorAsync Operator.
142+
"""
143+
configuration = {
144+
"query": {
145+
"query": INSERT_ROWS_QUERY,
146+
"useLegacySql": False,
147+
}
148+
}
149+
job_id = "123456"
150+
mock_hook.return_value.insert_job.return_value = MagicMock(job_id=job_id, error_result=False)
151+
task_id = "insert_query_job"
152+
operator = BigQueryInsertJobOperatorAsync(
153+
task_id=task_id,
154+
configuration=configuration,
155+
location=TEST_DATASET_LOCATION,
156+
job_id=job_id,
157+
project_id=TEST_GCP_PROJECT_ID,
158+
)
159+
160+
task_instance = TaskInstance(task=operator)
161+
execution_date = datetime(2022, 1, 1, 0, 0, 0)
162+
task_instance.run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date)
163+
164+
with pytest.raises(TaskDeferred):
165+
operator.execute(context)
166+
bq_extractor = BigQueryAsyncExtractor(operator)
167+
with mock.patch.object(bq_extractor.log, "exception") as mock_log_exception:
168+
task_meta = bq_extractor.extract_on_complete(task_instance)
169+
170+
mock_log_exception.assert_called_with("%s", "Could not pull relevant BigQuery job ID from XCOM")
171+
assert task_meta.name == f"adhoc_airflow.{task_id}"

0 commit comments

Comments
 (0)