|
78 | 78 | from airflow.configuration import conf
|
79 | 79 | from airflow.exceptions import (
|
80 | 80 | AirflowException,
|
81 |
| - AirflowFailException, |
82 | 81 | AirflowInactiveAssetInInletOrOutletException,
|
83 | 82 | AirflowRescheduleException,
|
84 |
| - AirflowSensorTimeout, |
85 |
| - AirflowSkipException, |
86 | 83 | AirflowTaskTerminated,
|
87 | 84 | AirflowTaskTimeout,
|
88 | 85 | TaskDeferralError,
|
|
118 | 115 | from airflow.utils.span_status import SpanStatus
|
119 | 116 | from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime
|
120 | 117 | from airflow.utils.state import DagRunState, State, TaskInstanceState
|
121 |
| -from airflow.utils.task_instance_session import set_current_task_instance_session |
122 | 118 | from airflow.utils.timeout import timeout
|
123 | 119 | from airflow.utils.xcom import XCOM_RETURN_KEY
|
124 | 120 |
|
@@ -1692,161 +1688,33 @@ def _run_raw_task(
|
1692 | 1688 | :param pool: specifies the pool to use to run the task instance
|
1693 | 1689 | :param session: SQLAlchemy ORM Session
|
1694 | 1690 | """
|
1695 |
| - if TYPE_CHECKING: |
1696 |
| - assert self.task |
1697 |
| - |
1698 |
| - if TYPE_CHECKING: |
1699 |
| - assert isinstance(self.task, BaseOperator) |
1700 |
| - |
1701 |
| - self.test_mode = test_mode |
1702 |
| - self.refresh_from_task(self.task, pool_override=pool) |
1703 |
| - self.refresh_from_db(session=session) |
1704 |
| - self.hostname = get_hostname() |
1705 |
| - self.pid = os.getpid() |
1706 |
| - if not test_mode: |
1707 |
| - TaskInstance.save_to_db(ti=self, session=session) |
1708 |
| - actual_start_date = timezone.utcnow() |
1709 |
| - Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags) |
1710 |
| - # Same metric with tagging |
1711 |
| - Stats.incr("ti.start", tags=self.stats_tags) |
1712 |
| - # Initialize final state counters at zero |
1713 |
| - for state in State.task_states: |
1714 |
| - Stats.incr( |
1715 |
| - f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", |
1716 |
| - count=0, |
1717 |
| - tags=self.stats_tags, |
1718 |
| - ) |
1719 |
| - # Same metric with tagging |
1720 |
| - Stats.incr( |
1721 |
| - "ti.finish", |
1722 |
| - count=0, |
1723 |
| - tags={**self.stats_tags, "state": str(state)}, |
1724 |
| - ) |
1725 |
| - with set_current_task_instance_session(session=session): |
1726 |
| - self.task = self.task.prepare_for_execution() |
1727 |
| - context = self.get_template_context(ignore_param_exceptions=False, session=session) |
1728 |
| - |
1729 |
| - try: |
1730 |
| - if self.task: |
1731 |
| - from airflow.sdk.definitions.asset import Asset |
1732 |
| - |
1733 |
| - inlets = [asset.asprofile() for asset in self.task.inlets if isinstance(asset, Asset)] |
1734 |
| - outlets = [asset.asprofile() for asset in self.task.outlets if isinstance(asset, Asset)] |
1735 |
| - TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session) |
1736 |
| - if not mark_success: |
1737 |
| - TaskInstance._execute_task_with_callbacks( |
1738 |
| - self=self, # type: ignore[arg-type] |
1739 |
| - context=context, |
1740 |
| - test_mode=test_mode, |
1741 |
| - session=session, |
1742 |
| - ) |
1743 |
| - if not test_mode: |
1744 |
| - self.refresh_from_db(lock_for_update=True, session=session, keep_local_changes=True) |
1745 |
| - self.state = TaskInstanceState.SUCCESS |
1746 |
| - except TaskDeferred as defer: |
1747 |
| - # The task has signalled it wants to defer execution based on |
1748 |
| - # a trigger. |
1749 |
| - if raise_on_defer: |
1750 |
| - raise |
1751 |
| - self.defer_task(exception=defer, session=session) |
1752 |
| - self.log.info( |
1753 |
| - "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, logical_date=%s, start_date=%s", |
1754 |
| - self.dag_id, |
1755 |
| - self.task_id, |
1756 |
| - self.run_id, |
1757 |
| - _date_or_empty(task_instance=self, attr="logical_date"), |
1758 |
| - _date_or_empty(task_instance=self, attr="start_date"), |
1759 |
| - ) |
1760 |
| - return TaskReturnCode.DEFERRED |
1761 |
| - except AirflowSkipException as e: |
1762 |
| - # Recording SKIP |
1763 |
| - # log only if exception has any arguments to prevent log flooding |
1764 |
| - if e.args: |
1765 |
| - self.log.info(e) |
1766 |
| - if not test_mode: |
1767 |
| - self.refresh_from_db(lock_for_update=True, session=session, keep_local_changes=True) |
1768 |
| - self.state = TaskInstanceState.SKIPPED |
1769 |
| - _run_finished_callback(callbacks=self.task.on_skipped_callback, context=context) |
1770 |
| - TaskInstance.save_to_db(ti=self, session=session) |
1771 |
| - except AirflowRescheduleException as reschedule_exception: |
1772 |
| - self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) |
1773 |
| - self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") |
1774 |
| - return None |
1775 |
| - except (AirflowFailException, AirflowSensorTimeout) as e: |
1776 |
| - # If AirflowFailException is raised, task should not retry. |
1777 |
| - # If a sensor in reschedule mode reaches timeout, task should not retry. |
1778 |
| - self.handle_failure( |
1779 |
| - e, test_mode, context, force_fail=True, session=session |
1780 |
| - ) # already saves to db |
1781 |
| - raise |
1782 |
| - except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e: |
1783 |
| - if not test_mode: |
1784 |
| - self.refresh_from_db(lock_for_update=True, session=session) |
1785 |
| - # for case when task is marked as success/failed externally |
1786 |
| - # or dagrun timed out and task is marked as skipped |
1787 |
| - # current behavior doesn't hit the callbacks |
1788 |
| - if self.state in State.finished: |
1789 |
| - self.clear_next_method_args() |
1790 |
| - TaskInstance.save_to_db(ti=self, session=session) |
1791 |
| - return None |
1792 |
| - self.handle_failure(e, test_mode, context, session=session) |
1793 |
| - raise |
1794 |
| - except SystemExit as e: |
1795 |
| - # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`. |
1796 |
| - # Therefore, here we must handle only error codes. |
1797 |
| - msg = f"Task failed due to SystemExit({e.code})" |
1798 |
| - self.handle_failure(msg, test_mode, context, session=session) |
1799 |
| - raise AirflowException(msg) |
1800 |
| - except BaseException as e: |
1801 |
| - self.handle_failure(e, test_mode, context, session=session) |
1802 |
| - raise |
1803 |
| - finally: |
1804 |
| - # Print a marker post execution for internals of post task processing |
1805 |
| - log.info("::group::Post task execution logs") |
1806 |
| - |
1807 |
| - Stats.incr( |
1808 |
| - f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", |
1809 |
| - tags=self.stats_tags, |
1810 |
| - ) |
1811 |
| - # Same metric with tagging |
1812 |
| - Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) |
| 1691 | + from airflow.sdk.api.datamodels._generated import TaskInstance as TaskInstanceSDK |
| 1692 | + from airflow.sdk.execution_time.supervisor import run_task_in_process |
1813 | 1693 |
|
1814 |
| - # Recording SKIPPED or SUCCESS |
1815 |
| - self.clear_next_method_args() |
1816 |
| - self.end_date = timezone.utcnow() |
1817 |
| - _log_state(task_instance=self) |
1818 |
| - self.set_duration() |
| 1694 | + self.set_state(TaskInstanceState.QUEUED) |
1819 | 1695 |
|
1820 |
| - # run on_success_callback before db committing |
1821 |
| - # otherwise, the LocalTaskJob sees the state is changed to `success`, |
1822 |
| - # but the task_runner is still running, LocalTaskJob then treats the state is set externally! |
1823 |
| - if self.state == TaskInstanceState.SUCCESS: |
1824 |
| - _run_finished_callback(callbacks=self.task.on_success_callback, context=context) |
| 1696 | + if mark_success: |
| 1697 | + self.set_state(TaskInstanceState.SUCCESS) |
| 1698 | + log.info("[DAG TEST] Marking success for %s ", self.task_id) |
| 1699 | + return |
1825 | 1700 |
|
1826 |
| - if not test_mode: |
1827 |
| - _add_log(event=self.state, task_instance=self, session=session) |
1828 |
| - if self.state == TaskInstanceState.SUCCESS: |
1829 |
| - from airflow.sdk.execution_time.task_runner import ( |
1830 |
| - _build_asset_profiles, |
1831 |
| - _serialize_outlet_events, |
1832 |
| - ) |
| 1701 | + taskrun_result = run_task_in_process( |
| 1702 | + ti=TaskInstanceSDK( |
| 1703 | + id=self.id, |
| 1704 | + task_id=self.task_id, |
| 1705 | + dag_id=self.task.dag_id, |
| 1706 | + run_id=self.run_id, |
| 1707 | + try_number=self.try_number, |
| 1708 | + map_index=self.map_index, |
| 1709 | + ), |
| 1710 | + task=self.task, |
| 1711 | + ) |
1833 | 1712 |
|
1834 |
| - TaskInstance.register_asset_changes_in_db( |
1835 |
| - self, |
1836 |
| - list(_build_asset_profiles(self.task.outlets)), |
1837 |
| - list(_serialize_outlet_events(context["outlet_events"])), |
1838 |
| - session=session, |
1839 |
| - ) |
| 1713 | + if taskrun_result.state != TaskInstanceState.QUEUED: |
| 1714 | + self.set_state(taskrun_result.state) |
1840 | 1715 |
|
1841 |
| - TaskInstance.save_to_db(ti=self, session=session) |
1842 |
| - if self.state == TaskInstanceState.SUCCESS: |
1843 |
| - try: |
1844 |
| - get_listener_manager().hook.on_task_instance_success( |
1845 |
| - previous_state=TaskInstanceState.RUNNING, task_instance=self |
1846 |
| - ) |
1847 |
| - except Exception: |
1848 |
| - log.exception("error calling listener") |
1849 |
| - return None |
| 1716 | + if taskrun_result.error: |
| 1717 | + raise taskrun_result.error |
1850 | 1718 |
|
1851 | 1719 | @staticmethod
|
1852 | 1720 | @provide_session
|
|
0 commit comments