Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from enum import Enum

from airflow.api_fastapi.core_api.base import StrictBaseModel


class CallbackTerminalState(str, Enum):
"""Terminal states a callback can transition to from RUNNING."""

SUCCESS = "success"
FAILED = "failed"


class CallbackTerminalStatePayload(StrictBaseModel):
"""Payload for transitioning a callback from RUNNING to a terminal state."""

state: CallbackTerminalState
output: str | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
asset_events,
asset_state_store,
assets,
callbacks,
connection_tests,
connections,
dag_runs,
Expand Down Expand Up @@ -52,6 +53,7 @@
authenticated_router.include_router(
connection_tests.router, prefix="/connection-tests", tags=["Connection Tests"]
)
authenticated_router.include_router(callbacks.router, prefix="/callbacks", tags=["Callbacks"])
authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"])
authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Annotated
from uuid import UUID

import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Response, Security, status
from structlog.contextvars import bind_contextvars

from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.execution_api.datamodels.callback import CallbackTerminalStatePayload
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.deps import DepContainer
from airflow.api_fastapi.execution_api.security import CurrentTIToken, ExecutionAPIRoute, require_auth
from airflow.models.callback import Callback
from airflow.utils.state import CallbackState

log = structlog.get_logger(__name__)

router = VersionedAPIRouter(route_class=ExecutionAPIRoute)


def _require_self(token: TIToken, callback_id: UUID) -> None:
"""Mirror the ``ti:self`` enforcement from security.py for callback routes."""
if str(token.id) != str(callback_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Token subject does not match callback id",
)


@router.post(
"/{callback_id}/run",
status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])],
responses={
status.HTTP_403_FORBIDDEN: {"description": "Token subject does not match callback id"},
status.HTTP_404_NOT_FOUND: {"description": "Callback not found"},
status.HTTP_409_CONFLICT: {"description": "Callback is not in a state that can be marked running"},
},
)
def callback_run(
callback_id: UUID,
response: Response,
session: SessionDep,
services=DepContainer,
token: TIToken = CurrentTIToken,
) -> Response:
"""
Mark a callback as RUNNING.

Mirrors ``PATCH /task-instances/{id}/run``: this is the single endpoint that
accepts a workload-scoped token and atomically (a) transitions the callback
from QUEUED to RUNNING and (b) issues a fresh execution-scoped token via the
``Refreshed-API-Token`` response header. All subsequent supervisor calls hit
execution-only routes.
Comment on lines +70 to +74

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this doc string is what shows up in the OpenAPI generated spec, and thus as inline docs for any clients generated off this.

As a result we should re-word this to make it be more user/client focused. (It's pretty close, I just don't think the "Mirrors PATCH" part is relevant to users. That can be a comment inside the fn itself if you think that is relevant.

"""
bind_contextvars(callback_id=str(callback_id))
_require_self(token, callback_id)

callback = session.get(Callback, callback_id)
if callback is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": "Callback not found"},
)

# Allow QUEUED → RUNNING transition; treat RUNNING as idempotent so a retried
# supervisor start does not 409. Anything else (PENDING / SCHEDULED / terminal) rejects.
if callback.state == CallbackState.RUNNING:
log.info("Duplicate start request received", callback_id=str(callback.id))
elif callback.state == CallbackState.QUEUED:
callback.state = CallbackState.RUNNING
session.add(callback)
else:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "Callback was not in a state where it could be marked running",
"current_state": callback.state,
},
)

if token.claims.scope == "workload":
generator: JWTGenerator = services.get(JWTGenerator)
execution_token = generator.generate(extras={"sub": str(callback_id), "scope": "execution"})
response.headers["Refreshed-API-Token"] = execution_token

response.status_code = status.HTTP_204_NO_CONTENT
return response


@router.patch(
"/{callback_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_403_FORBIDDEN: {"description": "Token subject does not match callback id"},
status.HTTP_404_NOT_FOUND: {"description": "Callback not found"},
status.HTTP_409_CONFLICT: {"description": "Callback is not in RUNNING state"},
},
)
def callback_update_state(
callback_id: UUID,
payload: Annotated[CallbackTerminalStatePayload, Body()],
session: SessionDep,
token: TIToken = CurrentTIToken,
) -> Response:
"""Mark a RUNNING callback as SUCCESS or FAILED."""
bind_contextvars(callback_id=str(callback_id))
_require_self(token, callback_id)

callback = session.get(Callback, callback_id)
if callback is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": "Callback not found"},
)

if callback.state != CallbackState.RUNNING:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "Callback was not in RUNNING state",
"current_state": callback.state,
},
)

callback.state = CallbackState(payload.state)
if payload.output is not None:
callback.output = payload.output
session.add(callback)

