|
3 | 3 | The tasks below trigger the computation of pi on the Spark instance
|
4 | 4 | using the Java and Python executables provided in the example library.
|
5 | 5 | """
|
| 6 | +import logging |
6 | 7 | import os
|
| 8 | +import time |
7 | 9 | from datetime import datetime
|
| 10 | +from typing import Any, List |
8 | 11 |
|
9 |
| -from airflow import DAG |
| 12 | +import boto3 |
| 13 | +import paramiko |
| 14 | +from airflow import DAG, settings |
| 15 | +from airflow.models import Connection, Variable |
| 16 | +from airflow.operators.python import PythonOperator |
| 17 | +from airflow.providers.amazon.aws.operators.emr import ( |
| 18 | + EmrCreateJobFlowOperator, |
| 19 | + EmrTerminateJobFlowOperator, |
| 20 | +) |
| 21 | +from requests import get |
10 | 22 |
|
11 | 23 | from astronomer.providers.apache.livy.operators.livy import LivyOperatorAsync
|
12 | 24 |
|
13 | 25 | LIVY_JAVA_FILE = os.environ.get("LIVY_JAVA_FILE", "/spark-examples.jar")
|
14 | 26 | LIVY_PYTHON_FILE = os.environ.get("LIVY_PYTHON_FILE", "/user/hadoop/pi.py")
|
| 27 | +JOB_FLOW_ROLE = os.environ.get("EMR_JOB_FLOW_ROLE", "EMR_EC2_DefaultRole") |
| 28 | +SERVICE_ROLE = os.environ.get("EMR_SERVICE_ROLE", "EMR_DefaultRole") |
| 29 | +PEM_FILENAME = os.environ.get("PEM_FILENAME", "providers_team_keypair") |
| 30 | +PRIVATE_KEY = Variable.get("providers_team_keypair") |
| 31 | + |
| 32 | +AWS_S3_CREDS = { |
| 33 | + "aws_access_key_id": os.environ.get("AWS_ACCESS_KEY", "sample_aws_access_key_id"), |
| 34 | + "aws_secret_access_key": os.environ.get("AWS_SECRET_KEY", "sample_aws_secret_access_key"), |
| 35 | + "region_name": os.environ.get("AWS_REGION_NAME", "us-east-2"), |
| 36 | +} |
| 37 | + |
| 38 | +COMMAND_TO_CREATE_PI_FILE: List[str] = [ |
| 39 | + "curl https://raw.githubusercontent.com/apache/spark/master/examples/src/main/python/pi.py >> pi.py", |
| 40 | + "hadoop fs -copyFromLocal pi.py /user/hadoop", |
| 41 | +] |
| 42 | + |
| 43 | +JOB_FLOW_OVERRIDES = { |
| 44 | + "Name": "team-provider-example-dag-livy-test", |
| 45 | + "ReleaseLabel": "emr-5.35.0", |
| 46 | + "Applications": [ |
| 47 | + {"Name": "Spark"}, |
| 48 | + { |
| 49 | + "Name": "Livy", |
| 50 | + }, |
| 51 | + { |
| 52 | + "Name": "Hive", |
| 53 | + }, |
| 54 | + { |
| 55 | + "Name": "Hadoop", |
| 56 | + }, |
| 57 | + ], |
| 58 | + "Instances": { |
| 59 | + "InstanceGroups": [ |
| 60 | + { |
| 61 | + "Name": "Primary node", |
| 62 | + "Market": "ON_DEMAND", |
| 63 | + "InstanceRole": "MASTER", |
| 64 | + "InstanceType": "m4.large", |
| 65 | + "InstanceCount": 1, |
| 66 | + }, |
| 67 | + ], |
| 68 | + "Ec2KeyName": PEM_FILENAME, |
| 69 | + "KeepJobFlowAliveWhenNoSteps": True, |
| 70 | + "TerminationProtected": False, |
| 71 | + }, |
| 72 | + "Steps": [], |
| 73 | + "JobFlowRole": JOB_FLOW_ROLE, |
| 74 | + "ServiceRole": SERVICE_ROLE, |
| 75 | +} |
| 76 | + |
| 77 | + |
| 78 | +def create_airflow_connection(task_instance: Any) -> None: |
| 79 | + """ |
| 80 | + Checks if airflow connection exists, if yes then deletes it. |
| 81 | + Then, create a new livy_default connection. |
| 82 | + """ |
| 83 | + conn = Connection( |
| 84 | + conn_id="livy_default", |
| 85 | + conn_type="livy", |
| 86 | + host=task_instance.xcom_pull( |
| 87 | + key="cluster_response_master_public_dns", task_ids=["describe_created_cluster"] |
| 88 | + )[0], |
| 89 | + login="", |
| 90 | + password="", |
| 91 | + port=8998, |
| 92 | + ) # create a connection object |
| 93 | + |
| 94 | + session = settings.Session() |
| 95 | + connection = session.query(Connection).filter_by(conn_id=conn.conn_id).one_or_none() |
| 96 | + if connection is None: |
| 97 | + logging.info("Connection %s doesn't exist.", str(conn.conn_id)) |
| 98 | + else: |
| 99 | + session.delete(connection) |
| 100 | + session.commit() |
| 101 | + logging.info("Connection %s deleted.", str(conn.conn_id)) |
| 102 | + |
| 103 | + session.add(conn) |
| 104 | + session.commit() # it will insert the connection object programmatically. |
| 105 | + logging.info("Connection livy_default is created") |
| 106 | + |
| 107 | + |
| 108 | +def add_inbound_rule_for_security_group(task_instance: Any) -> None: |
| 109 | + """ |
| 110 | + Sets the inbound rule for the aws security group, based on |
| 111 | + current ip address of the system. |
| 112 | + """ |
| 113 | + current_docker_ip = get("https://api.ipify.org").text |
| 114 | + logging.info("Current ip address is: %s", str(current_docker_ip)) |
| 115 | + client = boto3.client("ec2", **AWS_S3_CREDS) |
| 116 | + |
| 117 | + response = client.describe_security_groups( |
| 118 | + GroupIds=[ |
| 119 | + task_instance.xcom_pull( |
| 120 | + key="cluster_response_master_security_group", task_ids=["describe_created_cluster"] |
| 121 | + )[0] |
| 122 | + ] |
| 123 | + ) |
| 124 | + current_ip_permissions = response["SecurityGroups"][0]["IpPermissions"] |
| 125 | + ip_exists = False |
| 126 | + for current_ip in current_ip_permissions: |
| 127 | + ip_ranges = current_ip["IpRanges"] |
| 128 | + for ip in ip_ranges: |
| 129 | + if ip["CidrIp"] == str(current_docker_ip) + "/32": |
| 130 | + ip_exists = True |
| 131 | + |
| 132 | + if not ip_exists: |
| 133 | + # open port for port 8998 |
| 134 | + client.authorize_security_group_ingress( |
| 135 | + GroupId=task_instance.xcom_pull( |
| 136 | + key="cluster_response_master_security_group", task_ids=["describe_created_cluster"] |
| 137 | + )[0], |
| 138 | + IpPermissions=[ |
| 139 | + { |
| 140 | + "IpProtocol": "tcp", |
| 141 | + "FromPort": 8998, |
| 142 | + "ToPort": 8998, |
| 143 | + "IpRanges": [{"CidrIp": str(current_docker_ip) + "/32"}], |
| 144 | + } |
| 145 | + ], |
| 146 | + ) |
| 147 | + |
| 148 | + # open port for port 22 for ssh and copy file for hdfs |
| 149 | + client.authorize_security_group_ingress( |
| 150 | + GroupId=task_instance.xcom_pull( |
| 151 | + key="cluster_response_master_security_group", task_ids=["describe_created_cluster"] |
| 152 | + )[0], |
| 153 | + IpPermissions=[ |
| 154 | + { |
| 155 | + "IpProtocol": "tcp", |
| 156 | + "FromPort": 22, |
| 157 | + "ToPort": 22, |
| 158 | + "IpRanges": [{"CidrIp": str(current_docker_ip) + "/32"}], |
| 159 | + } |
| 160 | + ], |
| 161 | + ) |
| 162 | + |
| 163 | + |
| 164 | +def create_key_pair() -> None: |
| 165 | + """ |
| 166 | + Load the private_key from airflow variable and creates a pem_file |
| 167 | + at /tmp/. |
| 168 | + """ |
| 169 | + # remove the file if it exists |
| 170 | + if os.path.exists(f"/tmp/{PEM_FILENAME}.pem"): |
| 171 | + os.remove(f"/tmp/{PEM_FILENAME}.pem") |
| 172 | + |
| 173 | + # read the content for pem file from Variable set on Airflow UI. |
| 174 | + with open(f"/tmp/{PEM_FILENAME}.pem", "w+") as fh: |
| 175 | + fh.write(PRIVATE_KEY) |
| 176 | + |
| 177 | + # write private key to file with 400 permissions |
| 178 | + os.chmod(f"/tmp/{PEM_FILENAME}.pem", 0o400) |
| 179 | + |
| 180 | + |
| 181 | +def ssh_and_run_command(task_instance: Any, **kwargs: Any) -> None: |
| 182 | + """ |
| 183 | + SSH into the machine and execute the bash script from the list |
| 184 | + of commands. |
| 185 | + """ |
| 186 | + key = paramiko.RSAKey.from_private_key_file(kwargs["path_to_pem_file"]) |
| 187 | + client = paramiko.SSHClient() |
| 188 | + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) |
| 189 | + # Connect/ssh to an instance |
| 190 | + cluster_response_master_public_dns = task_instance.xcom_pull( |
| 191 | + key="cluster_response_master_public_dns", task_ids=["describe_created_cluster"] |
| 192 | + )[0] |
| 193 | + try: |
| 194 | + client.connect(hostname=cluster_response_master_public_dns, username=kwargs["username"], pkey=key) |
| 195 | + |
| 196 | + # Execute a command(cmd) after connecting/ssh to an instance |
| 197 | + for command in kwargs["command"]: |
| 198 | + stdin, stdout, stderr = client.exec_command(command) |
| 199 | + stdout.read() |
| 200 | + |
| 201 | + # close the client connection once the job is done |
| 202 | + client.close() |
| 203 | + except Exception as exc: |
| 204 | + raise Exception("Got an exception as %s.", str(exc)) |
| 205 | + |
| 206 | + |
| 207 | +def get_cluster_details(task_instance: Any) -> None: |
| 208 | + """ |
| 209 | + Fetches the cluster details and stores EmrManagedMasterSecurityGroup and |
| 210 | + MasterPublicDnsName in the XCOM. |
| 211 | + """ |
| 212 | + client = boto3.client("emr", **AWS_S3_CREDS) |
| 213 | + response = client.describe_cluster( |
| 214 | + ClusterId=str(task_instance.xcom_pull(key="return_value", task_ids=["cluster_creator"])[0]) |
| 215 | + ) |
| 216 | + while ( |
| 217 | + "MasterPublicDnsName" not in response["Cluster"] |
| 218 | + and response["Cluster"]["Status"]["State"] != "WAITING" |
| 219 | + ): |
| 220 | + logging.info("wait for 11 minutes to get the MasterPublicDnsName") |
| 221 | + time.sleep(660) |
| 222 | + response = client.describe_cluster( |
| 223 | + ClusterId=str(task_instance.xcom_pull(key="return_value", task_ids=["cluster_creator"])[0]) |
| 224 | + ) |
| 225 | + logging.info("current response from ams emr: %s", str(response)) |
| 226 | + task_instance.xcom_push( |
| 227 | + key="cluster_response_master_public_dns", value=response["Cluster"]["MasterPublicDnsName"] |
| 228 | + ) |
| 229 | + task_instance.xcom_push( |
| 230 | + key="cluster_response_master_security_group", |
| 231 | + value=response["Cluster"]["Ec2InstanceAttributes"]["EmrManagedMasterSecurityGroup"], |
| 232 | + ) |
| 233 | + |
15 | 234 |
|
16 | 235 | with DAG(
|
17 | 236 | dag_id="example_livy_operator",
|
18 |
| - default_args={"args": [10]}, |
19 | 237 | schedule_interval=None,
|
20 | 238 | start_date=datetime(2021, 1, 1),
|
21 | 239 | catchup=False,
|
22 | 240 | tags=["example", "async", "deferrable", "LivyOperatorAsync"],
|
23 | 241 | ) as dag:
|
| 242 | + # [START howto_create_key_pair_file] |
| 243 | + create_key_pair_file = PythonOperator( |
| 244 | + task_id="create_key_pair_file", |
| 245 | + python_callable=create_key_pair, |
| 246 | + ) |
| 247 | + # [END howto_create_key_pair_file] |
| 248 | + |
| 249 | + # [START howto_operator_emr_create_job_flow] |
| 250 | + cluster_creator = EmrCreateJobFlowOperator( |
| 251 | + task_id="cluster_creator", |
| 252 | + job_flow_overrides=JOB_FLOW_OVERRIDES, |
| 253 | + ) |
| 254 | + # [END howto_operator_emr_create_job_flow] |
| 255 | + |
| 256 | + # [START describe_created_cluster] |
| 257 | + describe_created_cluster = PythonOperator( |
| 258 | + task_id="describe_created_cluster", python_callable=get_cluster_details |
| 259 | + ) |
| 260 | + # [END describe_created_cluster] |
24 | 261 |
|
25 |
| - # [START create_livy] |
| 262 | + # [START add_example_pi_file_in_hdfs] |
| 263 | + ssh_and_copy_pifile_to_hdfs = PythonOperator( |
| 264 | + task_id="ssh_and_copy_pifile_to_hdfs", |
| 265 | + python_callable=ssh_and_run_command, |
| 266 | + op_kwargs={ |
| 267 | + "path_to_pem_file": f"/tmp/{PEM_FILENAME}.pem", |
| 268 | + "username": "hadoop", |
| 269 | + "command": COMMAND_TO_CREATE_PI_FILE, |
| 270 | + }, |
| 271 | + ) |
| 272 | + # [END add_example_pi_file_in_hdfs] |
| 273 | + |
| 274 | + # [START add_ip_address_for_inbound_rules] |
| 275 | + get_and_add_ip_address_for_inbound_rules = PythonOperator( |
| 276 | + task_id="get_and_add_ip_address_for_inbound_rules", |
| 277 | + python_callable=add_inbound_rule_for_security_group, |
| 278 | + ) |
| 279 | + # [END add_ip_address_for_inbound_rules] |
| 280 | + |
| 281 | + # [START create_airflow_connection_for_livy] |
| 282 | + create_airflow_connection_for_livy = PythonOperator( |
| 283 | + task_id="create_airflow_connection_for_livy", python_callable=create_airflow_connection |
| 284 | + ) |
| 285 | + # [END create_airflow_connection_for_livy] |
| 286 | + |
| 287 | + # [START run_pi_example_without_polling_interval] |
26 | 288 | livy_java_task = LivyOperatorAsync(
|
27 |
| - task_id="pi_java_task", |
| 289 | + task_id="livy_java_task", |
28 | 290 | file=LIVY_JAVA_FILE,
|
29 | 291 | num_executors=1,
|
30 | 292 | conf={
|
31 | 293 | "spark.shuffle.compress": "false",
|
32 | 294 | },
|
33 | 295 | class_name="org.apache.spark.examples.SparkPi",
|
34 |
| - polling_interval=0, |
35 | 296 | )
|
| 297 | + # [END run_pi_spark_without_polling_interval] |
36 | 298 |
|
37 |
| - livy_python_task = LivyOperatorAsync(task_id="pi_python_task", file=LIVY_PYTHON_FILE, polling_interval=30) |
| 299 | + # [START run_py_example_with_polling_interval] |
| 300 | + livy_python_task = LivyOperatorAsync( |
| 301 | + task_id="livy_python_task", file=LIVY_PYTHON_FILE, polling_interval=30 |
| 302 | + ) |
| 303 | + # [END run_py_example_with_polling_interval] |
38 | 304 |
|
39 |
| - livy_java_task >> livy_python_task |
40 |
| - # [END create_livy] |
| 305 | + # [START howto_operator_emr_terminate_job_flow] |
| 306 | + remove_cluster = EmrTerminateJobFlowOperator( |
| 307 | + task_id="remove_cluster", |
| 308 | + job_flow_id=cluster_creator.output, |
| 309 | + trigger_rule="all_done", |
| 310 | + ) |
| 311 | + # [END howto_operator_emr_terminate_job_flow] |
| 312 | + |
| 313 | + ( |
| 314 | + create_key_pair_file |
| 315 | + >> cluster_creator |
| 316 | + >> describe_created_cluster |
| 317 | + >> get_and_add_ip_address_for_inbound_rules |
| 318 | + >> ssh_and_copy_pifile_to_hdfs |
| 319 | + >> create_airflow_connection_for_livy |
| 320 | + >> livy_java_task |
| 321 | + >> livy_python_task |
| 322 | + >> remove_cluster |
| 323 | + ) |
0 commit comments