Skip to content

Fix Batch operator async for Amazon provider release 8.0.0 #978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions astronomer/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ class BatchOperatorAsync(BatchOperator):
:param job_name: the name for the job that will run on AWS Batch (templated)
:param job_definition: the job definition name on AWS Batch
:param job_queue: the queue name on AWS Batch
:param overrides: the `containerOverrides` parameter for boto3 (templated)
:param overrides: Removed in apache-airflow-providers-amazon release 8.0.0, use container_overrides instead with the
same value.
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param array_properties: the `arrayProperties` parameter for boto3
:param parameters: the `parameters` for boto3 (templated)
:param job_id: the job ID, usually unknown (None) until the
submit_job operation gets the jobId defined by AWS Batch
:param waiters: an :class:`.BatchWaiters` object (see note below);
:param waiters: an :py:class:`.BatchWaiters` object (see note below);
if None, polling is used with max_retries and status_retries.
:param max_retries: exponential back-off retries, 4200 = 48 hours;
polling is only used when waiters is None
Expand All @@ -59,14 +61,19 @@ def execute(self, context: Context) -> None:
Submit the job and get the job_id using which we defer and poll in trigger
"""
self.submit_job(context)
try:
container_overrides = self.container_overrides # type: ignore[attr-defined]
except AttributeError: # pragma: no cover
# For apache-airflow-providers-amazon<8.0.0
container_overrides = self.overrides
self.defer(
timeout=self.execution_timeout,
trigger=BatchOperatorTrigger(
job_id=self.job_id,
job_name=self.job_name,
job_definition=self.job_definition,
job_queue=self.job_queue,
overrides=self.overrides,
container_overrides=container_overrides,
array_properties=self.array_properties,
parameters=self.parameters,
waiters=self.waiters,
Expand Down
8 changes: 4 additions & 4 deletions astronomer/providers/amazon/aws/triggers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class BatchOperatorTrigger(BaseTrigger):
:param job_name: the name for the job that will run on AWS Batch (templated)
:param job_definition: the job definition name on AWS Batch
:param job_queue: the queue name on AWS Batch
:param overrides: the `containerOverrides` parameter for boto3 (templated)
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param array_properties: the `arrayProperties` parameter for boto3
:param parameters: the `parameters` for boto3 (templated)
:param waiters: a :class:`.BatchWaiters` object (see note below);
Expand All @@ -39,7 +39,7 @@ def __init__(
job_name: str,
job_definition: str,
job_queue: str,
overrides: Dict[str, str],
container_overrides: Dict[str, str],
array_properties: Dict[str, str],
parameters: Dict[str, str],
waiters: Any,
Expand All @@ -54,7 +54,7 @@ def __init__(
self.job_name = job_name
self.job_definition = job_definition
self.job_queue = job_queue
self.overrides = overrides or {}
self.container_overrides = container_overrides or {}
self.array_properties = array_properties or {}
self.parameters = parameters or {}
self.waiters = waiters
Expand All @@ -73,7 +73,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
"job_name": self.job_name,
"job_definition": self.job_definition,
"job_queue": self.job_queue,
"overrides": self.overrides,
"container_overrides": self.container_overrides,
"array_properties": self.array_properties,
"parameters": self.parameters,
"waiters": self.waiters,
Expand Down
6 changes: 3 additions & 3 deletions tests/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_batch_op_async(self, get_client_type_mock):
parameters=None,
overrides={},
array_properties=None,
aws_conn_id="airflow_test",
aws_conn_id="aws_default",
region_name="eu-west-1",
tags={},
)
Expand All @@ -53,7 +53,7 @@ def test_batch_op_async_execute_failure(self, context):
parameters=None,
overrides={},
array_properties=None,
aws_conn_id="airflow_test",
aws_conn_id="aws_default",
region_name="eu-west-1",
tags={},
)
Expand All @@ -78,7 +78,7 @@ def test_batch_op_async_execute_complete(self, caplog, event):
parameters=None,
overrides={},
array_properties=None,
aws_conn_id="airflow_test",
aws_conn_id="aws_default",
region_name="eu-west-1",
tags={},
)
Expand Down
4 changes: 2 additions & 2 deletions tests/amazon/aws/triggers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TestBatchOperatorTrigger:
max_retries=MAX_RETRIES,
status_retries=STATUS_RETRIES,
parameters={},
overrides={},
container_overrides={},
array_properties={},
region_name="eu-west-1",
aws_conn_id="airflow_test",
Expand All @@ -53,7 +53,7 @@ def test_batch_trigger_serialization(self):
"max_retries": MAX_RETRIES,
"status_retries": STATUS_RETRIES,
"parameters": {},
"overrides": {},
"container_overrides": {},
"array_properties": {},
"region_name": "eu-west-1",
"aws_conn_id": "airflow_test",
Expand Down