return Response(status_code=status.HTTP_204_NO_CONTENT)
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
AddTeamNameField,
AddVariableKeysEndpoint,
)
from airflow.api_fastapi.execution_api.versions.v2026_07_01 import AddCallbackEndpoints

bundle = VersionBundle(
HeadVersion(),
Expand All @@ -66,6 +67,10 @@
AddAssetsByAliasEndpoint,
AddPartitionDateField,
),
Version(
"2026-06-01",

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reconcile before this merges: the version module is named v2026_07_01.py, but this label is "2026-06-01". They should agree. Two ways:

  • If 2026-06-01 is correct, rename the file to v2026_06_01.py (the import above and this entry's position between 2026-06-30 and 2026-04-06 are already right for that date).
  • If 2026-07-01 was intended, change this label to "2026-07-01" and move the whole Version(...) block above the 2026-06-30 entry so the descending order stays correct.

It works today (Cadwyn keys off the string, and 2026-06-01 is correctly ordered), so this isn't a runtime bug — but the filename/label disagreement is a trap for the next person editing the version bundle.

AddCallbackEndpoints,
),
Version(
"2026-04-06",
AddPartitionKeyField,
Expand Down

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this has missed 3.3.0 we will want to future-date this version to ~2026-07-31 (and the release manager will fix the version when it's included in a release)

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from cadwyn import VersionChange, endpoint


class AddCallbackEndpoints(VersionChange):
Comment thread
wjddn279 marked this conversation as resolved.
"""Add the ``POST /callbacks/{callback_id}/run`` and ``PATCH /callbacks/{callback_id}/state`` endpoints."""

description = __doc__

instructions_to_migrate_to_previous_version = (
endpoint("/callbacks/{callback_id}/run", ["POST"]).didnt_exist,
endpoint("/callbacks/{callback_id}/state", ["PATCH"]).didnt_exist,
)
33 changes: 22 additions & 11 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,12 +1327,12 @@ def process_executor_events(
cls.logger().debug("Draining executor event with state %s for connection test %s", state, key)
elif isinstance(key, CallbackKey):
cls.logger().info("Received executor event with state %s for callback %s", state, key)
if state in (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.SUCCESS):
if state in (CallbackState.FAILED, CallbackState.SUCCESS):
callback_keys_with_events.append(key)
else:
cls.logger().error("Unknown workload key type in event buffer: %r", key)

# Handle callback state events
# Handle callback state events.
for callback_id in callback_keys_with_events:
state, info = event_buffer.pop(callback_id)
callback = session.get(Callback, UUID(str(callback_id)))
Expand All @@ -1344,17 +1344,28 @@ def process_executor_events(
)
continue

if state == CallbackState.RUNNING:
callback.state = CallbackState.RUNNING
cls.logger().info("Callback %s is currently running", callback_id)
elif state == CallbackState.SUCCESS:
callback.state = CallbackState.SUCCESS
# Callback state transitions are now driven by the supervisor through
# the Execution API (POST /callbacks/{id}/run, PATCH /callbacks/{id}/state).
# The in-process events from the executor are kept as a fallback safety
# net for cases where the supervisor crashed before reporting a terminal state

need_to_modify = False

if state == CallbackState.SUCCESS:
cls.logger().info("Callback %s completed successfully", callback_id)
if callback.state in (CallbackState.QUEUED, CallbackState.RUNNING):
callback.state = CallbackState.SUCCESS
need_to_modify = True
elif state == CallbackState.FAILED:
callback.state = CallbackState.FAILED
callback.output = str(info) if info else "Execution failed"
cls.logger().error("Callback %s failed: %s", callback_id, callback.output)
session.add(callback)
callback_output = str(info) if info else "Execution failed"
cls.logger().error("Callback %s failed: %s", callback_id, callback_output)
if callback.state in (CallbackState.QUEUED, CallbackState.RUNNING):
callback.state = CallbackState.FAILED
callback.output = callback_output
need_to_modify = True
Comment thread
wjddn279 marked this conversation as resolved.

if need_to_modify:
session.add(callback)

# Return if no finished tasks
if not tis_with_right_state:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
NON_DEFAULT_TOKEN_POLICY: dict[str, set[str]] = {
# The /run endpoint exchanges a workload token for a short-lived execution token.
"PATCH /task-instances/{task_instance_id}/run": {"execution", "workload"},
"POST /callbacks/{callback_id}/run": {"execution", "workload"},
# Connection test routes run from a queued worker context (workload-only).
"PATCH /connection-tests/{connection_test_id}": {"workload"},
"GET /connection-tests/{connection_test_id}/connection": {"workload"},
Expand Down
Loading
Loading