Skip to content

Implement GCSObjectsWithPrefixExistenceSensorAsync and GCSUploadSessionCompleteSensor #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
catchup=False,
) as dag:
task_create_func = RedshiftSQLOperatorAsync(
task_id='task_create_func',
task_id="task_create_func",
sql="""
CREATE OR REPLACE FUNCTION janky_sleep (x float) RETURNS bool IMMUTABLE as $$
from time import sleep
Expand All @@ -25,12 +25,12 @@
)

task_long_running_query_sleep = RedshiftSQLOperatorAsync(
task_id='task_long_running_query_sleep',
task_id="task_long_running_query_sleep",
sql="select janky_sleep(10.0);",
)

task_create_table = RedshiftSQLOperatorAsync(
task_id='task_create_table',
task_id="task_create_table",
sql="""
CREATE TABLE IF NOT EXISTS fruit (
fruit_id INTEGER,
Expand All @@ -40,7 +40,7 @@
""",
)
task_insert_data = RedshiftSQLOperatorAsync(
task_id='task_insert_data',
task_id="task_insert_data",
sql=[
"INSERT INTO fruit VALUES ( 1, 'Banana', 'Yellow');",
"INSERT INTO fruit VALUES ( 2, 'Apple', 'Red');",
Expand All @@ -52,18 +52,18 @@
)

task_get_all_data = RedshiftSQLOperatorAsync(
task_id='task_get_all_data',
task_id="task_get_all_data",
sql="SELECT * FROM fruit;",
)

task_get_data_with_filter = RedshiftSQLOperatorAsync(
task_id='task_get_with_filter',
task_id="task_get_with_filter",
sql="SELECT * FROM fruit WHERE color = '{{ params.color }}';",
params={'color': 'Red'},
params={"color": "Red"},
)

task_delete_table = RedshiftSQLOperatorAsync(
task_id='task_delete_table',
task_id="task_delete_table",
sql="drop table fruit;",
)

Expand Down
4 changes: 2 additions & 2 deletions astronomer/providers/amazon/aws/example_dags/example_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
tags=["async"],
) as dag:
create_bucket = S3CreateBucketOperator(
task_id='create_bucket', region_name=REGION_NAME, bucket_name=S3_BUCKET_NAME
task_id="create_bucket", region_name=REGION_NAME, bucket_name=S3_BUCKET_NAME
)

create_local_to_s3_job = LocalFilesystemToS3Operator(
Expand Down Expand Up @@ -84,7 +84,7 @@
)

delete_bucket = S3DeleteBucketOperator(
task_id='delete_bucket', force_delete=True, bucket_name=S3_BUCKET_NAME
task_id="delete_bucket", force_delete=True, bucket_name=S3_BUCKET_NAME
)

(
Expand Down
26 changes: 13 additions & 13 deletions astronomer/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ async def get_files(
bucket: str,
key: str,
wildcard_match: bool,
delimiter: Optional[str] = '/',
delimiter: Optional[str] = "/",
) -> List[Any]:
"""Gets a list of files in the bucket"""
prefix = key
if wildcard_match:
prefix = re.split(r'[\[\*\?]', key, 1)[0]
prefix = re.split(r"[\[\*\?]", key, 1)[0]

paginator = client.get_paginator('list_objects_v2')
paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter=delimiter)
keys: List[Any] = []
async for page in response:
if 'Contents' in page:
_temp = [k for k in page['Contents'] if isinstance(k.get('Size', None), (int, float))]
if "Contents" in page:
_temp = [k for k in page["Contents"] if isinstance(k.get("Size", None), (int, float))]
keys = keys + _temp
return keys

Expand All @@ -122,23 +122,23 @@ async def _list_keys(
:return: a list of matched keys
:rtype: list
"""
prefix = prefix or ''
delimiter = delimiter or ''
prefix = prefix or ""
delimiter = delimiter or ""
config = {
'PageSize': page_size,
'MaxItems': max_items,
"PageSize": page_size,
"MaxItems": max_items,
}

paginator = client.get_paginator('list_objects_v2')
paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
)

keys = []
async for page in response:
if 'Contents' in page:
for k in page['Contents']:
keys.append(k['Key'])
if "Contents" in page:
for k in page["Contents"]:
keys.append(k["Key"])

return keys

Expand Down
4 changes: 1 addition & 3 deletions astronomer/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def execute(self, context: Dict[str, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(
self, context: Dict[str, Any], event: Any = None
) -> None: # pylint: disable=unused-argument
def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def execute(self, context: Dict[Any, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(
self, context: Dict[Any, Any], event: Any = None
) -> None: # pylint: disable=unused-argument
def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
16 changes: 5 additions & 11 deletions astronomer/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def execute(self, context: Dict[Any, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(
self, context: Dict[Any, Any], event: Any = None
) -> None: # pylint: disable=unused-argument
def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
if event["status"] == "error":
raise AirflowException(event["message"])
return None
Expand Down Expand Up @@ -154,9 +152,7 @@ def execute(self, context: Dict[Any, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(
self, context: Dict[Any, Any], event: Any = None
) -> None: # pylint: disable=unused-argument
def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
if event["status"] == "error":
raise AirflowException(event["message"])
return None
Expand Down Expand Up @@ -194,14 +190,14 @@ class S3KeysUnchangedSensorAsync(BaseOperator):
when this happens. If false an error will be raised.
"""

template_fields: Sequence[str] = ('bucket_name', 'prefix')
template_fields: Sequence[str] = ("bucket_name", "prefix")

def __init__(
self,
*,
bucket_name: str,
prefix: str,
aws_conn_id: str = 'aws_default',
aws_conn_id: str = "aws_default",
verify: Optional[Union[bool, str]] = None,
inactivity_period: float = 60 * 60,
min_objects: int = 1,
Expand Down Expand Up @@ -241,9 +237,7 @@ def execute(self, context: Dict[Any, Any]) -> None:
method_name="execute_complete",
)

def execute_complete(
self, context: Dict[Any, Any], event: Any = None
) -> None: # pylint: disable=unused-argument
def execute_complete(self, context: Dict[Any, Any], event: Any = None) -> None:
if event["status"] == "error":
raise AirflowException(event["message"])
return None
2 changes: 1 addition & 1 deletion astronomer/providers/amazon/aws/triggers/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _check_fn(data: List[Any], object_min_size: Optional[Union[int, float]] = 0)
:param data: List of the objects in S3 bucket.
:param object_min_size: Checks if the objects sizes are greater then this value.
"""
return all(f.get('Size', 0) > object_min_size for f in data if isinstance(f, dict))
return all(f.get("Size", 0) > object_min_size for f in data if isinstance(f, dict))

async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""
Expand Down
2 changes: 1 addition & 1 deletion astronomer/providers/core/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def execute(self, context):
)

@provide_session
def execute_complete(self, context, session, event=None): # pylint: disable=unused-argument
def execute_complete(self, context, session, event=None):
"""
Callback for when the trigger fires - returns immediately.
Verifies that there is a success status for each task via execution date.
Expand Down
2 changes: 1 addition & 1 deletion astronomer/providers/core/sensors/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def execute(self, context):
method_name="execute_complete",
)

def execute_complete(self, context, event=None): # pylint: disable=unused-argument
def execute_complete(self, context, event=None):
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
4 changes: 2 additions & 2 deletions astronomer/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def execute(self, context):
method_name="execute_complete",
)

def execute_complete(self, context, event=None): # pylint: disable=unused-argument
def execute_complete(self, context, event=None):
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -89,7 +89,7 @@ def execute(self, context):
method_name="execute_complete",
)

def execute_complete(self, context, event=None): # pylint: disable=unused-argument
def execute_complete(self, context, event=None):
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
76 changes: 49 additions & 27 deletions astronomer/providers/google/cloud/example_dags/example_gcs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
Example Airflow DAG for Google Object Existence Sensor.
Example Airflow DAG for Google Cloud Storage operators.
"""

import os
from datetime import datetime

from airflow import models
Expand All @@ -13,55 +14,76 @@
LocalFilesystemToGCSOperator,
)

from astronomer.providers.google.cloud.sensors.gcs import GCSObjectExistenceSensorAsync

START_DATE = datetime(2022, 1, 1)
from astronomer.providers.google.cloud.sensors.gcs import (
GCSObjectExistenceSensorAsync,
GCSObjectsWithPrefixExistenceSensorAsync,
GCSUploadSessionCompleteSensorAsync,
)

PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "astronomer-airflow-providers")
BUCKET_1 = os.environ.get("GCP_TEST_BUCKET", "test-gcs-example-bucket")
PATH_TO_UPLOAD_FILE = "dags/example_gcs.py"
CONNECTION_ID = "my_connection"
PROJECT_ID = "astronomer-airflow-providers"
BUCKET_1 = "test_bucket_for_dag"
BUCKET_FILE_LOCATION = "test.txt"
PATH_TO_UPLOAD_FILE_PREFIX = "example_"

BUCKET_FILE_LOCATION = "example_gcs.py"

with models.DAG(
"example_async_gcs_sensors",
start_date=START_DATE,
start_date=datetime(2021, 1, 1),
catchup=False,
schedule_interval="@once",
tags=["example"],
) as dag:

create_bucket1 = GCSCreateBucketOperator(
task_id="create_bucket1",
bucket_name=BUCKET_1,
project_id=PROJECT_ID,
gcp_conn_id=CONNECTION_ID,
# [START howto_create_bucket_task]
create_bucket = GCSCreateBucketOperator(
task_id="create_bucket", bucket_name=BUCKET_1, project_id=PROJECT_ID
)

# [END howto_create_bucket_task]
# [START howto_upload_file_task]
upload_file = LocalFilesystemToGCSOperator(
task_id="upload_file",
src=[PATH_TO_UPLOAD_FILE],
src=PATH_TO_UPLOAD_FILE,
dst=BUCKET_FILE_LOCATION,
bucket=BUCKET_1,
gcp_conn_id=CONNECTION_ID,
)

# [END howto_upload_file_task]
# [START howto_sensor_object_exists_task]
gcs_object_exists = GCSObjectExistenceSensorAsync(
bucket=BUCKET_1,
object=BUCKET_FILE_LOCATION,
task_id="gcs_task_object_existence_sensor",
google_cloud_conn_id=CONNECTION_ID,
task_id="gcs_object_exists_task",
)

delete_bucket_1 = GCSDeleteBucketOperator(
task_id="delete_bucket_1",
bucket_name=BUCKET_1,
gcp_conn_id=CONNECTION_ID,
# [END howto_sensor_object_exists_task]
# [START howto_sensor_object_with_prefix_exists_task]
gcs_object_with_prefix_exists = GCSObjectsWithPrefixExistenceSensorAsync(
bucket=BUCKET_1,
prefix=PATH_TO_UPLOAD_FILE_PREFIX,
task_id="gcs_object_with_prefix_exists_task",
)
# [END howto_sensor_object_with_prefix_exists_task]
# [START howto_sensor_gcs_upload_session_complete_task]
gcs_upload_session_complete = GCSUploadSessionCompleteSensorAsync(
bucket=BUCKET_1,
prefix=PATH_TO_UPLOAD_FILE_PREFIX,
inactivity_period=60,
min_objects=1,
allow_delete=True,
previous_objects=set(),
task_id="gcs_upload_session_complete_task",
)
# [END howto_sensor_gcs_upload_session_complete_task]
# [START howto_delete_buckettask]
delete_bucket = GCSDeleteBucketOperator(task_id="delete_bucket", bucket_name=BUCKET_1)
# [END howto_delete_buckettask]

create_bucket1 >> upload_file >> gcs_object_exists >> delete_bucket_1

(
create_bucket
>> upload_file
>> [gcs_object_exists, gcs_object_with_prefix_exists, gcs_upload_session_complete]
>> delete_bucket
)

if __name__ == "__main__":
dag.clear()
dag.run()
Loading