Skip to content

Commit c9de188

Browse files
committed
openlineage: don't run task instance listener in executor
Signed-off-by: Maciej Obuchowski <[email protected]>
1 parent 744aa60 commit c9de188

File tree

5 files changed

+157
-10
lines changed

5 files changed

+157
-10
lines changed

airflow/providers/openlineage/plugins/listener.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import annotations
1818

1919
import logging
20-
from concurrent.futures import Executor, ThreadPoolExecutor
20+
from concurrent.futures import ThreadPoolExecutor
2121
from datetime import datetime
2222
from typing import TYPE_CHECKING
2323

@@ -42,8 +42,8 @@ class OpenLineageListener:
4242
"""OpenLineage listener sends events on task instance and dag run starts, completes and failures."""
4343

4444
def __init__(self):
45+
self._executor = None
4546
self.log = logging.getLogger(__name__)
46-
self.executor: Executor = None # type: ignore
4747
self.extractor_manager = ExtractorManager()
4848
self.adapter = OpenLineageAdapter()
4949

@@ -102,7 +102,7 @@ def on_running():
102102
},
103103
)
104104

105-
self.executor.submit(on_running)
105+
on_running()
106106

107107
@hookimpl
108108
def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session):
@@ -130,7 +130,7 @@ def on_success():
130130
task=task_metadata,
131131
)
132132

133-
self.executor.submit(on_success)
133+
on_success()
134134

135135
@hookimpl
136136
def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session):
@@ -158,12 +158,17 @@ def on_failure():
158158
task=task_metadata,
159159
)
160160

161-
self.executor.submit(on_failure)
161+
on_failure()
162+
163+
@property
164+
def executor(self):
165+
if not self._executor:
166+
self._executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_")
167+
return self._executor
162168

163169
@hookimpl
164170
def on_starting(self, component):
165171
self.log.debug("on_starting: %s", component.__class__.__name__)
166-
self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_")
167172

168173
@hookimpl
169174
def before_stopping(self, component):
@@ -174,9 +179,6 @@ def before_stopping(self, component):
174179

175180
@hookimpl
176181
def on_dag_run_running(self, dag_run: DagRun, msg: str):
177-
if not self.executor:
178-
self.log.error("Executor have not started before `on_dag_run_running`")
179-
return
180182
data_interval_start = dag_run.data_interval_start.isoformat() if dag_run.data_interval_start else None
181183
data_interval_end = dag_run.data_interval_end.isoformat() if dag_run.data_interval_end else None
182184
self.executor.submit(
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
##
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import datetime
21+
22+
from airflow.models import DAG
23+
from airflow.operators.python import PythonOperator
24+
25+
dag = DAG(
26+
dag_id="test_dag_xcom_openlineage",
27+
default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)},
28+
schedule="0 0 * * *",
29+
dagrun_timeout=datetime.timedelta(minutes=60),
30+
)
31+
32+
33+
def push_and_pull(ti, **kwargs):
34+
ti.xcom_push(key="pushed_key", value="asdf")
35+
ti.xcom_pull(key="pushed_key")
36+
37+
38+
task = PythonOperator(task_id="push_and_pull", python_callable=push_and_pull, dag=dag)
39+
40+
if __name__ == "__main__":
41+
dag.cli()

tests/listeners/test_listeners.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import os
20+
1921
import pytest as pytest
2022

2123
from airflow import AirflowException
@@ -46,6 +48,8 @@
4648
TASK_ID = "test_listener_task"
4749
EXECUTION_DATE = timezone.utcnow()
4850

51+
TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]
52+
4953

5054
@pytest.fixture(autouse=True)
5155
def clean_listener_manager():

tests/listeners/xcom_listener.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
from airflow.listeners import hookimpl
21+
22+
23+
class XComListener:
24+
def __init__(self, path: str, task_id: str):
25+
self.path = path
26+
self.task_id = task_id
27+
28+
def write(self, line: str):
29+
with open(self.path, "a") as f:
30+
f.write(line + "\n")
31+
32+
@hookimpl
33+
def on_task_instance_running(self, previous_state, task_instance, session):
34+
task_instance.xcom_push(key="listener", value="listener")
35+
task_instance.xcom_pull(task_ids=task_instance.task_id, key="listener")
36+
self.write("on_task_instance_running")
37+
38+
@hookimpl
39+
def on_task_instance_success(self, previous_state, task_instance, session):
40+
read = task_instance.xcom_pull(task_ids=self.task_id, key="listener")
41+
self.write("on_task_instance_success")
42+
self.write(read)
43+
44+
45+
def clear():
46+
pass

tests/task/task_runner/test_standard_task_runner.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from airflow.utils.platform import getuser
4040
from airflow.utils.state import State
4141
from airflow.utils.timeout import timeout
42+
from tests.listeners import xcom_listener
4243
from tests.listeners.file_write_listener import FileWriteListener
4344
from tests.test_utils.db import clear_db_runs
4445

@@ -85,10 +86,14 @@ def setup_class(self):
8586
(as the test environment does not have enough context for the normal
8687
way to run) and ensures they reset back to normal on the way out.
8788
"""
88-
get_listener_manager().clear()
8989
clear_db_runs()
9090
yield
9191
clear_db_runs()
92+
93+
@pytest.fixture(autouse=True)
94+
def clean_listener_manager(self):
95+
get_listener_manager().clear()
96+
yield
9297
get_listener_manager().clear()
9398

9499
@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
@@ -215,6 +220,55 @@ def test_notifies_about_fail(self):
215220
assert f.readline() == "on_task_instance_failed\n"
216221
assert f.readline() == "before_stopping\n"
217222

223+
def test_ol_does_not_block_xcoms(self):
224+
"""
225+
Test that ensures that pushing and pulling xcoms both in listener and task does not collide
226+
"""
227+
228+
path_listener_writer = "/tmp/test_ol_does_not_block_xcoms"
229+
try:
230+
os.unlink(path_listener_writer)
231+
except OSError:
232+
pass
233+
234+
listener = xcom_listener.XComListener(path_listener_writer, "push_and_pull")
235+
get_listener_manager().add_listener(listener)
236+
237+
dagbag = DagBag(
238+
dag_folder=TEST_DAG_FOLDER,
239+
include_examples=False,
240+
)
241+
dag = dagbag.dags.get("test_dag_xcom_openlineage")
242+
task = dag.get_task("push_and_pull")
243+
dag.create_dagrun(
244+
run_id="test",
245+
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
246+
state=State.RUNNING,
247+
start_date=DEFAULT_DATE,
248+
)
249+
250+
ti = TaskInstance(task=task, run_id="test")
251+
job = Job(dag_id=ti.dag_id)
252+
job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True)
253+
task_runner = StandardTaskRunner(job_runner)
254+
task_runner.start()
255+
256+
# Wait until process makes itself the leader of its own process group
257+
with timeout(seconds=1):
258+
while True:
259+
runner_pgid = os.getpgid(task_runner.process.pid)
260+
if runner_pgid == task_runner.process.pid:
261+
break
262+
time.sleep(0.01)
263+
264+
# Wait till process finishes
265+
assert task_runner.return_code(timeout=10) is not None
266+
267+
with open(path_listener_writer) as f:
268+
assert f.readline() == "on_task_instance_running\n"
269+
assert f.readline() == "on_task_instance_success\n"
270+
assert f.readline() == "listener\n"
271+
218272
@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
219273
def test_start_and_terminate_run_as_user(self, mock_init):
220274
mock_init.return_value = "/tmp/any"

0 commit comments

Comments
 (0)