Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from populace.frame import US_SCHEMA, Frame, MassChange, WeightKind, Weights
from populace.frame.adapters.policyengine_us import PolicyEngineUSEngine

US_RELEASE_REQUIRED_TAX_UNIT_SOURCE_COLUMNS = (
"takes_up_aca_if_eligible",
"selected_marketplace_plan_benchmark_ratio",
)


@dataclass(frozen=True)
class L0RefitWeights:
Expand Down Expand Up @@ -211,6 +216,29 @@ def attach_l0_refit_entity_weights(
)


def assert_required_us_release_source_columns(
frame: Frame,
*,
columns: tuple[str, ...] = US_RELEASE_REQUIRED_TAX_UNIT_SOURCE_COLUMNS,
) -> None:
"""Require source-stage tax-unit columns needed by US release gates."""

tax_unit = frame.table("tax_unit")
failures: list[str] = []
for column in columns:
if column not in tax_unit.columns:
failures.append(f"{column}: missing")
continue
unique = tax_unit[column].dropna().unique()
if len(unique) < 2:
failures.append(f"{column}: not nonconstant")
if failures:
raise ValueError(
"US L0/refit release export requires source-stage tax-unit columns: "
+ "; ".join(failures)
)


def load_us_frame(path: str | Path) -> Frame:
"""Load a PolicyEngine-US single-year H5 into a Populace frame."""

Expand Down Expand Up @@ -263,9 +291,14 @@ def export_us_l0_refit_h5(
weight_key: str = "weights",
zero_weight_tolerance: float = 0.0,
summary_json: str | Path | None = None,
require_source_columns: bool = True,
root_attrs_h5: str | Path | None = None,
) -> dict[str, Any]:
"""Write a selected US H5 from a base H5 and saved L0/refit weights."""

root_attrs_source = (
Path(root_attrs_h5) if root_attrs_h5 is not None else Path(base_h5)
)
base_frame = load_us_frame(base_h5)
solution = load_l0_refit_npz(
weights_npz,
Expand All @@ -274,18 +307,23 @@ def export_us_l0_refit_h5(
zero_weight_tolerance=zero_weight_tolerance,
)
export_frame = attach_l0_refit_weights(base_frame, solution)
if require_source_columns:
assert_required_us_release_source_columns(export_frame)
destination = Path(output_h5)
destination.parent.mkdir(parents=True, exist_ok=True)
PolicyEngineUSEngine().write_dataset(export_frame, destination, period=period)
copied_attrs = copy_populace_root_attrs(base_h5, destination)
copied_attrs = copy_populace_root_attrs(root_attrs_source, destination)
summary = {
"schema_version": 1,
"kind": "us_l0_refit_h5_export",
"base_h5": _file_manifest(base_h5),
"root_attrs_h5": _file_manifest(root_attrs_source),
"weights_npz": _file_manifest(weights_npz),
"output_h5": _file_manifest(destination),
"period": int(period),
"weight_key": weight_key,
"required_source_columns": list(US_RELEASE_REQUIRED_TAX_UNIT_SOURCE_COLUMNS),
"required_source_columns_checked": bool(require_source_columns),
"candidate_households": int(base_frame.n("household")),
"selected_households": int(export_frame.n("household")),
"selected_weight_sum": float(export_frame.weights_for("household").total),
Expand Down Expand Up @@ -327,6 +365,22 @@ def _parser() -> argparse.ArgumentParser:
".l0_refit_export_summary.json file beside --output-h5."
),
)
parser.add_argument(
"--root-attrs-h5",
type=Path,
help=(
"Optional H5 whose Populace-owned root attrs are copied to the "
"exported dataset. Defaults to --base-h5."
),
)
parser.add_argument(
"--allow-missing-source-columns",
action="store_true",
help=(
"Diagnostic escape hatch. By default, reconstruction requires the "
"US release source-stage tax-unit columns used by release gates."
),
)
return parser


Expand All @@ -340,6 +394,8 @@ def main(argv: list[str] | None = None) -> None:
weight_key=args.weight_key,
zero_weight_tolerance=args.zero_weight_tolerance,
summary_json=args.summary_json,
require_source_columns=not args.allow_missing_source_columns,
root_attrs_h5=args.root_attrs_h5,
)
print(json.dumps(summary, indent=2, sort_keys=True))

