From e34ab34c11f1a53ef6d63ec24534127f99010bd7 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Tue, 30 Jun 2026 16:56:11 +0200 Subject: [PATCH 1/2] Make UK firm Ledger selectors profile-driven --- .../build/uk_runtime/firm_generation.py | 150 ++++++++++++++---- .../tests/test_uk_firm_generation.py | 85 ++++++++++ 2 files changed, 206 insertions(+), 29 deletions(-) diff --git a/packages/populace-build/src/populace/build/uk_runtime/firm_generation.py b/packages/populace-build/src/populace/build/uk_runtime/firm_generation.py index 87d98d6..65165e2 100644 --- a/packages/populace-build/src/populace/build/uk_runtime/firm_generation.py +++ b/packages/populace-build/src/populace/build/uk_runtime/firm_generation.py @@ -55,30 +55,67 @@ "hmrc_liability_band": "hmrc_vat_liability_by_turnover_band.csv", "hmrc_liability_sector": "hmrc_vat_liability_by_sector.csv", } -LEDGER_ONS_TURNOVER_RECORD_SET = ( - "ons.uk_business.cy2025.enterprise_count.by_turnover_band" -) -LEDGER_ONS_EMPLOYMENT_RECORD_SET = ( - "ons.uk_business.cy2025.enterprise_count.by_employment_band" -) -LEDGER_ONS_SIC_TURNOVER_RECORD_SET = ( - "ons.uk_business.cy2025.enterprise_count.by_sic_turnover_band" -) -LEDGER_ONS_SIC_EMPLOYMENT_RECORD_SET = ( - "ons.uk_business.cy2025.enterprise_count.by_sic_employment_band" -) -LEDGER_HMRC_POPULATION_RECORD_SET = ( - "hmrc.vat.fy2024_25.registered_trader_count.by_turnover_band" -) -LEDGER_HMRC_LIABILITY_RECORD_SET = ( - "hmrc.vat.fy2024_25.net_liability.by_turnover_band" -) -LEDGER_HMRC_POPULATION_SIC_RECORD_SET = ( - "hmrc.vat.fy2024_25.registered_trader_count.by_sic" -) -LEDGER_HMRC_LIABILITY_SIC_RECORD_SET = ( - "hmrc.vat.fy2024_25.net_liability.by_sic" +UK_FIRM_TARGET_IDS: dict[str, str] = { + "ons_turnover": "ons.uk_business.enterprise_count.turnover_bands", + "ons_employment": "ons.uk_business.enterprise_count.employment_bands", + "hmrc_population_band": "hmrc.vat.registered_trader_count.turnover_bands", + "hmrc_liability_band": "hmrc.vat.net_liability.turnover_bands", + "ons_sic_turnover": "ons.uk_business.enterprise_count.sic_turnover_bands", + "ons_sic_employment": "ons.uk_business.enterprise_count.sic_employment_bands", + "hmrc_population_sic": "hmrc.vat.registered_trader_count.sic_sectors", + "hmrc_liability_sic": "hmrc.vat.net_liability.sic_sectors", +} + + +@dataclass(frozen=True) +class UKFirmLedgerTargetProfile: + """Ledger record-set selectors for the ``uk_firms`` target profile.""" + + ons_turnover_record_set: str + ons_employment_record_set: str + hmrc_population_band_record_set: str + hmrc_liability_band_record_set: str + ons_sic_turnover_record_set: str + ons_sic_employment_record_set: str + hmrc_population_sic_record_set: str + hmrc_liability_sic_record_set: str + + +DEFAULT_UK_FIRM_TARGET_PROFILE = UKFirmLedgerTargetProfile( + ons_turnover_record_set="ons.uk_business.cy2025.enterprise_count.by_turnover_band", + ons_employment_record_set="ons.uk_business.cy2025.enterprise_count.by_employment_band", + hmrc_population_band_record_set=( + "hmrc.vat.fy2024_25.registered_trader_count.by_turnover_band" + ), + hmrc_liability_band_record_set=( + "hmrc.vat.fy2024_25.net_liability.by_turnover_band" + ), + ons_sic_turnover_record_set=( + "ons.uk_business.cy2025.enterprise_count.by_sic_turnover_band" + ), + ons_sic_employment_record_set=( + "ons.uk_business.cy2025.enterprise_count.by_sic_employment_band" + ), + hmrc_population_sic_record_set=( + "hmrc.vat.fy2024_25.registered_trader_count.by_sic" + ), + hmrc_liability_sic_record_set="hmrc.vat.fy2024_25.net_liability.by_sic", ) + +_UK_FIRM_TARGET_PROFILE_FIELDS: dict[str, str] = { + "ons_turnover": "ons_turnover_record_set", + "ons_employment": "ons_employment_record_set", + "hmrc_population_band": "hmrc_population_band_record_set", + "hmrc_liability_band": "hmrc_liability_band_record_set", + "ons_sic_turnover": "ons_sic_turnover_record_set", + "ons_sic_employment": "ons_sic_employment_record_set", + "hmrc_population_sic": "hmrc_population_sic_record_set", + "hmrc_liability_sic": "hmrc_liability_sic_record_set", +} + +# These maps translate Ledger value IDs into the generator's temporary support +# matrix. The source selectors above are profile-owned; support layout should be +# the next piece lifted into a declarative spec. LEDGER_ONS_TURNOVER_BANDS: dict[str, str] = { "0_49k": "0-49", "50_99k": "50-99", @@ -355,10 +392,64 @@ def uk_firm_source_data_from_frames( ) +def uk_firm_target_profile_from_mapping( + raw: Mapping[str, Any], +) -> UKFirmLedgerTargetProfile: + """Extract UK firm record-set selectors from a Ledger target profile.""" + + target_rows = raw.get("targets") + if not isinstance(target_rows, Iterable) or isinstance( + target_rows, str | bytes | Mapping + ): + raise ValueError("UK firm target profile must contain a targets list.") + + targets_by_id: dict[str, Mapping[str, Any]] = {} + for row in target_rows: + if not isinstance(row, Mapping): + raise ValueError("UK firm target profile targets must be objects.") + target_id = row.get("target_id") + if isinstance(target_id, str): + targets_by_id[target_id] = row + + record_sets: dict[str, str] = {} + for logical_key, target_id in UK_FIRM_TARGET_IDS.items(): + target = targets_by_id.get(target_id) + if target is None: + raise ValueError(f"UK firm target profile missing target {target_id!r}.") + + selector = target.get("ledger_selector") + if not isinstance(selector, Mapping): + raise ValueError( + f"UK firm target {target_id!r} must define ledger_selector." + ) + + record_set_id = selector.get("record_set_id") + if not isinstance(record_set_id, str) or not record_set_id: + raise ValueError( + f"UK firm target {target_id!r} must define " + "ledger_selector.record_set_id." + ) + + record_sets[_UK_FIRM_TARGET_PROFILE_FIELDS[logical_key]] = record_set_id + + return UKFirmLedgerTargetProfile(**record_sets) + + +def _coerce_uk_firm_target_profile( + target_profile: UKFirmLedgerTargetProfile | Mapping[str, Any] | None, +) -> UKFirmLedgerTargetProfile: + if target_profile is None: + return DEFAULT_UK_FIRM_TARGET_PROFILE + if isinstance(target_profile, UKFirmLedgerTargetProfile): + return target_profile + return uk_firm_target_profile_from_mapping(target_profile) + + def uk_firm_source_data_from_ledger_facts( facts: Iterable[Mapping[str, Any] | object], *, data_vintage: str = "2024-25", + target_profile: UKFirmLedgerTargetProfile | Mapping[str, Any] | None = None, ) -> UKFirmSourceData: """Build UK firm inputs from Ledger consumer facts. @@ -368,30 +459,31 @@ def uk_firm_source_data_from_ledger_facts( contract while firm support is being migrated. """ + profile = _coerce_uk_firm_target_profile(target_profile) fact_rows = tuple(facts) ons_turnover = _ledger_ons_sic_band_matrix( fact_rows, - record_set_id=LEDGER_ONS_SIC_TURNOVER_RECORD_SET, + record_set_id=profile.ons_sic_turnover_record_set, band_map=LEDGER_ONS_TURNOVER_BANDS, band_dimension="uk.firm.turnover_band", table_name="ONS SIC turnover", ) ons_employment = _ledger_ons_sic_band_matrix( fact_rows, - record_set_id=LEDGER_ONS_SIC_EMPLOYMENT_RECORD_SET, + record_set_id=profile.ons_sic_employment_record_set, band_map=LEDGER_ONS_EMPLOYMENT_BANDS, band_dimension="uk.firm.employment_band", table_name="ONS SIC employment", ) hmrc_population_values = _ledger_values_by_band( fact_rows, - record_set_id=LEDGER_HMRC_POPULATION_RECORD_SET, + record_set_id=profile.hmrc_population_band_record_set, measure_id="vat_registered_trader_count", band_map=LEDGER_HMRC_BANDS, ) hmrc_liability_values_gbp = _ledger_values_by_band( fact_rows, - record_set_id=LEDGER_HMRC_LIABILITY_RECORD_SET, + record_set_id=profile.hmrc_liability_band_record_set, measure_id="net_vat_liability", band_map=LEDGER_HMRC_BANDS, ) @@ -404,7 +496,7 @@ def uk_firm_source_data_from_ledger_facts( ) hmrc_population_sector = _ledger_sic_series( fact_rows, - record_set_id=LEDGER_HMRC_POPULATION_SIC_RECORD_SET, + record_set_id=profile.hmrc_population_sic_record_set, measure_id="vat_registered_trader_count", data_vintage=data_vintage, value_scale=1.0, @@ -414,7 +506,7 @@ def uk_firm_source_data_from_ledger_facts( ) hmrc_liability_sector = _ledger_sic_series( fact_rows, - record_set_id=LEDGER_HMRC_LIABILITY_SIC_RECORD_SET, + record_set_id=profile.hmrc_liability_sic_record_set, measure_id="net_vat_liability", data_vintage=data_vintage, value_scale=1 / 1_000_000.0, diff --git a/packages/populace-build/tests/test_uk_firm_generation.py b/packages/populace-build/tests/test_uk_firm_generation.py index 3f11f57..6514ef8 100644 --- a/packages/populace-build/tests/test_uk_firm_generation.py +++ b/packages/populace-build/tests/test_uk_firm_generation.py @@ -6,8 +6,10 @@ import pytest from populace.build.uk_runtime.firm_generation import ( + DEFAULT_UK_FIRM_TARGET_PROFILE, HMRC_BAND_COLUMNS, INPUT_FILES, + UK_FIRM_TARGET_IDS, UKFirmGenerationConfig, employment_band_name, generate_uk_firm_population, @@ -15,6 +17,7 @@ read_uk_firm_source_data, uk_firm_source_data_from_frames, uk_firm_source_data_from_ledger_facts, + uk_firm_target_profile_from_mapping, write_uk_firm_population, ) @@ -50,6 +53,40 @@ def test_uk_firm_source_data_from_ledger_facts_uses_ledger_targets() -> None: assert data.hmrc_liability_sector["2024-25"].tolist() == [2.0, 1.5] +def test_uk_firm_source_data_from_ledger_facts_uses_target_profile_mapping() -> None: + custom_record_set_id = "custom.ons.enterprise_count.by_sic_turnover_band" + facts = _replace_fact_record_set( + _ledger_facts(), + old_record_set_id=DEFAULT_UK_FIRM_TARGET_PROFILE.ons_sic_turnover_record_set, + new_record_set_id=custom_record_set_id, + ) + target_profile = _ledger_target_profile_mapping( + ons_sic_turnover=custom_record_set_id, + ) + + data = uk_firm_source_data_from_ledger_facts( + facts, + data_vintage="2024-25", + target_profile=target_profile, + ) + + assert data.ons_turnover.loc[0, "0-49"] == 2.0 + assert data.ons_turnover.loc[1, "500-999"] == 1.0 + + +def test_uk_firm_target_profile_from_mapping_requires_declared_targets() -> None: + target_profile = _ledger_target_profile_mapping() + missing_target_id = UK_FIRM_TARGET_IDS["hmrc_liability_sic"] + target_profile["targets"] = [ + row + for row in target_profile["targets"] + if row["target_id"] != missing_target_id + ] + + with pytest.raises(ValueError, match=missing_target_id): + uk_firm_target_profile_from_mapping(target_profile) + + def test_generate_uk_firm_population_accepts_ledger_source_data() -> None: data = uk_firm_source_data_from_ledger_facts( _ledger_facts(), @@ -229,6 +266,54 @@ def _source_data(): return uk_firm_source_data_from_frames(**_source_frames()) +def _ledger_target_profile_mapping(**overrides: str) -> dict[str, object]: + record_sets = { + "ons_turnover": DEFAULT_UK_FIRM_TARGET_PROFILE.ons_turnover_record_set, + "ons_employment": DEFAULT_UK_FIRM_TARGET_PROFILE.ons_employment_record_set, + "hmrc_population_band": ( + DEFAULT_UK_FIRM_TARGET_PROFILE.hmrc_population_band_record_set + ), + "hmrc_liability_band": ( + DEFAULT_UK_FIRM_TARGET_PROFILE.hmrc_liability_band_record_set + ), + "ons_sic_turnover": ( + DEFAULT_UK_FIRM_TARGET_PROFILE.ons_sic_turnover_record_set + ), + "ons_sic_employment": ( + DEFAULT_UK_FIRM_TARGET_PROFILE.ons_sic_employment_record_set + ), + "hmrc_population_sic": ( + DEFAULT_UK_FIRM_TARGET_PROFILE.hmrc_population_sic_record_set + ), + "hmrc_liability_sic": DEFAULT_UK_FIRM_TARGET_PROFILE.hmrc_liability_sic_record_set, + **overrides, + } + return { + "targets": [ + { + "target_id": target_id, + "ledger_selector": {"record_set_id": record_sets[logical_key]}, + } + for logical_key, target_id in UK_FIRM_TARGET_IDS.items() + ] + } + + +def _replace_fact_record_set( + facts: list[dict[str, object]], + *, + old_record_set_id: str, + new_record_set_id: str, +) -> list[dict[str, object]]: + updated_facts: list[dict[str, object]] = [] + for fact in facts: + updated = {**fact, "layout": dict(fact["layout"])} + if updated["layout"]["record_set_id"] == old_record_set_id: + updated["layout"]["record_set_id"] = new_record_set_id + updated_facts.append(updated) + return updated_facts + + def _ledger_facts() -> list[dict[str, object]]: facts: list[dict[str, object]] = [] for sic, values in { From cede5f1ae01e71dc68c8d656f3b294285b2cff47 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 1 Jul 2026 17:16:55 +0200 Subject: [PATCH 2/2] Use Axiom VAT evaluator for UK firms --- packages/populace-build/README.md | 8 +- .../src/populace/build/uk_runtime/__init__.py | 4 + .../build/uk_runtime/firm_generation.py | 233 +++++++++++++++++- .../tests/test_uk_firm_generation.py | 62 ++++- 4 files changed, 292 insertions(+), 15 deletions(-) diff --git a/packages/populace-build/README.md b/packages/populace-build/README.md index b0e1f1d..a7501f6 100644 --- a/packages/populace-build/README.md +++ b/packages/populace-build/README.md @@ -106,10 +106,10 @@ data. This is not yet the production firm microsimulation path: Ledger currently covers the current paper target surface used by the adapter, including ONS SIC-by-turnover, ONS SIC-by-employment, HMRC VAT-registered firms by SIC, and -HMRC net VAT liability by SIC. VAT metrics still use the paper generator's -turnover/input heuristics rather than Axiom RuleSpec execution, so production -activation should wait for Axiom-backed VAT metrics. The processed-table reader -remains only for paper-repository migration comparisons. +HMRC net VAT liability by SIC. VAT liability is now an explicit rule-evaluator +input to generation; production runs should provide an Axiom RuleSpec artifact +through `AxiomVATRuleEvaluator`. The processed-table reader remains only for +paper-repository migration comparisons. Build the row-wise local-geography H5 from a compact Populace UK H5 with: diff --git a/packages/populace-build/src/populace/build/uk_runtime/__init__.py b/packages/populace-build/src/populace/build/uk_runtime/__init__.py index 08c325e..ae7ec4b 100644 --- a/packages/populace-build/src/populace/build/uk_runtime/__init__.py +++ b/packages/populace-build/src/populace/build/uk_runtime/__init__.py @@ -6,12 +6,14 @@ INPUT_FILES, VAT_LIABILITY_BANDS, VINTAGES, + AxiomVATRuleEvaluator, UKFirmCalibrationResult, UKFirmGenerationConfig, UKFirmGenerationResult, UKFirmSourceData, UKFirmTargetLayout, UKFirmValidationReport, + UKFirmVATRuleEvaluator, assign_employment, assign_vat_flags, build_firm_target_matrix, @@ -161,6 +163,7 @@ "AREA_TYPE_TO_LEDGER_GEOGRAPHY_LEVEL", "AREA_TYPES", "AREA_TYPE_TO_CROSSWALK_COLUMN", + "AxiomVATRuleEvaluator", "BASE_FRS_SUPPORT_CHANNEL", "BENUNIT_ID_COLUMNS", "COUNTRY_TO_REGION", @@ -208,6 +211,7 @@ "UKFirmGenerationResult", "UKFirmSourceData", "UKFirmTargetLayout", + "UKFirmVATRuleEvaluator", "UKFirmValidationReport", "UKLocalCandidateResult", "UKRowwiseDatasetResult", diff --git a/packages/populace-build/src/populace/build/uk_runtime/firm_generation.py b/packages/populace-build/src/populace/build/uk_runtime/firm_generation.py index 65165e2..1e5dae5 100644 --- a/packages/populace-build/src/populace/build/uk_runtime/firm_generation.py +++ b/packages/populace-build/src/populace/build/uk_runtime/firm_generation.py @@ -9,18 +9,20 @@ Scope is intentionally narrow and labelled experimental. Populace production calibration targets must be materialized from Ledger target profiles; the processed ONS/HMRC tables accepted here are a migration harness only, not a -production target source. Ledger-backed firm target activation and Axiom -VAT-rule execution are later integration steps, so this module does not claim -production firm microsimulation support. +production target source. VAT liability must be supplied by a configured rule +evaluator; the production path is an Axiom RuleSpec artifact, so this module +does not claim production firm microsimulation support without that artifact. """ from __future__ import annotations +import json import logging +import subprocess from collections.abc import Iterable, Mapping from dataclasses import dataclass, field, replace from pathlib import Path -from typing import Any +from typing import Any, Protocol import numpy as np import pandas as pd @@ -206,6 +208,136 @@ class UKFirmLedgerTargetProfile: } HIGH_LIABILITY_SECTORS = {11, 12, 69, 70, 78} +AXIOM_UK_VAT_NET_LIABILITY_OUTPUT = ( + "uk:statutes/ukpga/1994/23/25#net_vat_liability_after_input_tax_credit" +) +AXIOM_UK_VAT_TAXABLE_PERSON_INPUT = "uk:statutes/ukpga/1994/23/24#input.taxable_person" +AXIOM_UK_VAT_SUPPLIES_INPUT = ( + "uk:statutes/ukpga/1994/23/24#input.standard_rated_taxable_supplies_value" +) +AXIOM_UK_VAT_PURCHASES_INPUT = ( + "uk:statutes/ukpga/1994/23/24#input.standard_rated_deductible_business_purchase_value" +) +AXIOM_UK_VAT_ALLOWABLE_PROPORTION_INPUT = ( + "uk:statutes/ukpga/1994/23/26#input.allowable_input_tax_proportion" +) + + +class UKFirmVATRuleEvaluator(Protocol): + """Evaluate VAT liability for candidate firms before weight optimization.""" + + def net_liability_k( + self, + *, + turnover_k: np.ndarray, + input_cost_k: np.ndarray, + vat_registered: np.ndarray, + data_vintage: str, + ) -> np.ndarray: + """Return net VAT liability in thousands of pounds.""" + + +@dataclass(frozen=True) +class AxiomVATRuleEvaluator: + """Run UK VAT liability through a compiled Axiom RuleSpec artifact.""" + + artifact_path: str | Path + engine_binary: str | Path = "axiom-rules-engine" + allowable_input_tax_proportion: float = 1.0 + output_id: str = AXIOM_UK_VAT_NET_LIABILITY_OUTPUT + + def net_liability_k( + self, + *, + turnover_k: np.ndarray, + input_cost_k: np.ndarray, + vat_registered: np.ndarray, + data_vintage: str, + ) -> np.ndarray: + turnover = np.asarray(turnover_k, dtype=np.float64) + input_cost = np.asarray(input_cost_k, dtype=np.float64) + registered = np.asarray(vat_registered, dtype=bool) + if turnover.shape != input_cost.shape or turnover.shape != registered.shape: + raise ValueError("VAT rule inputs must have matching shapes.") + + period = _vat_rule_period(data_vintage) + interval = {"start": period["start"], "end": period["end"]} + request_inputs: list[dict[str, object]] = [] + queries: list[dict[str, object]] = [] + for index, (sales_k, purchases_k, taxable) in enumerate( + zip(turnover, input_cost, registered, strict=True), + ): + entity_id = f"firm:{index + 1}" + request_inputs.extend( + [ + _axiom_input_record( + AXIOM_UK_VAT_TAXABLE_PERSON_INPUT, + entity_id, + interval, + {"kind": "bool", "value": bool(taxable)}, + ), + _axiom_input_record( + AXIOM_UK_VAT_SUPPLIES_INPUT, + entity_id, + interval, + {"kind": "decimal", "value": _decimal_string(sales_k * 1000.0)}, + ), + _axiom_input_record( + AXIOM_UK_VAT_PURCHASES_INPUT, + entity_id, + interval, + { + "kind": "decimal", + "value": _decimal_string(purchases_k * 1000.0), + }, + ), + _axiom_input_record( + AXIOM_UK_VAT_ALLOWABLE_PROPORTION_INPUT, + entity_id, + interval, + { + "kind": "decimal", + "value": _decimal_string(self.allowable_input_tax_proportion), + }, + ), + ] + ) + queries.append( + { + "entity_id": entity_id, + "period": period, + "outputs": [self.output_id], + } + ) + + request = { + "mode": "fast", + "dataset": {"inputs": request_inputs, "relations": []}, + "queries": queries, + } + process = subprocess.run( + [ + str(self.engine_binary), + "run-compiled", + "--artifact", + str(self.artifact_path), + ], + input=json.dumps(request), + text=True, + capture_output=True, + check=False, + ) + if process.returncode != 0: + stderr = process.stderr.strip() or "Axiom VAT rule evaluation failed" + raise RuntimeError(stderr) + + response = json.loads(process.stdout) + values = [] + for result in response["results"]: + output = result["outputs"][self.output_id] + values.append(float(output["value"]["value"]) / 1000.0) + return np.asarray(values, dtype=np.float32) + @dataclass(frozen=True) class UKFirmGenerationConfig: @@ -227,6 +359,7 @@ class UKFirmGenerationConfig: vat_liability_band_importance: float = 2.0 calibrate_vat_liability_sector: bool = False input_files: dict[str, str] = field(default_factory=lambda: dict(INPUT_FILES)) + vat_rule_evaluator: UKFirmVATRuleEvaluator | None = None def __post_init__(self) -> None: if self.data_vintage not in VINTAGES: @@ -246,6 +379,34 @@ def __post_init__(self) -> None: raise ValueError("n_iterations must be positive.") +def _vat_rule_period(data_vintage: str) -> dict[str, str]: + start_year = int(data_vintage.split("-", 1)[0]) + return { + "period_kind": "tax_year", + "start": f"{start_year}-04-06", + "end": f"{start_year + 1}-04-05", + } + + +def _axiom_input_record( + name: str, + entity_id: str, + interval: Mapping[str, str], + value: Mapping[str, object], +) -> dict[str, object]: + return { + "name": name, + "entity": "Firm", + "entity_id": entity_id, + "interval": dict(interval), + "value": dict(value), + } + + +def _decimal_string(value: float) -> str: + return format(float(value), ".12g") + + @dataclass(frozen=True) class UKFirmSourceData: """Processed ONS/HMRC source tables used by the firm generator.""" @@ -564,6 +725,12 @@ def generate_uk_firm_population( cfg.device, ) base_vat_registered = assign_vat_flags(base_turnover, hmrc_bands, cfg) + base_vat_liability = calculate_vat_liability_values( + cfg, + base_turnover, + base_input, + base_vat_registered, + ) employment_band_indices = torch.tensor( [_employment_band_index(value.item()) for value in base_employment], dtype=torch.long, @@ -578,6 +745,7 @@ def generate_uk_firm_population( employment_band_indices, base_vat_registered, data, + vat_liability_values=base_vat_liability, ) calibration = solve_firm_weights(cfg, target_matrix, target_values, layout) weights = calibration.weights @@ -600,10 +768,17 @@ def generate_uk_firm_population( data.ons_employment, cfg.device, ) + final_vat_liability = calculate_vat_liability_values( + cfg, + final_turnover, + final_input, + final_vat_registered, + ) firms = _assemble_firm_rows( final_sic, final_turnover, final_input, + final_vat_liability, final_employment, final_weights, final_vat_registered, @@ -772,6 +947,7 @@ def build_firm_target_matrix( employment_band_indices: Tensor, vat_registered: Tensor, data: UKFirmSourceData, + vat_liability_values: Tensor | None = None, ) -> tuple[Tensor, Tensor, UKFirmTargetLayout, tuple[str, ...]]: """Construct the firm calibration target matrix and target vector.""" @@ -818,7 +994,16 @@ def build_firm_target_matrix( matrix[row, employment_band_indices == band_idx] = 1.0 target_names.append(f"ons_firm_employment/{band}") - vat_liability = turnover_values - input_values + vat_liability = ( + vat_liability_values + if vat_liability_values is not None + else calculate_vat_liability_values( + config, + turnover_values, + input_values, + vat_registered, + ) + ) for offset, (_, vat_row) in enumerate(vat_sector_rows.iterrows()): row = layout.vat_sector_start + offset sic_code = int(vat_row["Trade_Sector"]) @@ -967,6 +1152,40 @@ def target_diagnostics( ) +def calculate_vat_liability_values( + config: UKFirmGenerationConfig, + turnover_values: Tensor, + input_values: Tensor, + vat_registered: Tensor, +) -> Tensor: + """Evaluate net VAT liability for candidate firms via configured VAT rules.""" + + evaluator = config.vat_rule_evaluator + if evaluator is None: + raise ValueError( + "UK firm VAT liability requires a VAT rule evaluator. " + "Pass AxiomVATRuleEvaluator with a compiled UK VAT RuleSpec artifact." + ) + turnover_np = turnover_values.detach().cpu().numpy() + input_np = input_values.detach().cpu().numpy() + registered_np = vat_registered.detach().cpu().numpy().astype(bool) + liability_np = np.asarray( + evaluator.net_liability_k( + turnover_k=turnover_np, + input_cost_k=input_np, + vat_registered=registered_np, + data_vintage=config.data_vintage, + ), + dtype=np.float32, + ) + if liability_np.shape != turnover_np.shape: + raise ValueError( + "VAT rule evaluator returned shape " + f"{liability_np.shape}, expected {turnover_np.shape}." + ) + return torch.tensor(liability_np, dtype=torch.float32, device=config.device) + + def validate_uk_firm_population( firms: pd.DataFrame, data: UKFirmSourceData, @@ -1223,6 +1442,7 @@ def _assemble_firm_rows( sic_codes: Tensor, turnover: Tensor, input_values: Tensor, + vat_liability: Tensor, employment: Tensor, weights: Tensor, vat_registered: Tensor, @@ -1231,6 +1451,7 @@ def _assemble_firm_rows( sic_np = sic_codes.detach().cpu().numpy().astype(int) turnover_np = turnover.detach().cpu().numpy() input_np = input_values.detach().cpu().numpy() + vat_liability_np = vat_liability.detach().cpu().numpy() firm_ids = np.arange(1, len(sic_np) + 1, dtype=np.int64) return pd.DataFrame( { @@ -1244,7 +1465,7 @@ def _assemble_firm_rows( "sic_code": [str(sic).zfill(5) for sic in sic_np], "annual_turnover_k": turnover_np, "annual_input_k": input_np, - "vat_liability_k": turnover_np - input_np, + "vat_liability_k": vat_liability_np, "employment": employment.detach().cpu().numpy().astype(int), "firm_weight": weights.detach().cpu().numpy(), "vat_registered": vat_registered.detach().cpu().numpy().astype(bool), diff --git a/packages/populace-build/tests/test_uk_firm_generation.py b/packages/populace-build/tests/test_uk_firm_generation.py index 6514ef8..cb7418f 100644 --- a/packages/populace-build/tests/test_uk_firm_generation.py +++ b/packages/populace-build/tests/test_uk_firm_generation.py @@ -2,15 +2,22 @@ from pathlib import Path +import numpy as np import pandas as pd import pytest +import torch +from populace.build.uk_runtime import ( + AxiomVATRuleEvaluator as PublicAxiomVATRuleEvaluator, +) from populace.build.uk_runtime.firm_generation import ( DEFAULT_UK_FIRM_TARGET_PROFILE, HMRC_BAND_COLUMNS, INPUT_FILES, UK_FIRM_TARGET_IDS, + AxiomVATRuleEvaluator, UKFirmGenerationConfig, + calculate_vat_liability_values, employment_band_name, generate_uk_firm_population, hmrc_band_name, @@ -22,6 +29,10 @@ ) +def test_axiom_vat_rule_evaluator_is_public_uk_runtime_api() -> None: + assert PublicAxiomVATRuleEvaluator is AxiomVATRuleEvaluator + + def test_uk_firm_source_data_reads_processed_directory(tmp_path: Path) -> None: frames = _source_frames() for key, filename in INPUT_FILES.items(): @@ -95,7 +106,7 @@ def test_generate_uk_firm_population_accepts_ledger_source_data() -> None: result = generate_uk_firm_population( data, - UKFirmGenerationConfig(data_vintage="2024-25", n_iterations=4, seed=7), + _firm_config(data_vintage="2024-25", n_iterations=4, seed=7), ) diagnostics = result.target_diagnostics.set_index("target_name") @@ -123,6 +134,7 @@ def test_uk_firm_source_data_from_ledger_facts_requires_complete_profile() -> No def test_generate_uk_firm_population_returns_experimental_firm_rows() -> None: data = _source_data() config = UKFirmGenerationConfig( + vat_rule_evaluator=_TestVATRuleEvaluator(), data_vintage="2024-25", n_iterations=8, seed=7, @@ -183,11 +195,11 @@ def test_generate_uk_firm_population_uses_configured_vintage_targets() -> None: result_2023 = generate_uk_firm_population( data, - UKFirmGenerationConfig(data_vintage="2023-24", n_iterations=2, seed=7), + _firm_config(data_vintage="2023-24", n_iterations=2, seed=7), ) result_2024 = generate_uk_firm_population( data, - UKFirmGenerationConfig(data_vintage="2024-25", n_iterations=2, seed=7), + _firm_config(data_vintage="2024-25", n_iterations=2, seed=7), ) targets_2023 = result_2023.target_diagnostics.set_index("target_name")["target"] @@ -206,13 +218,14 @@ def test_generate_uk_firm_population_requires_requested_hmrc_vintage() -> None: with pytest.raises(ValueError, match="does not include vintage '2023-24'"): generate_uk_firm_population( data, - UKFirmGenerationConfig(data_vintage="2023-24", n_iterations=1), + _firm_config(data_vintage="2023-24", n_iterations=1), ) def test_generate_uk_firm_population_is_seed_reproducible() -> None: data = _source_data() config = UKFirmGenerationConfig( + vat_rule_evaluator=_TestVATRuleEvaluator(), n_iterations=3, seed=11, ) @@ -226,7 +239,7 @@ def test_generate_uk_firm_population_is_seed_reproducible() -> None: def test_write_uk_firm_population_writes_csv(tmp_path: Path) -> None: result = generate_uk_firm_population( _source_data(), - UKFirmGenerationConfig(n_iterations=2, seed=1), + _firm_config(n_iterations=2, seed=1), ) path = write_uk_firm_population(result, tmp_path / "firms.csv") @@ -262,6 +275,45 @@ def test_employment_band_name(employment: int, expected: str) -> None: assert employment_band_name(employment) == expected +def test_generate_uk_firm_population_requires_vat_rule_evaluator() -> None: + with pytest.raises(ValueError, match="requires a VAT rule evaluator"): + generate_uk_firm_population( + _source_data(), + UKFirmGenerationConfig(n_iterations=1, seed=1), + ) + + +def test_vat_liability_uses_input_tax_credit_not_full_input_cost() -> None: + result = calculate_vat_liability_values( + _firm_config(), + torch.tensor([100.0, 100.0], dtype=torch.float32), + torch.tensor([40.0, 120.0], dtype=torch.float32), + torch.tensor([True, True], dtype=torch.bool), + ) + + assert result.tolist() == pytest.approx([12.0, -4.0]) + + +class _TestVATRuleEvaluator: + def net_liability_k( + self, + *, + turnover_k: np.ndarray, + input_cost_k: np.ndarray, + vat_registered: np.ndarray, + data_vintage: str, + ) -> np.ndarray: + del data_vintage + return np.where(vat_registered, 0.2 * (turnover_k - input_cost_k), 0.0) + + +def _firm_config(**overrides) -> UKFirmGenerationConfig: + return UKFirmGenerationConfig( + vat_rule_evaluator=_TestVATRuleEvaluator(), + **overrides, + ) + + def _source_data(): return uk_firm_source_data_from_frames(**_source_frames())