Skip to content

Commit 4d99705

Browse files
authored
Add deferrable option to LambdaCreateFunctionOperator (#33327)
1 parent 8cc68e2 commit 4d99705

File tree

8 files changed

+174
-3
lines changed

8 files changed

+174
-3
lines changed

airflow/providers/amazon/aws/operators/lambda_function.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
from __future__ import annotations
1919

2020
import json
21+
from datetime import timedelta
2122
from functools import cached_property
22-
from typing import TYPE_CHECKING, Sequence
23+
from typing import TYPE_CHECKING, Any, Sequence
2324

25+
from airflow import AirflowException
26+
from airflow.configuration import conf
2427
from airflow.models import BaseOperator
2528
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
29+
from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger
2630

2731
if TYPE_CHECKING:
2832
from airflow.utils.context import Context
@@ -50,6 +54,11 @@ class LambdaCreateFunctionOperator(BaseOperator):
5054
:param timeout: The amount of time (in seconds) that Lambda allows a function to run before stopping it.
5155
:param config: Optional dictionary for arbitrary parameters to the boto API create_lambda call.
5256
:param wait_for_completion: If True, the operator will wait until the function is active.
57+
:param waiter_max_attempts: Maximum number of attempts to poll the creation.
58+
:param waiter_delay: Number of seconds between polling the state of the creation.
59+
:param deferrable: If True, the operator will wait asynchronously for the creation to complete.
60+
This implies waiting for creation complete. This mode requires aiobotocore module to be installed.
61+
(default: False, but can be overridden in config file by setting default_deferrable to True)
5362
:param aws_conn_id: The AWS connection ID to use
5463
"""
5564

@@ -75,6 +84,9 @@ def __init__(
7584
timeout: int | None = None,
7685
config: dict = {},
7786
wait_for_completion: bool = False,
87+
waiter_max_attempts: int = 60,
88+
waiter_delay: int = 15,
89+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
7890
aws_conn_id: str = "aws_default",
7991
**kwargs,
8092
):
@@ -88,6 +100,9 @@ def __init__(
88100
self.timeout = timeout
89101
self.config = config
90102
self.wait_for_completion = wait_for_completion
103+
self.waiter_delay = waiter_delay
104+
self.waiter_max_attempts = waiter_max_attempts
105+
self.deferrable = deferrable
91106
self.aws_conn_id = aws_conn_id
92107

93108
@cached_property
@@ -108,6 +123,18 @@ def execute(self, context: Context):
108123
)
109124
self.log.info("Lambda response: %r", response)
110125

126+
if self.deferrable:
127+
self.defer(
128+
trigger=LambdaCreateFunctionCompleteTrigger(
129+
function_name=self.function_name,
130+
function_arn=response["FunctionArn"],
131+
waiter_delay=self.waiter_delay,
132+
waiter_max_attempts=self.waiter_max_attempts,
133+
aws_conn_id=self.aws_conn_id,
134+
),
135+
method_name="execute_complete",
136+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
137+
)
111138
if self.wait_for_completion:
112139
self.log.info("Wait for Lambda function to be active")
113140
waiter = self.hook.conn.get_waiter("function_active_v2")
@@ -117,6 +144,13 @@ def execute(self, context: Context):
117144

118145
return response.get("FunctionArn")
119146

147+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
148+
if not event or event["status"] != "success":
149+
raise AirflowException(f"Trigger error: event is {event}")
150+
151+
self.log.info("Lambda function created successfully")
152+
return event["function_arn"]
153+
120154

121155
class LambdaInvokeFunctionOperator(BaseOperator):
122156
"""

airflow/providers/amazon/aws/triggers/athena.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
class AthenaTrigger(AwsBaseWaiterTrigger):
2525
"""
26-
Trigger for RedshiftCreateClusterOperator.
26+
Trigger for AthenaOperator.
2727
2828
The trigger will asynchronously poll the boto3 API and wait for the
2929
Redshift cluster to be in the `available` state.

airflow/providers/amazon/aws/triggers/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
112112
@abstractmethod
113113
def hook(self) -> AwsGenericHook:
114114
"""Override in subclasses to return the right hook."""
115-
...
116115

117116
async def run(self) -> AsyncIterator[TriggerEvent]:
118117
hook = self.hook()
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
20+
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
21+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
22+
23+
24+
class LambdaCreateFunctionCompleteTrigger(AwsBaseWaiterTrigger):
25+
"""
26+
Trigger to poll for the completion of a Lambda function creation.
27+
28+
:param function_name: The function name
29+
:param function_arn: The function ARN
30+
:param waiter_delay: The amount of time in seconds to wait between attempts.
31+
:param waiter_max_attempts: The maximum number of attempts to be made.
32+
:param aws_conn_id: The Airflow connection used for AWS credentials.
33+
"""
34+
35+
def __init__(
36+
self,
37+
*,
38+
function_name: str,
39+
function_arn: str,
40+
waiter_delay: int = 60,
41+
waiter_max_attempts: int = 30,
42+
aws_conn_id: str | None = None,
43+
) -> None:
44+
45+
super().__init__(
46+
serialized_fields={"function_name": function_name, "function_arn": function_arn},
47+
waiter_name="function_active_v2",
48+
waiter_args={"FunctionName": function_name},
49+
failure_message="Lambda function creation failed",
50+
status_message="Status of Lambda function creation is",
51+
status_queries=[
52+
"Configuration.LastUpdateStatus",
53+
"Configuration.LastUpdateStatusReason",
54+
"Configuration.LastUpdateStatusReasonCode",
55+
],
56+
return_key="function_arn",
57+
return_value=function_arn,
58+
waiter_delay=waiter_delay,
59+
waiter_max_attempts=waiter_max_attempts,
60+
aws_conn_id=aws_conn_id,
61+
)
62+
63+
def hook(self) -> AwsGenericHook:
64+
return LambdaHook(aws_conn_id=self.aws_conn_id)

airflow/providers/amazon/provider.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,9 @@ triggers:
544544
- integration-name: Amazon EC2
545545
python-modules:
546546
- airflow.providers.amazon.aws.triggers.ec2
547+
- integration-name: AWS Lambda
548+
python-modules:
549+
- airflow.providers.amazon.aws.triggers.lambda_function
547550
- integration-name: Amazon Redshift
548551
python-modules:
549552
- airflow.providers.amazon.aws.triggers.redshift_cluster

docs/apache-airflow-providers-amazon/operators/lambda.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ Create an AWS Lambda function
4040

4141
To create an AWS lambda function you can use
4242
:class:`~airflow.providers.amazon.aws.operators.lambda_function.LambdaCreateFunctionOperator`.
43+
This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. This requires
44+
the aiobotocore module to be installed.
4345

4446
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_lambda.py
4547
:language: python

tests/providers/amazon/aws/operators/test_lambda_function.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import pytest
2424

25+
from airflow.exceptions import TaskDeferred
2526
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
2627
from airflow.providers.amazon.aws.operators.lambda_function import (
2728
LambdaCreateFunctionOperator,
@@ -69,6 +70,20 @@ def test_create_lambda_with_wait_for_completion(self, mock_hook_conn, mock_hook_
6970
mock_hook_create_lambda.assert_called_once()
7071
mock_hook_conn.get_waiter.assert_called_once_with("function_active_v2")
7172

73+
@mock.patch.object(LambdaHook, "create_lambda")
74+
def test_create_lambda_deferrable(self, _):
75+
operator = LambdaCreateFunctionOperator(
76+
task_id="task_test",
77+
function_name=FUNCTION_NAME,
78+
role=ROLE_ARN,
79+
code={
80+
"ImageUri": IMAGE_URI,
81+
},
82+
deferrable=True,
83+
)
84+
with pytest.raises(TaskDeferred):
85+
operator.execute(None)
86+
7287

7388
class TestLambdaInvokeFunctionOperator:
7489
@pytest.mark.parametrize(
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
import pytest
20+
21+
from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger
22+
23+
TEST_FUNCTION_NAME = "test-function-name"
24+
TEST_FUNCTION_ARN = "test-function-arn"
25+
TEST_WAITER_DELAY = 10
26+
TEST_WAITER_MAX_ATTEMPTS = 10
27+
TEST_AWS_CONN_ID = "test-conn-id"
28+
TEST_REGION_NAME = "test-region-name"
29+
30+
31+
class TestLambdaFunctionTriggers:
32+
@pytest.mark.parametrize(
33+
"trigger",
34+
[
35+
LambdaCreateFunctionCompleteTrigger(
36+
function_name=TEST_FUNCTION_NAME,
37+
function_arn=TEST_FUNCTION_ARN,
38+
aws_conn_id=TEST_AWS_CONN_ID,
39+
waiter_delay=TEST_WAITER_DELAY,
40+
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
41+
)
42+
],
43+
)
44+
def test_serialize_recreate(self, trigger):
45+
class_path, args = trigger.serialize()
46+
47+
class_name = class_path.split(".")[-1]
48+
clazz = globals()[class_name]
49+
instance = clazz(**args)
50+
51+
class_path2, args2 = instance.serialize()
52+
53+
assert class_path == class_path2
54+
assert args == args2

0 commit comments

Comments
 (0)