Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf

from sqlalchemy.orm import Session

from airflow._shared.logging.remote import StreamingLogResponse
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
Expand Down Expand Up @@ -174,6 +175,7 @@ class BaseExecutor(LoggingMixin):
# The connection-test supervisor uses ``signal.SIGALRM`` (via ``TimeoutPosix``)
# to bound hook execution. Executors that opt in must run on POSIX systems.
supports_connection_test: bool = False
supports_streaming_logs: bool = False
sentry_integration: str = ""

is_local: bool = False
Expand Down Expand Up @@ -559,6 +561,19 @@ def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], li
"""
return [], []

def get_streaming_task_log(self, ti: TaskInstance, try_number: int) -> StreamingLogResponse:
"""
Return a streaming response for task logs.

Executors that implement this method must also set the ``supports_streaming_logs`` class
attribute to ``True``.

:param ti: A TaskInstance object
:param try_number: current try_number to read log from
:return: StreamingLogResponse
"""
raise NotImplementedError

def end(self) -> None: # pragma: no cover
"""Wait synchronously for the previously submitted job to complete."""
raise NotImplementedError
Expand Down
70 changes: 50 additions & 20 deletions airflow-core/src/airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from itertools import chain, islice
from pathlib import Path
from types import GeneratorType
from typing import IO, TYPE_CHECKING, TypedDict, cast
from typing import IO, TYPE_CHECKING, Literal, TypedDict, cast, overload
from urllib.parse import urljoin

import pendulum
Expand Down Expand Up @@ -557,27 +557,52 @@ def _render_filename(
)
raise RuntimeError(f"Unable to render log filename for {ti}. This should never happen")

def _get_executor_get_task_log(
self, ti: TaskInstance | TaskInstanceHistory
) -> Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]:
@overload
def _get_executor_log_callable(
self, ti: TaskInstance | TaskInstanceHistory, *, streaming: Literal[True]
) -> Callable[[TaskInstance | TaskInstanceHistory, int], StreamingLogResponse] | None: ...

@overload
def _get_executor_log_callable(
self, ti: TaskInstance | TaskInstanceHistory, *, streaming: Literal[False] = ...
) -> Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]: ...

def _get_executor_log_callable(
self, ti: TaskInstance | TaskInstanceHistory, *, streaming: bool = False
) -> (
Callable[[TaskInstance | TaskInstanceHistory, int], StreamingLogResponse]
| None
| Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]
):
"""
Get the get_task_log method from executor of current task instance.
Get the get_task_log or get_streaming_task_log method from executor of current task instance.

Since there might be multiple executors, so we need to get the executor of current task instance instead of getting from default executor.

