From 361f7e7b416b92876269ab3e17ec95d3d239fc9b Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 1 Jul 2026 14:05:34 +0200 Subject: [PATCH] Stream reform validation simulations --- .../build/us_runtime/reform_validation.py | 15 +---- .../tests/test_reform_validation.py | 66 +++++++++++++++++-- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/packages/populace-build/src/populace/build/us_runtime/reform_validation.py b/packages/populace-build/src/populace/build/us_runtime/reform_validation.py index 94e43d6..08bdeb9 100644 --- a/packages/populace-build/src/populace/build/us_runtime/reform_validation.py +++ b/packages/populace-build/src/populace/build/us_runtime/reform_validation.py @@ -365,14 +365,10 @@ def reform_validation_payload( targets = in_sample_targets or {} baseline: Any = None baseline_totals: dict[tuple[int, str], float] = {} - parameter_reform_sims: dict[str, Any] = {} obbba_specs = tuple(spec for spec in specs if _is_obbba_spec(spec)) obbba_pre_baseline_changes = _merged_parameter_changes(obbba_specs) obbba_stacked: dict[str, tuple[float, float, float]] | None = None - def parameter_changes_key(changes: dict[str, Any]) -> str: - return json.dumps(changes, sort_keys=True, separators=(",", ":")) - def baseline_total(measure: str, at_period: int) -> float: nonlocal baseline if baseline is None: @@ -383,11 +379,8 @@ def baseline_total(measure: str, at_period: int) -> float: return baseline_totals[key] def simulation_for_parameter_changes(changes: dict[str, Any]) -> Any: - key = parameter_changes_key(changes) - if key not in parameter_reform_sims: - reform = None if not changes else _build_parameter_reform(changes) - parameter_reform_sims[key] = simulate(reform) # type: ignore[misc] - return parameter_reform_sims[key] + reform = None if not changes else _build_parameter_reform(changes) + return simulate(reform) # type: ignore[misc] def stacked_obbba_effects() -> dict[str, tuple[float, float, float]]: """Score the OBBBA provisions *stacked* in their JCX-35-25 order. @@ -411,9 +404,7 @@ def stacked_obbba_effects() -> dict[str, tuple[float, float, float]]: return {} measures = {(spec.budget_measure, spec.period) for spec in obbba_specs} if len(measures) != 1: - return { - spec.id: _isolated_obbba_effect(spec) for spec in obbba_specs - } + return {spec.id: _isolated_obbba_effect(spec) for spec in obbba_specs} measure, period = next(iter(measures)) # state 0: pre-OBBBA (every provision reverted). prev_total = _weighted_total( diff --git a/packages/populace-build/tests/test_reform_validation.py b/packages/populace-build/tests/test_reform_validation.py index 51f10cd..a45d79a 100644 --- a/packages/populace-build/tests/test_reform_validation.py +++ b/packages/populace-build/tests/test_reform_validation.py @@ -213,12 +213,66 @@ def simulate(reform): assert rows["obbba_b"]["populace"]["reform_total"] == pytest.approx(960.0) assert rows["obbba_b"]["populace"]["budget_effect"] == pytest.approx(60.0) # Stacked line effects telescope to the true total OBBBA effect. - total = sum( - rows[i]["populace"]["budget_effect"] for i in ("obbba_a", "obbba_b") - ) + total = sum(rows[i]["populace"]["budget_effect"] for i in ("obbba_a", "obbba_b")) assert total == pytest.approx(960.0 - 1_000.0) +def test_obbba_stacked_scoring_releases_intermediate_simulations(monkeypatch): + specs = tuple( + ReformValidationSpec( + id=f"obbba_{name}", + name=f"OBBBA {name}", + category="OBBBA", + in_sample=False, + period=2026, + jct_score=None, + jct_window="FY2026", + jct_source="JCX", + jct_source_url="", + parameter_changes={ + f"gov.example.{name}": {"2026-01-01.2026-12-31": 0}, + }, + effect_direction="baseline_minus_reform", + ) + for name in ("a", "b", "c") + ) + monkeypatch.setattr( + reform_validation_module, + "_build_parameter_reform", + lambda changes: frozenset(changes), + ) + + live_simulations = 0 + + class _TrackedSim(_FakeSim): + def __init__(self, total: float) -> None: + nonlocal live_simulations + assert live_simulations == 0 + live_simulations += 1 + super().__init__({"income_tax": total}) + + def __del__(self) -> None: + nonlocal live_simulations + live_simulations -= 1 + + def simulate(reform): + totals = { + frozenset({"gov.example.a", "gov.example.b", "gov.example.c"}): 1000.0, + frozenset({"gov.example.b", "gov.example.c"}): 900.0, + frozenset({"gov.example.c"}): 960.0, + None: 970.0, + } + return _TrackedSim(totals[reform]) + + payload = reform_validation_payload(specs, period=2026, simulate=simulate) + assert [row["id"] for row in payload["reforms"]] == [ + "obbba_a", + "obbba_b", + "obbba_c", + ] + assert live_simulations == 0 + + def test_shipped_obbba_config_is_out_of_sample_counterfactual(): specs = out_of_sample_reform_specs(period=2026) assert {s.id for s in specs} >= {"obbba_no_tax_on_tips", "obbba_no_tax_on_overtime"} @@ -247,9 +301,9 @@ def test_itemized_benefit_limit_counterfactual_keeps_pease(): ) paths = set(spec.parameter_changes or {}) assert paths == {"gov.irs.deductions.itemized.limitation.obbb.applies"}, paths - assert ( - "gov.irs.deductions.itemized.limitation.applies" not in paths - ), "must not disable the whole limitation (drops present-law Pease)" + assert "gov.irs.deductions.itemized.limitation.applies" not in paths, ( + "must not disable the whole limitation (drops present-law Pease)" + ) def test_shipped_tax_expenditure_specs_neutralize_big_provisions():