Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf

from sqlalchemy.orm import Session

from airflow._shared.logging.remote import StreamingLogResponse
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
Expand Down Expand Up @@ -174,6 +175,7 @@ class BaseExecutor(LoggingMixin):
# The connection-test supervisor uses ``signal.SIGALRM`` (via ``TimeoutPosix``)
# to bound hook execution. Executors that opt in must run on POSIX systems.
supports_connection_test: bool = False
supports_streaming_logs: bool = False
sentry_integration: str = ""

is_local: bool = False
Expand Down Expand Up @@ -559,6 +561,19 @@ def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], li
"""
return [], []

def get_streaming_task_log(self, ti: TaskInstance, try_number: int) -> StreamingLogResponse:
"""
Return a streaming response for task logs.

Executors that implement this method must also set the ``supports_streaming_logs`` class
attribute to ``True``.

:param ti: A TaskInstance object
:param try_number: current try_number to read log from
:return: StreamingLogResponse
"""
raise NotImplementedError

def end(self) -> None: # pragma: no cover
"""Wait synchronously for the previously submitted job to complete."""
raise NotImplementedError
Expand Down
70 changes: 50 additions & 20 deletions airflow-core/src/airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from itertools import chain, islice
from pathlib import Path
from types import GeneratorType
from typing import IO, TYPE_CHECKING, TypedDict, cast
from typing import IO, TYPE_CHECKING, Literal, TypedDict, cast, overload
from urllib.parse import urljoin

import pendulum
Expand Down Expand Up @@ -557,27 +557,52 @@ def _render_filename(
)
raise RuntimeError(f"Unable to render log filename for {ti}. This should never happen")

def _get_executor_get_task_log(
self, ti: TaskInstance | TaskInstanceHistory
) -> Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]:
@overload
def _get_executor_log_callable(
self, ti: TaskInstance | TaskInstanceHistory, *, streaming: Literal[True]
) -> Callable[[TaskInstance | TaskInstanceHistory, int], StreamingLogResponse] | None: ...

@overload
def _get_executor_log_callable(
self, ti: TaskInstance | TaskInstanceHistory, *, streaming: Literal[False] = ...
) -> Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]: ...

def _get_executor_log_callable(
self, ti: TaskInstance | TaskInstanceHistory, *, streaming: bool = False
) -> (
Callable[[TaskInstance | TaskInstanceHistory, int], StreamingLogResponse]
| None
| Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]
):
"""
Get the get_task_log method from executor of current task instance.
Get the get_task_log or get_streaming_task_log method from executor of current task instance.

Since there might be multiple executors, so we need to get the executor of current task instance instead of getting from default executor.

