Skip to content

Commit 201c254

Browse files
authored
Implement self sufficient DAG for LivyOperatorAsync (#235)
1 parent 07c987a commit 201c254

File tree

3 files changed

+300
-8
lines changed

3 files changed

+300
-8
lines changed

.circleci/integration-tests/master_dag.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ def prepare_dag_dependency(task_info, execution_time):
138138
dag_run_ids.extend(ids)
139139
chain(*snowflake_trigger_tasks)
140140

141+
livy_task_info = [{"livy_dag": "example_livy_operator"}]
142+
livy_trigger_tasks, ids = prepare_dag_dependency(livy_task_info, "{{ ds }}")
143+
dag_run_ids.extend(ids)
144+
chain(*livy_trigger_tasks)
145+
141146
report = PythonOperator(
142147
task_id="get_report",
143148
python_callable=get_report,
@@ -158,6 +163,7 @@ def prepare_dag_dependency(task_info, execution_time):
158163
databricks_trigger_tasks[0],
159164
http_trigger_tasks[0],
160165
snowflake_trigger_tasks[0],
166+
livy_trigger_tasks[0],
161167
]
162168

163169
last_task = [
@@ -168,6 +174,7 @@ def prepare_dag_dependency(task_info, execution_time):
168174
databricks_trigger_tasks[-1],
169175
http_trigger_tasks[-1],
170176
snowflake_trigger_tasks[-1],
177+
livy_trigger_tasks[-1],
171178
]
172179

173180
last_task >> end

astronomer/providers/apache/livy/example_dags/example_livy.py

Lines changed: 291 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,321 @@
33
The tasks below trigger the computation of pi on the Spark instance
44
using the Java and Python executables provided in the example library.
55
"""
6+
import logging
67
import os
8+
import time
79
from datetime import datetime
10+
from typing import Any, List
811

9-
from airflow import DAG
12+
import boto3
13+
import paramiko
14+
from airflow import DAG, settings
15+
from airflow.models import Connection, Variable
16+
from airflow.operators.python import PythonOperator
17+
from airflow.providers.amazon.aws.operators.emr import (
18+
EmrCreateJobFlowOperator,
19+
EmrTerminateJobFlowOperator,
20+
)
21+
from requests import get
1022

1123
from astronomer.providers.apache.livy.operators.livy import LivyOperatorAsync
1224

1325
LIVY_JAVA_FILE = os.environ.get("LIVY_JAVA_FILE", "/spark-examples.jar")
1426
LIVY_PYTHON_FILE = os.environ.get("LIVY_PYTHON_FILE", "/user/hadoop/pi.py")
27+
JOB_FLOW_ROLE = os.environ.get("EMR_JOB_FLOW_ROLE", "EMR_EC2_DefaultRole")
28+
SERVICE_ROLE = os.environ.get("EMR_SERVICE_ROLE", "EMR_DefaultRole")
29+
PEM_FILENAME = os.environ.get("PEM_FILENAME", "providers_team_keypair")
30+
PRIVATE_KEY = Variable.get("providers_team_keypair")
31+
32+
AWS_S3_CREDS = {
33+
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY", "sample_aws_access_key_id"),
34+
"aws_secret_access_key": os.environ.get("AWS_SECRET_KEY", "sample_aws_secret_access_key"),
35+
"region_name": os.environ.get("AWS_REGION_NAME", "us-east-2"),
36+
}
37+
38+
COMMAND_TO_CREATE_PI_FILE: List[str] = [
39+
"curl https://raw.githubusercontent.com/apache/spark/master/examples/src/main/python/pi.py >> pi.py",
40+
"hadoop fs -copyFromLocal pi.py /user/hadoop",
41+
]
42+
43+
JOB_FLOW_OVERRIDES = {
44+
"Name": "team-provider-example-dag-livy-test",
45+
"ReleaseLabel": "emr-5.35.0",
46+
"Applications": [
47+
{"Name": "Spark"},
48+
{
49+
"Name": "Livy",
50+
},
51+
{
52+
"Name": "Hive",
53+
},
54+
{
55+
"Name": "Hadoop",
56+
},
57+
],
58+
"Instances": {
59+
"InstanceGroups": [
60+
{
61+
"Name": "Primary node",
62+
"Market": "ON_DEMAND",
63+
"InstanceRole": "MASTER",
64+
"InstanceType": "m4.large",
65+
"InstanceCount": 1,
66+
},
67+
],
68+
"Ec2KeyName": PEM_FILENAME,
69+
"KeepJobFlowAliveWhenNoSteps": True,
70+
"TerminationProtected": False,
71+
},
72+
"Steps": [],
73+
"JobFlowRole": JOB_FLOW_ROLE,
74+
"ServiceRole": SERVICE_ROLE,
75+
}
76+
77+
78+
def create_airflow_connection(task_instance: Any) -> None:
79+
"""
80+
Checks if airflow connection exists, if yes then deletes it.
81+
Then, create a new livy_default connection.
82+
"""
83+
conn = Connection(
84+
conn_id="livy_default",
85+
conn_type="livy",
86+
host=task_instance.xcom_pull(
87+
key="cluster_response_master_public_dns", task_ids=["describe_created_cluster"]
88+
)[0],
89+
login="",
90+
password="",
91+
port=8998,
92+
) # create a connection object
93+
94+
session = settings.Session()
95+
connection = session.query(Connection).filter_by(conn_id=conn.conn_id).one_or_none()
96+
if connection is None:
97+
logging.info("Connection %s doesn't exist.", str(conn.conn_id))
98+
else:
99+
session.delete(connection)
100+
session.commit()
101+
logging.info("Connection %s deleted.", str(conn.conn_id))
102+
103+
session.add(conn)
104+
session.commit() # it will insert the connection object programmatically.
105+
logging.info("Connection livy_default is created")
106+
107+
108+
def add_inbound_rule_for_security_group(task_instance: Any) -> None:
109+
"""
110+
Sets the inbound rule for the aws security group, based on
111+
current ip address of the system.
112+
"""
113+
current_docker_ip = get("https://api.ipify.org").text
114+
logging.info("Current ip address is: %s", str(current_docker_ip))
115+
client = boto3.client("ec2", **AWS_S3_CREDS)
116+
117+
response = client.describe_security_groups(
118+
GroupIds=[
119+
task_instance.xcom_pull(
120+
key="cluster_response_master_security_group", task_ids=["describe_created_cluster"]
121+
)[0]
122+
]
123+
)
124+
current_ip_permissions = response["SecurityGroups"][0]["IpPermissions"]
125+
ip_exists = False
126+
for current_ip in current_ip_permissions:
127+
ip_ranges = current_ip["IpRanges"]
128+
for ip in ip_ranges:
129+
if ip["CidrIp"] == str(current_docker_ip) + "/32":
130+
ip_exists = True
131+
132+
if not ip_exists:
133+
# open port for port 8998
134+
client.authorize_security_group_ingress(
135+
GroupId=task_instance.xcom_pull(
136+
key="cluster_response_master_security_group", task_ids=["describe_created_cluster"]
137+
)[0],
138+
IpPermissions=[
139+
{
140+
"IpProtocol": "tcp",
141+
"FromPort": 8998,
142+
"ToPort": 8998,
143+
"IpRanges": [{"CidrIp": str(current_docker_ip) + "/32"}],
144+
}
145+
],
146+
)
147+
148+
# open port for port 22 for ssh and copy file for hdfs
149+
client.authorize_security_group_ingress(
150+
GroupId=task_instance.xcom_pull(
151+
key="cluster_response_master_security_group", task_ids=["describe_created_cluster"]
152+
)[0],
153+
IpPermissions=[
154+
{
155+
"IpProtocol": "tcp",
156+
"FromPort": 22,
157+
"ToPort": 22,
158+
"IpRanges": [{"CidrIp": str(current_docker_ip) + "/32"}],
159+
}
160+
],
161+
)
162+
163+
164+
def create_key_pair() -> None:
165+
"""
166+
Load the private_key from airflow variable and creates a pem_file
167+
at /tmp/.
168+
"""
169+
# remove the file if it exists
170+
if os.path.exists(f"/tmp/{PEM_FILENAME}.pem"):
171+
os.remove(f"/tmp/{PEM_FILENAME}.pem")
172+
173+
# read the content for pem file from Variable set on Airflow UI.
174+
with open(f"/tmp/{PEM_FILENAME}.pem", "w+") as fh:
175+
fh.write(PRIVATE_KEY)
176+
177+
# write private key to file with 400 permissions
178+
os.chmod(f"/tmp/{PEM_FILENAME}.pem", 0o400)
179+
180+
181+
def ssh_and_run_command(task_instance: Any, **kwargs: Any) -> None:
182+
"""
183+
SSH into the machine and execute the bash script from the list
184+
of commands.
185+
"""
186+
key = paramiko.RSAKey.from_private_key_file(kwargs["path_to_pem_file"])
187+
client = paramiko.SSHClient()
188+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
189+
# Connect/ssh to an instance
190+
cluster_response_master_public_dns = task_instance.xcom_pull(
191+
key="cluster_response_master_public_dns", task_ids=["describe_created_cluster"]
192+
)[0]
193+
try:
194+
client.connect(hostname=cluster_response_master_public_dns, username=kwargs["username"], pkey=key)
195+
196+
# Execute a command(cmd) after connecting/ssh to an instance
197+
for command in kwargs["command"]:
198+
stdin, stdout, stderr = client.exec_command(command)
199+
stdout.read()
200+
201+
# close the client connection once the job is done
202+
client.close()
203+
except Exception as exc:
204+
raise Exception("Got an exception as %s.", str(exc))
205+
206+
207+
def get_cluster_details(task_instance: Any) -> None:
208+
"""
209+
Fetches the cluster details and stores EmrManagedMasterSecurityGroup and
210+
MasterPublicDnsName in the XCOM.
211+
"""
212+
client = boto3.client("emr", **AWS_S3_CREDS)
213+
response = client.describe_cluster(
214+
ClusterId=str(task_instance.xcom_pull(key="return_value", task_ids=["cluster_creator"])[0])
215+
)
216+
while (
217+
"MasterPublicDnsName" not in response["Cluster"]
218+
and response["Cluster"]["Status"]["State"] != "WAITING"
219+
):
220+
logging.info("wait for 11 minutes to get the MasterPublicDnsName")
221+
time.sleep(660)
222+
response = client.describe_cluster(
223+
ClusterId=str(task_instance.xcom_pull(key="return_value", task_ids=["cluster_creator"])[0])
224+
)
225+
logging.info("current response from ams emr: %s", str(response))
226+
task_instance.xcom_push(
227+
key="cluster_response_master_public_dns", value=response["Cluster"]["MasterPublicDnsName"]
228+
)
229+
task_instance.xcom_push(
230+
key="cluster_response_master_security_group",
231+
value=response["Cluster"]["Ec2InstanceAttributes"]["EmrManagedMasterSecurityGroup"],
232+
)
233+
15234

16235
with DAG(
17236
dag_id="example_livy_operator",
18-
default_args={"args": [10]},
19237
schedule_interval=None,
20238
start_date=datetime(2021, 1, 1),
21239
catchup=False,
22240
tags=["example", "async", "deferrable", "LivyOperatorAsync"],
23241
) as dag:
242+
# [START howto_create_key_pair_file]
243+
create_key_pair_file = PythonOperator(
244+
task_id="create_key_pair_file",
245+
python_callable=create_key_pair,
246+
)
247+
# [END howto_create_key_pair_file]
248+
249+
# [START howto_operator_emr_create_job_flow]
250+
cluster_creator = EmrCreateJobFlowOperator(
251+
task_id="cluster_creator",
252+
job_flow_overrides=JOB_FLOW_OVERRIDES,
253+
)
254+
# [END howto_operator_emr_create_job_flow]
255+
256+
# [START describe_created_cluster]
257+
describe_created_cluster = PythonOperator(
258+
task_id="describe_created_cluster", python_callable=get_cluster_details
259+
)
260+
# [END describe_created_cluster]
24261

25-
# [START create_livy]
262+
# [START add_example_pi_file_in_hdfs]
263+
ssh_and_copy_pifile_to_hdfs = PythonOperator(
264+
task_id="ssh_and_copy_pifile_to_hdfs",
265+
python_callable=ssh_and_run_command,
266+
op_kwargs={
267+
"path_to_pem_file": f"/tmp/{PEM_FILENAME}.pem",
268+
"username": "hadoop",
269+
"command": COMMAND_TO_CREATE_PI_FILE,
270+
},
271+
)
272+
# [END add_example_pi_file_in_hdfs]
273+
274+
# [START add_ip_address_for_inbound_rules]
275+
get_and_add_ip_address_for_inbound_rules = PythonOperator(
276+
task_id="get_and_add_ip_address_for_inbound_rules",
277+
python_callable=add_inbound_rule_for_security_group,
278+
)
279+
# [END add_ip_address_for_inbound_rules]
280+
281+
# [START create_airflow_connection_for_livy]
282+
create_airflow_connection_for_livy = PythonOperator(
283+
task_id="create_airflow_connection_for_livy", python_callable=create_airflow_connection
284+
)
285+
# [END create_airflow_connection_for_livy]
286+
287+
# [START run_pi_example_without_polling_interval]
26288
livy_java_task = LivyOperatorAsync(
27-
task_id="pi_java_task",
289+
task_id="livy_java_task",
28290
file=LIVY_JAVA_FILE,
29291
num_executors=1,
30292
conf={
31293
"spark.shuffle.compress": "false",
32294
},
33295
class_name="org.apache.spark.examples.SparkPi",
34-
polling_interval=0,
35296
)
297+
# [END run_pi_spark_without_polling_interval]
36298

37-
livy_python_task = LivyOperatorAsync(task_id="pi_python_task", file=LIVY_PYTHON_FILE, polling_interval=30)
299+
# [START run_py_example_with_polling_interval]
300+
livy_python_task = LivyOperatorAsync(
301+
task_id="livy_python_task", file=LIVY_PYTHON_FILE, polling_interval=30
302+
)
303+
# [END run_py_example_with_polling_interval]
38304

39-
livy_java_task >> livy_python_task
40-
# [END create_livy]
305+
# [START howto_operator_emr_terminate_job_flow]
306+
remove_cluster = EmrTerminateJobFlowOperator(
307+
task_id="remove_cluster",
308+
job_flow_id=cluster_creator.output,
309+
trigger_rule="all_done",
310+
)
311+
# [END howto_operator_emr_terminate_job_flow]
312+
313+
(
314+
create_key_pair_file
315+
>> cluster_creator
316+
>> describe_created_cluster
317+
>> get_and_add_ip_address_for_inbound_rules
318+
>> ssh_and_copy_pifile_to_hdfs
319+
>> create_airflow_connection_for_livy
320+
>> livy_java_task
321+
>> livy_python_task
322+
>> remove_cluster
323+
)

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ snowflake =
6969
apache-airflow-providers-snowflake
7070
apache.livy =
7171
apache-airflow-providers-apache-livy
72+
paramiko
7273
docs =
7374
sphinx
7475
sphinx-autoapi
@@ -117,6 +118,7 @@ all =
117118
gcloud-aio-storage
118119
google-api-core>=1.25.1,<2.0.0
119120
kubernetes_asyncio
121+
paramiko
120122

121123
[options.packages.find]
122124
include =

0 commit comments

Comments
 (0)