diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index ef5772331108f..2d45dea4c9321 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -71,6 +71,7 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf from sqlalchemy.orm import Session + from airflow._shared.logging.remote import StreamingLogResponse from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest @@ -174,6 +175,7 @@ class BaseExecutor(LoggingMixin): # The connection-test supervisor uses ``signal.SIGALRM`` (via ``TimeoutPosix``) # to bound hook execution. Executors that opt in must run on POSIX systems. supports_connection_test: bool = False + supports_streaming_logs: bool = False sentry_integration: str = "" is_local: bool = False @@ -559,6 +561,19 @@ def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], li """ return [], [] + def get_streaming_task_log(self, ti: TaskInstance, try_number: int) -> StreamingLogResponse: + """ + Return a streaming response for task logs. + + Executors that implement this method must also set the ``supports_streaming_logs`` class + attribute to ``True``. + + :param ti: A TaskInstance object + :param try_number: current try_number to read log from + :return: StreamingLogResponse + """ + raise NotImplementedError + def end(self) -> None: # pragma: no cover """Wait synchronously for the previously submitted job to complete.""" raise NotImplementedError diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index 9eec24b18134e..4ef1099f417e8 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -30,7 +30,7 @@ from itertools import chain, islice from pathlib import Path from types import GeneratorType -from typing import IO, TYPE_CHECKING, TypedDict, cast +from typing import IO, TYPE_CHECKING, Literal, TypedDict, cast, overload from urllib.parse import urljoin import pendulum @@ -557,27 +557,52 @@ def _render_filename( ) raise RuntimeError(f"Unable to render log filename for {ti}. This should never happen") - def _get_executor_get_task_log( - self, ti: TaskInstance | TaskInstanceHistory - ) -> Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]: + @overload + def _get_executor_log_callable( + self, ti: TaskInstance | TaskInstanceHistory, *, streaming: Literal[True] + ) -> Callable[[TaskInstance | TaskInstanceHistory, int], StreamingLogResponse] | None: ... + + @overload + def _get_executor_log_callable( + self, ti: TaskInstance | TaskInstanceHistory, *, streaming: Literal[False] = ... + ) -> Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]: ... + + def _get_executor_log_callable( + self, ti: TaskInstance | TaskInstanceHistory, *, streaming: bool = False + ) -> ( + Callable[[TaskInstance | TaskInstanceHistory, int], StreamingLogResponse] + | None + | Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]] + ): """ - Get the get_task_log method from executor of current task instance. + Get the get_task_log or get_streaming_task_log method from executor of current task instance. Since there might be multiple executors, so we need to get the executor of current task instance instead of getting from default executor. :param ti: task instance object - :return: get_task_log method of the executor + :param streaming: if True, get the get_streaming_task_log method, otherwise get the get_task_log method + :return: get_task_log or get_streaming_task_log method of the executor """ executor_name = ti.executor or self.DEFAULT_EXECUTOR_KEY executor = self.executor_instances.get(executor_name) - if executor is not None: - return executor.get_task_log + if executor is None: + if executor_name == self.DEFAULT_EXECUTOR_KEY: + executor = ExecutorLoader.get_default_executor() + else: + executor = ExecutorLoader.load_executor(executor_name) + self.executor_instances[executor_name] = executor + + if streaming: + # The `supports_streaming_logs` class attribute and `get_streaming_task_log` method was added in Airflow 3.2.0. + # And some of the provider executors or custom executors haven't supported `get_streaming_task_log` yet. + # For backward compatibility with earlier versions, we need to check for their existence. + if hasattr(executor, "get_streaming_task_log") and getattr( + executor, "supports_streaming_logs", False + ): + return executor.get_streaming_task_log + return None - if executor_name == self.DEFAULT_EXECUTOR_KEY: - self.executor_instances[executor_name] = ExecutorLoader.get_default_executor() - else: - self.executor_instances[executor_name] = ExecutorLoader.load_executor(executor_name) - return self.executor_instances[executor_name].get_task_log + return executor.get_task_log def _read( self, @@ -632,23 +657,28 @@ def _read( raise TypeError("Logs should be either a list of strings or a generator of log lines.") # Extend LogSourceInfo source_list.extend(sources) - has_k8s_exec_pod = False + + has_executor_log = False if ti.state == TaskInstanceState.RUNNING: - executor_get_task_log = self._get_executor_get_task_log(ti) - response = executor_get_task_log(ti, try_number) - if response: - sources, logs = response + # check for streaming logs first + if executor_streaming_get_task_log := self._get_executor_log_callable(ti, streaming=True): + sources, executor_logs = executor_streaming_get_task_log(ti, try_number) + else: # fallback to non-streaming logs if streaming not supported + executor_get_task_log = self._get_executor_log_callable(ti) + sources, logs = executor_get_task_log(ti, try_number) # make the logs stream-like compatible executor_logs = [_get_compatible_log_stream(logs)] + if sources: source_list.extend(sources) - has_k8s_exec_pod = True + has_executor_log = True + if not (remote_logs and ti.state not in State.unfinished): # when finished, if we have remote logs, no need to check local worker_log_full_path = Path(self.local_base, worker_log_rel_path) sources, local_logs = self._read_from_local(worker_log_full_path) source_list.extend(sources) - if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_k8s_exec_pod: + if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_executor_log: sources, served_logs = self._read_from_logs_server(ti, worker_log_rel_path) source_list.extend(sources) elif (ti.state not in State.unfinished or ti.state in _STATES_WITH_COMPLETED_ATTEMPT) and not ( diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index cb646f7ce8d55..78d5e078ba0f0 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -67,6 +67,10 @@ def test_supports_multi_team_default_value(): assert not BaseExecutor.supports_multi_team +def test_supports_streaming_logs_default_value(): + assert not BaseExecutor.supports_streaming_logs + + def test_invalid_slotspool(): with pytest.raises(ValueError, match="parallelism is set to 0 or lower"): BaseExecutor(0) @@ -78,6 +82,14 @@ def test_get_task_log(): assert executor.get_task_log(ti=ti, try_number=1) == ([], []) +def test_get_streaming_task_log_not_implemented(): + executor = BaseExecutor() + ti = TaskInstance(task=SerializedBaseOperator(task_id="dummy"), dag_version_id=mock.MagicMock(spec=UUID)) + + with pytest.raises(NotImplementedError): + executor.get_streaming_task_log(ti=ti, try_number=1) + + def test_serve_logs_default_value(): assert not BaseExecutor.serve_logs diff --git a/airflow-core/tests/unit/utils/log/test_file_task_handler.py b/airflow-core/tests/unit/utils/log/test_file_task_handler.py index ceecffafa3265..c46f5af9d2476 100644 --- a/airflow-core/tests/unit/utils/log/test_file_task_handler.py +++ b/airflow-core/tests/unit/utils/log/test_file_task_handler.py @@ -21,6 +21,9 @@ from unittest.mock import MagicMock, patch from airflow.utils.log.file_task_handler import FileTaskHandler +from airflow.utils.state import TaskInstanceState + +from tests_common.test_utils.file_task_handler import convert_list_to_stream, extract_events class TestFileTaskHandlerLogServer: @@ -220,3 +223,75 @@ def test_handles_base_log_folder_that_is_itself_a_symlink(self, tmp_path): assert len(sources) == 1 assert "through-symlink content" in self._drain(streams[0]) + + +class TestFileTaskHandlerExecutorLogs: + """Tests for executor log retrieval selection.""" + + @staticmethod + def _running_ti(executor_name: str) -> MagicMock: + ti = MagicMock() + ti.executor = executor_name + ti.state = TaskInstanceState.RUNNING + ti.try_number = 1 + return ti + + def test_running_task_prefers_streaming_executor_logs(self): + """Use executor streaming logs when the executor advertises streaming support.""" + handler = FileTaskHandler(base_log_folder="") + executor = MagicMock() + executor.supports_streaming_logs = True + executor.get_streaming_task_log.return_value = ( + ["streaming source"], + [convert_list_to_stream(["streaming log"])], + ) + executor.get_task_log.return_value = (["legacy source"], ["legacy log"]) + handler.executor_instances = {"StreamingExecutor": executor} + ti = self._running_ti("StreamingExecutor") + + with ( + patch.object(handler, "_render_filename", return_value="dag/run/task/1.log"), + patch.object(handler, "_read_remote_logs", side_effect=NotImplementedError), + patch.object(handler, "_read_from_local", return_value=([], [])), + patch.object(handler, "_read_from_logs_server", return_value=([], [])) as read_from_logs_server, + ): + logs, metadata = handler._read(ti=ti, try_number=1) + + executor.get_streaming_task_log.assert_called_once_with(ti, 1) + executor.get_task_log.assert_not_called() + read_from_logs_server.assert_not_called() + assert extract_events(logs, skip_source_info=False) == [ + "::group::Log message source details", + "streaming source", + "::endgroup::", + "streaming log", + ] + assert metadata == {"end_of_log": False, "log_pos": 1} + + def test_running_task_falls_back_to_legacy_executor_logs(self): + """Use legacy executor logs when streaming support is not advertised.""" + handler = FileTaskHandler(base_log_folder="") + executor = MagicMock() + executor.supports_streaming_logs = False + executor.get_task_log.return_value = (["legacy source"], ["legacy log"]) + handler.executor_instances = {"LegacyExecutor": executor} + ti = self._running_ti("LegacyExecutor") + + with ( + patch.object(handler, "_render_filename", return_value="dag/run/task/1.log"), + patch.object(handler, "_read_remote_logs", side_effect=NotImplementedError), + patch.object(handler, "_read_from_local", return_value=([], [])), + patch.object(handler, "_read_from_logs_server", return_value=([], [])) as read_from_logs_server, + ): + logs, metadata = handler._read(ti=ti, try_number=1) + + executor.get_streaming_task_log.assert_not_called() + executor.get_task_log.assert_called_once_with(ti, 1) + read_from_logs_server.assert_not_called() + assert extract_events(logs, skip_source_info=False) == [ + "::group::Log message source details", + "legacy source", + "::endgroup::", + "legacy log", + ] + assert metadata == {"end_of_log": False, "log_pos": 1}