Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions packages/populace-build/tests/test_us_fiscal_refresh_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,213 @@ def _load_scorer_module():
return module


def test__given_matching_warm_start_npz__then_builder_loads_household_weights(
tmp_path,
) -> None:
builder = _load_builder_module()
path = tmp_path / "populace_us_2024_calibration.npz"
initial = np.asarray([10.0, 20.0, 30.0])
weights = np.asarray([12.0, 18.0, 35.0])
np.savez_compressed(
path,
household_weight=weights,
initial_household_weight=initial,
)

loaded, payload = builder._load_warm_start_calibration_npz(
path,
expected_initial_weights=initial.copy(),
)

np.testing.assert_allclose(loaded, weights)
assert payload["enabled"] is True
assert payload["n_households"] == 3
assert payload["sha256"] == builder._sha256(path)


def test__given_mismatched_warm_start_initial_weights__then_builder_rejects_npz(
tmp_path,
) -> None:
builder = _load_builder_module()
path = tmp_path / "populace_us_2024_calibration.npz"
np.savez_compressed(
path,
household_weight=np.asarray([12.0, 18.0, 35.0]),
initial_household_weight=np.asarray([10.0, 20.0, 30.0]),
)

with pytest.raises(ValueError, match="different initial household weights"):
builder._load_warm_start_calibration_npz(
path,
expected_initial_weights=np.asarray([10.0, 20.0, 31.0]),
)


def test__given_target_frame_checkpoint__then_builder_round_trips_frame(
monkeypatch,
tmp_path,
small_frame,
) -> None:
builder = _load_builder_module()
monkeypatch.setattr(builder, "US_SCHEMA", small_frame.schema)
tables = {
entity: small_frame.table(entity).copy() for entity in small_frame.entities
}
tables["household"]["mock_measure"] = np.asarray([1.5, 2.5])
tables["household"]["mock_filter"] = np.asarray([1, 0], dtype=np.int64)
frame = Frame(
tables,
small_frame.schema,
{"household": small_frame.weights_for("household")},
small_frame.strata,
)
target = TargetSpec(
name="mock.measure",
entity="household",
measure="mock_measure",
filter="mock_filter",
value=1500.0,
source="Mock source",
)
identity = builder._target_frame_checkpoint_identity(
base_dataset_sha256="base-sha",
policyengine_us_version="1.2.3",
seed=0,
target_period=builder.PERIOD,
target_registry_version="registry-sha",
congressional_district_vintage_crosswalk_sha256="crosswalk-sha",
)
path = tmp_path / "target_frame_checkpoint.h5"

payload = builder._write_target_frame_checkpoint(
path,
frame=frame,
identity=identity,
compilation={"declared_targets": 1},
)
loaded = builder._read_target_frame_checkpoint(
path,
identity=identity,
target_specs=(target,),
)

assert payload["status"] == "miss_written"
assert loaded is not None
loaded_frame, loaded_registry, loaded_compilation = loaded
assert np.array_equal(
loaded_frame.table("household")["mock_measure"].to_numpy(),
np.asarray([1.5, 2.5]),
)
assert np.array_equal(
loaded_frame.table("household")["mock_filter"].to_numpy(),
np.asarray([1, 0], dtype=np.int64),
)
assert np.array_equal(
loaded_frame.weights_for("household").values,
small_frame.weights_for("household").values,
)
assert loaded_frame.weights_for("household").kind is WeightKind.DESIGN
pd.testing.assert_series_equal(
loaded_frame.strata,
small_frame.strata,
check_dtype=False,
)
assert len(loaded_registry) == 1
assert loaded_compilation["compiled_candidate_targets"] == 1
assert loaded_compilation["target_frame_checkpoint"]["status"] == "hit"
assert (
loaded_compilation["target_frame_checkpoint"]["stored_compilation"][
"declared_targets"
]
== 1
)


def test__given_stale_target_frame_checkpoint__then_builder_ignores_it(
tmp_path,
small_frame,
) -> None:
builder = _load_builder_module()
fresh_identity = builder._target_frame_checkpoint_identity(
base_dataset_sha256="base-sha",
policyengine_us_version="1.2.3",
seed=0,
target_period=builder.PERIOD,
target_registry_version="registry-sha",
congressional_district_vintage_crosswalk_sha256="crosswalk-sha",
)
stale_identity = {
**fresh_identity,
"target_registry_version": "old-registry-sha",
}
path = tmp_path / "target_frame_checkpoint.h5"
builder._write_target_frame_checkpoint(
path,
frame=small_frame,
identity=stale_identity,
compilation={},
)

loaded = builder._read_target_frame_checkpoint(
path,
identity=fresh_identity,
target_specs=(),
)

assert loaded is None


def test__given_matching_target_frame_checkpoint__then_builder_skips_materialization(
monkeypatch,
tmp_path,
small_frame,
) -> None:
builder = _load_builder_module()
target = TargetSpec(
name="mock.measure",
entity="household",
measure="household_id",
value=1.0,
source="Mock source",
)
registry = TargetRegistry((target,), country="us")
identity = builder._target_frame_checkpoint_identity(
base_dataset_sha256="base-sha",
policyengine_us_version="1.2.3",
seed=0,
target_period=builder.PERIOD,
target_registry_version=registry.version,
congressional_district_vintage_crosswalk_sha256=None,
)

def fail_materialize(*args, **kwargs):
raise AssertionError("materialization should not run on checkpoint hit")

monkeypatch.setattr(builder, "_materialize_target_frame", fail_materialize)
monkeypatch.setattr(
builder,
"_read_target_frame_checkpoint",
lambda path, **kwargs: (
small_frame,
registry,
{"target_frame_checkpoint": {"status": "hit"}},
),
)

loaded_frame, loaded_registry, compilation = (
builder._load_or_materialize_target_frame(
small_frame,
(target,),
target_frame_checkpoint_path=tmp_path / "target_frame_checkpoint.h5",
target_frame_checkpoint_identity=identity,
)
)

assert loaded_frame is small_frame
assert loaded_registry is registry
assert compilation["target_frame_checkpoint"]["status"] == "hit"


def test_runtime_versions_use_local_workspace_package_version(
monkeypatch, tmp_path
) -> None:
Expand Down Expand Up @@ -1694,6 +1901,7 @@ class FakeFrame:
str(out),
"--release-id",
release_id,
"--no-target-frame-checkpoint",
],
)
monkeypatch.setattr(builder, "_git_dirty", lambda: False)
Expand Down
Loading
Loading