:param ti: task instance object
:return: get_task_log method of the executor
:param streaming: if True, get the get_streaming_task_log method, otherwise get the get_task_log method
:return: get_task_log or get_streaming_task_log method of the executor
"""
executor_name = ti.executor or self.DEFAULT_EXECUTOR_KEY
executor = self.executor_instances.get(executor_name)
if executor is not None:
return executor.get_task_log
if executor is None:
if executor_name == self.DEFAULT_EXECUTOR_KEY:
executor = ExecutorLoader.get_default_executor()
else:
executor = ExecutorLoader.load_executor(executor_name)
self.executor_instances[executor_name] = executor

if streaming:
# The `supports_streaming_logs` class attribute and `get_streaming_task_log` method was added in Airflow 3.2.0.
# And some of the provider executors or custom executors haven't supported `get_streaming_task_log` yet.
# For backward compatibility with earlier versions, we need to check for their existence.
if hasattr(executor, "get_streaming_task_log") and getattr(
executor, "supports_streaming_logs", False
):
return executor.get_streaming_task_log
return None

if executor_name == self.DEFAULT_EXECUTOR_KEY:
self.executor_instances[executor_name] = ExecutorLoader.get_default_executor()
else:
self.executor_instances[executor_name] = ExecutorLoader.load_executor(executor_name)
return self.executor_instances[executor_name].get_task_log
return executor.get_task_log

def _read(
self,
Expand Down Expand Up @@ -632,23 +657,28 @@ def _read(
raise TypeError("Logs should be either a list of strings or a generator of log lines.")
# Extend LogSourceInfo
source_list.extend(sources)
has_k8s_exec_pod = False

has_executor_log = False
if ti.state == TaskInstanceState.RUNNING:
executor_get_task_log = self._get_executor_get_task_log(ti)
response = executor_get_task_log(ti, try_number)
if response:
sources, logs = response
# check for streaming logs first
if executor_streaming_get_task_log := self._get_executor_log_callable(ti, streaming=True):
sources, executor_logs = executor_streaming_get_task_log(ti, try_number)
else: # fallback to non-streaming logs if streaming not supported
executor_get_task_log = self._get_executor_log_callable(ti)
sources, logs = executor_get_task_log(ti, try_number)
# make the logs stream-like compatible
executor_logs = [_get_compatible_log_stream(logs)]

if sources:
source_list.extend(sources)
has_k8s_exec_pod = True
has_executor_log = True

if not (remote_logs and ti.state not in State.unfinished):
# when finished, if we have remote logs, no need to check local
worker_log_full_path = Path(self.local_base, worker_log_rel_path)
sources, local_logs = self._read_from_local(worker_log_full_path)
source_list.extend(sources)
if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_k8s_exec_pod:
if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_executor_log:
sources, served_logs = self._read_from_logs_server(ti, worker_log_rel_path)
source_list.extend(sources)
elif (ti.state not in State.unfinished or ti.state in _STATES_WITH_COMPLETED_ATTEMPT) and not (
Expand Down
12 changes: 12 additions & 0 deletions airflow-core/tests/unit/executors/test_base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def test_supports_multi_team_default_value():
assert not BaseExecutor.supports_multi_team


def test_supports_streaming_logs_default_value():
assert not BaseExecutor.supports_streaming_logs


def test_invalid_slotspool():
with pytest.raises(ValueError, match="parallelism is set to 0 or lower"):
BaseExecutor(0)
Expand All @@ -78,6 +82,14 @@ def test_get_task_log():
assert executor.get_task_log(ti=ti, try_number=1) == ([], [])


def test_get_streaming_task_log_not_implemented():
executor = BaseExecutor()
ti = TaskInstance(task=SerializedBaseOperator(task_id="dummy"), dag_version_id=mock.MagicMock(spec=UUID))

with pytest.raises(NotImplementedError):
executor.get_streaming_task_log(ti=ti, try_number=1)


def test_serve_logs_default_value():
assert not BaseExecutor.serve_logs

Expand Down
75 changes: 75 additions & 0 deletions airflow-core/tests/unit/utils/log/test_file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from unittest.mock import MagicMock, patch

from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.state import TaskInstanceState

from tests_common.test_utils.file_task_handler import convert_list_to_stream, extract_events


class TestFileTaskHandlerLogServer:
Expand Down Expand Up @@ -220,3 +223,75 @@ def test_handles_base_log_folder_that_is_itself_a_symlink(self, tmp_path):

assert len(sources) == 1
assert "through-symlink content" in self._drain(streams[0])


class TestFileTaskHandlerExecutorLogs:
"""Tests for executor log retrieval selection."""

@staticmethod
def _running_ti(executor_name: str) -> MagicMock:
ti = MagicMock()
ti.executor = executor_name
ti.state = TaskInstanceState.RUNNING
ti.try_number = 1
return ti

def test_running_task_prefers_streaming_executor_logs(self):
"""Use executor streaming logs when the executor advertises streaming support."""
handler = FileTaskHandler(base_log_folder="")
executor = MagicMock()
executor.supports_streaming_logs = True
executor.get_streaming_task_log.return_value = (
["streaming source"],
[convert_list_to_stream(["streaming log"])],
)
executor.get_task_log.return_value = (["legacy source"], ["legacy log"])
handler.executor_instances = {"StreamingExecutor": executor}
ti = self._running_ti("StreamingExecutor")

with (
patch.object(handler, "_render_filename", return_value="dag/run/task/1.log"),
patch.object(handler, "_read_remote_logs", side_effect=NotImplementedError),
patch.object(handler, "_read_from_local", return_value=([], [])),
patch.object(handler, "_read_from_logs_server", return_value=([], [])) as read_from_logs_server,
):
logs, metadata = handler._read(ti=ti, try_number=1)

executor.get_streaming_task_log.assert_called_once_with(ti, 1)
executor.get_task_log.assert_not_called()
read_from_logs_server.assert_not_called()
assert extract_events(logs, skip_source_info=False) == [
"::group::Log message source details",
"streaming source",
"::endgroup::",
"streaming log",
]
assert metadata == {"end_of_log": False, "log_pos": 1}

def test_running_task_falls_back_to_legacy_executor_logs(self):
"""Use legacy executor logs when streaming support is not advertised."""
handler = FileTaskHandler(base_log_folder="")
executor = MagicMock()
executor.supports_streaming_logs = False
executor.get_task_log.return_value = (["legacy source"], ["legacy log"])
handler.executor_instances = {"LegacyExecutor": executor}
ti = self._running_ti("LegacyExecutor")

with (
patch.object(handler, "_render_filename", return_value="dag/run/task/1.log"),
patch.object(handler, "_read_remote_logs", side_effect=NotImplementedError),
patch.object(handler, "_read_from_local", return_value=([], [])),
patch.object(handler, "_read_from_logs_server", return_value=([], [])) as read_from_logs_server,
):
logs, metadata = handler._read(ti=ti, try_number=1)

executor.get_streaming_task_log.assert_not_called()
executor.get_task_log.assert_called_once_with(ti, 1)
read_from_logs_server.assert_not_called()
assert extract_events(logs, skip_source_info=False) == [
"::group::Log message source details",
"legacy source",
"::endgroup::",
"legacy log",
]
assert metadata == {"end_of_log": False, "log_pos": 1}
7 changes: 4 additions & 3 deletions airflow-core/tests/unit/utils/test_log_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,9 @@ def test_file_task_handler_with_multiple_executors(
else:
path_to_executor_class = executors_mapping.get(executor_name)

with patch(f"{path_to_executor_class}.get_task_log", return_value=([], [])) as mock_get_task_log:
mock_get_task_log.return_value = ([], [])
with patch(
f"{path_to_executor_class}.get_streaming_task_log", return_value=([], [])
) as mock_get_streaming_task_log:
ti = create_task_instance(
dag_id="dag_for_testing_multiple_executors",
task_id="task_for_testing_multiple_executors",
Expand Down Expand Up @@ -326,7 +327,7 @@ def test_file_task_handler_with_multiple_executors(
assert hasattr(file_handler, "read")
file_handler.read(ti)
os.remove(log_filename)
mock_get_task_log.assert_called_once()
mock_get_streaming_task_log.assert_called_once()

if executor_name is None:
mock_get_default_executor.assert_called_once()
Expand Down
3 changes: 1 addition & 2 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
aarch
abc
AbstractFileSystem
AbstractToolset
accessor
Expand Down Expand Up @@ -1605,6 +1603,7 @@ StoredInfoType
storedInfoType
str
Streamable
StreamingLogResponse
strftime
Stringified
stringified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

if TYPE_CHECKING:
from airflow._shared.logging.remote import StreamingLogResponse
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.cli.cli_config import GroupCommand
Expand All @@ -55,6 +56,7 @@ class CeleryKubernetesExecutor(BaseExecutor):
"""

