Skip to content

Commit b7dad41

Browse files
committed
fixup! Resolve lazy_object_proxy for PyVirtualEnvOperator
1 parent 5bec8c8 commit b7dad41

File tree

5 files changed

+85
-27
lines changed

5 files changed

+85
-27
lines changed

providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adx.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from airflow.models import DAG
2626
from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
2727
from airflow.providers.microsoft.azure.operators.adx import AzureDataExplorerQueryOperator
28+
from airflow.providers.microsoft.azure.version_compat import AIRFLOW_V_3_0_PLUS
2829
from airflow.utils.timezone import datetime
2930

3031
TEST_DAG_ID = "unit_tests"
@@ -88,12 +89,20 @@ def test_azure_data_explorer_query_operator_xcom_push_and_pull(
8889
mock_conn,
8990
mock_run_query,
9091
create_task_instance_of_operator,
92+
request,
9193
):
92-
ti = create_task_instance_of_operator(
93-
AzureDataExplorerQueryOperator,
94-
dag_id="test_azure_data_explorer_query_operator_xcom_push_and_pull",
95-
**MOCK_DATA,
96-
)
97-
ti.run()
98-
99-
assert ti.xcom_pull(task_ids=MOCK_DATA["task_id"]) == str(MOCK_RESULT)
94+
if AIRFLOW_V_3_0_PLUS:
95+
run_task = request.getfixturevalue("run_task")
96+
task = AzureDataExplorerQueryOperator(**MOCK_DATA)
97+
run_task(task=task)
98+
99+
assert run_task.xcom.get(key="return_value", task_id=task.task_id) == str(MOCK_RESULT)
100+
else:
101+
ti = create_task_instance_of_operator(
102+
AzureDataExplorerQueryOperator,
103+
dag_id="test_azure_data_explorer_query_operator_xcom_push_and_pull",
104+
**MOCK_DATA,
105+
)
106+
ti.run()
107+
108+
assert ti.xcom_pull(task_ids=MOCK_DATA["task_id"]) == str(MOCK_RESULT)

providers/oracle/tests/unit/oracle/operators/test_oracle.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from airflow.providers.oracle.hooks.oracle import OracleHook
2828
from airflow.providers.oracle.operators.oracle import OracleStoredProcedureOperator
2929

30+
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
31+
3032

3133
class TestOracleStoredProcedureOperator:
3234
@mock.patch.object(OracleHook, "run", autospec=OracleHook.run)
@@ -65,12 +67,20 @@ def test_push_oracle_exit_to_xcom(self, mock_callproc, request, dag_maker):
6567
error = f"ORA-{ora_exit_code}: This is a five-digit ORA error code"
6668
mock_callproc.side_effect = oracledb.DatabaseError(error)
6769

68-
with dag_maker(dag_id=f"dag_{request.node.name}"):
70+
if AIRFLOW_V_3_0_PLUS:
71+
run_task = request.getfixturevalue("run_task")
6972
task = OracleStoredProcedureOperator(
7073
procedure=procedure, oracle_conn_id=oracle_conn_id, parameters=parameters, task_id=task_id
7174
)
72-
dr = dag_maker.create_dagrun(run_id=task_id)
73-
ti = TaskInstance(task=task, run_id=dr.run_id)
74-
with pytest.raises(oracledb.DatabaseError, match=re.escape(error)):
75-
ti.run()
76-
assert ti.xcom_pull(task_ids=task.task_id, key="ORA") == ora_exit_code
75+
run_task(task=task)
76+
assert run_task.xcom.get(task_id=task.task_id, key="ORA") == ora_exit_code
77+
else:
78+
with dag_maker(dag_id=f"dag_{request.node.name}"):
79+
task = OracleStoredProcedureOperator(
80+
procedure=procedure, oracle_conn_id=oracle_conn_id, parameters=parameters, task_id=task_id
81+
)
82+
dr = dag_maker.create_dagrun(run_id=task_id)
83+
ti = TaskInstance(task=task, run_id=dr.run_id)
84+
with pytest.raises(oracledb.DatabaseError, match=re.escape(error)):
85+
ti.run()
86+
assert ti.xcom_pull(task_ids=task.task_id, key="ORA") == ora_exit_code

providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pytest
2525

2626
from airflow.decorators import task
27+
from airflow.providers.snowflake.version_compat import AIRFLOW_V_3_0_PLUS
2728
from airflow.utils import timezone
2829

2930
if TYPE_CHECKING:
@@ -156,7 +157,7 @@ def func(session: Session):
156157
mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
157158

