Skip to content

Commit 12bd1c4

Browse files
committed
openlineage: don't run task instance listener in executor
Signed-off-by: Maciej Obuchowski <[email protected]>
1 parent 5ee1bcb commit 12bd1c4

File tree

4 files changed

+164
-10
lines changed

4 files changed

+164
-10
lines changed

airflow/providers/openlineage/plugins/listener.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import annotations
1818

1919
import logging
20-
from concurrent.futures import Executor, ThreadPoolExecutor
20+
from concurrent.futures import ThreadPoolExecutor
2121
from datetime import datetime
2222
from typing import TYPE_CHECKING
2323

@@ -42,8 +42,8 @@ class OpenLineageListener:
4242
"""OpenLineage listener sends events on task instance and dag run starts, completes and failures."""
4343

4444
def __init__(self):
45+
self._executor = None
4546
self.log = logging.getLogger(__name__)
46-
self.executor: Executor = None # type: ignore
4747
self.extractor_manager = ExtractorManager()
4848
self.adapter = OpenLineageAdapter()
4949

@@ -102,7 +102,7 @@ def on_running():
102102
},
103103
)
104104

105-
self.executor.submit(on_running)
105+
on_running()
106106

107107
@hookimpl
108108
def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session):
@@ -130,7 +130,7 @@ def on_success():
130130
task=task_metadata,
131131
)
132132

133-
self.executor.submit(on_success)
133+
on_success()
134134

135135
@hookimpl
136136
def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session):
@@ -158,12 +158,17 @@ def on_failure():
158158
task=task_metadata,
159159
)
160160

161-
self.executor.submit(on_failure)
161+
on_failure()
162+
163+
@property
164+
def executor(self):
165+
if not self._executor:
166+
self._executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_")
167+
return self._executor
162168

163169
@hookimpl
164170
def on_starting(self, component):
165171
self.log.debug("on_starting: %s", component.__class__.__name__)
166-
self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_")
167172

168173
@hookimpl
169174
def before_stopping(self, component):
@@ -174,9 +179,6 @@ def before_stopping(self, component):
174179

175180
@hookimpl
176181
def on_dag_run_running(self, dag_run: DagRun, msg: str):
177-
if not self.executor:
178-
self.log.error("Executor have not started before `on_dag_run_running`")
179-
return
180182
data_interval_start = dag_run.data_interval_start.isoformat() if dag_run.data_interval_start else None
181183
data_interval_end = dag_run.data_interval_end.isoformat() if dag_run.data_interval_end else None
182184
self.executor.submit(
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
##
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import datetime
21+
22+
from airflow.models import DAG
23+
from airflow.operators.python import PythonOperator
24+
25+
dag = DAG(
26+
dag_id="test_dag_xcom_openlineage",
27+
default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)},
28+
schedule="0 0 * * *",
29+
dagrun_timeout=datetime.timedelta(minutes=60),
30+
)
31+
32+
33+
def push_and_pull(ti, **kwargs):
34+
ti.xcom_push(key="pushed_key", value="asdf")
35+
ti.xcom_pull(key="pushed_key")
36+
37+
38+
task = PythonOperator(task_id="push_and_pull", python_callable=push_and_pull, dag=dag)
39+
40+
if __name__ == "__main__":
41+
dag.cli()

tests/listeners/test_listeners.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,32 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import logging
20+
import os
21+
import time
22+
1923
import pytest as pytest
2024

2125
from airflow import AirflowException
2226
from airflow.jobs.job import Job, run_job
27+
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
2328
from airflow.listeners.listener import get_listener_manager
29+
from airflow.models import DagBag, TaskInstance
2430
from airflow.operators.bash import BashOperator
31+
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
2532
from airflow.utils import timezone
2633
from airflow.utils.session import provide_session
27-
from airflow.utils.state import DagRunState, TaskInstanceState
34+
from airflow.utils.state import DagRunState, State, TaskInstanceState
35+
from airflow.utils.timeout import timeout
2836
from tests.listeners import (
2937
class_listener,
3038
full_listener,
3139
lifecycle_listener,
3240
partial_listener,
3341
throwing_listener,
42+
xcom_listener,
3443
)
44+
from tests.models import DEFAULT_DATE
3545
from tests.utils.test_helpers import MockJobRunner
3646

3747
LISTENERS = [
@@ -46,6 +56,8 @@
4656
TASK_ID = "test_listener_task"
4757
EXECUTION_DATE = timezone.utcnow()
4858

59+
TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]
60+
4961

5062
@pytest.fixture(autouse=True)
5163
def clean_listener_manager():
@@ -163,3 +175,56 @@ def test_class_based_listener(create_task_instance, session=None):
163175

164176
assert len(listener.state) == 2
165177
assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]
178+
179+
180+
def test_ol_does_not_block_xcoms():
181+
"""
182+
Test that ensures that where a task is marked success in the UI
183+
on_success_callback gets executed
184+
"""
185+
186+
path_listener_writer = "/tmp/test_ol_does_not_block_xcoms"
187+
try:
188+
os.unlink(path_listener_writer)
189+
except OSError:
190+
pass
191+
192+
listener = xcom_listener.XComListener(path_listener_writer, "push_and_pull")
193+
get_listener_manager().add_listener(listener)
194+
log = logging.getLogger("airflow")
195+
196+
dagbag = DagBag(
197+
dag_folder=TEST_DAG_FOLDER,
198+
include_examples=False,
199+
)
200+
dag = dagbag.dags.get("test_dag_xcom_openlineage")
201+
task = dag.get_task("push_and_pull")
202+
dag.create_dagrun(
203+
run_id="test",
204+
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
205+
state=State.RUNNING,
206+
start_date=DEFAULT_DATE,
207+
)
208+
209+
ti = TaskInstance(task=task, run_id="test")
210+
job = Job(dag_id=ti.dag_id)
211+
job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True)
212+
task_runner = StandardTaskRunner(job_runner)
213+
task_runner.start()
214+
215+
# Wait until process makes itself the leader of its own process group
216+
with timeout(seconds=1):
217+
while True:
218+
runner_pgid = os.getpgid(task_runner.process.pid)
219+
if runner_pgid == task_runner.process.pid:
220+
break
221+
time.sleep(0.01)
222+
223+
# Wait till process finishes
224+
assert task_runner.return_code(timeout=10) is not None
225+
log.error(task_runner.return_code())
226+
227+
with open(path_listener_writer) as f:
228+
assert f.readline() == "on_task_instance_running\n"
229+
assert f.readline() == "on_task_instance_success\n"
230+
assert f.readline() == "listener\n"

tests/listeners/xcom_listener.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
from airflow.listeners import hookimpl
21+
22+
23+
class XComListener:
24+
def __init__(self, path: str, task_id: str):
25+
self.path = path
26+
self.task_id = task_id
27+
28+
def write(self, line: str):
29+
with open(self.path, "a") as f:
30+
f.write(line + "\n")
31+
32+
@hookimpl
33+
def on_task_instance_running(self, previous_state, task_instance, session):
34+
task_instance.xcom_push(key="listener", value="listener")
35+
task_instance.xcom_pull(task_ids=task_instance.task_id, key="listener")
36+
self.write("on_task_instance_running")
37+
38+
@hookimpl
39+
def on_task_instance_success(self, previous_state, task_instance, session):
40+
read = task_instance.xcom_pull(task_ids=self.task_id, key="listener")
41+
self.write("on_task_instance_success")
42+
self.write(read)
43+
44+
45+
def clear():
46+
pass

0 commit comments

Comments
 (0)