From 532b7714c2cb91b277c0a7c87e0e5fa6d144819c Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 1 Jul 2026 12:25:38 +0200 Subject: [PATCH] Harden US L0 refit H5 reconstruction --- .../build/us_runtime/l0_refit_export.py | 59 ++++++++++++++++++- .../tests/test_us_l0_refit_export.py | 43 +++++++++++++- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/packages/populace-build/src/populace/build/us_runtime/l0_refit_export.py b/packages/populace-build/src/populace/build/us_runtime/l0_refit_export.py index b8634c9..881e996 100644 --- a/packages/populace-build/src/populace/build/us_runtime/l0_refit_export.py +++ b/packages/populace-build/src/populace/build/us_runtime/l0_refit_export.py @@ -14,6 +14,11 @@ from populace.frame import US_SCHEMA, Frame, MassChange, WeightKind, Weights from populace.frame.adapters.policyengine_us import PolicyEngineUSEngine +US_RELEASE_REQUIRED_TAX_UNIT_SOURCE_COLUMNS = ( + "takes_up_aca_if_eligible", + "selected_marketplace_plan_benchmark_ratio", +) + @dataclass(frozen=True) class L0RefitWeights: @@ -211,6 +216,29 @@ def attach_l0_refit_entity_weights( ) +def assert_required_us_release_source_columns( + frame: Frame, + *, + columns: tuple[str, ...] = US_RELEASE_REQUIRED_TAX_UNIT_SOURCE_COLUMNS, +) -> None: + """Require source-stage tax-unit columns needed by US release gates.""" + + tax_unit = frame.table("tax_unit") + failures: list[str] = [] + for column in columns: + if column not in tax_unit.columns: + failures.append(f"{column}: missing") + continue + unique = tax_unit[column].dropna().unique() + if len(unique) < 2: + failures.append(f"{column}: not nonconstant") + if failures: + raise ValueError( + "US L0/refit release export requires source-stage tax-unit columns: " + + "; ".join(failures) + ) + + def load_us_frame(path: str | Path) -> Frame: """Load a PolicyEngine-US single-year H5 into a Populace frame.""" @@ -263,9 +291,14 @@ def export_us_l0_refit_h5( weight_key: str = "weights", zero_weight_tolerance: float = 0.0, summary_json: str | Path | None = None, + require_source_columns: bool = True, + root_attrs_h5: str | Path | None = None, ) -> dict[str, Any]: """Write a selected US H5 from a base H5 and saved L0/refit weights.""" + root_attrs_source = ( + Path(root_attrs_h5) if root_attrs_h5 is not None else Path(base_h5) + ) base_frame = load_us_frame(base_h5) solution = load_l0_refit_npz( weights_npz, @@ -274,18 +307,23 @@ def export_us_l0_refit_h5( zero_weight_tolerance=zero_weight_tolerance, ) export_frame = attach_l0_refit_weights(base_frame, solution) + if require_source_columns: + assert_required_us_release_source_columns(export_frame) destination = Path(output_h5) destination.parent.mkdir(parents=True, exist_ok=True) PolicyEngineUSEngine().write_dataset(export_frame, destination, period=period) - copied_attrs = copy_populace_root_attrs(base_h5, destination) + copied_attrs = copy_populace_root_attrs(root_attrs_source, destination) summary = { "schema_version": 1, "kind": "us_l0_refit_h5_export", "base_h5": _file_manifest(base_h5), + "root_attrs_h5": _file_manifest(root_attrs_source), "weights_npz": _file_manifest(weights_npz), "output_h5": _file_manifest(destination), "period": int(period), "weight_key": weight_key, + "required_source_columns": list(US_RELEASE_REQUIRED_TAX_UNIT_SOURCE_COLUMNS), + "required_source_columns_checked": bool(require_source_columns), "candidate_households": int(base_frame.n("household")), "selected_households": int(export_frame.n("household")), "selected_weight_sum": float(export_frame.weights_for("household").total), @@ -327,6 +365,22 @@ def _parser() -> argparse.ArgumentParser: ".l0_refit_export_summary.json file beside --output-h5." ), ) + parser.add_argument( + "--root-attrs-h5", + type=Path, + help=( + "Optional H5 whose Populace-owned root attrs are copied to the " + "exported dataset. Defaults to --base-h5." + ), + ) + parser.add_argument( + "--allow-missing-source-columns", + action="store_true", + help=( + "Diagnostic escape hatch. By default, reconstruction requires the " + "US release source-stage tax-unit columns used by release gates." + ), + ) return parser @@ -340,6 +394,8 @@ def main(argv: list[str] | None = None) -> None: weight_key=args.weight_key, zero_weight_tolerance=args.zero_weight_tolerance, summary_json=args.summary_json, + require_source_columns=not args.allow_missing_source_columns, + root_attrs_h5=args.root_attrs_h5, ) print(json.dumps(summary, indent=2, sort_keys=True)) @@ -348,6 +404,7 @@ def main(argv: list[str] | None = None) -> None: "L0RefitWeights", "attach_l0_refit_entity_weights", "attach_l0_refit_weights", + "assert_required_us_release_source_columns", "copy_populace_root_attrs", "export_us_l0_refit_h5", "load_l0_refit_npz", diff --git a/packages/populace-build/tests/test_us_l0_refit_export.py b/packages/populace-build/tests/test_us_l0_refit_export.py index b986055..d9760cc 100644 --- a/packages/populace-build/tests/test_us_l0_refit_export.py +++ b/packages/populace-build/tests/test_us_l0_refit_export.py @@ -8,6 +8,7 @@ from populace.build.us_runtime import l0_refit_export from populace.build.us_runtime.l0_refit_export import ( + assert_required_us_release_source_columns, attach_l0_refit_entity_weights, attach_l0_refit_weights, export_us_l0_refit_h5, @@ -36,7 +37,15 @@ def _us_frame() -> Frame: {"household_id": np.asarray([10, 20], dtype="int64")} ), "tax_unit": pd.DataFrame( - {"tax_unit_id": np.asarray([100, 200, 201], dtype="int64")} + { + "tax_unit_id": np.asarray([100, 200, 201], dtype="int64"), + "takes_up_aca_if_eligible": np.asarray( + [False, True, False], dtype=bool + ), + "selected_marketplace_plan_benchmark_ratio": np.asarray( + [1.0, 0.8, 1.2], dtype="float64" + ), + } ), "spm_unit": pd.DataFrame( {"spm_unit_id": np.asarray([1000, 2000], dtype="int64")} @@ -117,6 +126,27 @@ def test_attach_l0_refit_entity_weights_rejects_misaligned_weights() -> None: ) +def test_required_us_release_source_columns_rejects_missing_source_stage() -> None: + frame = _us_frame() + raw_tax_units = frame.table("tax_unit").drop( + columns=[ + "takes_up_aca_if_eligible", + "selected_marketplace_plan_benchmark_ratio", + ] + ) + raw_frame = Frame( + { + **{entity: frame.table(entity).copy() for entity in frame.schema.entities}, + "tax_unit": raw_tax_units, + }, + frame.schema, + {"household": frame.weights_for("household")}, + ) + + with pytest.raises(ValueError, match="takes_up_aca_if_eligible: missing"): + assert_required_us_release_source_columns(raw_frame) + + def test_export_us_l0_refit_h5_uses_existing_policyengine_writer( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, @@ -130,13 +160,18 @@ def test_export_us_l0_refit_h5_uses_existing_policyengine_writer( ) base_h5 = tmp_path / "base.h5" base_h5.write_text("base") + attrs_h5 = tmp_path / "attrs.h5" + attrs_h5.write_text("attrs") output = tmp_path / "populace_us_2024.h5" + copied_from: list[Path] = [] monkeypatch.setattr(l0_refit_export, "load_us_frame", lambda path: frame) monkeypatch.setattr( l0_refit_export, "copy_populace_root_attrs", - lambda source, destination: ("populace_test_attr",), + lambda source, destination: ( + copied_from.append(Path(source)) or ("populace_test_attr",) + ), ) class FakeEngine: @@ -156,15 +191,19 @@ def write_dataset(self, bundle, path, period): base_h5=base_h5, weights_npz=npz, output_h5=output, + root_attrs_h5=attrs_h5, ) assert output.read_text() == "sentinel" + assert copied_from == [attrs_h5] manifest_path = output.with_suffix(".l0_refit_export_summary.json") manifest = json.loads(manifest_path.read_text()) assert manifest["schema_version"] == 1 assert manifest["kind"] == "us_l0_refit_h5_export" assert manifest["summary_json_path"] == str(manifest_path) + assert manifest["required_source_columns_checked"] is True assert manifest["base_h5"]["sha256"] == sha256(b"base").hexdigest() + assert manifest["root_attrs_h5"]["sha256"] == sha256(b"attrs").hexdigest() assert manifest["weights_npz"]["sha256"] == sha256(npz.read_bytes()).hexdigest() assert manifest["output_h5"]["sha256"] == sha256(b"sentinel").hexdigest() assert summary["candidate_households"] == 2