Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/protect-ai-routes.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Protect AI and tracing routes with the shared API key.
4 changes: 3 additions & 1 deletion policyengine_api/routes/ai_prompt_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, Response, request
from copy import deepcopy
from policyengine_api.services.ai_prompt_service import AIPromptService
from policyengine_api.security import require_simulation_analysis_api_key
from policyengine_api.utils.payload_validators import validate_country
from policyengine_api.utils.payload_validators.ai import (
validate_sim_analysis_payload,
Expand All @@ -12,11 +13,12 @@
ai_prompt_service = AIPromptService()


@validate_country
@ai_prompt_bp.route(
"/<country_id>/ai-prompts/<string:prompt_name>",
methods=["POST"],
)
@validate_country
@require_simulation_analysis_api_key
def generate_ai_prompt(country_id, prompt_name: str) -> Response:
"""
Get an AI prompt with a given name, filled with the given data.
Expand Down
7 changes: 5 additions & 2 deletions policyengine_api/routes/simulation_analysis_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json

from flask import Blueprint, request, Response, stream_with_context
from werkzeug.exceptions import BadRequest
from policyengine_api.utils.payload_validators import validate_country

from policyengine_api.security import require_simulation_analysis_api_key
from policyengine_api.services.simulation_analysis_service import (
SimulationAnalysisService,
)
Expand All @@ -10,14 +13,14 @@
from policyengine_api.utils.payload_validators.ai import (
validate_sim_analysis_payload,
)
import json

simulation_analysis_bp = Blueprint("simulation_analysis", __name__)
simulation_analysis_service = SimulationAnalysisService()


@simulation_analysis_bp.route("/<country_id>/simulation-analysis", methods=["POST"])
@validate_country
@require_simulation_analysis_api_key
def execute_simulation_analysis(country_id):
print("Got POST request for simulation analysis")

Expand Down
10 changes: 5 additions & 5 deletions policyengine_api/routes/tracer_analysis_routes.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import json

from flask import Blueprint, request, Response, stream_with_context
from werkzeug.exceptions import BadRequest

from policyengine_api.security import require_simulation_analysis_api_key
from policyengine_api.utils.payload_validators import (
validate_country,
validate_tracer_analysis_payload,
)
from policyengine_api.services.tracer_analysis_service import (
TracerAnalysisService,
)
import json
from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS
import re

tracer_analysis_bp = Blueprint("tracer_analysis", __name__)
tracer_analysis_service = TracerAnalysisService()


@tracer_analysis_bp.route("/<country_id>/tracer-analysis", methods=["POST"])
@validate_country
@require_simulation_analysis_api_key
def execute_tracer_analysis(country_id):

payload = request.json
Expand All @@ -28,8 +30,6 @@ def execute_tracer_analysis(country_id):
household_id = payload.get("household_id")
policy_id = payload.get("policy_id")
variable = payload.get("variable")
api_version = COUNTRY_PACKAGE_VERSIONS[country_id]

if not isinstance(variable, str):
raise BadRequest("variable must be a string")

Expand Down
24 changes: 24 additions & 0 deletions policyengine_api/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Security helpers for sensitive API routes."""

import os
from functools import wraps

from flask import request
from werkzeug.exceptions import Unauthorized


def require_simulation_analysis_api_key(view):
"""Require a shared API key for simulation analysis requests."""

@wraps(view)
def wrapped(*args, **kwargs):
expected_key = os.getenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "").strip()
if not expected_key:
raise Unauthorized("Simulation analysis API key is not configured")

if request.headers.get("X-PolicyEngine-Api-Key") == expected_key:
return view(*args, **kwargs)

raise Unauthorized("API key required for simulation analysis")

return wrapped
11 changes: 10 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from contextlib import contextmanager
from subprocess import Popen, TimeoutExpired
import sys
import os
import redis
import pytest
from policyengine_api.api import app
Expand Down Expand Up @@ -33,9 +33,18 @@ def running(process_arguments, seconds_to_wait_after_launch=0):
def client():
"""run the app for the tests to run against"""
app.config["TESTING"] = True
previous_api_key = os.environ.get("POLICYENGINE_API_AI_ANALYSIS_API_KEY")
os.environ["POLICYENGINE_API_AI_ANALYSIS_API_KEY"] = "test-ai-analysis-key"
with running(["redis-server"], 3):
redis_client = redis.Redis()
redis_client.ping()
with running([sys.executable, "policyengine_api/worker.py"], 3):
with app.test_client() as test_client:
test_client.environ_base["HTTP_X_POLICYENGINE_API_KEY"] = (
"test-ai-analysis-key"
)
yield test_client
if previous_api_key is None:
os.environ.pop("POLICYENGINE_API_AI_ANALYSIS_API_KEY", None)
else:
os.environ["POLICYENGINE_API_AI_ANALYSIS_API_KEY"] = previous_api_key
136 changes: 136 additions & 0 deletions tests/unit/routes/test_ai_route_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os
from unittest.mock import patch

import pytest

os.environ.setdefault("FLASK_DEBUG", "1")

