Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def has_on_success_callback(self) -> bool:
def has_on_skipped_callback(self) -> bool:
return self._get_partial_kwargs_or_operator_default("has_on_skipped_callback")

@property
def has_retry_policy(self) -> bool:
return self._get_partial_kwargs_or_operator_default("has_retry_policy")

@property
def run_as_user(self) -> str | None:
return self._get_partial_kwargs_or_operator_default("run_as_user")
Expand Down
17 changes: 12 additions & 5 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@
_CALLBACK_TYPES = ("execute", "failure", "success", "retry", "skipped")
_OPERATOR_CALLBACK_FIELDS = frozenset(f"on_{x}_callback" for x in _CALLBACK_TYPES)
_HAS_CALLBACK_FIELDS = frozenset(f"has_on_{x}_callback" for x in _CALLBACK_TYPES)
# Fields whose value must never be serialized: the object has no serializer, so it would
# fall back to str(obj) and leak a non-deterministic memory address (a new DagVersion every
# parse). Only a boolean ``has_<field>`` flag is stored; the live object is recovered by
# re-parsing the DAG source on the worker. Applies both to a mapped operator's
# ``partial_kwargs`` and to a DAG's ``default_args``.
_HAS_FLAG_FIELDS = _OPERATOR_CALLBACK_FIELDS | frozenset({"retry_policy"})