:param ti: task instance object
:return: get_task_log method of the executor
:param streaming: if True, get the get_streaming_task_log method, otherwise get the get_task_log method
:return: get_task_log or get_streaming_task_log method of the executor
"""
executor_name = ti.executor or self.DEFAULT_EXECUTOR_KEY
executor = self.executor_instances.get(executor_name)
if executor is not None:
return executor.get_task_log
if executor is None:
if executor_name == self.DEFAULT_EXECUTOR_KEY:
executor = ExecutorLoader.get_default_executor()
else:
executor = ExecutorLoader.load_executor(executor_name)
self.executor_instances[executor_name] = executor

if streaming:
# The `supports_streaming_logs` class attribute and `get_streaming_task_log` method was added in Airflow 3.2.0.
# And some of the provider executors or custom executors haven't supported `get_streaming_task_log` yet.
# For backward compatibility with earlier versions, we need to check for their existence.
if hasattr(executor, "get_streaming_task_log") and getattr(
executor, "supports_streaming_logs", False
):
return executor.get_streaming_task_log
return None

if executor_name == self.DEFAULT_EXECUTOR_KEY:
self.executor_instances[executor_name] = ExecutorLoader.get_default_executor()
else:
self.executor_instances[executor_name] = ExecutorLoader.load_executor(executor_name)
return self.executor_instances[executor_name].get_task_log
return executor.get_task_log

def _read(
self,
Expand Down Expand Up @@ -632,23 +657,28 @@ def _read(
raise TypeError("Logs should be either a list of strings or a generator of log lines.")
# Extend LogSourceInfo
source_list.extend(sources)
has_k8s_exec_pod = False

has_executor_log = False
if ti.state == TaskInstanceState.RUNNING:
executor_get_task_log = self._get_executor_get_task_log(ti)
response = executor_get_task_log(ti, try_number)
if response:
sources, logs = response
# check for streaming logs first
if executor_streaming_get_task_log := self._get_executor_log_callable(ti, streaming=True):
sources, executor_logs = executor_streaming_get_task_log(ti, try_number)
else: # fallback to non-streaming logs if streaming not supported
executor_get_task_log = self._get_executor_log_callable(ti)
sources, logs = executor_get_task_log(ti, try_number)
# make the logs stream-like compatible
executor_logs = [_get_compatible_log_stream(logs)]

if sources:
source_list.extend(sources)
has_k8s_exec_pod = True
has_executor_log = True

if not (remote_logs and ti.state not in State.unfinished):
# when finished, if we have remote logs, no need to check local
worker_log_full_path = Path(self.local_base, worker_log_rel_path)
sources, local_logs = self._read_from_local(worker_log_full_path)
source_list.extend(sources)
if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_k8s_exec_pod:
if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_executor_log:
sources, served_logs = self._read_from_logs_server(ti, worker_log_rel_path)
source_list.extend(sources)
elif (ti.state not in State.unfinished or ti.state in _STATES_WITH_COMPLETED_ATTEMPT) and not (
Expand Down
12 changes: 12 additions & 0 deletions airflow-core/tests/unit/executors/test_base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def test_supports_multi_team_default_value():
assert not BaseExecutor.supports_multi_team


def test_supports_streaming_logs_default_value():
assert not BaseExecutor.supports_streaming_logs


def test_invalid_slotspool():
with pytest.raises(ValueError, match="parallelism is set to 0 or lower"):
BaseExecutor(0)
Expand All @@ -78,6 +82,14 @@ def test_get_task_log():
assert executor.get_task_log(ti=ti, try_number=1) == ([], [])


def test_get_streaming_task_log_not_implemented():
executor = BaseExecutor()
ti = TaskInstance(task=SerializedBaseOperator(task_id="dummy"), dag_version_id=mock.MagicMock(spec=UUID))

with pytest.raises(NotImplementedError):
executor.get_streaming_task_log(ti=ti, try_number=1)


def test_serve_logs_default_value():
assert not BaseExecutor.serve_logs

Expand Down
75 changes: 75 additions & 0 deletions airflow-core/tests/unit/utils/log/test_file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from unittest.mock import MagicMock, patch

from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.state import TaskInstanceState

from tests_common.test_utils.file_task_handler import convert_list_to_stream, extract_events


class TestFileTaskHandlerLogServer:
Expand Down Expand Up @@ -220,3 +223,75 @@ def test_handles_base_log_folder_that_is_itself_a_symlink(self, tmp_path):

assert len(sources) == 1
assert "through-symlink content" in self._drain(streams[0])


class TestFileTaskHandlerExecutorLogs:
"""Tests for executor log retrieval selection."""

@staticmethod
def _running_ti(executor_name: str) -> MagicMock:
ti = MagicMock()
ti.executor = executor_name
ti.state = TaskInstanceState.RUNNING
ti.try_number = 1
return ti

def test_running_task_prefers_streaming_executor_logs(self):
"""Use executor streaming logs when the executor advertises streaming support."""
handler = FileTaskHandler(base_log_folder="")
executor = MagicMock()
executor.supports_streaming_logs = True
executor.get_streaming_task_log.return_value = (
["streaming source"],
[convert_list_to_stream(["streaming log"])],
)
executor.get_task_log.return_value = (["legacy source"], ["legacy log"])
handler.executor_instances = {"StreamingExecutor": executor}
ti = self._running_ti("StreamingExecutor")

with (
patch.object(handler, "_render_filename", return_value="dag/run/task/1.log"),
patch.object(handler, "_read_remote_logs", side_effect=NotImplementedError),
patch.object(handler, "_read_from_local", return_value=([], [])),
patch.object(handler, "_read_from_logs_server", return_value=([], [])) as read_from_logs_server,
):
logs, metadata = handler._read(ti=ti, try_number=1)

executor.get_streaming_task_log.assert_called_once_with(ti, 1)
executor.get_task_log.assert_not_called()
read_from_logs_server.assert_not_called()
assert extract_events(logs, skip_source_info=False) == [
"::group::Log message source details",
"streaming source",
"::endgroup::",
"streaming log",
]
assert metadata == {"end_of_log": False, "log_pos": 1}

def test_running_task_falls_back_to_legacy_executor_logs(self):
"""Use legacy executor logs when streaming support is not advertised."""
handler = FileTaskHandler(base_log_folder="")
executor = MagicMock()
executor.supports_streaming_logs = False
executor.get_task_log.return_value = (["legacy source"], ["legacy log"])
handler.executor_instances = {"LegacyExecutor": executor}
ti = self._running_ti("LegacyExecutor")

with (
patch.object(handler, "_render_filename", return_value="dag/run/task/1.log"),
patch.object(handler, "_read_remote_logs", side_effect=NotImplementedError),
patch.object(handler, "_read_from_local", return_value=([], [])),
patch.object(handler, "_read_from_logs_server", return_value=([], [])) as read_from_logs_server,
):
logs, metadata = handler._read(ti=ti, try_number=1)

executor.get_streaming_task_log.assert_not_called()
executor.get_task_log.assert_called_once_with(ti, 1)
read_from_logs_server.assert_not_called()
assert extract_events(logs, skip_source_info=False) == [
"::group::Log message source details",
"legacy source",
"::endgroup::",
"legacy log",
]
assert metadata == {"end_of_log": False, "log_pos": 1}
Loading