diff --git a/.circleci/integration-tests/master_dag.py b/.circleci/integration-tests/master_dag.py index c74e1165f..423722845 100644 --- a/.circleci/integration-tests/master_dag.py +++ b/.circleci/integration-tests/master_dag.py @@ -133,6 +133,8 @@ def prepare_dag_dependency(task_info, execution_time): {"big_query_sensor_dag": "example_bigquery_sensors"}, {"dataproc_dag": "example_gcp_dataproc"}, {"kubernetes_engine_dag": "example_google_kubernetes_engine"}, + {"bigquery_impersonation_dag": "example_bigquery_impersonation"}, + {"dataproc_impersonation_dag": "example_gcp_dataproc_impersonation"}, ] google_trigger_tasks, ids = prepare_dag_dependency(google_task_info, "{{ ds }}") dag_run_ids.extend(ids) diff --git a/astronomer/providers/google/cloud/example_dags/example_bigquery_impersonation_chain.py b/astronomer/providers/google/cloud/example_dags/example_bigquery_impersonation_chain.py new file mode 100644 index 000000000..2937b9ed3 --- /dev/null +++ b/astronomer/providers/google/cloud/example_dags/example_bigquery_impersonation_chain.py @@ -0,0 +1,135 @@ +""" +Example Airflow DAG which uses impersonation and delegate_to +parameters for authenticating with Google BigQuery service +""" +import os +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.operators.empty import EmptyOperator +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, +) + +from astronomer.providers.google.cloud.operators.bigquery import ( + BigQueryCheckOperatorAsync, + BigQueryInsertJobOperatorAsync, +) + +PROJECT_ID = os.getenv("GCP_PROJECT_ID", "astronomer-airflow-providers") +DATASET_NAME = os.getenv("GCP_BIGQUERY_DATASET_NAME", "astro_dataset") +GCP_IMPERSONATION_CONN_ID = os.getenv("GCP_IMPERSONATION_CONN_ID", "google_impersonation") +LOCATION = os.getenv("GCP_LOCATION", "us") +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) +IMPERSONATION_CHAIN = os.getenv("IMPERSONATION_CHAIN", "") +DELEGATE_TO = os.getenv("DELEGATE_TO", "") + + +TABLE_1 = "table1" +TABLE_2 = "table2" + +SCHEMA = [ + {"name": "value", "type": "INTEGER", "mode": "REQUIRED"}, + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "ds", "type": "STRING", "mode": "NULLABLE"}, +] + +DATASET = DATASET_NAME +INSERT_DATE = datetime.now().strftime("%Y-%m-%d") +INSERT_ROWS_QUERY = ( + f"INSERT {DATASET}.{TABLE_1} VALUES " + f"(42, 'monthy python', '{INSERT_DATE}'), " + f"(42, 'fishy fish', '{INSERT_DATE}');" +) + +default_args = { + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)), + "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), +} + +with DAG( + dag_id="example_bigquery_impersonation", + schedule_interval=None, + start_date=datetime(2022, 1, 1), + catchup=False, + default_args=default_args, + tags=["example", "async", "bigquery"], + user_defined_macros={"DATASET": DATASET, "TABLE": TABLE_1}, +) as dag: + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset", + dataset_id=DATASET, + location=LOCATION, + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + create_table_1 = BigQueryCreateEmptyTableOperator( + task_id="create_table_1", + dataset_id=DATASET, + table_id=TABLE_1, + schema_fields=SCHEMA, + location=LOCATION, + bigquery_conn_id=GCP_IMPERSONATION_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + create_dataset >> create_table_1 + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", + dataset_id=DATASET, + delete_contents=True, + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + trigger_rule="all_done", + impersonation_chain=IMPERSONATION_CHAIN, + ) + + # [START howto_operator_bigquery_insert_job_async] + insert_query_job = BigQueryInsertJobOperatorAsync( + task_id="insert_query_job", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + } + }, + location=LOCATION, + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_operator_bigquery_insert_job_async] + + # [START howto_operator_bigquery_select_job_async] + select_query_job = BigQueryInsertJobOperatorAsync( + task_id="select_query_job", + configuration={ + "query": { + "query": "{% include 'example_bigquery_query.sql' %}", + "useLegacySql": False, + } + }, + location=LOCATION, + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + delegate_to=DELEGATE_TO, + ) + # [END howto_operator_bigquery_select_job_async] + + # [START howto_operator_bigquery_check_async] + check_count = BigQueryCheckOperatorAsync( + task_id="check_count", + sql=f"SELECT COUNT(*) FROM {DATASET}.{TABLE_1}", + use_legacy_sql=False, + location=LOCATION, + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_operator_bigquery_check_async] + + end = EmptyOperator(task_id="end") + + create_table_1 >> insert_query_job >> select_query_job >> check_count >> delete_dataset >> end diff --git a/astronomer/providers/google/cloud/example_dags/example_dataproc_impersonation.py b/astronomer/providers/google/cloud/example_dags/example_dataproc_impersonation.py new file mode 100644 index 000000000..430b0fe0d --- /dev/null +++ b/astronomer/providers/google/cloud/example_dags/example_dataproc_impersonation.py @@ -0,0 +1,252 @@ +"""Example Airflow DAG which uses impersonation parameters for authenticating with Dataproc operators.""" + +import os +from datetime import datetime, timedelta + +from airflow import models +from airflow.operators.empty import EmptyOperator +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, +) + +from astronomer.providers.google.cloud.operators.dataproc import ( + DataprocCreateClusterOperatorAsync, + DataprocDeleteClusterOperatorAsync, + DataprocSubmitJobOperatorAsync, + DataprocUpdateClusterOperatorAsync, +) + +PROJECT_ID = os.getenv("GCP_PROJECT_ID", "astronomer-airflow-providers") +CLUSTER_NAME = os.getenv("GCP_DATAPROC_CLUSTER_NAME", "example-cluster-astronomer-providers") +REGION = os.getenv("GCP_LOCATION", "us-central1") +GCP_IMPERSONATION_CONN_ID = os.getenv("GCP_IMPERSONATION_CONN_ID", "google_impersonation") +ZONE = os.getenv("GCP_REGION", "us-central1-a") +BUCKET = os.getenv("GCP_DATAPROC_BUCKET", "dataproc-system-tests-astronomer-providers") +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) +OUTPUT_FOLDER = "wordcount" +OUTPUT_PATH = f"gs://{BUCKET}/{OUTPUT_FOLDER}/" +IMPERSONATION_CHAIN = os.getenv("IMPERSONATION_CHAIN", "") + +# Cluster definition +# [START how_to_cloud_dataproc_create_cluster] + +CLUSTER_CONFIG = { + "master_config": { + "num_instances": 1, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024}, + }, +} + +# [END how_to_cloud_dataproc_create_cluster] + + +CLUSTER_UPDATE = {"config": {"worker_config": {"num_instances": 2}}} +UPDATE_MASK = { + "paths": ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"] +} + +TIMEOUT = {"seconds": 1 * 24 * 60 * 60} + +# Jobs definitions +# [START how_to_cloud_dataproc_pig_config] +PIG_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "pig_job": {"query_list": {"queries": ["define sin HiveUDF('sin');"]}}, +} +# [END how_to_cloud_dataproc_pig_config] + +# [START how_to_cloud_dataproc_sparksql_config] +SPARK_SQL_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "spark_sql_job": {"query_list": {"queries": ["SHOW DATABASES;"]}}, +} +# [END how_to_cloud_dataproc_sparksql_config] + +# [START how_to_cloud_dataproc_spark_config] +SPARK_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "spark_job": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, +} +# [END how_to_cloud_dataproc_spark_config] + +# [START how_to_cloud_dataproc_hive_config] +HIVE_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "hive_job": {"query_list": {"queries": ["SHOW DATABASES;"]}}, +} +# [END how_to_cloud_dataproc_hive_config] + +# [START how_to_cloud_dataproc_hadoop_config] +HADOOP_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "hadoop_job": { + "main_jar_file_uri": "file:///usr/lib/hadoop-mapreduce/hadoop-mapreduce-examples.jar", + "args": ["wordcount", "gs://pub/shakespeare/rose.txt", OUTPUT_PATH], + }, +} +# [END how_to_cloud_dataproc_hadoop_config] +WORKFLOW_NAME = "airflow-dataproc-test" +WORKFLOW_TEMPLATE = { + "id": WORKFLOW_NAME, + "placement": { + "managed_cluster": { + "cluster_name": CLUSTER_NAME, + "config": CLUSTER_CONFIG, + } + }, + "jobs": [{"step_id": "pig_job_1", "pig_job": PIG_JOB["pig_job"]}], +} + +default_args = { + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)), + "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), +} + + +with models.DAG( + dag_id="example_gcp_dataproc_impersonation", + schedule_interval=None, + start_date=datetime(2021, 1, 1), + catchup=False, + default_args=default_args, + tags=["example", "async", "dataproc"], +) as dag: + # [START howto_operator_dataproc_create_cluster_async] + create_cluster = DataprocCreateClusterOperatorAsync( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="create_cluster", + project_id=PROJECT_ID, + cluster_config=CLUSTER_CONFIG, + region=REGION, + cluster_name=CLUSTER_NAME, + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_operator_dataproc_create_cluster_async] + + # [START howto_operator_dataproc_update_cluster_async] + update_cluster = DataprocUpdateClusterOperatorAsync( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="update_cluster", + cluster_name=CLUSTER_NAME, + cluster=CLUSTER_UPDATE, + update_mask=UPDATE_MASK, + graceful_decommission_timeout=TIMEOUT, + project_id=PROJECT_ID, + region=REGION, + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_operator_dataproc_update_cluster_async] + + # [START howto_create_bucket_task] + create_bucket = GCSCreateBucketOperator( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="create_bucket", + bucket_name=BUCKET, + project_id=PROJECT_ID, + resource={ + "iamConfiguration": { + "uniformBucketLevelAccess": { + "enabled": False, + }, + }, + }, + ) + # [END howto_create_bucket_task] + + # [START howto_operator_dataproc_submit_pig_job_async] + pig_task = DataprocSubmitJobOperatorAsync( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="pig_task", + job=PIG_JOB, + region=REGION, + project_id=PROJECT_ID, + ) + # [END howto_operator_dataproc_submit_pig_job_async] + + # [START howto_DataprocSubmitJobOperatorAsync] + spark_sql_task = DataprocSubmitJobOperatorAsync( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="spark_sql_task", + job=SPARK_SQL_JOB, + region=REGION, + project_id=PROJECT_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_DataprocSubmitJobOperatorAsync] + # [START howto_DataprocSubmitJobOperatorAsync] + spark_task = DataprocSubmitJobOperatorAsync( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="spark_task", + job=SPARK_JOB, + region=REGION, + project_id=PROJECT_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_DataprocSubmitJobOperatorAsync] + # [START howto_DataprocSubmitJobOperatorAsync] + hive_task = DataprocSubmitJobOperatorAsync( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="hive_task", + job=HIVE_JOB, + region=REGION, + project_id=PROJECT_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_DataprocSubmitJobOperatorAsync] + # [START howto_DataprocSubmitJobOperatorAsync] + hadoop_task = DataprocSubmitJobOperatorAsync( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="hadoop_task", + job=HADOOP_JOB, + region=REGION, + project_id=PROJECT_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_DataprocSubmitJobOperatorAsync] + + # [START howto_operator_dataproc_delete_cluster_async] + delete_cluster = DataprocDeleteClusterOperatorAsync( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="delete_cluster", + project_id=PROJECT_ID, + cluster_name=CLUSTER_NAME, + region=REGION, + trigger_rule="all_done", + impersonation_chain=IMPERSONATION_CHAIN, + ) + # [END howto_operator_dataproc_delete_cluster_async] + + # [START howto_delete_buckettask] + delete_bucket = GCSDeleteBucketOperator( + gcp_conn_id=GCP_IMPERSONATION_CONN_ID, + task_id="delete_bucket", + bucket_name=BUCKET, + trigger_rule="all_done", + ) + # [END howto_delete_buckettask] + + end = EmptyOperator(task_id="end") + + create_cluster >> update_cluster >> hive_task >> spark_task >> spark_sql_task >> delete_cluster + ( + create_cluster + >> update_cluster + >> pig_task + >> create_bucket + >> hadoop_task + >> delete_bucket + >> delete_cluster + ) + + [spark_sql_task, hadoop_task, delete_cluster, delete_bucket] >> end diff --git a/astronomer/providers/google/cloud/hooks/bigquery.py b/astronomer/providers/google/cloud/hooks/bigquery.py index dedcf99f4..15fe71866 100644 --- a/astronomer/providers/google/cloud/hooks/bigquery.py +++ b/astronomer/providers/google/cloud/hooks/bigquery.py @@ -25,6 +25,9 @@ class BigQueryHookAsync(GoogleBaseHookAsync): sync_hook_class = BigQueryHook + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + async def get_job_instance( self, project_id: Optional[str], job_id: Optional[str], session: ClientSession ) -> Job: diff --git a/astronomer/providers/google/cloud/operators/bigquery.py b/astronomer/providers/google/cloud/operators/bigquery.py index 004badbc8..d700f0352 100644 --- a/astronomer/providers/google/cloud/operators/bigquery.py +++ b/astronomer/providers/google/cloud/operators/bigquery.py @@ -72,7 +72,11 @@ class BigQueryInsertJobOperatorAsync(BigQueryInsertJobOperator, BaseOperator): """ def execute(self, context: Context) -> None: # noqa: D102 - hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) self.hook = hook job_id = self.hook.generate_job_id( @@ -114,6 +118,8 @@ def execute(self, context: Context) -> None: # noqa: D102 conn_id=self.gcp_conn_id, job_id=self.job_id, project_id=self.project_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", ) @@ -158,6 +164,7 @@ def _submit_job( def execute(self, context: Context) -> None: # noqa: D102 hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, ) job = self._submit_job(hook, job_id="") context["ti"].xcom_push(key="job_id", value=job.job_id) @@ -167,6 +174,7 @@ def execute(self, context: Context) -> None: # noqa: D102 conn_id=self.gcp_conn_id, job_id=job.job_id, project_id=hook.project_id, + impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", ) @@ -288,6 +296,8 @@ def execute(self, context: Context) -> None: # noqa: D102 dataset_id=self.dataset_id, table_id=self.table_id, project_id=hook.project_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", ) @@ -376,6 +386,7 @@ def execute(self, context: Context) -> None: days_back=self.days_back, ratio_formula=self.ratio_formula, ignore_zero=self.ignore_zero, + impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", ) @@ -434,6 +445,7 @@ def execute(self, context: Context) -> None: # noqa: D102 sql=self.sql, pass_value=self.pass_value, tolerance=self.tol, + impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", ) diff --git a/astronomer/providers/google/cloud/operators/dataproc.py b/astronomer/providers/google/cloud/operators/dataproc.py index b1d28308e..e241f1228 100644 --- a/astronomer/providers/google/cloud/operators/dataproc.py +++ b/astronomer/providers/google/cloud/operators/dataproc.py @@ -74,7 +74,7 @@ def __init__( def execute(self, context: Context) -> None: # type: ignore[override] """Call create cluster API and defer to DataprocCreateClusterTrigger to check the status""" - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) DataprocLink.persist( context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name ) @@ -108,6 +108,7 @@ def execute(self, context: Context) -> None: # type: ignore[override] cluster_config=self.cluster_config, labels=self.labels, gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, polling_interval=self.polling_interval, ), method_name="execute_complete", @@ -185,6 +186,7 @@ def execute(self, context: Context) -> None: self.defer( trigger=DataprocDeleteClusterTrigger( + gcp_conn_id=self.gcp_conn_id, project_id=self.project_id, region=self.region, cluster_name=self.cluster_name, @@ -192,6 +194,7 @@ def execute(self, context: Context) -> None: retry=self.retry, end_time=end_time, metadata=self.metadata, + impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", ) @@ -265,6 +268,7 @@ def execute(self, context: Context) -> None: dataproc_job_id=job_id, project_id=self.project_id, region=self.region, + impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", ) @@ -366,6 +370,7 @@ def execute(self, context: "Context") -> None: end_time=end_time, metadata=self.metadata, gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, polling_interval=self.polling_interval, ), method_name="execute_complete", diff --git a/astronomer/providers/google/cloud/triggers/bigquery.py b/astronomer/providers/google/cloud/triggers/bigquery.py index 3f621ba1d..5ab6ef87f 100644 --- a/astronomer/providers/google/cloud/triggers/bigquery.py +++ b/astronomer/providers/google/cloud/triggers/bigquery.py @@ -1,5 +1,14 @@ import asyncio -from typing import Any, AsyncIterator, Dict, Optional, SupportsAbs, Tuple, Union +from typing import ( + Any, + AsyncIterator, + Dict, + Optional, + Sequence, + SupportsAbs, + Tuple, + Union, +) from aiohttp import ClientSession from aiohttp.client_exceptions import ClientResponseError @@ -20,6 +29,9 @@ class BigQueryInsertJobTrigger(BaseTrigger): :param project_id: Google Cloud Project where the job is running :param dataset_id: The dataset ID of the requested table. (templated) :param table_id: The table ID of the requested table. (templated) + :param delegate_to: This performs a task on one host with reference to other hosts. + :param impersonation_chain: This is the optional service account to impersonate using short term + credentials. :param poll_interval: polling period in seconds to check for the status """ @@ -30,6 +42,8 @@ def __init__( project_id: Optional[str], dataset_id: Optional[str] = None, table_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, poll_interval: float = 4.0, ): super().__init__() @@ -40,6 +54,8 @@ def __init__( self.dataset_id = dataset_id self.project_id = project_id self.table_id = table_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval def serialize(self) -> Tuple[str, Dict[str, Any]]: @@ -52,6 +68,8 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: "dataset_id": self.dataset_id, "project_id": self.project_id, "table_id": self.table_id, + "delegate_to": self.delegate_to, + "impersonation_chain": self.impersonation_chain, "poll_interval": self.poll_interval, }, ) @@ -85,7 +103,11 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] yield TriggerEvent({"status": "error", "message": str(e)}) def _get_async_hook(self) -> BigQueryHookAsync: - return BigQueryHookAsync(gcp_conn_id=self.conn_id) + return BigQueryHookAsync( + gcp_conn_id=self.conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) class BigQueryCheckTrigger(BigQueryInsertJobTrigger): @@ -101,6 +123,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: "dataset_id": self.dataset_id, "project_id": self.project_id, "table_id": self.table_id, + "impersonation_chain": self.impersonation_chain, "poll_interval": self.poll_interval, }, ) @@ -160,6 +183,8 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: "dataset_id": self.dataset_id, "project_id": self.project_id, "table_id": self.table_id, + "delegate_to": self.delegate_to, + "impersonation_chain": self.impersonation_chain, "poll_interval": self.poll_interval, }, ) @@ -213,6 +238,8 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): :param ratio_formula: ration formula :param ignore_zero: boolean value to consider zero or not :param table_id: The table ID of the requested table. (templated) + :param impersonation_chain: This is the optional service account to impersonate using short term + credentials. :param poll_interval: polling period in seconds to check for the status """ @@ -230,6 +257,7 @@ def __init__( ignore_zero: bool = True, dataset_id: Optional[str] = None, table_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, poll_interval: float = 4.0, ): super().__init__( @@ -238,6 +266,7 @@ def __init__( project_id=project_id, dataset_id=dataset_id, table_id=table_id, + impersonation_chain=impersonation_chain, poll_interval=poll_interval, ) self.conn_id = conn_id @@ -353,6 +382,8 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): :param tolerance: certain metrics for tolerance :param dataset_id: The dataset ID of the requested table. (templated) :param table_id: The table ID of the requested table. (templated) + :param impersonation_chain: This is the optional service account to impersonate using short term + credentials. :param poll_interval: polling period in seconds to check for the status """ @@ -366,6 +397,7 @@ def __init__( tolerance: Any = None, dataset_id: Optional[str] = None, table_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, poll_interval: float = 4.0, ): super().__init__( @@ -374,6 +406,7 @@ def __init__( project_id=project_id, dataset_id=dataset_id, table_id=table_id, + impersonation_chain=impersonation_chain, poll_interval=poll_interval, ) self.sql = sql @@ -468,7 +501,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: ) def _get_async_hook(self) -> BigQueryTableHookAsync: - return BigQueryTableHookAsync(gcp_conn_id=self.gcp_conn_id) + return BigQueryTableHookAsync(gcp_conn_id=self.gcp_conn_id, **self.hook_params) async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Will run until the table exists in the Google Big Query.""" diff --git a/astronomer/providers/google/cloud/triggers/dataproc.py b/astronomer/providers/google/cloud/triggers/dataproc.py index 10a57a3b8..0f656603d 100644 --- a/astronomer/providers/google/cloud/triggers/dataproc.py +++ b/astronomer/providers/google/cloud/triggers/dataproc.py @@ -22,6 +22,14 @@ class DataprocCreateClusterTrigger(BaseTrigger): :param end_time: Time in second left to check the cluster status :param metadata: Additional metadata that is provided to the method :param gcp_conn_id: The connection ID to use when fetching connection info. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. :param polling_interval: Time in seconds to sleep between checks of cluster status """ @@ -37,6 +45,7 @@ def __init__( cluster_config: Optional[Union[Dict[str, Any], clusters.Cluster]] = None, labels: Optional[Dict[str, str]] = None, gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, polling_interval: float = 5.0, **kwargs: Any, ): @@ -50,6 +59,7 @@ def __init__( self.cluster_config = cluster_config self.labels = labels self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain self.polling_interval = polling_interval def serialize(self) -> Tuple[str, Dict[str, Any]]: @@ -66,6 +76,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: "cluster_config": self.cluster_config, "labels": self.labels, "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, "polling_interval": self.polling_interval, }, ) @@ -122,7 +133,7 @@ async def _handle_error(self, cluster: clusters.Cluster) -> None: ) def _delete_cluster(self) -> None: - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) hook.delete_cluster( project_id=self.project_id, region=self.region, @@ -147,7 +158,7 @@ async def _wait_for_deleting(self) -> None: raise e def _create_cluster(self) -> Any: - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) return hook.create_cluster( region=self.region, project_id=self.project_id, @@ -158,7 +169,7 @@ def _create_cluster(self) -> Any: ) async def _get_cluster(self) -> clusters.Cluster: - hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id) + hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) return await hook.get_cluster( region=self.region, # type: ignore[arg-type] cluster_name=self.cluster_name, @@ -167,7 +178,7 @@ async def _get_cluster(self) -> clusters.Cluster: ) def _diagnose_cluster(self) -> Any: - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) return hook.diagnose_cluster( project_id=self.project_id, region=self.region, @@ -186,6 +197,14 @@ class DataprocDeleteClusterTrigger(BaseTrigger): :param region: The Cloud Dataproc region in which to handle the request :param metadata: Additional metadata that is provided to the method :param gcp_conn_id: The connection ID to use when fetching connection info. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. :param polling_interval: Time in seconds to sleep between checks of cluster status """ @@ -197,6 +216,7 @@ def __init__( region: Optional[str] = None, metadata: Sequence[Tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, polling_interval: float = 5.0, **kwargs: Any, ): @@ -207,6 +227,7 @@ def __init__( self.region = region self.metadata = metadata self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain self.polling_interval = polling_interval def serialize(self) -> Tuple[str, Dict[str, Any]]: @@ -220,13 +241,14 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: "region": self.region, "metadata": self.metadata, "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, "polling_interval": self.polling_interval, }, ) async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Wait until cluster is deleted completely""" - hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id) + hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) while self.end_time > time.time(): try: cluster = await hook.get_cluster( @@ -259,6 +281,14 @@ class DataProcSubmitTrigger(BaseTrigger): :param location: (To be deprecated). The Cloud Dataproc region in which to handle the request. (templated) :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. :param wait_timeout: How many seconds wait for job to be ready. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. """ def __init__( @@ -268,6 +298,7 @@ def __init__( region: Optional[str] = None, project_id: Optional[str] = None, gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, polling_interval: float = 5.0, ) -> None: super().__init__() @@ -275,6 +306,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.dataproc_job_id = dataproc_job_id self.region = region + self.impersonation_chain = impersonation_chain self.polling_interval = polling_interval def serialize(self) -> Tuple[str, Dict[str, Any]]: @@ -286,6 +318,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: "dataproc_job_id": self.dataproc_job_id, "region": self.region, "polling_interval": self.polling_interval, + "impersonation_chain": self.impersonation_chain, "gcp_conn_id": self.gcp_conn_id, }, ) @@ -293,7 +326,9 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] """Simple loop until the job running on Google Cloud DataProc is completed or not""" try: - hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id) + hook = DataprocHookAsync( + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + ) while True: job_status = await self._get_job_status(hook) if "status" in job_status and job_status["status"] == "success": diff --git a/tests/google/cloud/triggers/test_bigquery.py b/tests/google/cloud/triggers/test_bigquery.py index c1b719057..5b7190fe2 100644 --- a/tests/google/cloud/triggers/test_bigquery.py +++ b/tests/google/cloud/triggers/test_bigquery.py @@ -40,6 +40,8 @@ TEST_IGNORE_ZERO = True TEST_GCP_CONN_ID = "TEST_GCP_CONN_ID" TEST_HOOK_PARAMS = {} +TEST_DELEGATE_TO = None +TEST_IMPERSONATION_CHAIN = None def test_bigquery_insert_job_op_trigger_serialization(): @@ -53,6 +55,8 @@ def test_bigquery_insert_job_op_trigger_serialization(): TEST_GCP_PROJECT_ID, TEST_DATASET_ID, TEST_TABLE_ID, + TEST_DELEGATE_TO, + TEST_IMPERSONATION_CHAIN, POLLING_PERIOD_SECONDS, ) classpath, kwargs = trigger.serialize() @@ -63,6 +67,8 @@ def test_bigquery_insert_job_op_trigger_serialization(): "project_id": TEST_GCP_PROJECT_ID, "dataset_id": TEST_DATASET_ID, "table_id": TEST_TABLE_ID, + "delegate_to": TEST_DELEGATE_TO, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, "poll_interval": POLLING_PERIOD_SECONDS, } @@ -185,12 +191,13 @@ def test_bigquery_check_op_trigger_serialization(): and classpath. """ trigger = BigQueryCheckTrigger( - TEST_CONN_ID, - TEST_JOB_ID, - TEST_GCP_PROJECT_ID, - TEST_DATASET_ID, - TEST_TABLE_ID, - POLLING_PERIOD_SECONDS, + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + poll_interval=POLLING_PERIOD_SECONDS, ) classpath, kwargs = trigger.serialize() assert classpath == "astronomer.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger" @@ -200,6 +207,7 @@ def test_bigquery_check_op_trigger_serialization(): "dataset_id": TEST_DATASET_ID, "project_id": TEST_GCP_PROJECT_ID, "table_id": TEST_TABLE_ID, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, "poll_interval": POLLING_PERIOD_SECONDS, } @@ -297,6 +305,8 @@ def test_bigquery_get_data_trigger_serialization(): project_id=TEST_GCP_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, + delegate_to=TEST_DELEGATE_TO, + impersonation_chain=TEST_IMPERSONATION_CHAIN, poll_interval=POLLING_PERIOD_SECONDS, ) classpath, kwargs = trigger.serialize() @@ -307,6 +317,8 @@ def test_bigquery_get_data_trigger_serialization(): "dataset_id": TEST_DATASET_ID, "project_id": TEST_GCP_PROJECT_ID, "table_id": TEST_TABLE_ID, + "delegate_to": TEST_DELEGATE_TO, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, "poll_interval": POLLING_PERIOD_SECONDS, } @@ -386,6 +398,7 @@ def test_bigquery_interval_check_trigger_serialization(): TEST_IGNORE_ZERO, TEST_DATASET_ID, TEST_TABLE_ID, + TEST_IMPERSONATION_CHAIN, POLLING_PERIOD_SECONDS, ) classpath, kwargs = trigger.serialize() diff --git a/tests/google/cloud/triggers/test_dataproc.py b/tests/google/cloud/triggers/test_dataproc.py index 2e20b5100..1a8508f73 100644 --- a/tests/google/cloud/triggers/test_dataproc.py +++ b/tests/google/cloud/triggers/test_dataproc.py @@ -24,6 +24,7 @@ TEST_ZONE = "us-central1-a" TEST_JOB_ID = "test-job" TEST_POLLING_INTERVAL = 3.0 +TEST_IMPERSONATION_CHAIN = None def test_dataproc_submit_trigger_serialization(): @@ -36,6 +37,7 @@ def test_dataproc_submit_trigger_serialization(): dataproc_job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, region=TEST_REGION, + impersonation_chain=TEST_IMPERSONATION_CHAIN, polling_interval=TEST_POLLING_INTERVAL, ) classpath, kwargs = trigger.serialize() @@ -46,6 +48,7 @@ def test_dataproc_submit_trigger_serialization(): "region": TEST_REGION, "polling_interval": TEST_POLLING_INTERVAL, "gcp_conn_id": TEST_GCP_CONN_ID, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, } @@ -66,6 +69,7 @@ async def test_dataproc_submit_return_success_and_failure(mock_get_job_status, s dataproc_job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, region=TEST_REGION, + impersonation_chain=TEST_IMPERSONATION_CHAIN, polling_interval=TEST_POLLING_INTERVAL, ) generator = trigger.run() @@ -162,6 +166,7 @@ def test_dataproc_create_cluster_trigger_serialization(): polling_interval=TEST_POLLING_INTERVAL, end_time=100, metadata=(), + impersonation_chain=TEST_IMPERSONATION_CHAIN, ) classpath, kwargs = trigger.serialize() assert classpath == "astronomer.providers.google.cloud.triggers.dataproc.DataprocCreateClusterTrigger" @@ -171,6 +176,7 @@ def test_dataproc_create_cluster_trigger_serialization(): "cluster_name": "test_cluster", "gcp_conn_id": TEST_GCP_CONN_ID, "polling_interval": TEST_POLLING_INTERVAL, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, "delete_on_error": True, "labels": None, "cluster_config": None, @@ -460,6 +466,7 @@ def test_dataproc_delete_cluster_trigger_serialization(): cluster_name="test_cluster", gcp_conn_id=TEST_GCP_CONN_ID, polling_interval=TEST_POLLING_INTERVAL, + impersonation_chain=TEST_IMPERSONATION_CHAIN, end_time=100, metadata=(), ) @@ -471,6 +478,7 @@ def test_dataproc_delete_cluster_trigger_serialization(): "cluster_name": "test_cluster", "gcp_conn_id": TEST_GCP_CONN_ID, "polling_interval": TEST_POLLING_INTERVAL, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, "end_time": 100, "metadata": (), }