From 24d013566f84782fda79b0f4450b153a947ed425 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 1 Jul 2026 11:16:47 +0200 Subject: [PATCH] Add US L0 refit H5 reconstruction utility --- packages/populace-build/pyproject.toml | 3 + .../build/us_runtime/l0_refit_export.py | 360 ++++++++++++++++++ .../tests/test_us_l0_refit_export.py | 177 +++++++++ tools/build_us_fiscal_refresh_release.py | 35 +- tools/export_us_l0_refit_h5.py | 6 + 5 files changed, 556 insertions(+), 25 deletions(-) create mode 100644 packages/populace-build/src/populace/build/us_runtime/l0_refit_export.py create mode 100644 packages/populace-build/tests/test_us_l0_refit_export.py create mode 100644 tools/export_us_l0_refit_h5.py diff --git a/packages/populace-build/pyproject.toml b/packages/populace-build/pyproject.toml index 1e8977a..91043db 100644 --- a/packages/populace-build/pyproject.toml +++ b/packages/populace-build/pyproject.toml @@ -25,6 +25,9 @@ us = ["policyengine-us>=1.745.0,<2", "h5py>=3", "microunit>=0.1.0", "tables>=3"] # still does not import policyengine-uk at import time. uk = ["policyengine-uk>=2.88", "h5py>=3", "tables>=3"] +[project.scripts] +populace-export-us-l0-refit-h5 = "populace.build.us_runtime.l0_refit_export:main" + [project.urls] Homepage = "https://populace.dev" Repository = "https://github.com/PolicyEngine/populace" diff --git a/packages/populace-build/src/populace/build/us_runtime/l0_refit_export.py b/packages/populace-build/src/populace/build/us_runtime/l0_refit_export.py new file mode 100644 index 0000000..b8634c9 --- /dev/null +++ b/packages/populace-build/src/populace/build/us_runtime/l0_refit_export.py @@ -0,0 +1,360 @@ +"""Export a US Populace H5 from a base frame plus saved L0/refit weights.""" + +from __future__ import annotations + +import argparse +import hashlib +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np + +from populace.frame import US_SCHEMA, Frame, MassChange, WeightKind, Weights +from populace.frame.adapters.policyengine_us import PolicyEngineUSEngine + + +@dataclass(frozen=True) +class L0RefitWeights: + """A saved post-L0 refit solution aligned to candidate entity rows.""" + + weights: np.ndarray + selected_mask: np.ndarray + metadata: dict[str, Any] + + @property + def selected_weights(self) -> np.ndarray: + return self.weights[self.selected_mask] + + +def _metadata_from_npz(value: np.ndarray | None) -> dict[str, Any]: + if value is None: + return {} + if value.shape != (): + raise ValueError("metadata_json must be a scalar JSON string.") + raw = value.item() + if isinstance(raw, bytes): + raw = raw.decode() + if not isinstance(raw, str): + raise ValueError("metadata_json must be a scalar JSON string.") + metadata = json.loads(raw) + if not isinstance(metadata, dict): + raise ValueError("metadata_json must decode to an object.") + return metadata + + +def _sha256(path: str | Path) -> str: + digest = hashlib.sha256() + with Path(path).open("rb") as stream: + for block in iter(lambda: stream.read(1024 * 1024), b""): + digest.update(block) + return digest.hexdigest() + + +def _file_manifest(path: str | Path) -> dict[str, Any]: + source = Path(path) + return { + "path": str(source), + "size_bytes": int(source.stat().st_size), + "sha256": _sha256(source), + } + + +def load_l0_refit_npz( + path: str | Path, + *, + expected_candidate_records: int, + weight_key: str = "weights", + zero_weight_tolerance: float = 0.0, +) -> L0RefitWeights: + """Load and validate saved L0/refit weights. + + The returned full-length vector stays aligned to the candidate household + table. Selection is the positive-weight support after ``zero_weight_tolerance``. + """ + + source = Path(path) + with np.load(source, allow_pickle=False) as payload: + if weight_key not in payload.files: + raise ValueError(f"{source} is missing required key {weight_key!r}.") + weights = np.asarray(payload[weight_key], dtype=np.float64) + metadata = _metadata_from_npz( + payload["metadata_json"] if "metadata_json" in payload.files else None + ) + + if weights.shape != (expected_candidate_records,): + raise ValueError( + f"{weight_key!r} must have shape {(expected_candidate_records,)}, " + f"got {weights.shape}." + ) + if not np.isfinite(weights).all(): + raise ValueError(f"{weight_key!r} must be finite.") + if (weights < 0).any(): + raise ValueError(f"{weight_key!r} must be non-negative.") + if zero_weight_tolerance < 0: + raise ValueError("zero_weight_tolerance must be non-negative.") + + selected_mask = weights > zero_weight_tolerance + n_selected = int(selected_mask.sum()) + if n_selected == 0: + raise ValueError("L0/refit weights select zero candidate rows.") + + metadata_candidate_records = metadata.get("candidate_records") + if ( + metadata_candidate_records is not None + and int(metadata_candidate_records) != expected_candidate_records + ): + raise ValueError( + "metadata candidate_records does not match base frame: " + f"{metadata_candidate_records} != {expected_candidate_records}." + ) + metadata_selected = metadata.get("n_selected", metadata.get("budget_achieved")) + if metadata_selected is not None and int(metadata_selected) != n_selected: + raise ValueError( + "metadata selected count does not match positive-weight support: " + f"{metadata_selected} != {n_selected}." + ) + metadata_weight_entity = metadata.get("weight_entity") + if metadata_weight_entity is not None and metadata_weight_entity != "household": + raise ValueError( + "US L0/refit export currently supports household weights only, " + f"got {metadata_weight_entity!r}." + ) + + return L0RefitWeights( + weights=weights, + selected_mask=selected_mask, + metadata=metadata, + ) + + +def attach_l0_refit_weights( + base_frame: Frame, + solution: L0RefitWeights, +) -> Frame: + """Return the selected base-frame support with post-L0 refit weights.""" + + schema = base_frame.schema + weight_entity = "household" + if weight_entity not in schema.group_entities: + raise ValueError("US L0/refit export requires a household group entity.") + if solution.weights.shape != (base_frame.n(weight_entity),): + raise ValueError( + "L0/refit weights must align to household rows: " + f"{solution.weights.shape} != {(base_frame.n(weight_entity),)}." + ) + + selected_ids = base_frame.table(weight_entity)[ + schema.id_column(weight_entity) + ].to_numpy()[solution.selected_mask] + return attach_l0_refit_entity_weights( + base_frame, + weight_entity=weight_entity, + selected_entity_ids=selected_ids, + selected_weights=solution.selected_weights, + reason="US L0/refit saved-weight export", + ) + + +def attach_l0_refit_entity_weights( + base_frame: Frame, + *, + weight_entity: str, + selected_entity_ids: np.ndarray, + selected_weights: np.ndarray, + reason: str, +) -> Frame: + """Return selected support with post-L0 refit weights for one entity.""" + + schema = base_frame.schema + selected_ids = np.asarray(selected_entity_ids) + selected_weights = np.asarray(selected_weights, dtype=np.float64) + if selected_ids.shape != selected_weights.shape: + raise ValueError( + "Selected entity ids and weights must have the same shape: " + f"{selected_ids.shape} != {selected_weights.shape}." + ) + if selected_weights.size == 0: + raise ValueError("L0/refit selected support is empty.") + if not np.isfinite(selected_weights).all(): + raise ValueError("L0/refit selected weights must be finite.") + if (selected_weights < 0).any(): + raise ValueError("L0/refit selected weights must be non-negative.") + if weight_entity == schema.person_entity: + person_ids = base_frame.person[schema.person_id_column].to_numpy() + person_mask = np.isin(person_ids, selected_ids) + elif weight_entity in schema.group_entities: + membership = base_frame.person[ + schema.membership_column(weight_entity) + ].to_numpy() + person_mask = np.isin(membership, selected_ids) + else: + raise ValueError(f"L0/refit export cannot map weight entity {weight_entity!r}.") + selected_base = base_frame.select(person_mask) + exported_ids = selected_base.table(weight_entity)[ + schema.id_column(weight_entity) + ].to_numpy() + if not np.array_equal(exported_ids, selected_ids): + raise ValueError( + "Selected support is not aligned with the base-frame export support " + f"for {weight_entity!r}." + ) + return selected_base.with_weights( + weight_entity, + Weights(selected_weights, WeightKind.CALIBRATED), + mass=MassChange( + factor=selected_weights.sum() + / selected_base.weights_for(weight_entity).total, + reason=reason, + ), + ) + + +def load_us_frame(path: str | Path) -> Frame: + """Load a PolicyEngine-US single-year H5 into a Populace frame.""" + + from policyengine_us.data import USSingleYearDataset + + dataset = USSingleYearDataset(file_path=str(path)) + tables = { + "person": dataset.person.copy(), + "household": dataset.household.copy(), + "tax_unit": dataset.tax_unit.copy(), + "spm_unit": dataset.spm_unit.copy(), + "family": dataset.family.copy(), + "marital_unit": dataset.marital_unit.copy(), + } + weights = tables["household"].pop("household_weight").to_numpy(dtype=np.float64) + return Frame( + tables, + US_SCHEMA, + {"household": Weights(weights, WeightKind.CALIBRATED)}, + ) + + +def copy_populace_root_attrs( + source_h5: str | Path, + destination_h5: str | Path, +) -> tuple[str, ...]: + """Copy Populace-owned root attrs from the base H5 to the exported H5.""" + + import h5py + + copied: list[str] = [] + with ( + h5py.File(source_h5, "r") as source, + h5py.File(destination_h5, "a") as destination, + ): + for name, value in source.attrs.items(): + if not str(name).startswith("populace_"): + continue + destination.attrs[name] = value + copied.append(str(name)) + return tuple(copied) + + +def export_us_l0_refit_h5( + *, + base_h5: str | Path, + weights_npz: str | Path, + output_h5: str | Path, + period: int = 2024, + weight_key: str = "weights", + zero_weight_tolerance: float = 0.0, + summary_json: str | Path | None = None, +) -> dict[str, Any]: + """Write a selected US H5 from a base H5 and saved L0/refit weights.""" + + base_frame = load_us_frame(base_h5) + solution = load_l0_refit_npz( + weights_npz, + expected_candidate_records=base_frame.n("household"), + weight_key=weight_key, + zero_weight_tolerance=zero_weight_tolerance, + ) + export_frame = attach_l0_refit_weights(base_frame, solution) + destination = Path(output_h5) + destination.parent.mkdir(parents=True, exist_ok=True) + PolicyEngineUSEngine().write_dataset(export_frame, destination, period=period) + copied_attrs = copy_populace_root_attrs(base_h5, destination) + summary = { + "schema_version": 1, + "kind": "us_l0_refit_h5_export", + "base_h5": _file_manifest(base_h5), + "weights_npz": _file_manifest(weights_npz), + "output_h5": _file_manifest(destination), + "period": int(period), + "weight_key": weight_key, + "candidate_households": int(base_frame.n("household")), + "selected_households": int(export_frame.n("household")), + "selected_weight_sum": float(export_frame.weights_for("household").total), + "copied_root_attrs": list(copied_attrs), + "metadata": solution.metadata, + } + summary_destination = ( + Path(summary_json) + if summary_json is not None + else destination.with_suffix(".l0_refit_export_summary.json") + ) + summary["summary_json_path"] = str(summary_destination) + summary_destination.parent.mkdir(parents=True, exist_ok=True) + summary_destination.write_text( + json.dumps(summary, indent=2, sort_keys=True, allow_nan=False) + "\n" + ) + summary["summary_json"] = _file_manifest(summary_destination) + return summary + + +def _parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Export a PolicyEngine-US H5 by attaching saved L0/refit weights " + "to a Populace base support frame." + ) + ) + parser.add_argument("--base-h5", required=True, type=Path) + parser.add_argument("--weights-npz", required=True, type=Path) + parser.add_argument("--output-h5", required=True, type=Path) + parser.add_argument("--period", type=int, default=2024) + parser.add_argument("--weight-key", default="weights") + parser.add_argument("--zero-weight-tolerance", type=float, default=0.0) + parser.add_argument( + "--summary-json", + type=Path, + help=( + "Path for the reconstruction manifest. Defaults to a " + ".l0_refit_export_summary.json file beside --output-h5." + ), + ) + return parser + + +def main(argv: list[str] | None = None) -> None: + args = _parser().parse_args(argv) + summary = export_us_l0_refit_h5( + base_h5=args.base_h5, + weights_npz=args.weights_npz, + output_h5=args.output_h5, + period=args.period, + weight_key=args.weight_key, + zero_weight_tolerance=args.zero_weight_tolerance, + summary_json=args.summary_json, + ) + print(json.dumps(summary, indent=2, sort_keys=True)) + + +__all__ = [ + "L0RefitWeights", + "attach_l0_refit_entity_weights", + "attach_l0_refit_weights", + "copy_populace_root_attrs", + "export_us_l0_refit_h5", + "load_l0_refit_npz", + "load_us_frame", + "main", +] + + +if __name__ == "__main__": + main() diff --git a/packages/populace-build/tests/test_us_l0_refit_export.py b/packages/populace-build/tests/test_us_l0_refit_export.py new file mode 100644 index 0000000..b986055 --- /dev/null +++ b/packages/populace-build/tests/test_us_l0_refit_export.py @@ -0,0 +1,177 @@ +import json +from hashlib import sha256 +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from populace.build.us_runtime import l0_refit_export +from populace.build.us_runtime.l0_refit_export import ( + attach_l0_refit_entity_weights, + attach_l0_refit_weights, + export_us_l0_refit_h5, + load_l0_refit_npz, +) +from populace.frame import US_SCHEMA, Frame, WeightKind, Weights + + +def _us_frame() -> Frame: + person = pd.DataFrame( + { + "person_id": np.asarray([1, 2, 3], dtype="int64"), + "person_household_id": np.asarray([10, 20, 20], dtype="int64"), + "person_tax_unit_id": np.asarray([100, 200, 201], dtype="int64"), + "person_spm_unit_id": np.asarray([1000, 2000, 2000], dtype="int64"), + "person_family_id": np.asarray([10000, 20000, 20000], dtype="int64"), + "person_marital_unit_id": np.asarray( + [100000, 200000, 200001], dtype="int64" + ), + } + ) + return Frame( + { + "person": person, + "household": pd.DataFrame( + {"household_id": np.asarray([10, 20], dtype="int64")} + ), + "tax_unit": pd.DataFrame( + {"tax_unit_id": np.asarray([100, 200, 201], dtype="int64")} + ), + "spm_unit": pd.DataFrame( + {"spm_unit_id": np.asarray([1000, 2000], dtype="int64")} + ), + "family": pd.DataFrame( + {"family_id": np.asarray([10000, 20000], dtype="int64")} + ), + "marital_unit": pd.DataFrame( + {"marital_unit_id": np.asarray([100000, 200000, 200001], dtype="int64")} + ), + }, + US_SCHEMA, + {"household": Weights(np.asarray([100.0, 200.0]), WeightKind.CALIBRATED)}, + ) + + +def test_load_l0_refit_npz_validates_full_candidate_vector(tmp_path: Path) -> None: + path = tmp_path / "weights.npz" + metadata = {"candidate_records": 3, "n_selected": 2, "weight_entity": "household"} + np.savez_compressed( + path, + weights=np.asarray([0.0, 4.0, 5.0]), + metadata_json=json.dumps(metadata), + ) + + solution = load_l0_refit_npz(path, expected_candidate_records=3) + + assert solution.selected_mask.tolist() == [False, True, True] + np.testing.assert_allclose(solution.selected_weights, np.asarray([4.0, 5.0])) + assert solution.metadata == metadata + + +def test_load_l0_refit_npz_rejects_metadata_candidate_count_mismatch( + tmp_path: Path, +) -> None: + path = tmp_path / "weights.npz" + np.savez_compressed( + path, + weights=np.asarray([0.0, 4.0, 5.0]), + metadata_json=json.dumps({"candidate_records": 4, "n_selected": 2}), + ) + + with pytest.raises(ValueError, match="candidate_records"): + load_l0_refit_npz(path, expected_candidate_records=3) + + +def test_attach_l0_refit_weights_subsets_clean_base_support() -> None: + frame = _us_frame() + solution = l0_refit_export.L0RefitWeights( + weights=np.asarray([0.0, 333.0]), + selected_mask=np.asarray([False, True]), + metadata={}, + ) + + exported = attach_l0_refit_weights(frame, solution) + + assert exported.table("household")["household_id"].to_list() == [20] + assert exported.table("person")["person_id"].to_list() == [2, 3] + assert exported.table("tax_unit")["tax_unit_id"].to_list() == [200, 201] + assert exported.table("spm_unit")["spm_unit_id"].to_list() == [2000] + np.testing.assert_allclose( + exported.weights_for("household").values, + np.asarray([333.0]), + ) + assert exported.weights_for("household").kind == WeightKind.CALIBRATED + + +def test_attach_l0_refit_entity_weights_rejects_misaligned_weights() -> None: + frame = _us_frame() + + with pytest.raises(ValueError, match="same shape"): + attach_l0_refit_entity_weights( + frame, + weight_entity="household", + selected_entity_ids=np.asarray([10, 20]), + selected_weights=np.asarray([333.0]), + reason="test", + ) + + +def test_export_us_l0_refit_h5_uses_existing_policyengine_writer( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + frame = _us_frame() + npz = tmp_path / "weights.npz" + np.savez_compressed( + npz, + weights=np.asarray([0.0, 333.0]), + metadata_json=json.dumps({"candidate_records": 2, "n_selected": 1}), + ) + base_h5 = tmp_path / "base.h5" + base_h5.write_text("base") + output = tmp_path / "populace_us_2024.h5" + + monkeypatch.setattr(l0_refit_export, "load_us_frame", lambda path: frame) + monkeypatch.setattr( + l0_refit_export, + "copy_populace_root_attrs", + lambda source, destination: ("populace_test_attr",), + ) + + class FakeEngine: + def write_dataset(self, bundle, path, period): + assert Path(path) == output + assert period == 2024 + assert bundle.table("household")["household_id"].to_list() == [20] + np.testing.assert_allclose( + bundle.weights_for("household").values, + np.asarray([333.0]), + ) + Path(path).write_text("sentinel") + + monkeypatch.setattr(l0_refit_export, "PolicyEngineUSEngine", FakeEngine) + + summary = export_us_l0_refit_h5( + base_h5=base_h5, + weights_npz=npz, + output_h5=output, + ) + + assert output.read_text() == "sentinel" + manifest_path = output.with_suffix(".l0_refit_export_summary.json") + manifest = json.loads(manifest_path.read_text()) + assert manifest["schema_version"] == 1 + assert manifest["kind"] == "us_l0_refit_h5_export" + assert manifest["summary_json_path"] == str(manifest_path) + assert manifest["base_h5"]["sha256"] == sha256(b"base").hexdigest() + assert manifest["weights_npz"]["sha256"] == sha256(npz.read_bytes()).hexdigest() + assert manifest["output_h5"]["sha256"] == sha256(b"sentinel").hexdigest() + assert summary["candidate_households"] == 2 + assert summary["selected_households"] == 1 + assert summary["selected_weight_sum"] == pytest.approx(333.0) + assert summary["copied_root_attrs"] == ["populace_test_attr"] + assert ( + summary["summary_json"]["sha256"] + == sha256(manifest_path.read_bytes()).hexdigest() + ) diff --git a/tools/build_us_fiscal_refresh_release.py b/tools/build_us_fiscal_refresh_release.py index f6b1d02..7a7f815 100644 --- a/tools/build_us_fiscal_refresh_release.py +++ b/tools/build_us_fiscal_refresh_release.py @@ -68,6 +68,7 @@ population_by_age_from_sim, write_demographics, ) +from populace.build.us_runtime.l0_refit_export import attach_l0_refit_entity_weights from populace.build.us_runtime.reform_validation import ( default_simulate_factory, load_default_reform_specs, @@ -637,8 +638,7 @@ def _parse_args() -> argparse.Namespace: "is set." ) if not args.dense_default_dataset and not ( - math.isfinite(args.l0_refit_lambda_share) - and args.l0_refit_lambda_share > 0.0 + math.isfinite(args.l0_refit_lambda_share) and args.l0_refit_lambda_share > 0.0 ): parser.error( "--l0-refit-lambda-share must be positive unless " @@ -2929,29 +2929,14 @@ def _with_calibrated_weights( def _with_l0_refit_weights(base_frame: Frame, result) -> Frame: """Attach post-L0 refit weights to the clean selected base-frame support.""" - selected_ids = np.asarray(result.selected_entity_ids) - schema = base_frame.schema - weight_entity = result.weight_entity - if weight_entity == schema.person_entity: - person_ids = base_frame.person[schema.person_id_column].to_numpy() - person_mask = np.isin(person_ids, selected_ids) - elif weight_entity in schema.group_entities: - membership = base_frame.person[schema.membership_column(weight_entity)].to_numpy() - person_mask = np.isin(membership, selected_ids) - else: - raise ValueError( - f"L0 refit default export cannot map weight entity {weight_entity!r}." - ) - selected_base = base_frame.select(person_mask) - exported_ids = selected_base.table(weight_entity)[ - schema.id_column(weight_entity) - ].to_numpy() - if not np.array_equal(exported_ids, selected_ids): - raise ValueError( - "L0 refit selected support is not aligned with the base-frame export " - f"support for {weight_entity!r}." - ) - return _with_calibrated_weights(selected_base, result.weights) + _assert_no_formula_owned_columns(base_frame) + return attach_l0_refit_entity_weights( + base_frame, + weight_entity=result.weight_entity, + selected_entity_ids=np.asarray(result.selected_entity_ids), + selected_weights=np.asarray(result.weights), + reason="US fiscal target refresh L0/refit calibration", + ) def _selected_plan_ratio_bucket(values: np.ndarray) -> dict[str, object]: diff --git a/tools/export_us_l0_refit_h5.py b/tools/export_us_l0_refit_h5.py new file mode 100644 index 0000000..9fdebc4 --- /dev/null +++ b/tools/export_us_l0_refit_h5.py @@ -0,0 +1,6 @@ +"""CLI wrapper for exporting a US Populace H5 from saved L0/refit weights.""" + +from populace.build.us_runtime.l0_refit_export import main + +if __name__ == "__main__": + main()