diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 056151cefa1c1..0f6eee118ec81 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -735,16 +735,33 @@ def get_queued_dag_runs_to_set_running(cls, session: Session) -> ScalarResult[Da .subquery() ) - query = ( - select(cls) - .where(cls.state == DagRunState.QUEUED) + available_dagruns_rn = ( + select( + DagRun.dag_id, + DagRun.id, + running_drs.c.num_running, + func.row_number() + .over( + partition_by=[DagRun.dag_id, DagRun.backfill_id], + order_by=[ + nulls_first(cast("ColumnElement[Any]", BackfillDagRun.sort_ordinal), session=session), + nulls_first( + cast("ColumnElement[Any]", cls.last_scheduling_decision), session=session + ), + nulls_first(running_drs.c.num_running, session=session), + DagRun.run_after, + ], + ) + .label("rn"), + ) + .where(DagRun.state == DagRunState.QUEUED) .join( - DagModel, + running_drs, and_( - DagModel.dag_id == cls.dag_id, - DagModel.is_paused == false(), - DagModel.is_stale == false(), + running_drs.c.dag_id == DagRun.dag_id, + running_drs.c.backfill_id == DagRun.backfill_id, ), + isouter=True, ) .join( BackfillDagRun, @@ -754,37 +771,39 @@ def get_queued_dag_runs_to_set_running(cls, session: Session) -> ScalarResult[Da ), isouter=True, ) - .join(Backfill, isouter=True) + .subquery() + ) + + query = ( + select(cls) .join( - running_drs, + available_dagruns_rn, and_( - running_drs.c.dag_id == DagRun.dag_id, - coalesce(running_drs.c.backfill_id, text("-1")) - == coalesce(DagRun.backfill_id, text("-1")), + available_dagruns_rn.c.id == DagRun.id, + available_dagruns_rn.c.dag_id == DagRun.dag_id, + ), + ) + .join( + DagModel, + and_( + DagModel.dag_id == cls.dag_id, + DagModel.is_paused == false(), + DagModel.is_stale == false(), ), - isouter=True, ) + .join(Backfill, isouter=True) .where( - # there are two levels of checks for num_running - # the one done in this query verifies that the dag is not maxed out - # it could return many more dag runs than runnable if there is even - # capacity for 1. this could be improved. - coalesce(running_drs.c.num_running, text("0")) - < coalesce(Backfill.max_active_runs, DagModel.max_active_runs), + # this check returns strictly only the amount of dagruns + # which can run according to the max active runs limit + available_dagruns_rn.c.rn + <= coalesce( + Backfill.max_active_runs, + DagModel.max_active_runs, + ) + - coalesce(available_dagruns_rn.c.num_running, 0), # don't set paused dag runs as running not_(coalesce(cast("ColumnElement[bool]", Backfill.is_paused), False)), ) - .order_by( - # ordering by backfill sort ordinal first ensures that backfill dag runs - # have lower priority than all other dag run types (since sort_ordinal >= 1). - # additionally, sorting by sort_ordinal ensures that the backfill - # dag runs are created in the right order when that matters. - # todo: AIP-78 use row_number to avoid starvation; limit the number of returned runs per-dag - nulls_first(cast("ColumnElement[Any]", BackfillDagRun.sort_ordinal), session=session), - nulls_first(cast("ColumnElement[Any]", cls.last_scheduling_decision), session=session), - nulls_first(running_drs.c.num_running, session=session), # many running -> lower priority - cls.run_after, - ) .limit(cls.DEFAULT_DAGRUNS_TO_EXAMINE) ) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 603fcf4bf66cf..03f746bd76408 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -2632,7 +2632,10 @@ def test_find_executable_task_instances_max_active_tis_per_dagrun_deferred(self, dag_id = "SchedulerJobTest.test_max_active_tis_per_dagrun_deferred" with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): task_a = EmptyOperator.partial(task_id="task_a", max_active_tis_per_dagrun=1).expand_kwargs( - [{"inputs": 1}, {"inputs": 2}] + [ + {"inputs": 1}, + {"inputs": 2}, + ] ) EmptyOperator(task_id="task_b") @@ -4076,6 +4079,94 @@ def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, ses dag_runs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(dag_runs) == 2 + def test_runs_are_not_starved_by_max_active_runs_limit(self, dag_maker, session): + """ + Test that dagruns are not starved by max_active_runs + """ + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, executors=[self.null_exec]) + + dag_ids = ["dag1", "dag2", "dag3"] + + max_active_runs = 3 + + for dag_id in dag_ids: + with dag_maker( + dag_id=dag_id, + max_active_runs=max_active_runs, + session=session, + catchup=True, + schedule=timedelta(seconds=60), + start_date=DEFAULT_DATE, + ): + # Need to use something that doesn't immediately get marked as success by the scheduler + BashOperator(task_id="task", bash_command="true") + + dag_run = dag_maker.create_dagrun( + state=State.QUEUED, session=session, run_type=DagRunType.SCHEDULED + ) + + for _ in range(50): + # create a bunch of dagruns in queued state, to make sure they are filtered by max_active_runs + dag_run = dag_maker.create_dagrun_after( + dag_run, run_type=DagRunType.SCHEDULED, state=State.QUEUED + ) + + self.job_runner._start_queued_dagruns(session) + session.flush() + + running_dagrun_count = session.scalar( + select(func.count()).select_from(DagRun).where(DagRun.state == DagRunState.RUNNING) + ) + + assert running_dagrun_count == max_active_runs * len(dag_ids) + + def test_no_more_dagruns_are_set_to_running_when_max_active_runs_exceeded(self, dag_maker, session): + """ + Test that dagruns are not moved to running if there are more than the max_active_runs running dagruns + """ + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, executors=[self.null_exec]) + + max_active_runs = 1 + with dag_maker( + dag_id="test_dag", + max_active_runs=max_active_runs, + session=session, + catchup=True, + schedule=timedelta(seconds=60), + start_date=DEFAULT_DATE, + ): + # Need to use something that doesn't immediately get marked as success by the scheduler + BashOperator(task_id="task", bash_command="true") + + dag_run = dag_maker.create_dagrun(state=State.RUNNING, session=session, run_type=DagRunType.SCHEDULED) + + for _ in range(5): + # create a bunch of dagruns in running state, to exceed max_active_runs + dag_run = dag_maker.create_dagrun_after( + dag_run, run_type=DagRunType.SCHEDULED, state=State.RUNNING + ) + + running_dagruns_pre = session.scalar( + select(func.count()).select_from(DagRun).where(DagRun.state == DagRunState.RUNNING) + ) + + for _ in range(5): + # create a bunch of dagruns in queued state, to make sure they are filtered by max_active_runs + dag_run = dag_maker.create_dagrun_after( + dag_run, run_type=DagRunType.SCHEDULED, state=State.QUEUED + ) + + self.job_runner._start_queued_dagruns(session) + session.flush() + + running_dagruns_post = session.scalar( + select(func.count()).select_from(DagRun).where(DagRun.state == DagRunState.RUNNING) + ) + + assert running_dagruns_pre == running_dagruns_post + def test_dagrun_timeout_verify_max_active_runs(self, dag_maker, session): """ Test if a dagrun will not be scheduled if max_dag_runs @@ -6802,14 +6893,14 @@ def _running_counts(): EmptyOperator(task_id="mytask") dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) - for _ in range(9): + for _ in range(29): dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) # initial state -- nothing is running assert dag1_non_b_running == 0 assert dag1_b_running == 0 assert total_running == 0 - assert session.scalar(select(func.count(DagRun.id))) == 46 + assert session.scalar(select(func.count(DagRun.id))) == 66 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 # now let's run it once @@ -6817,26 +6908,40 @@ def _running_counts(): session.flush() # after running the scheduler one time, observe that only one dag run is started - # this is because there are 30 runs for dag 1 so neither the backfills nor - # any runs for dag2 get started + # and 3 backfill dagruns are started + # this is because there are 30 queued dagruns, many of which get filtered because their DAGs + # have already reached max_active_runs + # and so due to the default dagruns-to-examine limit, we look at the first 20 dagruns that CAN be run + # according to the max_active_runs parameter, meaning 3 backfill runs will start, 1 non-backfill, + # and all runnable dagruns for dag2 assert DagRun.DEFAULT_DAGRUNS_TO_EXAMINE == 20 dag1_non_b_running, dag1_b_running, total_running = _running_counts() assert dag1_non_b_running == 1 - assert dag1_b_running == 0 - assert total_running == 1 - assert session.scalar(select(func.count()).select_from(DagRun)) == 46 + assert dag1_b_running == 3 + assert total_running == 20 + assert session.scalar(select(func.count()).select_from(DagRun)) == 66 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 + # now we finish all lower priority scheduled runs, and observe new higher priority tasks are started + session.execute( + update(DagRun) + .where(DagRun.dag_id == "test_dag2", DagRun.state == DagRunState.RUNNING) + .values(state=DagRunState.SUCCESS) + ) + session.commit() + session.flush() # we run scheduler again and observe that now all the runs are created + # other than the finished runs of the backfill # this must be because sorting is working + # new tasks from test dag 2 should run, and so they are scheduled self.job_runner._start_queued_dagruns(session) session.flush() dag1_non_b_running, dag1_b_running, total_running = _running_counts() assert dag1_non_b_running == 1 assert dag1_b_running == 3 - assert total_running == 14 - assert session.scalar(select(func.count()).select_from(DagRun)) == 46 + assert total_running == 18 + assert session.scalar(select(func.count()).select_from(DagRun)) == 66 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 # run it a 3rd time and nothing changes @@ -6846,8 +6951,8 @@ def _running_counts(): dag1_non_b_running, dag1_b_running, total_running = _running_counts() assert dag1_non_b_running == 1 assert dag1_b_running == 3 - assert total_running == 14 - assert session.scalar(select(func.count()).select_from(DagRun)) == 46 + assert total_running == 18 + assert session.scalar(select(func.count()).select_from(DagRun)) == 66 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 def test_backfill_runs_are_started_with_lower_priority_catchup_false(self, dag_maker, session): @@ -7067,25 +7172,11 @@ def _running_counts(): assert dag1_non_b_running == 1 assert dag1_b_running == 3 - # this should be 14 but it is not. why? - # answer: because dag2 got starved out by dag1 - # if we run the scheduler again, dag2 should get queued - assert total_running == 4 + assert total_running == 14 assert session.scalar(select(func.count()).select_from(DagRun)) == 46 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 - # run scheduler a second time - self.job_runner._start_queued_dagruns(session) - session.flush() - - dag1_non_b_running, dag1_b_running, total_running = _running_counts() - assert dag1_non_b_running == 1 - assert dag1_b_running == 3 - - # on the second try, dag 2's 10 runs now start running - assert total_running == 14 - assert session.scalar(select(func.count()).select_from(DagRun)) == 46 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36