def _get_registered_priority_weight_strategy(
Expand Down Expand Up @@ -974,7 +980,8 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
if cls._is_excluded(v, k, op):
continue

if k in _OPERATOR_CALLBACK_FIELDS:
if k in _HAS_FLAG_FIELDS:
# Store only a has_<field> flag, never the object (see _HAS_FLAG_FIELDS).
if bool(v):
serialized_op["partial_kwargs"][f"has_{k}"] = True
continue
Expand Down Expand Up @@ -1728,13 +1735,13 @@ def serialize_dag(cls, dag: DAG) -> dict:
# Ideally default_args goes through same logic as fields of SerializedBaseOperator.
if serialized_dag.get("default_args", {}):
default_args_dict = serialized_dag["default_args"][Encoding.VAR]
callbacks_to_remove = []
flags_to_remove = []
for k, v in list(default_args_dict.items()):
if k in _OPERATOR_CALLBACK_FIELDS:
if k in _HAS_FLAG_FIELDS:
if bool(v):
default_args_dict[f"has_{k}"] = True
callbacks_to_remove.append(k)
for k in callbacks_to_remove:
flags_to_remove.append(k)
for k in flags_to_remove:
del default_args_dict[k]

return serialized_dag
Expand Down
106 changes: 105 additions & 1 deletion airflow-core/tests/unit/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from airflow.sdk.definitions.decorators import task
from airflow.sdk.definitions.operator_resources import Resources
from airflow.sdk.definitions.param import Param
from airflow.sdk.definitions.retry_policy import ExceptionRetryPolicy, RetryAction, RetryRule
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
from airflow.serialization import serialized_objects
Expand Down Expand Up @@ -1279,7 +1280,6 @@ class TestRetryPolicySerialization:
def test_has_retry_policy_flag_set_when_policy_present(self):
"""When retry_policy is set, has_retry_policy=True in serialized form."""
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.definitions.retry_policy import ExceptionRetryPolicy, RetryAction, RetryRule
from airflow.serialization.serialized_objects import DagSerialization

policy = ExceptionRetryPolicy(
Expand Down Expand Up @@ -1309,6 +1309,110 @@ def test_has_retry_policy_flag_false_when_no_policy(self):
task = deserialized.task_dict["op_no_policy"]
assert task.has_retry_policy is False

def test_mapped_task_retry_policy_serializes_as_flag(self):
"""A mapped task's retry_policy must serialize as has_retry_policy, not the object."""
from airflow.sdk import DAG # module-level DAG is airflow.models.dag.DAG

policy = ExceptionRetryPolicy(
rules=[RetryRule(exception=ValueError, action=RetryAction.FAIL, reason="bad data")],
)

with DAG(dag_id="test_mapped_retry_policy_ser", start_date=DEFAULT_DATE) as dag:

@task(retries=3, retry_policy=policy)
def mapped(x):
return x

mapped.expand(x=[1, 2, 3])

serialized = DagSerialization.serialize_dag(dag)
# The RetryPolicy object must never be embedded -- str(obj) leaks a memory address.
assert "ExceptionRetryPolicy object at 0x" not in json.dumps(serialized)

deserialized = DagSerialization.deserialize_dag(serialized)
assert deserialized.task_dict["mapped"].has_retry_policy is True

def test_mapped_task_retry_policy_serialization_is_deterministic(self):
"""Serializing the same mapped-task-with-policy DAG twice yields identical output.

Regression test: the RetryPolicy object was serialized via str(obj), embedding a
per-process memory address, so every re-parse produced a different serialized DAG
(and a spurious new DagVersion).
"""
from airflow.sdk import DAG # module-level DAG is airflow.models.dag.DAG

def build():
policy = ExceptionRetryPolicy(
rules=[RetryRule(exception=ValueError, action=RetryAction.FAIL)],
)
with DAG(dag_id="test_mapped_retry_policy_determinism", start_date=DEFAULT_DATE) as dag:

@task(retries=3, retry_policy=policy)
def mapped(x):
return x

mapped.expand(x=[1, 2, 3])
return dag

assert DagSerialization.serialize_dag(build()) == DagSerialization.serialize_dag(build())

def test_dag_default_args_retry_policy_serializes_as_flag(self):
"""A retry_policy in DAG default_args must not leak the object into serialized default_args.

Regression test: serialize_dag serializes the raw default_args dict, so a RetryPolicy
there hit the str(obj) fallback (memory address -> new DagVersion every parse) even
though each task's has_retry_policy was set correctly.
"""
from airflow.sdk import DAG # module-level DAG is airflow.models.dag.DAG

def build():
policy = ExceptionRetryPolicy(
rules=[RetryRule(exception=ValueError, action=RetryAction.FAIL)],
)
with DAG(
dag_id="test_default_args_retry_policy",
start_date=DEFAULT_DATE,
default_args={"retry_policy": policy},
) as dag:

@task
def plain():
return 1

plain()
return dag

serialized = DagSerialization.serialize_dag(build())
# The RetryPolicy object must never be embedded in serialized default_args...
assert "ExceptionRetryPolicy object at 0x" not in json.dumps(serialized)
# ...and the has_retry_policy flag must be written in its place.
default_args_dict = serialized["default_args"][Encoding.VAR]
assert default_args_dict.get("has_retry_policy") is True
assert "retry_policy" not in default_args_dict
# Deterministic across independent parses (no embedded memory address).
assert DagSerialization.serialize_dag(build()) == DagSerialization.serialize_dag(build())

def test_mapped_task_no_retry_policy_flag_false(self):
"""A mapped task without a retry_policy must not spuriously set has_retry_policy.

Mapped resolves the flag via _get_partial_kwargs_or_operator_default (falling back to
SerializedBaseOperator.has_retry_policy=False), a different path than the non-mapped
dataclass default — so it needs its own negative test.
"""
from airflow.sdk import DAG # module-level DAG is airflow.models.dag.DAG

with DAG(dag_id="test_mapped_no_retry_policy", start_date=DEFAULT_DATE) as dag:

@task(retries=3)
def mapped(x):
return x

mapped.expand(x=[1, 2, 3])

serialized = DagSerialization.serialize_dag(dag)
deserialized = DagSerialization.deserialize_dag(serialized)
assert deserialized.task_dict["mapped"].has_retry_policy is False


class TestKubernetesImportAvoidance:
"""Test that serialization doesn't import kubernetes unnecessarily."""
Expand Down
Loading