diff --git a/astronomer/providers/amazon/aws/operators/batch.py b/astronomer/providers/amazon/aws/operators/batch.py index e343c1b16..b29dd92fa 100644 --- a/astronomer/providers/amazon/aws/operators/batch.py +++ b/astronomer/providers/amazon/aws/operators/batch.py @@ -27,12 +27,14 @@ class BatchOperatorAsync(BatchOperator): :param job_name: the name for the job that will run on AWS Batch (templated) :param job_definition: the job definition name on AWS Batch :param job_queue: the queue name on AWS Batch - :param overrides: the `containerOverrides` parameter for boto3 (templated) + :param overrides: Removed in apache-airflow-providers-amazon release 8.0.0, use container_overrides instead with the + same value. + :param container_overrides: the `containerOverrides` parameter for boto3 (templated) :param array_properties: the `arrayProperties` parameter for boto3 :param parameters: the `parameters` for boto3 (templated) :param job_id: the job ID, usually unknown (None) until the submit_job operation gets the jobId defined by AWS Batch - :param waiters: an :class:`.BatchWaiters` object (see note below); + :param waiters: an :py:class:`.BatchWaiters` object (see note below); if None, polling is used with max_retries and status_retries. :param max_retries: exponential back-off retries, 4200 = 48 hours; polling is only used when waiters is None @@ -59,6 +61,11 @@ def execute(self, context: Context) -> None: Submit the job and get the job_id using which we defer and poll in trigger """ self.submit_job(context) + try: + container_overrides = self.container_overrides # type: ignore[attr-defined] + except AttributeError: # pragma: no cover + # For apache-airflow-providers-amazon<8.0.0 + container_overrides = self.overrides self.defer( timeout=self.execution_timeout, trigger=BatchOperatorTrigger( @@ -66,7 +73,7 @@ def execute(self, context: Context) -> None: job_name=self.job_name, job_definition=self.job_definition, job_queue=self.job_queue, - overrides=self.overrides, + container_overrides=container_overrides, array_properties=self.array_properties, parameters=self.parameters, waiters=self.waiters, diff --git a/astronomer/providers/amazon/aws/triggers/batch.py b/astronomer/providers/amazon/aws/triggers/batch.py index aaea43d48..de5eb7542 100644 --- a/astronomer/providers/amazon/aws/triggers/batch.py +++ b/astronomer/providers/amazon/aws/triggers/batch.py @@ -16,7 +16,7 @@ class BatchOperatorTrigger(BaseTrigger): :param job_name: the name for the job that will run on AWS Batch (templated) :param job_definition: the job definition name on AWS Batch :param job_queue: the queue name on AWS Batch - :param overrides: the `containerOverrides` parameter for boto3 (templated) + :param container_overrides: the `containerOverrides` parameter for boto3 (templated) :param array_properties: the `arrayProperties` parameter for boto3 :param parameters: the `parameters` for boto3 (templated) :param waiters: a :class:`.BatchWaiters` object (see note below); @@ -39,7 +39,7 @@ def __init__( job_name: str, job_definition: str, job_queue: str, - overrides: Dict[str, str], + container_overrides: Dict[str, str], array_properties: Dict[str, str], parameters: Dict[str, str], waiters: Any, @@ -54,7 +54,7 @@ def __init__( self.job_name = job_name self.job_definition = job_definition self.job_queue = job_queue - self.overrides = overrides or {} + self.container_overrides = container_overrides or {} self.array_properties = array_properties or {} self.parameters = parameters or {} self.waiters = waiters @@ -73,7 +73,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: "job_name": self.job_name, "job_definition": self.job_definition, "job_queue": self.job_queue, - "overrides": self.overrides, + "container_overrides": self.container_overrides, "array_properties": self.array_properties, "parameters": self.parameters, "waiters": self.waiters, diff --git a/tests/amazon/aws/operators/test_batch.py b/tests/amazon/aws/operators/test_batch.py index 9050fcd51..7ad3eb40b 100644 --- a/tests/amazon/aws/operators/test_batch.py +++ b/tests/amazon/aws/operators/test_batch.py @@ -31,7 +31,7 @@ def test_batch_op_async(self, get_client_type_mock): parameters=None, overrides={}, array_properties=None, - aws_conn_id="airflow_test", + aws_conn_id="aws_default", region_name="eu-west-1", tags={}, ) @@ -53,7 +53,7 @@ def test_batch_op_async_execute_failure(self, context): parameters=None, overrides={}, array_properties=None, - aws_conn_id="airflow_test", + aws_conn_id="aws_default", region_name="eu-west-1", tags={}, ) @@ -78,7 +78,7 @@ def test_batch_op_async_execute_complete(self, caplog, event): parameters=None, overrides={}, array_properties=None, - aws_conn_id="airflow_test", + aws_conn_id="aws_default", region_name="eu-west-1", tags={}, ) diff --git a/tests/amazon/aws/triggers/test_batch.py b/tests/amazon/aws/triggers/test_batch.py index 14db5d98d..e5602deb8 100644 --- a/tests/amazon/aws/triggers/test_batch.py +++ b/tests/amazon/aws/triggers/test_batch.py @@ -29,7 +29,7 @@ class TestBatchOperatorTrigger: max_retries=MAX_RETRIES, status_retries=STATUS_RETRIES, parameters={}, - overrides={}, + container_overrides={}, array_properties={}, region_name="eu-west-1", aws_conn_id="airflow_test", @@ -53,7 +53,7 @@ def test_batch_trigger_serialization(self): "max_retries": MAX_RETRIES, "status_retries": STATUS_RETRIES, "parameters": {}, - "overrides": {}, + "container_overrides": {}, "array_properties": {}, "region_name": "eu-west-1", "aws_conn_id": "airflow_test",