diff --git a/scripts/ci/prek/check_new_airflow_exception_usage.py b/scripts/ci/prek/check_new_airflow_exception_usage.py index 6ba97db865e50..5c09602192059 100755 --- a/scripts/ci/prek/check_new_airflow_exception_usage.py +++ b/scripts/ci/prek/check_new_airflow_exception_usage.py @@ -54,87 +54,40 @@ import argparse import re +from collections.abc import Iterable from pathlib import Path +from common_prek_utils import AIRFLOW_ROOT_PATH, AllowlistManager from rich.console import Console -from rich.panel import Panel console = Console(color_system="standard", width=200) -REPO_ROOT = Path(__file__).parents[3] +REPO_ROOT = AIRFLOW_ROOT_PATH # Match lines that actually raise AirflowException. Comment filtering is done # in _raise_lines() by skipping lines whose stripped form starts with "#". _RAISE_RE = re.compile(r"raise\s+AirflowException\b") -class AllowlistManager: +class AirflowExceptionAllowlistManager(AllowlistManager): def __init__(self, allowlist_file: Path) -> None: - self.allowlist_file = allowlist_file - - def load(self) -> dict[str, int]: - """Return mapping of ``relative_path -> allowed_count``.""" - if not self.allowlist_file.exists(): - return {} - - result: dict[str, int] = {} - for raw_line in self.allowlist_file.read_text().splitlines(): - if not (stripped := raw_line.strip()): - continue - - rel_str, _, count_str = stripped.rpartition("::") - if not rel_str or not count_str: - continue - - try: - result[rel_str] = int(count_str) - except ValueError: - continue - - return result - - def save(self, counts: dict[str, int]) -> None: - lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())] - self.allowlist_file.write_text("\n".join(lines) + "\n") - - def generate(self) -> int: - console.print(f"Scanning [cyan]{REPO_ROOT}[/cyan] for raise AirflowException …") - counts: dict[str, int] = {} - for path in _iter_python_files(): - n = len(_raise_lines(path)) - if n > 0: - counts[str(path.relative_to(REPO_ROOT))] = n - - self.save(counts) - total = sum(counts.values()) - console.print( - f"[green]✓ Generated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] " - f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] occurrences." + super().__init__(allowlist_file, repo_root=REPO_ROOT) + + def iter_files(self) -> Iterable[Path]: + return _iter_python_files() + + def count_occurrences(self, path: Path) -> int: + return len(_raise_lines(path)) + + def violation_panel_text(self) -> str: + return ( + "New [bold]raise AirflowException[/bold] usage detected.\n" + "Define a dedicated exception class or use an existing specific exception.\n" + "If this usage is intentional and pre-existing, run:\n\n" + " [cyan]uv run ./scripts/ci/prek/check_new_airflow_exception_usage.py --generate[/cyan]\n\n" + "to regenerate the allowlist, then commit the updated\n" + "[cyan]generated/known_airflow_exceptions.txt[/cyan]." ) - return 0 - - def cleanup(self) -> int: - allowlist = self.load() - if not allowlist: - console.print("[yellow]Allowlist is empty – nothing to clean up.[/yellow]") - return 0 - - stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / rel).exists()] - if stale: - console.print( - f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) == 1 else 'ies'}:[/yellow]" - ) - for s in sorted(stale): - console.print(f" [dim]-[/dim] {s}") - for s in stale: - del allowlist[s] - self.save(allowlist) - console.print( - f"\n[green]Updated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]" - ) - else: - console.print("[green]✓ No stale entries found.[/green]") - return 0 def _raise_lines(path: Path) -> list[str]: @@ -162,57 +115,9 @@ def _iter_python_files() -> list[Path]: def _check_airflow_exception_usage( - files: list[Path], allowlist: dict[str, int], manager: AllowlistManager + files: list[Path], allowlist: dict[str, int], manager: AirflowExceptionAllowlistManager ) -> int: - violations: list[tuple[Path, int, int]] = [] - tightened: list[tuple[str, int, int]] = [] # (rel, old_count, new_count) - - for path in files: - if not path.exists() or path.suffix != ".py": - continue - actual = len(_raise_lines(path)) - rel = str(path.relative_to(REPO_ROOT)) - allowed = allowlist.get(rel, 0) - if actual > allowed: - violations.append((path, actual, allowed)) - elif actual < allowed: - # Usage was reduced — tighten the allowlist entry so it can't creep back up. - if actual == 0: - del allowlist[rel] - else: - allowlist[rel] = actual - tightened.append((rel, allowed, actual)) - - if tightened: - manager.save(allowlist) - console.print( - f"[green]✓ Tightened {len(tightened)} entr{'y' if len(tightened) == 1 else 'ies'} " - f"in [cyan]{manager.allowlist_file.relative_to(REPO_ROOT)}[/cyan][/green] " - "(stage the updated file):" - ) - for rel, old, new in tightened: - console.print(f" [cyan]{rel}[/cyan] {old} → {new}") - - if violations: - console.print( - Panel.fit( - "New [bold]raise AirflowException[/bold] usage detected.\n" - "Define a dedicated exception class or use an existing specific exception.\n" - "If this usage is intentional and pre-existing, run:\n\n" - " [cyan]uv run ./scripts/ci/prek/check_new_airflow_exception_usage.py --generate[/cyan]\n\n" - "to regenerate the allowlist, then commit the updated\n" - "[cyan]generated/known_airflow_exceptions.txt[/cyan].", - title="[red]❌ Check failed[/red]", - border_style="red", - ) - ) - for path, actual, allowed in violations: - console.print(f" [cyan]{path.relative_to(REPO_ROOT)}[/cyan] count={actual} (allowed={allowed})") - return 1 - - # Return 1 when the allowlist was tightened so pre-commit reports the file as modified - # and prompts the user to stage the updated allowlist. - return 1 if tightened else 0 + return manager.check(files, allowlist) def main(argv: list[str] | None = None) -> int: @@ -239,7 +144,7 @@ def main(argv: list[str] | None = None) -> int: ) args = parser.parse_args(argv) - manager = AllowlistManager(REPO_ROOT / "generated" / "known_airflow_exceptions.txt") + manager = AirflowExceptionAllowlistManager(REPO_ROOT / "generated" / "known_airflow_exceptions.txt") if args.generate: return manager.generate() diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py index 29152e13a4d86..0025126d52f0b 100755 --- a/scripts/ci/prek/check_provide_session_kwargs.py +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -65,14 +65,16 @@ import ast import subprocess import typing +from collections.abc import Iterable from pathlib import Path +from common_prek_utils import AIRFLOW_ROOT_PATH, AllowlistManager from rich.console import Console from rich.panel import Panel console = Console(color_system="standard", width=200) -REPO_ROOT = Path(__file__).parents[3] +REPO_ROOT = AIRFLOW_ROOT_PATH _PROVIDE_SESSION_DECORATOR = "provide_session" @@ -133,109 +135,33 @@ def _count_violations(path: Path) -> int: return sum(1 for _ in _iter_positional_session_in_provide_session(path)) -def _is_safe_relative(rel: str) -> bool: - """Whether ``rel`` is a plain relative path that stays inside ``REPO_ROOT``. - - Rejects absolute paths and any entry that resolves outside the repo root so - callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``. - """ - candidate = Path(rel) - if candidate.is_absolute(): - return False - try: - (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve()) - except ValueError: - return False - return True - - -class AllowlistManager: +class ProvideSessionAllowlistManager(AllowlistManager): def __init__(self, allowlist_file: Path) -> None: - self.allowlist_file = allowlist_file - - @staticmethod - def parse(text: str) -> dict[str, int]: - """Parse allowlist *text* into a ``{rel_path: count}`` mapping. - - Same validation rules as :meth:`load` so we can reuse parsing for the - on-disk allowlist *and* for the git-tracked version fetched from - ``HEAD`` when guarding against entry-removal bypasses. - """ - result: dict[str, int] = {} - for raw_line in text.splitlines(): - if not (stripped := raw_line.strip()): - continue - - rel_str, _, count_str = stripped.rpartition("::") - if not rel_str or not count_str: - continue - - try: - count = int(count_str) - except ValueError: - continue - - if not _is_safe_relative(rel_str): - console.print( - f"[yellow]Ignoring unsafe allowlist entry (escapes repo root):[/yellow] {rel_str}" - ) - continue - - result[rel_str] = count - - return result - - def load(self) -> dict[str, int]: - if not self.allowlist_file.exists(): - return {} - return self.parse(self.allowlist_file.read_text()) - - def save(self, counts: dict[str, int]) -> None: - lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())] - self.allowlist_file.write_text("\n".join(lines) + "\n") - - def generate(self) -> int: - roots = ", ".join(_PROJECT_SOURCE_ROOTS) - console.print( - f"Scanning project source roots ([cyan]{roots}[/cyan]) under [cyan]{REPO_ROOT}[/cyan] " - "for @provide_session functions with positional session …" + super().__init__(allowlist_file, repo_root=REPO_ROOT) + + def iter_files(self) -> Iterable[Path]: + return _iter_python_files() + + def count_occurrences(self, path: Path) -> int: + return _count_violations(path) + + def violation_panel_text(self) -> str: + return ( + "New [bold]@provide_session[/bold] function with positional ``session`` detected.\n" + "Move ``session`` after a bare ``*`` in the signature so callers must pass it by keyword:\n\n" + " [cyan]@provide_session\n" + " def foo(arg, *, session: Session = NEW_SESSION) -> None: ...[/cyan]\n\n" + "If this usage is intentional and pre-existing, run:\n\n" + " [cyan]uv run ./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]\n\n" + "to regenerate the allowlist, then commit the updated\n" + "[cyan]scripts/ci/prek/known_provide_session_positional.txt[/cyan]." ) - counts: dict[str, int] = {} - for path in _iter_python_files(): - n = _count_violations(path) - if n > 0: - counts[str(path.relative_to(REPO_ROOT))] = n - - self.save(counts) - total = sum(counts.values()) - console.print( - f"[green]Generated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] " - f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] offenders." - ) - return 0 - - def cleanup(self) -> int: - allowlist = self.load() - if not allowlist: - console.print("[yellow]Allowlist is empty - nothing to clean up.[/yellow]") - return 0 - stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / rel).exists()] - if stale: - console.print( - f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) == 1 else 'ies'}:[/yellow]" - ) - for s in sorted(stale): - console.print(f" [dim]-[/dim] {s}") - for s in stale: - del allowlist[s] - self.save(allowlist) - console.print( - f"\n[green]Updated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]" - ) - else: - console.print("[green]No stale entries found.[/green]") - return 0 + def format_violation_details(self, path: Path) -> list[str]: + return [ + f" [dim]L{argument.lineno}[/dim] def {func.name}(...)" + for func, argument in _iter_positional_session_in_provide_session(path) + ] def _iter_python_files() -> list[Path]: @@ -250,7 +176,7 @@ def _iter_python_files() -> list[Path]: def _check_provide_session_kwargs( - files: list[Path], allowlist: dict[str, int], manager: AllowlistManager + files: list[Path], allowlist: dict[str, int], manager: ProvideSessionAllowlistManager ) -> int: allowlist_file = manager.allowlist_file.resolve() if any(p.resolve() == allowlist_file for p in files) and not allowlist_file.exists(): @@ -265,57 +191,7 @@ def _check_provide_session_kwargs( ) ) return 1 - - violations: list[tuple[Path, int, int]] = [] - tightened: list[tuple[str, int, int]] = [] - - for path in files: - if not path.exists() or path.suffix != ".py": - continue - actual = _count_violations(path) - rel = str(path.relative_to(REPO_ROOT)) - allowed = allowlist.get(rel, 0) - if actual > allowed: - violations.append((path, actual, allowed)) - elif actual < allowed: - if actual == 0: - del allowlist[rel] - else: - allowlist[rel] = actual - tightened.append((rel, allowed, actual)) - - if tightened: - manager.save(allowlist) - console.print( - f"[green]Tightened {len(tightened)} entr{'y' if len(tightened) == 1 else 'ies'} " - f"in [cyan]{manager.allowlist_file.relative_to(REPO_ROOT)}[/cyan][/green] " - "(stage the updated file):" - ) - for rel, old, new in tightened: - console.print(f" [cyan]{rel}[/cyan] {old} -> {new}") - - if violations: - console.print( - Panel.fit( - "New [bold]@provide_session[/bold] function with positional ``session`` detected.\n" - "Move ``session`` after a bare ``*`` in the signature so callers must pass it by keyword:\n\n" - " [cyan]@provide_session\n" - " def foo(arg, *, session: Session = NEW_SESSION) -> None: ...[/cyan]\n\n" - "If this usage is intentional and pre-existing, run:\n\n" - " [cyan]uv run ./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]\n\n" - "to regenerate the allowlist, then commit the updated\n" - "[cyan]scripts/ci/prek/known_provide_session_positional.txt[/cyan].", - title="[red]Check failed[/red]", - border_style="red", - ) - ) - for path, actual, allowed in violations: - console.print(f" [cyan]{path.relative_to(REPO_ROOT)}[/cyan] count={actual} (allowed={allowed})") - for func, argument in _iter_positional_session_in_provide_session(path): - console.print(f" [dim]L{argument.lineno}[/dim] def {func.name}(...)") - return 1 - - return 1 if tightened else 0 + return manager.check(files, allowlist) def main(argv: list[str] | None = None) -> int: @@ -342,7 +218,7 @@ def main(argv: list[str] | None = None) -> int: ) args = parser.parse_args(argv) - manager = AllowlistManager(Path(__file__).parent / "known_provide_session_positional.txt") + manager = ProvideSessionAllowlistManager(Path(__file__).parent / "known_provide_session_positional.txt") if args.generate: return manager.generate() @@ -366,7 +242,7 @@ def main(argv: list[str] | None = None) -> int: return _check_provide_session_kwargs(paths, allowlist, manager) -def _parse_tracked_allowlist(manager: AllowlistManager) -> dict[str, int]: +def _parse_tracked_allowlist(manager: ProvideSessionAllowlistManager) -> dict[str, int]: """Return the allowlist as recorded at ``HEAD`` (the git-tracked version). Used by :func:`_expand_for_allowlist_edits` so that *removing* an entry @@ -390,11 +266,11 @@ def _parse_tracked_allowlist(manager: AllowlistManager) -> dict[str, int]: return {} if completed.returncode != 0: return {} - return AllowlistManager.parse(completed.stdout) + return manager.parse(completed.stdout) def _expand_for_allowlist_edits( - paths: list[Path], manager: AllowlistManager, allowlist: dict[str, int] + paths: list[Path], manager: ProvideSessionAllowlistManager, allowlist: dict[str, int] ) -> list[Path]: """Add allowlisted files when the allowlist itself is being changed. diff --git a/scripts/ci/prek/common_prek_utils.py b/scripts/ci/prek/common_prek_utils.py index 434bf13a482ad..26e49175ce607 100644 --- a/scripts/ci/prek/common_prek_utils.py +++ b/scripts/ci/prek/common_prek_utils.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import abc import ast import difflib import os @@ -26,7 +27,7 @@ import sys import textwrap import time -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager from pathlib import Path from tempfile import NamedTemporaryFile, _TemporaryFileWrapper @@ -58,10 +59,12 @@ try: from rich.console import Console + from rich.panel import Panel console = Console(width=400, color_system="standard") except ImportError: console = None # type: ignore[assignment] + Panel = None # type: ignore[assignment,misc] @contextmanager @@ -875,3 +878,182 @@ def parse_operations( commands[group_name].append(subcommand) return commands + + +def _is_safe_relative(rel: str, repo_root: Path) -> bool: + """Whether ``rel`` is a plain relative path that stays inside ``repo_root``.""" + candidate = Path(rel) + if candidate.is_absolute(): + return False + try: + (repo_root / candidate).resolve().relative_to(repo_root.resolve()) + except ValueError: + return False + return True + + +class AllowlistManager(abc.ABC): + """Common base for prek hooks that track per-file occurrence counts in allowlist files. + + Subclasses implement :meth:`iter_files`, :meth:`count_occurrences`, and + :meth:`violation_panel_text` to define what gets scanned, how violations + are counted, and what help text to show. Everything else — loading, saving, + generating, cleaning up, and the check loop — is handled here. + """ + + def __init__(self, allowlist_file: Path, *, repo_root: Path = AIRFLOW_ROOT_PATH) -> None: + self.allowlist_file = allowlist_file + self.repo_root = repo_root + + def parse(self, text: str) -> dict[str, int]: + """Parse allowlist *text* into a ``{rel_path: count}`` mapping. + + Entries that escape the repo root (absolute paths or ``..`` segments) + are silently skipped. + """ + result: dict[str, int] = {} + for raw_line in text.splitlines(): + if not (stripped := raw_line.strip()): + continue + + rel_str, _, count_str = stripped.rpartition("::") + if not rel_str or not count_str: + continue + + try: + count = int(count_str) + except ValueError: + continue + + if not _is_safe_relative(rel_str, self.repo_root): + if console: + console.print( + f"[yellow]Ignoring unsafe allowlist entry (escapes repo root):[/yellow] {rel_str}" + ) + continue + + result[rel_str] = count + + return result + + def load(self) -> dict[str, int]: + """Return mapping of ``relative_path -> allowed_count``.""" + if not self.allowlist_file.exists(): + return {} + return self.parse(self.allowlist_file.read_text()) + + def save(self, counts: dict[str, int]) -> None: + lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())] + self.allowlist_file.write_text("\n".join(lines) + "\n") + + @abc.abstractmethod + def iter_files(self) -> Iterable[Path]: + """Return all files to scan during ``--generate`` or ``--all-files``.""" + + @abc.abstractmethod + def count_occurrences(self, path: Path) -> int: + """Count the number of violations/occurrences in a single file.""" + + @abc.abstractmethod + def violation_panel_text(self) -> str: + """Return the rich markup body for the violation help panel.""" + + def format_violation_details(self, path: Path) -> list[str]: + """Return extra detail lines for each violating file.""" + return [] + + def check(self, files: list[Path], allowlist: dict[str, int]) -> int: + """Run the check loop: compare counts, tighten entries, report violations.""" + violations: list[tuple[Path, int, int]] = [] + tightened: list[tuple[str, int, int]] = [] + + for path in files: + if not path.exists() or path.suffix != ".py": + continue + actual = self.count_occurrences(path) + rel = str(path.relative_to(self.repo_root)) + allowed = allowlist.get(rel, 0) + if actual > allowed: + violations.append((path, actual, allowed)) + elif actual < allowed: + if actual == 0: + del allowlist[rel] + else: + allowlist[rel] = actual + tightened.append((rel, allowed, actual)) + + if tightened: + self.save(allowlist) + if console: + console.print( + f"[green]Tightened {len(tightened)} entr{'y' if len(tightened) == 1 else 'ies'} " + f"in [cyan]{self.allowlist_file.relative_to(self.repo_root)}[/cyan][/green] " + "(stage the updated file):" + ) + for rel, old, new in tightened: + console.print(f" [cyan]{rel}[/cyan] {old} → {new}") + + if violations: + if console: + console.print( + Panel.fit( + self.violation_panel_text(), + title="[red]Check failed[/red]", + border_style="red", + ) + ) + for path, actual, allowed in violations: + console.print( + f" [cyan]{path.relative_to(self.repo_root)}[/cyan] " + f"count={actual} (allowed={allowed})" + ) + for detail in self.format_violation_details(path): + console.print(detail) + return 1 + + return 1 if tightened else 0 + + def generate(self) -> int: + if console: + console.print(f"Scanning [cyan]{self.repo_root}[/cyan] …") + counts: dict[str, int] = {} + for path in self.iter_files(): + n = self.count_occurrences(path) + if n > 0: + counts[str(path.relative_to(self.repo_root))] = n + + self.save(counts) + total = sum(counts.values()) + if console: + console.print( + f"[green]Generated[/green] [cyan]{self.allowlist_file.relative_to(self.repo_root)}[/cyan] " + f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] occurrences." + ) + return 0 + + def cleanup(self) -> int: + allowlist = self.load() + if not allowlist: + if console: + console.print("[yellow]Allowlist is empty – nothing to clean up.[/yellow]") + return 0 + + stale: list[str] = [rel for rel in allowlist if not (self.repo_root / rel).exists()] + if stale: + if console: + console.print( + f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) == 1 else 'ies'}:[/yellow]" + ) + for s in sorted(stale): + console.print(f" [dim]-[/dim] {s}") + for s in stale: + del allowlist[s] + self.save(allowlist) + if console: + console.print( + f"\n[green]Updated[/green] [cyan]{self.allowlist_file.relative_to(self.repo_root)}[/cyan]" + ) + else: + if console: + console.print("[green]No stale entries found.[/green]") + return 0 diff --git a/scripts/tests/ci/prek/test_check_new_airflow_exception_usage.py b/scripts/tests/ci/prek/test_check_new_airflow_exception_usage.py new file mode 100644 index 0000000000000..f1abb7d6308b6 --- /dev/null +++ b/scripts/tests/ci/prek/test_check_new_airflow_exception_usage.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest +from ci.prek import check_new_airflow_exception_usage as hook +from ci.prek.check_new_airflow_exception_usage import ( + AirflowExceptionAllowlistManager, + _check_airflow_exception_usage, + _raise_lines, +) + + +@pytest.fixture +def create_fake_repo(tmp_path, monkeypatch): + monkeypatch.setattr(hook, "REPO_ROOT", tmp_path) + + def _write(rel: str, code: str) -> Path: + path = tmp_path / rel + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(textwrap.dedent(code)) + return path + + return _write + + +class TestRaiseLines: + def test_counts_raise_airflow_exception(self, write_python_file): + path = write_python_file( + """\ + from airflow.exceptions import AirflowException + raise AirflowException("boom") + raise AirflowException("bang") + """ + ) + assert len(_raise_lines(path)) == 2 + + def test_ignores_commented_lines(self, write_python_file): + path = write_python_file( + """\ + # raise AirflowException("commented out") + raise AirflowException("real") + """ + ) + assert len(_raise_lines(path)) == 1 + + def test_ignores_other_raises(self, write_python_file): + path = write_python_file( + """\ + raise ValueError("not this") + raise TypeError("nor this") + """ + ) + assert len(_raise_lines(path)) == 0 + + def test_missing_file_returns_empty(self, tmp_path): + assert _raise_lines(tmp_path / "nonexistent.py") == [] + + +class TestAirflowExceptionAllowlistManager: + def test_load_missing_file_returns_empty(self, tmp_path): + manager = AirflowExceptionAllowlistManager(tmp_path / "missing.txt") + assert manager.load() == {} + + def test_save_and_load_round_trip(self, tmp_path): + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + manager.save({"b/file.py": 2, "a/file.py": 1}) + text = (tmp_path / "allowlist.txt").read_text() + assert text.splitlines() == ["a/file.py::1", "b/file.py::2"] + assert manager.load() == {"a/file.py": 1, "b/file.py": 2} + + def test_load_skips_blank_and_malformed_lines(self, tmp_path): + path = tmp_path / "allowlist.txt" + path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n") + assert AirflowExceptionAllowlistManager(path).load() == {"valid/file.py": 3} + + @pytest.mark.usefixtures("create_fake_repo") + def test_load_skips_unsafe_entries(self, tmp_path): + path = tmp_path / "allowlist.txt" + path.write_text("airflow-core/src/airflow/safe.py::1\n../escape.py::1\n/etc/passwd::1\n") + assert AirflowExceptionAllowlistManager(path).load() == {"airflow-core/src/airflow/safe.py": 1} + + +class TestCheckAirflowExceptionUsage: + def test_no_violations_passes(self, create_fake_repo, tmp_path): + path = create_fake_repo( + "airflow-core/src/airflow/clean.py", + """\ + raise ValueError("specific exception") + """, + ) + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + assert _check_airflow_exception_usage([path], {}, manager) == 0 + + def test_new_violation_fails(self, create_fake_repo, tmp_path): + path = create_fake_repo( + "airflow-core/src/airflow/bad.py", + """\ + raise AirflowException("boom") + """, + ) + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + assert _check_airflow_exception_usage([path], {}, manager) == 1 + + def test_violation_within_allowlist_passes(self, create_fake_repo, tmp_path): + path = create_fake_repo( + "airflow-core/src/airflow/grandfathered.py", + """\ + raise AirflowException("old") + """, + ) + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/grandfathered.py": 1} + assert _check_airflow_exception_usage([path], allowlist, manager) == 0 + + def test_exceeding_allowlist_fails(self, create_fake_repo, tmp_path): + path = create_fake_repo( + "airflow-core/src/airflow/grew.py", + """\ + raise AirflowException("one") + raise AirflowException("two") + """, + ) + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/grew.py": 1} + assert _check_airflow_exception_usage([path], allowlist, manager) == 1 + + def test_reducing_violations_tightens_allowlist(self, create_fake_repo, tmp_path): + path = create_fake_repo( + "airflow-core/src/airflow/improved.py", + """\ + raise AirflowException("one remains") + """, + ) + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/improved.py": 2} + assert _check_airflow_exception_usage([path], allowlist, manager) == 1 + assert manager.load() == {"airflow-core/src/airflow/improved.py": 1} + + def test_fixing_all_violations_removes_entry(self, create_fake_repo, tmp_path): + path = create_fake_repo( + "airflow-core/src/airflow/fixed.py", + """\ + raise ValueError("migrated away from AirflowException") + """, + ) + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/fixed.py": 1} + assert _check_airflow_exception_usage([path], allowlist, manager) == 1 + assert manager.load() == {} + + def test_non_python_file_is_skipped(self, create_fake_repo, tmp_path): + path = create_fake_repo( + "airflow-core/src/airflow/not_python.txt", + "raise AirflowException('in a text file')\n", + ) + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + assert _check_airflow_exception_usage([path], {}, manager) == 0 + + +class TestCleanup: + def test_cleanup_removes_stale_entries(self, create_fake_repo, tmp_path): + create_fake_repo("airflow-core/src/airflow/keeper.py", "pass") + allowlist_path = tmp_path / "allowlist.txt" + manager = AirflowExceptionAllowlistManager(allowlist_path) + manager.save( + { + "airflow-core/src/airflow/keeper.py": 1, + "airflow-core/src/airflow/gone.py": 1, + } + ) + assert manager.cleanup() == 0 + assert manager.load() == {"airflow-core/src/airflow/keeper.py": 1} + + def test_cleanup_empty_allowlist(self, tmp_path): + manager = AirflowExceptionAllowlistManager(tmp_path / "allowlist.txt") + assert manager.cleanup() == 0 diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py index 78b85cd270bbb..4e96f1a413efa 100644 --- a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -25,7 +25,7 @@ import pytest from ci.prek import check_provide_session_kwargs as hook from ci.prek.check_provide_session_kwargs import ( - AllowlistManager, + ProvideSessionAllowlistManager, _check_provide_session_kwargs, _count_violations, _expand_for_allowlist_edits, @@ -228,13 +228,13 @@ def test_invalid_utf8_does_not_crash(self, tmp_path): assert _count_violations(path) == 1 -class TestAllowlistManager: +class TestProvideSessionAllowlistManager: def test_load_missing_file_returns_empty(self, tmp_path): - manager = AllowlistManager(tmp_path / "missing.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "missing.txt") assert manager.load() == {} def test_save_and_load_round_trip(self, tmp_path): - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") manager.save({"b/file.py": 2, "a/file.py": 1}) # Sorted by key in the file text = (tmp_path / "allowlist.txt").read_text() @@ -244,7 +244,7 @@ def test_save_and_load_round_trip(self, tmp_path): def test_load_skips_blank_and_malformed_lines(self, tmp_path): path = tmp_path / "allowlist.txt" path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n") - assert AllowlistManager(path).load() == {"valid/file.py": 3} + assert ProvideSessionAllowlistManager(path).load() == {"valid/file.py": 3} @pytest.mark.usefixtures("create_fake_repo") def test_load_skips_unsafe_entries(self, tmp_path): @@ -252,7 +252,7 @@ def test_load_skips_unsafe_entries(self, tmp_path): path = tmp_path / "allowlist.txt" path.write_text("airflow-core/src/airflow/safe.py::1\n../escape.py::1\n/etc/passwd::1\n") # `create_fake_repo` patches REPO_ROOT to tmp_path so the safety check is meaningful. - assert AllowlistManager(path).load() == {"airflow-core/src/airflow/safe.py": 1} + assert ProvideSessionAllowlistManager(path).load() == {"airflow-core/src/airflow/safe.py": 1} class TestCheckProvideSessionKwargs: @@ -265,7 +265,7 @@ def foo(*, session=NEW_SESSION): pass """, ) - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") assert _check_provide_session_kwargs([path], {}, manager) == 0 def test_new_violation_fails(self, create_fake_repo, tmp_path): @@ -277,7 +277,7 @@ def foo(session=NEW_SESSION): pass """, ) - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") assert _check_provide_session_kwargs([path], {}, manager) == 1 def test_violation_within_allowlist_passes(self, create_fake_repo, tmp_path): @@ -289,7 +289,7 @@ def foo(session=NEW_SESSION): pass """, ) - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") allowlist = {"airflow-core/src/airflow/grandfathered.py": 1} assert _check_provide_session_kwargs([path], allowlist, manager) == 0 @@ -306,7 +306,7 @@ def b(session=NEW_SESSION): pass """, ) - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") allowlist = {"airflow-core/src/airflow/grew.py": 1} assert _check_provide_session_kwargs([path], allowlist, manager) == 1 @@ -323,7 +323,7 @@ def bar(*, session=NEW_SESSION): pass """, ) - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") allowlist = {"airflow-core/src/airflow/improved.py": 2} # Exit non-zero so pre-commit reports the modified allowlist assert _check_provide_session_kwargs([path], allowlist, manager) == 1 @@ -338,7 +338,7 @@ def foo(*, session=NEW_SESSION): pass """, ) - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") allowlist = {"airflow-core/src/airflow/fixed.py": 1} assert _check_provide_session_kwargs([path], allowlist, manager) == 1 assert manager.load() == {} @@ -347,14 +347,14 @@ def test_non_python_file_is_skipped(self, create_fake_repo, tmp_path): path = create_fake_repo( "airflow-core/src/airflow/not_python.txt", "@provide_session\ndef foo(session=N): pass\n" ) - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") assert _check_provide_session_kwargs([path], {}, manager) == 0 @pytest.mark.usefixtures("create_fake_repo") def test_missing_allowlist_file_fails_loudly(self, tmp_path): """Passing the allowlist path when the file is missing must fail, not silently pass.""" allowlist_path = tmp_path / "allowlist.txt" - manager = AllowlistManager(allowlist_path) + manager = ProvideSessionAllowlistManager(allowlist_path) assert not allowlist_path.exists() assert _check_provide_session_kwargs([allowlist_path.resolve()], {}, manager) == 1 @@ -362,12 +362,12 @@ def test_missing_allowlist_file_fails_loudly(self, tmp_path): class TestExpandForAllowlistEdits: def test_unchanged_when_allowlist_not_in_paths(self, create_fake_repo, tmp_path): py = create_fake_repo("airflow-core/src/airflow/x.py", "pass") - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") assert _expand_for_allowlist_edits([py], manager, {"airflow-core/src/airflow/x.py": 1}) == [py] def test_appends_allowlisted_files_when_allowlist_edited(self, create_fake_repo, tmp_path): allowlist_path = tmp_path / "allowlist.txt" - manager = AllowlistManager(allowlist_path) + manager = ProvideSessionAllowlistManager(allowlist_path) listed = create_fake_repo("airflow-core/src/airflow/listed.py", "pass") # Pass a resolved path — matches production behavior (``main()`` resolves argv). result = _expand_for_allowlist_edits( @@ -383,7 +383,7 @@ def test_appends_allowlisted_files_when_allowlist_edited(self, create_fake_repo, def test_detection_robust_to_symlinked_allowlist(self, create_fake_repo, tmp_path): """A symlink pointing at the allowlist file must still trigger expansion.""" allowlist_path = tmp_path / "allowlist.txt" - manager = AllowlistManager(allowlist_path) + manager = ProvideSessionAllowlistManager(allowlist_path) listed = create_fake_repo("airflow-core/src/airflow/listed.py", "pass") manager.save({"airflow-core/src/airflow/listed.py": 1}) @@ -410,7 +410,7 @@ def foo(session=NEW_SESSION): """, ) allowlist_path = tmp_path / "allowlist.txt" - manager = AllowlistManager(allowlist_path) + manager = ProvideSessionAllowlistManager(allowlist_path) manager.save({rel: 1}) create_git_repo("seed allowlist at HEAD") @@ -429,7 +429,7 @@ def foo(session=NEW_SESSION): @pytest.mark.usefixtures("create_fake_repo") def test_parse_tracked_allowlist_empty_when_no_git_history(self, tmp_path): """Without a git repo the git-tracked allowlist lookup returns empty and does not crash.""" - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") assert _parse_tracked_allowlist(manager) == {} def test_re_validates_listed_files_so_loosening_cannot_bypass(self, create_fake_repo, tmp_path, capsys): @@ -448,7 +448,7 @@ def bar(session=NEW_SESSION): """, ) allowlist_path = tmp_path / "allowlist.txt" - manager = AllowlistManager(allowlist_path) + manager = ProvideSessionAllowlistManager(allowlist_path) # Allowlist loosened to 5 although file only has 2 positional sessions. allowlist = {rel: 5} manager.save(allowlist) @@ -467,7 +467,7 @@ class TestCleanup: def test_cleanup_removes_stale_entries(self, create_fake_repo, tmp_path): create_fake_repo("airflow-core/src/airflow/keeper.py", "pass") allowlist_path = tmp_path / "allowlist.txt" - manager = AllowlistManager(allowlist_path) + manager = ProvideSessionAllowlistManager(allowlist_path) manager.save( { "airflow-core/src/airflow/keeper.py": 1, @@ -478,5 +478,5 @@ def test_cleanup_removes_stale_entries(self, create_fake_repo, tmp_path): assert manager.load() == {"airflow-core/src/airflow/keeper.py": 1} def test_cleanup_empty_allowlist(self, tmp_path): - manager = AllowlistManager(tmp_path / "allowlist.txt") + manager = ProvideSessionAllowlistManager(tmp_path / "allowlist.txt") assert manager.cleanup() == 0