from policyengine_api.api import app
from tests.fixtures.simulation_analysis_prompt_fixtures import valid_input_us


@pytest.fixture
def client():
app.config["TESTING"] = True
with app.test_client() as test_client:
yield test_client


def test_ai_prompt_rejects_requests_without_api_key(client, monkeypatch):
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")

response = client.post(
"/us/ai-prompts/simulation_analysis",
json=valid_input_us,
environ_base={"REMOTE_ADDR": "203.0.113.10"},
)

assert response.status_code == 401
assert "API key required" in response.json["message"]


def test_ai_prompt_rejects_loopback_requests_without_api_key(client, monkeypatch):
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")

response = client.post(
"/us/ai-prompts/simulation_analysis",
json=valid_input_us,
environ_base={"REMOTE_ADDR": "127.0.0.1"},
)

assert response.status_code == 401
assert "API key required" in response.json["message"]


def test_ai_prompt_allows_requests_with_api_key(client, monkeypatch):
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")

with patch(
"policyengine_api.routes.ai_prompt_routes.ai_prompt_service.get_prompt",
return_value="Prompt text",
) as mock_get_prompt:
response = client.post(
"/us/ai-prompts/simulation_analysis",
json=valid_input_us,
headers={"X-PolicyEngine-Api-Key": "secret-key"},
environ_base={"REMOTE_ADDR": "203.0.113.10"},
)

assert response.status_code == 200
assert response.json["result"] == "Prompt text"
mock_get_prompt.assert_called_once()


def test_tracer_analysis_rejects_requests_without_api_key(client, monkeypatch):
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")

response = client.post(
"/us/tracer-analysis",
json={
"household_id": 1500,
"policy_id": 2,
"variable": "disposable_income",
},
environ_base={"REMOTE_ADDR": "203.0.113.10"},
)

assert response.status_code == 401
assert "API key required" in response.json["message"]


def test_requests_fail_closed_when_api_key_is_not_configured(client, monkeypatch):
monkeypatch.delenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", raising=False)

response = client.post(
"/us/tracer-analysis",
json={
"household_id": 1500,
"policy_id": 2,
"variable": "disposable_income",
},
environ_base={"REMOTE_ADDR": "203.0.113.10"},
)

assert response.status_code == 401
assert "not configured" in response.json["message"]


def test_env_flag_does_not_reopen_tracer_analysis(client, monkeypatch):
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")
monkeypatch.setenv("POLICYENGINE_API_ALLOW_UNAUTHENTICATED_AI_ANALYSIS", "true")

response = client.post(
"/us/tracer-analysis",
json={
"household_id": 1500,
"policy_id": 2,
"variable": "disposable_income",
},
environ_base={"REMOTE_ADDR": "203.0.113.10"},
)

assert response.status_code == 401
assert "API key required" in response.json["message"]


def test_tracer_analysis_allows_requests_with_api_key(client, monkeypatch):
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")

with patch(
"policyengine_api.routes.tracer_analysis_routes.tracer_analysis_service.execute_analysis",
return_value=("Existing analysis", "static"),
) as mock_execute_analysis:
response = client.post(
"/us/tracer-analysis",
json={
"household_id": 1500,
"policy_id": 2,
"variable": "disposable_income",
},
headers={"X-PolicyEngine-Api-Key": "secret-key"},
environ_base={"REMOTE_ADDR": "203.0.113.10"},
)

assert response.status_code == 200
assert response.json["result"] == "Existing analysis"
mock_execute_analysis.assert_called_once_with("us", 1500, 2, "disposable_income")
48 changes: 48 additions & 0 deletions tests/unit/routes/test_simulation_analysis_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
from unittest.mock import patch

import pytest

os.environ.setdefault("FLASK_DEBUG", "1")

from policyengine_api.api import app
from tests.to_refactor.fixtures.simulation_analysis_fixtures import test_json


@pytest.fixture
def client():
app.config["TESTING"] = True
with app.test_client() as test_client:
yield test_client


def test_simulation_analysis_rejects_requests_without_api_key(client, monkeypatch):
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")

response = client.post(
"/us/simulation-analysis",
json=test_json,
environ_base={"REMOTE_ADDR": "203.0.113.10"},
)

assert response.status_code == 401
assert "API key required" in response.json["message"]


def test_simulation_analysis_allows_requests_with_api_key(client, monkeypatch):
monkeypatch.setenv("POLICYENGINE_API_AI_ANALYSIS_API_KEY", "secret-key")

with patch(
"policyengine_api.routes.simulation_analysis_routes.simulation_analysis_service.execute_analysis",
return_value=("Existing analysis", "static"),
) as mock_execute_analysis:
response = client.post(
"/us/simulation-analysis",
json=test_json,
headers={"X-PolicyEngine-Api-Key": "secret-key"},
environ_base={"REMOTE_ADDR": "203.0.113.10"},
)

assert response.status_code == 200
assert response.json["result"] == "Existing analysis"
mock_execute_analysis.assert_called_once()
Loading