Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog.d/745.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Split the new `calibration.local_h5` contracts into themed request, input,
validation, and result modules; extract test-only fixtures into dedicated
fixture helpers; and tighten the new request boundary so construction logic
stays outside the value objects.
29 changes: 9 additions & 20 deletions modal_app/local_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
modal run modal_app/local_area.py --branch=main --num-workers=8
"""

import heapq
import json
import os
import subprocess
Expand All @@ -30,6 +29,9 @@

from modal_app.images import cpu_image as image # noqa: E402
from modal_app.resilience import reconcile_run_dir_fingerprint # noqa: E402
from policyengine_us_data.calibration.local_h5.partitioning import ( # noqa: E402
partition_weighted_work_items,
)

app = modal.App("policyengine-us-data-local-area")

Expand Down Expand Up @@ -309,25 +311,12 @@ def partition_work(
num_workers: int,
completed: set,
) -> List[List[Dict]]:
"""Partition work items across N workers using LPT scheduling."""
remaining = [
item for item in work_items if f"{item['type']}:{item['id']}" not in completed
]
remaining.sort(key=lambda x: -x["weight"])

n_workers = min(num_workers, len(remaining))
if n_workers == 0:
return []

heap = [(0, i) for i in range(n_workers)]
chunks = [[] for _ in range(n_workers)]

for item in remaining:
load, idx = heapq.heappop(heap)
chunks[idx].append(item)
heapq.heappush(heap, (load + item["weight"], idx))

return [c for c in chunks if c]
"""Compatibility wrapper over the extracted pure partitioning seam."""
return partition_weighted_work_items(
work_items=work_items,
num_workers=num_workers,
completed=completed,
)


def get_completed_from_volume(run_dir: Path) -> set:
Expand Down
6 changes: 6 additions & 0 deletions policyengine_us_data/calibration/local_h5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Internal package for the incremental local H5 migration.

Modules in this package should land only when they become active runtime
seams rather than speculative placeholders. The first migration slice
introduces only ``partitioning.py``.
"""
42 changes: 42 additions & 0 deletions policyengine_us_data/calibration/local_h5/partitioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Pure helpers for assigning weighted work items to worker chunks."""

from __future__ import annotations

import heapq
from collections.abc import Mapping, Sequence
from typing import Any


def work_item_key(item: Mapping[str, Any]) -> str:
"""Return the stable completion key used by the current H5 workers."""

return f"{item['type']}:{item['id']}"


def partition_weighted_work_items(
work_items: Sequence[Mapping[str, Any]],
num_workers: int,
completed: set[str] | None = None,
) -> list[list[Mapping[str, Any]]]:
"""Partition work items across workers using longest-processing-time first."""

if num_workers <= 0:
return []

completed = completed or set()
remaining = [item for item in work_items if work_item_key(item) not in completed]
remaining.sort(key=lambda item: -item["weight"])

n_workers = min(num_workers, len(remaining))
if n_workers == 0:
return []

heap: list[tuple[int | float, int]] = [(0, idx) for idx in range(n_workers)]
chunks: list[list[Mapping[str, Any]]] = [[] for _ in range(n_workers)]

for item in remaining:
load, idx = heapq.heappop(heap)
chunks[idx].append(item)
heapq.heappush(heap, (load + item["weight"], idx))

return [chunk for chunk in chunks if chunk]
5 changes: 5 additions & 0 deletions tests/unit/calibration/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Shared test helpers for unit calibration tests.

Fixture modules in this package hold setup code and reusable test-only
helpers so individual test files stay focused on assertions.
"""
50 changes: 50 additions & 0 deletions tests/unit/calibration/fixtures/test_local_h5_partitioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Fixture helpers for ``test_local_h5_partitioning.py``."""

from __future__ import annotations

import importlib.util
import sys
from pathlib import Path

__test__ = False


def _load_partitioning_module():
"""Load the pure partitioning module directly from disk."""

repo_root = Path(__file__).resolve().parents[4]
module_path = (
repo_root
/ "policyengine_us_data"
/ "calibration"
/ "local_h5"
/ "partitioning.py"
)
spec = importlib.util.spec_from_file_location(
"local_h5_partitioning",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module


def flatten_chunks(chunks):
"""Flatten worker chunks into a single item list for assertions."""

return [item for chunk in chunks for item in chunk]


def load_partitioning_exports():
"""Load the partitioning module and return its public exports."""

module = _load_partitioning_module()
return {
"module": module,
"flatten_chunks": flatten_chunks,
"partition_weighted_work_items": module.partition_weighted_work_items,
"work_item_key": module.work_item_key,
}
75 changes: 75 additions & 0 deletions tests/unit/calibration/test_local_h5_partitioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from tests.unit.calibration.fixtures.test_local_h5_partitioning import (
load_partitioning_exports,
)


partitioning = load_partitioning_exports()
flatten_chunks = partitioning["flatten_chunks"]
partition_weighted_work_items = partitioning["partition_weighted_work_items"]
work_item_key = partitioning["work_item_key"]


def test_work_item_key_uses_existing_completion_shape():
item = {"type": "district", "id": "CA-12", "weight": 1}
assert work_item_key(item) == "district:CA-12"


def test_partition_filters_completed_items():
work_items = [
{"type": "state", "id": "CA", "weight": 3},
{"type": "district", "id": "CA-12", "weight": 1},
{"type": "city", "id": "NYC", "weight": 2},
]

chunks = partition_weighted_work_items(
work_items,
num_workers=2,
completed={"district:CA-12"},
)

flattened = flatten_chunks(chunks)
assert all(item["id"] != "CA-12" for item in flattened)
assert {item["id"] for item in flattened} == {"CA", "NYC"}


def test_partition_returns_empty_for_zero_workers_or_zero_remaining():
work_items = [{"type": "state", "id": "CA", "weight": 1}]

assert partition_weighted_work_items(work_items, num_workers=0) == []
assert (
partition_weighted_work_items(
work_items,
num_workers=3,
completed={"state:CA"},
)
== []
)


def test_partition_uses_no_more_workers_than_remaining_items():
work_items = [
{"type": "state", "id": "CA", "weight": 5},
{"type": "state", "id": "NY", "weight": 4},
]

chunks = partition_weighted_work_items(work_items, num_workers=10)

assert len(chunks) == 2
assert all(len(chunk) == 1 for chunk in chunks)


def test_partition_is_weight_balancing_and_deterministic_for_equal_weights():
work_items = [
{"type": "district", "id": "A", "weight": 5},
{"type": "district", "id": "B", "weight": 5},
{"type": "district", "id": "C", "weight": 2},
{"type": "district", "id": "D", "weight": 2},
]

chunks = partition_weighted_work_items(work_items, num_workers=2)

ids_by_chunk = [[item["id"] for item in chunk] for chunk in chunks]
loads = [sum(item["weight"] for item in chunk) for chunk in chunks]

assert ids_by_chunk == [["A", "C"], ["B", "D"]]
assert loads == [7, 7]
1 change: 1 addition & 0 deletions tests/unit/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Shared test helpers for top-level unit tests."""
88 changes: 88 additions & 0 deletions tests/unit/fixtures/test_modal_local_area.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Fixture helpers for `test_modal_local_area.py`."""

import importlib
import sys
from contextlib import contextmanager
from types import ModuleType, SimpleNamespace

__test__ = False


@contextmanager
def _patched_module_registry(overrides: dict[str, ModuleType]):
"""Temporarily replace selected `sys.modules` entries for one import."""

sentinel = object()
previous = {
name: sys.modules.get(name, sentinel)
for name in [*overrides.keys(), "modal_app.local_area"]
}

try:
for name, module in overrides.items():
sys.modules[name] = module
sys.modules.pop("modal_app.local_area", None)
yield
finally:
for name, module in previous.items():
if module is sentinel:
sys.modules.pop(name, None)
else:
sys.modules[name] = module


def load_local_area_module():
"""Import `modal_app.local_area` with scoped fake Modal dependencies."""

fake_modal = ModuleType("modal")
fake_policyengine = ModuleType("policyengine_us_data")
fake_calibration = ModuleType("policyengine_us_data.calibration")
fake_local_h5 = ModuleType("policyengine_us_data.calibration.local_h5")
fake_partitioning = ModuleType(
"policyengine_us_data.calibration.local_h5.partitioning"
)
fake_policyengine.__path__ = []
fake_calibration.__path__ = []
fake_local_h5.__path__ = []

class _FakeApp:
def __init__(self, *args, **kwargs):
pass

def function(self, *args, **kwargs):
def decorator(func):
return func

return decorator

def local_entrypoint(self, *args, **kwargs):
def decorator(func):
return func

return decorator

fake_modal.App = _FakeApp
fake_modal.Secret = SimpleNamespace(from_name=lambda *args, **kwargs: object())
fake_modal.Volume = SimpleNamespace(from_name=lambda *args, **kwargs: object())

fake_images = ModuleType("modal_app.images")
fake_images.cpu_image = object()

fake_resilience = ModuleType("modal_app.resilience")
fake_resilience.reconcile_run_dir_fingerprint = lambda *args, **kwargs: None
fake_partitioning.partition_weighted_work_items = lambda *args, **kwargs: []

with _patched_module_registry(
{
"modal": fake_modal,
"modal_app.images": fake_images,
"modal_app.resilience": fake_resilience,
"policyengine_us_data": fake_policyengine,
"policyengine_us_data.calibration": fake_calibration,
"policyengine_us_data.calibration.local_h5": fake_local_h5,
"policyengine_us_data.calibration.local_h5.partitioning": (
fake_partitioning
),
}
):
return importlib.import_module("modal_app.local_area")
44 changes: 3 additions & 41 deletions tests/unit/test_modal_local_area.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,8 @@
import importlib
import sys
from types import ModuleType, SimpleNamespace


def _load_local_area_module():
fake_modal = ModuleType("modal")

class _FakeApp:
def __init__(self, *args, **kwargs):
pass

def function(self, *args, **kwargs):
def decorator(func):
return func

return decorator

def local_entrypoint(self, *args, **kwargs):
def decorator(func):
return func

return decorator

fake_modal.App = _FakeApp
fake_modal.Secret = SimpleNamespace(from_name=lambda *args, **kwargs: object())
fake_modal.Volume = SimpleNamespace(from_name=lambda *args, **kwargs: object())

fake_images = ModuleType("modal_app.images")
fake_images.cpu_image = object()

fake_resilience = ModuleType("modal_app.resilience")
fake_resilience.reconcile_run_dir_fingerprint = lambda *args, **kwargs: None

sys.modules["modal"] = fake_modal
sys.modules["modal_app.images"] = fake_images
sys.modules["modal_app.resilience"] = fake_resilience
sys.modules.pop("modal_app.local_area", None)
return importlib.import_module("modal_app.local_area")
from tests.unit.fixtures.test_modal_local_area import load_local_area_module


def test_build_promote_national_publish_script_imports_version_manifest_helpers():
local_area = _load_local_area_module()
local_area = load_local_area_module()

script = local_area._build_promote_national_publish_script(
version="1.73.0",
Expand All @@ -55,7 +17,7 @@ def test_build_promote_national_publish_script_imports_version_manifest_helpers(


def test_build_promote_publish_script_finalizes_complete_release():
local_area = _load_local_area_module()
local_area = load_local_area_module()

script = local_area._build_promote_publish_script(
version="1.73.0",
Expand Down
Loading