Skip to content

Add parameter to pass role ARN to GlueJobOperator #33408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class GlueJobHook(AwsBaseHook):
:param retry_limit: Maximum number of times to retry this job if it fails
:param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job
:param region_name: aws region name (example: us-east-1)
:param iam_role_name: AWS IAM Role for Glue Job Execution
:param iam_role_name: AWS IAM Role for Glue Job Execution. If set iam_role_arn must equal None.
:param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set iam_role_name must equal None.
:param create_job_kwargs: Extra arguments for Glue Job Creation
:param update_config: Update job configuration on Glue (default: False)

Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
retry_limit: int = 0,
num_of_dpus: int | float | None = None,
iam_role_name: str | None = None,
iam_role_arn: str | None = None,
create_job_kwargs: dict | None = None,
update_config: bool = False,
job_poll_interval: int | float = 6,
Expand All @@ -85,6 +87,7 @@ def __init__(
self.retry_limit = retry_limit
self.s3_bucket = s3_bucket
self.role_name = iam_role_name
self.role_arn = iam_role_arn
self.s3_glue_logs = "logs/glue-logs/"
self.create_job_kwargs = create_job_kwargs or {}
self.update_config = update_config
Expand All @@ -93,6 +96,8 @@ def __init__(
worker_type_exists = "WorkerType" in self.create_job_kwargs
num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs

if self.role_arn and self.role_name:
raise ValueError("Cannot set iam_role_arn and iam_role_name simultaneously")
if worker_type_exists and num_workers_exists:
if num_of_dpus is not None:
raise ValueError("Cannot specify num_of_dpus with custom WorkerType")
Expand All @@ -114,12 +119,14 @@ def create_glue_job_config(self) -> dict:
"ScriptLocation": self.script_location,
}
command = self.create_job_kwargs.pop("Command", default_command)
execution_role = self.get_iam_execution_role()
if not self.role_arn:
execution_role = self.get_iam_execution_role()
self.role_arn = execution_role["Role"]["Arn"]

config = {
"Name": self.job_name,
"Description": self.desc,
"Role": execution_role["Role"]["Arn"],
"Role": self.role_arn,
"ExecutionProperty": {"MaxConcurrentRuns": self.concurrent_run_limit},
"Command": command,
"MaxRetries": self.retry_limit,
Expand All @@ -144,7 +151,6 @@ def list_jobs(self) -> list:
return self.conn.get_jobs()

def get_iam_execution_role(self) -> dict:
"""Get IAM Role for job execution."""
try:
iam_client = self.get_session(region_name=self.region_name).client(
"iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class GlueJobOperator(BaseOperator):
:param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job.
:param region_name: aws region name (example: us-east-1)
:param s3_bucket: S3 bucket where logs and local etl script will be uploaded
:param iam_role_name: AWS IAM Role for Glue Job Execution
:param iam_role_name: AWS IAM Role for Glue Job Execution. If set iam_role_arn must equal None.
:param iam_role_arn: AWS IAM ARN for Glue Job Execution. If set iam_role_name must equal None.
:param create_job_kwargs: Extra arguments for Glue Job Creation
:param run_job_kwargs: Extra arguments for Glue Job Run
:param wait_for_completion: Whether to wait for job run completion. (default: True)
Expand All @@ -72,6 +73,7 @@ class GlueJobOperator(BaseOperator):
"create_job_kwargs",
"s3_bucket",
"iam_role_name",
"iam_role_arn",
)
template_ext: Sequence[str] = ()
template_fields_renderers = {
Expand All @@ -96,6 +98,7 @@ def __init__(
region_name: str | None = None,
s3_bucket: str | None = None,
iam_role_name: str | None = None,
iam_role_arn: str | None = None,
create_job_kwargs: dict | None = None,
run_job_kwargs: dict | None = None,
wait_for_completion: bool = True,
Expand All @@ -118,6 +121,7 @@ def __init__(
self.region_name = region_name
self.s3_bucket = s3_bucket
self.iam_role_name = iam_role_name
self.iam_role_arn = iam_role_arn
self.s3_protocol = "s3://"
self.s3_artifacts_prefix = "artifacts/glue-scripts/"
self.create_job_kwargs = create_job_kwargs
Expand Down Expand Up @@ -154,6 +158,7 @@ def glue_job_hook(self) -> GlueJobHook:
region_name=self.region_name,
s3_bucket=self.s3_bucket,
iam_role_name=self.iam_role_name,
iam_role_arn=self.iam_role_arn,
create_job_kwargs=self.create_job_kwargs,
update_config=self.update_config,
job_poll_interval=self.job_poll_interval,
Expand Down
50 changes: 50 additions & 0 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,56 @@ class JobNotFoundException(Exception):
assert result is False
mock_conn.get_job.assert_called_once_with(JobName=job_name)

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
@mock.patch.object(AwsBaseHook, "conn")
def test_role_arn_has_job_exists(self, mock_conn, mock_get_iam_execution_role):
"""
Calls 'create_or_update_glue_job' with no existing job.
Should create a new job.
"""

class JobNotFoundException(Exception):
pass

expected_job_name = "aws_test_glue_job"
job_description = "This is test case job from Airflow"
role_name = "my_test_role"
role_name_arn = "test_role"
some_s3_bucket = "bucket"

mock_conn.exceptions.EntityNotFoundException = JobNotFoundException
mock_conn.get_job.side_effect = JobNotFoundException()
mock_get_iam_execution_role.return_value = {"Role": {"RoleName": role_name, "Arn": role_name_arn}}

hook = GlueJobHook(
s3_bucket=some_s3_bucket,
job_name=expected_job_name,
desc=job_description,
concurrent_run_limit=2,
retry_limit=3,
num_of_dpus=5,
iam_role_arn=role_name_arn,
create_job_kwargs={"Command": {}},
region_name=self.some_aws_region,
update_config=True,
)

result = hook.create_or_update_glue_job()

mock_conn.get_job.assert_called_once_with(JobName=expected_job_name)
mock_conn.create_job.assert_called_once_with(
Command={},
Description=job_description,
ExecutionProperty={"MaxConcurrentRuns": 2},
LogUri=f"s3://{some_s3_bucket}/logs/glue-logs/{expected_job_name}",
MaxCapacity=5,
MaxRetries=3,
Name=expected_job_name,
Role=role_name_arn,
)
mock_conn.update_job.assert_not_called()
assert result == expected_job_name

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
@mock.patch.object(GlueJobHook, "conn")
def test_create_or_update_glue_job_create_new_job(self, mock_conn, mock_get_iam_execution_role):
Expand Down
23 changes: 23 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_render_template(self, create_task_instance_of_operator):
script_args="{{ dag.dag_id }}",
create_job_kwargs="{{ dag.dag_id }}",
iam_role_name="{{ dag.dag_id }}",
iam_role_arn="{{ dag.dag_id }}",
s3_bucket="{{ dag.dag_id }}",
job_name="{{ dag.dag_id }}",
)
Expand All @@ -57,6 +58,7 @@ def test_render_template(self, create_task_instance_of_operator):
assert DAG_ID == rendered_template.script_args
assert DAG_ID == rendered_template.create_job_kwargs
assert DAG_ID == rendered_template.iam_role_name
assert DAG_ID == rendered_template.iam_role_arn
assert DAG_ID == rendered_template.s3_bucket
assert DAG_ID == rendered_template.job_name

Expand Down Expand Up @@ -99,6 +101,27 @@ def test_execute_without_failure(
mock_print_job_logs.assert_not_called()
assert glue.job_name == JOB_NAME

@mock.patch.object(GlueJobHook, "initialize_job")
@mock.patch.object(GlueJobHook, "get_conn")
def test_role_arn_execute_deferrable(self, _, mock_initialize_job):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_arn="test_role",
deferrable=True,
)
mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID}

with pytest.raises(TaskDeferred) as defer:
glue.execute(mock.MagicMock())

assert defer.value.trigger.job_name == JOB_NAME
assert defer.value.trigger.run_id == JOB_RUN_ID

@mock.patch.object(GlueJobHook, "initialize_job")
@mock.patch.object(GlueJobHook, "get_conn")
def test_execute_deferrable(self, _, mock_initialize_job):
Expand Down