Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,17 @@ def process_executor_events(
)
)

if ti_queued and not ti_requeued:
# A running task that's still sending heartbeats is alive -- a worker is running it right now.
# This event is probably from a duplicate that already lost and died, so don't fail the live
# run. If the task really did die, heartbeat detection will fail it once the heartbeat stops.
heartbeat_timeout = conf.getint("scheduler", "task_instance_heartbeat_timeout")
ti_alive = (
ti.state == TaskInstanceState.RUNNING
and ti.last_heartbeat_at is not None
and ti.last_heartbeat_at >= timezone.utcnow() - timedelta(seconds=heartbeat_timeout)
)

if ti_queued and not ti_requeued and not ti_alive:
team_name = (
DagModel.get_team_name(ti.dag_id, session=session)
if conf.getboolean("core", "multi_team")
Expand Down
55 changes: 55 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,61 @@ def test_process_executor_events_stale_success_when_scheduled_after_defer(
tags={"dag_id": dag_id, "task_id": ti1.task_id},
)

@pytest.mark.parametrize("event_state", [State.FAILED, State.SUCCESS])
@mock.patch("airflow.jobs.scheduler_job_runner.TaskCallbackRequest")
@mock.patch("airflow._shared.observability.metrics.stats._get_backend")
def test_process_executor_events_does_not_fail_running_ti_with_fresh_heartbeat(
self, mock_get_backend, mock_task_callback, dag_maker, event_state
):
"""A running task with a fresh heartbeat must not be failed by a stale event from a dead duplicate."""
mock_stats = mock.MagicMock(spec=StatsLogger)
mock_get_backend.return_value = mock_stats
dag_id = "test_process_executor_events_running_fresh_heartbeat"
task_id_1 = "dummy_task"

session = settings.Session()
with dag_maker(dag_id=dag_id, fileloc="/test_path1/"):
task1 = EmptyOperator(task_id=task_id_1)
ti1 = dag_maker.create_dagrun().get_task_instance(task1.task_id)

executor = MockExecutor(do_update=False)
task_callback = mock.MagicMock()
mock_task_callback.return_value = task_callback
scheduler_job = Job()
session.add(scheduler_job)
session.flush()
self.job_runner = SchedulerJobRunner(scheduler_job, executors=[executor])

ti1.state = State.RUNNING
ti1.queued_by_job_id = scheduler_job.id
ti1.last_heartbeat_at = timezone.utcnow()
session.merge(ti1)
session.commit()

executor.event_buffer[ti1.key] = event_state, None
executor.has_task = mock.MagicMock(return_value=False)
mock_stats.incr.reset_mock()

self.job_runner._process_executor_events(executor=executor, session=session)
ti1.refresh_from_db(session=session)
assert ti1.state == State.RUNNING
self.job_runner.executor.callback_sink.send.assert_not_called()
mock_stats.incr.assert_not_called()

# But a stale heartbeat means it really is dead, so it still gets failed.
ti1.last_heartbeat_at = timezone.utcnow() - timedelta(hours=2)
session.merge(ti1)
session.commit()

executor.event_buffer[ti1.key] = event_state, None
mock_stats.incr.reset_mock()

self.job_runner._process_executor_events(executor=executor, session=session)
mock_stats.incr.assert_any_call(
"scheduler.tasks.killed_externally",
tags={"dag_id": dag_id, "task_id": ti1.task_id},
)

@mock.patch("airflow.jobs.scheduler_job_runner.TaskCallbackRequest")
@mock.patch("airflow._shared.observability.metrics.stats._get_backend")
def test_process_executor_events_stale_success_when_queued_after_defer(
Expand Down
11 changes: 10 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
TaskInstanceState,
)
from airflow.sdk.configuration import conf
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time import comms
from airflow.sdk.execution_time.comms import (
AssetEventsResult,
Expand Down Expand Up @@ -2548,6 +2548,15 @@ def supervise_task(
final_state=result.final_state,
)
return result.exit_code
except TaskAlreadyRunningError:
# Another worker is already running this task, so the server told us to back off. This is
# normal -- it just means we were a duplicate that lost the race. Our task never started any
# real work, so exit quietly instead of reporting a failure that would look like a crash.
log.info(
"Task instance already running on another worker; standing down without failing it",
workload_id=str(ti.id),
)
return 0
finally:
if log_path and log_file_descriptor:
log_file_descriptor.close()
Expand Down
34 changes: 34 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,40 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker
"timestamp": "2024-11-07T12:34:56.078901Z",
} in captured_logs

def test_supervise_task_stands_down_when_already_running_on_another_worker(
self, test_dags_dir, captured_logs, mocker
):
"""When the task is already running elsewhere, we should exit quietly with 0, not report a failure."""
ti = TaskInstance(
id=uuid7(),
task_id="hello",
dag_id="super_basic_run",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
queue="default",
)

mock_client = mocker.Mock(spec=sdk_client.Client)
mock_client.task_instances.start.side_effect = TaskAlreadyRunningError(
f"Task instance {ti.id} is already running"
)

bundle_info = BundleInfo(name="my-bundle", version=None)
with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)):
exit_code = supervise_task(
ti=ti,
dag_rel_path="super_basic_run.py",
token="",
client=mock_client,
bundle_info=bundle_info,
)

assert exit_code == 0, captured_logs
mock_client.task_instances.start.assert_called_once()
mock_client.task_instances.finish.assert_not_called()
mock_client.task_instances.succeed.assert_not_called()

def test_supervise_handles_deferred_task(
self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context
):
Expand Down
Loading