Expand All @@ -348,6 +404,7 @@ def main(argv: list[str] | None = None) -> None:
"L0RefitWeights",
"attach_l0_refit_entity_weights",
"attach_l0_refit_weights",
"assert_required_us_release_source_columns",
"copy_populace_root_attrs",
"export_us_l0_refit_h5",
"load_l0_refit_npz",
Expand Down
43 changes: 41 additions & 2 deletions packages/populace-build/tests/test_us_l0_refit_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from populace.build.us_runtime import l0_refit_export
from populace.build.us_runtime.l0_refit_export import (
assert_required_us_release_source_columns,
attach_l0_refit_entity_weights,
attach_l0_refit_weights,
export_us_l0_refit_h5,
Expand Down Expand Up @@ -36,7 +37,15 @@ def _us_frame() -> Frame:
{"household_id": np.asarray([10, 20], dtype="int64")}
),
"tax_unit": pd.DataFrame(
{"tax_unit_id": np.asarray([100, 200, 201], dtype="int64")}
{
"tax_unit_id": np.asarray([100, 200, 201], dtype="int64"),
"takes_up_aca_if_eligible": np.asarray(
[False, True, False], dtype=bool
),
"selected_marketplace_plan_benchmark_ratio": np.asarray(
[1.0, 0.8, 1.2], dtype="float64"
),
}
),
"spm_unit": pd.DataFrame(
{"spm_unit_id": np.asarray([1000, 2000], dtype="int64")}
Expand Down Expand Up @@ -117,6 +126,27 @@ def test_attach_l0_refit_entity_weights_rejects_misaligned_weights() -> None:
)


def test_required_us_release_source_columns_rejects_missing_source_stage() -> None:
frame = _us_frame()
raw_tax_units = frame.table("tax_unit").drop(
columns=[
"takes_up_aca_if_eligible",
"selected_marketplace_plan_benchmark_ratio",
]
)
raw_frame = Frame(
{
**{entity: frame.table(entity).copy() for entity in frame.schema.entities},
"tax_unit": raw_tax_units,
},
frame.schema,
{"household": frame.weights_for("household")},
)

with pytest.raises(ValueError, match="takes_up_aca_if_eligible: missing"):
assert_required_us_release_source_columns(raw_frame)


def test_export_us_l0_refit_h5_uses_existing_policyengine_writer(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
Expand All @@ -130,13 +160,18 @@ def test_export_us_l0_refit_h5_uses_existing_policyengine_writer(
)
base_h5 = tmp_path / "base.h5"
base_h5.write_text("base")
attrs_h5 = tmp_path / "attrs.h5"
attrs_h5.write_text("attrs")
output = tmp_path / "populace_us_2024.h5"
copied_from: list[Path] = []

monkeypatch.setattr(l0_refit_export, "load_us_frame", lambda path: frame)
monkeypatch.setattr(
l0_refit_export,
"copy_populace_root_attrs",
lambda source, destination: ("populace_test_attr",),
lambda source, destination: (
copied_from.append(Path(source)) or ("populace_test_attr",)
),
)

class FakeEngine:
Expand All @@ -156,15 +191,19 @@ def write_dataset(self, bundle, path, period):
base_h5=base_h5,
weights_npz=npz,
output_h5=output,
root_attrs_h5=attrs_h5,
)

assert output.read_text() == "sentinel"
assert copied_from == [attrs_h5]
manifest_path = output.with_suffix(".l0_refit_export_summary.json")
manifest = json.loads(manifest_path.read_text())
assert manifest["schema_version"] == 1
assert manifest["kind"] == "us_l0_refit_h5_export"
assert manifest["summary_json_path"] == str(manifest_path)
assert manifest["required_source_columns_checked"] is True
assert manifest["base_h5"]["sha256"] == sha256(b"base").hexdigest()
assert manifest["root_attrs_h5"]["sha256"] == sha256(b"attrs").hexdigest()
assert manifest["weights_npz"]["sha256"] == sha256(npz.read_bytes()).hexdigest()
assert manifest["output_h5"]["sha256"] == sha256(b"sentinel").hexdigest()
assert summary["candidate_households"] == 2
Expand Down
Loading