158159
@mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
159-
def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, dag_maker):
160+
def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, dag_maker, request):
160161
@task.snowpark(
161162
task_id=TASK_ID,
162163
snowflake_conn_id=CONN_ID,
@@ -171,15 +172,23 @@ def func(session: Session):
171172
assert session == mock_snowflake_hook.return_value.get_snowpark_session.return_value
172173
return {"a": 1, "b": "2"}
173174

174-
with dag_maker(dag_id=TEST_DAG_ID):
175-
ret = func()
176-
177-
dr = dag_maker.create_dagrun()
178-
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
179-
ti = dr.get_task_instances()[0]
180-
assert ti.xcom_pull(key="a") == 1
181-
assert ti.xcom_pull(key="b") == "2"
182-
assert ti.xcom_pull() == {"a": 1, "b": "2"}
175+
if AIRFLOW_V_3_0_PLUS:
176+
run_task = request.getfixturevalue("run_task")
177+
op = func().operator
178+
run_task(task=op)
179+
assert run_task.xcom.get(key="a") == 1
180+
assert run_task.xcom.get(key="b") == "2"
181+
assert run_task.xcom.get(key="return_value") == {"a": 1, "b": "2"}
182+
else:
183+
with dag_maker(dag_id=TEST_DAG_ID):
184+
ret = func()
185+
186+
dr = dag_maker.create_dagrun()
187+
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
188+
ti = dr.get_task_instances()[0]
189+
assert ti.xcom_pull(key="a") == 1
190+
assert ti.xcom_pull(key="b") == "2"
191+
assert ti.xcom_pull() == {"a": 1, "b": "2"}
183192
mock_snowflake_hook.assert_called_once()
184193
mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
185194

providers/standard/tests/unit/standard/decorators/test_python.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,22 @@ def identity_notyping_with_decorator_call(x: int):
215215

216216
assert identity_notyping_with_decorator_call(5).operator.multiple_outputs is False
217217

218-
def test_manual_multiple_outputs_false_with_typings(self):
218+
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
219+
def test_manual_multiple_outputs_false_with_typings(self, run_task):
220+
@task_decorator(multiple_outputs=False)
221+
def identity2(x: int, y: int) -> tuple[int, int]:
222+
return x, y
223+
224+
res = identity2(8, 4)
225+
run_task(task=res.operator)
226+
227+
assert not res.operator.multiple_outputs
228+
assert run_task.xcom.get(key=res.key) == (8, 4)
229+
assert run_task.xcom.get(key="return_value_0") is None
230+
assert run_task.xcom.get(key="return_value_1") is None
231+
232+
@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 3")
233+
def test_manual_multiple_outputs_false_with_typings_af2(self):
219234
@task_decorator(multiple_outputs=False)
220235
def identity2(x: int, y: int) -> tuple[int, int]:
221236
return x, y
@@ -233,7 +248,22 @@ def identity2(x: int, y: int) -> tuple[int, int]:
233248
assert ti.xcom_pull(key="return_value_0") is None
234249
assert ti.xcom_pull(key="return_value_1") is None
235250

236-
def test_multiple_outputs_ignore_typing(self):
251+
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
252+
def test_multiple_outputs_ignore_typing(self, run_task):
253+
@task_decorator
254+
def identity_tuple(x: int, y: int) -> tuple[int, int]:
255+
return x, y
256+
257+
ident = identity_tuple(35, 36)
258+
run_task(task=ident.operator)
259+
260+
assert not ident.operator.multiple_outputs
261+
assert run_task.xcom.get(key=ident.key) == (35, 36)
262+
assert run_task.xcom.get(key="return_value_0") is None
263+
assert run_task.xcom.get(key="return_value_1") is None
264+
265+
@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 3")
266+
def test_multiple_outputs_ignore_typing_af2(self):
237267
@task_decorator
238268
def identity_tuple(x: int, y: int) -> tuple[int, int]:
239269
return x, y

providers/standard/tests/unit/standard/operators/test_datetime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def test_branch_datetime_operator_upper_comparison_outside_range(self, target_up
214214

215215
@pytest.mark.parametrize("target_lower", [target_lower for (target_lower, _) in targets])
216216
@time_machine.travel("2020-07-07 09:00:00")
217-
def test_branch_datetime_operator_lower_comparison_outside_range(self, target_lower, run_task):
217+
def test_branch_datetime_operator_lower_comparison_outside_range(self, target_lower):
218218
"""Check BranchDateTimeOperator branch operation"""
219219
self.branch_op.target_lower = target_lower
220220
self.branch_op.target_upper = None

0 commit comments

Comments
 (0)