supports_ad_hoc_ti_run: bool = True
supports_streaming_logs: bool = True
# TODO: Remove this flag once providers depend on Airflow 3.0
supports_pickling: bool = True
supports_sentry: bool = False
Expand Down Expand Up @@ -206,6 +208,12 @@ def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], li
return self.kubernetes_executor.get_task_log(ti=ti, try_number=try_number)
return [], []

def get_streaming_task_log(self, ti: TaskInstance, try_number: int) -> StreamingLogResponse:
"""Fetch streaming task log from Kubernetes executor."""
if ti.queue == self.kubernetes_executor.kubernetes_queue:
return self.kubernetes_executor.get_streaming_task_log(ti=ti, try_number=try_number)
return [], []

def has_task(self, task_instance: TaskInstance) -> bool:
"""
Check if a task is either queued or running in either celery or kubernetes executor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_is_production_default_value(self):
def test_serve_logs_default_value(self):
assert not CeleryKubernetesExecutor.serve_logs

def test_supports_streaming_logs(self):
assert CeleryKubernetesExecutor.supports_streaming_logs

def test_cli_commands_vended(self):
assert CeleryKubernetesExecutor.get_cli_commands()

Expand Down Expand Up @@ -198,6 +201,27 @@ def test_log_is_fetched_from_k8s_executor_only_for_k8s_queue(self):
k8s_executor_mock.get_task_log.assert_not_called()
assert log == ([], [])

def test_streaming_log_is_fetched_from_k8s_executor_only_for_k8s_queue(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
simple_task_instance = mock.MagicMock()
simple_task_instance.queue = KUBERNETES_QUEUE

cke.get_streaming_task_log(ti=simple_task_instance, try_number=1)

k8s_executor_mock.get_streaming_task_log.assert_called_once_with(
ti=simple_task_instance, try_number=1
)

k8s_executor_mock.reset_mock()
simple_task_instance.queue = "test-queue"

log = cke.get_streaming_task_log(ti=simple_task_instance, try_number=1)

k8s_executor_mock.get_streaming_task_log.assert_not_called()
assert log == ([], [])

def test_get_event_buffer(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
Expand Down
Loading
Loading