From 087335dfdcd4519a483c236af980b2b8ac370964 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Tue, 31 Mar 2026 22:03:26 +0200 Subject: [PATCH 01/18] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20replace?= =?UTF-8?q?=20diskcache=20with=20aiosqlite,=20make=20all=20bench=20routes?= =?UTF-8?q?=20async?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/bench/app/api.py | 145 +++++++++++------- scripts/bench/app/main.py | 4 +- .../apps/bench-apx/requirements.txt | 2 +- .../apps/bench-granian/requirements.txt | 2 +- .../apps/bench-uvicorn/requirements.txt | 2 +- 5 files changed, 98 insertions(+), 57 deletions(-) diff --git a/scripts/bench/app/api.py b/scripts/bench/app/api.py index bda8ff50..82852aa2 100644 --- a/scripts/bench/app/api.py +++ b/scripts/bench/app/api.py @@ -1,17 +1,16 @@ from __future__ import annotations import asyncio +import json +from contextlib import asynccontextmanager -import diskcache -from fastapi import APIRouter, Depends, HTTPException, Request +import aiosqlite +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from .models import Item, ItemCreate, ItemUpdate -router = APIRouter() - -_CACHE_DIR = "/tmp/bench_items_cache" -_cache = diskcache.Cache(_CACHE_DIR) +DB_PATH = "/tmp/bench_items.db" _DEFAULT_ITEMS = [ Item( @@ -24,25 +23,55 @@ for i in range(1, 11) ] +_CREATE_TABLE = """\ +CREATE TABLE IF NOT EXISTS items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + description TEXT, + price REAL NOT NULL, + tags TEXT NOT NULL DEFAULT '[]' +)""" + +_INSERT_ITEM = "INSERT INTO items (id, name, description, price, tags) VALUES (?, ?, ?, ?, ?)" +_INSERT_ITEM_AUTO = "INSERT INTO items (name, description, price, tags) VALUES (?, ?, ?, ?)" + + +def _row_to_item(row: aiosqlite.Row) -> Item: + return Item(id=row[0], name=row[1], description=row[2], price=row[3], tags=json.loads(row[4])) -def _populate_defaults(): - _cache.clear() + +async def _seed_defaults(db: aiosqlite.Connection) -> None: for item in _DEFAULT_ITEMS: - _cache[f"item:{item.id}"] = item.model_dump() - _cache["_counter"] = 10 + await db.execute( + _INSERT_ITEM, + (item.id, item.name, item.description, item.price, json.dumps(item.tags)), + ) + await db.commit() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + db = await aiosqlite.connect(DB_PATH) + await db.execute("PRAGMA journal_mode=WAL") + await db.execute(_CREATE_TABLE) + cursor = await db.execute("SELECT count(*) FROM items") + row = await cursor.fetchone() + if row is not None and row[0] == 0: + await _seed_defaults(db) + app.state.db = db + yield + await db.close() -# Auto-populate on first boot -if "_counter" not in _cache: - _populate_defaults() +router = APIRouter() -def _next_id() -> int: - return _cache.incr("_counter") +async def _get_db(request: Request) -> aiosqlite.Connection: + return request.app.state.db @router.get("/version") -def version() -> dict[str, str]: +async def version() -> dict[str, str]: """Return the APX package version (includes build timestamp).""" try: from importlib.metadata import version as pkg_version @@ -53,70 +82,84 @@ def version() -> dict[str, str]: @router.get("/echo") -def echo() -> dict[str, bool]: +async def echo() -> dict[str, bool]: """Minimal handler — isolates framework overhead from app logic.""" return {"echo": True} @router.get("/request-id") -def request_id(request: Request) -> dict[str, str | None]: +async def request_id(request: Request) -> dict[str, str | None]: """Return the X-Request-Id seen by the ASGI app.""" return {"request_id": request.headers.get("x-request-id")} @router.get("/health") -def health() -> dict[str, str]: +async def health() -> dict[str, str]: return {"status": "ok"} @router.get("/items", response_model=list[Item]) -def list_items() -> list[Item]: - items = [] - for key in _cache: - if isinstance(key, str) and key.startswith("item:"): - items.append(Item(**_cache[key])) - items.sort(key=lambda x: x.id) - return items +async def list_items(db: aiosqlite.Connection = Depends(_get_db)) -> list[Item]: + cursor = await db.execute("SELECT id, name, description, price, tags FROM items ORDER BY id") + rows = await cursor.fetchall() + return [_row_to_item(row) for row in rows] @router.get("/items/{item_id}", response_model=Item) -def get_item(item_id: int) -> Item: - data = _cache.get(f"item:{item_id}") - if data is None: +async def get_item(item_id: int, db: aiosqlite.Connection = Depends(_get_db)) -> Item: + cursor = await db.execute( + "SELECT id, name, description, price, tags FROM items WHERE id = ?", (item_id,) + ) + row = await cursor.fetchone() + if row is None: raise HTTPException(status_code=404, detail="Item not found") - return Item(**data) + return _row_to_item(row) @router.post("/items", response_model=Item, status_code=201) -def create_item(body: ItemCreate) -> Item: - item = Item(id=_next_id(), **body.model_dump()) - _cache[f"item:{item.id}"] = item.model_dump() - return item +async def create_item(body: ItemCreate, db: aiosqlite.Connection = Depends(_get_db)) -> Item: + cursor = await db.execute( + _INSERT_ITEM_AUTO, + (body.name, body.description, body.price, json.dumps(body.tags)), + ) + await db.commit() + assert cursor.lastrowid is not None + return Item(id=cursor.lastrowid, **body.model_dump()) @router.patch("/items/{item_id}", response_model=Item) -def update_item(item_id: int, body: ItemUpdate) -> Item: - data = _cache.get(f"item:{item_id}") - if data is None: +async def update_item( + item_id: int, body: ItemUpdate, db: aiosqlite.Connection = Depends(_get_db) +) -> Item: + cursor = await db.execute( + "SELECT id, name, description, price, tags FROM items WHERE id = ?", (item_id,) + ) + row = await cursor.fetchone() + if row is None: raise HTTPException(status_code=404, detail="Item not found") - existing = Item(**data) + existing = _row_to_item(row) updated = existing.model_copy(update=body.model_dump(exclude_unset=True)) - _cache[f"item:{item_id}"] = updated.model_dump() + await db.execute( + "UPDATE items SET name = ?, description = ?, price = ?, tags = ? WHERE id = ?", + (updated.name, updated.description, updated.price, json.dumps(updated.tags), item_id), + ) + await db.commit() return updated @router.delete("/items/{item_id}", status_code=204) -def delete_item(item_id: int): - _cache.pop(f"item:{item_id}", None) - from fastapi.responses import Response - +async def delete_item(item_id: int, db: aiosqlite.Connection = Depends(_get_db)): + await db.execute("DELETE FROM items WHERE id = ?", (item_id,)) + await db.commit() return Response(status_code=204) @router.post("/items/reset") -def items_reset(): +async def items_reset(db: aiosqlite.Connection = Depends(_get_db)): """Clear all items and repopulate with defaults.""" - _populate_defaults() + await db.execute("DELETE FROM items") + await db.execute("DELETE FROM sqlite_sequence WHERE name = 'items'") + await _seed_defaults(db) return {"status": "reset", "items": 10} @@ -147,7 +190,7 @@ async def yield_once(): @router.get("/cpu/{n}") -def cpu_work(n: int): +async def cpu_work(n: int): """GIL hold under concurrency — sum of squares.""" n = min(n, 1_000_000) total = sum(i * i for i in range(n)) @@ -155,7 +198,7 @@ def cpu_work(n: int): @router.get("/large/{kb}") -def large_response(kb: int): +async def large_response(kb: int): """Large body — send() overhead.""" kb = min(kb, 1024) data = "x" * (kb * 1024) @@ -372,7 +415,7 @@ async def telemetry_cross_signal(): @router.get("/profile/dump") -def profile_dump(): +async def profile_dump(): """Return profiling JSONL over HTTP (for remote extraction).""" from .profiling import PROFILE_PATH, flush @@ -383,7 +426,7 @@ def profile_dump(): @router.delete("/profile/reset") -def profile_reset(): +async def profile_reset(): """Clear profiling data for a fresh run.""" from . import profiling @@ -392,7 +435,7 @@ def profile_reset(): @router.get("/_bench/scheduler-stats") -def scheduler_stats(): +async def scheduler_stats(): """Return scheduler counters as JSON.""" try: from apx._core import scheduler_stats_json @@ -401,6 +444,4 @@ def scheduler_stats(): data = scheduler_stats_json() if data is None: raise HTTPException(status_code=404, detail="no scheduler stats") - import json - return json.loads(data) diff --git a/scripts/bench/app/main.py b/scripts/bench/app/main.py index 370b613d..f70413b0 100644 --- a/scripts/bench/app/main.py +++ b/scripts/bench/app/main.py @@ -3,10 +3,10 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles -from .api import router +from .api import lifespan, router from .profiling import install_profiling -app = FastAPI(title="APX Bench App") +app = FastAPI(title="APX Bench App", lifespan=lifespan) app.include_router(router, prefix="/api") # Install ASGI profiling middleware when APX_BENCH_PROFILE=1. diff --git a/scripts/bench/databricks/apps/bench-apx/requirements.txt b/scripts/bench/databricks/apps/bench-apx/requirements.txt index ae6b3b55..7d4f9614 100644 --- a/scripts/bench/databricks/apps/bench-apx/requirements.txt +++ b/scripts/bench/databricks/apps/bench-apx/requirements.txt @@ -1,3 +1,3 @@ -diskcache +aiosqlite fastapi uvloop diff --git a/scripts/bench/databricks/apps/bench-granian/requirements.txt b/scripts/bench/databricks/apps/bench-granian/requirements.txt index 22bb5569..0d41e3bc 100644 --- a/scripts/bench/databricks/apps/bench-granian/requirements.txt +++ b/scripts/bench/databricks/apps/bench-granian/requirements.txt @@ -1,3 +1,3 @@ -diskcache +aiosqlite fastapi granian[uvloop] diff --git a/scripts/bench/databricks/apps/bench-uvicorn/requirements.txt b/scripts/bench/databricks/apps/bench-uvicorn/requirements.txt index 133ac53b..0d5db4be 100644 --- a/scripts/bench/databricks/apps/bench-uvicorn/requirements.txt +++ b/scripts/bench/databricks/apps/bench-uvicorn/requirements.txt @@ -1,3 +1,3 @@ -diskcache +aiosqlite fastapi uvicorn[standard] From b80d0619cb26e1a259d122cd819cae90d22680b5 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Tue, 31 Mar 2026 23:34:45 +0200 Subject: [PATCH 02/18] =?UTF-8?q?=E2=9C=A8=20feat:=20implement=20ASGI=20li?= =?UTF-8?q?fespan=20protocol=20for=20startup/shutdown=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/framework/src/asgi/app.rs | 25 +- crates/framework/src/asgi/lifespan.rs | 729 +++++++++++++++++++++ crates/framework/src/asgi/mod.rs | 1 + crates/framework/src/pyapi.rs | 4 + crates/framework/src/supervision/worker.rs | 113 +++- tests/integration/test_apx.py | 31 + 6 files changed, 882 insertions(+), 21 deletions(-) create mode 100644 crates/framework/src/asgi/lifespan.rs diff --git a/crates/framework/src/asgi/app.rs b/crates/framework/src/asgi/app.rs index e0f63b76..c68267ba 100644 --- a/crates/framework/src/asgi/app.rs +++ b/crates/framework/src/asgi/app.rs @@ -133,6 +133,7 @@ const DEFAULT_BODY_LIMIT: usize = 10 * 1024 * 1024; /// Implementations decide how the app is located (runtime import, manifest, /// etc.) and which dispatch strategy to use. The returned `Arc` /// is handed to `ApxService` and shared across all connections. +#[expect(dead_code, reason = "extension seam for future app loading strategies")] pub trait AppSource: Send + Sync + std::fmt::Debug { /// Load the app and construct its dispatch pipeline. /// @@ -212,15 +213,18 @@ impl ModuleImport { } } -impl AppSource for ModuleImport { - fn build( +impl ModuleImport { + /// Load the app and build dispatch, returning both the dispatch and the + /// raw ASGI callable reference (needed for the lifespan protocol). + pub fn build_with_app( &self, py: Python<'_>, ctx: Arc, event_loop_py: &Py, server_addr: SocketAddr, - ) -> Result, AppLoadError> { + ) -> Result<(Arc, Py), AppLoadError> { let app = self.load_callable(py)?; + let asgi_app = app.inner().clone_ref(py); let interns = Arc::new(ScopeInterns::new(py, server_addr)); let queue = RequestQueue::new( @@ -264,7 +268,20 @@ impl AppSource for ModuleImport { interns, ctx, ); - Ok(Arc::new(dispatch)) + Ok((Arc::new(dispatch), asgi_app)) + } +} + +impl AppSource for ModuleImport { + fn build( + &self, + py: Python<'_>, + ctx: Arc, + event_loop_py: &Py, + server_addr: SocketAddr, + ) -> Result, AppLoadError> { + self.build_with_app(py, ctx, event_loop_py, server_addr) + .map(|(dispatch, _)| dispatch) } } diff --git a/crates/framework/src/asgi/lifespan.rs b/crates/framework/src/asgi/lifespan.rs new file mode 100644 index 00000000..c0c5b57b --- /dev/null +++ b/crates/framework/src/asgi/lifespan.rs @@ -0,0 +1,729 @@ +//! ASGI lifespan protocol — startup/shutdown hooks for the application. +//! +//! Implements the ASGI lifespan spec: the server calls `app(scope, receive, send)` +//! with `scope["type"] == "lifespan"`, then exchanges startup/shutdown events via +//! the receive and send callables. +//! +//! The protocol runs on the asyncio thread as a long-lived task. Three tokio +//! oneshot channels bridge it to the tokio thread: +//! - **startup**: `LifespanSend` signals startup result +//! - **shutdown_trigger**: tokio thread tells `LifespanReceive` to deliver shutdown +//! - **shutdown**: `LifespanSend` signals shutdown result + +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyString}; +use std::sync::Mutex; +use std::time::Duration; +use tokio::sync::oneshot; + +use super::scope::{ResolvedAwaitable, ResolvedAwaitableWithValue}; +use crate::io::EventLoop; + +// ── Protocol types (pure, no I/O) ──────────────────────────────────────── + +/// Outcome of a lifespan startup or shutdown phase. +#[derive(Debug)] +pub enum LifespanResult { + /// App sent `lifespan.startup.complete` or `lifespan.shutdown.complete`. + Complete, + /// App sent `lifespan.startup.failed` or `lifespan.shutdown.failed`. + Failed(String), + /// App raised during `app(scope, receive, send)` — does not support lifespan. + Unsupported, +} + +/// Internal state machine for [`LifespanReceive`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ReceiveState { + /// Next call returns `{"type": "lifespan.startup"}`. + Startup, + /// Next call blocks until shutdown trigger, then returns `{"type": "lifespan.shutdown"}`. + WaitingShutdown, + /// No more events — pend forever. + Done, +} + +// ── LifespanReceive ────────────────────────────────────────────────────── + +/// ASGI `receive` callable for the lifespan protocol. +/// +/// First `await receive()` returns `{"type": "lifespan.startup"}` immediately. +/// Second `await receive()` blocks until the server triggers shutdown, then +/// returns `{"type": "lifespan.shutdown"}`. Subsequent calls pend forever. +#[pyclass(module = "apx._core")] +pub struct LifespanReceive { + state: Mutex, + shutdown_trigger_rx: Mutex>>, +} + +crate::opaque_debug!(LifespanReceive); + +impl LifespanReceive { + /// Create a new lifespan receive callable. + pub(crate) fn new(shutdown_trigger_rx: oneshot::Receiver<()>) -> Self { + Self { + state: Mutex::new(ReceiveState::Startup), + shutdown_trigger_rx: Mutex::new(Some(shutdown_trigger_rx)), + } + } +} + +#[pymethods] +impl LifespanReceive { + fn __call__<'py>(&self, py: Python<'py>) -> PyResult> { + let mut state = self + .state + .lock() + .map_err(|_| pyo3::exceptions::PyRuntimeError::new_err("receive mutex poisoned"))?; + + match *state { + ReceiveState::Startup => { + *state = ReceiveState::WaitingShutdown; + drop(state); + let event = build_startup_event(py)?; + Py::new(py, ResolvedAwaitableWithValue::new(event)) + .map(|obj| obj.into_bound(py).into_any()) + } + ReceiveState::WaitingShutdown => { + let rx = self + .shutdown_trigger_rx + .lock() + .map_err(|_| { + pyo3::exceptions::PyRuntimeError::new_err("shutdown trigger mutex poisoned") + })? + .take() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("shutdown already triggered") + })?; + *state = ReceiveState::Done; + drop(state); + + let handle = crate::io::with_tokio_handle(|h| h.clone()).ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err( + "no tokio runtime for lifespan shutdown wait", + ) + })?; + let _guard = handle.enter(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let _ = rx.await; + Python::attach(|py| { + let event = build_shutdown_event(py)?; + Ok(event) + }) + }) + } + ReceiveState::Done => { + drop(state); + let handle = crate::io::with_tokio_handle(|h| h.clone()).ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err( + "no tokio runtime for lifespan pending", + ) + })?; + let _guard = handle.enter(); + pyo3_async_runtimes::tokio::future_into_py( + py, + std::future::pending::>>(), + ) + } + } + } +} + +// ── Send event classification (sans-I/O) ───────────────────────────────── + +/// ASGI lifespan send event type: startup completed successfully. +const STARTUP_COMPLETE: &str = "lifespan.startup.complete"; + +/// ASGI lifespan send event type: startup failed. +const STARTUP_FAILED: &str = "lifespan.startup.failed"; + +/// ASGI lifespan send event type: shutdown completed successfully. +const SHUTDOWN_COMPLETE: &str = "lifespan.shutdown.complete"; + +/// ASGI lifespan send event type: shutdown failed. +const SHUTDOWN_FAILED: &str = "lifespan.shutdown.failed"; + +/// Classified lifespan send event — pure protocol, no I/O. +enum SendEvent { + StartupComplete, + StartupFailed(String), + ShutdownComplete, + ShutdownFailed(String), +} + +/// Parse a lifespan send event dict into a classified event. +fn classify_send_event(event: &Bound<'_, PyDict>) -> PyResult { + let py = event.py(); + let event_type: String = event + .get_item(pyo3::intern!(py, "type"))? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("type"))? + .extract()?; + + match event_type.as_str() { + STARTUP_COMPLETE => Ok(SendEvent::StartupComplete), + STARTUP_FAILED => Ok(SendEvent::StartupFailed(extract_message(event)?)), + SHUTDOWN_COMPLETE => Ok(SendEvent::ShutdownComplete), + SHUTDOWN_FAILED => Ok(SendEvent::ShutdownFailed(extract_message(event)?)), + _ => Err(pyo3::exceptions::PyValueError::new_err(format!( + "unsupported lifespan event type: {event_type}" + ))), + } +} + +/// Extract the optional `"message"` field from a lifespan event dict. +fn extract_message(event: &Bound<'_, PyDict>) -> PyResult { + let py = event.py(); + event + .get_item(pyo3::intern!(py, "message"))? + .map(|v| v.extract::()) + .transpose() + .map(|opt| opt.unwrap_or_default()) +} + +/// Send a result through a guarded oneshot channel. +fn signal(tx: &Mutex>>, result: LifespanResult) { + if let Ok(mut guard) = tx.lock() + && let Some(tx) = guard.take() + { + let _ = tx.send(result); + } +} + +// ── LifespanSend ───────────────────────────────────────────────────────── + +/// ASGI `send` callable for the lifespan protocol. +/// +/// Parses `lifespan.startup.complete`, `lifespan.startup.failed`, +/// `lifespan.shutdown.complete`, and `lifespan.shutdown.failed` events, +/// signaling results through oneshot channels. +#[pyclass(module = "apx._core")] +pub struct LifespanSend { + startup_tx: Mutex>>, + shutdown_tx: Mutex>>, + resolved: Py, +} + +crate::opaque_debug!(LifespanSend); + +impl LifespanSend { + /// Create a new lifespan send callable. + pub(crate) fn new( + py: Python<'_>, + startup_tx: oneshot::Sender, + shutdown_tx: oneshot::Sender, + ) -> PyResult { + Ok(Self { + startup_tx: Mutex::new(Some(startup_tx)), + shutdown_tx: Mutex::new(Some(shutdown_tx)), + resolved: Py::new(py, ResolvedAwaitable)?, + }) + } +} + +#[pymethods] +impl LifespanSend { + /// Forward an unhandled app exception — signals lifespan unsupported or shutdown failed. + fn send_error(&self, traceback: String) { + if let Ok(mut guard) = self.startup_tx.lock() + && let Some(tx) = guard.take() + { + let _ = tx.send(LifespanResult::Unsupported); + return; + } + signal(&self.shutdown_tx, LifespanResult::Failed(traceback)); + } + + fn __call__<'py>( + &self, + py: Python<'py>, + event: Bound<'py, PyDict>, + ) -> PyResult> { + match classify_send_event(&event)? { + SendEvent::StartupComplete => { + signal(&self.startup_tx, LifespanResult::Complete); + } + SendEvent::StartupFailed(msg) => { + signal(&self.startup_tx, LifespanResult::Failed(msg)); + } + SendEvent::ShutdownComplete => { + signal(&self.shutdown_tx, LifespanResult::Complete); + } + SendEvent::ShutdownFailed(msg) => { + signal(&self.shutdown_tx, LifespanResult::Failed(msg)); + } + } + Ok(self.resolved.clone_ref(py).into_bound(py).into_any()) + } +} + +// ── Scope builder ──────────────────────────────────────────────────────── + +/// ASGI protocol version string. +const ASGI_VERSION: &str = "3.0"; + +/// ASGI spec version string. +const ASGI_SPEC_VERSION: &str = "2.4"; + +/// Build the ASGI lifespan scope dict. +fn build_lifespan_scope(py: Python<'_>) -> PyResult> { + let scope = PyDict::new(py); + scope.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "lifespan"))?; + + let asgi = PyDict::new(py); + asgi.set_item( + pyo3::intern!(py, "version"), + PyString::intern(py, ASGI_VERSION), + )?; + asgi.set_item( + pyo3::intern!(py, "spec_version"), + PyString::intern(py, ASGI_SPEC_VERSION), + )?; + scope.set_item(pyo3::intern!(py, "asgi"), asgi)?; + + scope.set_item(pyo3::intern!(py, "state"), PyDict::new(py))?; + Ok(scope.unbind()) +} + +/// Build `{"type": "lifespan.startup"}` event for receive. +fn build_startup_event(py: Python<'_>) -> PyResult> { + let event = PyDict::new(py); + event.set_item( + pyo3::intern!(py, "type"), + pyo3::intern!(py, "lifespan.startup"), + )?; + Ok(event.into_any().unbind()) +} + +/// Build `{"type": "lifespan.shutdown"}` event for receive. +fn build_shutdown_event(py: Python<'_>) -> PyResult> { + let event = PyDict::new(py); + event.set_item( + pyo3::intern!(py, "type"), + pyo3::intern!(py, "lifespan.shutdown"), + )?; + Ok(event.into_any().unbind()) +} + +// ── Handles ────────────────────────────────────────────────────────────── + +/// Pre-startup handle — awaiting startup result. +/// +/// Returned by [`launch_lifespan`]. Call [`wait_startup`](Self::wait_startup) +/// to consume the startup channel and obtain a [`LifespanHandle`] for shutdown. +pub struct LifespanPending { + startup_rx: oneshot::Receiver, + shutdown_trigger_tx: oneshot::Sender<()>, + shutdown_rx: oneshot::Receiver, +} + +crate::opaque_debug!(LifespanPending); + +/// Lifespan startup timeout — if the app does not respond within this +/// duration, startup is treated as a failure. +const STARTUP_TIMEOUT: Duration = Duration::from_secs(30); + +impl LifespanPending { + /// Wait for the app to complete lifespan startup. + /// + /// Returns `Ok(Some(handle))` on success, `Ok(None)` if the app does + /// not support lifespan, or `Err(message)` on failure or timeout. + pub async fn wait_startup(self) -> Result, String> { + let result = tokio::time::timeout(STARTUP_TIMEOUT, self.startup_rx).await; + + match result { + Ok(Ok(LifespanResult::Complete)) => Ok(Some(LifespanHandle { + shutdown_trigger_tx: Some(self.shutdown_trigger_tx), + shutdown_rx: Some(self.shutdown_rx), + })), + Ok(Ok(LifespanResult::Unsupported)) => Ok(None), + Ok(Ok(LifespanResult::Failed(msg))) => Err(msg), + Ok(Err(_)) => Err("lifespan task died unexpectedly".to_owned()), + Err(_) => Err("lifespan startup timed out (30s)".to_owned()), + } + } +} + +/// Post-startup handle — the lifespan coroutine is alive and waiting for shutdown. +pub struct LifespanHandle { + shutdown_trigger_tx: Option>, + shutdown_rx: Option>, +} + +crate::opaque_debug!(LifespanHandle); + +/// Lifespan shutdown timeout. +const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30); + +impl LifespanHandle { + /// Trigger lifespan shutdown and wait for completion. + pub async fn trigger_shutdown(mut self) -> Result<(), String> { + if let Some(tx) = self.shutdown_trigger_tx.take() { + let _ = tx.send(()); + } + let Some(rx) = self.shutdown_rx.take() else { + return Ok(()); + }; + match tokio::time::timeout(SHUTDOWN_TIMEOUT, rx).await { + Ok(Ok(LifespanResult::Failed(msg))) => Err(msg), + Ok(Ok(LifespanResult::Complete | LifespanResult::Unsupported) | Err(_)) => Ok(()), + Err(_) => { + tracing::warn!( + name: "apx.lifespan.shutdown_timeout", + "lifespan shutdown timed out (30s)" + ); + Ok(()) + } + } + } +} + +// ── Launcher ───────────────────────────────────────────────────────────── + +/// Launch the ASGI lifespan protocol on the asyncio thread. +/// +/// Builds the lifespan scope, receive, and send callables, then submits +/// `launch(app, scope, receive, send)` via `call_soon_threadsafe`. +/// Returns a [`LifespanPending`] for awaiting the startup result. +pub fn launch_lifespan( + py: Python<'_>, + event_loop: &EventLoop, + app: &Py, + launch_fn: &Py, +) -> PyResult { + let (startup_tx, startup_rx) = oneshot::channel(); + let (shutdown_trigger_tx, shutdown_trigger_rx) = oneshot::channel(); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let scope = build_lifespan_scope(py)?; + let receive = Py::new(py, LifespanReceive::new(shutdown_trigger_rx))?; + let send = Py::new(py, LifespanSend::new(py, startup_tx, shutdown_tx)?)?; + + event_loop + .call_soon_threadsafe() + .call1(py, (launch_fn, app, &scope, &receive, &send))?; + + Ok(LifespanPending { + startup_rx, + shutdown_trigger_tx, + shutdown_rx, + }) +} + +// ── Tests ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +#[expect( + clippy::unwrap_used, + reason = "test code uses unwrap/assert for clarity" +)] +mod tests { + use super::*; + use crate::with_py; + + #[test] + fn build_lifespan_scope_fields() { + with_py(|py| { + let scope = build_lifespan_scope(py).unwrap(); + let scope = scope.bind(py); + + let scope_type: String = scope.get_item("type").unwrap().unwrap().extract().unwrap(); + assert_eq!(scope_type, "lifespan"); + + let asgi = scope.get_item("asgi").unwrap().unwrap(); + let version: String = asgi.get_item("version").unwrap().extract().unwrap(); + assert_eq!(version, "3.0"); + let spec: String = asgi.get_item("spec_version").unwrap().extract().unwrap(); + assert_eq!(spec, "2.4"); + + let state = scope.get_item("state").unwrap().unwrap(); + assert_eq!(state.len().unwrap(), 0); + }); + } + + #[test] + fn startup_event_has_correct_type() { + with_py(|py| { + let event = build_startup_event(py).unwrap(); + let event = event.bind(py); + let t: String = event.get_item("type").unwrap().extract().unwrap(); + assert_eq!(t, "lifespan.startup"); + }); + } + + #[test] + fn shutdown_event_has_correct_type() { + with_py(|py| { + let event = build_shutdown_event(py).unwrap(); + let event = event.bind(py); + let t: String = event.get_item("type").unwrap().extract().unwrap(); + assert_eq!(t, "lifespan.shutdown"); + }); + } + + #[test] + fn lifespan_send_startup_complete() { + with_py(|py| { + let (startup_tx, mut startup_rx) = oneshot::channel(); + let (shutdown_tx, _shutdown_rx) = oneshot::channel(); + let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); + + let event = PyDict::new(py); + event.set_item("type", "lifespan.startup.complete").unwrap(); + send.__call__(py, event).unwrap(); + + let result = startup_rx.try_recv().unwrap(); + assert!(matches!(result, LifespanResult::Complete)); + }); + } + + #[test] + fn lifespan_send_startup_failed() { + with_py(|py| { + let (startup_tx, mut startup_rx) = oneshot::channel(); + let (shutdown_tx, _shutdown_rx) = oneshot::channel(); + let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); + + let event = PyDict::new(py); + event.set_item("type", "lifespan.startup.failed").unwrap(); + event.set_item("message", "db connection refused").unwrap(); + send.__call__(py, event).unwrap(); + + let result = startup_rx.try_recv().unwrap(); + assert!( + matches!(&result, LifespanResult::Failed(msg) if msg == "db connection refused"), + "expected Failed(\"db connection refused\"), got {result:?}" + ); + }); + } + + #[test] + fn lifespan_send_startup_failed_no_message() { + with_py(|py| { + let (startup_tx, mut startup_rx) = oneshot::channel(); + let (shutdown_tx, _shutdown_rx) = oneshot::channel(); + let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); + + let event = PyDict::new(py); + event.set_item("type", "lifespan.startup.failed").unwrap(); + send.__call__(py, event).unwrap(); + + let result = startup_rx.try_recv().unwrap(); + assert!( + matches!(&result, LifespanResult::Failed(msg) if msg.is_empty()), + "expected Failed(\"\"), got {result:?}" + ); + }); + } + + #[test] + fn lifespan_send_shutdown_complete() { + with_py(|py| { + let (startup_tx, _startup_rx) = oneshot::channel(); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); + + // Consume startup first (simulates normal flow). + let event = PyDict::new(py); + event.set_item("type", "lifespan.startup.complete").unwrap(); + send.__call__(py, event).unwrap(); + + let event = PyDict::new(py); + event + .set_item("type", "lifespan.shutdown.complete") + .unwrap(); + send.__call__(py, event).unwrap(); + + let result = shutdown_rx.try_recv().unwrap(); + assert!(matches!(result, LifespanResult::Complete)); + }); + } + + #[test] + fn lifespan_send_shutdown_failed() { + with_py(|py| { + let (startup_tx, _startup_rx) = oneshot::channel(); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); + + // Consume startup first. + let event = PyDict::new(py); + event.set_item("type", "lifespan.startup.complete").unwrap(); + send.__call__(py, event).unwrap(); + + let event = PyDict::new(py); + event.set_item("type", "lifespan.shutdown.failed").unwrap(); + event.set_item("message", "cleanup error").unwrap(); + send.__call__(py, event).unwrap(); + + let result = shutdown_rx.try_recv().unwrap(); + assert!( + matches!(&result, LifespanResult::Failed(msg) if msg == "cleanup error"), + "expected Failed(\"cleanup error\"), got {result:?}" + ); + }); + } + + #[test] + fn lifespan_send_unknown_event_type() { + with_py(|py| { + let (startup_tx, _startup_rx) = oneshot::channel(); + let (shutdown_tx, _shutdown_rx) = oneshot::channel(); + let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); + + let event = PyDict::new(py); + event.set_item("type", "lifespan.unknown").unwrap(); + let result = send.__call__(py, event); + assert!(result.is_err()); + }); + } + + #[test] + fn send_error_during_startup_signals_unsupported() { + with_py(|py| { + let (startup_tx, mut startup_rx) = oneshot::channel(); + let (shutdown_tx, _shutdown_rx) = oneshot::channel(); + let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); + + send.send_error("TypeError: ...".to_owned()); + + let result = startup_rx.try_recv().unwrap(); + assert!(matches!(result, LifespanResult::Unsupported)); + }); + } + + #[test] + fn send_error_during_shutdown_signals_failed() { + with_py(|py| { + let (startup_tx, _startup_rx) = oneshot::channel(); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); + + // Consume startup to transition phase. + let event = PyDict::new(py); + event.set_item("type", "lifespan.startup.complete").unwrap(); + send.__call__(py, event).unwrap(); + + send.send_error("RuntimeError: cleanup failed".to_owned()); + + let result = shutdown_rx.try_recv().unwrap(); + assert!( + matches!(&result, LifespanResult::Failed(msg) if msg.contains("cleanup failed")), + "expected Failed containing \"cleanup failed\", got {result:?}" + ); + }); + } + + #[test] + fn lifespan_receive_first_call_returns_startup() { + with_py(|py| { + let (_tx, rx) = oneshot::channel(); + let receive = LifespanReceive::new(rx); + + let awaitable = receive.__call__(py).unwrap(); + // The awaitable should be a ResolvedAwaitableWithValue. + // We can check it implements __await__. + assert!(awaitable.hasattr("__await__").unwrap()); + }); + } + + #[test] + fn lifespan_result_debug() { + let c = LifespanResult::Complete; + assert!(format!("{c:?}").contains("Complete")); + let f = LifespanResult::Failed("err".to_owned()); + assert!(format!("{f:?}").contains("Failed")); + let u = LifespanResult::Unsupported; + assert!(format!("{u:?}").contains("Unsupported")); + } + + #[tokio::test] + async fn lifespan_pending_wait_startup_complete() { + let (startup_tx, startup_rx) = oneshot::channel(); + let (shutdown_trigger_tx, _shutdown_trigger_rx) = oneshot::channel(); + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + + let pending = LifespanPending { + startup_rx, + shutdown_trigger_tx, + shutdown_rx, + }; + + let _ = startup_tx.send(LifespanResult::Complete); + let result = pending.wait_startup().await; + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[tokio::test] + async fn lifespan_pending_wait_startup_unsupported() { + let (startup_tx, startup_rx) = oneshot::channel(); + let (shutdown_trigger_tx, _shutdown_trigger_rx) = oneshot::channel(); + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + + let pending = LifespanPending { + startup_rx, + shutdown_trigger_tx, + shutdown_rx, + }; + + let _ = startup_tx.send(LifespanResult::Unsupported); + let result = pending.wait_startup().await; + assert!(result.unwrap().is_none()); + } + + #[tokio::test] + async fn lifespan_pending_wait_startup_failed() { + let (startup_tx, startup_rx) = oneshot::channel(); + let (shutdown_trigger_tx, _shutdown_trigger_rx) = oneshot::channel(); + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + + let pending = LifespanPending { + startup_rx, + shutdown_trigger_tx, + shutdown_rx, + }; + + let _ = startup_tx.send(LifespanResult::Failed("db error".to_owned())); + let result = pending.wait_startup().await; + assert_eq!(result.unwrap_err(), "db error"); + } + + #[tokio::test] + async fn lifespan_handle_trigger_shutdown_complete() { + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let (trigger_tx, trigger_rx) = oneshot::channel(); + + let handle = LifespanHandle { + shutdown_trigger_tx: Some(trigger_tx), + shutdown_rx: Some(shutdown_rx), + }; + + // Simulate the app responding to shutdown trigger. + tokio::spawn(async move { + let _ = trigger_rx.await; + let _ = shutdown_tx.send(LifespanResult::Complete); + }); + + let result = handle.trigger_shutdown().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn lifespan_handle_trigger_shutdown_failed() { + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let (trigger_tx, trigger_rx) = oneshot::channel(); + + let handle = LifespanHandle { + shutdown_trigger_tx: Some(trigger_tx), + shutdown_rx: Some(shutdown_rx), + }; + + tokio::spawn(async move { + let _ = trigger_rx.await; + let _ = shutdown_tx.send(LifespanResult::Failed("cleanup err".to_owned())); + }); + + let result = handle.trigger_shutdown().await; + assert_eq!(result.unwrap_err(), "cleanup err"); + } +} diff --git a/crates/framework/src/asgi/mod.rs b/crates/framework/src/asgi/mod.rs index e0310cd7..e09d7cb6 100644 --- a/crates/framework/src/asgi/mod.rs +++ b/crates/framework/src/asgi/mod.rs @@ -6,6 +6,7 @@ pub mod app; pub mod channel_body; pub mod dispatch; +pub mod lifespan; pub mod queue; pub mod scope; pub mod slot_receive; diff --git a/crates/framework/src/pyapi.rs b/crates/framework/src/pyapi.rs index 9383ad95..28a3f506 100644 --- a/crates/framework/src/pyapi.rs +++ b/crates/framework/src/pyapi.rs @@ -11,6 +11,10 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + // ASGI lifespan protocol types + m.add_class::()?; + m.add_class::()?; + // 3-thread dispatch pipeline types m.add_class::()?; m.add_class::()?; diff --git a/crates/framework/src/supervision/worker.rs b/crates/framework/src/supervision/worker.rs index ee5a1832..f4a37692 100644 --- a/crates/framework/src/supervision/worker.rs +++ b/crates/framework/src/supervision/worker.rs @@ -8,7 +8,7 @@ use super::ipc::channel::WorkerChannel; use super::ipc::protocol::{BootstrapError, IpcMessage, Nonce, WorkerBootstrap}; use super::signal::shutdown_signal; use super::worker_context::WorkerContext; -use crate::asgi::app::{AppSource, ModuleImport, format_pyerr}; +use crate::asgi::app::{ModuleImport, format_pyerr}; use crate::io::EventLoop; use crate::protocol::http::service::{ApxService, ServiceConfig, serve_tcp}; use crate::transport::{Listener, TransportConfig, TransportError}; @@ -39,6 +39,10 @@ pub enum WorkerError { /// Serving requests failed. #[error("serve failed: {0}")] Serve(std::io::Error), + + /// ASGI lifespan startup failed. + #[error("lifespan startup failed: {0}")] + LifespanStartup(String), } /// Format a worker error with full Python traceback when available. @@ -47,10 +51,13 @@ pub enum WorkerError { /// `AppLoadError::ImportFailed` and renders its traceback. Falls back to /// the standard `Display` chain for non-Python errors. pub fn format_worker_error(err: &WorkerError) -> String { - if let WorkerError::AppLoad(crate::asgi::app::AppLoadError::ImportFailed { source, .. }) = err { - return Python::attach(|py| format_pyerr(py, source)); + match err { + WorkerError::AppLoad(crate::asgi::app::AppLoadError::ImportFailed { source, .. }) => { + Python::attach(|py| format_pyerr(py, source)) + } + WorkerError::LifespanStartup(msg) => format!("lifespan startup failed: {msg}"), + _ => err.to_string(), } - err.to_string() } /// Phase 1 runtime: TCP listener + Python interpreter (expensive, survives reloads). @@ -118,6 +125,10 @@ async fn signal_readiness(channel: &mut WorkerChannel) -> Result<(), WorkerError struct AppReady { dispatch: Arc, telemetry: crate::telemetry::config::TelemetryConfig, + /// Raw ASGI callable for the lifespan protocol. + asgi_app: Py, + /// Cached `apx._bridge.launch` for submitting coroutines to the asyncio thread. + launch_fn: Py, } crate::opaque_debug!(AppReady); @@ -134,23 +145,27 @@ fn load_app(runtime: &WorkerRuntime, bootstrap: &WorkerBootstrap) -> Result Result, WorkerError> { - let launch_fn = register_launch(py) - .map_err(|e| WorkerError::PythonInit(format!("register launch: {e}")))?; - Ok(Arc::new(WorkerContext { - pipeline: Arc::clone(&pipeline), - call_soon_threadsafe: el.call_soon_threadsafe().clone_ref(py), - launch_fn, - })) - })? + Python::attach( + |py| -> Result<(Arc, Py), WorkerError> { + let launch_fn = register_launch(py) + .map_err(|e| WorkerError::PythonInit(format!("register launch: {e}")))?; + let launch_fn_ref = launch_fn.clone_ref(py); + let ctx = Arc::new(WorkerContext { + pipeline: Arc::clone(&pipeline), + call_soon_threadsafe: el.call_soon_threadsafe().clone_ref(py), + launch_fn, + }); + Ok((ctx, launch_fn_ref)) + }, + )? }; let server_addr = runtime.listener.local_addr(); let event_loop_py = runtime.event_loop.event_loop_py(); - let dispatch = Python::attach(|py| { - ModuleImport::new(bootstrap.app_module.as_str(), bootstrap.dev_mode).build( + let (dispatch, asgi_app) = Python::attach(|py| { + ModuleImport::new(bootstrap.app_module.as_str(), bootstrap.dev_mode).build_with_app( py, ctx, event_loop_py, @@ -168,6 +183,8 @@ fn load_app(runtime: &WorkerRuntime, bootstrap: &WorkerBootstrap) -> Result, ) -> Result<(), WorkerError> { let (ipc_reader, mut ipc_writer) = runtime.channel.split(); @@ -280,6 +298,17 @@ async fn serve( .await; } + // Trigger ASGI lifespan shutdown while the asyncio loop is still running. + if let Some(handle) = lifespan + && let Err(e) = handle.trigger_shutdown().await + { + tracing::warn!( + name: "apx.worker.lifespan_shutdown_failed", + error = %e, + "ASGI lifespan shutdown failed" + ); + } + let _ = ipc_writer.send(&IpcMessage::Drained).await; apx_core::tracing_init::shutdown_telemetry(); @@ -288,6 +317,43 @@ async fn serve( Ok(()) } +/// Launch the ASGI lifespan protocol and await startup completion. +/// +/// Returns `Ok(Some(handle))` if the app completed lifespan startup, +/// `Ok(None)` if the app does not support lifespan, or `Err` on failure. +async fn launch_and_await_lifespan( + runtime: &WorkerRuntime, + ready: &AppReady, +) -> Result, WorkerError> { + let pending = Python::attach(|py| { + crate::asgi::lifespan::launch_lifespan( + py, + &runtime.event_loop, + &ready.asgi_app, + &ready.launch_fn, + ) + }) + .map_err(|e| WorkerError::LifespanStartup(format!("{e}")))?; + + match pending.wait_startup().await { + Ok(Some(handle)) => { + tracing::info!( + name: "apx.worker.lifespan_startup_complete", + "ASGI lifespan startup complete" + ); + Ok(Some(handle)) + } + Ok(None) => { + tracing::info!( + name: "apx.worker.lifespan_unsupported", + "app does not support ASGI lifespan protocol" + ); + Ok(None) + } + Err(msg) => Err(WorkerError::LifespanStartup(msg)), + } +} + /// Connect, init, load app, signal readiness, and serve. /// /// If app loading fails, sends `StartupFailed` over IPC so the supervisor @@ -314,6 +380,19 @@ pub async fn run_worker( } }; + // ASGI lifespan startup — run app startup hooks before accepting traffic. + let lifespan = match launch_and_await_lifespan(&runtime, &ready).await { + Ok(handle) => handle, + Err(e) => { + let detail = format_worker_error(&e); + let _ = runtime + .channel + .send(&IpcMessage::StartupFailed { error: detail }) + .await; + return Err(e); + } + }; + signal_readiness(&mut runtime.channel).await?; relay_telemetry(&mut runtime.channel, &bootstrap, &ready.telemetry).await?; init_metrics(&ready.telemetry); @@ -325,7 +404,7 @@ pub async fn run_worker( } let service = build_service(&runtime, &bootstrap, ready.dispatch); - serve(runtime, service, bootstrap.drain_timeout_secs).await + serve(runtime, service, bootstrap.drain_timeout_secs, lifespan).await } /// Detect worker mode and connect to the supervisor's IPC channel. diff --git a/tests/integration/test_apx.py b/tests/integration/test_apx.py index 31775d65..751eea0b 100644 --- a/tests/integration/test_apx.py +++ b/tests/integration/test_apx.py @@ -11,6 +11,37 @@ import pytest +# --------------------------------------------------------------------------- +# Lifespan +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestLifespan: + """Verify that the ASGI lifespan protocol fires startup hooks. + + The bench app's lifespan context manager opens an SQLite connection, + creates the items table, and seeds 10 default rows. If lifespan + never runs, ``app.state.db`` is unset and all DB-dependent routes fail. + """ + + def test_lifespan_startup_ran(self, client: httpx.Client) -> None: + """DB-dependent route works — proves lifespan startup completed.""" + r = client.get("/api/items") + assert r.status_code == 200 + items = r.json() + assert isinstance(items, list) + assert len(items) >= 10, "lifespan should have seeded 10 default items" + + def test_db_state_accessible(self, client: httpx.Client) -> None: + """Individual item fetch works — app.state.db is live.""" + r = client.get("/api/items/1") + assert r.status_code == 200 + item = r.json() + assert item["id"] == 1 + assert item["name"] == "Item 1" + + # --------------------------------------------------------------------------- # Health & meta # --------------------------------------------------------------------------- From 262d4a2d50fd60e408217160c05d8e28522e7d3f Mon Sep 17 00:00:00 2001 From: renardeinside Date: Wed, 1 Apr 2026 00:03:33 +0200 Subject: [PATCH 03/18] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20deduplic?= =?UTF-8?q?ate=20ASGI=20version=20constants=20into=20asgi/mod.rs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/framework/src/asgi/lifespan.rs | 6 +----- crates/framework/src/asgi/mod.rs | 6 ++++++ crates/framework/src/asgi/scope.rs | 6 +----- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/crates/framework/src/asgi/lifespan.rs b/crates/framework/src/asgi/lifespan.rs index c0c5b57b..98592a33 100644 --- a/crates/framework/src/asgi/lifespan.rs +++ b/crates/framework/src/asgi/lifespan.rs @@ -258,11 +258,7 @@ impl LifespanSend { // ── Scope builder ──────────────────────────────────────────────────────── -/// ASGI protocol version string. -const ASGI_VERSION: &str = "3.0"; - -/// ASGI spec version string. -const ASGI_SPEC_VERSION: &str = "2.4"; +use super::{ASGI_SPEC_VERSION, ASGI_VERSION}; /// Build the ASGI lifespan scope dict. fn build_lifespan_scope(py: Python<'_>) -> PyResult> { diff --git a/crates/framework/src/asgi/mod.rs b/crates/framework/src/asgi/mod.rs index e09d7cb6..22bfdde7 100644 --- a/crates/framework/src/asgi/mod.rs +++ b/crates/framework/src/asgi/mod.rs @@ -3,6 +3,12 @@ //! Translates Rust domain types (InboundRequest, OutboundResponse) to/from //! ASGI protocol objects (scope, receive, send). +/// ASGI protocol version string. +pub const ASGI_VERSION: &str = "3.0"; + +/// ASGI spec version string. +pub const ASGI_SPEC_VERSION: &str = "2.4"; + pub mod app; pub mod channel_body; pub mod dispatch; diff --git a/crates/framework/src/asgi/scope.rs b/crates/framework/src/asgi/scope.rs index b1be510e..49cc19cf 100644 --- a/crates/framework/src/asgi/scope.rs +++ b/crates/framework/src/asgi/scope.rs @@ -18,11 +18,7 @@ use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::{Mutex, mpsc, oneshot}; -/// ASGI protocol version string. -const ASGI_VERSION: &str = "3.0"; - -/// ASGI spec version string. -const ASGI_SPEC_VERSION: &str = "2.4"; +use super::{ASGI_SPEC_VERSION, ASGI_VERSION}; /// Default HTTP scheme (TLS detection is a future extension). const DEFAULT_SCHEME: &str = "http"; From 75c5925030db2b86bf1a12ccca05a16a38b2cdae Mon Sep 17 00:00:00 2001 From: renardeinside Date: Wed, 1 Apr 2026 00:46:48 +0200 Subject: [PATCH 04/18] =?UTF-8?q?=F0=9F=9A=80=20perf:=20optimize=20ASGI=20?= =?UTF-8?q?data=20path=20(P3/P4/P5/P6/P12/P13)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/framework/src/asgi/dispatch.rs | 32 ++++--- crates/framework/src/asgi/queue.rs | 22 +---- crates/framework/src/asgi/scope.rs | 114 +++++++++++++++++++++---- crates/framework/src/asgi/slot_send.rs | 55 ++++++------ crates/framework/src/io/channel.rs | 18 ++-- src/apx/_dispatch.py | 4 +- 6 files changed, 165 insertions(+), 80 deletions(-) diff --git a/crates/framework/src/asgi/dispatch.rs b/crates/framework/src/asgi/dispatch.rs index bc61cd0c..a340d553 100644 --- a/crates/framework/src/asgi/dispatch.rs +++ b/crates/framework/src/asgi/dispatch.rs @@ -9,7 +9,7 @@ use crate::asgi::channel_body::ChannelBody; use crate::asgi::scope::ScopeInterns; use crate::dispatch::Dispatch; -use crate::io::channel::{RequestSlot, ResponseData, Wakeup}; +use crate::io::channel::{RequestSlot, ResponseData, SlotBody, Wakeup}; use crate::protocol::http::error::AppError; use crate::supervision::worker_context::WorkerContext; use crate::telemetry::context::TraceContext; @@ -216,15 +216,11 @@ async fn dispatch_pipeline( fn response_data_to_outbound(data: ResponseData) -> Result { let status = http::StatusCode::from_u16(data.status).unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR); - let mut headers = HeaderMap::with_capacity(data.headers.len()); - for (name, value) in &data.headers { - let header_name = HeaderName::from_bytes(name) - .map_err(|e| AppError::Internal(format!("invalid header name: {e}")))?; - let header_value = HeaderValue::from_bytes(value) - .map_err(|e| AppError::Internal(format!("invalid header value: {e}")))?; - headers.append(header_name, header_value); - } - let body = ResponseBody::Stream(Box::pin(ChannelBody::new(data.body_rx))); + let headers = build_response_headers(&data.headers)?; + let body = match data.body { + SlotBody::Complete(bytes) => ResponseBody::Fixed(bytes), + SlotBody::Chunked(rx) => ResponseBody::Stream(Box::pin(ChannelBody::new(rx))), + }; Ok(OutboundResponse { status, @@ -234,6 +230,22 @@ fn response_data_to_outbound(data: ResponseData) -> Result Result { + let mut headers = HeaderMap::with_capacity(raw.len()); + for (name, value) in raw { + let header_name = HeaderName::from_lowercase(name) + .map_err(|e| AppError::Internal(format!("invalid header name: {e}")))?; + let header_value = HeaderValue::from_bytes(value) + .map_err(|e| AppError::Internal(format!("invalid header value: {e}")))?; + headers.append(header_name, header_value); + } + Ok(headers) +} + /// Client-visible body for internal errors. const INTERNAL_ERROR_BODY: &[u8] = b"Internal Server Error"; diff --git a/crates/framework/src/asgi/queue.rs b/crates/framework/src/asgi/queue.rs index a8eecfe6..d2ea271d 100644 --- a/crates/framework/src/asgi/queue.rs +++ b/crates/framework/src/asgi/queue.rs @@ -11,7 +11,6 @@ use crate::asgi::scope::{ use crate::asgi::slot_receive::SlotReceive; use crate::asgi::slot_send::SlotSend; use crate::io::channel::{InboundChannel, RequestSlot, Wakeup}; -use crate::transport::types::{BodyStream, InboundRequest, TransportKind}; use pyo3::prelude::*; use pyo3::types::PyDict; use std::sync::Arc; @@ -111,11 +110,10 @@ impl RequestQueue { crate::telemetry::context::set_python_context(py, ctx)?; } - let request = slot_to_inbound_request(&slot); let scope = scope_from_template( py, &self.scope_interns.scope_template, - &request, + &slot, None, &self.scope_interns, )?; @@ -131,21 +129,3 @@ impl RequestQueue { }) } } - -/// Build a temporary [`InboundRequest`] from a [`RequestSlot`] for -/// `scope_from_template`. The body is already consumed so we pass `Empty`. -fn slot_to_inbound_request(slot: &RequestSlot) -> InboundRequest { - InboundRequest::new( - slot.method.clone(), - slot.path.clone(), - slot.query_string.clone(), - slot.headers.clone(), - BodyStream::Empty, - slot.protocol, - TransportKind::Tcp, - slot.client_addr, - slot.server_addr, - Vec::new(), - http::Extensions::new(), - ) -} diff --git a/crates/framework/src/asgi/scope.rs b/crates/framework/src/asgi/scope.rs index 49cc19cf..fa099d6a 100644 --- a/crates/framework/src/asgi/scope.rs +++ b/crates/framework/src/asgi/scope.rs @@ -6,6 +6,7 @@ //! These types enable Starlette's `Request`, `StreamingResponse`, and `WebSocket` //! to work unmodified against a Rust-backed ASGI server. +use crate::io::channel::RequestSlot; use crate::protocol::http::error::AppError; use crate::transport::types::{InboundRequest, OutboundResponse, ProtocolVersion, ResponseBody}; use bytes::Bytes; @@ -53,6 +54,9 @@ pub struct ScopeInterns { // ── Scope template ── /// Pre-built HTTP scope dict with fixed fields. `dict.copy()` per request. pub(crate) scope_template: Py, + // ── Cached empty dict ── + /// Shared empty dict for parameterless routes (avoids `PyDict::new` per request). + pub(crate) empty_dict: Py, } /// Fixed dict keys used in ASGI scope construction. @@ -269,6 +273,7 @@ impl ScopeInterns { server_tuple, versions, scope_template, + empty_dict: PyDict::new(py).unbind(), } } } @@ -1119,6 +1124,71 @@ fn extract_bytes_field(obj: &Bound<'_, PyAny>) -> PyResult> { } } +// ── ScopeSource ───────────────────────────────────────────────────────── + +/// Read-only access to the fields needed for ASGI scope construction. +/// +/// Implemented by both [`InboundRequest`] (legacy ASGI path, WS path) and +/// [`RequestSlot`] (zero-GIL crossbeam path), so `scope_from_template` +/// can accept either without an intermediate clone. +pub trait ScopeSource { + fn method(&self) -> &http::Method; + fn path(&self) -> &str; + fn query_string(&self) -> &Bytes; + fn headers(&self) -> &HeaderMap; + fn protocol(&self) -> ProtocolVersion; + fn client_addr(&self) -> Option; + fn path_params(&self) -> &[(String, String)]; +} + +impl ScopeSource for InboundRequest { + fn method(&self) -> &http::Method { + &self.method + } + fn path(&self) -> &str { + &self.path + } + fn query_string(&self) -> &Bytes { + &self.query_string + } + fn headers(&self) -> &HeaderMap { + &self.headers + } + fn protocol(&self) -> ProtocolVersion { + self.protocol + } + fn client_addr(&self) -> Option { + self.client_addr + } + fn path_params(&self) -> &[(String, String)] { + &self.path_params + } +} + +impl ScopeSource for RequestSlot { + fn method(&self) -> &http::Method { + &self.method + } + fn path(&self) -> &str { + &self.path + } + fn query_string(&self) -> &Bytes { + &self.query_string + } + fn headers(&self) -> &HeaderMap { + &self.headers + } + fn protocol(&self) -> ProtocolVersion { + self.protocol + } + fn client_addr(&self) -> Option { + self.client_addr + } + fn path_params(&self) -> &[(String, String)] { + &[] + } +} + // ── scope_from_template ────────────────────────────────────────────────── /// Build an HTTP scope from the pre-populated template. @@ -1128,7 +1198,7 @@ fn extract_bytes_field(obj: &Bound<'_, PyAny>) -> PyResult> { pub fn scope_from_template( py: Python<'_>, template: &Py, - request: &InboundRequest, + request: &impl ScopeSource, fastapi_app: Option<&Py>, interns: &ScopeInterns, ) -> PyResult> { @@ -1141,10 +1211,10 @@ pub fn scope_from_template( "scope template copy returned non-dict: {e}" )) })?; - if request.protocol != ProtocolVersion::Http11 { + if request.protocol() != ProtocolVersion::Http11 { scope.set_item( interns.keys.http_version.bind(py), - interns.versions.get(py, request.protocol), + interns.versions.get(py, request.protocol()), )?; } set_scope_request_fields(py, &scope, request, interns)?; @@ -1230,23 +1300,22 @@ fn set_ws_scope_request_fields( fn set_scope_request_fields( py: Python<'_>, dict: &Bound<'_, PyDict>, - request: &InboundRequest, + request: &impl ScopeSource, interns: &ScopeInterns, ) -> PyResult<()> { dict.set_item( interns.keys.http_version.bind(py), - interns.versions.get(py, request.protocol), + interns.versions.get(py, request.protocol()), )?; - dict.set_item(interns.keys.method.bind(py), request.method.as_str())?; - // ASGI spec: "path" is the decoded URL path, "raw_path" is the raw bytes. - dict.set_item(interns.keys.path.bind(py), percent_decode(&request.path))?; + dict.set_item(interns.keys.method.bind(py), request.method().as_str())?; + dict.set_item(interns.keys.path.bind(py), percent_decode(request.path()))?; dict.set_item( interns.keys.raw_path.bind(py), - PyBytes::new(py, request.path.as_bytes()), + PyBytes::new(py, request.path().as_bytes()), )?; dict.set_item( interns.keys.query_string.bind(py), - PyBytes::new(py, &request.query_string), + PyBytes::new(py, request.query_string()), )?; Ok(()) } @@ -1258,11 +1327,11 @@ fn set_scope_request_fields( fn set_scope_headers( py: Python<'_>, dict: &Bound<'_, PyDict>, - request: &InboundRequest, + request: &impl ScopeSource, interns: &ScopeInterns, ) -> PyResult<()> { - let mut pairs: Vec> = Vec::with_capacity(request.headers.len()); - for (name, value) in &request.headers { + let mut pairs: Vec> = Vec::with_capacity(request.headers().len()); + for (name, value) in request.headers() { let n = interns .headers .get(py, name) @@ -1280,11 +1349,11 @@ fn set_scope_headers( fn set_scope_addresses( py: Python<'_>, dict: &Bound<'_, PyDict>, - request: &InboundRequest, + request: &impl ScopeSource, interns: &ScopeInterns, ) -> PyResult<()> { dict.set_item(interns.keys.server.bind(py), interns.server_tuple.bind(py))?; - match request.client_addr { + match request.client_addr() { Some(addr) => { dict.set_item( interns.keys.client.bind(py), @@ -1301,14 +1370,25 @@ fn set_scope_addresses( /// Values are URL-decoded because axum's `RawPathParams` provides percent-encoded /// strings, but Starlette/FastAPI expects decoded values (matching what Starlette's /// own router would produce). +/// +/// When path_params is empty (parameterless routes), reuses a pre-built +/// empty dict singleton from `ScopeInterns` to avoid a `PyDict::new` per request. fn set_scope_path_params( py: Python<'_>, dict: &Bound<'_, PyDict>, - request: &InboundRequest, + request: &impl ScopeSource, interns: &ScopeInterns, ) -> PyResult<()> { + let params = request.path_params(); + if params.is_empty() { + dict.set_item( + interns.keys.path_params.bind(py), + interns.empty_dict.bind(py), + )?; + return Ok(()); + } let pp = PyDict::new(py); - for (k, v) in &request.path_params { + for (k, v) in params { pp.set_item(k.as_str(), percent_decode(v.as_str()))?; } dict.set_item(interns.keys.path_params.bind(py), pp)?; diff --git a/crates/framework/src/asgi/slot_send.rs b/crates/framework/src/asgi/slot_send.rs index aad915d9..d35589db 100644 --- a/crates/framework/src/asgi/slot_send.rs +++ b/crates/framework/src/asgi/slot_send.rs @@ -6,7 +6,7 @@ //! Subsequent body chunks are pushed via the mpsc sender. Dropping the //! sender signals EOF. -use crate::io::channel::ResponseData; +use crate::io::channel::{ResponseData, SlotBody}; use bytes::Bytes; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; @@ -76,21 +76,18 @@ impl SlotSend { "{traceback}", ); if let Some(response_tx) = self.response_tx.take() { - let (body_tx, body_rx) = mpsc::unbounded_channel(); let body = if self.dev_mode { Bytes::from(traceback) } else { INTERNAL_ERROR_BODY }; - let _ = body_tx.send(body); - drop(body_tx); let response = ResponseData { status: 500, headers: vec![( Bytes::from_static(b"content-type"), Bytes::from_static(b"text/plain; charset=utf-8"), )], - body_rx, + body: SlotBody::Complete(body), }; let _ = response_tx.send(response); } @@ -155,7 +152,10 @@ impl SlotSend { Ok(self.resolved.clone_ref(py).into_bound(py).into_any()) } - /// First body chunk: create mpsc, build `ResponseData`, push `OutboundSlot`. + /// First body chunk: build `ResponseData` and fire the tokio oneshot. + /// + /// Non-streaming (`more_body == false`): carries the body inline, + /// skipping the mpsc channel + ChannelBody + Box::pin allocation. fn send_first_body_chunk(&mut self, body: Bytes, more_body: bool) -> PyResult<()> { let status = self.status.take().ok_or_else(|| { pyo3::exceptions::PyRuntimeError::new_err( @@ -164,21 +164,21 @@ impl SlotSend { })?; let headers = self.raw_headers.take().unwrap_or_default(); - let (body_tx, body_rx) = mpsc::unbounded_channel(); - if !body.is_empty() { - let _ = body_tx.send(body); - } - - if more_body { + let slot_body = if more_body { + let (body_tx, body_rx) = mpsc::unbounded_channel(); + if !body.is_empty() { + let _ = body_tx.send(body); + } self.body_tx = Some(body_tx); + SlotBody::Chunked(body_rx) } else { - drop(body_tx); - } + SlotBody::Complete(body) + }; let response = ResponseData { status, headers, - body_rx, + body: slot_body, }; if let Some(response_tx) = self.response_tx.take() { @@ -254,7 +254,8 @@ fn extract_body_bytes(event: &Bound<'_, PyDict>) -> PyResult { #[cfg(test)] #[expect( clippy::unwrap_used, - reason = "test code uses unwrap/assert for clarity" + clippy::panic, + reason = "test code uses unwrap/assert/panic for clarity" )] mod tests { use super::*; @@ -271,10 +272,12 @@ mod tests { let traceback = "Traceback (most recent call last):\n NameError: x\n".to_owned(); slot.send_error(traceback); - let mut response = rx.try_recv().unwrap(); + let response = rx.try_recv().unwrap(); assert_eq!(response.status, 500); - let body = response.body_rx.try_recv().unwrap(); - assert_eq!(body.as_ref(), b"Internal Server Error"); + match response.body { + SlotBody::Complete(b) => assert_eq!(b.as_ref(), b"Internal Server Error"), + SlotBody::Chunked(_) => panic!("expected Complete body"), + } } #[test] @@ -283,12 +286,16 @@ mod tests { let traceback = "Traceback (most recent call last):\n NameError: x\n".to_owned(); slot.send_error(traceback); - let mut response = rx.try_recv().unwrap(); + let response = rx.try_recv().unwrap(); assert_eq!(response.status, 500); - let body = response.body_rx.try_recv().unwrap(); - let body_str = std::str::from_utf8(body.as_ref()).unwrap(); - assert!(body_str.contains("Traceback")); - assert!(body_str.contains("NameError")); + match response.body { + SlotBody::Complete(b) => { + let body_str = std::str::from_utf8(b.as_ref()).unwrap(); + assert!(body_str.contains("Traceback")); + assert!(body_str.contains("NameError")); + } + SlotBody::Chunked(_) => panic!("expected Complete body"), + } } #[test] diff --git a/crates/framework/src/io/channel.rs b/crates/framework/src/io/channel.rs index b8cc7ee0..2f9077a5 100644 --- a/crates/framework/src/io/channel.rs +++ b/crates/framework/src/io/channel.rs @@ -51,18 +51,24 @@ pub struct RequestSlot { // ── ResponseData ───────────────────────────────────────────────────────── -/// Response flowing from Thread 2 → Thread 3 → Thread 1. -/// -/// Uses an unbounded mpsc channel for the body to unify streaming and -/// non-streaming responses under a single code path. +/// Body payload flowing from Thread 2 (asyncio) to Thread 1 (tokio). +#[derive(Debug)] +pub enum SlotBody { + /// Complete body for non-streaming responses (95% of traffic). + Complete(Bytes), + /// Streaming body fed chunk-by-chunk via an mpsc channel. + Chunked(mpsc::UnboundedReceiver), +} + +/// Response flowing from Thread 2 → Thread 1 via tokio oneshot. #[derive(Debug)] pub struct ResponseData { /// HTTP status code. pub status: u16, /// Response headers as raw byte pairs (name, value). pub headers: Vec<(Bytes, Bytes)>, - /// Streaming body channel — one chunk per `send(http.response.body)`. - pub body_rx: mpsc::UnboundedReceiver, + /// Response body — complete or streaming. + pub body: SlotBody, } // ── Wakeup ─────────────────────────────────────────────────────────────── diff --git a/src/apx/_dispatch.py b/src/apx/_dispatch.py index 276d0fae..ed2b8a45 100644 --- a/src/apx/_dispatch.py +++ b/src/apx/_dispatch.py @@ -31,10 +31,10 @@ def install_dispatch( which appends ``_drain_queue`` directly to ``_ready`` (no fd needed). """ - # At ~85µs per materialize(), 8 items ≈ 680µs GIL hold — well under + # At ~31µs per materialize(), 32 items ≈ 1ms GIL hold — well under # the 5ms GIL switch interval (sys.getswitchinterval()), keeping the # drain responsive without excessive re-scheduling overhead. - max_drain_batch: int = 8 + max_drain_batch: int = 32 async def _guarded( scope: dict[str, Any], From 4bd227a190f91ed929e63a6d21067840bfc36bf5 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Wed, 1 Apr 2026 12:21:21 +0200 Subject: [PATCH 05/18] =?UTF-8?q?=F0=9F=9A=80=20perf:=20inline=20coroutine?= =?UTF-8?q?=20driving=20to=20eliminate=20create=5Ftask=20overhead?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/bench/app/api.py | 36 ++- src/apx/_continuation.py | 103 ++++++++ src/apx/_core.pyi | 12 +- src/apx/_dispatch.py | 71 ++++-- src/apx/_scheduler.py | 223 +++++++++++++++++ src/apx/telemetry.py | 28 ++- tests/integration/test_telemetry.py | 3 +- tests/telemetry/test_cross_signal.py | 20 +- tests/telemetry/test_dispatch_metrics.py | 16 +- tests/telemetry/test_error_handling.py | 20 +- tests/telemetry/test_event_name.py | 21 +- tests/telemetry/test_otlp_fields.py | 84 ++----- tests/telemetry/test_sequential_spans.py | 8 +- tests/test_dispatch.py | 292 +++++++++++++++++++---- tests/test_telemetry.py | 14 +- 15 files changed, 730 insertions(+), 221 deletions(-) create mode 100644 src/apx/_continuation.py create mode 100644 src/apx/_scheduler.py diff --git a/scripts/bench/app/api.py b/scripts/bench/app/api.py index 82852aa2..46bf8143 100644 --- a/scripts/bench/app/api.py +++ b/scripts/bench/app/api.py @@ -32,12 +32,22 @@ tags TEXT NOT NULL DEFAULT '[]' )""" -_INSERT_ITEM = "INSERT INTO items (id, name, description, price, tags) VALUES (?, ?, ?, ?, ?)" -_INSERT_ITEM_AUTO = "INSERT INTO items (name, description, price, tags) VALUES (?, ?, ?, ?)" +_INSERT_ITEM = ( + "INSERT INTO items (id, name, description, price, tags) VALUES (?, ?, ?, ?, ?)" +) +_INSERT_ITEM_AUTO = ( + "INSERT INTO items (name, description, price, tags) VALUES (?, ?, ?, ?)" +) def _row_to_item(row: aiosqlite.Row) -> Item: - return Item(id=row[0], name=row[1], description=row[2], price=row[3], tags=json.loads(row[4])) + return Item( + id=row[0], + name=row[1], + description=row[2], + price=row[3], + tags=json.loads(row[4]), + ) async def _seed_defaults(db: aiosqlite.Connection) -> None: @@ -100,7 +110,9 @@ async def health() -> dict[str, str]: @router.get("/items", response_model=list[Item]) async def list_items(db: aiosqlite.Connection = Depends(_get_db)) -> list[Item]: - cursor = await db.execute("SELECT id, name, description, price, tags FROM items ORDER BY id") + cursor = await db.execute( + "SELECT id, name, description, price, tags FROM items ORDER BY id" + ) rows = await cursor.fetchall() return [_row_to_item(row) for row in rows] @@ -117,7 +129,9 @@ async def get_item(item_id: int, db: aiosqlite.Connection = Depends(_get_db)) -> @router.post("/items", response_model=Item, status_code=201) -async def create_item(body: ItemCreate, db: aiosqlite.Connection = Depends(_get_db)) -> Item: +async def create_item( + body: ItemCreate, db: aiosqlite.Connection = Depends(_get_db) +) -> Item: cursor = await db.execute( _INSERT_ITEM_AUTO, (body.name, body.description, body.price, json.dumps(body.tags)), @@ -141,7 +155,13 @@ async def update_item( updated = existing.model_copy(update=body.model_dump(exclude_unset=True)) await db.execute( "UPDATE items SET name = ?, description = ?, price = ?, tags = ? WHERE id = ?", - (updated.name, updated.description, updated.price, json.dumps(updated.tags), item_id), + ( + updated.name, + updated.description, + updated.price, + json.dumps(updated.tags), + item_id, + ), ) await db.commit() return updated @@ -402,9 +422,7 @@ async def telemetry_cross_signal(): ) histogram.observe(42.0, attributes={"scenario": "cross_signal"}) - logging.getLogger("test.cross_signal").warning( - "cross signal stdlib warning" - ) + logging.getLogger("test.cross_signal").warning("cross signal stdlib warning") return {"ok": True} diff --git a/src/apx/_continuation.py b/src/apx/_continuation.py new file mode 100644 index 00000000..16b30be1 --- /dev/null +++ b/src/apx/_continuation.py @@ -0,0 +1,103 @@ +"""Callback-based continuation for suspended coroutines. + +When ``drive_inline`` returns ``Suspended``, the coroutine has yielded +an asyncio Future (real I/O like a database query). ``Continuation`` +attaches a done-callback to that Future and resumes driving when the +I/O completes — entirely on the asyncio thread, no ``create_task``. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable, Coroutine +from typing import Any + +from apx._scheduler import ( + CallSoonCapture, + Completed, + Failed, + SchedulerTask, + Suspended, + _enter_task, + _leave_task, + drive_inline, +) + + +class Continuation: + """Drives a suspended coroutine via done-callbacks. + + Each step uses per-step ``_enter_task`` / ``_leave_task`` brackets, + keeping invariant I1. Runs entirely on the asyncio thread. + """ + + __slots__ = ("_coro", "_loop", "_task", "_capture", "_on_complete") + + def __init__( + self, + coro: Coroutine[Any, Any, Any], + yielded: object, + loop: asyncio.AbstractEventLoop, + task: SchedulerTask, + capture: CallSoonCapture, + on_complete: Callable[[], None] | None = None, + ) -> None: + self._coro: Coroutine[Any, Any, Any] | None = coro + self._loop = loop + self._task = task + self._capture = capture + self._on_complete = on_complete + self._attach(yielded) + + def _attach(self, yielded: object) -> None: + """Attach to a yielded value to resume when ready.""" + if yielded is None: + self._loop.call_soon(self._step) + elif hasattr(yielded, "add_done_callback"): + self._task._waiter = yielded # type: ignore[assignment, ty:invalid-assignment] + yielded.add_done_callback(self._on_future_done) # type: ignore[union-attr, ty:call-non-callable] + else: + self._finish() + + def _on_future_done(self, future: asyncio.Future[Any]) -> None: + self._task._waiter = None + if asyncio.current_task() is not None: + self._loop.call_soon(self._step) + return + self._step() + + def _step(self) -> None: + if self._coro is None: + return + + if self._task._cancel_flag: + self._task._cancel_flag = False + _enter_task(self._loop, self._task) + try: + yielded = self._coro.throw( + asyncio.CancelledError(self._task._cancel_msg) + ) + except (StopIteration, BaseException): + _leave_task(self._loop, self._task) + self._finish() + return + _leave_task(self._loop, self._task) + self._attach(yielded) + return + + self._capture.enter() + result = drive_inline(self._coro, self._task, self._loop, self._capture) + self._capture.leave() + + if isinstance(result, Completed): + self._finish() + elif isinstance(result, Failed): + self._finish() + elif isinstance(result, Suspended): + self._attach(result.yielded) + + def _finish(self) -> None: + self._coro = None + self._task._waiter = None + if self._on_complete is not None: + self._on_complete() diff --git a/src/apx/_core.pyi b/src/apx/_core.pyi index a627ffd7..d79ba8a0 100644 --- a/src/apx/_core.pyi +++ b/src/apx/_core.pyi @@ -65,7 +65,9 @@ class StatusCode(enum.IntEnum): class SpanHandle: """OTEL span usable as sync/async context manager.""" - def __init__(self, name: str, attributes: dict[str, str] | None = None, kind: int = 1) -> None: ... + def __init__( + self, name: str, attributes: dict[str, str] | None = None, kind: int = 1 + ) -> None: ... def __enter__(self) -> SpanHandle: ... def __exit__( self, @@ -95,7 +97,9 @@ class RustCounter: class RustHistogram: """OTLP histogram backed by Rust.""" - def observe(self, value: float, attributes: dict[str, str] | None = None) -> None: ... + def observe( + self, value: float, attributes: dict[str, str] | None = None + ) -> None: ... class RustGauge: """OTLP gauge backed by Rust.""" @@ -106,7 +110,9 @@ def create_histogram( name: str, description: str = "", unit: str = "" ) -> RustHistogram: ... def create_gauge(name: str, description: str = "", unit: str = "") -> RustGauge: ... -def _emit_log(level: int, message: str, logger_name: str, event_name: str = "") -> None: ... +def _emit_log( + level: int, message: str, logger_name: str, event_name: str = "" +) -> None: ... class PyMetricDefinition: """A framework metric definition.""" diff --git a/src/apx/_dispatch.py b/src/apx/_dispatch.py index ed2b8a45..455080ed 100644 --- a/src/apx/_dispatch.py +++ b/src/apx/_dispatch.py @@ -1,7 +1,10 @@ -"""Zero-GIL dispatch loop for the 3-thread architecture. +"""Zero-GIL dispatch loop with inline coroutine driving. -Installs an fd-based wakeup on the asyncio event loop (Unix) or -exposes a drain callback for ``call_soon_threadsafe`` (Windows). +Installs an fd-based wakeup on the asyncio event loop. Drains +requests from the Rust crossbeam channel and drives each ASGI +coroutine inline. Simple handlers complete in ~21us with zero +event loop scheduling. Handlers that suspend on real I/O fall +back to callback-based continuation. Called once from Rust during reactor init via ``py.import(c"apx._dispatch")?.call_method1(c"install_dispatch", ...)``. @@ -15,7 +18,16 @@ from collections.abc import Coroutine from typing import Any, Callable +from apx._continuation import Continuation from apx._core import RequestQueue +from apx._scheduler import ( + CallSoonCapture, + Completed, + Failed, + SchedulerTask, + Suspended, + drive_inline, +) def install_dispatch( @@ -24,17 +36,10 @@ def install_dispatch( app: Callable[..., Coroutine[Any, Any, None]], wakeup_fd: int | None = None, ) -> None: - """Install the zero-GIL dispatch reader on the asyncio event loop. + """Install the inline dispatch driver on the asyncio event loop.""" - On Unix: registers ``wakeup_fd`` with the loop's selector via ``add_reader``. - On Windows: ``wakeup_fd`` is ``None`` — Rust uses ``call_soon_threadsafe`` - which appends ``_drain_queue`` directly to ``_ready`` (no fd needed). - """ - - # At ~31µs per materialize(), 32 items ≈ 1ms GIL hold — well under - # the 5ms GIL switch interval (sys.getswitchinterval()), keeping the - # drain responsive without excessive re-scheduling overhead. - max_drain_batch: int = 32 + max_drain_batch: int = 8 + capture = CallSoonCapture(loop) async def _guarded( scope: dict[str, Any], @@ -44,10 +49,38 @@ async def _guarded( try: await app(scope, receive, send) except Exception as exc: + tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + send.send_error(tb) + + def _dispatch_inline( + scope: dict[str, Any], + receive: Any, + send: Any, + ) -> None: + """Drive one request inline. Falls back on suspension.""" + coro = _guarded(scope, receive, send) + try: + task = SchedulerTask(loop=loop) + + capture.enter() + result = drive_inline(coro, task, loop, capture) + capture.leave() + except BaseException: + coro.close() + raise + + if isinstance(result, Completed): + return + elif isinstance(result, Failed): tb = "".join( - traceback.format_exception(type(exc), exc, exc.__traceback__) + traceback.format_exception( + type(result.exc), result.exc, result.exc.__traceback__ + ) ) send.send_error(tb) + return + elif isinstance(result, Suspended): + Continuation(coro, result.yielded, loop, task, capture) def _drain_queue() -> None: for _ in range(max_drain_batch): @@ -55,12 +88,7 @@ def _drain_queue() -> None: if result is None: return scope, receive, send = result - loop.create_task(_guarded(scope, receive, send)) - # Batch full — more items may remain. Yield to the event loop - # so _run_once can process I/O, fire done callbacks, and give - # thread pool workers a GIL window before we drain more. - # call_soon (not threadsafe): we're already on the asyncio - # thread, no selector wake needed. + _dispatch_inline(scope, receive, send) loop.call_soon(_drain_queue) if wakeup_fd is not None: @@ -72,7 +100,6 @@ def _on_readable() -> None: pass _drain_queue() - # add_reader is not thread-safe; schedule it onto the asyncio thread. loop.call_soon_threadsafe(loop.add_reader, wakeup_fd, _on_readable) else: - install_dispatch._drain_queue = _drain_queue # type: ignore[attr-defined] + install_dispatch._drain_queue = _drain_queue # type: ignore[attr-defined, ty:unresolved-attribute] diff --git a/src/apx/_scheduler.py b/src/apx/_scheduler.py new file mode 100644 index 00000000..10d038e1 --- /dev/null +++ b/src/apx/_scheduler.py @@ -0,0 +1,223 @@ +"""Inline coroutine driver for the 3-thread dispatch architecture. + +Drives ASGI coroutines to completion within a single ``_run_once`` +callback, eliminating ``create_task`` scheduling overhead. Falls back +to callback-based continuation for coroutines that suspend on real I/O. + +Safety: all driving happens on the asyncio thread during callback +processing (``current_task() is None``). Per-step ``_enter_task`` / +``_leave_task`` brackets maintain invariant I1. See +``.plans/framework/io/pythonic-inlining.md`` for the full analysis. +""" + +from __future__ import annotations + +import asyncio +import contextvars +import time +from collections import deque +from collections.abc import Callable, Coroutine +from typing import Any + +# ── Constants ──────────────────────────────────────────────────────── + +STEP_BUDGET: int = 256 +"""Maximum coroutine steps before falling back to continuation.""" + +TIME_BUDGET_S: float = 0.005 +"""Maximum wall-clock seconds for inline driving (5ms).""" + +FLUSH_BUDGET: int = 64 +"""Maximum captured callbacks to process between drive steps.""" + +# ── asyncio internals (stable since 3.7) ──────────────────────────── + +_enter_task = asyncio.tasks._enter_task # type: ignore[attr-defined] +_leave_task = asyncio.tasks._leave_task # type: ignore[attr-defined] + + +# ── Scheduler task ─────────────────────────────────────────────────── + + +async def _park_forever() -> None: + """Sentinel coroutine that parks on an unresolved Future.""" + await asyncio.get_event_loop().create_future() + + +class SchedulerTask(asyncio.Task): + """Placeholder task for ``_enter_task`` / ``_leave_task`` bracketing. + + Parks on an unresolved Future (``done() == False``). Provides + cancel forwarding so ``asyncio.timeout`` and ``anyio.fail_after`` + can signal the inline driver or continuation. + """ + + __slots__ = ( + "_cancel_flag", + "_cancel_msg", + "_cancel_count", + "_waiter", + "_drive_context", + ) + + def __init__(self, *, loop: asyncio.AbstractEventLoop) -> None: + # Capture context BEFORE super().__init__ which may alter it. + # We store it explicitly because CPython's C-implemented Task + # does not expose ``_context`` as a Python-accessible attribute. + self._drive_context: contextvars.Context = contextvars.copy_context() + super().__init__(_park_forever(), loop=loop) + self._log_destroy_pending: bool = False + self._cancel_flag: bool = False + self._cancel_msg: str | None = None + self._cancel_count: int = 0 + self._waiter: asyncio.Future[Any] | None = None + + def cancel(self, msg: str | None = None) -> bool: + self._cancel_flag = True + self._cancel_msg = msg + self._cancel_count += 1 + if self._waiter is not None and not self._waiter.done(): + self._waiter.cancel(msg=msg) + return True + + def cancelling(self) -> int: + return self._cancel_count + + def uncancel(self) -> int: + self._cancel_count = max(0, self._cancel_count - 1) + return self._cancel_count + + +# ── call_soon capture ──────────────────────────────────────────────── + + +class CallSoonCapture: + """Intercepts ``loop.call_soon`` during inline driving. + + While active, callbacks are captured into an internal queue instead + of being appended to the event loop's ``_ready`` deque. This + prevents the sentinel ``__step`` from ``SchedulerTask.__init__`` + (invariant I7) from polluting ``_run_once``. + + Captured callbacks are processed between drive steps via + ``flush()`` or spilled back to the real ``call_soon`` on ``leave()``. + """ + + __slots__ = ("_original", "_queue", "_active") + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self._original: Callable[..., Any] = loop.call_soon + self._queue: deque[tuple[Callable[..., Any], tuple[Any, ...]]] = deque() + self._active: bool = False + loop.call_soon = self._intercept # type: ignore[assignment, ty:invalid-assignment] + + def _intercept( + self, + callback: Callable[..., Any], + *args: Any, + context: Any = None, + ) -> None: + if self._active: + self._queue.append((callback, args)) + else: + self._original(callback, *args) + + def enter(self) -> None: + """Start capturing ``call_soon`` callbacks.""" + self._active = True + self._queue.clear() + + def leave(self) -> None: + """Stop capturing and spill remaining callbacks to the real loop.""" + self._active = False + original = self._original + while self._queue: + cb, args = self._queue.popleft() + original(cb, *args) + + def flush(self, budget: int = FLUSH_BUDGET) -> None: + """Process captured callbacks inline (between drive steps).""" + queue = self._queue + while queue and budget > 0: + cb, args = queue.popleft() + cb(*args) + budget -= 1 + + +# ── Inline driver ──────────────────────────────────────────────────── + + +class _DriveResult: + """Base class for inline drive outcomes.""" + + +class Completed(_DriveResult): + """Coroutine ran to completion.""" + + +class Suspended(_DriveResult): + """Coroutine yielded an asyncio Future (real I/O).""" + + __slots__ = ("yielded",) + + def __init__(self, yielded: object) -> None: + self.yielded = yielded + + +class Failed(_DriveResult): + """Coroutine raised an exception.""" + + __slots__ = ("exc",) + + def __init__(self, exc: BaseException) -> None: + self.exc = exc + + +_COMPLETED = Completed() + + +def drive_inline( + coro: Coroutine[Any, Any, Any], + task: SchedulerTask, + loop: asyncio.AbstractEventLoop, + capture: CallSoonCapture, +) -> _DriveResult: + """Drive a coroutine to completion or first real suspension. + + Must be called from a ``_run_once`` callback where + ``current_task() is None``. Uses per-step ``_enter_task`` / + ``_leave_task`` brackets (invariant I1). + + Returns: + Completed — coroutine finished; response already fired. + Suspended — coroutine yielded an asyncio Future; needs + callback-based continuation. + Failed — coroutine raised; caller should log / error. + """ + budget = STEP_BUDGET + deadline = time.monotonic() + TIME_BUDGET_S + context_run = task._drive_context.run + + while True: + _enter_task(loop, task) + try: + result = context_run(coro.send, None) + except StopIteration: + _leave_task(loop, task) + return _COMPLETED + except BaseException as exc: + _leave_task(loop, task) + return Failed(exc) + _leave_task(loop, task) + + capture.flush() + + if result is None: + budget -= 1 + if budget <= 0 or time.monotonic() > deadline: + return Suspended(None) + continue + + if getattr(result, "_asyncio_future_blocking", False): + result._asyncio_future_blocking = False + return Suspended(result) diff --git a/src/apx/telemetry.py b/src/apx/telemetry.py index e498d3f7..f5131c4b 100644 --- a/src/apx/telemetry.py +++ b/src/apx/telemetry.py @@ -383,7 +383,9 @@ class _Log: __slots__ = () @staticmethod - def trace(message: str, *, event_name: str | None = None, **attributes: Any) -> None: + def trace( + message: str, *, event_name: str | None = None, **attributes: Any + ) -> None: """Emit a TRACE-level log. Example:: @@ -393,7 +395,9 @@ def trace(message: str, *, event_name: str | None = None, **attributes: Any) -> _emit_log_span("trace", message, event_name=event_name, **attributes) @staticmethod - def debug(message: str, *, event_name: str | None = None, **attributes: Any) -> None: + def debug( + message: str, *, event_name: str | None = None, **attributes: Any + ) -> None: """Emit a DEBUG-level log. Example:: @@ -414,7 +418,9 @@ def info(message: str, *, event_name: str | None = None, **attributes: Any) -> N _emit_log_span("info", message, event_name=event_name, **attributes) @staticmethod - def notice(message: str, *, event_name: str | None = None, **attributes: Any) -> None: + def notice( + message: str, *, event_name: str | None = None, **attributes: Any + ) -> None: """Emit a NOTICE-level log. Example:: @@ -434,7 +440,9 @@ def warn(message: str, *, event_name: str | None = None, **attributes: Any) -> N _emit_log_span("warn", message, event_name=event_name, **attributes) @staticmethod - def error(message: str, *, event_name: str | None = None, **attributes: Any) -> None: + def error( + message: str, *, event_name: str | None = None, **attributes: Any + ) -> None: """Emit an ERROR-level log. Example:: @@ -444,7 +452,9 @@ def error(message: str, *, event_name: str | None = None, **attributes: Any) -> _emit_log_span("error", message, event_name=event_name, **attributes) @staticmethod - def fatal(message: str, *, event_name: str | None = None, **attributes: Any) -> None: + def fatal( + message: str, *, event_name: str | None = None, **attributes: Any + ) -> None: """Emit a FATAL-level log. Example:: @@ -454,7 +464,9 @@ def fatal(message: str, *, event_name: str | None = None, **attributes: Any) -> _emit_log_span("fatal", message, event_name=event_name, **attributes) @staticmethod - def exception(message: str, *, event_name: str | None = None, **attributes: Any) -> None: + def exception( + message: str, *, event_name: str | None = None, **attributes: Any + ) -> None: """Emit an ERROR-level log with the current exception attached. Must be called from an ``except`` block. Automatically captures @@ -545,7 +557,9 @@ def __init__( name, description, str(unit) ) - def observe(self, value: float, *, attributes: dict[str, str] | None = None) -> None: + def observe( + self, value: float, *, attributes: dict[str, str] | None = None + ) -> None: """Record an observation. Example:: diff --git a/tests/integration/test_telemetry.py b/tests/integration/test_telemetry.py index 840626f8..2f1db6cb 100644 --- a/tests/integration/test_telemetry.py +++ b/tests/integration/test_telemetry.py @@ -1175,8 +1175,7 @@ def test_log_span_is_zero_duration(self, otel_collector: OtelCollector) -> None: end = int(s.endTimeUnixNano) delta_us = (end - start) / 1_000 assert delta_us < 1_000, ( - f"log span should be near-zero-duration; " - f"delta={delta_us:.1f}µs" + f"log span should be near-zero-duration; delta={delta_us:.1f}µs" ) return pytest.fail("log span 'integration test log message' not found") diff --git a/tests/telemetry/test_cross_signal.py b/tests/telemetry/test_cross_signal.py index 8200d8ad..4b4c376c 100644 --- a/tests/telemetry/test_cross_signal.py +++ b/tests/telemetry/test_cross_signal.py @@ -34,14 +34,14 @@ class TestCrossSignal: """Verify instrumentation scope names and cross-signal correlation.""" _setup = make_setup_fixture( - "/api/telemetry/cross-signal", sleep_time=5, require_logs=True, + "/api/telemetry/cross-signal", + sleep_time=5, + require_logs=True, ) # ── Instrumentation scope: user spans ───────────────────────────────── - def test_user_span_scope_is_apx_user( - self, otel_collector: OtelCollector - ) -> None: + def test_user_span_scope_is_apx_user(self, otel_collector: OtelCollector) -> None: """User spans (via apx.telemetry.span) should have scope 'apx.user'.""" for scope, span in flat_spans_with_scope(otel_collector): if span.name == "test.cross_signal_span": @@ -149,10 +149,7 @@ def test_stdlib_log_not_duplicated_as_span( ) -> None: """Stdlib logs should only appear as log records, not as spans.""" spans = flat_spans(otel_collector) - matching = [ - s for s in spans - if "cross signal stdlib warning" in s.name - ] + matching = [s for s in spans if "cross signal stdlib warning" in s.name] assert not matching, ( f"stdlib log should not produce a span; found {len(matching)} span(s)" ) @@ -169,9 +166,7 @@ def test_user_counter_scope_is_apx_user( ) return all_names = sorted({m.name for _, m in flat_metrics_with_scope(otel_collector)}) - pytest.fail( - f"test.cross_signal_counter not found; available: {all_names}" - ) + pytest.fail(f"test.cross_signal_counter not found; available: {all_names}") def test_user_histogram_scope_is_apx_user( self, otel_collector: OtelCollector @@ -219,8 +214,7 @@ def test_log_info_span_is_zero_duration( end = int(span.endTimeUnixNano) delta_us = (end - start) / 1_000 assert delta_us < 1_000, ( - f"log span should be near-zero-duration; " - f"delta={delta_us:.1f}µs" + f"log span should be near-zero-duration; delta={delta_us:.1f}µs" ) return pytest.fail("'cross signal info log' span not found") diff --git a/tests/telemetry/test_dispatch_metrics.py b/tests/telemetry/test_dispatch_metrics.py index 7a13fe0a..cf2d7842 100644 --- a/tests/telemetry/test_dispatch_metrics.py +++ b/tests/telemetry/test_dispatch_metrics.py @@ -48,13 +48,9 @@ def _setup( time.sleep(5) wait_for_collector_data(otel_collector) - def test_all_dispatch_metrics_present( - self, otel_collector: OtelCollector - ) -> None: + def test_all_dispatch_metrics_present(self, otel_collector: OtelCollector) -> None: """Every APX dispatch histogram must have at least one data point.""" - collected_names = { - m.name for _, m in flat_metrics_with_scope(otel_collector) - } + collected_names = {m.name for _, m in flat_metrics_with_scope(otel_collector)} missing = APX_DISPATCH_METRICS - collected_names assert not missing, ( f"Missing APX dispatch metrics: {sorted(missing)}. " @@ -78,9 +74,7 @@ def test_dispatch_metrics_unit_is_microseconds( duration_metrics = {n for n in APX_DISPATCH_METRICS if n.endswith(".duration")} for _, m in flat_metrics_with_scope(otel_collector): if m.name in duration_metrics: - assert m.unit == "us", ( - f"{m.name} unit should be 'us', got {m.unit!r}" - ) + assert m.unit == "us", f"{m.name} unit should be 'us', got {m.unit!r}" def test_queue_depth_unit_is_dimensionless( self, otel_collector: OtelCollector @@ -88,6 +82,4 @@ def test_queue_depth_unit_is_dimensionless( """queue_depth is a count, not a duration — unit must be '1'.""" for _, m in flat_metrics_with_scope(otel_collector): if m.name == "apx.dispatch.queue_depth": - assert m.unit == "1", ( - f"{m.name} unit should be '1', got {m.unit!r}" - ) + assert m.unit == "1", f"{m.name} unit should be '1', got {m.unit!r}" diff --git a/tests/telemetry/test_error_handling.py b/tests/telemetry/test_error_handling.py index a966ce44..9cf54452 100644 --- a/tests/telemetry/test_error_handling.py +++ b/tests/telemetry/test_error_handling.py @@ -39,9 +39,7 @@ def test_erroring_span_has_error_status( f"expected status Error (2); got {span.status.code}" ) - def test_erroring_span_status_message( - self, otel_collector: OtelCollector - ) -> None: + def test_erroring_span_status_message(self, otel_collector: OtelCollector) -> None: span = find_span(otel_collector, "test.erroring_span") assert "deliberate test error" in span.status.message, ( f"expected 'deliberate test error' in status message; " @@ -58,9 +56,7 @@ def test_erroring_span_has_exception_event( f"got events: {[e.name for e in span.events]}" ) - def test_erroring_span_exception_type( - self, otel_collector: OtelCollector - ) -> None: + def test_erroring_span_exception_type(self, otel_collector: OtelCollector) -> None: span = find_span(otel_collector, "test.erroring_span") exc_event = next(e for e in span.events if e.name == "exception") attrs = {a.key: (a.value.stringValue or "") for a in exc_event.attributes} @@ -91,9 +87,7 @@ def test_erroring_span_exception_stacktrace( # ── (b) log.exception() captures exception info ─────────────────────── - def test_log_exception_span_exists( - self, otel_collector: OtelCollector - ) -> None: + def test_log_exception_span_exists(self, otel_collector: OtelCollector) -> None: span = find_span(otel_collector, "caught runtime error") attrs = span_attrs(span) assert attrs.get("log.level") == "error" @@ -117,9 +111,7 @@ def test_log_exception_has_exception_message( f"got {attrs.get('exception.message')!r}" ) - def test_log_exception_has_stacktrace( - self, otel_collector: OtelCollector - ) -> None: + def test_log_exception_has_stacktrace(self, otel_collector: OtelCollector) -> None: span = find_span(otel_collector, "caught runtime error") attrs = span_attrs(span) assert attrs.get("exception.stacktrace"), ( @@ -153,9 +145,7 @@ def test_explicit_error_span_status_message( # ── (d) Clean span — no error ───────────────────────────────────────── - def test_clean_span_has_unset_status( - self, otel_collector: OtelCollector - ) -> None: + def test_clean_span_has_unset_status(self, otel_collector: OtelCollector) -> None: span = find_span(otel_collector, "test.clean_span") assert span.status.code == OTEL_STATUS_UNSET, ( f"clean span should have status Unset (0); got {span.status.code}" diff --git a/tests/telemetry/test_event_name.py b/tests/telemetry/test_event_name.py index 004cd69d..f77a45f1 100644 --- a/tests/telemetry/test_event_name.py +++ b/tests/telemetry/test_event_name.py @@ -28,14 +28,14 @@ class TestEventName: """Verify event_name propagation through spans and OTLP log records.""" _setup = make_setup_fixture( - "/api/telemetry/event-name", sleep_time=5, require_logs=True, + "/api/telemetry/event-name", + sleep_time=5, + require_logs=True, ) # ── log.info / log.warn produce spans with event.name attribute ── - def test_log_info_span_has_event_name( - self, otel_collector: OtelCollector - ) -> None: + def test_log_info_span_has_event_name(self, otel_collector: OtelCollector) -> None: """log.info(event_name='user.login') produces span with event.name attr.""" for s in flat_spans(otel_collector): if s.name == "user logged in": @@ -46,9 +46,7 @@ def test_log_info_span_has_event_name( return pytest.fail("span 'user logged in' not found") - def test_log_warn_span_has_event_name( - self, otel_collector: OtelCollector - ) -> None: + def test_log_warn_span_has_event_name(self, otel_collector: OtelCollector) -> None: """log.warn(event_name='rate_limit.warning') produces span with event.name attr.""" for s in flat_spans(otel_collector): if s.name == "rate limit near": @@ -108,10 +106,7 @@ def test_rust_event_names_follow_convention( name = lr.eventName if not name.startswith("apx."): continue - assert name == name.lower(), ( - f"eventName should be lowercase: {name!r}" + assert name == name.lower(), f"eventName should be lowercase: {name!r}" + assert all(part.replace("_", "").isalnum() for part in name.split(".")), ( + f"eventName has invalid segment: {name!r}" ) - assert all( - part.replace("_", "").isalnum() - for part in name.split(".") - ), f"eventName has invalid segment: {name!r}" diff --git a/tests/telemetry/test_otlp_fields.py b/tests/telemetry/test_otlp_fields.py index d3b65c3c..aa68bd70 100644 --- a/tests/telemetry/test_otlp_fields.py +++ b/tests/telemetry/test_otlp_fields.py @@ -37,14 +37,14 @@ class TestOtlpFields: """Verify all OTLP proto fields are properly populated in exported telemetry.""" _setup = make_setup_fixture( - "/api/telemetry/otlp-fields", sleep_time=5, require_logs=True, + "/api/telemetry/otlp-fields", + sleep_time=5, + require_logs=True, ) # ── Resource attributes ─────────────────────────────────────────────── - def test_resource_has_service_name( - self, otel_collector: OtelCollector - ) -> None: + def test_resource_has_service_name(self, otel_collector: OtelCollector) -> None: """resource.attributes must include service.name.""" for rs in flat_resource_spans(otel_collector): attr_keys = {a.key for a in rs.resource.attributes} @@ -52,9 +52,7 @@ def test_resource_has_service_name( return pytest.fail("service.name not found in resource.attributes") - def test_resource_has_apx_attributes( - self, otel_collector: OtelCollector - ) -> None: + def test_resource_has_apx_attributes(self, otel_collector: OtelCollector) -> None: """resource.attributes must include apx.process.type and apx.worker.id.""" for rs in flat_resource_spans(otel_collector): attr_keys = {a.key for a in rs.resource.attributes} @@ -64,18 +62,14 @@ def test_resource_has_apx_attributes( # ── Resource schema_url ─────────────────────────────────────────────── - def test_resource_schema_url_on_spans( - self, otel_collector: OtelCollector - ) -> None: + def test_resource_schema_url_on_spans(self, otel_collector: OtelCollector) -> None: for rs in flat_resource_spans(otel_collector): if rs.schemaUrl: assert "opentelemetry.io/schemas" in rs.schemaUrl return pytest.fail("ResourceSpans.schemaUrl not populated") - def test_resource_schema_url_on_logs( - self, otel_collector: OtelCollector - ) -> None: + def test_resource_schema_url_on_logs(self, otel_collector: OtelCollector) -> None: for rl in flat_resource_logs(otel_collector): if rl.schemaUrl: assert "opentelemetry.io/schemas" in rl.schemaUrl @@ -93,9 +87,7 @@ def test_resource_schema_url_on_metrics( # ── InstrumentationScope version ────────────────────────────────────── - def test_span_scope_has_version( - self, otel_collector: OtelCollector - ) -> None: + def test_span_scope_has_version(self, otel_collector: OtelCollector) -> None: for scope, s in flat_spans_with_scope(otel_collector): if s.name == "test.client_call" and scope.version: return @@ -106,17 +98,13 @@ def test_span_scope_has_version( "from target, dropping version (opentelemetry-proto transform)", strict=False, ) - def test_log_scope_has_version( - self, otel_collector: OtelCollector - ) -> None: + def test_log_scope_has_version(self, otel_collector: OtelCollector) -> None: for scope, lr in flat_log_records(otel_collector): if scope.name == "apx.python" and scope.version: return pytest.fail("Log scope version not populated for apx.python logger") - def test_metric_scope_has_version( - self, otel_collector: OtelCollector - ) -> None: + def test_metric_scope_has_version(self, otel_collector: OtelCollector) -> None: for scope, m in flat_metrics_with_scope(otel_collector): if m.name.startswith("test.otlp_fields") and scope.version: return @@ -124,9 +112,7 @@ def test_metric_scope_has_version( # ── InstrumentationScope schema_url ─────────────────────────────────── - def test_span_scope_has_schema_url( - self, otel_collector: OtelCollector - ) -> None: + def test_span_scope_has_schema_url(self, otel_collector: OtelCollector) -> None: for rs in flat_resource_spans(otel_collector): for ss in rs.scopeSpans: if ss.scope.name == "apx.user" and ss.schemaUrl: @@ -134,9 +120,7 @@ def test_span_scope_has_schema_url( return pytest.fail("ScopeSpans.schemaUrl not populated for apx.user scope") - def test_log_scope_has_schema_url( - self, otel_collector: OtelCollector - ) -> None: + def test_log_scope_has_schema_url(self, otel_collector: OtelCollector) -> None: for rl in flat_resource_logs(otel_collector): for sl in rl.scopeLogs: if sl.scope.name == "apx.python" and sl.schemaUrl: @@ -144,9 +128,7 @@ def test_log_scope_has_schema_url( return pytest.fail("ScopeLogs.schemaUrl not populated for apx.python scope") - def test_metric_scope_has_schema_url( - self, otel_collector: OtelCollector - ) -> None: + def test_metric_scope_has_schema_url(self, otel_collector: OtelCollector) -> None: for rm in flat_resource_metrics(otel_collector): for sm in rm.scopeMetrics: if sm.scope.name == "apx.user" and sm.schemaUrl: @@ -156,9 +138,7 @@ def test_metric_scope_has_schema_url( # ── Metric start_time_unix_nano ─────────────────────────────────────── - def test_counter_has_start_time( - self, otel_collector: OtelCollector - ) -> None: + def test_counter_has_start_time(self, otel_collector: OtelCollector) -> None: """Counter (Sum) should have non-empty start_time_unix_nano.""" for _, m in flat_metrics_with_scope(otel_collector): if m.name == "test.otlp_fields_counter" and m.sum: @@ -167,9 +147,7 @@ def test_counter_has_start_time( return pytest.fail("Counter start_time_unix_nano not populated") - def test_histogram_has_start_time( - self, otel_collector: OtelCollector - ) -> None: + def test_histogram_has_start_time(self, otel_collector: OtelCollector) -> None: """Histogram should have non-empty start_time_unix_nano.""" for _, m in flat_metrics_with_scope(otel_collector): if m.name == "test.otlp_fields_histogram" and m.histogram: @@ -178,9 +156,7 @@ def test_histogram_has_start_time( return pytest.fail("Histogram start_time_unix_nano not populated") - def test_gauge_start_time( - self, otel_collector: OtelCollector - ) -> None: + def test_gauge_start_time(self, otel_collector: OtelCollector) -> None: """Gauge start_time_unix_nano may or may not be populated depending on SDK version.""" for _, m in flat_metrics_with_scope(otel_collector): if m.name == "test.otlp_fields_gauge" and m.gauge: @@ -189,9 +165,7 @@ def test_gauge_start_time( # ── Log observed_time_unix_nano ─────────────────────────────────────── - def test_log_has_observed_timestamp( - self, otel_collector: OtelCollector - ) -> None: + def test_log_has_observed_timestamp(self, otel_collector: OtelCollector) -> None: """Log records should have non-empty observedTimeUnixNano.""" for _, lr in flat_log_records(otel_collector): body = lr.body.stringValue or "" @@ -204,9 +178,7 @@ def test_log_has_observed_timestamp( # ── Log flags ───────────────────────────────────────────────────────── - def test_log_flags_has_trace_flags( - self, otel_collector: OtelCollector - ) -> None: + def test_log_flags_has_trace_flags(self, otel_collector: OtelCollector) -> None: """When trace context is present, log flags bits 0-7 should carry trace flags.""" for _, lr in flat_log_records(otel_collector): body = lr.body.stringValue or "" @@ -218,9 +190,7 @@ def test_log_flags_has_trace_flags( # ── Span kind ───────────────────────────────────────────────────────── - def test_span_kind_client( - self, otel_collector: OtelCollector - ) -> None: + def test_span_kind_client(self, otel_collector: OtelCollector) -> None: """test.client_call should have kind=3 (CLIENT).""" for s in flat_spans(otel_collector): if s.name == "test.client_call": @@ -228,9 +198,7 @@ def test_span_kind_client( return pytest.fail("test.client_call span not found") - def test_span_kind_internal_default( - self, otel_collector: OtelCollector - ) -> None: + def test_span_kind_internal_default(self, otel_collector: OtelCollector) -> None: """test.internal_work should have kind=1 (INTERNAL).""" for s in flat_spans(otel_collector): if s.name == "test.internal_work": @@ -240,9 +208,7 @@ def test_span_kind_internal_default( # ── Span flags ──────────────────────────────────────────────────────── - def test_span_flags_nonzero( - self, otel_collector: OtelCollector - ) -> None: + def test_span_flags_nonzero(self, otel_collector: OtelCollector) -> None: """Sampled spans should have flags with bit 0 set (SAMPLED).""" for s in flat_spans(otel_collector): if s.name == "test.client_call": @@ -253,9 +219,7 @@ def test_span_flags_nonzero( # ── Span events ─────────────────────────────────────────────────────── - def test_span_events_array_populated( - self, otel_collector: OtelCollector - ) -> None: + def test_span_events_array_populated(self, otel_collector: OtelCollector) -> None: """test.client_call should have a dns.resolved event with attributes.""" for s in flat_spans(otel_collector): if s.name == "test.client_call": @@ -271,9 +235,7 @@ def test_span_events_array_populated( # ── Log event_name (proto field) ────────────────────────────────────── - def test_log_event_name_populated( - self, otel_collector: OtelCollector - ) -> None: + def test_log_event_name_populated(self, otel_collector: OtelCollector) -> None: """Log with event_name should carry it as an attribute. The OTEL Rust SDK ``set_event_name`` requires ``&'static str``, diff --git a/tests/telemetry/test_sequential_spans.py b/tests/telemetry/test_sequential_spans.py index 01ffaddc..063b0b7b 100644 --- a/tests/telemetry/test_sequential_spans.py +++ b/tests/telemetry/test_sequential_spans.py @@ -97,9 +97,7 @@ def test_sibling_a_finishes_before_b_starts( f"a.end={a_end} b.start={b_start}" ) - def test_parent_encloses_both_siblings( - self, otel_collector: OtelCollector - ) -> None: + def test_parent_encloses_both_siblings(self, otel_collector: OtelCollector) -> None: """Parent span should start before and end after both children.""" parent = find_span(otel_collector, "test.parent") a = find_span(otel_collector, "test.sibling_a") @@ -113,6 +111,4 @@ def test_parent_encloses_both_siblings( assert p_start <= a_start, ( f"parent should start before sibling_a: {p_start} > {a_start}" ) - assert p_end >= b_end, ( - f"parent should end after sibling_b: {p_end} < {b_end}" - ) + assert p_end >= b_end, f"parent should end after sibling_b: {p_end} < {b_end}" diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index 4b05198b..26f0cde8 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -1,42 +1,46 @@ -"""Unit tests for the batch-limited drain. +"""Unit tests for the inline dispatch driver. Tests cover: - Batch drain stops after ``max_drain_batch`` items and re-schedules - Batch drain exhausts small queues without re-scheduling - ``_guarded`` calls ``send.send_error()`` on app exceptions +- Inline driving completes simple coroutines without tasks +- ``CallSoonCapture`` intercepts and flushes callbacks correctly """ from __future__ import annotations import asyncio +import traceback from collections.abc import Coroutine from typing import Any, Callable from unittest.mock import AsyncMock, MagicMock - -def _cancel_all( - loop: asyncio.AbstractEventLoop, tasks: list[asyncio.Task[None]] -) -> None: - """Cancel and await all tasks so the loop can close cleanly.""" - for t in tasks: - t.cancel() - if tasks: - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) +from apx._continuation import Continuation +from apx._scheduler import ( + CallSoonCapture, + Completed, + Failed, + SchedulerTask, + Suspended, + drive_inline, +) # --------------------------------------------------------------------------- # Helpers — replicate the dispatch wiring without the Rust native module. # --------------------------------------------------------------------------- + def _make_dispatch( queue_items: list[tuple[Any, Any, Any] | None], app: Callable[..., Coroutine[Any, Any, None]] | None = None, max_drain_batch: int = 8, ) -> tuple[asyncio.AbstractEventLoop, Callable[[], None], MagicMock]: - """Build an ``install_dispatch``-style closure with a mock queue. + """Build an inline dispatch drain with a mock queue. Returns ``(loop, _drain_queue, mock_queue)`` so callers can invoke - ``_drain_queue()`` and inspect what was scheduled. + ``_drain_queue()`` and inspect what was dispatched. """ items = list(queue_items) @@ -44,19 +48,7 @@ def _make_dispatch( mock_queue.try_recv.side_effect = lambda: items.pop(0) if items else None loop = asyncio.new_event_loop() - tasks_created: list[asyncio.Task[None]] = [] - - original_create_task = loop.create_task - - def tracking_create_task(coro: Any, **kwargs: Any) -> Any: - task = original_create_task(coro, **kwargs) - tasks_created.append(task) - return task - - loop.create_task = tracking_create_task # type: ignore[assignment] - - if app is None: - app = AsyncMock() + dispatched_count: list[int] = [0] call_soon_calls: list[Any] = [] original_call_soon = loop.call_soon @@ -67,7 +59,10 @@ def tracking_call_soon(cb: Any, *args: Any, **kwargs: Any) -> Any: loop.call_soon = tracking_call_soon # type: ignore[assignment] - import traceback + capture = CallSoonCapture(loop) + + if app is None: + app = AsyncMock() async def _guarded( scope: dict[str, Any], @@ -77,21 +72,44 @@ async def _guarded( try: await app(scope, receive, send) # type: ignore[misc] except Exception as exc: - tb = "".join( - traceback.format_exception(type(exc), exc, exc.__traceback__) - ) + tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) send.send_error(tb) + def _dispatch_inline( + scope: dict[str, Any], + receive: Any, + send: Any, + ) -> None: + coro = _guarded(scope, receive, send) + task = SchedulerTask(loop=loop) + + capture.enter() + result = drive_inline(coro, task, loop, capture) + capture.leave() + + dispatched_count[0] += 1 + if isinstance(result, Completed): + return + elif isinstance(result, Failed): + tb_str = "".join( + traceback.format_exception( + type(result.exc), result.exc, result.exc.__traceback__ + ) + ) + send.send_error(tb_str) + elif isinstance(result, Suspended): + Continuation(coro, result.yielded, loop, task, capture) + def _drain_queue() -> None: for _ in range(max_drain_batch): result: tuple[Any, Any, Any] | None = mock_queue.try_recv() if result is None: return scope, receive, send = result - loop.create_task(_guarded(scope, receive, send)) + _dispatch_inline(scope, receive, send) loop.call_soon(_drain_queue) - _drain_queue._tasks_created = tasks_created # type: ignore[attr-defined] + _drain_queue._dispatched_count = dispatched_count # type: ignore[attr-defined] _drain_queue._call_soon_calls = call_soon_calls # type: ignore[attr-defined] return loop, _drain_queue, mock_queue @@ -107,9 +125,10 @@ def _make_item( # --------------------------------------------------------------------------- -# Tests +# Tests — batch drain limits # --------------------------------------------------------------------------- + class TestDrainBatchLimit: """_drain_queue stops after max_drain_batch and re-schedules.""" @@ -121,13 +140,12 @@ def test_batch_limit_triggers_reschedule(self) -> None: try: drain() - assert len(drain._tasks_created) == 8 # type: ignore[attr-defined] + assert drain._dispatched_count[0] == 8 # type: ignore[attr-defined] assert any( cb is drain for cb, _ in drain._call_soon_calls # type: ignore[attr-defined] ) finally: - _cancel_all(loop, drain._tasks_created) # type: ignore[attr-defined] loop.close() def test_small_queue_no_reschedule(self) -> None: @@ -138,13 +156,12 @@ def test_small_queue_no_reschedule(self) -> None: try: drain() - assert len(drain._tasks_created) == 3 # type: ignore[attr-defined] + assert drain._dispatched_count[0] == 3 # type: ignore[attr-defined] assert not any( cb is drain for cb, _ in drain._call_soon_calls # type: ignore[attr-defined] ) finally: - _cancel_all(loop, drain._tasks_created) # type: ignore[attr-defined] loop.close() def test_empty_queue_noop(self) -> None: @@ -153,10 +170,9 @@ def test_empty_queue_noop(self) -> None: try: drain() - assert len(drain._tasks_created) == 0 # type: ignore[attr-defined] + assert drain._dispatched_count[0] == 0 # type: ignore[attr-defined] mock_q.try_recv.assert_called_once() finally: - _cancel_all(loop, drain._tasks_created) # type: ignore[attr-defined] loop.close() def test_exact_batch_size_triggers_reschedule(self) -> None: @@ -168,35 +184,37 @@ def test_exact_batch_size_triggers_reschedule(self) -> None: try: drain() - assert len(drain._tasks_created) == 8 # type: ignore[attr-defined] + assert drain._dispatched_count[0] == 8 # type: ignore[attr-defined] assert any( cb is drain for cb, _ in drain._call_soon_calls # type: ignore[attr-defined] ) finally: - _cancel_all(loop, drain._tasks_created) # type: ignore[attr-defined] loop.close() +# --------------------------------------------------------------------------- +# Tests — error handling +# --------------------------------------------------------------------------- + + class TestGuarded: """_guarded handles app errors correctly.""" def test_app_error_calls_send_error(self) -> None: - """If the ASGI app raises, ``send.send_error(tb)`` is called.""" + """If the ASGI app raises, ``send.send_error(tb)`` is called + inline — no event loop ticking required.""" mock_send = MagicMock() item = _make_item(send=mock_send) app = AsyncMock(side_effect=ValueError("handler failed")) loop, drain, _ = _make_dispatch([item], app=app) try: drain() - loop.run_until_complete(asyncio.sleep(0)) - loop.run_until_complete(asyncio.sleep(0)) mock_send.send_error.assert_called_once() tb_arg: str = mock_send.send_error.call_args[0][0] assert "handler failed" in tb_arg finally: - _cancel_all(loop, drain._tasks_created) # type: ignore[attr-defined] loop.close() def test_successful_app_call(self) -> None: @@ -209,11 +227,191 @@ def test_successful_app_call(self) -> None: loop, drain, _ = _make_dispatch([item], app=app) try: drain() - loop.run_until_complete(asyncio.sleep(0)) - loop.run_until_complete(asyncio.sleep(0)) app.assert_called_once_with(mock_scope, mock_receive, mock_send) mock_send.send_error.assert_not_called() finally: - _cancel_all(loop, drain._tasks_created) # type: ignore[attr-defined] + loop.close() + + +# --------------------------------------------------------------------------- +# Tests — inline driving +# --------------------------------------------------------------------------- + + +class TestInlineDriving: + """drive_inline completes simple coroutines without creating tasks.""" + + def test_sync_completing_coroutine_returns_completed(self) -> None: + """A coroutine that finishes without yielding returns Completed.""" + loop = asyncio.new_event_loop() + capture = CallSoonCapture(loop) + try: + + async def simple() -> None: + pass + + task = SchedulerTask(loop=loop) + capture.enter() + result = drive_inline(simple(), task, loop, capture) + capture.leave() + + assert isinstance(result, Completed) + finally: + loop.close() + + def test_coroutine_with_return_value_completes(self) -> None: + """A coroutine that returns a value still results in Completed.""" + loop = asyncio.new_event_loop() + capture = CallSoonCapture(loop) + try: + + async def with_return() -> str: + return "hello" + + task = SchedulerTask(loop=loop) + capture.enter() + result = drive_inline(with_return(), task, loop, capture) + capture.leave() + + assert isinstance(result, Completed) + finally: + loop.close() + + def test_exception_returns_failed(self) -> None: + """A coroutine that raises an exception returns Failed.""" + loop = asyncio.new_event_loop() + capture = CallSoonCapture(loop) + try: + + async def raises() -> None: + raise RuntimeError("boom") + + task = SchedulerTask(loop=loop) + capture.enter() + result = drive_inline(raises(), task, loop, capture) + capture.leave() + + assert isinstance(result, Failed) + assert isinstance(result.exc, RuntimeError) + assert "boom" in str(result.exc) + finally: + loop.close() + + def test_real_future_returns_suspended(self) -> None: + """A coroutine awaiting a real asyncio Future returns Suspended.""" + loop = asyncio.new_event_loop() + capture = CallSoonCapture(loop) + try: + fut = loop.create_future() + + async def waits_on_future() -> None: + await fut + + task = SchedulerTask(loop=loop) + capture.enter() + result = drive_inline(waits_on_future(), task, loop, capture) + capture.leave() + + assert isinstance(result, Suspended) + assert result.yielded is not None + finally: + loop.close() + + def test_multiple_items_complete_inline(self) -> None: + """Multiple queue items driven inline complete synchronously.""" + items = [_make_item() for _ in range(5)] + app = AsyncMock() + loop, drain, _ = _make_dispatch(items, app=app, max_drain_batch=8) + try: + drain() + + assert drain._dispatched_count[0] == 5 # type: ignore[attr-defined] + assert app.call_count == 5 + finally: + loop.close() + + +# --------------------------------------------------------------------------- +# Tests — CallSoonCapture +# --------------------------------------------------------------------------- + + +class TestCallSoonCapture: + """CallSoonCapture intercepts and processes callbacks correctly.""" + + def test_active_captures_callbacks(self) -> None: + """When active, call_soon callbacks are captured, not executed.""" + loop = asyncio.new_event_loop() + try: + capture = CallSoonCapture(loop) + called: list[str] = [] + + capture.enter() + loop.call_soon(lambda: called.append("should_be_captured")) + assert len(called) == 0 + + capture.leave() + finally: + loop.close() + + def test_inactive_passes_through(self) -> None: + """When not active, call_soon delegates to the original.""" + loop = asyncio.new_event_loop() + try: + scheduled: list[Any] = [] + original = loop.call_soon + + def tracking(cb: Any, *args: Any, **kw: Any) -> Any: + scheduled.append(cb) + return original(cb, *args, **kw) + + loop.call_soon = tracking # type: ignore[assignment] + capture = CallSoonCapture(loop) + + callback = lambda: None # noqa: E731 + loop.call_soon(callback) + assert callback in scheduled + finally: + loop.close() + + def test_flush_processes_captured(self) -> None: + """flush() runs captured callbacks inline.""" + loop = asyncio.new_event_loop() + try: + capture = CallSoonCapture(loop) + called: list[str] = [] + + capture.enter() + loop.call_soon(lambda: called.append("a")) + loop.call_soon(lambda: called.append("b")) + + capture.flush() + assert called == ["a", "b"] + + capture.leave() + finally: + loop.close() + + def test_leave_spills_remaining(self) -> None: + """leave() schedules remaining callbacks via the real call_soon.""" + loop = asyncio.new_event_loop() + try: + scheduled: list[Any] = [] + original = loop.call_soon + + def tracking(cb: Any, *args: Any, **kw: Any) -> Any: + scheduled.append(cb) + return original(cb, *args, **kw) + + loop.call_soon = tracking # type: ignore[assignment] + capture = CallSoonCapture(loop) + + callback = lambda: None # noqa: E731 + capture.enter() + loop.call_soon(callback) + + capture.leave() + assert callback in scheduled + finally: loop.close() diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index d965ad53..4f3fb5a1 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -964,9 +964,7 @@ def test_default_resource_is_empty(self) -> None: def test_resource_in_config(self) -> None: c = Configuration( - resource=Resource( - attributes=[Attribute(key="env", value="staging")] - ) + resource=Resource(attributes=[Attribute(key="env", value="staging")]) ) assert len(c.resource.attributes) == 1 assert c.resource.attributes[0].key == "env" @@ -974,9 +972,7 @@ def test_resource_in_config(self) -> None: def test_get_config_includes_resource(self) -> None: configure( Configuration( - resource=Resource( - attributes=[Attribute(key="team", value="platform")] - ) + resource=Resource(attributes=[Attribute(key="team", value="platform")]) ) ) from apx.telemetry import _get_config @@ -988,11 +984,7 @@ def test_get_config_includes_resource(self) -> None: configure(Configuration()) def test_get_config_resource_schema_url(self) -> None: - configure( - Configuration( - resource=Resource(schema_url="https://custom.schema") - ) - ) + configure(Configuration(resource=Resource(schema_url="https://custom.schema"))) from apx.telemetry import _get_config config = _get_config() From 75866b7503442f21ea2bb904b0aeaca74a52a31f Mon Sep 17 00:00:00 2001 From: renardeinside Date: Wed, 1 Apr 2026 12:50:24 +0200 Subject: [PATCH 06/18] =?UTF-8?q?=F0=9F=90=9B=20fix:=20deliver=20Future=20?= =?UTF-8?q?results=20in=20continuation=20and=20preserve=20contextvars=20in?= =?UTF-8?q?=20call=5Fsoon=20wire?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/apx/_continuation.py | 84 ++++++++++++++++++++++++++++++++++++++-- src/apx/_scheduler.py | 45 ++++++++++++++++----- 2 files changed, 116 insertions(+), 13 deletions(-) diff --git a/src/apx/_continuation.py b/src/apx/_continuation.py index 16b30be1..beb57926 100644 --- a/src/apx/_continuation.py +++ b/src/apx/_continuation.py @@ -29,9 +29,21 @@ class Continuation: Each step uses per-step ``_enter_task`` / ``_leave_task`` brackets, keeping invariant I1. Runs entirely on the asyncio thread. + + When an asyncio Future resolves, the continuation delivers the + result (or exception) to the coroutine via ``drive_inline``'s + ``send_value`` / ``send_exception`` parameters — matching the + standard ``Task.__step`` protocol. """ - __slots__ = ("_coro", "_loop", "_task", "_capture", "_on_complete") + __slots__ = ( + "_coro", + "_loop", + "_task", + "_capture", + "_on_complete", + "_resolved_future", + ) def __init__( self, @@ -47,13 +59,16 @@ def __init__( self._task = task self._capture = capture self._on_complete = on_complete + self._resolved_future: asyncio.Future[Any] | None = None self._attach(yielded) def _attach(self, yielded: object) -> None: """Attach to a yielded value to resume when ready.""" if yielded is None: + # yield None (e.g. asyncio.sleep(0)) — re-enter next cycle. self._loop.call_soon(self._step) elif hasattr(yielded, "add_done_callback"): + # asyncio.Future — resume when I/O completes. self._task._waiter = yielded # type: ignore[assignment, ty:invalid-assignment] yielded.add_done_callback(self._on_future_done) # type: ignore[union-attr, ty:call-non-callable] else: @@ -61,15 +76,39 @@ def _attach(self, yielded: object) -> None: def _on_future_done(self, future: asyncio.Future[Any]) -> None: self._task._waiter = None + self._resolved_future = future + # If another task is currently entered (defensive guard), + # defer to next callback cycle. if asyncio.current_task() is not None: self._loop.call_soon(self._step) return self._step() + def _extract_resume( + self, + ) -> tuple[Any, BaseException | None]: + """Extract send_value / send_exception from a resolved Future. + + Mirrors the ``Task.__step`` protocol: deliver the Future's + result to the coroutine, or throw its exception. + """ + future = self._resolved_future + self._resolved_future = None + if future is None: + # yield-None re-entry — no value to deliver. + return None, None + if future.cancelled(): + return None, asyncio.CancelledError() + exc = future.exception() + if exc is not None: + return None, exc + return future.result(), None + def _step(self) -> None: if self._coro is None: return + # Check cancellation flag (asyncio.timeout / anyio.fail_after). if self._task._cancel_flag: self._task._cancel_flag = False _enter_task(self._loop, self._task) @@ -77,16 +116,39 @@ def _step(self) -> None: yielded = self._coro.throw( asyncio.CancelledError(self._task._cancel_msg) ) - except (StopIteration, BaseException): + except StopIteration: + # Coroutine caught CancelledError and returned normally. + _leave_task(self._loop, self._task) + self._finish() + return + except asyncio.CancelledError: + # Coroutine re-raised CancelledError — expected. _leave_task(self._loop, self._task) self._finish() return + except BaseException as exc: + # Coroutine raised a different exception during cancel + # cleanup (e.g. error in a yield-dep finalizer). + _leave_task(self._loop, self._task) + _log_cancel_exception(exc) + self._finish() + return _leave_task(self._loop, self._task) self._attach(yielded) return + # Normal step: deliver the Future result and resume driving. + send_value, send_exception = self._extract_resume() + self._capture.enter() - result = drive_inline(self._coro, self._task, self._loop, self._capture) + result = drive_inline( + self._coro, + self._task, + self._loop, + self._capture, + send_value=send_value, + send_exception=send_exception, + ) self._capture.leave() if isinstance(result, Completed): @@ -99,5 +161,21 @@ def _step(self) -> None: def _finish(self) -> None: self._coro = None self._task._waiter = None + self._resolved_future = None if self._on_complete is not None: self._on_complete() + + +def _log_cancel_exception(exc: BaseException) -> None: + """Log an unexpected exception from the cancel path. + + Avoids silent swallowing — if cleanup code (e.g. a yield-dep + finalizer) raises during cancellation, it shows up in logs. + """ + import logging + import traceback + + tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + logging.getLogger("apx.dispatch").warning( + "exception during cancellation cleanup:\n%s", tb + ) diff --git a/src/apx/_scheduler.py b/src/apx/_scheduler.py index 10d038e1..65db90ac 100644 --- a/src/apx/_scheduler.py +++ b/src/apx/_scheduler.py @@ -105,9 +105,14 @@ class CallSoonCapture: __slots__ = ("_original", "_queue", "_active") + # Queue entry: (callback, args, context). Context is preserved so + # that Task.__step and Future done-callbacks run in their correct + # contextvars snapshot (invariant I2). + _Entry = tuple[Callable[..., Any], tuple[Any, ...], contextvars.Context | None] + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._original: Callable[..., Any] = loop.call_soon - self._queue: deque[tuple[Callable[..., Any], tuple[Any, ...]]] = deque() + self._queue: deque[CallSoonCapture._Entry] = deque() self._active: bool = False loop.call_soon = self._intercept # type: ignore[assignment, ty:invalid-assignment] @@ -115,12 +120,12 @@ def _intercept( self, callback: Callable[..., Any], *args: Any, - context: Any = None, + context: contextvars.Context | None = None, ) -> None: if self._active: - self._queue.append((callback, args)) + self._queue.append((callback, args, context)) else: - self._original(callback, *args) + self._original(callback, *args, context=context) def enter(self) -> None: """Start capturing ``call_soon`` callbacks.""" @@ -132,15 +137,22 @@ def leave(self) -> None: self._active = False original = self._original while self._queue: - cb, args = self._queue.popleft() - original(cb, *args) + cb, args, ctx = self._queue.popleft() + original(cb, *args, context=ctx) def flush(self, budget: int = FLUSH_BUDGET) -> None: - """Process captured callbacks inline (between drive steps).""" + """Process captured callbacks inline (between drive steps). + + Callbacks run in their original context so that contextvars + (e.g. OTEL trace propagation) are preserved correctly. + """ queue = self._queue while queue and budget > 0: - cb, args = queue.popleft() - cb(*args) + cb, args, ctx = queue.popleft() + if ctx is not None: + ctx.run(cb, *args) + else: + cb(*args) budget -= 1 @@ -181,6 +193,9 @@ def drive_inline( task: SchedulerTask, loop: asyncio.AbstractEventLoop, capture: CallSoonCapture, + *, + send_value: Any = None, + send_exception: BaseException | None = None, ) -> _DriveResult: """Drive a coroutine to completion or first real suspension. @@ -188,6 +203,11 @@ def drive_inline( ``current_task() is None``. Uses per-step ``_enter_task`` / ``_leave_task`` brackets (invariant I1). + On initial entry ``send_value`` is ``None`` (starts the coroutine). + On continuation re-entry after a Future resolves, pass the Future's + result via ``send_value`` or its exception via ``send_exception`` + so the coroutine receives the I/O result at its ``await`` point. + Returns: Completed — coroutine finished; response already fired. Suspended — coroutine yielded an asyncio Future; needs @@ -201,7 +221,12 @@ def drive_inline( while True: _enter_task(loop, task) try: - result = context_run(coro.send, None) + if send_exception is not None: + result = context_run(coro.throw, send_exception) + send_exception = None + else: + result = context_run(coro.send, send_value) + send_value = None except StopIteration: _leave_task(loop, task) return _COMPLETED From 00c89257db96c80a8a36c5c7aec7720c8c1aec73 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Wed, 1 Apr 2026 14:56:28 +0200 Subject: [PATCH 07/18] increase drainer --- src/apx/_dispatch.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/apx/_dispatch.py b/src/apx/_dispatch.py index 455080ed..f2f396ca 100644 --- a/src/apx/_dispatch.py +++ b/src/apx/_dispatch.py @@ -29,6 +29,7 @@ drive_inline, ) +MAX_DRAIN_BATCH: int = 32 def install_dispatch( loop: asyncio.AbstractEventLoop, @@ -36,11 +37,8 @@ def install_dispatch( app: Callable[..., Coroutine[Any, Any, None]], wakeup_fd: int | None = None, ) -> None: - """Install the inline dispatch driver on the asyncio event loop.""" - - max_drain_batch: int = 8 + """Install the inline dispatch driver on the asyncio event loop.""" capture = CallSoonCapture(loop) - async def _guarded( scope: dict[str, Any], receive: Any, @@ -83,7 +81,7 @@ def _dispatch_inline( Continuation(coro, result.yielded, loop, task, capture) def _drain_queue() -> None: - for _ in range(max_drain_batch): + for _ in range(MAX_DRAIN_BATCH): result: tuple[Any, Any, Any] | None = queue.try_recv() if result is None: return From 3d0686a467b752c11bf3bcd979e5c83d66ea65ce Mon Sep 17 00:00:00 2001 From: renardeinside Date: Wed, 1 Apr 2026 19:36:46 +0200 Subject: [PATCH 08/18] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20migrate?= =?UTF-8?q?=20to=20asyncio=20protocol=20with=20Rust=20HTTP=20primitives?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 45 +- Cargo.toml | 4 + crates/framework/Cargo.toml | 16 +- crates/framework/src/asgi/app.rs | 197 +- crates/framework/src/asgi/channel_body.rs | 81 - crates/framework/src/asgi/dispatch.rs | 310 --- crates/framework/src/asgi/lifespan.rs | 637 +---- crates/framework/src/asgi/mod.rs | 10 +- crates/framework/src/asgi/queue.rs | 131 - crates/framework/src/asgi/scope.rs | 2192 +---------------- crates/framework/src/asgi/slot_receive.rs | 72 - crates/framework/src/asgi/slot_send.rs | 307 --- crates/framework/src/asgi/streaming.rs | 184 -- crates/framework/src/dispatch.rs | 52 - crates/framework/src/io/channel.rs | 312 --- crates/framework/src/io/mod.rs | 97 - crates/framework/src/io/reactor/mod.rs | 272 -- crates/framework/src/lib.rs | 2 - crates/framework/src/protocol/connection.rs | 596 +++++ crates/framework/src/protocol/http/error.rs | 152 -- crates/framework/src/protocol/http/mod.rs | 4 - crates/framework/src/protocol/http/service.rs | 698 ------ crates/framework/src/protocol/mod.rs | 8 +- crates/framework/src/protocol/parser.rs | 406 +++ crates/framework/src/protocol/router.rs | 155 ++ crates/framework/src/protocol/writer.rs | 537 ++++ crates/framework/src/protocol/ws/mod.rs | 3 - crates/framework/src/protocol/ws/session.rs | 524 ---- crates/framework/src/pyapi.rs | 13 +- crates/framework/src/supervision/mod.rs | 1 - crates/framework/src/supervision/worker.rs | 662 ++--- .../src/supervision/worker_context.rs | 24 - crates/framework/src/telemetry/config.rs | 61 +- crates/framework/src/telemetry/context.rs | 10 +- crates/framework/src/telemetry/defs.rs | 119 +- .../src/telemetry/dispatch_metrics.rs | 161 +- crates/framework/src/telemetry/http.rs | 218 +- crates/framework/src/telemetry/mod.rs | 14 - crates/framework/src/transport/listener.rs | 154 -- crates/framework/src/transport/mod.rs | 13 +- crates/framework/src/transport/tcp.rs | 173 -- crates/framework/src/transport/types.rs | 652 +---- pyproject.toml | 6 +- src/apx/_core.pyi | 44 +- src/apx/_dispatch.py | 103 - src/apx/_scheduler.py | 2 +- src/apx/_server.py | 194 ++ src/apx/telemetry.py | 28 +- tests/telemetry/test_dispatch_metrics.py | 59 +- tests/test_telemetry.py | 98 +- uv.lock | 70 + 51 files changed, 2860 insertions(+), 8023 deletions(-) delete mode 100644 crates/framework/src/asgi/channel_body.rs delete mode 100644 crates/framework/src/asgi/dispatch.rs delete mode 100644 crates/framework/src/asgi/queue.rs delete mode 100644 crates/framework/src/asgi/slot_receive.rs delete mode 100644 crates/framework/src/asgi/slot_send.rs delete mode 100644 crates/framework/src/asgi/streaming.rs delete mode 100644 crates/framework/src/dispatch.rs delete mode 100644 crates/framework/src/io/channel.rs delete mode 100644 crates/framework/src/io/mod.rs delete mode 100644 crates/framework/src/io/reactor/mod.rs create mode 100644 crates/framework/src/protocol/connection.rs delete mode 100644 crates/framework/src/protocol/http/error.rs delete mode 100644 crates/framework/src/protocol/http/mod.rs delete mode 100644 crates/framework/src/protocol/http/service.rs create mode 100644 crates/framework/src/protocol/parser.rs create mode 100644 crates/framework/src/protocol/router.rs create mode 100644 crates/framework/src/protocol/writer.rs delete mode 100644 crates/framework/src/protocol/ws/mod.rs delete mode 100644 crates/framework/src/protocol/ws/session.rs delete mode 100644 crates/framework/src/supervision/worker_context.rs delete mode 100644 crates/framework/src/transport/listener.rs delete mode 100644 crates/framework/src/transport/tcp.rs delete mode 100644 src/apx/_dispatch.py create mode 100644 src/apx/_server.py diff --git a/Cargo.lock b/Cargo.lock index fa1b7567..e29d970d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -301,40 +301,28 @@ dependencies = [ "apx-common", "apx-core", "bytes", - "crossbeam-channel", - "crossbeam-queue", - "futures-core", "futures-util", "hex", "http", - "http-body", - "http-body-util", - "hyper", - "hyper-tungstenite", - "hyper-util", + "httparse", + "matchit", "mimalloc", "notify", "opentelemetry 0.29.1", "opentelemetry_sdk 0.29.0", "pyo3", - "pyo3-async-runtimes", "rand 0.8.5", - "reqwest 0.13.1", "rmp-serde", "serde", "serde_json", - "socket2", "sysinfo", "tempfile", "thiserror 2.0.18", "tikv-jemallocator", "tokio", - "tokio-stream", - "tokio-tungstenite", "tracing", "tracing-opentelemetry", "tracing-subscriber", - "uuid", "which", ] @@ -2708,21 +2696,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "hyper-tungstenite" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc778da281a749ed28d2be73a9f2cd13030680a1574bc729debd1195e44f00e9" -dependencies = [ - "http-body-util", - "hyper", - "hyper-util", - "pin-project-lite", - "tokio", - "tokio-tungstenite", - "tungstenite", -] - [[package]] name = "hyper-util" version = "0.1.19" @@ -4608,20 +4581,6 @@ dependencies = [ "pyo3-macros", ] -[[package]] -name = "pyo3-async-runtimes" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e7364a95bf00e8377bbf9b0f09d7ff9715a29d8fcf93b47d1a967363b973178" -dependencies = [ - "futures-channel", - "futures-util", - "once_cell", - "pin-project-lite", - "pyo3", - "tokio", -] - [[package]] name = "pyo3-build-config" version = "0.28.2" diff --git a/Cargo.toml b/Cargo.toml index 4f8efeb1..4ddb998d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -142,6 +142,10 @@ http = "1" bytes = "1" http-body-util = "0.1" +# Framework: oneshot protocol primitives +httparse = "1.10" +matchit = "0.8" + [workspace.lints.rust] unsafe_code = "deny" warnings = "deny" diff --git a/crates/framework/Cargo.toml b/crates/framework/Cargo.toml index bebb8057..21d95054 100644 --- a/crates/framework/Cargo.toml +++ b/crates/framework/Cargo.toml @@ -21,9 +21,7 @@ jemalloc = ["dep:tikv-jemallocator"] [dependencies] mimalloc = { version = "0.1", optional = true, default-features = false } pyo3.workspace = true -pyo3-async-runtimes.workspace = true tokio.workspace = true -socket2.workspace = true apx-common.workspace = true tracing.workspace = true tracing-opentelemetry.workspace = true @@ -39,27 +37,17 @@ rmp-serde.workspace = true tempfile.workspace = true http.workspace = true bytes.workspace = true -http-body = "1" -http-body-util.workspace = true sysinfo.workspace = true which.workspace = true notify.workspace = true -futures-core.workspace = true futures-util.workspace = true -crossbeam-channel.workspace = true -crossbeam-queue.workspace = true -tokio-stream.workspace = true -hyper = { version = "1", features = ["http1", "http2", "server"] } -hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } -hyper-tungstenite = { workspace = true } -uuid = { version = "1", features = ["v4"] } +httparse.workspace = true +matchit.workspace = true [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = { version = "0.6", optional = true } [dev-dependencies] tempfile.workspace = true -reqwest = { workspace = true } -tokio-tungstenite = { workspace = true } opentelemetry_sdk = { workspace = true, features = ["testing"] } tracing-subscriber = { workspace = true } diff --git a/crates/framework/src/asgi/app.rs b/crates/framework/src/asgi/app.rs index c68267ba..3ee8b636 100644 --- a/crates/framework/src/asgi/app.rs +++ b/crates/framework/src/asgi/app.rs @@ -3,18 +3,9 @@ //! Parses specifiers like `"myapp.main:app"` or `"myapp"` (attr defaults to //! `"app"`) and imports the callable via `importlib.import_module`. //! -//! The [`AppSource`] trait is the extension seam for app loading strategies. -//! `ModuleImport` is the live-import implementation; a future `ManifestSource` -//! will provide pre-built dispatch pipelines. - -use crate::asgi::dispatch::AsgiDispatch; -use crate::asgi::queue::RequestQueue; -use crate::asgi::scope::ScopeInterns; -use crate::dispatch::Dispatch; -use crate::supervision::worker_context::WorkerContext; +//! [`ModuleImport`] performs a live import from a specifier string. + use pyo3::prelude::*; -use std::net::SocketAddr; -use std::sync::Arc; // ── Constants ──────────────────────────────────────────────────────────── @@ -24,20 +15,7 @@ const SPECIFIER_SEPARATOR: char = ':'; /// Default attribute name when no separator is present. const DEFAULT_ATTR: &str = "app"; -// ── AsgiApp ────────────────────────────────────────────────────────────── - -/// A Python ASGI callable, held as a GIL-independent reference. -#[derive(Debug)] -pub struct AsgiApp(Py); - -impl AsgiApp { - /// Access the inner Python object. - pub fn inner(&self) -> &Py { - &self.0 - } -} - -// ── AppLoadError ───────────────────────────────────────────────────────── +// ── AppLoadError ─────────────────────────────────────────────────────────── /// Errors that can occur when loading a Python ASGI app. #[derive(Debug, thiserror::Error)] @@ -77,34 +55,30 @@ pub enum AppLoadError { }, } -// ── format_pyerr ───────────────────────────────────────────────────────── +// ── AsgiApp ──────────────────────────────────────────────────────────────── -/// Render a Python exception with its full traceback. -/// -/// Uses `traceback.format_exception(err)` to produce the same multi-line -/// output that Python prints on unhandled exceptions. Falls back to -/// `PyErr`'s `Display` (type + message only) if the traceback module is -/// unavailable or the formatting call itself fails. -pub fn format_pyerr(py: Python<'_>, err: &PyErr) -> String { - format_pyerr_inner(py, err).unwrap_or_else(|| err.to_string()) -} +/// A Python ASGI callable, held as a GIL-independent reference. +#[derive(Debug)] +pub struct AsgiApp(Py); -/// Inner helper: returns `None` on any failure so the caller can fall back. -fn format_pyerr_inner(py: Python<'_>, err: &PyErr) -> Option { - let tb_mod = py.import(c"traceback").ok()?; - let lines = tb_mod - .call_method1(c"format_exception", (err.value(py),)) - .ok()?; - let joined: String = lines.extract::>().ok()?.join(""); - Some(joined) +impl AsgiApp { + /// Access the inner Python object. + pub fn inner(&self) -> &Py { + &self.0 + } } -// ── parse_specifier ────────────────────────────────────────────────────── +// ── parse_specifier / ModuleImport ───────────────────────────────────────── /// Split a specifier into `(module, attr)`. /// /// Supports `"module:attr"` (explicit) and `"module"` (attr defaults to `"app"`). -fn parse_specifier(specifier: &str) -> Result<(&str, &str), AppLoadError> { +/// +/// # Errors +/// +/// Returns [`AppLoadError::InvalidSpecifier`] if the string is empty, or if either +/// side of `:` is empty when a separator is present. +pub fn parse_specifier(specifier: &str) -> Result<(&str, &str), AppLoadError> { if specifier.is_empty() { return Err(AppLoadError::InvalidSpecifier { specifier: specifier.to_owned(), @@ -123,50 +97,17 @@ fn parse_specifier(specifier: &str) -> Result<(&str, &str), AppLoadError> { } } -// ── AppSource ──────────────────────────────────────────────────────────── - -/// Maximum request body size: 10 MiB. -const DEFAULT_BODY_LIMIT: usize = 10 * 1024 * 1024; - -/// Load an ASGI application and build its dispatch pipeline. -/// -/// Implementations decide how the app is located (runtime import, manifest, -/// etc.) and which dispatch strategy to use. The returned `Arc` -/// is handed to `ApxService` and shared across all connections. -#[expect(dead_code, reason = "extension seam for future app loading strategies")] -pub trait AppSource: Send + Sync + std::fmt::Debug { - /// Load the app and construct its dispatch pipeline. - /// - /// Called once per worker with the GIL held. `event_loop_py` is the - /// asyncio event loop object needed by `install_dispatch`. - /// - /// # Errors - /// - /// Returns [`AppLoadError`] if the app cannot be loaded. - fn build( - &self, - py: Python<'_>, - ctx: Arc, - event_loop_py: &Py, - server_addr: SocketAddr, - ) -> Result, AppLoadError>; -} - -// ── ModuleImport ───────────────────────────────────────────────────────── - /// Runtime import of a Python ASGI callable from a `"module:attr"` specifier. #[derive(Debug)] pub struct ModuleImport { specifier: String, - dev_mode: bool, } impl ModuleImport { /// Create a new loader from a specifier string. - pub fn new(specifier: impl Into, dev_mode: bool) -> Self { + pub fn new(specifier: impl Into) -> Self { Self { specifier: specifier.into(), - dev_mode, } } @@ -213,79 +154,29 @@ impl ModuleImport { } } -impl ModuleImport { - /// Load the app and build dispatch, returning both the dispatch and the - /// raw ASGI callable reference (needed for the lifespan protocol). - pub fn build_with_app( - &self, - py: Python<'_>, - ctx: Arc, - event_loop_py: &Py, - server_addr: SocketAddr, - ) -> Result<(Arc, Py), AppLoadError> { - let app = self.load_callable(py)?; - let asgi_app = app.inner().clone_ref(py); - let interns = Arc::new(ScopeInterns::new(py, server_addr)); - - let queue = RequestQueue::new( - py, - &ctx.pipeline.inbound, - Arc::clone(&ctx.pipeline.wakeup), - Arc::clone(&interns), - self.dev_mode, - ) - .map_err(|e| AppLoadError::ImportFailed { - module: "RequestQueue".to_owned(), - source: e, - })?; - let queue_obj = Py::new(py, queue).map_err(|e| AppLoadError::ImportFailed { - module: "RequestQueue".to_owned(), - source: e, - })?; - - let wakeup_fd = ctx.pipeline.wakeup.reader_fd(); - let dispatch_mod = py - .import(c"apx._dispatch") - .map_err(|e| AppLoadError::ImportFailed { - module: "apx._dispatch".to_owned(), - source: e, - })?; - dispatch_mod - .call_method1( - c"install_dispatch", - (event_loop_py, queue_obj, app.inner(), wakeup_fd), - ) - .map_err(|e| AppLoadError::ImportFailed { - module: "install_dispatch".to_owned(), - source: e, - })?; +// ── format_pyerr ─────────────────────────────────────────────────────────── - let dispatch = AsgiDispatch::new( - ctx.pipeline.inbound.sender().clone(), - Arc::clone(&ctx.pipeline.wakeup), - DEFAULT_BODY_LIMIT, - app.inner().clone_ref(py), - interns, - ctx, - ); - Ok((Arc::new(dispatch), asgi_app)) - } +/// Render a Python exception with its full traceback. +/// +/// Uses `traceback.format_exception(err)` to produce the same multi-line +/// output that Python prints on unhandled exceptions. Falls back to +/// `PyErr`'s `Display` (type + message only) if the traceback module is +/// unavailable or the formatting call itself fails. +pub fn format_pyerr(py: Python<'_>, err: &PyErr) -> String { + format_pyerr_inner(py, err).unwrap_or_else(|| err.to_string()) } -impl AppSource for ModuleImport { - fn build( - &self, - py: Python<'_>, - ctx: Arc, - event_loop_py: &Py, - server_addr: SocketAddr, - ) -> Result, AppLoadError> { - self.build_with_app(py, ctx, event_loop_py, server_addr) - .map(|(dispatch, _)| dispatch) - } +/// Inner helper: returns `None` on any failure so the caller can fall back. +fn format_pyerr_inner(py: Python<'_>, err: &PyErr) -> Option { + let tb_mod = py.import(c"traceback").ok()?; + let lines = tb_mod + .call_method1(c"format_exception", (err.value(py),)) + .ok()?; + let joined: String = lines.extract::>().ok()?.join(""); + Some(joined) } -// ── Tests ──────────────────────────────────────────────────────────────── +// ── Tests ────────────────────────────────────────────────────────────────── #[cfg(test)] #[expect( @@ -347,7 +238,7 @@ mod tests { #[test] fn load_builtin_callable() { crate::with_py(|py| { - let loader = ModuleImport::new("json:dumps", false); + let loader = ModuleImport::new("json:dumps"); let app = loader.load_callable(py).unwrap(); assert!(app.inner().bind(py).is_callable()); }); @@ -356,7 +247,7 @@ mod tests { #[test] fn load_plain_module_default_attr_fails() { crate::with_py(|py| { - let loader = ModuleImport::new("json", false); + let loader = ModuleImport::new("json"); let err = loader.load_callable(py).unwrap_err(); assert!(matches!(err, AppLoadError::MissingAttribute { .. })); }); @@ -365,7 +256,7 @@ mod tests { #[test] fn load_missing_module() { crate::with_py(|py| { - let loader = ModuleImport::new("nonexistent_module_xyz:app", false); + let loader = ModuleImport::new("nonexistent_module_xyz:app"); let err = loader.load_callable(py).unwrap_err(); assert!(matches!(err, AppLoadError::ImportFailed { .. })); }); @@ -374,7 +265,7 @@ mod tests { #[test] fn load_missing_attr() { crate::with_py(|py| { - let loader = ModuleImport::new("json:nonexistent_attr_xyz", false); + let loader = ModuleImport::new("json:nonexistent_attr_xyz"); let err = loader.load_callable(py).unwrap_err(); assert!(matches!(err, AppLoadError::MissingAttribute { .. })); }); @@ -383,7 +274,7 @@ mod tests { #[test] fn load_not_callable() { crate::with_py(|py| { - let loader = ModuleImport::new("json:__name__", false); + let loader = ModuleImport::new("json:__name__"); let err = loader.load_callable(py).unwrap_err(); assert!(matches!(err, AppLoadError::NotCallable { .. })); }); @@ -403,7 +294,7 @@ mod tests { #[test] fn error_display_import_failed() { crate::with_py(|py| { - let loader = ModuleImport::new("nonexistent_module_xyz:app", false); + let loader = ModuleImport::new("nonexistent_module_xyz:app"); let err = loader.load_callable(py).unwrap_err(); let msg = format!("{err}"); assert!(msg.contains("import")); diff --git a/crates/framework/src/asgi/channel_body.rs b/crates/framework/src/asgi/channel_body.rs deleted file mode 100644 index d86d3dc2..00000000 --- a/crates/framework/src/asgi/channel_body.rs +++ /dev/null @@ -1,81 +0,0 @@ -//! Streaming response body backed by a tokio mpsc channel. -//! -//! [`ChannelBody`] wraps `mpsc::UnboundedReceiver` and implements -//! `futures_core::Stream`. It replaces `AsgiBodyStream` for the 3-thread -//! architecture — `SlotSend` pushes chunks from Thread 2, hyper consumes -//! them on Thread 1 via `ResponseBody::Stream`. - -use bytes::Bytes; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::sync::mpsc; - -/// Streaming response body fed by an mpsc channel. -/// -/// EOF is signaled by dropping the sender half. -pub struct ChannelBody { - rx: mpsc::UnboundedReceiver, -} - -impl ChannelBody { - /// Wrap a receiver into a stream of body chunks. - pub fn new(rx: mpsc::UnboundedReceiver) -> Self { - Self { rx } - } -} - -crate::opaque_debug!(ChannelBody); - -impl futures_core::Stream for ChannelBody { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.rx.poll_recv(cx) { - Poll::Ready(Some(chunk)) => Poll::Ready(Some(Ok(chunk))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -// ── Tests ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect(clippy::unwrap_used, reason = "test code uses unwrap for clarity")] -mod tests { - use super::*; - use tokio_stream::StreamExt; - - #[tokio::test] - async fn channel_body_single_chunk() { - let (tx, rx) = mpsc::unbounded_channel(); - tx.send(Bytes::from("hello")).unwrap(); - drop(tx); - - let mut body = ChannelBody::new(rx); - let chunk = body.next().await.unwrap().unwrap(); - assert_eq!(chunk, Bytes::from("hello")); - assert!(body.next().await.is_none()); - } - - #[tokio::test] - async fn channel_body_multiple_chunks() { - let (tx, rx) = mpsc::unbounded_channel(); - tx.send(Bytes::from("hel")).unwrap(); - tx.send(Bytes::from("lo")).unwrap(); - drop(tx); - - let mut body = ChannelBody::new(rx); - assert_eq!(body.next().await.unwrap().unwrap(), Bytes::from("hel")); - assert_eq!(body.next().await.unwrap().unwrap(), Bytes::from("lo")); - assert!(body.next().await.is_none()); - } - - #[tokio::test] - async fn channel_body_empty() { - let (tx, rx) = mpsc::unbounded_channel::(); - drop(tx); - let mut body = ChannelBody::new(rx); - assert!(body.next().await.is_none()); - } -} diff --git a/crates/framework/src/asgi/dispatch.rs b/crates/framework/src/asgi/dispatch.rs deleted file mode 100644 index a340d553..00000000 --- a/crates/framework/src/asgi/dispatch.rs +++ /dev/null @@ -1,310 +0,0 @@ -//! ASGI dispatch — zero-GIL 3-thread HTTP dispatch + legacy WS dispatch. -//! -//! HTTP requests flow through the crossbeam pipeline: -//! Thread 1 (tokio) → crossbeam → Thread 2 (asyncio) → crossbeam → Thread 3 → oneshot → Thread 1 -//! -//! WebSocket upgrades still use the legacy `call_soon_threadsafe(launch_fn, ...)` -//! path until WS is migrated to crossbeam. - -use crate::asgi::channel_body::ChannelBody; -use crate::asgi::scope::ScopeInterns; -use crate::dispatch::Dispatch; -use crate::io::channel::{RequestSlot, ResponseData, SlotBody, Wakeup}; -use crate::protocol::http::error::AppError; -use crate::supervision::worker_context::WorkerContext; -use crate::telemetry::context::TraceContext; -use crate::telemetry::dispatch_metrics; -use crate::telemetry::timed; -use crate::transport::types::{BodyStream, InboundRequest, OutboundResponse, ResponseBody}; -use bytes::Bytes; -use http::header::{HeaderMap, HeaderName, HeaderValue}; -use hyper::body::Incoming; -use hyper::{Request, Response}; -use pyo3::prelude::*; -use std::future::Future; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; - -// ── AsgiDispatch ───────────────────────────────────────────────────────── - -/// ASGI dispatch: HTTP via crossbeam pipeline (no GIL), WS via legacy path. -pub struct AsgiDispatch { - /// Inbound channel sender — pushes `RequestSlot` to Thread 2. - inbound_tx: crossbeam_channel::Sender, - /// Wakeup signal for the asyncio thread. - wakeup: Arc, - /// Maximum request body size in bytes. - body_limit: usize, - - // ── WS legacy fields (until WS migrates to crossbeam) ── - /// The Python ASGI callable. - app: Arc>, - /// Pre-interned scope strings (shared with RequestQueue on Thread 2). - scope_interns: Arc, - /// Shared worker context (carries call_soon_threadsafe + launch_fn for WS). - ctx: Arc, -} - -impl AsgiDispatch { - /// Create a new `AsgiDispatch` with crossbeam pipeline for HTTP. - pub fn new( - inbound_tx: crossbeam_channel::Sender, - wakeup: Arc, - body_limit: usize, - app: Py, - scope_interns: Arc, - ctx: Arc, - ) -> Self { - Self { - inbound_tx, - wakeup, - body_limit, - app: Arc::new(app), - scope_interns, - ctx, - } - } -} - -impl std::fmt::Debug for AsgiDispatch { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("AsgiDispatch") - .field("body_limit", &self.body_limit) - .finish_non_exhaustive() - } -} - -impl Dispatch for AsgiDispatch { - fn dispatch( - &self, - mut request: InboundRequest, - ) -> Pin + Send>> { - let body_stream = request.take_body(); - let body_limit = self.body_limit; - let inbound_tx = self.inbound_tx.clone(); - let wakeup = Arc::clone(&self.wakeup); - - Box::pin(async move { - let result = dispatch_inner(request, body_stream, body_limit, inbound_tx, wakeup).await; - result.unwrap_or_else(error_response) - }) - } - - fn dispatch_ws( - &self, - request: Request, - server_addr: SocketAddr, - client_addr: Option, - ) -> Pin> + Send>> { - let app = Arc::clone(&self.app); - let interns = Arc::clone(&self.scope_interns); - let ctx = Arc::clone(&self.ctx); - - Box::pin(async move { - match crate::protocol::ws::session::handle_upgrade( - request, - server_addr, - client_addr, - app, - interns, - ctx, - ) { - Ok(response) => response, - Err(err) => { - tracing::error!(name: "apx.dispatch.websocket_upgrade_error", error = %err, "websocket upgrade error"); - Response::builder() - .status(http::StatusCode::INTERNAL_SERVER_ERROR) - .header(http::header::CONTENT_TYPE, "text/plain") - .body(ResponseBody::Fixed(Bytes::from_static( - b"Internal Server Error", - ))) - .unwrap_or_else(|_| unreachable!()) - } - } - }) - } -} - -// ── Dispatch internals ─────────────────────────────────────────────────── - -/// Zero-GIL HTTP dispatch: extract trace context, then time the full pipeline. -async fn dispatch_inner( - request: InboundRequest, - body_stream: BodyStream, - body_limit: usize, - inbound_tx: crossbeam_channel::Sender, - wakeup: Arc, -) -> Result { - if let Some(id) = request - .headers - .get(&crate::protocol::http::service::REQUEST_ID_HEADER) - && let Ok(val) = id.to_str() - { - tracing::Span::current().record("request.id", val); - } - - let trace_context = crate::telemetry::context::extract_trace_context(); - - timed!( - dispatch_metrics::record_dispatch_total, - dispatch_pipeline( - request, - body_stream, - body_limit, - inbound_tx, - wakeup, - trace_context - ) - .await - ) -} - -/// Collect body → build RequestSlot → push to crossbeam → await response. -async fn dispatch_pipeline( - request: InboundRequest, - body_stream: BodyStream, - body_limit: usize, - inbound_tx: crossbeam_channel::Sender, - wakeup: Arc, - trace_context: Option, -) -> Result { - let body_bytes = timed!( - dispatch_metrics::record_body_collect, - body_stream - .collect(body_limit) - .await - .map_err(|e| AppError::Internal(format!("body collect: {e}")))? - ); - - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - - let raw_path = Bytes::copy_from_slice(request.path.as_bytes()); - let slot = RequestSlot { - method: request.method.clone(), - path: request.path.clone(), - raw_path, - query_string: request.query_string.clone(), - headers: request.headers.clone(), - body: body_bytes, - protocol: request.protocol, - client_addr: request.client_addr, - server_addr: request.server_addr, - trace_context, - created_at: std::time::Instant::now(), - response_tx, - }; - - timed!(dispatch_metrics::record_crossbeam_send, { - inbound_tx - .send(slot) - .map_err(|_| AppError::Internal("inbound channel closed".to_owned()))?; - wakeup.signal(); - }); - - let response_data = timed!( - dispatch_metrics::record_response_wait, - response_rx - .await - .map_err(|_| AppError::Internal("response channel closed".to_owned()))? - ); - - response_data_to_outbound(response_data) -} - -/// Convert a `ResponseData` from the crossbeam pipeline into an `OutboundResponse`. -fn response_data_to_outbound(data: ResponseData) -> Result { - let status = - http::StatusCode::from_u16(data.status).unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR); - let headers = build_response_headers(&data.headers)?; - let body = match data.body { - SlotBody::Complete(bytes) => ResponseBody::Fixed(bytes), - SlotBody::Chunked(rx) => ResponseBody::Stream(Box::pin(ChannelBody::new(rx))), - }; - - Ok(OutboundResponse { - status, - headers, - body, - server_route: None, - }) -} - -/// Parse raw byte-pair headers into an `http::HeaderMap`. -/// -/// Uses `from_lowercase` because ASGI guarantees response header names -/// are lowercase byte strings, skipping the case-folding pass. -fn build_response_headers(raw: &[(Bytes, Bytes)]) -> Result { - let mut headers = HeaderMap::with_capacity(raw.len()); - for (name, value) in raw { - let header_name = HeaderName::from_lowercase(name) - .map_err(|e| AppError::Internal(format!("invalid header name: {e}")))?; - let header_value = HeaderValue::from_bytes(value) - .map_err(|e| AppError::Internal(format!("invalid header value: {e}")))?; - headers.append(header_name, header_value); - } - Ok(headers) -} - -/// Client-visible body for internal errors. -const INTERNAL_ERROR_BODY: &[u8] = b"Internal Server Error"; - -/// Client-visible body for request timeout. -const TIMEOUT_BODY: &[u8] = b"request timeout"; - -/// Map an [`AppError`] to a generic HTTP error response. -/// -/// The error detail is logged but NOT leaked to the client. -fn error_response(err: AppError) -> OutboundResponse { - let status = err.status_code(); - let body = match &err { - AppError::Timeout => TIMEOUT_BODY, - AppError::Internal(msg) => { - tracing::error!(name: "apx.dispatch.internal_error", error = %msg, "internal dispatch error"); - INTERNAL_ERROR_BODY - } - }; - OutboundResponse { - status, - headers: { - let mut h = HeaderMap::new(); - h.insert( - http::header::CONTENT_TYPE, - HeaderValue::from_static("text/plain"), - ); - h - }, - body: ResponseBody::Fixed(Bytes::from_static(body)), - server_route: None, - } -} - -// ── Tests ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect(clippy::panic, reason = "test code uses unwrap/assert for clarity")] -mod tests { - use super::*; - - #[test] - fn error_response_internal() { - let err = AppError::Internal("db connection failed".to_owned()); - let resp = error_response(err); - assert_eq!(resp.status, http::StatusCode::INTERNAL_SERVER_ERROR); - match &resp.body { - ResponseBody::Fixed(b) => assert_eq!(b.as_ref(), b"Internal Server Error"), - ResponseBody::Stream(_) => panic!("expected Fixed body"), - } - } - - #[test] - fn error_response_timeout() { - let err = AppError::Timeout; - let resp = error_response(err); - assert_eq!(resp.status, http::StatusCode::REQUEST_TIMEOUT); - match &resp.body { - ResponseBody::Fixed(b) => assert_eq!(b.as_ref(), b"request timeout"), - ResponseBody::Stream(_) => panic!("expected Fixed body"), - } - } -} diff --git a/crates/framework/src/asgi/lifespan.rs b/crates/framework/src/asgi/lifespan.rs index 98592a33..62920f66 100644 --- a/crates/framework/src/asgi/lifespan.rs +++ b/crates/framework/src/asgi/lifespan.rs @@ -1,43 +1,24 @@ //! ASGI lifespan protocol — startup/shutdown hooks for the application. //! -//! Implements the ASGI lifespan spec: the server calls `app(scope, receive, send)` -//! with `scope["type"] == "lifespan"`, then exchanges startup/shutdown events via -//! the receive and send callables. +//! Implements the ASGI lifespan spec: the server calls +//! `app(scope, receive, send)` with `scope["type"] == "lifespan"`, +//! then exchanges startup/shutdown events via the receive and send callables. //! -//! The protocol runs on the asyncio thread as a long-lived task. Three tokio -//! oneshot channels bridge it to the tokio thread: -//! - **startup**: `LifespanSend` signals startup result -//! - **shutdown_trigger**: tokio thread tells `LifespanReceive` to deliver shutdown -//! - **shutdown**: `LifespanSend` signals shutdown result +//! Lifespan runs as an asyncio task on the event loop. Signaling uses +//! `asyncio.Event` for startup/shutdown coordination. use pyo3::prelude::*; -use pyo3::types::{PyDict, PyString}; +use pyo3::types::PyDict; use std::sync::Mutex; -use std::time::Duration; -use tokio::sync::oneshot; use super::scope::{ResolvedAwaitable, ResolvedAwaitableWithValue}; -use crate::io::EventLoop; - -// ── Protocol types (pure, no I/O) ──────────────────────────────────────── - -/// Outcome of a lifespan startup or shutdown phase. -#[derive(Debug)] -pub enum LifespanResult { - /// App sent `lifespan.startup.complete` or `lifespan.shutdown.complete`. - Complete, - /// App sent `lifespan.startup.failed` or `lifespan.shutdown.failed`. - Failed(String), - /// App raised during `app(scope, receive, send)` — does not support lifespan. - Unsupported, -} /// Internal state machine for [`LifespanReceive`]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ReceiveState { /// Next call returns `{"type": "lifespan.startup"}`. Startup, - /// Next call blocks until shutdown trigger, then returns `{"type": "lifespan.shutdown"}`. + /// Next call waits for shutdown, then returns `{"type": "lifespan.shutdown"}`. WaitingShutdown, /// No more events — pend forever. Done, @@ -48,28 +29,26 @@ enum ReceiveState { /// ASGI `receive` callable for the lifespan protocol. /// /// First `await receive()` returns `{"type": "lifespan.startup"}` immediately. -/// Second `await receive()` blocks until the server triggers shutdown, then -/// returns `{"type": "lifespan.shutdown"}`. Subsequent calls pend forever. +/// Second `await receive()` waits for the shutdown event, then returns +/// `{"type": "lifespan.shutdown"}`. #[pyclass(module = "apx._core")] pub struct LifespanReceive { state: Mutex, - shutdown_trigger_rx: Mutex>>, + shutdown_event: Py, } crate::opaque_debug!(LifespanReceive); +#[pymethods] impl LifespanReceive { - /// Create a new lifespan receive callable. - pub(crate) fn new(shutdown_trigger_rx: oneshot::Receiver<()>) -> Self { + #[new] + fn new(shutdown_event: Py) -> Self { Self { state: Mutex::new(ReceiveState::Startup), - shutdown_trigger_rx: Mutex::new(Some(shutdown_trigger_rx)), + shutdown_event, } } -} -#[pymethods] -impl LifespanReceive { fn __call__<'py>(&self, py: Python<'py>) -> PyResult> { let mut state = self .state @@ -85,69 +64,74 @@ impl LifespanReceive { .map(|obj| obj.into_bound(py).into_any()) } ReceiveState::WaitingShutdown => { - let rx = self - .shutdown_trigger_rx - .lock() - .map_err(|_| { - pyo3::exceptions::PyRuntimeError::new_err("shutdown trigger mutex poisoned") - })? - .take() - .ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("shutdown already triggered") - })?; *state = ReceiveState::Done; drop(state); - - let handle = crate::io::with_tokio_handle(|h| h.clone()).ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err( - "no tokio runtime for lifespan shutdown wait", - ) - })?; - let _guard = handle.enter(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let _ = rx.await; - Python::attach(|py| { - let event = build_shutdown_event(py)?; - Ok(event) - }) - }) + build_shutdown_awaitable(py, &self.shutdown_event) } ReceiveState::Done => { drop(state); - let handle = crate::io::with_tokio_handle(|h| h.clone()).ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err( - "no tokio runtime for lifespan pending", - ) - })?; - let _guard = handle.enter(); - pyo3_async_runtimes::tokio::future_into_py( - py, - std::future::pending::>>(), - ) + let fut = py + .import(c"asyncio")? + .call_method0(pyo3::intern!(py, "get_running_loop"))? + .call_method0(pyo3::intern!(py, "create_future"))?; + Ok(fut) } } } } -// ── Send event classification (sans-I/O) ───────────────────────────────── +/// Build an awaitable that waits for the shutdown event, then returns the +/// lifespan.shutdown event dict. +fn build_shutdown_awaitable<'py>( + py: Python<'py>, + shutdown_event: &Py, +) -> PyResult> { + let globals = PyDict::new(py); + globals.set_item("_shutdown_event", shutdown_event.bind(py))?; + let locals = PyDict::new(py); + + py.run( + c" +import asyncio + +async def _wait_shutdown(): + await _shutdown_event.wait() + return {'type': 'lifespan.shutdown'} + +_coro = _wait_shutdown() +", + Some(&globals), + Some(&locals), + )?; -/// ASGI lifespan send event type: startup completed successfully. -const STARTUP_COMPLETE: &str = "lifespan.startup.complete"; + let coro = locals + .get_item(pyo3::intern!(py, "_coro"))? + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("failed to create shutdown coro") + })?; + Ok(coro.clone()) +} -/// ASGI lifespan send event type: startup failed. -const STARTUP_FAILED: &str = "lifespan.startup.failed"; +// ── Send event classification ──────────────────────────────────────────── -/// ASGI lifespan send event type: shutdown completed successfully. +/// ASGI lifespan send event types. +const STARTUP_COMPLETE: &str = "lifespan.startup.complete"; +/// ASGI lifespan startup failed event. +const STARTUP_FAILED: &str = "lifespan.startup.failed"; +/// ASGI lifespan shutdown complete event. const SHUTDOWN_COMPLETE: &str = "lifespan.shutdown.complete"; - -/// ASGI lifespan send event type: shutdown failed. +/// ASGI lifespan shutdown failed event. const SHUTDOWN_FAILED: &str = "lifespan.shutdown.failed"; -/// Classified lifespan send event — pure protocol, no I/O. +/// Classified lifespan send event. enum SendEvent { + /// Startup completed. StartupComplete, + /// Startup failed with message. StartupFailed(String), + /// Shutdown completed. ShutdownComplete, + /// Shutdown failed with message. ShutdownFailed(String), } @@ -180,57 +164,52 @@ fn extract_message(event: &Bound<'_, PyDict>) -> PyResult { .map(|opt| opt.unwrap_or_default()) } -/// Send a result through a guarded oneshot channel. -fn signal(tx: &Mutex>>, result: LifespanResult) { - if let Ok(mut guard) = tx.lock() - && let Some(tx) = guard.take() - { - let _ = tx.send(result); - } -} - // ── LifespanSend ───────────────────────────────────────────────────────── /// ASGI `send` callable for the lifespan protocol. /// -/// Parses `lifespan.startup.complete`, `lifespan.startup.failed`, -/// `lifespan.shutdown.complete`, and `lifespan.shutdown.failed` events, -/// signaling results through oneshot channels. +/// Parses startup/shutdown events and signals results via Python +/// `asyncio.Event` objects and a shared result slot. #[pyclass(module = "apx._core")] pub struct LifespanSend { - startup_tx: Mutex>>, - shutdown_tx: Mutex>>, + startup_result: Py, + shutdown_result: Py, + startup_event: Py, + shutdown_done_event: Py, resolved: Py, } crate::opaque_debug!(LifespanSend); +#[pymethods] impl LifespanSend { - /// Create a new lifespan send callable. - pub(crate) fn new( + #[new] + fn new( py: Python<'_>, - startup_tx: oneshot::Sender, - shutdown_tx: oneshot::Sender, + startup_event: Py, + startup_result: Py, + shutdown_done_event: Py, + shutdown_result: Py, ) -> PyResult { Ok(Self { - startup_tx: Mutex::new(Some(startup_tx)), - shutdown_tx: Mutex::new(Some(shutdown_tx)), + startup_result, + shutdown_result, + startup_event, + shutdown_done_event, resolved: Py::new(py, ResolvedAwaitable)?, }) } -} -#[pymethods] -impl LifespanSend { - /// Forward an unhandled app exception — signals lifespan unsupported or shutdown failed. - fn send_error(&self, traceback: String) { - if let Ok(mut guard) = self.startup_tx.lock() - && let Some(tx) = guard.take() - { - let _ = tx.send(LifespanResult::Unsupported); - return; - } - signal(&self.shutdown_tx, LifespanResult::Failed(traceback)); + /// Forward an unhandled app exception — signals lifespan unsupported. + fn send_error(&self, py: Python<'_>, _traceback: String) -> PyResult<()> { + self.startup_result.call_method1( + py, + pyo3::intern!(py, "__setitem__"), + (0, "unsupported"), + )?; + self.startup_event + .call_method0(py, pyo3::intern!(py, "set"))?; + Ok(()) } fn __call__<'py>( @@ -240,16 +219,39 @@ impl LifespanSend { ) -> PyResult> { match classify_send_event(&event)? { SendEvent::StartupComplete => { - signal(&self.startup_tx, LifespanResult::Complete); + self.startup_result.call_method1( + py, + pyo3::intern!(py, "__setitem__"), + (0, "complete"), + )?; + self.startup_event + .call_method0(py, pyo3::intern!(py, "set"))?; } SendEvent::StartupFailed(msg) => { - signal(&self.startup_tx, LifespanResult::Failed(msg)); + let val = format!("failed:{msg}"); + self.startup_result + .call_method1(py, pyo3::intern!(py, "__setitem__"), (0, val))?; + self.startup_event + .call_method0(py, pyo3::intern!(py, "set"))?; } SendEvent::ShutdownComplete => { - signal(&self.shutdown_tx, LifespanResult::Complete); + self.shutdown_result.call_method1( + py, + pyo3::intern!(py, "__setitem__"), + (0, "complete"), + )?; + self.shutdown_done_event + .call_method0(py, pyo3::intern!(py, "set"))?; } SendEvent::ShutdownFailed(msg) => { - signal(&self.shutdown_tx, LifespanResult::Failed(msg)); + let val = format!("failed:{msg}"); + self.shutdown_result.call_method1( + py, + pyo3::intern!(py, "__setitem__"), + (0, val), + )?; + self.shutdown_done_event + .call_method0(py, pyo3::intern!(py, "set"))?; } } Ok(self.resolved.clone_ref(py).into_bound(py).into_any()) @@ -258,10 +260,13 @@ impl LifespanSend { // ── Scope builder ──────────────────────────────────────────────────────── -use super::{ASGI_SPEC_VERSION, ASGI_VERSION}; - /// Build the ASGI lifespan scope dict. -fn build_lifespan_scope(py: Python<'_>) -> PyResult> { +#[cfg(test)] +pub fn build_lifespan_scope(py: Python<'_>) -> PyResult> { + use pyo3::types::PyString; + + use super::{ASGI_SPEC_VERSION, ASGI_VERSION}; + let scope = PyDict::new(py); scope.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "lifespan"))?; @@ -275,8 +280,8 @@ fn build_lifespan_scope(py: Python<'_>) -> PyResult> { PyString::intern(py, ASGI_SPEC_VERSION), )?; scope.set_item(pyo3::intern!(py, "asgi"), asgi)?; - scope.set_item(pyo3::intern!(py, "state"), PyDict::new(py))?; + Ok(scope.unbind()) } @@ -290,123 +295,6 @@ fn build_startup_event(py: Python<'_>) -> PyResult> { Ok(event.into_any().unbind()) } -/// Build `{"type": "lifespan.shutdown"}` event for receive. -fn build_shutdown_event(py: Python<'_>) -> PyResult> { - let event = PyDict::new(py); - event.set_item( - pyo3::intern!(py, "type"), - pyo3::intern!(py, "lifespan.shutdown"), - )?; - Ok(event.into_any().unbind()) -} - -// ── Handles ────────────────────────────────────────────────────────────── - -/// Pre-startup handle — awaiting startup result. -/// -/// Returned by [`launch_lifespan`]. Call [`wait_startup`](Self::wait_startup) -/// to consume the startup channel and obtain a [`LifespanHandle`] for shutdown. -pub struct LifespanPending { - startup_rx: oneshot::Receiver, - shutdown_trigger_tx: oneshot::Sender<()>, - shutdown_rx: oneshot::Receiver, -} - -crate::opaque_debug!(LifespanPending); - -/// Lifespan startup timeout — if the app does not respond within this -/// duration, startup is treated as a failure. -const STARTUP_TIMEOUT: Duration = Duration::from_secs(30); - -impl LifespanPending { - /// Wait for the app to complete lifespan startup. - /// - /// Returns `Ok(Some(handle))` on success, `Ok(None)` if the app does - /// not support lifespan, or `Err(message)` on failure or timeout. - pub async fn wait_startup(self) -> Result, String> { - let result = tokio::time::timeout(STARTUP_TIMEOUT, self.startup_rx).await; - - match result { - Ok(Ok(LifespanResult::Complete)) => Ok(Some(LifespanHandle { - shutdown_trigger_tx: Some(self.shutdown_trigger_tx), - shutdown_rx: Some(self.shutdown_rx), - })), - Ok(Ok(LifespanResult::Unsupported)) => Ok(None), - Ok(Ok(LifespanResult::Failed(msg))) => Err(msg), - Ok(Err(_)) => Err("lifespan task died unexpectedly".to_owned()), - Err(_) => Err("lifespan startup timed out (30s)".to_owned()), - } - } -} - -/// Post-startup handle — the lifespan coroutine is alive and waiting for shutdown. -pub struct LifespanHandle { - shutdown_trigger_tx: Option>, - shutdown_rx: Option>, -} - -crate::opaque_debug!(LifespanHandle); - -/// Lifespan shutdown timeout. -const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30); - -impl LifespanHandle { - /// Trigger lifespan shutdown and wait for completion. - pub async fn trigger_shutdown(mut self) -> Result<(), String> { - if let Some(tx) = self.shutdown_trigger_tx.take() { - let _ = tx.send(()); - } - let Some(rx) = self.shutdown_rx.take() else { - return Ok(()); - }; - match tokio::time::timeout(SHUTDOWN_TIMEOUT, rx).await { - Ok(Ok(LifespanResult::Failed(msg))) => Err(msg), - Ok(Ok(LifespanResult::Complete | LifespanResult::Unsupported) | Err(_)) => Ok(()), - Err(_) => { - tracing::warn!( - name: "apx.lifespan.shutdown_timeout", - "lifespan shutdown timed out (30s)" - ); - Ok(()) - } - } - } -} - -// ── Launcher ───────────────────────────────────────────────────────────── - -/// Launch the ASGI lifespan protocol on the asyncio thread. -/// -/// Builds the lifespan scope, receive, and send callables, then submits -/// `launch(app, scope, receive, send)` via `call_soon_threadsafe`. -/// Returns a [`LifespanPending`] for awaiting the startup result. -pub fn launch_lifespan( - py: Python<'_>, - event_loop: &EventLoop, - app: &Py, - launch_fn: &Py, -) -> PyResult { - let (startup_tx, startup_rx) = oneshot::channel(); - let (shutdown_trigger_tx, shutdown_trigger_rx) = oneshot::channel(); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - - let scope = build_lifespan_scope(py)?; - let receive = Py::new(py, LifespanReceive::new(shutdown_trigger_rx))?; - let send = Py::new(py, LifespanSend::new(py, startup_tx, shutdown_tx)?)?; - - event_loop - .call_soon_threadsafe() - .call1(py, (launch_fn, app, &scope, &receive, &send))?; - - Ok(LifespanPending { - startup_rx, - shutdown_trigger_tx, - shutdown_rx, - }) -} - -// ── Tests ──────────────────────────────────────────────────────────────── - #[cfg(test)] #[expect( clippy::unwrap_used, @@ -437,7 +325,7 @@ mod tests { } #[test] - fn startup_event_has_correct_type() { + fn build_startup_event_type() { with_py(|py| { let event = build_startup_event(py).unwrap(); let event = event.bind(py); @@ -445,281 +333,4 @@ mod tests { assert_eq!(t, "lifespan.startup"); }); } - - #[test] - fn shutdown_event_has_correct_type() { - with_py(|py| { - let event = build_shutdown_event(py).unwrap(); - let event = event.bind(py); - let t: String = event.get_item("type").unwrap().extract().unwrap(); - assert_eq!(t, "lifespan.shutdown"); - }); - } - - #[test] - fn lifespan_send_startup_complete() { - with_py(|py| { - let (startup_tx, mut startup_rx) = oneshot::channel(); - let (shutdown_tx, _shutdown_rx) = oneshot::channel(); - let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); - - let event = PyDict::new(py); - event.set_item("type", "lifespan.startup.complete").unwrap(); - send.__call__(py, event).unwrap(); - - let result = startup_rx.try_recv().unwrap(); - assert!(matches!(result, LifespanResult::Complete)); - }); - } - - #[test] - fn lifespan_send_startup_failed() { - with_py(|py| { - let (startup_tx, mut startup_rx) = oneshot::channel(); - let (shutdown_tx, _shutdown_rx) = oneshot::channel(); - let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); - - let event = PyDict::new(py); - event.set_item("type", "lifespan.startup.failed").unwrap(); - event.set_item("message", "db connection refused").unwrap(); - send.__call__(py, event).unwrap(); - - let result = startup_rx.try_recv().unwrap(); - assert!( - matches!(&result, LifespanResult::Failed(msg) if msg == "db connection refused"), - "expected Failed(\"db connection refused\"), got {result:?}" - ); - }); - } - - #[test] - fn lifespan_send_startup_failed_no_message() { - with_py(|py| { - let (startup_tx, mut startup_rx) = oneshot::channel(); - let (shutdown_tx, _shutdown_rx) = oneshot::channel(); - let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); - - let event = PyDict::new(py); - event.set_item("type", "lifespan.startup.failed").unwrap(); - send.__call__(py, event).unwrap(); - - let result = startup_rx.try_recv().unwrap(); - assert!( - matches!(&result, LifespanResult::Failed(msg) if msg.is_empty()), - "expected Failed(\"\"), got {result:?}" - ); - }); - } - - #[test] - fn lifespan_send_shutdown_complete() { - with_py(|py| { - let (startup_tx, _startup_rx) = oneshot::channel(); - let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); - let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); - - // Consume startup first (simulates normal flow). - let event = PyDict::new(py); - event.set_item("type", "lifespan.startup.complete").unwrap(); - send.__call__(py, event).unwrap(); - - let event = PyDict::new(py); - event - .set_item("type", "lifespan.shutdown.complete") - .unwrap(); - send.__call__(py, event).unwrap(); - - let result = shutdown_rx.try_recv().unwrap(); - assert!(matches!(result, LifespanResult::Complete)); - }); - } - - #[test] - fn lifespan_send_shutdown_failed() { - with_py(|py| { - let (startup_tx, _startup_rx) = oneshot::channel(); - let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); - let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); - - // Consume startup first. - let event = PyDict::new(py); - event.set_item("type", "lifespan.startup.complete").unwrap(); - send.__call__(py, event).unwrap(); - - let event = PyDict::new(py); - event.set_item("type", "lifespan.shutdown.failed").unwrap(); - event.set_item("message", "cleanup error").unwrap(); - send.__call__(py, event).unwrap(); - - let result = shutdown_rx.try_recv().unwrap(); - assert!( - matches!(&result, LifespanResult::Failed(msg) if msg == "cleanup error"), - "expected Failed(\"cleanup error\"), got {result:?}" - ); - }); - } - - #[test] - fn lifespan_send_unknown_event_type() { - with_py(|py| { - let (startup_tx, _startup_rx) = oneshot::channel(); - let (shutdown_tx, _shutdown_rx) = oneshot::channel(); - let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); - - let event = PyDict::new(py); - event.set_item("type", "lifespan.unknown").unwrap(); - let result = send.__call__(py, event); - assert!(result.is_err()); - }); - } - - #[test] - fn send_error_during_startup_signals_unsupported() { - with_py(|py| { - let (startup_tx, mut startup_rx) = oneshot::channel(); - let (shutdown_tx, _shutdown_rx) = oneshot::channel(); - let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); - - send.send_error("TypeError: ...".to_owned()); - - let result = startup_rx.try_recv().unwrap(); - assert!(matches!(result, LifespanResult::Unsupported)); - }); - } - - #[test] - fn send_error_during_shutdown_signals_failed() { - with_py(|py| { - let (startup_tx, _startup_rx) = oneshot::channel(); - let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); - let send = LifespanSend::new(py, startup_tx, shutdown_tx).unwrap(); - - // Consume startup to transition phase. - let event = PyDict::new(py); - event.set_item("type", "lifespan.startup.complete").unwrap(); - send.__call__(py, event).unwrap(); - - send.send_error("RuntimeError: cleanup failed".to_owned()); - - let result = shutdown_rx.try_recv().unwrap(); - assert!( - matches!(&result, LifespanResult::Failed(msg) if msg.contains("cleanup failed")), - "expected Failed containing \"cleanup failed\", got {result:?}" - ); - }); - } - - #[test] - fn lifespan_receive_first_call_returns_startup() { - with_py(|py| { - let (_tx, rx) = oneshot::channel(); - let receive = LifespanReceive::new(rx); - - let awaitable = receive.__call__(py).unwrap(); - // The awaitable should be a ResolvedAwaitableWithValue. - // We can check it implements __await__. - assert!(awaitable.hasattr("__await__").unwrap()); - }); - } - - #[test] - fn lifespan_result_debug() { - let c = LifespanResult::Complete; - assert!(format!("{c:?}").contains("Complete")); - let f = LifespanResult::Failed("err".to_owned()); - assert!(format!("{f:?}").contains("Failed")); - let u = LifespanResult::Unsupported; - assert!(format!("{u:?}").contains("Unsupported")); - } - - #[tokio::test] - async fn lifespan_pending_wait_startup_complete() { - let (startup_tx, startup_rx) = oneshot::channel(); - let (shutdown_trigger_tx, _shutdown_trigger_rx) = oneshot::channel(); - let (_shutdown_tx, shutdown_rx) = oneshot::channel(); - - let pending = LifespanPending { - startup_rx, - shutdown_trigger_tx, - shutdown_rx, - }; - - let _ = startup_tx.send(LifespanResult::Complete); - let result = pending.wait_startup().await; - assert!(result.is_ok()); - assert!(result.unwrap().is_some()); - } - - #[tokio::test] - async fn lifespan_pending_wait_startup_unsupported() { - let (startup_tx, startup_rx) = oneshot::channel(); - let (shutdown_trigger_tx, _shutdown_trigger_rx) = oneshot::channel(); - let (_shutdown_tx, shutdown_rx) = oneshot::channel(); - - let pending = LifespanPending { - startup_rx, - shutdown_trigger_tx, - shutdown_rx, - }; - - let _ = startup_tx.send(LifespanResult::Unsupported); - let result = pending.wait_startup().await; - assert!(result.unwrap().is_none()); - } - - #[tokio::test] - async fn lifespan_pending_wait_startup_failed() { - let (startup_tx, startup_rx) = oneshot::channel(); - let (shutdown_trigger_tx, _shutdown_trigger_rx) = oneshot::channel(); - let (_shutdown_tx, shutdown_rx) = oneshot::channel(); - - let pending = LifespanPending { - startup_rx, - shutdown_trigger_tx, - shutdown_rx, - }; - - let _ = startup_tx.send(LifespanResult::Failed("db error".to_owned())); - let result = pending.wait_startup().await; - assert_eq!(result.unwrap_err(), "db error"); - } - - #[tokio::test] - async fn lifespan_handle_trigger_shutdown_complete() { - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let (trigger_tx, trigger_rx) = oneshot::channel(); - - let handle = LifespanHandle { - shutdown_trigger_tx: Some(trigger_tx), - shutdown_rx: Some(shutdown_rx), - }; - - // Simulate the app responding to shutdown trigger. - tokio::spawn(async move { - let _ = trigger_rx.await; - let _ = shutdown_tx.send(LifespanResult::Complete); - }); - - let result = handle.trigger_shutdown().await; - assert!(result.is_ok()); - } - - #[tokio::test] - async fn lifespan_handle_trigger_shutdown_failed() { - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let (trigger_tx, trigger_rx) = oneshot::channel(); - - let handle = LifespanHandle { - shutdown_trigger_tx: Some(trigger_tx), - shutdown_rx: Some(shutdown_rx), - }; - - tokio::spawn(async move { - let _ = trigger_rx.await; - let _ = shutdown_tx.send(LifespanResult::Failed("cleanup err".to_owned())); - }); - - let result = handle.trigger_shutdown().await; - assert_eq!(result.unwrap_err(), "cleanup err"); - } } diff --git a/crates/framework/src/asgi/mod.rs b/crates/framework/src/asgi/mod.rs index 22bfdde7..7e54a14d 100644 --- a/crates/framework/src/asgi/mod.rs +++ b/crates/framework/src/asgi/mod.rs @@ -1,7 +1,7 @@ //! Python ASGI boundary layer. //! -//! Translates Rust domain types (InboundRequest, OutboundResponse) to/from -//! ASGI protocol objects (scope, receive, send). +//! Translates Rust domain types to/from ASGI protocol objects +//! (scope, receive, send). /// ASGI protocol version string. pub const ASGI_VERSION: &str = "3.0"; @@ -10,11 +10,5 @@ pub const ASGI_VERSION: &str = "3.0"; pub const ASGI_SPEC_VERSION: &str = "2.4"; pub mod app; -pub mod channel_body; -pub mod dispatch; pub mod lifespan; -pub mod queue; pub mod scope; -pub mod slot_receive; -pub mod slot_send; -pub mod streaming; diff --git a/crates/framework/src/asgi/queue.rs b/crates/framework/src/asgi/queue.rs deleted file mode 100644 index d2ea271d..00000000 --- a/crates/framework/src/asgi/queue.rs +++ /dev/null @@ -1,131 +0,0 @@ -//! Request queue exposed to the Python asyncio thread. -//! -//! [`RequestQueue`] wraps the inbound crossbeam receiver and builds -//! `(scope, receive, send)` tuples on each `try_recv()`. The scope dict -//! is built via `scope_from_template` (the same optimized path used by -//! the existing dispatch), and runs entirely on Thread 2 (100% GIL). - -use crate::asgi::scope::{ - ResolvedAwaitable, ScopeInterns, build_receive_template, scope_from_template, -}; -use crate::asgi::slot_receive::SlotReceive; -use crate::asgi::slot_send::SlotSend; -use crate::io::channel::{InboundChannel, RequestSlot, Wakeup}; -use pyo3::prelude::*; -use pyo3::types::PyDict; -use std::sync::Arc; - -// ── RequestQueue ───────────────────────────────────────────────────────── - -/// Python-visible request queue for the 2-thread dispatch pipeline. -/// -/// Created once per worker, passed to `install_dispatch()` in Python. -/// `try_recv()` is called from the `_on_readable` callback whenever -/// the wakeup pipe signals new requests. -#[pyclass(module = "apx._core")] -pub struct RequestQueue { - inbound_rx: crossbeam_channel::Receiver, - wakeup: Arc, - scope_interns: Arc, - receive_template: Py, - resolved: Py, - dev_mode: bool, -} - -crate::opaque_debug!(RequestQueue); - -impl RequestQueue { - /// Create a new request queue from the inbound channel and scope interns. - /// - /// Must be called with the GIL held (needs `py` for template construction). - pub fn new( - py: Python<'_>, - inbound: &InboundChannel, - wakeup: Arc, - scope_interns: Arc, - dev_mode: bool, - ) -> PyResult { - let receive_template = build_receive_template(py)?; - let resolved = Py::new(py, ResolvedAwaitable)?; - Ok(Self { - inbound_rx: inbound.receiver().clone(), - wakeup, - scope_interns, - receive_template, - resolved, - dev_mode, - }) - } -} - -#[pymethods] -impl RequestQueue { - /// Try to receive one request, returning `(scope, receive, send)` or `None`. - /// - /// Called from the asyncio `_on_readable` callback. Non-blocking — - /// returns `None` immediately when the queue is empty, clearing the - /// wakeup coalescing flag so the next `signal()` writes a fresh byte. - fn try_recv<'py>(&self, py: Python<'py>) -> PyResult>> { - use crate::telemetry::dispatch_metrics; - - dispatch_metrics::record_queue_depth(self.inbound_rx.len() as f64); - - let slot = match self.inbound_rx.try_recv() { - Ok(slot) => slot, - Err( - crossbeam_channel::TryRecvError::Empty - | crossbeam_channel::TryRecvError::Disconnected, - ) => { - // Clear the coalescing flag so the next signal() writes a - // fresh wakeup byte. We must re-check the channel AFTER - // clearing: a signal() racing between our first try_recv - // (Empty) and this drain() would lose the CAS (pending was - // still true) and never write a byte. The re-check picks - // up that orphaned item. - self.wakeup.drain(); - match self.inbound_rx.try_recv() { - Ok(slot) => slot, - Err(_) => return Ok(None), - } - } - }; - - dispatch_metrics::record_pickup_delay(slot.created_at.elapsed().as_micros() as f64); - - self.materialize(py, slot).map(Some) - } -} - -impl RequestQueue { - /// Build `(scope, receive, send)` Python tuple from a pure-Rust `RequestSlot`. - fn materialize<'py>( - &self, - py: Python<'py>, - slot: RequestSlot, - ) -> PyResult> { - use crate::telemetry::{dispatch_metrics, timed}; - - timed!(dispatch_metrics::record_materialize, { - if let Some(ref ctx) = slot.trace_context { - crate::telemetry::context::set_python_context(py, ctx)?; - } - - let scope = scope_from_template( - py, - &self.scope_interns.scope_template, - &slot, - None, - &self.scope_interns, - )?; - - let receive = SlotReceive::new(slot.body, self.receive_template.clone_ref(py)); - let receive_obj = Py::new(py, receive)?.into_bound(py).into_any(); - - let send = SlotSend::new(slot.response_tx, self.resolved.clone_ref(py), self.dev_mode); - let send_obj = Py::new(py, send)?.into_bound(py).into_any(); - - let scope_any = scope.into_bound(py).into_any(); - pyo3::types::PyTuple::new(py, [scope_any, receive_obj, send_obj]) - }) - } -} diff --git a/crates/framework/src/asgi/scope.rs b/crates/framework/src/asgi/scope.rs index fa099d6a..358fe2f8 100644 --- a/crates/framework/src/asgi/scope.rs +++ b/crates/framework/src/asgi/scope.rs @@ -1,61 +1,40 @@ -//! ASGI protocol primitives backed by Rust. +//! ASGI scope interning and template building. //! -//! Provides `AsgiReceive`, `AsgiSend` (Python callables), `scope_from_template`, -//! and `build_ws_scope` for constructing ASGI scope dicts from [`InboundRequest`]. -//! -//! These types enable Starlette's `Request`, `StreamingResponse`, and `WebSocket` -//! to work unmodified against a Rust-backed ASGI server. +//! Provides [`ScopeInterns`] for pre-building scope dictionaries and +//! [`ResolvedAwaitable`] / [`ResolvedAwaitableWithValue`] for zero-overhead +//! Python awaitables. -use crate::io::channel::RequestSlot; -use crate::protocol::http::error::AppError; -use crate::transport::types::{InboundRequest, OutboundResponse, ProtocolVersion, ResponseBody}; -use bytes::Bytes; -use http::header::{self, HeaderMap, HeaderName, HeaderValue}; +use crate::transport::types::ProtocolVersion; +use http::header::{self, HeaderName}; use pyo3::prelude::*; -use pyo3::pybacked::PyBackedBytes; -use pyo3::types::{PyBytes, PyDict, PyDictMethods, PyList, PyString, PyTuple}; -use std::borrow::Cow; +use pyo3::types::{PyBytes, PyDict, PyString, PyTuple}; use std::net::SocketAddr; -use std::sync::Arc; -use tokio::sync::{Mutex, mpsc, oneshot}; use super::{ASGI_SPEC_VERSION, ASGI_VERSION}; /// Default HTTP scheme (TLS detection is a future extension). const DEFAULT_SCHEME: &str = "http"; -/// Default WebSocket scheme. -const WS_SCHEME: &str = "ws"; - // ── ScopeInterns ───────────────────────────────────────────────────────── crate::opaque_debug!(ScopeInterns); /// Pre-interned Python strings for ASGI scope construction. /// -/// Created once at worker startup, shared across all requests via `AppState`. +/// Created once at worker startup, shared across all requests. /// Eliminates ~25 transient `PyString` allocations per request. pub struct ScopeInterns { - // ── Scope dict keys ── /// Fixed keys used in every ASGI scope dict. pub(crate) keys: ScopeKeys, - // ── Scope dict fixed values ── - /// Fixed values (type strings, version strings, empty root_path). - pub(crate) vals: ScopeValues, - // ── Header name cache ── /// Cached `PyBytes` for common HTTP header names. pub(crate) headers: HeaderInterns, - // ── Per-worker address cache ── /// Pre-built `(host_str, port)` tuple for the server address. pub(crate) server_tuple: Py, - // ── HTTP version interns ── /// Cached `PyString` for HTTP protocol versions. pub(crate) versions: VersionInterns, - // ── Scope template ── /// Pre-built HTTP scope dict with fixed fields. `dict.copy()` per request. pub(crate) scope_template: Py, - // ── Cached empty dict ── - /// Shared empty dict for parameterless routes (avoids `PyDict::new` per request). + /// Shared empty dict for parameterless routes. pub(crate) empty_dict: Py, } @@ -75,20 +54,9 @@ pub struct ScopeKeys { pub(crate) root_path: Py, pub(crate) state: Py, pub(crate) path_params: Py, - pub(crate) app: Py, - pub(crate) router: Py, } -/// Fixed dict values used in ASGI scope construction. -pub struct ScopeValues { - pub(crate) type_http: Py, - pub(crate) type_websocket: Py, - pub(crate) scheme_http: Py, - pub(crate) scheme_ws: Py, - pub(crate) root_path_empty: Py, - /// Pre-built `{"version": "3.0", "spec_version": "2.3"}` dict, shared per-request. - pub(crate) asgi_dict: Py, -} +crate::opaque_debug!(ScopeKeys); /// Common HTTP header names, ordered by frequency in typical HTTP/1.1 traffic. const COMMON_HEADERS: &[HeaderName] = &[ @@ -112,13 +80,12 @@ const COMMON_HEADERS: &[HeaderName] = &[ ]; /// Pre-built `PyBytes` for common HTTP header names. -/// -/// `http::HeaderName` standard constants compare by pointer, so the -/// lookup is a pointer match — not a string hash. pub struct HeaderInterns { - map: Vec<(HeaderName, Py)>, + pub(crate) map: Vec<(HeaderName, Py)>, } +crate::opaque_debug!(HeaderInterns); + impl HeaderInterns { /// Create cached `PyBytes` for common header names. Call once at worker startup. pub fn new(py: Python<'_>) -> Self { @@ -128,20 +95,8 @@ impl HeaderInterns { .collect(); Self { map } } - - /// Look up a cached `PyBytes` for this header name. - /// Returns `None` for non-standard headers (fallback to `PyBytes::new`). - pub fn get<'py>(&self, py: Python<'py>, name: &HeaderName) -> Option> { - self.map - .iter() - .find(|(h, _)| h == name) - .map(|(_, cached)| cached.bind(py).clone()) - } } -/// Pre-interned `PyString` for common HTTP methods. -/// -/// Uses pointer comparison on `http::Method` constants for O(1) lookup. /// Pre-interned `PyString` for HTTP protocol versions ("1.0", "1.1", "2"). pub struct VersionInterns { http10: Py, @@ -149,6 +104,8 @@ pub struct VersionInterns { h2: Py, } +crate::opaque_debug!(VersionInterns); + impl VersionInterns { /// Create cached `PyString` for protocol versions. Call once at worker startup. fn new(py: Python<'_>) -> Self { @@ -169,28 +126,6 @@ impl VersionInterns { } } -// ── SendCache ──────────────────────────────────────────────────────────── - -/// Cached Python objects for the ASGI send path. -/// -/// Separate from `ScopeInterns` (smallest possible scope): scope-building -/// code never touches these, and send code never touches scope interns. -pub struct SendCache { - /// Singleton `ResolvedAwaitable` — stateless, reused via `clone_ref`. - pub(crate) resolved: Py, -} - -crate::opaque_debug!(SendCache); - -impl SendCache { - /// Create the send cache. Call once at worker startup with GIL held. - pub fn new(py: Python<'_>) -> PyResult { - Ok(Self { - resolved: Py::new(py, ResolvedAwaitable)?, - }) - } -} - impl ScopeInterns { /// Create all interned strings and cached objects. /// @@ -241,26 +176,15 @@ impl ScopeInterns { root_path: s("root_path"), state: s("state"), path_params: s("path_params"), - app: s("app"), - router: s("router"), - }; - let vals = ScopeValues { - type_http: s("http"), - type_websocket: s("websocket"), - scheme_http: s(DEFAULT_SCHEME), - scheme_ws: s(WS_SCHEME), - root_path_empty: s(""), - asgi_dict: asgi_dict.unbind(), }; let versions = VersionInterns::new(py); - // Build scope template with fixed HTTP fields pre-populated. let scope_template = { let tpl = PyDict::new(py); - let _ = tpl.set_item(keys.r#type.bind(py), vals.type_http.bind(py)); - let _ = tpl.set_item(keys.asgi.bind(py), vals.asgi_dict.bind(py)); - let _ = tpl.set_item(keys.scheme.bind(py), vals.scheme_http.bind(py)); - let _ = tpl.set_item(keys.root_path.bind(py), vals.root_path_empty.bind(py)); + let _ = tpl.set_item(keys.r#type.bind(py), pyo3::intern!(py, "http")); + let _ = tpl.set_item(keys.asgi.bind(py), &asgi_dict); + let _ = tpl.set_item(keys.scheme.bind(py), pyo3::intern!(py, DEFAULT_SCHEME)); + let _ = tpl.set_item(keys.root_path.bind(py), pyo3::intern!(py, "")); let _ = tpl.set_item(keys.http_version.bind(py), versions.http11.bind(py)); let _ = tpl.set_item(keys.server.bind(py), server_tuple.bind(py)); tpl.unbind() @@ -268,7 +192,6 @@ impl ScopeInterns { Self { keys, - vals, headers: HeaderInterns::new(py), server_tuple, versions, @@ -278,174 +201,14 @@ impl ScopeInterns { } } -// ── AsgiEvent ──────────────────────────────────────────────────────────── - -/// Parsed ASGI send event (Rust-side representation). -/// -/// Pushed through a channel from [`AsgiSend`] (Python side) to the response -/// collector (Rust side) that assembles the final HTTP response or relays -/// WebSocket frames. -#[derive(Debug)] -pub enum AsgiEvent { - /// `http.response.start` — status code and headers. - ResponseStart { - /// HTTP status code. - status: u16, - /// Response headers, built directly from Python bytes. - headers: HeaderMap, - }, - /// `http.response.body` — body chunk with continuation flag. - ResponseBody { - /// Body bytes. - body: Bytes, - /// Whether more body chunks follow. - more_body: bool, - }, - /// `websocket.accept` — server accepts the WebSocket connection. - WsAccept { - /// Optional subprotocol. - subprotocol: Option, - /// Response headers as raw byte pairs. - headers: Vec<(Vec, Vec)>, - }, - /// `websocket.send` — server sends a frame to the client. - WsSend { - /// Text frame payload. - text: Option, - /// Binary frame payload (zero-copy from Python via `PyBackedBytes`). - bytes: Option, - }, - /// `websocket.close` — server closes the connection. - WsClose { - /// WebSocket close code (default 1000). - code: u16, - }, -} - -// ── AsgiReceive ────────────────────────────────────────────────────────── - -/// Build the `receive` template dict: `{"type": "http.request", "body": b"", "more_body": False}`. -/// -/// Created once per worker, cloned per-request via `PyDict::copy`. This is -/// faster than building 3 dict keys from scratch each time. -pub fn build_receive_template(py: Python<'_>) -> PyResult> { - let d = PyDict::new(py); - d.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?; - d.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, b""))?; - d.set_item(pyo3::intern!(py, "more_body"), false)?; - Ok(d.unbind()) -} - -/// ASGI `receive` callable backed by Rust. -/// -/// For HTTP: first call returns `http.request` with the pre-buffered body -/// synchronously (via `ResolvedAwaitableWithValue`, no tokio task overhead). -/// Subsequent calls pend forever via `future_into_py` + `pending()`, -/// preventing Starlette's `listen_for_disconnect` from prematurely firing. -#[pyclass(module = "apx._core", freelist = 64)] -pub struct AsgiReceive { - body: std::sync::Mutex>, - disconnect_rx: std::sync::Mutex>>, - receive_template: Py, -} - -crate::opaque_debug!(AsgiReceive); - -impl AsgiReceive { - /// Create for an HTTP request with a known body. - pub fn http( - body: Bytes, - disconnect_rx: oneshot::Receiver<()>, - receive_template: Py, - ) -> Self { - Self { - body: std::sync::Mutex::new(Some(body)), - disconnect_rx: std::sync::Mutex::new(Some(disconnect_rx)), - receive_template, - } - } - - /// Create for an HTTP request with no body (GET, HEAD, DELETE). - pub fn empty(disconnect_rx: oneshot::Receiver<()>, receive_template: Py) -> Self { - Self { - body: std::sync::Mutex::new(Some(Bytes::new())), - disconnect_rx: std::sync::Mutex::new(Some(disconnect_rx)), - receive_template, - } - } -} - -#[pymethods] -impl AsgiReceive { - /// Python: `event = await receive()` - /// - /// First call: returns body synchronously via `ResolvedAwaitableWithValue` - /// (no tokio task, no `future_into_py` overhead). - /// Subsequent calls: pend forever via `future_into_py` + `pending()` - /// (proper asyncio suspension for the disconnect listener). - fn __call__<'py>(&self, py: Python<'py>) -> PyResult> { - let taken = self - .body - .lock() - .map_err(|_| pyo3::exceptions::PyRuntimeError::new_err("receive mutex poisoned"))? - .take(); - - if let Some(bytes) = taken { - let event = crate::telemetry::timed!( - crate::telemetry::dispatch_metrics::record_receive_build, - { - let event = self.receive_template.bind(py).copy()?; - event.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &bytes))?; - event - } - ); - let event = event.unbind().into_any(); - Py::new(py, ResolvedAwaitableWithValue { value: Some(event) }) - .map(|obj| obj.into_bound(py).into_any()) - } else { - let maybe_disconnect = self - .disconnect_rx - .lock() - .map_err(|_| { - pyo3::exceptions::PyRuntimeError::new_err("disconnect mutex poisoned") - })? - .take(); - let handle = crate::io::with_tokio_handle(|h| h.clone()).ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("no tokio runtime for disconnect watch") - })?; - let _guard = handle.enter(); - if let Some(disconnect_rx) = maybe_disconnect { - let disconnect_type = pyo3::intern!(py, "http.disconnect").clone().unbind(); - let type_key = pyo3::intern!(py, "type").clone().unbind(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let _ = disconnect_rx.await; - Python::attach(|py| -> PyResult> { - let event = PyDict::new(py); - event.set_item(&type_key, &disconnect_type)?; - Ok(event.unbind().into_any()) - }) - }) - } else { - pyo3_async_runtimes::tokio::future_into_py( - py, - std::future::pending::>>(), - ) - } - } - } -} - -// ── ResolvedAwaitable ───────────────────────────────────────────────────── +// ── ResolvedAwaitable ──────────────────────────────────────────────────── /// Zero-overhead Python awaitable that completes immediately. /// -/// Used by buffered `AsgiSend` to avoid `pyo3_async_runtimes::future_into_py` -/// and its tokio task overhead. Implements the Python iterator protocol -/// so `await resolved_awaitable` returns `None` with no scheduling. -#[expect( - clippy::redundant_pub_crate, - reason = "visible to sibling modules in asgi/" -)] +/// Used by the response writer to return from `send()` without scheduling. +/// Implements the Python iterator protocol so `await resolved` returns +/// `None` with no scheduling overhead. +#[expect(clippy::redundant_pub_crate, reason = "used from protocol::writer")] #[pyclass(module = "apx._core", freelist = 128)] pub(crate) struct ResolvedAwaitable; @@ -461,19 +224,15 @@ impl ResolvedAwaitable { #[expect(clippy::unused_self, reason = "required by Python iterator protocol")] fn __next__(&self) -> Option> { - None // StopIteration — completes immediately + None } } /// Zero-overhead Python awaitable that completes immediately with a value. /// -/// Used by `AsgiReceive` and `SlotReceive` to return the receive dict -/// without `future_into_py` (which requires a tokio runtime, unavailable -/// on `spawn_blocking` threads). -#[expect( - clippy::redundant_pub_crate, - reason = "visible to sibling modules in asgi/" -)] +/// Used by [`HttpReceive`](crate::protocol::connection::HttpReceive) +/// to return the receive dict without scheduling. +#[expect(clippy::redundant_pub_crate, reason = "used from protocol::connection")] #[pyclass(module = "apx._core", freelist = 64)] pub(crate) struct ResolvedAwaitableWithValue { value: Option>, @@ -497,7 +256,6 @@ impl ResolvedAwaitableWithValue { } fn __next__(&mut self) -> PyResult> { - // Raise StopIteration(value) — this is how Python awaitables return results. let val = self .value .take() @@ -505,1897 +263,3 @@ impl ResolvedAwaitableWithValue { Err(pyo3::exceptions::PyStopIteration::new_err((val,))) } } - -// ── AsgiSend ───────────────────────────────────────────────────────────── - -/// Channel capacity for streaming body chunks after the first. -/// -/// Must be at least as large as the drive step budget (128) so that a -/// streaming handler producing many small chunks never blocks during a -/// single drive cycle. Backpressure still engages for very large -/// responses — the driver suspends and the drain task resumes once hyper -/// drains the channel. -const STREAM_CHANNEL_CAPACITY: usize = 256; - -/// Internal state for [`AsgiSend`] — HTTP vs WebSocket mode. -enum SendInner { - /// HTTP mode — accumulates response, sends via oneshot. - Http { - status: Option, - headers: Option, - response_tx: Option>>, - disconnect_tx: Option>, - stream_tx: Option>, - }, - /// WebSocket mode — forwards events via mpsc (unchanged). - Ws { tx: mpsc::Sender }, -} - -/// ASGI `send` callable backed by Rust. -/// -/// In HTTP mode, accumulates status/headers from `ResponseStart` and builds -/// an [`OutboundResponse`] directly — no intermediate mpsc channel for the -/// common fixed-response case. -/// -/// In WebSocket mode, forwards events via mpsc (same as before). -#[pyclass(module = "apx._core", freelist = 64)] -pub struct AsgiSend { - inner: SendInner, - resolved: Option>, -} - -impl std::fmt::Debug for AsgiSend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.inner { - SendInner::Http { .. } => f.debug_struct("AsgiSend::Http").finish_non_exhaustive(), - SendInner::Ws { .. } => f.debug_struct("AsgiSend::Ws").finish_non_exhaustive(), - } - } -} - -impl AsgiSend { - /// Create an HTTP-mode sender backed by a oneshot response channel. - pub fn http( - response_tx: oneshot::Sender>, - disconnect_tx: oneshot::Sender<()>, - send_cache: &SendCache, - py: Python<'_>, - ) -> Self { - Self { - inner: SendInner::Http { - status: None, - headers: None, - response_tx: Some(response_tx), - disconnect_tx: Some(disconnect_tx), - stream_tx: None, - }, - resolved: Some(send_cache.resolved.clone_ref(py)), - } - } - - /// Create a WebSocket-mode sender backed by an mpsc channel. - pub fn new(tx: mpsc::Sender) -> Self { - Self { - inner: SendInner::Ws { tx }, - resolved: None, - } - } -} - -#[pymethods] -impl AsgiSend { - /// Forward an unhandled app exception through the response channel as a 500. - /// - /// Called by the `_guarded` wrapper when the ASGI app raises an - /// `Exception`. Without this, `response_tx` drops silently and - /// `response_rx` gets `RecvError`. - fn send_error(&mut self, traceback: String) { - if let SendInner::Http { response_tx, .. } = &mut self.inner - && let Some(tx) = response_tx.take() - { - let _ = tx.send(Err(AppError::Internal(traceback))); - } - } - - /// Python: `await send({"type": "http.response.start", ...})` - fn __call__<'py>( - &mut self, - py: Python<'py>, - event: Bound<'py, PyDict>, - ) -> PyResult> { - let parsed = crate::telemetry::timed!( - crate::telemetry::dispatch_metrics::record_send_parse, - parse_asgi_send_event(&event)? - ); - - let resolved = self.resolved.as_ref(); - match &mut self.inner { - SendInner::Http { - status, - headers, - response_tx, - disconnect_tx, - stream_tx, - } => Self::handle_http( - py, - parsed, - status, - headers, - response_tx, - disconnect_tx, - stream_tx, - resolved, - ), - SendInner::Ws { tx } => Self::handle_ws(py, parsed, tx, resolved), - } - } -} - -impl AsgiSend { - /// Return a `ResolvedAwaitable` from the cached singleton or a fresh allocation. - fn resolved_awaitable<'py>( - resolved: Option<&Py>, - py: Python<'py>, - ) -> PyResult> { - if let Some(cached) = resolved { - Ok(cached.clone_ref(py).into_bound(py).into_any()) - } else { - Py::new(py, ResolvedAwaitable).map(|obj| obj.into_bound(py).into_any()) - } - } - - /// Handle an event in HTTP mode. - #[expect( - clippy::too_many_arguments, - reason = "mutable refs to send state fields" - )] - fn handle_http<'py>( - py: Python<'py>, - event: AsgiEvent, - status: &mut Option, - headers: &mut Option, - response_tx: &mut Option>>, - disconnect_tx: &mut Option>, - stream_tx: &mut Option>, - resolved: Option<&Py>, - ) -> PyResult> { - match event { - AsgiEvent::ResponseStart { - status: s, - headers: h, - } => { - tracing::trace!(name: "apx.asgi.send_response_start", status = s, "asgi_send: response_start"); - *status = Some(s); - *headers = Some(h); - Self::resolved_awaitable(resolved, py) - } - AsgiEvent::ResponseBody { body, more_body } if stream_tx.is_none() => { - Self::handle_first_body( - py, - body, - more_body, - status, - headers, - response_tx, - disconnect_tx, - stream_tx, - resolved, - ) - } - AsgiEvent::ResponseBody { body, more_body } => { - Self::handle_stream_body(py, body, more_body, stream_tx, resolved) - } - _ => Err(pyo3::exceptions::PyRuntimeError::new_err( - "unexpected event type in HTTP mode", - )), - } - } - - /// First `http.response.body` — decide streaming vs fixed and send the response. - #[expect( - clippy::too_many_arguments, - reason = "mutable refs to send state fields" - )] - fn handle_first_body<'py>( - py: Python<'py>, - body: Bytes, - more_body: bool, - status: &mut Option, - headers: &mut Option, - response_tx: &mut Option>>, - disconnect_tx: &mut Option>, - stream_tx: &mut Option>, - resolved: Option<&Py>, - ) -> PyResult> { - let Some(raw_status) = status.take() else { - return Err(pyo3::exceptions::PyRuntimeError::new_err( - "ASGI protocol error: body before response start", - )); - }; - let resp_headers = headers.take().unwrap_or_default(); - let http_status = http::StatusCode::from_u16(raw_status) - .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR); - let server_route = None; - - if more_body { - let (stx, srx) = mpsc::channel(STREAM_CHANNEL_CAPACITY); - let dtx = disconnect_tx.take(); - let body_len = body.len(); - let stream = super::streaming::AsgiBodyStream::new(srx, Some(body), dtx); - if let Some(tx) = response_tx.take() { - let _ = tx.send(Ok(OutboundResponse { - status: http_status, - headers: resp_headers, - body: ResponseBody::Stream(Box::pin(stream)), - server_route, - })); - } - *stream_tx = Some(stx); - tracing::trace!(name: "apx.asgi.send_first_body_chunk", body_len, "asgi_send: first body chunk (streaming started)"); - } else { - let _ = disconnect_tx.take(); - let body_len = body.len(); - if let Some(tx) = response_tx.take() { - let _ = tx.send(Ok(OutboundResponse { - status: http_status, - headers: resp_headers, - body: ResponseBody::Fixed(body), - server_route, - })); - } - tracing::trace!(name: "apx.asgi.send_fixed_body", body_len, "asgi_send: fixed body (complete)"); - } - Self::resolved_awaitable(resolved, py) - } - - /// Subsequent `http.response.body` — push to the streaming channel with backpressure. - fn handle_stream_body<'py>( - py: Python<'py>, - body: Bytes, - more_body: bool, - stream_tx: &mut Option>, - resolved: Option<&Py>, - ) -> PyResult> { - let Some(tx) = stream_tx.as_ref() else { - return Err(pyo3::exceptions::PyRuntimeError::new_err( - "ASGI protocol error: body after stream closed", - )); - }; - let body_len = body.len(); - match tx.try_send(AsgiEvent::ResponseBody { body, more_body }) { - Ok(()) => { - tracing::trace!( - name: "apx.asgi.send_stream_chunk", - body_len, - more_body, - "asgi_send: stream chunk sent (no backpressure)" - ); - if !more_body { - *stream_tx = None; - } - Self::resolved_awaitable(resolved, py) - } - Err(mpsc::error::TrySendError::Full(event)) => { - tracing::trace!( - name: "apx.asgi.send_stream_backpressure", - body_len, - more_body, - "asgi_send: stream chunk BACKPRESSURE (channel full)" - ); - let tx = tx.clone(); - let drop_stream = !more_body; - let handle = crate::io::with_tokio_handle(|h| h.clone()).ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err( - "no tokio runtime for backpressure send", - ) - })?; - let _guard = handle.enter(); - if drop_stream { - *stream_tx = None; - } - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let _ = tx.send(event).await; - tracing::trace!(name: "apx.asgi.send_backpressure_resolved", "asgi_send: backpressure resolved"); - Ok(Python::attach(|py| py.None())) - }) - } - Err(mpsc::error::TrySendError::Closed(_)) => { - tracing::trace!(name: "apx.asgi.send_stream_channel_closed", "asgi_send: stream channel CLOSED"); - *stream_tx = None; - Err(pyo3::exceptions::PyRuntimeError::new_err( - "stream channel closed", - )) - } - } - } - - /// Handle an event in WebSocket mode (unchanged logic). - fn handle_ws<'py>( - py: Python<'py>, - event: AsgiEvent, - tx: &mpsc::Sender, - resolved: Option<&Py>, - ) -> PyResult> { - match tx.try_send(event) { - Ok(()) => Self::resolved_awaitable(resolved, py), - Err(mpsc::error::TrySendError::Full(event)) => { - let tx = tx.clone(); - let handle = crate::io::with_tokio_handle(|h| h.clone()).ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err( - "no tokio runtime for backpressure send", - ) - })?; - let _guard = handle.enter(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let _ = tx.send(event).await; - Ok(Python::attach(|py| py.None())) - }) - } - Err(mpsc::error::TrySendError::Closed(_)) => Err( - pyo3::exceptions::PyRuntimeError::new_err("response channel closed"), - ), - } - } -} - -// ── WebSocket incoming events ──────────────────────────────────────────── - -/// Incoming WebSocket event from the client (axum WS → Python handler). -#[derive(Debug)] -pub enum WsIncomingEvent { - /// `websocket.connect` — initial connection event. - Connect, - /// `websocket.receive` — client sent a text or binary frame. - Receive { - /// Text frame payload. - text: Option, - /// Binary frame payload (zero-copy from tungstenite `Bytes`). - bytes: Option, - }, - /// `websocket.disconnect` — client disconnected. - Disconnect { - /// WebSocket close code (default 1000). - code: u16, - }, -} - -/// ASGI `receive` callable for WebSocket connections. -/// -/// Returns ASGI dicts for `websocket.connect`, `websocket.receive`, -/// and `websocket.disconnect` events by reading from a channel fed -/// by the axum WebSocket frame forwarder. -#[pyclass(module = "apx._core")] -pub struct AsgiWsReceive { - rx: Arc>>, -} - -crate::opaque_debug!(AsgiWsReceive); - -impl AsgiWsReceive { - /// Create a new WebSocket receive callable. - pub fn new(rx: mpsc::Receiver) -> Self { - Self { - rx: Arc::new(Mutex::new(rx)), - } - } -} - -#[pymethods] -impl AsgiWsReceive { - /// Python: `event = await receive()` - fn __call__<'py>(&self, py: Python<'py>) -> PyResult> { - let rx = Arc::clone(&self.rx); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let mut guard = rx.lock().await; - let event = guard.recv().await; - Python::attach(|py| build_ws_receive_event(py, event)) - }) - } -} - -/// Build an ASGI WebSocket receive event dict. -fn build_ws_receive_event(py: Python<'_>, event: Option) -> PyResult> { - let dict = PyDict::new(py); - let type_key = pyo3::intern!(py, "type"); - match event { - Some(WsIncomingEvent::Connect) => { - dict.set_item(type_key, pyo3::intern!(py, "websocket.connect"))?; - } - Some(WsIncomingEvent::Receive { text, bytes }) => { - dict.set_item(type_key, pyo3::intern!(py, "websocket.receive"))?; - if let Some(t) = text { - dict.set_item(pyo3::intern!(py, "text"), t)?; - } - if let Some(b) = bytes { - dict.set_item(pyo3::intern!(py, "bytes"), PyBytes::new(py, &b))?; - } - } - Some(WsIncomingEvent::Disconnect { code }) => { - dict.set_item(type_key, pyo3::intern!(py, "websocket.disconnect"))?; - dict.set_item(pyo3::intern!(py, "code"), code)?; - } - None => { - dict.set_item(type_key, pyo3::intern!(py, "websocket.disconnect"))?; - dict.set_item(pyo3::intern!(py, "code"), 1000u16)?; - } - } - Ok(dict.into_any().unbind()) -} - -// ── Parse helpers ──────────────────────────────────────────────────────── - -/// Parse an ASGI send event dict into a typed [`AsgiEvent`]. -/// -/// Compares the `"type"` value against interned Python strings directly, -/// avoiding a Rust `String` allocation on every call. Only the error path -/// (unsupported event type) extracts the string for the error message. -fn parse_asgi_send_event(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let type_obj = event - .get_item(pyo3::intern!(py, "type"))? - .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("type"))?; - - if type_obj.eq(pyo3::intern!(py, "http.response.start"))? { - parse_response_start(event) - } else if type_obj.eq(pyo3::intern!(py, "http.response.body"))? { - parse_response_body(event) - } else if type_obj.eq(pyo3::intern!(py, "websocket.accept"))? { - parse_ws_accept(event) - } else if type_obj.eq(pyo3::intern!(py, "websocket.send"))? { - parse_ws_send(event) - } else if type_obj.eq(pyo3::intern!(py, "websocket.close"))? { - parse_ws_close(event) - } else { - let event_type: String = type_obj.extract()?; - Err(pyo3::exceptions::PyValueError::new_err(format!( - "unsupported ASGI event type: {event_type}" - ))) - } -} - -/// Parse `http.response.start` — extract status and build `HeaderMap` directly. -/// -/// Builds the `HeaderMap` from `PyBytes` references without intermediate -/// `Vec` allocations. Standard header names (content-type, etc.) are -/// recognized as constants with zero allocation. -fn parse_response_start(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let status: u16 = event - .get_item(pyo3::intern!(py, "status"))? - .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("status"))? - .extract()?; - let headers = parse_header_map(event)?; - Ok(AsgiEvent::ResponseStart { status, headers }) -} - -/// Parse `http.response.body` — extract body bytes and more_body flag. -fn parse_response_body(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let body = extract_body_bytes(event)?; - let more_body: bool = event - .get_item(pyo3::intern!(py, "more_body"))? - .map(|b| b.extract()) - .transpose()? - .unwrap_or(false); - Ok(AsgiEvent::ResponseBody { body, more_body }) -} - -/// Extract body bytes from an ASGI event dict via zero-copy ownership transfer. -/// -/// `PyBackedBytes` borrows Python's buffer; `Bytes::from_owner` wraps it for -/// hyper. Python's refcount keeps the buffer alive until Rust drops it. -fn extract_body_bytes(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let Some(obj) = event.get_item(pyo3::intern!(py, "body"))? else { - return Ok(Bytes::new()); - }; - match obj.cast::() { - Ok(py_bytes) => { - let backed: PyBackedBytes = py_bytes.clone().into(); - Ok(Bytes::from_owner(backed)) - } - Err(_) => Ok(Bytes::from(obj.extract::>()?)), - } -} - -/// Parse `websocket.accept` — extract optional subprotocol and headers. -fn parse_ws_accept(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let subprotocol: Option = event - .get_item(pyo3::intern!(py, "subprotocol"))? - .and_then(|v| v.extract().ok()); - let headers = extract_header_list(event)?; - Ok(AsgiEvent::WsAccept { - subprotocol, - headers, - }) -} - -/// Parse `websocket.send` — extract text or binary payload. -/// -/// Binary frames use `PyBackedBytes` + `Bytes::from_owner` for zero-copy -/// transfer from Python to tungstenite. -fn parse_ws_send(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let text: Option = event - .get_item(pyo3::intern!(py, "text"))? - .and_then(|v| v.extract().ok()); - let bytes: Option = match event.get_item(pyo3::intern!(py, "bytes"))? { - Some(v) => match v.cast::() { - Ok(py_bytes) => { - let backed: PyBackedBytes = py_bytes.clone().into(); - Some(Bytes::from_owner(backed)) - } - Err(_) => None, - }, - None => None, - }; - Ok(AsgiEvent::WsSend { text, bytes }) -} - -/// Parse `websocket.close` — extract close code. -fn parse_ws_close(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let code: u16 = event - .get_item(pyo3::intern!(py, "code"))? - .map(|v| v.extract()) - .transpose()? - .unwrap_or(1000); - Ok(AsgiEvent::WsClose { code }) -} - -/// Build an `http::HeaderMap` directly from an ASGI headers list. -/// -/// Reads `[(b"name", b"value"), ...]` from the Python dict and constructs -/// `HeaderName`/`HeaderValue` directly from `PyBytes::as_bytes()` borrows, -/// eliminating intermediate `Vec` allocations per header. -fn parse_header_map(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let Some(list) = event.get_item(pyo3::intern!(py, "headers"))? else { - return Ok(HeaderMap::new()); - }; - // Direct C-API indexing (PyList_GET_ITEM) avoids Python iterator protocol overhead. - let list: &Bound<'_, PyList> = list.cast()?; - let len = list.len(); - let mut headers = HeaderMap::with_capacity(len); - for i in 0..len { - let tuple = list.get_item(i)?; - let name = header_name_from_py(&tuple.get_item(0)?)?; - let value = header_value_from_py(&tuple.get_item(1)?)?; - headers.insert(name, value); - } - Ok(headers) -} - -/// Build a `HeaderName` from a Python bytes-like object. -fn header_name_from_py(obj: &Bound<'_, PyAny>) -> PyResult { - let bytes = match obj.cast::() { - Ok(py_bytes) => py_bytes.as_bytes(), - Err(_) => return header_name_from_extracted(obj), - }; - HeaderName::from_bytes(bytes) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("invalid header name: {e}"))) -} - -/// Fallback: extract bytes then parse header name. -fn header_name_from_extracted(obj: &Bound<'_, PyAny>) -> PyResult { - let bytes: Vec = obj.extract()?; - HeaderName::from_bytes(&bytes) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("invalid header name: {e}"))) -} - -/// Build a `HeaderValue` from a Python bytes-like object. -fn header_value_from_py(obj: &Bound<'_, PyAny>) -> PyResult { - let bytes = match obj.cast::() { - Ok(py_bytes) => py_bytes.as_bytes(), - Err(_) => return header_value_from_extracted(obj), - }; - HeaderValue::from_bytes(bytes) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("invalid header value: {e}"))) -} - -/// Fallback: extract bytes then parse header value. -fn header_value_from_extracted(obj: &Bound<'_, PyAny>) -> PyResult { - let bytes: Vec = obj.extract()?; - HeaderValue::from_bytes(&bytes) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("invalid header value: {e}"))) -} - -/// Extract raw byte pairs from an ASGI headers list (for WebSocket events). -fn extract_header_list(event: &Bound<'_, PyDict>) -> PyResult, Vec)>> { - let Some(list) = event.get_item(pyo3::intern!(event.py(), "headers"))? else { - return Ok(Vec::new()); - }; - list.try_iter()? - .map(|item| { - let tuple = item?; - let name = extract_bytes_field(&tuple.get_item(0)?)?; - let value = extract_bytes_field(&tuple.get_item(1)?)?; - Ok((name, value)) - }) - .collect() -} - -/// Extract a `Vec` from a Python object, preferring direct `PyBytes` borrow. -fn extract_bytes_field(obj: &Bound<'_, PyAny>) -> PyResult> { - match obj.cast::() { - Ok(py_bytes) => Ok(py_bytes.as_bytes().to_vec()), - Err(_) => obj.extract::>(), - } -} - -// ── ScopeSource ───────────────────────────────────────────────────────── - -/// Read-only access to the fields needed for ASGI scope construction. -/// -/// Implemented by both [`InboundRequest`] (legacy ASGI path, WS path) and -/// [`RequestSlot`] (zero-GIL crossbeam path), so `scope_from_template` -/// can accept either without an intermediate clone. -pub trait ScopeSource { - fn method(&self) -> &http::Method; - fn path(&self) -> &str; - fn query_string(&self) -> &Bytes; - fn headers(&self) -> &HeaderMap; - fn protocol(&self) -> ProtocolVersion; - fn client_addr(&self) -> Option; - fn path_params(&self) -> &[(String, String)]; -} - -impl ScopeSource for InboundRequest { - fn method(&self) -> &http::Method { - &self.method - } - fn path(&self) -> &str { - &self.path - } - fn query_string(&self) -> &Bytes { - &self.query_string - } - fn headers(&self) -> &HeaderMap { - &self.headers - } - fn protocol(&self) -> ProtocolVersion { - self.protocol - } - fn client_addr(&self) -> Option { - self.client_addr - } - fn path_params(&self) -> &[(String, String)] { - &self.path_params - } -} - -impl ScopeSource for RequestSlot { - fn method(&self) -> &http::Method { - &self.method - } - fn path(&self) -> &str { - &self.path - } - fn query_string(&self) -> &Bytes { - &self.query_string - } - fn headers(&self) -> &HeaderMap { - &self.headers - } - fn protocol(&self) -> ProtocolVersion { - self.protocol - } - fn client_addr(&self) -> Option { - self.client_addr - } - fn path_params(&self) -> &[(String, String)] { - &[] - } -} - -// ── scope_from_template ────────────────────────────────────────────────── - -/// Build an HTTP scope from the pre-populated template. -/// -/// `dict.copy()` + per-request fields only. For HTTP/1.1 (>95% of traffic), -/// the `http_version` field is already correct from the template. -pub fn scope_from_template( - py: Python<'_>, - template: &Py, - request: &impl ScopeSource, - fastapi_app: Option<&Py>, - interns: &ScopeInterns, -) -> PyResult> { - let scope = template - .bind(py) - .call_method0(pyo3::intern!(py, "copy"))? - .cast_into::() - .map_err(|e| { - pyo3::exceptions::PyTypeError::new_err(format!( - "scope template copy returned non-dict: {e}" - )) - })?; - if request.protocol() != ProtocolVersion::Http11 { - scope.set_item( - interns.keys.http_version.bind(py), - interns.versions.get(py, request.protocol()), - )?; - } - set_scope_request_fields(py, &scope, request, interns)?; - set_scope_headers(py, &scope, request, interns)?; - set_scope_addresses(py, &scope, request, interns)?; - set_scope_path_params(py, &scope, request, interns)?; - scope.set_item(interns.keys.state.bind(py), PyDict::new(py))?; - if let Some(app) = fastapi_app { - scope.set_item(interns.keys.app.bind(py), app.bind(py))?; - scope.set_item( - interns.keys.router.bind(py), - app.bind(py).getattr(c"router")?, - )?; - } - Ok(scope.unbind()) -} - -/// Construct an ASGI WebSocket scope dict from an [`InboundRequest`]. -/// -/// Similar to [`build_http_scope`] but sets `type: "websocket"` and `scheme: "ws"`. -/// No body-related fields. -pub fn build_ws_scope( - py: Python<'_>, - request: &InboundRequest, - interns: &ScopeInterns, -) -> PyResult> { - let dict = PyDict::new(py); - set_ws_scope_metadata(py, &dict, interns)?; - set_ws_scope_request_fields(py, &dict, request, interns)?; - set_scope_headers(py, &dict, request, interns)?; - set_scope_addresses(py, &dict, request, interns)?; - set_scope_path_params(py, &dict, request, interns)?; - dict.set_item(interns.keys.state.bind(py), PyDict::new(py))?; - Ok(dict.unbind()) -} - -/// Set ASGI WebSocket scope metadata fields. -fn set_ws_scope_metadata( - py: Python<'_>, - dict: &Bound<'_, PyDict>, - interns: &ScopeInterns, -) -> PyResult<()> { - dict.set_item( - interns.keys.r#type.bind(py), - interns.vals.type_websocket.bind(py), - )?; - dict.set_item(interns.keys.asgi.bind(py), interns.vals.asgi_dict.bind(py))?; - dict.set_item( - interns.keys.scheme.bind(py), - interns.vals.scheme_ws.bind(py), - )?; - dict.set_item( - interns.keys.root_path.bind(py), - interns.vals.root_path_empty.bind(py), - )?; - Ok(()) -} - -/// Set WebSocket request-specific scope fields. -fn set_ws_scope_request_fields( - py: Python<'_>, - dict: &Bound<'_, PyDict>, - request: &InboundRequest, - interns: &ScopeInterns, -) -> PyResult<()> { - dict.set_item( - interns.keys.http_version.bind(py), - request.protocol.as_asgi_version(), - )?; - dict.set_item(interns.keys.path.bind(py), percent_decode(&request.path))?; - dict.set_item( - interns.keys.raw_path.bind(py), - PyBytes::new(py, request.path.as_bytes()), - )?; - dict.set_item( - interns.keys.query_string.bind(py), - PyBytes::new(py, &request.query_string), - )?; - Ok(()) -} - -/// Set request-specific scope fields: http_version, method, path, raw_path, query_string. -fn set_scope_request_fields( - py: Python<'_>, - dict: &Bound<'_, PyDict>, - request: &impl ScopeSource, - interns: &ScopeInterns, -) -> PyResult<()> { - dict.set_item( - interns.keys.http_version.bind(py), - interns.versions.get(py, request.protocol()), - )?; - dict.set_item(interns.keys.method.bind(py), request.method().as_str())?; - dict.set_item(interns.keys.path.bind(py), percent_decode(request.path()))?; - dict.set_item( - interns.keys.raw_path.bind(py), - PyBytes::new(py, request.path().as_bytes()), - )?; - dict.set_item( - interns.keys.query_string.bind(py), - PyBytes::new(py, request.query_string()), - )?; - Ok(()) -} - -/// Set ASGI headers as a list of `(bytes, bytes)` tuples. -/// -/// Uses cached `PyBytes` for common header names (cache hit = zero allocation) -/// and constructs the list from a presized `Vec` (zero list resizes). -fn set_scope_headers( - py: Python<'_>, - dict: &Bound<'_, PyDict>, - request: &impl ScopeSource, - interns: &ScopeInterns, -) -> PyResult<()> { - let mut pairs: Vec> = Vec::with_capacity(request.headers().len()); - for (name, value) in request.headers() { - let n = interns - .headers - .get(py, name) - .unwrap_or_else(|| PyBytes::new(py, name.as_str().as_bytes())); - let v = PyBytes::new(py, value.as_bytes()); - let pair = PyTuple::new(py, [n.into_any(), v.into_any()])?; - pairs.push(pair.into_any()); - } - let headers_list = PyList::new(py, &pairs)?; - dict.set_item(interns.keys.headers.bind(py), headers_list)?; - Ok(()) -} - -/// Set server and client address tuples in scope. -fn set_scope_addresses( - py: Python<'_>, - dict: &Bound<'_, PyDict>, - request: &impl ScopeSource, - interns: &ScopeInterns, -) -> PyResult<()> { - dict.set_item(interns.keys.server.bind(py), interns.server_tuple.bind(py))?; - match request.client_addr() { - Some(addr) => { - dict.set_item( - interns.keys.client.bind(py), - (addr.ip().to_string(), addr.port()), - )?; - } - None => dict.set_item(interns.keys.client.bind(py), py.None())?, - } - Ok(()) -} - -/// Set path_params dict in scope (Starlette reads `scope["path_params"]`). -/// -/// Values are URL-decoded because axum's `RawPathParams` provides percent-encoded -/// strings, but Starlette/FastAPI expects decoded values (matching what Starlette's -/// own router would produce). -/// -/// When path_params is empty (parameterless routes), reuses a pre-built -/// empty dict singleton from `ScopeInterns` to avoid a `PyDict::new` per request. -fn set_scope_path_params( - py: Python<'_>, - dict: &Bound<'_, PyDict>, - request: &impl ScopeSource, - interns: &ScopeInterns, -) -> PyResult<()> { - let params = request.path_params(); - if params.is_empty() { - dict.set_item( - interns.keys.path_params.bind(py), - interns.empty_dict.bind(py), - )?; - return Ok(()); - } - let pp = PyDict::new(py); - for (k, v) in params { - pp.set_item(k.as_str(), percent_decode(v.as_str()))?; - } - dict.set_item(interns.keys.path_params.bind(py), pp)?; - Ok(()) -} - -/// Decode percent-encoded UTF-8 strings (e.g., `hello%20world` → `hello world`). -/// -/// Returns the original string borrowed if no percent sequences are present, -/// avoiding a heap allocation on the common path. -pub(super) fn percent_decode(input: &str) -> Cow<'_, str> { - if !input.contains('%') { - return Cow::Borrowed(input); - } - let mut bytes = Vec::with_capacity(input.len()); - let mut chars = input.as_bytes().iter().copied(); - while let Some(b) = chars.next() { - if b == b'%' { - let hi = chars.next(); - let lo = chars.next(); - if let (Some(h), Some(l)) = (hi, lo) { - if let (Some(hv), Some(lv)) = (hex_val(h), hex_val(l)) { - bytes.push(hv << 4 | lv); - continue; - } - // Invalid hex — emit literally - bytes.extend_from_slice(&[b'%', h, l]); - } else { - // Truncated — emit literally - bytes.push(b'%'); - if let Some(h) = hi { - bytes.push(h); - } - } - } else { - bytes.push(b); - } - } - Cow::Owned(String::from_utf8(bytes).unwrap_or_else(|_| input.to_owned())) -} - -/// Convert an ASCII hex digit to its 4-bit value. -const fn hex_val(b: u8) -> Option { - match b { - b'0'..=b'9' => Some(b - b'0'), - b'a'..=b'f' => Some(b - b'a' + 10), - b'A'..=b'F' => Some(b - b'A' + 10), - _ => None, - } -} - -// ── Tests ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - clippy::panic, - reason = "test code uses unwrap/assert for clarity" -)] -mod tests { - use super::*; - use crate::transport::types::{BodyStream, ProtocolVersion, TransportKind}; - use crate::with_py; - use http::header::HeaderMap; - - const TEST_SERVER_ADDR: SocketAddr = - SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), 8080); - use std::net::SocketAddr; - - // ── Pure Rust tests ────────────────────────────────────────────────── - - #[test] - fn asgi_event_debug_response_start() { - let mut h = HeaderMap::new(); - h.insert("content-type", "text/plain".parse().unwrap()); - let event = AsgiEvent::ResponseStart { - status: 200, - headers: h, - }; - let dbg = format!("{event:?}"); - assert!(dbg.contains("ResponseStart")); - assert!(dbg.contains("200")); - } - - #[test] - fn asgi_event_debug_response_body() { - let event = AsgiEvent::ResponseBody { - body: Bytes::from("hello"), - more_body: false, - }; - let dbg = format!("{event:?}"); - assert!(dbg.contains("ResponseBody")); - } - - #[test] - fn asgi_send_debug_http() { - with_py(|py| { - let (tx, _rx) = oneshot::channel(); - let (dtx, _drx) = oneshot::channel(); - let cache = SendCache::new(py).unwrap(); - let send = AsgiSend::http(tx, dtx, &cache, py); - let dbg = format!("{send:?}"); - assert!(dbg.contains("AsgiSend::Http")); - }); - } - - #[test] - fn asgi_send_debug_ws() { - let (tx, _rx) = mpsc::channel(1); - let send = AsgiSend::new(tx); - let dbg = format!("{send:?}"); - assert!(dbg.contains("AsgiSend::Ws")); - } - - // ── Helper ─────────────────────────────────────────────────────────── - - fn make_inbound_request( - method: http::Method, - path: &str, - query: &[u8], - headers: HeaderMap, - path_params: Vec<(String, String)>, - client_addr: Option, - ) -> InboundRequest { - InboundRequest::new( - method, - path.to_owned(), - Bytes::copy_from_slice(query), - headers, - BodyStream::Empty, - ProtocolVersion::Http11, - TransportKind::Tcp, - client_addr, - SocketAddr::from(([127, 0, 0, 1], 8080)), - path_params, - http::Extensions::new(), - ) - } - - // ── build_http_scope tests (require Python) ────────────────────────── - - #[test] - fn scope_basic_fields() { - let req = make_inbound_request( - http::Method::GET, - "/", - b"", - HeaderMap::new(), - Vec::new(), - Some(SocketAddr::from(([10, 0, 0, 1], 5555))), - ); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - let scope = scope.bind(py); - assert_eq!( - scope - .get_item("type") - .unwrap() - .unwrap() - .extract::() - .unwrap(), - "http" - ); - assert_eq!( - scope - .get_item("method") - .unwrap() - .unwrap() - .extract::() - .unwrap(), - "GET" - ); - assert_eq!( - scope - .get_item("path") - .unwrap() - .unwrap() - .extract::() - .unwrap(), - "/" - ); - assert_eq!( - scope - .get_item("scheme") - .unwrap() - .unwrap() - .extract::() - .unwrap(), - "http" - ); - assert_eq!( - scope - .get_item("root_path") - .unwrap() - .unwrap() - .extract::() - .unwrap(), - "" - ); - // asgi version - let asgi = scope.get_item("asgi").unwrap().unwrap(); - assert_eq!( - asgi.get_item("version") - .unwrap() - .extract::() - .unwrap(), - "3.0" - ); - assert_eq!( - asgi.get_item("spec_version") - .unwrap() - .extract::() - .unwrap(), - "2.4" - ); - }); - } - - #[test] - fn scope_protocol_versions() { - with_py(|py| { - for (version, expected) in [ - (ProtocolVersion::Http10, "1.0"), - (ProtocolVersion::Http11, "1.1"), - (ProtocolVersion::H2, "2"), - ] { - let req = InboundRequest::new( - http::Method::GET, - "/".to_owned(), - Bytes::new(), - HeaderMap::new(), - BodyStream::Empty, - version, - TransportKind::Tcp, - None, - SocketAddr::from(([127, 0, 0, 1], 8080)), - Vec::new(), - http::Extensions::new(), - ); - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - let scope = scope.bind(py); - let http_version: String = scope - .get_item("http_version") - .unwrap() - .unwrap() - .extract() - .unwrap(); - assert_eq!(http_version, expected, "version {version:?}"); - } - }); - } - - #[test] - fn scope_with_query_string() { - let req = make_inbound_request( - http::Method::GET, - "/search", - b"q=hello&page=1", - HeaderMap::new(), - Vec::new(), - None, - ); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - let scope = scope.bind(py); - let qs: Vec = scope - .get_item("query_string") - .unwrap() - .unwrap() - .extract() - .unwrap(); - assert_eq!(qs, b"q=hello&page=1"); - }); - } - - #[test] - fn scope_with_headers() { - let mut headers = HeaderMap::new(); - headers.insert("content-type", "application/json".parse().unwrap()); - headers.insert("x-custom", "value".parse().unwrap()); - let req = make_inbound_request(http::Method::POST, "/api", b"", headers, Vec::new(), None); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - let scope = scope.bind(py); - let headers_list = scope.get_item("headers").unwrap().unwrap(); - let len = headers_list.len().unwrap(); - assert_eq!(len, 2); - }); - } - - #[test] - fn scope_with_path_params() { - let req = make_inbound_request( - http::Method::GET, - "/items/42", - b"", - HeaderMap::new(), - vec![("item_id".to_owned(), "42".to_owned())], - None, - ); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - let scope = scope.bind(py); - let pp = scope.get_item("path_params").unwrap().unwrap(); - let val: String = pp.get_item("item_id").unwrap().extract().unwrap(); - assert_eq!(val, "42"); - }); - } - - #[test] - fn scope_with_client_addr() { - let req = make_inbound_request( - http::Method::GET, - "/", - b"", - HeaderMap::new(), - Vec::new(), - Some(SocketAddr::from(([192, 168, 1, 100], 12345))), - ); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - let scope = scope.bind(py); - let client = scope.get_item("client").unwrap().unwrap(); - let host: String = client.get_item(0).unwrap().extract().unwrap(); - let port: u16 = client.get_item(1).unwrap().extract().unwrap(); - assert_eq!(host, "192.168.1.100"); - assert_eq!(port, 12345); - }); - } - - #[test] - fn scope_no_client() { - let req = make_inbound_request( - http::Method::GET, - "/", - b"", - HeaderMap::new(), - Vec::new(), - None, - ); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - let scope = scope.bind(py); - let client = scope.get_item("client").unwrap().unwrap(); - assert!(client.is_none()); - }); - } - - #[test] - fn scope_server_addr() { - let req = make_inbound_request( - http::Method::GET, - "/", - b"", - HeaderMap::new(), - Vec::new(), - None, - ); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - let scope = scope.bind(py); - let server = scope.get_item("server").unwrap().unwrap(); - let host: String = server.get_item(0).unwrap().extract().unwrap(); - let port: u16 = server.get_item(1).unwrap().extract().unwrap(); - assert_eq!(host, "127.0.0.1"); - assert_eq!(port, 8080); - }); - } - - #[test] - fn receive_disconnect_event() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item( - pyo3::intern!(py, "type"), - pyo3::intern!(py, "http.disconnect"), - ) - .unwrap(); - - let event_type: String = dict.get_item("type").unwrap().unwrap().extract().unwrap(); - assert_eq!(event_type, "http.disconnect"); - }); - } - - // ── AsgiSend parse + channel tests ─────────────────────────────────── - - #[test] - fn parse_response_start_event() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "http.response.start").unwrap(); - dict.set_item("status", 200u16).unwrap(); - let headers = PyList::empty(py); - let h = PyTuple::new( - py, - [ - PyBytes::new(py, b"content-type").into_any(), - PyBytes::new(py, b"text/plain").into_any(), - ], - ) - .unwrap(); - headers.append(h).unwrap(); - dict.set_item("headers", headers).unwrap(); - - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::ResponseStart { status, headers } => { - assert_eq!(status, 200); - assert_eq!(headers.len(), 1); - assert_eq!(headers.get("content-type").unwrap(), "text/plain"); - } - other => panic!("expected ResponseStart, got {other:?}"), - } - }); - } - - #[test] - fn parse_response_body_event() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "http.response.body").unwrap(); - dict.set_item("body", PyBytes::new(py, b"hello")).unwrap(); - dict.set_item("more_body", false).unwrap(); - - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::ResponseBody { body, more_body } => { - assert_eq!(body.as_ref(), b"hello"); - assert!(!more_body); - } - other => panic!("expected ResponseBody, got {other:?}"), - } - }); - } - - #[tokio::test] - async fn asgi_send_http_fixed_response() { - let (response_tx, response_rx) = oneshot::channel(); - let (disconnect_tx, _disconnect_rx) = oneshot::channel(); - - with_py(|py| { - let cache = SendCache::new(py).unwrap(); - let mut send = AsgiSend::http(response_tx, disconnect_tx, &cache, py); - - let start_dict = PyDict::new(py); - start_dict.set_item("type", "http.response.start").unwrap(); - start_dict.set_item("status", 200u16).unwrap(); - let headers = PyList::empty(py); - start_dict.set_item("headers", headers).unwrap(); - send.__call__(py, start_dict.clone()).unwrap(); - - let body_dict = PyDict::new(py); - body_dict.set_item("type", "http.response.body").unwrap(); - body_dict - .set_item("body", PyBytes::new(py, b"hello")) - .unwrap(); - body_dict.set_item("more_body", false).unwrap(); - send.__call__(py, body_dict.clone()).unwrap(); - }); - - let resp = response_rx.await.unwrap().unwrap(); - assert_eq!(resp.status, http::StatusCode::OK); - match resp.body { - ResponseBody::Fixed(b) => assert_eq!(b.as_ref(), b"hello"), - ResponseBody::Stream(_) => panic!("expected Fixed body"), - } - } - - #[tokio::test] - async fn asgi_send_http_streaming_response() { - let (response_tx, response_rx) = oneshot::channel(); - let (disconnect_tx, _disconnect_rx) = oneshot::channel(); - - with_py(|py| { - let cache = SendCache::new(py).unwrap(); - let mut send = AsgiSend::http(response_tx, disconnect_tx, &cache, py); - - let start_dict = PyDict::new(py); - start_dict.set_item("type", "http.response.start").unwrap(); - start_dict.set_item("status", 200u16).unwrap(); - let headers = PyList::empty(py); - start_dict.set_item("headers", headers).unwrap(); - send.__call__(py, start_dict.clone()).unwrap(); - - let body_dict = PyDict::new(py); - body_dict.set_item("type", "http.response.body").unwrap(); - body_dict - .set_item("body", PyBytes::new(py, b"chunk1")) - .unwrap(); - body_dict.set_item("more_body", true).unwrap(); - send.__call__(py, body_dict.clone()).unwrap(); - }); - - let resp = response_rx.await.unwrap().unwrap(); - assert_eq!(resp.status, http::StatusCode::OK); - match resp.body { - ResponseBody::Stream(mut stream) => { - use futures_core::Stream; - let waker = futures_util::task::noop_waker(); - let mut cx = std::task::Context::from_waker(&waker); - match std::pin::Pin::new(&mut stream).poll_next(&mut cx) { - std::task::Poll::Ready(Some(Ok(chunk))) => { - assert_eq!(chunk.as_ref(), b"chunk1"); - } - other => panic!("expected Ready(Some(Ok(...))), got {other:?}"), - } - } - ResponseBody::Fixed(_) => panic!("expected Stream body"), - } - } - - #[test] - fn send_unknown_event_type() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "http.unknown").unwrap(); - let result = parse_asgi_send_event(&dict); - assert!(result.is_err()); - let err_str = result.unwrap_err().to_string(); - assert!(err_str.contains("unsupported ASGI event type")); - }); - } - - #[test] - fn send_missing_type_key() { - with_py(|py| { - let dict = PyDict::new(py); - let result = parse_asgi_send_event(&dict); - assert!(result.is_err()); - }); - } - - // ── WebSocket event parse tests ───────────────────────────────────── - - #[test] - fn parse_ws_accept_event() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "websocket.accept").unwrap(); - dict.set_item("subprotocol", "graphql-ws").unwrap(); - - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::WsAccept { - subprotocol, - headers, - } => { - assert_eq!(subprotocol.as_deref(), Some("graphql-ws")); - assert!(headers.is_empty()); - } - other => panic!("expected WsAccept, got {other:?}"), - } - }); - } - - #[test] - fn parse_ws_send_text_event() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "websocket.send").unwrap(); - dict.set_item("text", "hello").unwrap(); - - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::WsSend { text, bytes } => { - assert_eq!(text.as_deref(), Some("hello")); - assert!(bytes.is_none()); - } - other => panic!("expected WsSend, got {other:?}"), - } - }); - } - - #[test] - fn parse_ws_send_binary_event() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "websocket.send").unwrap(); - dict.set_item("bytes", PyBytes::new(py, b"\x01\x02\x03")) - .unwrap(); - - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::WsSend { text, bytes } => { - assert!(text.is_none()); - assert_eq!(bytes.as_deref(), Some(b"\x01\x02\x03".as_ref())); - } - other => panic!("expected WsSend, got {other:?}"), - } - }); - } - - #[test] - fn parse_ws_close_event() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "websocket.close").unwrap(); - dict.set_item("code", 1001u16).unwrap(); - - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::WsClose { code } => { - assert_eq!(code, 1001); - } - other => panic!("expected WsClose, got {other:?}"), - } - }); - } - - #[test] - fn parse_ws_close_default_code() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "websocket.close").unwrap(); - - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::WsClose { code } => { - assert_eq!(code, 1000); - } - other => panic!("expected WsClose, got {other:?}"), - } - }); - } - - #[test] - fn ws_incoming_event_debug() { - let connect = WsIncomingEvent::Connect; - assert!(format!("{connect:?}").contains("Connect")); - - let recv = WsIncomingEvent::Receive { - text: Some("hello".to_owned()), - bytes: None, - }; - assert!(format!("{recv:?}").contains("Receive")); - - let disc = WsIncomingEvent::Disconnect { code: 1000 }; - assert!(format!("{disc:?}").contains("Disconnect")); - } - - #[test] - fn asgi_ws_receive_debug() { - let (_tx, rx) = mpsc::channel(1); - let recv = AsgiWsReceive::new(rx); - let dbg = format!("{recv:?}"); - assert!(dbg.contains("AsgiWsReceive")); - } - - // ── build_ws_scope tests ──────────────────────────────────────────── - - #[test] - fn build_ws_scope_basic() { - let req = make_inbound_request( - http::Method::GET, - "/ws", - b"token=abc", - HeaderMap::new(), - vec![("room".to_owned(), "main".to_owned())], - Some(SocketAddr::from(([10, 0, 0, 1], 5555))), - ); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let scope = build_ws_scope(py, &req, &interns).unwrap(); - let scope = scope.bind(py); - - let scope_type: String = scope.get_item("type").unwrap().unwrap().extract().unwrap(); - assert_eq!(scope_type, "websocket"); - - let scheme: String = scope - .get_item("scheme") - .unwrap() - .unwrap() - .extract() - .unwrap(); - assert_eq!(scheme, "ws"); - - let path: String = scope.get_item("path").unwrap().unwrap().extract().unwrap(); - assert_eq!(path, "/ws"); - - let qs: Vec = scope - .get_item("query_string") - .unwrap() - .unwrap() - .extract() - .unwrap(); - assert_eq!(qs, b"token=abc"); - - // path params - let pp = scope.get_item("path_params").unwrap().unwrap(); - let room: String = pp.get_item("room").unwrap().extract().unwrap(); - assert_eq!(room, "main"); - - // no 'method' key (WS scope doesn't have method) - assert!(scope.get_item("method").unwrap().is_none()); - }); - } - - // ── build_ws_receive_event tests ───────────────────────────────────── - - #[test] - fn build_ws_receive_event_connect() { - with_py(|py| { - let result = build_ws_receive_event(py, Some(WsIncomingEvent::Connect)).unwrap(); - let dict = result.bind(py); - let event_type: String = dict.get_item("type").unwrap().extract().unwrap(); - assert_eq!(event_type, "websocket.connect"); - }); - } - - #[test] - fn build_ws_receive_event_receive_text() { - with_py(|py| { - let event = WsIncomingEvent::Receive { - text: Some("hello".to_owned()), - bytes: None, - }; - let result = build_ws_receive_event(py, Some(event)).unwrap(); - let dict = result.bind(py); - let event_type: String = dict.get_item("type").unwrap().extract().unwrap(); - assert_eq!(event_type, "websocket.receive"); - let text: String = dict.get_item("text").unwrap().extract().unwrap(); - assert_eq!(text, "hello"); - }); - } - - #[test] - fn build_ws_receive_event_receive_bytes() { - with_py(|py| { - let event = WsIncomingEvent::Receive { - text: None, - bytes: Some(Bytes::from_static(&[0x01, 0x02, 0x03])), - }; - let result = build_ws_receive_event(py, Some(event)).unwrap(); - let dict = result.bind(py); - let event_type: String = dict.get_item("type").unwrap().extract().unwrap(); - assert_eq!(event_type, "websocket.receive"); - let bytes: Vec = dict.get_item("bytes").unwrap().extract().unwrap(); - assert_eq!(bytes, vec![0x01, 0x02, 0x03]); - }); - } - - #[test] - fn build_ws_receive_event_disconnect_with_code() { - with_py(|py| { - let event = WsIncomingEvent::Disconnect { code: 1001 }; - let result = build_ws_receive_event(py, Some(event)).unwrap(); - let dict = result.bind(py); - let event_type: String = dict.get_item("type").unwrap().extract().unwrap(); - assert_eq!(event_type, "websocket.disconnect"); - let code: u16 = dict.get_item("code").unwrap().extract().unwrap(); - assert_eq!(code, 1001); - }); - } - - #[test] - fn build_ws_receive_event_channel_closed() { - with_py(|py| { - let result = build_ws_receive_event(py, None).unwrap(); - let dict = result.bind(py); - let event_type: String = dict.get_item("type").unwrap().extract().unwrap(); - assert_eq!(event_type, "websocket.disconnect"); - let code: u16 = dict.get_item("code").unwrap().extract().unwrap(); - assert_eq!(code, 1000); - }); - } - - // ── parse edge case tests ──────────────────────────────────────────── - - #[test] - fn parse_response_body_missing_body_key() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "http.response.body").unwrap(); - // No "body" key, no "more_body" key — defaults to empty body, more_body=false - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::ResponseBody { body, more_body } => { - assert!(body.is_empty()); - assert!(!more_body); - } - other => panic!("expected ResponseBody, got {other:?}"), - } - }); - } - - #[test] - fn parse_ws_accept_no_subprotocol() { - with_py(|py| { - let dict = PyDict::new(py); - dict.set_item("type", "websocket.accept").unwrap(); - let event = parse_asgi_send_event(&dict).unwrap(); - match event { - AsgiEvent::WsAccept { - subprotocol, - headers, - } => { - assert!(subprotocol.is_none()); - assert!(headers.is_empty()); - } - other => panic!("expected WsAccept, got {other:?}"), - } - }); - } - - // ── Microbenchmarks ───────────────────────────────────────────────── - // - // Run with: cargo test -p apx-framework -- --nocapture microbench - // - // These are not assert-based tests — they print timing comparisons - // for manual inspection. They isolate specific operations to validate - // (or refute) performance hypotheses. - - const MICROBENCH_ITERATIONS: usize = 100_000; - - fn bench_loop(label: &str, mut f: F) -> std::time::Duration { - // Warmup - for _ in 0..1000 { - f(); - } - let start = std::time::Instant::now(); - for _ in 0..MICROBENCH_ITERATIONS { - f(); - } - let elapsed = start.elapsed(); - let per_op = elapsed / MICROBENCH_ITERATIONS as u32; - eprintln!(" {label:40} {per_op:>8?} ({elapsed:?} / {MICROBENCH_ITERATIONS})"); - elapsed - } - - #[test] - fn microbench_version_intern_vs_direct_str() { - eprintln!("\n=== VersionInterns.get() vs direct as_asgi_version() ==="); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let key = interns.keys.http_version.bind(py); - let dict = PyDict::new(py); - let protocol = ProtocolVersion::Http11; - - bench_loop("VersionInterns.get(Http11) + set_item", || { - dict.set_item(key, interns.versions.get(py, protocol)) - .unwrap(); - }); - - bench_loop("as_asgi_version() + set_item", || { - dict.set_item(key, protocol.as_asgi_version()).unwrap(); - }); - }); - } - - #[test] - fn microbench_server_tuple_cached_vs_dynamic() { - eprintln!("\n=== Cached server_tuple vs dynamic ip().to_string() ==="); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let key = interns.keys.server.bind(py); - let dict = PyDict::new(py); - - bench_loop("cached server_tuple + set_item", || { - dict.set_item(key, interns.server_tuple.bind(py)).unwrap(); - }); - - bench_loop("dynamic (ip.to_string(), port) + set_item", || { - dict.set_item( - key, - (TEST_SERVER_ADDR.ip().to_string(), TEST_SERVER_ADDR.port()), - ) - .unwrap(); - }); - }); - } - - #[test] - fn microbench_resolved_awaitable_singleton_vs_freelist() { - eprintln!("\n=== ResolvedAwaitable: clone_ref (singleton) vs Py::new (freelist) ==="); - with_py(|py| { - let cache = SendCache::new(py).unwrap(); - - bench_loop("clone_ref (singleton)", || { - let _ = cache.resolved.clone_ref(py); - }); - - bench_loop("Py::new (freelist=128)", || { - let _ = Py::new(py, ResolvedAwaitable).unwrap(); - }); - }); - } - - #[test] - fn microbench_receive_dict_build_vs_template_copy() { - eprintln!("\n=== Receive dict: direct build vs template.copy() ==="); - with_py(|py| { - let body = PyBytes::new(py, b"hello world"); - - // NEW: direct dict construction with interned keys - bench_loop("direct PyDict + 3x set_item (interned)", || { - let event = PyDict::new(py); - event - .set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request")) - .unwrap(); - event.set_item(pyo3::intern!(py, "body"), &body).unwrap(); - event - .set_item(pyo3::intern!(py, "more_body"), false) - .unwrap(); - std::hint::black_box(&event); - }); - - // OLD: template.copy() + set_item(body) - let template = PyDict::new(py); - template - .set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request")) - .unwrap(); - template - .set_item(pyo3::intern!(py, "body"), PyBytes::new(py, b"")) - .unwrap(); - template - .set_item(pyo3::intern!(py, "more_body"), false) - .unwrap(); - let template = template.unbind(); - - bench_loop("template.copy() + set_item(body)", || { - let event: Bound<'_, PyDict> = template - .bind(py) - .call_method0(pyo3::intern!(py, "copy")) - .unwrap() - .cast_into() - .unwrap(); - event.set_item(pyo3::intern!(py, "body"), &body).unwrap(); - std::hint::black_box(&event); - }); - }); - } - - #[test] - fn microbench_full_scope_build() { - eprintln!("\n=== Full scope_from_template (new interns) ==="); - with_py(|py| { - let interns = ScopeInterns::new(py, TEST_SERVER_ADDR); - let req = make_inbound_request( - http::Method::GET, - "/api/health", - b"", - HeaderMap::new(), - Vec::new(), - Some(SocketAddr::from(([10, 0, 0, 1], 5555))), - ); - - bench_loop("scope_from_template", || { - let _ = - scope_from_template(py, &interns.scope_template, &req, None, &interns).unwrap(); - }); - }); - } - - #[test] - fn microbench_pylist_direct_index_vs_iterator() { - eprintln!("\n=== PyList: direct index vs try_iter() ==="); - with_py(|py| { - let items: Vec> = (0..20i32) - .map(|i| { - PyTuple::new( - py, - [ - PyBytes::new(py, format!("header-{i}").as_bytes()).into_any(), - PyBytes::new(py, format!("value-{i}").as_bytes()).into_any(), - ], - ) - .unwrap() - .into_any() - }) - .collect(); - let list = PyList::new(py, &items).unwrap(); - - bench_loop("direct index: list.get_item(i)", || { - let mut count = 0usize; - for i in 0..list.len() { - let tuple = list.get_item(i).unwrap(); - std::hint::black_box(&tuple); - count += 1; - } - std::hint::black_box(count); - }); - - bench_loop("try_iter() protocol", || { - let mut count = 0usize; - for item in list.try_iter().unwrap() { - let tuple = item.unwrap(); - std::hint::black_box(&tuple); - count += 1; - } - std::hint::black_box(count); - }); - }); - } -} diff --git a/crates/framework/src/asgi/slot_receive.rs b/crates/framework/src/asgi/slot_receive.rs deleted file mode 100644 index f69727b3..00000000 --- a/crates/framework/src/asgi/slot_receive.rs +++ /dev/null @@ -1,72 +0,0 @@ -//! ASGI `receive()` callable for the 3-thread dispatch pipeline. -//! -//! [`SlotReceive`] wraps a pre-collected request body from [`RequestSlot`]. -//! First call returns `http.request` with the body as a resolved awaitable. -//! Subsequent calls pend indefinitely (disconnect watch). - -use bytes::Bytes; -use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyDict}; - -use super::scope::ResolvedAwaitableWithValue; - -/// ASGI `receive` callable for the 3-thread pipeline. -/// -/// Runs entirely on Thread 2 (Python thread, 100% GIL). The body is -/// pre-collected on Thread 1 and passed in as `Bytes`. -#[pyclass(module = "apx._core", freelist = 64)] -pub struct SlotReceive { - body: std::sync::Mutex>, - receive_template: Py, -} - -crate::opaque_debug!(SlotReceive); - -impl SlotReceive { - /// Create for an HTTP request with a pre-collected body. - pub fn new(body: Bytes, receive_template: Py) -> Self { - Self { - body: std::sync::Mutex::new(Some(body)), - receive_template, - } - } -} - -#[pymethods] -impl SlotReceive { - /// `event = await receive()` - /// - /// First call: returns `http.request` with body via `ResolvedAwaitableWithValue`. - /// Subsequent calls: pend forever (disconnect watch — the connection - /// outlives the ASGI handler in normal operation). - fn __call__<'py>(&self, py: Python<'py>) -> PyResult> { - let taken = self - .body - .lock() - .map_err(|_| pyo3::exceptions::PyRuntimeError::new_err("receive mutex poisoned"))? - .take(); - - if let Some(bytes) = taken { - let event = crate::telemetry::timed!( - crate::telemetry::dispatch_metrics::record_receive_build, - { - let event = self.receive_template.bind(py).copy()?; - event.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &bytes))?; - event - } - ); - let event = event.unbind().into_any(); - Py::new(py, ResolvedAwaitableWithValue::new(event)) - .map(|obj| obj.into_bound(py).into_any()) - } else { - let handle = crate::io::with_tokio_handle(|h| h.clone()).ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("no tokio runtime for disconnect watch") - })?; - let _guard = handle.enter(); - pyo3_async_runtimes::tokio::future_into_py( - py, - std::future::pending::>>(), - ) - } - } -} diff --git a/crates/framework/src/asgi/slot_send.rs b/crates/framework/src/asgi/slot_send.rs deleted file mode 100644 index d35589db..00000000 --- a/crates/framework/src/asgi/slot_send.rs +++ /dev/null @@ -1,307 +0,0 @@ -//! ASGI `send()` callable for the 2-thread dispatch pipeline. -//! -//! [`SlotSend`] collects `http.response.start` (status + headers) and on -//! the first `http.response.body` creates an mpsc channel, builds a -//! [`ResponseData`], and fires the tokio oneshot directly to Thread 1. -//! Subsequent body chunks are pushed via the mpsc sender. Dropping the -//! sender signals EOF. - -use crate::io::channel::{ResponseData, SlotBody}; -use bytes::Bytes; -use pyo3::prelude::*; -use pyo3::pybacked::PyBackedBytes; -use pyo3::types::{PyBytes, PyDict}; -use tokio::sync::{mpsc, oneshot}; - -use super::scope::ResolvedAwaitable; - -/// Generic 500 body returned to clients in production mode. -const INTERNAL_ERROR_BODY: Bytes = Bytes::from_static(b"Internal Server Error"); - -// ── SlotSend ───────────────────────────────────────────────────────────── - -/// ASGI `send` callable for the 2-thread pipeline. -/// -/// Runs entirely on Thread 2 (Python thread, 100% GIL). On the first -/// body chunk, creates the response and fires the tokio oneshot directly -/// to Thread 1. -#[pyclass(module = "apx._core", freelist = 64)] -pub struct SlotSend { - status: Option, - raw_headers: Option>, - response_tx: Option>, - body_tx: Option>, - resolved: Py, - dev_mode: bool, -} - -impl std::fmt::Debug for SlotSend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SlotSend") - .field("status", &self.status) - .field("has_response_tx", &self.response_tx.is_some()) - .field("streaming", &self.body_tx.is_some()) - .finish_non_exhaustive() - } -} - -impl SlotSend { - /// Create a new `SlotSend` for an HTTP request. - pub(crate) fn new( - response_tx: oneshot::Sender, - resolved: Py, - dev_mode: bool, - ) -> Self { - Self { - status: None, - raw_headers: None, - response_tx: Some(response_tx), - body_tx: None, - resolved, - dev_mode, - } - } -} - -#[pymethods] -impl SlotSend { - /// Forward an unhandled app exception as a 500 response. - /// - /// Called by the `_guarded` wrapper in `_dispatch.py` when the ASGI - /// app raises an `Exception`. Always logs the full traceback - /// server-side; the response body depends on `dev_mode`. - fn send_error(&mut self, traceback: String) { - tracing::error!( - name: "apx.dispatch.unhandled_exception", - "{traceback}", - ); - if let Some(response_tx) = self.response_tx.take() { - let body = if self.dev_mode { - Bytes::from(traceback) - } else { - INTERNAL_ERROR_BODY - }; - let response = ResponseData { - status: 500, - headers: vec![( - Bytes::from_static(b"content-type"), - Bytes::from_static(b"text/plain; charset=utf-8"), - )], - body: SlotBody::Complete(body), - }; - let _ = response_tx.send(response); - } - } - - /// `await send({"type": "http.response.start"|"http.response.body", ...})` - fn __call__<'py>( - &mut self, - py: Python<'py>, - event: Bound<'py, PyDict>, - ) -> PyResult> { - let type_obj = event - .get_item(pyo3::intern!(py, "type"))? - .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("type"))?; - - if type_obj.eq(pyo3::intern!(py, "http.response.start"))? { - crate::telemetry::timed!( - crate::telemetry::dispatch_metrics::record_send_parse, - self.handle_response_start(py, &event) - ) - } else if type_obj.eq(pyo3::intern!(py, "http.response.body"))? { - let (body, more_body) = - crate::telemetry::timed!(crate::telemetry::dispatch_metrics::record_send_parse, { - let body = extract_body_bytes(&event)?; - let more_body: bool = event - .get_item(pyo3::intern!(py, "more_body"))? - .map(|b| b.extract()) - .transpose()? - .unwrap_or(false); - (body, more_body) - }); - - if self.body_tx.is_none() { - self.send_first_body_chunk(body, more_body)?; - } else { - self.send_subsequent_chunk(body, more_body); - } - Ok(self.resolved.clone_ref(py).into_bound(py).into_any()) - } else { - let event_type: String = type_obj.extract()?; - Err(pyo3::exceptions::PyValueError::new_err(format!( - "unsupported ASGI event type: {event_type}" - ))) - } - } -} - -impl SlotSend { - /// Handle `http.response.start` — extract status + headers. - fn handle_response_start<'py>( - &mut self, - py: Python<'py>, - event: &Bound<'py, PyDict>, - ) -> PyResult> { - let status: u16 = event - .get_item(pyo3::intern!(py, "status"))? - .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("status"))? - .extract()?; - let headers = extract_raw_headers(event)?; - self.status = Some(status); - self.raw_headers = Some(headers); - Ok(self.resolved.clone_ref(py).into_bound(py).into_any()) - } - - /// First body chunk: build `ResponseData` and fire the tokio oneshot. - /// - /// Non-streaming (`more_body == false`): carries the body inline, - /// skipping the mpsc channel + ChannelBody + Box::pin allocation. - fn send_first_body_chunk(&mut self, body: Bytes, more_body: bool) -> PyResult<()> { - let status = self.status.take().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err( - "ASGI protocol error: body before response start", - ) - })?; - let headers = self.raw_headers.take().unwrap_or_default(); - - let slot_body = if more_body { - let (body_tx, body_rx) = mpsc::unbounded_channel(); - if !body.is_empty() { - let _ = body_tx.send(body); - } - self.body_tx = Some(body_tx); - SlotBody::Chunked(body_rx) - } else { - SlotBody::Complete(body) - }; - - let response = ResponseData { - status, - headers, - body: slot_body, - }; - - if let Some(response_tx) = self.response_tx.take() { - let _ = response_tx.send(response); - } - - Ok(()) - } - - /// Subsequent body chunks: push via mpsc, drop sender on EOF. - fn send_subsequent_chunk(&mut self, body: Bytes, more_body: bool) { - if let Some(tx) = &self.body_tx { - let _ = tx.send(body); - } - if !more_body { - self.body_tx = None; - } - } -} - -// ── Header extraction ──────────────────────────────────────────────────── - -/// Extract response headers as raw byte pairs from the ASGI event dict. -/// -/// Returns `(name, value)` pairs as `Bytes` for zero-copy transfer to -/// Thread 1. The ASGI spec represents headers as a list of 2-tuples -/// of byte strings. -fn extract_raw_headers(event: &Bound<'_, PyDict>) -> PyResult> { - let py = event.py(); - let Some(obj) = event.get_item(pyo3::intern!(py, "headers"))? else { - return Ok(vec![]); - }; - let iter = obj.try_iter()?; - let mut result = Vec::with_capacity(8); - for item in iter { - let pair = item?; - let tuple = pair.cast::()?; - let name = extract_bytes_from_obj(&tuple.get_item(0)?)?; - let value = extract_bytes_from_obj(&tuple.get_item(1)?)?; - result.push((name, value)); - } - Ok(result) -} - -/// Extract `Bytes` from a Python bytes object via zero-copy `PyBackedBytes`. -fn extract_bytes_from_obj(obj: &Bound<'_, PyAny>) -> PyResult { - match obj.cast::() { - Ok(py_bytes) => { - let backed: PyBackedBytes = py_bytes.clone().into(); - Ok(Bytes::from_owner(backed)) - } - Err(_) => Ok(Bytes::from(obj.extract::>()?)), - } -} - -/// Extract body bytes from an ASGI event dict. -fn extract_body_bytes(event: &Bound<'_, PyDict>) -> PyResult { - let py = event.py(); - let Some(obj) = event.get_item(pyo3::intern!(py, "body"))? else { - return Ok(Bytes::new()); - }; - match obj.cast::() { - Ok(py_bytes) => { - let backed: PyBackedBytes = py_bytes.clone().into(); - Ok(Bytes::from_owner(backed)) - } - Err(_) => Ok(Bytes::from(obj.extract::>()?)), - } -} - -// ── Tests ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - clippy::panic, - reason = "test code uses unwrap/assert/panic for clarity" -)] -mod tests { - use super::*; - - fn make_slot_send(dev_mode: bool) -> (SlotSend, oneshot::Receiver) { - let (tx, rx) = oneshot::channel(); - let resolved = crate::with_py(|py| Py::new(py, ResolvedAwaitable).unwrap()); - (SlotSend::new(tx, resolved, dev_mode), rx) - } - - #[test] - fn send_error_prod_mode_returns_generic_body() { - let (mut slot, mut rx) = make_slot_send(false); - let traceback = "Traceback (most recent call last):\n NameError: x\n".to_owned(); - slot.send_error(traceback); - - let response = rx.try_recv().unwrap(); - assert_eq!(response.status, 500); - match response.body { - SlotBody::Complete(b) => assert_eq!(b.as_ref(), b"Internal Server Error"), - SlotBody::Chunked(_) => panic!("expected Complete body"), - } - } - - #[test] - fn send_error_dev_mode_returns_traceback_body() { - let (mut slot, mut rx) = make_slot_send(true); - let traceback = "Traceback (most recent call last):\n NameError: x\n".to_owned(); - slot.send_error(traceback); - - let response = rx.try_recv().unwrap(); - assert_eq!(response.status, 500); - match response.body { - SlotBody::Complete(b) => { - let body_str = std::str::from_utf8(b.as_ref()).unwrap(); - assert!(body_str.contains("Traceback")); - assert!(body_str.contains("NameError")); - } - SlotBody::Chunked(_) => panic!("expected Complete body"), - } - } - - #[test] - fn send_error_without_response_tx_does_not_panic() { - let (mut slot, _rx) = make_slot_send(false); - drop(slot.response_tx.take()); - slot.send_error("some error".to_owned()); - } -} diff --git a/crates/framework/src/asgi/streaming.rs b/crates/framework/src/asgi/streaming.rs deleted file mode 100644 index dcd6a38f..00000000 --- a/crates/framework/src/asgi/streaming.rs +++ /dev/null @@ -1,184 +0,0 @@ -//! Streaming ASGI response body. -//! -//! [`AsgiBodyStream`] wraps an mpsc channel of body chunks into a -//! [`futures_core::Stream`] suitable for HTTP chunked/SSE responses. - -use super::scope::AsgiEvent; -use bytes::Bytes; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::sync::{mpsc, oneshot}; - -/// Stream wrapper over an ASGI send channel. -/// -/// Yields `ResponseBody` chunks until `more_body=false` or the channel closes. -/// Fires `disconnect_tx` on drop to signal the ASGI handler via `http.disconnect`. -pub struct AsgiBodyStream { - rx: mpsc::Receiver, - initial_chunk: Option, - disconnect_tx: Option>, - done: bool, -} - -impl AsgiBodyStream { - /// Create a new body stream with an optional initial chunk and disconnect signal. - pub(super) fn new( - rx: mpsc::Receiver, - initial_chunk: Option, - disconnect_tx: Option>, - ) -> Self { - Self { - rx, - initial_chunk, - disconnect_tx, - done: false, - } - } -} - -impl futures_core::Stream for AsgiBodyStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.done { - return Poll::Ready(None); - } - - if let Some(chunk) = self.initial_chunk.take() { - tracing::trace!(name: "apx.asgi.streaming.initial_chunk", chunk_len = chunk.len(), "body_stream: initial chunk"); - return Poll::Ready(Some(Ok(chunk))); - } - - match self.rx.poll_recv(cx) { - Poll::Ready(Some(AsgiEvent::ResponseBody { body, more_body })) => { - tracing::trace!( - name: "apx.asgi.streaming.chunk_received", - body_len = body.len(), - more_body, - "body_stream: chunk received" - ); - if !more_body { - self.done = true; - } - Poll::Ready(Some(Ok(body))) - } - Poll::Ready(Some(_) | None) => { - tracing::trace!(name: "apx.asgi.streaming.channel_closed_or_unexpected", "body_stream: channel closed or unexpected event"); - self.done = true; - Poll::Ready(None) - } - Poll::Pending => { - tracing::trace!(name: "apx.asgi.streaming.pending", "body_stream: pending (waiting for next chunk)"); - Poll::Pending - } - } - } -} - -impl Drop for AsgiBodyStream { - fn drop(&mut self) { - // Signal disconnect to AsgiReceive. Sending () is enough — - // the receiver resolves its Future with http.disconnect. - // If the coroutine already finished, the signal is harmless. - if let Some(tx) = self.disconnect_tx.take() { - let _ = tx.send(()); - } - } -} - -// ── Tests ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - reason = "test code uses unwrap/assert for clarity" -)] -mod tests { - use super::*; - use tokio_stream::StreamExt; - - #[tokio::test] - async fn asgi_body_stream_single_chunk() { - let (tx, rx) = mpsc::channel(4); - tx.send(AsgiEvent::ResponseBody { - body: Bytes::from("hello"), - more_body: false, - }) - .await - .unwrap(); - drop(tx); - - let (disconnect_tx, _disconnect_rx) = oneshot::channel(); - let mut stream = AsgiBodyStream::new(rx, None, Some(disconnect_tx)); - let chunk = stream.next().await.unwrap().unwrap(); - assert_eq!(chunk.as_ref(), b"hello"); - assert!(stream.next().await.is_none()); - } - - #[tokio::test] - async fn asgi_body_stream_multiple_chunks() { - let (tx, rx) = mpsc::channel(4); - tx.send(AsgiEvent::ResponseBody { - body: Bytes::from("hel"), - more_body: true, - }) - .await - .unwrap(); - tx.send(AsgiEvent::ResponseBody { - body: Bytes::from("lo"), - more_body: false, - }) - .await - .unwrap(); - drop(tx); - - let (disconnect_tx, _disconnect_rx) = oneshot::channel(); - let mut stream = AsgiBodyStream::new(rx, None, Some(disconnect_tx)); - let c1 = stream.next().await.unwrap().unwrap(); - assert_eq!(c1.as_ref(), b"hel"); - let c2 = stream.next().await.unwrap().unwrap(); - assert_eq!(c2.as_ref(), b"lo"); - assert!(stream.next().await.is_none()); - } - - #[tokio::test] - async fn asgi_body_stream_channel_closed() { - let (tx, rx) = mpsc::channel::(4); - drop(tx); - - let (disconnect_tx, _disconnect_rx) = oneshot::channel(); - let mut stream = AsgiBodyStream::new(rx, None, Some(disconnect_tx)); - assert!(stream.next().await.is_none()); - } - - #[tokio::test] - async fn asgi_body_stream_initial_chunk() { - let (tx, rx) = mpsc::channel(4); - tx.send(AsgiEvent::ResponseBody { - body: Bytes::from("world"), - more_body: false, - }) - .await - .unwrap(); - drop(tx); - - let (disconnect_tx, _disconnect_rx) = oneshot::channel(); - let mut stream = AsgiBodyStream::new(rx, Some(Bytes::from("hello ")), Some(disconnect_tx)); - let c1 = stream.next().await.unwrap().unwrap(); - assert_eq!(c1.as_ref(), b"hello "); - let c2 = stream.next().await.unwrap().unwrap(); - assert_eq!(c2.as_ref(), b"world"); - assert!(stream.next().await.is_none()); - } - - #[tokio::test] - async fn asgi_body_stream_drop_fires_disconnect() { - let (disconnect_tx, disconnect_rx) = oneshot::channel(); - let (_tx, rx) = mpsc::channel::(4); - let stream = AsgiBodyStream::new(rx, None, Some(disconnect_tx)); - drop(stream); - - // disconnect_rx should have received the signal. - assert!(disconnect_rx.await.is_ok()); - } -} diff --git a/crates/framework/src/dispatch.rs b/crates/framework/src/dispatch.rs deleted file mode 100644 index 36237ba1..00000000 --- a/crates/framework/src/dispatch.rs +++ /dev/null @@ -1,52 +0,0 @@ -//! Request dispatch abstraction. -//! -//! Defines the [`Dispatch`] trait — the extension seam between the HTTP -//! service layer and the application layer. The service layer calls -//! `dispatch()` after health probes, concurrency checks, and timeout -//! wrapping. For WebSocket upgrades, the service calls `dispatch_ws()` -//! with the raw hyper request (before body consumption). - -use crate::transport::types::ResponseBody; -use crate::transport::{InboundRequest, OutboundResponse}; -use bytes::Bytes; -use hyper::body::Incoming; -use hyper::{Request, Response}; -use std::future::Future; -use std::net::SocketAddr; -use std::pin::Pin; - -/// Dispatch a request to the application layer. -/// -/// Implementations decide the strategy: ASGI bridge, direct dispatch, etc. -/// The service layer calls `dispatch()` after health probes, concurrency -/// checks, and timeout wrapping. -pub trait Dispatch: Send + Sync + std::fmt::Debug { - /// Handle a single inbound request. - fn dispatch( - &self, - request: InboundRequest, - ) -> Pin + Send>>; - - /// Handle a WebSocket upgrade request. - /// - /// Called with the raw hyper request before body consumption, since - /// `hyper_tungstenite::upgrade` consumes the request. The default - /// implementation returns 400 Bad Request. - fn dispatch_ws( - &self, - _request: Request, - _server_addr: SocketAddr, - _client_addr: Option, - ) -> Pin> + Send>> { - Box::pin(async { - // 400 Bad Request — WebSocket not supported by this dispatch. - Response::builder() - .status(http::StatusCode::BAD_REQUEST) - .header(http::header::CONTENT_TYPE, "text/plain") - .body(ResponseBody::Fixed(Bytes::from_static( - b"websocket not supported", - ))) - .unwrap_or_else(|_| unreachable!()) - }) - } -} diff --git a/crates/framework/src/io/channel.rs b/crates/framework/src/io/channel.rs deleted file mode 100644 index 2f9077a5..00000000 --- a/crates/framework/src/io/channel.rs +++ /dev/null @@ -1,312 +0,0 @@ -//! Cross-thread dispatch channels for the zero-GIL 3-thread architecture. -//! -//! Thread 1 (tokio) pushes [`RequestSlot`] into the inbound channel and -//! signals the asyncio thread via [`Wakeup`]. Thread 2 (Python/asyncio) -//! drains requests, runs the ASGI app, and pushes [`OutboundSlot`] into -//! the outbound channel. Thread 3 relays responses back to tokio via -//! oneshot senders. -//! -//! All types in this module are pure Rust — no `Py`, no GIL. - -use bytes::Bytes; -use http::header::HeaderMap; -use std::io; -use std::net::SocketAddr; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, Ordering}; -use tokio::sync::{mpsc, oneshot}; - -use crate::transport::types::ProtocolVersion; - -// ── RequestSlot ────────────────────────────────────────────────────────── - -/// Request flowing from Thread 1 (tokio) → Thread 2 (asyncio). Pure Rust. -#[derive(Debug)] -pub struct RequestSlot { - /// HTTP method. - pub method: http::Method, - /// Request path (without query string). - pub path: String, - /// Raw path bytes for ASGI `raw_path`. - pub raw_path: Bytes, - /// Raw query string bytes. - pub query_string: Bytes, - /// HTTP headers. - pub headers: HeaderMap, - /// Pre-collected request body. - pub body: Bytes, - /// HTTP protocol version. - pub protocol: ProtocolVersion, - /// Client socket address (if available). - pub client_addr: Option, - /// Server socket address. - pub server_addr: SocketAddr, - /// Trace context extracted from the active OTEL span on the tokio thread. - pub trace_context: Option, - /// Timestamp when the slot was created (for pickup_delay measurement). - pub created_at: std::time::Instant, - /// Thread 1 awaits this for the response. - pub response_tx: oneshot::Sender, -} - -// ── ResponseData ───────────────────────────────────────────────────────── - -/// Body payload flowing from Thread 2 (asyncio) to Thread 1 (tokio). -#[derive(Debug)] -pub enum SlotBody { - /// Complete body for non-streaming responses (95% of traffic). - Complete(Bytes), - /// Streaming body fed chunk-by-chunk via an mpsc channel. - Chunked(mpsc::UnboundedReceiver), -} - -/// Response flowing from Thread 2 → Thread 1 via tokio oneshot. -#[derive(Debug)] -pub struct ResponseData { - /// HTTP status code. - pub status: u16, - /// Response headers as raw byte pairs (name, value). - pub headers: Vec<(Bytes, Bytes)>, - /// Response body — complete or streaming. - pub body: SlotBody, -} - -// ── Wakeup ─────────────────────────────────────────────────────────────── - -/// Cross-platform wakeup signal for the asyncio thread. -/// -/// Unix: socket fd pair — `signal()` writes 1 byte, asyncio wakes via -/// `loop.add_reader(fd)`. No GIL needed. -/// -/// Under burst load, multiple tokio tasks may call `signal()` concurrently. -/// An [`AtomicBool`] flag coalesces redundant writes: only the first -/// `signal()` after a `drain()` actually writes to the pipe. This -/// eliminates the `Mutex` contention that serialized all signalers. -pub struct Wakeup { - reader: std::os::unix::net::UnixStream, - writer: std::os::unix::net::UnixStream, - /// Coalescing flag — `true` means a wakeup byte is already in the pipe. - pending: AtomicBool, -} - -crate::opaque_debug!(Wakeup); - -impl Wakeup { - /// Create a new wakeup pipe pair. - /// - /// Both ends are set to non-blocking so neither `signal()` nor the - /// asyncio `on_readable` callback can block. - /// - /// # Errors - /// - /// Returns an IO error if the Unix socket pair cannot be created. - pub fn new() -> io::Result { - let (reader, writer) = std::os::unix::net::UnixStream::pair()?; - reader.set_nonblocking(true)?; - writer.set_nonblocking(true)?; - Ok(Self { - reader, - writer, - pending: AtomicBool::new(false), - }) - } - - /// Signal the asyncio thread by writing 1 byte to the pipe. - /// - /// Uses CAS to coalesce: only the thread that flips `false→true` - /// writes the byte. All others skip — a wakeup is already pending. - /// POSIX guarantees atomicity for writes ≤ `PIPE_BUF`, so one byte - /// is safe without a mutex. - pub fn signal(&self) { - if self - .pending - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - let _ = io::Write::write(&mut &self.writer, &[1u8]); - } - } - - /// Clear the pending flag after the asyncio thread drains the pipe. - /// - /// Called from [`crate::asgi::queue::RequestQueue::try_recv`] when - /// the crossbeam queue is empty, allowing the next `signal()` to - /// write a fresh wakeup byte. - pub fn drain(&self) { - self.pending.store(false, Ordering::Release); - } - - /// Raw file descriptor for the reader end. - /// - /// Passed to `loop.add_reader(fd, callback)` during asyncio init. - pub fn reader_fd(&self) -> std::os::unix::io::RawFd { - std::os::unix::io::AsRawFd::as_raw_fd(&self.reader) - } -} - -// ── InboundChannel ─────────────────────────────────────────────────────── - -/// Thread 1 → Thread 2 request channel. -/// -/// Unbounded crossbeam channel — backpressure is handled by the HTTP -/// semaphore in `ApxService`, not by the channel itself. -#[derive(Debug)] -pub struct InboundChannel { - tx: crossbeam_channel::Sender, - rx: crossbeam_channel::Receiver, -} - -impl InboundChannel { - /// Create a new unbounded inbound channel. - pub fn new() -> Self { - let (tx, rx) = crossbeam_channel::unbounded(); - Self { tx, rx } - } - - /// Sender half — cloned into each tokio task on Thread 1. - pub fn sender(&self) -> &crossbeam_channel::Sender { - &self.tx - } - - /// Receiver half — used by `RequestQueue` on Thread 2. - pub fn receiver(&self) -> &crossbeam_channel::Receiver { - &self.rx - } -} - -// ── DispatchPipeline ───────────────────────────────────────────────────── - -/// Bundles the inbound channel + wakeup. Created once per worker. -/// -/// Responses flow directly from `SlotSend` to Thread 1 via tokio oneshot — -/// no outbound channel or relay thread needed. -#[derive(Debug)] -pub struct DispatchPipeline { - /// Thread 1 → Thread 2 request channel. - pub inbound: InboundChannel, - /// Wakeup signal for the asyncio thread. - pub wakeup: Arc, -} - -impl DispatchPipeline { - /// Create a new dispatch pipeline with Unix pipe wakeup. - /// - /// # Errors - /// - /// Returns an IO error if the wakeup pipe cannot be created. - pub fn new() -> io::Result { - Ok(Self { - inbound: InboundChannel::new(), - wakeup: Arc::new(Wakeup::new()?), - }) - } -} - -// ── Tests ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect(clippy::unwrap_used, reason = "test code uses unwrap for clarity")] -mod tests { - use super::*; - - #[test] - fn wakeup_signal_roundtrip() { - let wakeup = Wakeup::new().unwrap(); - wakeup.signal(); - let mut buf = [0u8; 16]; - let n = io::Read::read(&mut &wakeup.reader, &mut buf).unwrap(); - assert!(n > 0); - } - - #[test] - fn wakeup_coalescing_skips_redundant_writes() { - let wakeup = Wakeup::new().unwrap(); - wakeup.signal(); - wakeup.signal(); - wakeup.signal(); - - let mut buf = [0u8; 16]; - let n = io::Read::read(&mut &wakeup.reader, &mut buf).unwrap(); - assert_eq!(n, 1, "coalesced signals should produce exactly 1 byte"); - - let err = io::Read::read(&mut &wakeup.reader, &mut buf).unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::WouldBlock); - } - - #[test] - fn wakeup_drain_resets_flag() { - let wakeup = Wakeup::new().unwrap(); - - wakeup.signal(); - let mut buf = [0u8; 16]; - let _ = io::Read::read(&mut &wakeup.reader, &mut buf).unwrap(); - - // Before drain: second signal is suppressed (flag still true). - wakeup.signal(); - let err = io::Read::read(&mut &wakeup.reader, &mut buf).unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::WouldBlock); - - // After drain: flag cleared, next signal writes again. - wakeup.drain(); - wakeup.signal(); - let n = io::Read::read(&mut &wakeup.reader, &mut buf).unwrap(); - assert_eq!(n, 1); - } - - #[test] - fn inbound_channel_send_recv() { - let ch = InboundChannel::new(); - let (response_tx, _response_rx) = oneshot::channel(); - let slot = RequestSlot { - method: http::Method::GET, - path: "/test".to_owned(), - raw_path: Bytes::from_static(b"/test"), - query_string: Bytes::new(), - headers: HeaderMap::new(), - body: Bytes::new(), - protocol: ProtocolVersion::Http11, - client_addr: None, - server_addr: SocketAddr::from(([127, 0, 0, 1], 8080)), - trace_context: None, - created_at: std::time::Instant::now(), - response_tx, - }; - ch.sender().send(slot).unwrap(); - let received = ch.receiver().try_recv().unwrap(); - assert_eq!(received.path, "/test"); - } - - #[test] - fn dispatch_pipeline_creates_successfully() { - let pipeline = DispatchPipeline::new().unwrap(); - assert!(format!("{pipeline:?}").contains("DispatchPipeline")); - } - - #[test] - fn wakeup_reader_fd_is_valid() { - let wakeup = Wakeup::new().unwrap(); - let fd = wakeup.reader_fd(); - assert!(fd >= 0); - } - - #[test] - fn request_slot_debug() { - let (response_tx, _) = oneshot::channel(); - let slot = RequestSlot { - method: http::Method::POST, - path: "/api".to_owned(), - raw_path: Bytes::from_static(b"/api"), - query_string: Bytes::from_static(b"q=1"), - headers: HeaderMap::new(), - body: Bytes::from_static(b"{}"), - protocol: ProtocolVersion::Http11, - client_addr: Some(SocketAddr::from(([10, 0, 0, 1], 5000))), - server_addr: SocketAddr::from(([0, 0, 0, 0], 8000)), - trace_context: None, - created_at: std::time::Instant::now(), - response_tx, - }; - let dbg = format!("{slot:?}"); - assert!(dbg.contains("RequestSlot")); - } -} diff --git a/crates/framework/src/io/mod.rs b/crates/framework/src/io/mod.rs deleted file mode 100644 index 38a01c78..00000000 --- a/crates/framework/src/io/mod.rs +++ /dev/null @@ -1,97 +0,0 @@ -//! Python I/O interop — asyncio event loop lifecycle and coroutine submission. -//! -//! [`EventLoop`] is the composition root — it creates the asyncio reactor -//! and exposes cached Python callables for coroutine submission via -//! `call_soon_threadsafe(create_task, coro)`. - -pub mod channel; -pub mod reactor; - -use pyo3::prelude::*; - -// ── EventLoop ──────────────────────────────────────────────────────────── - -/// Asyncio event loop lifecycle — owns the Reactor and exposes -/// cached Python callables for coroutine submission. -pub struct EventLoop { - reactor: reactor::Reactor, -} - -impl EventLoop { - /// Initialize the event loop on the current thread. - /// - /// 1. Creates the asyncio reactor (event loop, thread). - /// 2. Stores the tokio runtime handle in the thread-local. - /// - /// # Errors - /// - /// Returns an error if Python initialization fails. - pub fn init(py: Python<'_>, loop_policy: &str) -> Result { - let reactor = reactor::Reactor::init(py, loop_policy)?; - - if let Ok(handle) = tokio::runtime::Handle::try_current() { - set_tokio_handle(handle); - } - - tracing::debug!(name: "apx.io.event_loop_initialized", "event loop initialized (asyncio delegation)"); - - Ok(Self { reactor }) - } - - /// Cached `loop.call_soon_threadsafe` — the cross-thread submission primitive. - pub fn call_soon_threadsafe(&self) -> &Py { - self.reactor.call_soon_threadsafe() - } - - /// Cached `loop.create_task` — creates a standard asyncio.Task. - pub fn create_task(&self) -> &Py { - self.reactor.create_task() - } - - /// The Python asyncio event loop object. - pub fn event_loop_py(&self) -> &Py { - self.reactor.event_loop_py() - } - - /// Shut down the event loop. - pub fn shutdown(&self) { - self.reactor.shutdown(); - } -} - -crate::opaque_debug!(EventLoop); - -// ── Thread-local tokio runtime handle ──────────────────────────────────── - -use std::cell::RefCell; - -thread_local! { - /// Tokio runtime handle cached on the event loop thread. - /// - /// Set once during [`EventLoop::init`] and on the asyncio thread - /// during reactor init. `AsgiReceive` disconnect detection and - /// `AsgiSend` backpressure handling need the runtime handle on - /// the asyncio thread. - static TOKIO_HANDLE: RefCell> = const { RefCell::new(None) }; -} - -/// Store the tokio runtime handle for the current thread. -pub fn set_tokio_handle(handle: tokio::runtime::Handle) { - TOKIO_HANDLE.with(|cell| *cell.borrow_mut() = Some(handle)); -} - -/// Run a closure with a tokio runtime handle. -/// -/// Checks the thread-local first (set via [`set_tokio_handle`]), then -/// falls back to [`tokio::runtime::Handle::try_current`]. -pub fn with_tokio_handle(f: F) -> Option -where - F: FnOnce(&tokio::runtime::Handle) -> R, -{ - TOKIO_HANDLE.with(|cell| { - if let Some(h) = cell.borrow().as_ref() { - return Some(f(h)); - } - tokio::runtime::Handle::try_current().ok().map(|h| f(&h)) - }) -} diff --git a/crates/framework/src/io/reactor/mod.rs b/crates/framework/src/io/reactor/mod.rs deleted file mode 100644 index 3e3a1f71..00000000 --- a/crates/framework/src/io/reactor/mod.rs +++ /dev/null @@ -1,272 +0,0 @@ -//! Asyncio event loop lifecycle — init, shutdown, task submission. -//! -//! The [`Reactor`] manages the Python asyncio event loop on a dedicated -//! thread. It owns the loop object, `call_soon_threadsafe`, and -//! `create_task` for coroutine submission. - -use std::sync::Mutex; -use std::thread::JoinHandle; - -use pyo3::prelude::*; - -// ── Asyncio event loop utilities ───────────────────────────────────────── - -/// Install the event loop policy (uvloop or asyncio) before creating the loop. -/// -/// Must be called before `asyncio.new_event_loop()` so the factory picks up -/// the right policy. -fn install_loop_policy(py: Python<'_>, policy: &str) { - if policy == "uvloop" { - match py.import(c"uvloop") { - Ok(uvloop) => { - let Ok(asyncio) = py.import(c"asyncio") else { - tracing::error!(name: "apx.reactor.asyncio_import_failed", "failed to import asyncio for uvloop policy install"); - return; - }; - let Ok(policy_obj) = uvloop.call_method0(c"EventLoopPolicy") else { - tracing::error!(name: "apx.reactor.uvloop_event_loop_policy_failed", "uvloop.EventLoopPolicy() call failed"); - return; - }; - if let Err(e) = asyncio.call_method1(c"set_event_loop_policy", (policy_obj,)) { - tracing::error!(name: "apx.reactor.set_event_loop_policy_failed", error = %e, "asyncio.set_event_loop_policy() failed"); - return; - } - tracing::debug!(name: "apx.reactor.uvloop_policy_installed", "installed uvloop event loop policy"); - } - Err(e) => { - tracing::warn!(name: "apx.reactor.uvloop_unavailable_fallback", error = %e, "uvloop not available, falling back to asyncio"); - } - } - } else { - tracing::debug!(name: "apx.reactor.event_loop_policy", policy, "using asyncio event loop policy"); - } -} - -/// Create an asyncio event loop. -fn create_event_loop(py: Python<'_>) -> PyResult> { - tracing::debug!(name: "apx.reactor.creating_event_loop", "creating asyncio event loop"); - py.import(c"asyncio")?.call_method0(c"new_event_loop") -} - -/// Cancel all pending asyncio tasks and run them to completion. -/// -/// Without this step, `loop.close()` leaves live tasks whose cleanup -/// callbacks call `call_soon_threadsafe` on the already-closed loop, -/// producing `RuntimeError: Event loop is closed` on stderr. -fn cancel_pending_tasks(py: Python<'_>, event_loop: &Bound<'_, PyAny>) { - let Ok(asyncio) = py.import(c"asyncio") else { - return; - }; - let Ok(tasks) = asyncio.call_method1(c"all_tasks", (event_loop,)) else { - return; - }; - let Ok(task_list) = pyo3::types::PyList::new( - py, - tasks - .try_iter() - .into_iter() - .flatten() - .flatten() - .collect::>(), - ) else { - return; - }; - for task in task_list.iter() { - let _ = task.call_method0(c"cancel"); - } - let Ok(gather) = asyncio.call_method(c"gather", (&task_list,), Some(&gather_kwargs(py))) else { - return; - }; - let _ = event_loop.call_method1(c"run_until_complete", (gather,)); -} - -/// Build `return_exceptions=True` kwargs for `asyncio.gather`. -fn gather_kwargs(py: Python<'_>) -> Bound<'_, pyo3::types::PyDict> { - let kwargs = pyo3::types::PyDict::new(py); - let _ = kwargs.set_item("return_exceptions", true); - kwargs -} - -/// Shut down all async generators — run their `aclose()` finalizers. -fn shutdown_asyncgens(_py: Python<'_>, event_loop: &Bound<'_, PyAny>) { - let Ok(coro) = event_loop.call_method0(c"shutdown_asyncgens") else { - return; - }; - if let Err(e) = event_loop.call_method1(c"run_until_complete", (&coro,)) { - tracing::warn!(name: "apx.reactor.shutdown_asyncgens_failed", error = %e, "shutdown_asyncgens failed"); - } -} - -/// Shut down the default thread pool executor with a timeout. -/// -/// Uses a 5-second timeout to avoid the Ctrl+C deadlock documented -/// in CPython #111358. `asyncio.run()` uses 5 minutes — we use 5s -/// because our executor usage is minimal (DNS, file I/O). -fn shutdown_default_executor(py: Python<'_>, event_loop: &Bound<'_, PyAny>) { - let Ok(coro) = event_loop.call_method0(c"shutdown_default_executor") else { - return; - }; - let Ok(asyncio) = py.import(c"asyncio") else { - let _ = event_loop.call_method1(c"run_until_complete", (&coro,)); - return; - }; - let Ok(wait_for) = asyncio.call_method1(c"wait_for", (&coro, 5.0)) else { - let _ = event_loop.call_method1(c"run_until_complete", (&coro,)); - return; - }; - if let Err(e) = event_loop.call_method1(c"run_until_complete", (&wait_for,)) { - tracing::warn!(name: "apx.reactor.shutdown_default_executor_failed", error = %e, "shutdown_default_executor failed"); - } -} - -// ── Reactor ────────────────────────────────────────────────────────────── - -/// Asyncio event loop lifecycle manager. -/// -/// Owns the Python asyncio event loop running on a dedicated OS thread, -/// the cached `call_soon_threadsafe` and `create_task` bound methods. -pub struct Reactor { - /// Python asyncio event loop object. - event_loop: Py, - /// Cached `loop.call_soon_threadsafe` bound method. - call_soon_threadsafe: Py, - /// Cached `loop.create_task` bound method. - create_task: Py, - /// Dedicated OS thread running `loop.run_forever()`. - asyncio_thread: Mutex>>, -} - -impl Reactor { - /// Initialize the reactor on the current thread. - /// - /// Sets up the asyncio event loop, marks it as running, enables eager - /// task factory (Python 3.12+), caches submission callables, and - /// spawns a dedicated OS thread running `run_forever()`. - /// - /// # Errors - /// - /// Returns an error if Python initialization fails. - pub fn init(py: Python<'_>, loop_policy: &str) -> Result { - install_loop_policy(py, loop_policy); - - let event_loop = create_event_loop(py).map_err(|e| format!("create_event_loop: {e}"))?; - - let asyncio = py - .import(c"asyncio") - .map_err(|e| format!("import asyncio: {e}"))?; - asyncio - .call_method1(c"set_event_loop", (&event_loop,)) - .map_err(|e| format!("set_event_loop: {e}"))?; - - // Mark as running loop so asyncio.get_running_loop() works for - // libraries (Starlette middleware, DB drivers, etc.). - let events = py - .import(c"asyncio.events") - .map_err(|e| format!("import asyncio.events: {e}"))?; - events - .call_method1(c"_set_running_loop", (&event_loop,)) - .map_err(|e| format!("_set_running_loop: {e}"))?; - tracing::debug!(name: "apx.reactor.set_running_loop_installed", "reactor: _set_running_loop installed"); - - // Eager task factory (Python 3.12+). - if let Ok(eager_factory) = asyncio.getattr(c"eager_task_factory") { - match event_loop.call_method1(c"set_task_factory", (eager_factory,)) { - Ok(_) => { - tracing::debug!(name: "apx.reactor.eager_task_factory_enabled", "eager task factory enabled (Python 3.12+)"); - } - Err(e) => { - tracing::debug!(name: "apx.reactor.eager_task_factory_unavailable", "eager task factory not available: {e}"); - } - } - } - - let call_soon_threadsafe = event_loop - .getattr(c"call_soon_threadsafe") - .map_err(|e| format!("missing call_soon_threadsafe: {e}"))? - .unbind(); - let create_task = event_loop - .getattr(c"create_task") - .map_err(|e| format!("missing create_task: {e}"))? - .unbind(); - - // Spawn dedicated asyncio thread with tokio handle for I/O. - let el_for_thread = event_loop.clone().unbind(); - let tokio_handle = tokio::runtime::Handle::try_current().ok(); - let asyncio_thread = std::thread::Builder::new() - .name("apx-asyncio".to_owned()) - .spawn(move || { - if let Some(handle) = tokio_handle { - super::set_tokio_handle(handle); - } - Python::attach(|py| { - let el = el_for_thread.bind(py); - if let Err(e) = el.call_method0(c"run_forever") { - tracing::error!(name: "apx.reactor.run_forever_failed", error = %e, "asyncio thread: run_forever failed"); - } - }); - }) - .map_err(|e| format!("spawn asyncio thread: {e}"))?; - - tracing::debug!(name: "apx.reactor.initialized", "reactor initialized (asyncio delegation)"); - - Ok(Self { - event_loop: event_loop.unbind(), - call_soon_threadsafe, - create_task, - asyncio_thread: Mutex::new(Some(asyncio_thread)), - }) - } - - /// The Python asyncio event loop object. - pub fn event_loop_py(&self) -> &Py { - &self.event_loop - } - - /// Cached `loop.call_soon_threadsafe` bound method. - pub fn call_soon_threadsafe(&self) -> &Py { - &self.call_soon_threadsafe - } - - /// Cached `loop.create_task` bound method. - pub fn create_task(&self) -> &Py { - &self.create_task - } - - /// Shut down the reactor. - /// - /// 1. Stops the asyncio loop (wakes `run_forever` via `call_soon_threadsafe`). - /// 2. Joins the dedicated asyncio thread. - /// 3. Cancels pending tasks and closes the loop. - pub fn shutdown(&self) { - Python::attach(|py| { - let el = self.event_loop.bind(py); - if let Ok(stop) = el.getattr(c"stop") { - let _ = el.call_method1(c"call_soon_threadsafe", (stop,)); - } - }); - - let handle = self - .asyncio_thread - .lock() - .unwrap_or_else(|e| e.into_inner()) - .take(); - if let Some(h) = handle - && let Err(e) = h.join() - { - tracing::warn!(name: "apx.reactor.asyncio_thread_panicked", "asyncio thread panicked: {e:?}"); - } - - Python::attach(|py| { - let el = self.event_loop.bind(py); - if let Ok(events) = py.import(c"asyncio.events") { - let _ = events.call_method1(c"_set_running_loop", (py.None(),)); - } - shutdown_asyncgens(py, el); - shutdown_default_executor(py, el); - cancel_pending_tasks(py, el); - let _ = el.call_method0(c"close"); - }); - } -} - -crate::opaque_debug!(Reactor); diff --git a/crates/framework/src/lib.rs b/crates/framework/src/lib.rs index b327b230..38a1081a 100644 --- a/crates/framework/src/lib.rs +++ b/crates/framework/src/lib.rs @@ -34,14 +34,12 @@ macro_rules! opaque_debug { } pub(crate) use opaque_debug; -pub mod dispatch; pub(crate) mod protocol; pub mod pyapi; pub mod telemetry; pub mod transport; pub(crate) mod asgi; -pub(crate) mod io; pub mod supervision; #[cfg(test)] diff --git a/crates/framework/src/protocol/connection.rs b/crates/framework/src/protocol/connection.rs new file mode 100644 index 00000000..f3b2e9eb --- /dev/null +++ b/crates/framework/src/protocol/connection.rs @@ -0,0 +1,596 @@ +//! HTTP/1.1 connection handler for asyncio transport/protocol. +//! +//! Implements the asyncio Protocol interface using Rust-accelerated +//! parsing, scope building, and response writing. + +use std::borrow::Cow; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::time::Instant; + +use bytes::Bytes; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyDict, PyList, PyTuple}; + +use crate::asgi::scope::{ResolvedAwaitableWithValue, ScopeInterns}; +use crate::telemetry::dispatch_metrics; +use crate::transport::types::ProtocolVersion; + +use super::parser::{HttpVersion, ParsedHead, ParsedRequest, RequestParser}; +use super::writer::RustResponseWriter; + +/// Maximum concurrent in-flight requests per protocol instance. +const MAX_CONCURRENT: u32 = 256; + +/// Shared state for all protocol instances on this worker. +struct ProtocolShared { + on_request: Py, + interns: ScopeInterns, + server_host: String, + server_port: u16, + active_requests: AtomicU32, +} + +/// Factory that creates [`RustProtocol`] instances for `loop.create_server()`. +/// +/// Holds shared worker state (interns, dispatch callback, concurrency limit). +/// Created in Rust (worker init), passed to Python as a callable. +#[pyclass(module = "apx._core")] +pub struct ProtocolFactory { + shared: Arc, +} + +crate::opaque_debug!(ProtocolFactory); + +impl ProtocolFactory { + /// Create a factory with shared worker state (Rust-side constructor). + pub fn new( + on_request: Py, + interns: ScopeInterns, + server_host: String, + server_port: u16, + ) -> Self { + Self { + shared: Arc::new(ProtocolShared { + on_request, + interns, + server_host, + server_port, + active_requests: AtomicU32::new(0), + }), + } + } +} + +#[pymethods] +impl ProtocolFactory { + /// Called by asyncio as the protocol factory (`loop.create_server(factory)`). + fn __call__(&self, py: Python<'_>) -> PyResult> { + Py::new( + py, + RustProtocol { + transport: None, + parser: RequestParser::new(), + shared: Arc::clone(&self.shared), + client_addr: None, + }, + ) + } +} + +/// HTTP/1.1 protocol for asyncio `loop.create_server()`. +/// +/// Plugs into asyncio's transport/protocol layer. Parses HTTP +/// requests in Rust and dispatches to a Python callback. +#[pyclass(module = "apx._core")] +pub struct RustProtocol { + transport: Option>, + parser: RequestParser, + shared: Arc, + client_addr: Option, +} + +crate::opaque_debug!(RustProtocol); + +#[pymethods] +impl RustProtocol { + /// Called by asyncio when a connection is established. + fn connection_made(&mut self, py: Python<'_>, transport: Py) -> PyResult<()> { + self.client_addr = extract_peer_addr(py, &transport); + self.transport = Some(transport); + dispatch_metrics::inc_connections(); + Ok(()) + } + + /// Called by asyncio when data is received on the connection. + fn data_received(&mut self, py: Python<'_>, data: &[u8]) -> PyResult<()> { + let t0 = Instant::now(); + let requests = self + .parser + .feed(data) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + dispatch_metrics::record_parse(t0.elapsed().as_micros() as f64); + + for parsed in requests { + self.dispatch_request(py, parsed)?; + } + Ok(()) + } + + /// Called by asyncio when the peer sends EOF. + #[expect(clippy::unused_self, reason = "required by asyncio protocol interface")] + fn eof_received(&self) -> bool { + false + } + + /// Called by asyncio when the connection is lost. + fn connection_lost(&mut self, _py: Python<'_>, _exc: Option<&Bound<'_, PyAny>>) { + self.transport = None; + self.parser.reset(); + dispatch_metrics::dec_connections(); + } +} + +impl RustProtocol { + fn dispatch_request(&self, py: Python<'_>, parsed: ParsedRequest) -> PyResult<()> { + let t_dispatch = Instant::now(); + let Some(transport) = &self.transport else { + return Ok(()); + }; + + let active = self.shared.active_requests.fetch_add(1, Ordering::Relaxed); + if active >= MAX_CONCURRENT { + self.shared.active_requests.fetch_sub(1, Ordering::Relaxed); + write_503(py, transport)?; + return Ok(()); + } + dispatch_metrics::inc_active_requests(); + crate::telemetry::http::inc_active_requests(); + + transport.call_method0(py, pyo3::intern!(py, "pause_reading"))?; + + let request_id = resolve_request_id(&parsed.head.headers); + + let t_scope = Instant::now(); + let scope = build_scope_from_parsed( + py, + &parsed, + &self.shared.interns, + &self.shared.server_host, + self.shared.server_port, + self.client_addr, + &request_id, + )?; + dispatch_metrics::record_scope_build(t_scope.elapsed().as_micros() as f64); + + let t_receive = Instant::now(); + let receive = HttpReceive::new(py, parsed.body)?; + dispatch_metrics::record_receive_build(t_receive.elapsed().as_micros() as f64); + + let method = parsed.head.method.as_str().to_owned(); + let path = parsed.head.path; + + let (request_span, trace_ctx) = + crate::telemetry::http::begin_request_span(&request_id, &method, &path); + crate::telemetry::context::set_python_context(py, &trace_ctx)?; + + let transport_clone = transport.clone_ref(py); + let on_complete = OnRequestComplete::create( + py, + transport_clone, + Arc::clone(&self.shared), + t_dispatch, + method, + path, + request_span, + )?; + let send = + RustResponseWriter::new(py, transport.clone_ref(py), Some(on_complete.into_any()))?; + + self.shared.on_request.call1(py, (scope, receive, send))?; + dispatch_metrics::record_dispatch_total(t_dispatch.elapsed().as_micros() as f64); + Ok(()) + } +} + +/// Callback invoked when a response is fully written. +/// +/// Resumes reading on the transport, decrements the active count, +/// records handler_wait duration, emits `http.server.request.duration`, +/// and ends the OTEL request span. +#[pyclass(module = "apx._core")] +struct OnRequestComplete { + transport: Py, + shared: Arc, + dispatch_start: Instant, + method: String, + path: String, + request_span: tracing::Span, +} + +crate::opaque_debug!(OnRequestComplete); + +impl OnRequestComplete { + fn create( + py: Python<'_>, + transport: Py, + shared: Arc, + dispatch_start: Instant, + method: String, + path: String, + request_span: tracing::Span, + ) -> PyResult> { + Py::new( + py, + Self { + transport, + shared, + dispatch_start, + method, + path, + request_span, + }, + ) + } +} + +#[pymethods] +impl OnRequestComplete { + fn __call__(&mut self, py: Python<'_>, status: u16) -> PyResult<()> { + let elapsed = self.dispatch_start.elapsed(); + + { + let _guard = self.request_span.enter(); + dispatch_metrics::record_handler_wait(elapsed.as_micros() as f64); + + crate::telemetry::http::record_duration( + elapsed.as_secs_f64(), + &self.method, + "http", + status, + &self.path, + None, + ); + + crate::telemetry::http::finish_request_span(&self.request_span, status); + } + + self.transport + .call_method0(py, pyo3::intern!(py, "resume_reading"))?; + self.shared.active_requests.fetch_sub(1, Ordering::Relaxed); + dispatch_metrics::dec_active_requests(); + crate::telemetry::http::dec_active_requests(); + Ok(()) + } +} + +// ── HttpReceive ────────────────────────────────────────────────────── + +/// ASGI `receive` callable for non-streaming HTTP requests. +/// +/// First call returns the request body immediately via +/// `ResolvedAwaitableWithValue`. Subsequent calls return a pending +/// future (disconnect sentinel). +#[pyclass(module = "apx._core", freelist = 64)] +pub struct HttpReceive { + body: std::sync::Mutex>, +} + +crate::opaque_debug!(HttpReceive); + +impl HttpReceive { + fn new(py: Python<'_>, body: Bytes) -> PyResult> { + Py::new( + py, + Self { + body: std::sync::Mutex::new(Some(body)), + }, + ) + } +} + +#[pymethods] +impl HttpReceive { + fn __call__(&self, py: Python<'_>) -> PyResult> { + let body = self + .body + .lock() + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))? + .take(); + + if let Some(b) = body { + let event = PyDict::new(py); + event.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?; + event.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &b))?; + event.set_item(pyo3::intern!(py, "more_body"), false)?; + let awaitable = Py::new( + py, + ResolvedAwaitableWithValue::new(event.into_any().unbind()), + )?; + Ok(awaitable.into_any()) + } else { + let fut = py + .import("asyncio")? + .call_method0(pyo3::intern!(py, "get_running_loop"))? + .call_method0(pyo3::intern!(py, "create_future"))?; + Ok(fut.unbind()) + } + } +} + +// ── Scope building ────────────────────────────────────────────────────── + +/// Build an ASGI HTTP scope dict directly from [`ParsedRequest`]. +/// +/// Bypasses `ScopeSource` trait and `HeaderMap` — works with raw byte +/// pairs from the parser, avoiding the intermediate allocation. +fn build_scope_from_parsed( + py: Python<'_>, + parsed: &ParsedRequest, + interns: &ScopeInterns, + server_host: &str, + server_port: u16, + client_addr: Option, + request_id: &str, +) -> PyResult> { + let scope = interns + .scope_template + .bind(py) + .call_method0(pyo3::intern!(py, "copy"))? + .cast_into::()?; + + let version = match parsed.head.version { + HttpVersion::Http10 => ProtocolVersion::Http10, + HttpVersion::Http11 => ProtocolVersion::Http11, + }; + if version != ProtocolVersion::Http11 { + scope.set_item( + interns.keys.http_version.bind(py), + interns.versions.get(py, version), + )?; + } + + scope.set_item(interns.keys.method.bind(py), parsed.head.method.as_str())?; + scope.set_item( + interns.keys.path.bind(py), + percent_decode(&parsed.head.path).as_ref(), + )?; + scope.set_item( + interns.keys.raw_path.bind(py), + PyBytes::new(py, parsed.head.path.as_bytes()), + )?; + scope.set_item( + interns.keys.query_string.bind(py), + PyBytes::new(py, &parsed.head.query_string), + )?; + + set_headers_from_parsed(py, &scope, &parsed.head, interns, request_id)?; + set_addresses(py, &scope, interns, server_host, server_port, client_addr)?; + scope.set_item( + interns.keys.path_params.bind(py), + interns.empty_dict.bind(py), + )?; + scope.set_item(interns.keys.state.bind(py), PyDict::new(py))?; + + Ok(scope.unbind()) +} + +/// Extract existing `x-request-id` from headers or generate a UUID v4. +fn resolve_request_id(headers: &[(Bytes, Bytes)]) -> String { + for (name, value) in headers { + if name.eq_ignore_ascii_case(b"x-request-id") + && let Ok(s) = std::str::from_utf8(value) + { + return s.to_owned(); + } + } + generate_uuid_v4() +} + +/// Set headers list from raw byte pairs (no `HeaderMap` intermediary). +/// +/// Prepends `x-request-id` if not already present in the request. +fn set_headers_from_parsed( + py: Python<'_>, + scope: &Bound<'_, PyDict>, + head: &ParsedHead, + interns: &ScopeInterns, + request_id: &str, +) -> PyResult<()> { + let has_request_id = head + .headers + .iter() + .any(|(name, _)| name.eq_ignore_ascii_case(b"x-request-id")); + + let extra_cap = usize::from(!has_request_id); + let mut pairs: Vec> = Vec::with_capacity(head.headers.len() + extra_cap); + + if !has_request_id { + let id_name = PyBytes::new(py, b"x-request-id"); + let id_value = PyBytes::new(py, request_id.as_bytes()); + let pair = PyTuple::new(py, [id_name.into_any(), id_value.into_any()])?; + pairs.push(pair.into_any()); + } + + for (name, value) in &head.headers { + let n = intern_header_name(py, name, interns); + let v = PyBytes::new(py, value); + let pair = PyTuple::new(py, [n.into_any(), v.into_any()])?; + pairs.push(pair.into_any()); + } + let headers_list = PyList::new(py, &pairs)?; + scope.set_item(interns.keys.headers.bind(py), headers_list)?; + Ok(()) +} + +/// Generate a UUID v4 string (random, RFC 4122 variant 1). +fn generate_uuid_v4() -> String { + let mut bytes: [u8; 16] = rand::random(); + bytes[6] = (bytes[6] & 0x0f) | 0x40; + bytes[8] = (bytes[8] & 0x3f) | 0x80; + format!( + "{:08x}-{:04x}-{:04x}-{:04x}-{:012x}", + u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), + u16::from_be_bytes([bytes[4], bytes[5]]), + u16::from_be_bytes([bytes[6], bytes[7]]), + u16::from_be_bytes([bytes[8], bytes[9]]), + u64::from_be_bytes([ + 0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15] + ]) + ) +} + +/// Try to use the header intern cache, falling back to `PyBytes::new`. +fn intern_header_name<'py>( + py: Python<'py>, + name: &Bytes, + interns: &ScopeInterns, +) -> Bound<'py, PyBytes> { + let name_lower = name.to_ascii_lowercase(); + for (cached_name, cached_py) in &interns.headers.map { + if cached_name.as_str().as_bytes() == name_lower.as_slice() { + return cached_py.bind(py).clone(); + } + } + PyBytes::new(py, &name_lower) +} + +/// Set server and client address tuples. +fn set_addresses( + py: Python<'_>, + scope: &Bound<'_, PyDict>, + interns: &ScopeInterns, + _server_host: &str, + _server_port: u16, + client_addr: Option, +) -> PyResult<()> { + scope.set_item(interns.keys.server.bind(py), interns.server_tuple.bind(py))?; + match client_addr { + Some(addr) => { + scope.set_item( + interns.keys.client.bind(py), + (addr.ip().to_string(), addr.port()), + )?; + } + None => scope.set_item(interns.keys.client.bind(py), py.None())?, + } + Ok(()) +} + +/// Percent-decode a URL path. +fn percent_decode(input: &str) -> Cow<'_, str> { + if !input.contains('%') { + return Cow::Borrowed(input); + } + let mut bytes = Vec::with_capacity(input.len()); + let mut chars = input.as_bytes().iter().copied(); + while let Some(b) = chars.next() { + if b == b'%' { + let hi = chars.next(); + let lo = chars.next(); + if let (Some(h), Some(l)) = (hi, lo) { + if let (Some(hv), Some(lv)) = (hex_val(h), hex_val(l)) { + bytes.push(hv << 4 | lv); + continue; + } + bytes.extend_from_slice(&[b'%', h, l]); + } else { + bytes.push(b'%'); + if let Some(h) = hi { + bytes.push(h); + } + } + } else { + bytes.push(b); + } + } + match String::from_utf8(bytes) { + Ok(s) => Cow::Owned(s), + Err(e) => Cow::Owned(String::from_utf8_lossy(e.as_bytes()).into_owned()), + } +} + +/// Convert a hex ASCII char to its value. +fn hex_val(b: u8) -> Option { + match b { + b'0'..=b'9' => Some(b - b'0'), + b'a'..=b'f' => Some(b - b'a' + 10), + b'A'..=b'F' => Some(b - b'A' + 10), + _ => None, + } +} + +/// Extract the peer address from an asyncio transport. +fn extract_peer_addr(py: Python<'_>, transport: &Py) -> Option { + let peername = transport + .call_method1( + py, + pyo3::intern!(py, "get_extra_info"), + (pyo3::intern!(py, "peername"),), + ) + .ok()?; + if peername.is_none(py) { + return None; + } + let bound = peername.bind(py); + let tuple: &Bound<'_, PyTuple> = bound.cast().ok()?; + let host: String = tuple.get_item(0).ok()?.extract().ok()?; + let port: u16 = tuple.get_item(1).ok()?.extract().ok()?; + format!("{host}:{port}").parse().ok() +} + +/// Write a 503 Service Unavailable response directly. +fn write_503(py: Python<'_>, transport: &Py) -> PyResult<()> { + let body = b"Service Unavailable"; + let response = format!( + "HTTP/1.1 503 Service Unavailable\r\ncontent-length: {}\r\ncontent-type: text/plain\r\n\r\nService Unavailable", + body.len(), + ); + let py_bytes = PyBytes::new(py, response.as_bytes()); + transport.call_method1(py, pyo3::intern!(py, "write"), (py_bytes,))?; + Ok(()) +} + +#[cfg(test)] +#[expect(clippy::expect_used, reason = "test code uses expect for clarity")] +mod tests { + use super::*; + + #[test] + fn test_generate_uuid_v4_format() { + let id = generate_uuid_v4(); + assert_eq!(id.len(), 36); + let parts: Vec<&str> = id.split('-').collect(); + assert_eq!(parts.len(), 5); + assert_eq!(parts[0].len(), 8); + assert_eq!(parts[1].len(), 4); + assert_eq!(parts[2].len(), 4); + assert_eq!(parts[3].len(), 4); + assert_eq!(parts[4].len(), 12); + } + + #[test] + fn test_generate_uuid_v4_version_bits() { + let id = generate_uuid_v4(); + let version_char = id.chars().nth(14).expect("version char"); + assert_eq!(version_char, '4', "UUID version nibble should be 4"); + } + + #[test] + fn test_generate_uuid_v4_variant_bits() { + let id = generate_uuid_v4(); + let variant_char = id.chars().nth(19).expect("variant char"); + assert!( + matches!(variant_char, '8' | '9' | 'a' | 'b'), + "UUID variant nibble should be 8/9/a/b, got {variant_char}" + ); + } + + #[test] + fn test_generate_uuid_v4_uniqueness() { + let a = generate_uuid_v4(); + let b = generate_uuid_v4(); + assert_ne!(a, b); + } +} diff --git a/crates/framework/src/protocol/http/error.rs b/crates/framework/src/protocol/http/error.rs deleted file mode 100644 index 1d9c6cbd..00000000 --- a/crates/framework/src/protocol/http/error.rs +++ /dev/null @@ -1,152 +0,0 @@ -//! Structured error types for the framework runtime. - -use http::StatusCode; - -/// Max depth to walk the error source chain (fixed loop bound). -#[cfg(test)] -const MAX_ERROR_CHAIN_DEPTH: usize = 10; - -/// Walk an error's source chain looking for a specific error type. -#[cfg(test)] -pub fn find_in_error_chain( - err: &dyn std::error::Error, -) -> Option<&T> { - let mut source = err.source(); - for _ in 0..MAX_ERROR_CHAIN_DEPTH { - let e = source?; - if let Some(found) = e.downcast_ref::() { - return Some(found); - } - source = e.source(); - } - None -} - -/// Application error. -/// -/// **Security**: `Internal` logs the full error via `tracing::error!` but -/// returns a generic "Internal Server Error" detail to the client. Never -/// leak exception messages, file paths, or connection strings in 500 responses. -#[derive(Debug, thiserror::Error)] -pub enum AppError { - /// Internal error (500) — detail is logged, NOT sent to client. - #[error("internal error: {0}")] - Internal(String), - - /// Request timeout (408). - #[error("request timeout")] - Timeout, -} - -impl AppError { - /// Convert to status code. - pub(crate) fn status_code(&self) -> StatusCode { - match self { - Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR, - Self::Timeout => StatusCode::REQUEST_TIMEOUT, - } - } -} - -// ── Tests ─────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - reason = "test code uses unwrap/assert for clarity" -)] -mod tests { - use super::*; - - #[test] - fn find_in_error_chain_not_found() { - #[derive(Debug)] - struct SimpleErr; - impl std::fmt::Display for SimpleErr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("simple") - } - } - impl std::error::Error for SimpleErr {} - - let err = SimpleErr; - assert!(find_in_error_chain::(&err).is_none()); - } - - #[test] - fn app_error_internal_does_not_leak() { - let err = AppError::Internal("secret db password: hunter2".to_owned()); - // status_code is correct - assert_eq!(err.status_code(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[test] - fn app_error_status_codes() { - assert_eq!( - AppError::Internal("x".to_owned()).status_code(), - StatusCode::INTERNAL_SERVER_ERROR - ); - assert_eq!(AppError::Timeout.status_code(), StatusCode::REQUEST_TIMEOUT); - } - - /// Produce a boxed error whose chain contains `LengthLimitError`. - async fn make_length_limit_boxed_error() -> Box { - use http_body_util::{BodyExt, Full, Limited}; - Limited::new(Full::new(bytes::Bytes::from("xx")), 0) - .collect() - .await - .unwrap_err() - } - - #[tokio::test] - async fn find_in_error_chain_positive() { - #[derive(Debug)] - struct Wrapper(Box); - impl std::fmt::Display for Wrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("wrap") - } - } - impl std::error::Error for Wrapper { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(self.0.as_ref()) - } - } - let lle = make_length_limit_boxed_error().await; - let err = Wrapper(lle); - assert!(find_in_error_chain::(&err).is_some()); - } - - #[tokio::test] - async fn find_in_error_chain_depth_two() { - #[derive(Debug)] - struct Inner(Box); - impl std::fmt::Display for Inner { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("inner") - } - } - impl std::error::Error for Inner { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(self.0.as_ref()) - } - } - - #[derive(Debug)] - struct Outer(Inner); - impl std::fmt::Display for Outer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("outer") - } - } - impl std::error::Error for Outer { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&self.0) - } - } - - let lle = make_length_limit_boxed_error().await; - let err = Outer(Inner(lle)); - assert!(find_in_error_chain::(&err).is_some()); - } -} diff --git a/crates/framework/src/protocol/http/mod.rs b/crates/framework/src/protocol/http/mod.rs deleted file mode 100644 index 37596d9d..00000000 --- a/crates/framework/src/protocol/http/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -//! HTTP protocol: Hyper service, health probes, concurrency. - -pub mod error; -pub mod service; diff --git a/crates/framework/src/protocol/http/service.rs b/crates/framework/src/protocol/http/service.rs deleted file mode 100644 index 1a3dde74..00000000 --- a/crates/framework/src/protocol/http/service.rs +++ /dev/null @@ -1,698 +0,0 @@ -//! Hyper `Service` implementation with health probes, concurrency limiting, -//! and request timeout. -//! -//! `ApxService` is the HTTP layer between hyper and the application dispatch. -//! It short-circuits health probes, enforces a per-worker concurrency limit -//! via `Arc`, and wraps dispatch in `tokio::time::timeout`. - -use crate::dispatch::Dispatch; -use crate::protocol::ws::session as websocket; -use crate::telemetry::http::{self, ActiveRequestGuard}; -use crate::transport::tcp::TcpListener; -use crate::transport::types::{ - BodyStream, InboundRequest, OutboundResponse, ProtocolVersion, ResponseBody, TransportKind, -}; -use ::http::header::{HeaderMap, HeaderName, HeaderValue}; -use bytes::Bytes; -use hyper::body::Incoming; -use hyper::server::conn::http1; -use hyper::service::Service; -use hyper::{Request, Response}; -use hyper_util::rt::TokioIo; -use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState}; -use std::convert::Infallible; -use std::future::Future; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::Duration; -use tokio::sync::OwnedSemaphorePermit; -use tracing::Instrument; -use tracing_opentelemetry::OpenTelemetrySpanExt; - -// ── Constants ──────────────────────────────────────────────────────────── - -/// Default request timeout. -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); - -/// Default max concurrent requests per worker. -const DEFAULT_MAX_CONCURRENT: usize = 256; - -/// Health probe response: alive. -const HEALTH_ALIVE: &[u8] = br#"{"status":"alive"}"#; - -/// Health probe response: ready. -const HEALTH_READY: &[u8] = br#"{"status":"ready"}"#; - -/// JSON content type for health responses. -const JSON_CONTENT_TYPE: &str = "application/json"; - -/// Databricks Apps `X-Request-Id` header. -pub const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id"); - -/// Ensure every request carries an `X-Request-Id` header. -/// -/// Databricks Apps always sets this header. For local dev (no proxy), -/// a UUID v4 is generated so downstream telemetry can always rely on it. -pub fn ensure_request_id(headers: &mut HeaderMap) { - if headers.contains_key(&REQUEST_ID_HEADER) { - return; - } - let id = uuid::Uuid::new_v4().to_string(); - if let Ok(val) = HeaderValue::from_str(&id) { - headers.insert(REQUEST_ID_HEADER, val); - } -} - -// ── Config ─────────────────────────────────────────────────────────────── - -/// Configuration for the HTTP service layer. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ServiceConfig { - /// Per-request timeout. - pub timeout: Duration, - /// Maximum concurrent requests per worker. - pub max_concurrent: usize, -} - -impl Default for ServiceConfig { - fn default() -> Self { - Self { - timeout: DEFAULT_TIMEOUT, - max_concurrent: DEFAULT_MAX_CONCURRENT, - } - } -} - -// ── Service ────────────────────────────────────────────────────────────── - -/// Hyper service implementation. -/// -/// Cloned per-connection so that `client_addr` can be set for each connection. -/// The `dispatch` and `semaphore` are shared via `Arc`. -#[derive(Clone)] -pub struct ApxService { - dispatch: Arc, - semaphore: Arc, - timeout: Duration, - server_addr: SocketAddr, - client_addr: Option, -} - -impl ApxService { - /// Create a new `ApxService`. - pub fn new( - dispatch: Arc, - server_addr: SocketAddr, - config: &ServiceConfig, - ) -> Self { - Self { - dispatch, - semaphore: Arc::new(tokio::sync::Semaphore::new(config.max_concurrent)), - timeout: config.timeout, - server_addr, - client_addr: None, - } - } - - /// Set the client address for this per-connection clone. - pub fn with_client_addr(mut self, addr: SocketAddr) -> Self { - self.client_addr = Some(addr); - self - } -} - -impl std::fmt::Debug for ApxService { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ApxService") - .field("dispatch", &self.dispatch) - .field("timeout", &self.timeout) - .field("server_addr", &self.server_addr) - .field("client_addr", &self.client_addr) - .finish_non_exhaustive() - } -} - -impl Service> for ApxService { - type Response = Response; - type Error = Infallible; - type Future = Pin> + Send>>; - - fn call(&self, req: Request) -> Self::Future { - let this = self.clone(); - Box::pin(async move { Ok(this.handle(req).await) }) - } -} - -// ── Trace context ──────────────────────────────────────────────────────── - -/// Parse the `tracestate` HTTP header into an OTEL `TraceState`. -/// -/// Falls back to an empty `TraceState` if the header is missing or malformed. -fn parse_tracestate_header(headers: &HeaderMap) -> TraceState { - headers - .get("tracestate") - .and_then(|v| v.to_str().ok()) - .and_then(|raw| { - TraceState::from_key_value( - raw.split(',') - .filter_map(|pair| pair.split_once('=').map(|(k, v)| (k.trim(), v.trim()))), - ) - .ok() - }) - .unwrap_or_default() -} - -/// Parse `x-request-id` UUID into an OTEL `TraceId`. -/// -/// Databricks Apps always sends a UUID v4 (128 bits = OTEL TraceId size). -/// For locally-generated UUIDs the same mapping applies. -fn parse_request_id_as_trace_id(headers: &HeaderMap) -> Option { - let val = headers.get(&REQUEST_ID_HEADER)?.to_str().ok()?; - let uuid = uuid::Uuid::parse_str(val).ok()?; - Some(TraceId::from_bytes(*uuid.as_bytes())) -} - -/// Build a `tracing` span for an HTTP request with OTEL semantic conventions. -/// -/// If the `x-request-id` header contains a valid UUID, the span's trace_id -/// is set to match so all downstream spans share the Databricks correlation ID. -fn build_request_span( - headers: &HeaderMap, - method: &str, - scheme: &str, - path: &str, -) -> tracing::Span { - let request_id = headers - .get(&REQUEST_ID_HEADER) - .and_then(|v| v.to_str().ok()) - .unwrap_or(""); - - let span = tracing::info_span!( - target: "apx::http", - "http.server.request", - otel.kind = "server", - http.request.method = method, - url.scheme = scheme, - url.path = path, - request.id = request_id, - http.response.status_code = tracing::field::Empty, - ); - - if let Some(tid) = parse_request_id_as_trace_id(headers) { - let parent_span_id = SpanId::from_bytes( - uuid::Uuid::new_v4().as_bytes()[..8] - .try_into() - .unwrap_or([0; 8]), - ); - let trace_state = parse_tracestate_header(headers); - let parent_sc = - SpanContext::new(tid, parent_span_id, TraceFlags::SAMPLED, true, trace_state); - let parent_cx = opentelemetry::Context::new().with_remote_span_context(parent_sc); - span.set_parent(parent_cx); - } - - span -} - -// ── Request pipeline ───────────────────────────────────────────────────── - -impl ApxService { - /// Main request handler — orchestrates probe check, semaphore, timeout, dispatch. - async fn handle(self, req: Request) -> Response { - let method = req.method().as_str().to_owned(); - let scheme = "http"; - - // Health probe short-circuit — no span needed. - if let Some(probe_resp) = probe_response(req.uri().path()) { - return probe_resp; - } - - // WebSocket upgrade — must happen before consuming the request body. - if websocket::is_websocket_upgrade(&req) { - return self.handle_ws(req, &method, scheme).await; - } - - let path = req.uri().path().to_owned(); - let inbound = inbound_from_hyper(req, path.clone(), self.server_addr, self.client_addr); - let span = build_request_span(&inbound.headers, &method, scheme, &path); - - self.handle_http(inbound, method, scheme, path, span).await - } - - /// WebSocket upgrade path — no OTEL span (short-lived handshake). - async fn handle_ws( - self, - req: Request, - method: &str, - scheme: &str, - ) -> Response { - let path = req.uri().path().to_owned(); - let start = std::time::Instant::now(); - let response = self - .dispatch - .dispatch_ws(req, self.server_addr, self.client_addr) - .await; - let status = response.status().as_u16(); - http::record_duration( - start.elapsed().as_secs_f64(), - method, - scheme, - status, - &path, - None, - ); - response - } - - /// HTTP dispatch path — wrapped in an OTEL span. - async fn handle_http( - self, - inbound: InboundRequest, - method: String, - scheme: &str, - path: String, - span: tracing::Span, - ) -> Response { - async { - let _active = ActiveRequestGuard::enter(&method, scheme); - let start = std::time::Instant::now(); - - tracing::info!(name: "apx.http.request", "~> {} {}", method, path); - - let Ok(permit) = Arc::clone(&self.semaphore).try_acquire_owned() else { - let elapsed_ms = start.elapsed().as_millis(); - tracing::info!(name: "apx.http.response", "<~ {} {} 503 [{}ms]", method, path, elapsed_ms); - let resp = - error_response(hyper::StatusCode::SERVICE_UNAVAILABLE, "service overloaded"); - http::record_duration( - start.elapsed().as_secs_f64(), - &method, - scheme, - 503, - "", - Some("503"), - ); - return resp; - }; - - let result = tokio::time::timeout(self.timeout, self.dispatch.dispatch(inbound)).await; - - let (response, server_route) = match result { - Ok(mut outbound) => { - let route = outbound.server_route.take(); - if let ResponseBody::Stream(stream) = outbound.body { - outbound.body = ResponseBody::Stream(Box::pin(PermitGuardedStream { - inner: stream, - _permit: permit, - })); - } else { - drop(permit); - } - (outbound_to_hyper(outbound), route) - } - Err(_elapsed) => { - drop(permit); - ( - error_response(hyper::StatusCode::REQUEST_TIMEOUT, "request timeout"), - None, - ) - } - }; - - let route = server_route.as_deref().unwrap_or(&path); - let status = response.status().as_u16(); - let elapsed = start.elapsed().as_secs_f64(); - let elapsed_ms = (elapsed * 1000.0) as u64; - let error_type = if status >= 400 { - Some(status.to_string()) - } else { - None - }; - - tracing::Span::current().record("http.response.status_code", status); - http::record_duration( - elapsed, - &method, - scheme, - status, - route, - error_type.as_deref(), - ); - - tracing::info!(name: "apx.http.response", "<~ {} {} {} [{}ms]", method, route, status, elapsed_ms); - - response - } - .instrument(span) - .await - } -} - -/// Check if the path is a health probe and return the response. -fn probe_response(path: &str) -> Option> { - let body = match path { - "/healthz" => HEALTH_ALIVE, - "/readyz" => HEALTH_READY, - _ => return None, - }; - - // Builder with static status + header cannot fail. - let resp = Response::builder() - .status(hyper::StatusCode::OK) - .header(hyper::header::CONTENT_TYPE, JSON_CONTENT_TYPE) - .body(ResponseBody::Fixed(Bytes::from_static(body))) - .unwrap_or_else(|_| unreachable!()); - - Some(resp) -} - -/// Convert a hyper request to an `InboundRequest`. -/// -/// Accepts a pre-extracted `path` to avoid re-extracting from the URI -/// (the caller already needs the path for metrics recording). -fn inbound_from_hyper( - req: Request, - path: String, - server_addr: SocketAddr, - client_addr: Option, -) -> InboundRequest { - use http_body::Body as _; - - let (parts, body) = req.into_parts(); - - let method = parts.method; - let query_string = parts - .uri - .query() - .map(|q| Bytes::copy_from_slice(q.as_bytes())) - .unwrap_or_default(); - let mut headers = parts.headers; - ensure_request_id(&mut headers); - - let protocol = match parts.version { - hyper::Version::HTTP_10 => ProtocolVersion::Http10, - hyper::Version::HTTP_2 => ProtocolVersion::H2, - _ => ProtocolVersion::Http11, - }; - - let body_stream = if body.is_end_stream() { - BodyStream::Empty - } else { - let stream = http_body_util::BodyStream::new(body); - let mapped = futures_util::StreamExt::map(stream, |result| { - result - .map(|frame| frame.into_data().unwrap_or_default()) - .map_err(|e| std::io::Error::other(e.to_string())) - }); - BodyStream::Stream(Box::pin(mapped)) - }; - - InboundRequest::new( - method, - path, - query_string, - headers, - body_stream, - protocol, - TransportKind::Tcp, - client_addr, - server_addr, - Vec::new(), - parts.extensions, - ) -} - -/// Convert an `OutboundResponse` to a hyper response. -fn outbound_to_hyper(resp: OutboundResponse) -> Response { - let mut builder = Response::builder().status(resp.status); - if let Some(headers) = builder.headers_mut() { - *headers = resp.headers; - } - // Builder with valid status cannot fail. - builder.body(resp.body).unwrap_or_else(|_| unreachable!()) -} - -/// Construct an error response with a plain-text body. -fn error_response(status: hyper::StatusCode, body: &str) -> Response { - // Builder with valid status + header cannot fail. - Response::builder() - .status(status) - .header(hyper::header::CONTENT_TYPE, "text/plain") - .body(ResponseBody::Fixed(Bytes::copy_from_slice(body.as_bytes()))) - .unwrap_or_else(|_| unreachable!()) -} - -// ── PermitGuardedStream ────────────────────────────────────────────────── - -/// Streaming body that holds a semaphore permit for its lifetime. -/// -/// Ensures SSE connections count against the concurrency limit -/// until the stream ends or the client disconnects. -struct PermitGuardedStream { - inner: Pin> + Send>>, - _permit: OwnedSemaphorePermit, -} - -impl futures_core::Stream for PermitGuardedStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.as_mut().poll_next(cx) - } -} - -// ── Serve loop ─────────────────────────────────────────────────────────── - -/// Accept connections and serve them using the given `ApxService`. -/// -/// Runs until the `shutdown` future completes, then stops accepting new -/// connections. Returns a `JoinSet` of in-flight connections so the caller -/// can await their completion (graceful drain). -pub async fn serve_tcp( - listener: TcpListener, - service: ApxService, - shutdown: impl Future + Send + 'static, -) -> Result, std::io::Error> { - tokio::pin!(shutdown); - let mut connections = tokio::task::JoinSet::new(); - - loop { - // Reap finished tasks to avoid unbounded growth. - while connections.try_join_next().is_some() {} - - tokio::select! { - result = listener.accept() => { - let (stream, client_addr) = result?; - let svc = service.clone().with_client_addr(client_addr); - connections.spawn(serve_connection(stream, svc)); - } - () = &mut shutdown => { - tracing::debug!(name: "apx.http.accept_shutdown", "shutdown signal received, stopping accept loop"); - break; - } - } - } - - Ok(connections) -} - -/// Serve a single connection using HTTP/1 auto-detection. -async fn serve_connection(stream: tokio::net::TcpStream, service: ApxService) { - if let Err(e) = stream.set_nodelay(true) { - tracing::debug!(name: "apx.http.tcp_nodelay_failed", error = %e, "failed to set TCP_NODELAY"); - } - let io = TokioIo::new(stream); - let result = http1::Builder::new() - .pipeline_flush(true) - .serve_connection(io, service) - .with_upgrades() - .await; - if let Err(e) = result { - tracing::debug!(name: "apx.http.connection_error", error = %e, "connection error"); - } -} - -// ── Tests ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - clippy::panic, - reason = "test code uses unwrap/assert for clarity" -)] -mod tests { - use super::*; - use ::http::header::HeaderMap; - - /// Stub dispatch for testing — returns 200 with "ok" body. - #[derive(Debug)] - struct StubDispatch; - - impl Dispatch for StubDispatch { - fn dispatch( - &self, - _request: InboundRequest, - ) -> Pin + Send>> { - Box::pin(async { - OutboundResponse { - status: hyper::StatusCode::OK, - headers: HeaderMap::new(), - body: ResponseBody::Fixed(Bytes::from_static(b"ok")), - server_route: None, - } - }) - } - } - - fn stub_service() -> ApxService { - let dispatch: Arc = Arc::new(StubDispatch); - let config = ServiceConfig::default(); - let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); - ApxService::new(dispatch, addr, &config) - } - - #[test] - fn probe_healthz_returns_200_with_json() { - let resp = probe_response("/healthz").unwrap(); - assert_eq!(resp.status(), hyper::StatusCode::OK); - assert_eq!( - resp.headers().get("content-type").unwrap(), - JSON_CONTENT_TYPE - ); - match resp.body() { - ResponseBody::Fixed(b) => assert_eq!(b.as_ref(), HEALTH_ALIVE), - ResponseBody::Stream(_) => panic!("expected Fixed body"), - } - } - - #[test] - fn probe_readyz_returns_200_with_json() { - let resp = probe_response("/readyz").unwrap(); - assert_eq!(resp.status(), hyper::StatusCode::OK); - assert_eq!( - resp.headers().get("content-type").unwrap(), - JSON_CONTENT_TYPE - ); - match resp.body() { - ResponseBody::Fixed(b) => assert_eq!(b.as_ref(), HEALTH_READY), - ResponseBody::Stream(_) => panic!("expected Fixed body"), - } - } - - #[test] - fn probe_unknown_path_returns_none() { - assert!(probe_response("/api/users").is_none()); - assert!(probe_response("/").is_none()); - assert!(probe_response("/health").is_none()); - } - - #[test] - fn service_config_default_values() { - let config = ServiceConfig::default(); - assert_eq!(config.timeout, Duration::from_secs(30)); - assert_eq!(config.max_concurrent, 256); - } - - #[test] - fn error_response_503_service_unavailable() { - let resp = error_response(hyper::StatusCode::SERVICE_UNAVAILABLE, "overloaded"); - assert_eq!(resp.status(), hyper::StatusCode::SERVICE_UNAVAILABLE); - assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain"); - match resp.body() { - ResponseBody::Fixed(b) => assert_eq!(b.as_ref(), b"overloaded"), - ResponseBody::Stream(_) => panic!("expected Fixed body"), - } - } - - #[test] - fn error_response_408_request_timeout() { - let resp = error_response(hyper::StatusCode::REQUEST_TIMEOUT, "timeout"); - assert_eq!(resp.status(), hyper::StatusCode::REQUEST_TIMEOUT); - match resp.body() { - ResponseBody::Fixed(b) => assert_eq!(b.as_ref(), b"timeout"), - ResponseBody::Stream(_) => panic!("expected Fixed body"), - } - } - - #[test] - fn outbound_to_hyper_preserves_status_and_headers() { - let mut headers = HeaderMap::new(); - headers.insert("x-custom", "value".parse().unwrap()); - headers.insert("content-type", "application/json".parse().unwrap()); - - let outbound = OutboundResponse { - status: hyper::StatusCode::CREATED, - headers, - body: ResponseBody::Fixed(Bytes::from_static(b"{}")), - server_route: None, - }; - - let resp = outbound_to_hyper(outbound); - assert_eq!(resp.status(), hyper::StatusCode::CREATED); - assert_eq!(resp.headers().get("x-custom").unwrap(), "value"); - assert_eq!( - resp.headers().get("content-type").unwrap(), - "application/json" - ); - } - - #[test] - fn apx_service_debug_does_not_panic() { - let service = stub_service(); - let dbg = format!("{service:?}"); - assert!(dbg.contains("ApxService")); - assert!(dbg.contains("8080")); - } - - #[tokio::test] - async fn permit_guarded_stream_holds_permit() { - let sem = Arc::new(tokio::sync::Semaphore::new(1)); - let permit = Arc::clone(&sem).try_acquire_owned().unwrap(); - - // Wrap a stream with the permit. - let chunks = vec![Ok(Bytes::from("hello")), Ok(Bytes::from(" world"))]; - let inner_stream = tokio_stream::iter(chunks); - let mut stream = PermitGuardedStream { - inner: Box::pin(inner_stream), - _permit: permit, - }; - - // While the stream is alive, the semaphore should have 0 permits. - assert_eq!(sem.available_permits(), 0); - - // Consume the stream. - use futures_core::Stream; - let waker = futures_util::task::noop_waker(); - let mut cx = Context::from_waker(&waker); - let _ = Pin::new(&mut stream).poll_next(&mut cx); - assert_eq!(sem.available_permits(), 0); - - drop(stream); - assert_eq!(sem.available_permits(), 1); - } - - #[tokio::test] - async fn fixed_response_drops_permit_immediately() { - let sem = Arc::new(tokio::sync::Semaphore::new(1)); - let permit = Arc::clone(&sem).try_acquire_owned().unwrap(); - assert_eq!(sem.available_permits(), 0); - - // Simulate what handle() does for fixed responses. - let outbound = OutboundResponse { - status: hyper::StatusCode::OK, - headers: HeaderMap::new(), - body: ResponseBody::Fixed(Bytes::from_static(b"ok")), - server_route: None, - }; - - // Fixed body — permit is dropped immediately. - if matches!(outbound.body, ResponseBody::Stream(_)) { - unreachable!(); - } else { - drop(permit); - } - assert_eq!(sem.available_permits(), 1); - } -} diff --git a/crates/framework/src/protocol/mod.rs b/crates/framework/src/protocol/mod.rs index f381e7fc..b8e33eeb 100644 --- a/crates/framework/src/protocol/mod.rs +++ b/crates/framework/src/protocol/mod.rs @@ -1,4 +1,6 @@ -//! Application protocol handling (HTTP, WebSocket). +//! HTTP protocol: parsing, connection handling, and response writing. -pub mod http; -pub mod ws; +pub mod connection; +pub mod parser; +pub mod router; +pub mod writer; diff --git a/crates/framework/src/protocol/parser.rs b/crates/framework/src/protocol/parser.rs new file mode 100644 index 00000000..e5f84901 --- /dev/null +++ b/crates/framework/src/protocol/parser.rs @@ -0,0 +1,406 @@ +//! Sans-I/O HTTP/1.1 request parser. +//! +//! Pure parsing — no I/O, no Python, no async. Takes bytes in, +//! returns parsed request data out. Testable with `#[test]`. + +use bytes::{Bytes, BytesMut}; + +/// Maximum number of HTTP headers supported per request. +pub const MAX_HEADERS: usize = 96; + +/// HTTP method (small enum for the common methods). +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Method { + /// GET + Get, + /// POST + Post, + /// PUT + Put, + /// DELETE + Delete, + /// PATCH + Patch, + /// HEAD + Head, + /// OPTIONS + Options, + /// Any other method stored as a string. + Other(String), +} + +impl Method { + fn from_str(s: &str) -> Self { + match s { + "GET" => Self::Get, + "POST" => Self::Post, + "PUT" => Self::Put, + "DELETE" => Self::Delete, + "PATCH" => Self::Patch, + "HEAD" => Self::Head, + "OPTIONS" => Self::Options, + other => Self::Other(other.to_owned()), + } + } + + /// ASGI-compatible method string. + pub fn as_str(&self) -> &str { + match self { + Self::Get => "GET", + Self::Post => "POST", + Self::Put => "PUT", + Self::Delete => "DELETE", + Self::Patch => "PATCH", + Self::Head => "HEAD", + Self::Options => "OPTIONS", + Self::Other(s) => s, + } + } +} + +/// HTTP protocol version. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HttpVersion { + /// HTTP/1.0 + Http10, + /// HTTP/1.1 + Http11, +} + +/// Parsed request head (status line + headers). +#[derive(Debug, Clone)] +pub struct ParsedHead { + /// HTTP method. + pub method: Method, + /// Request path without query string. + pub path: String, + /// Raw query string (without leading `?`), empty if none. + pub query_string: Bytes, + /// Header pairs as raw bytes. + pub headers: Vec<(Bytes, Bytes)>, + /// HTTP version. + pub version: HttpVersion, + /// Content-Length value, if present. + pub content_length: Option, +} + +/// A fully parsed HTTP request (head + body). +#[derive(Debug, Clone)] +pub struct ParsedRequest { + /// Request head. + pub head: ParsedHead, + /// Request body (may be empty). + pub body: Bytes, +} + +/// Parser state machine. +#[derive(Debug)] +enum ParseState { + /// Waiting for a complete request line + headers. + AwaitingHead, + /// Head is parsed, accumulating body bytes. + AwaitingBody { + /// Parsed head. + head: ParsedHead, + /// Remaining body bytes to read. + remaining: usize, + }, +} + +/// Incremental HTTP/1.1 request parser. +/// +/// Accumulates bytes via `feed()` and returns zero or more complete +/// requests (supports HTTP pipelining). +#[derive(Debug)] +pub struct RequestParser { + buf: BytesMut, + state: ParseState, +} + +/// Parser error. +#[derive(Debug, thiserror::Error)] +pub enum ParseError { + /// httparse returned an error (malformed request). + #[error("invalid HTTP request: {0}")] + Invalid(String), +} + +impl RequestParser { + /// Create a new parser. + pub fn new() -> Self { + Self { + buf: BytesMut::with_capacity(8192), + state: ParseState::AwaitingHead, + } + } + + /// Reset the parser state, discarding any partial data. + pub fn reset(&mut self) { + self.buf.clear(); + self.state = ParseState::AwaitingHead; + } + + /// Feed bytes into the parser. + /// + /// Returns zero or more complete requests. Partial data is buffered + /// for the next `feed()` call. Supports HTTP pipelining: a single + /// `data_received` call can contain multiple requests. + /// + /// # Errors + /// + /// Returns `ParseError` if the data contains malformed HTTP. + pub fn feed(&mut self, data: &[u8]) -> Result, ParseError> { + self.buf.extend_from_slice(data); + let mut requests = Vec::new(); + + loop { + match &self.state { + ParseState::AwaitingHead => { + let Some((head, consumed)) = self.try_parse_head()? else { + break; + }; + self.advance_buffer(consumed); + let content_length = head.content_length.unwrap_or(0); + if content_length == 0 { + requests.push(ParsedRequest { + head, + body: Bytes::new(), + }); + } else { + self.state = ParseState::AwaitingBody { + head, + remaining: content_length, + }; + } + } + ParseState::AwaitingBody { remaining, .. } => { + let remaining = *remaining; + if self.buf.len() < remaining { + break; + } + let body = Bytes::copy_from_slice(&self.buf[..remaining]); + self.advance_buffer(remaining); + + let state = std::mem::replace(&mut self.state, ParseState::AwaitingHead); + if let ParseState::AwaitingBody { head, .. } = state { + requests.push(ParsedRequest { head, body }); + } + } + } + } + + Ok(requests) + } + + /// Try to parse a complete request head from the buffer. + /// + /// Returns `None` if more data is needed. + fn try_parse_head(&self) -> Result, ParseError> { + let mut headers_buf = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut req = httparse::Request::new(&mut headers_buf); + + match req.parse(&self.buf) { + Ok(httparse::Status::Complete(consumed)) => { + let head = build_head(&req)?; + Ok(Some((head, consumed))) + } + Ok(httparse::Status::Partial) => Ok(None), + Err(e) => Err(ParseError::Invalid(e.to_string())), + } + } + + fn advance_buffer(&mut self, n: usize) { + let _ = self.buf.split_to(n); + } +} + +/// Build a `ParsedHead` from a completed httparse request. +fn build_head(req: &httparse::Request<'_, '_>) -> Result { + let method_str = req.method.unwrap_or("GET"); + let method = Method::from_str(method_str); + + let raw_path = req.path.unwrap_or("/"); + let (path, query_string) = split_path_query(raw_path); + + let version = match req.version { + Some(0) => HttpVersion::Http10, + _ => HttpVersion::Http11, + }; + + let mut content_length = None; + let mut headers = Vec::with_capacity(req.headers.len()); + + for header in req.headers.iter() { + let name = Bytes::copy_from_slice(header.name.as_bytes()); + let value = Bytes::copy_from_slice(header.value); + if header.name.eq_ignore_ascii_case("content-length") + && let Ok(s) = std::str::from_utf8(header.value) + { + content_length = s.trim().parse().ok(); + } + headers.push((name, value)); + } + + Ok(ParsedHead { + method, + path: path.to_owned(), + query_string, + headers, + version, + content_length, + }) +} + +/// Split a raw path into path and query string. +fn split_path_query(raw: &str) -> (&str, Bytes) { + match raw.find('?') { + Some(pos) => { + let path = &raw[..pos]; + let qs = &raw[pos + 1..]; + (path, Bytes::copy_from_slice(qs.as_bytes())) + } + None => (raw, Bytes::new()), + } +} + +#[cfg(test)] +#[expect(clippy::expect_used, reason = "test code uses expect for clarity")] +mod tests { + use super::*; + + fn simple_get() -> Vec { + b"GET /hello?name=world HTTP/1.1\r\nHost: localhost\r\n\r\n".to_vec() + } + + fn post_with_body() -> Vec { + b"POST /data HTTP/1.1\r\nHost: localhost\r\nContent-Length: 13\r\n\r\nHello, world!" + .to_vec() + } + + #[test] + fn test_parse_simple_get() { + let mut parser = RequestParser::new(); + let requests = parser.feed(&simple_get()).expect("parse failed"); + assert_eq!(requests.len(), 1); + let req = &requests[0]; + assert_eq!(req.head.method, Method::Get); + assert_eq!(req.head.path, "/hello"); + assert_eq!(req.head.query_string, "name=world"); + assert_eq!(req.head.version, HttpVersion::Http11); + assert!(req.body.is_empty()); + assert_eq!(req.head.headers.len(), 1); + } + + #[test] + fn test_parse_post_with_body() { + let mut parser = RequestParser::new(); + let requests = parser.feed(&post_with_body()).expect("parse failed"); + assert_eq!(requests.len(), 1); + let req = &requests[0]; + assert_eq!(req.head.method, Method::Post); + assert_eq!(req.head.path, "/data"); + assert_eq!(req.body, "Hello, world!"); + assert_eq!(req.head.content_length, Some(13)); + } + + #[test] + fn test_partial_head() { + let mut parser = RequestParser::new(); + let requests = parser + .feed(b"GET /hello HTTP/1.1\r\nHost: loc") + .expect("parse failed"); + assert!(requests.is_empty()); + + let requests = parser.feed(b"alhost\r\n\r\n").expect("parse failed"); + assert_eq!(requests.len(), 1); + assert_eq!(requests[0].head.path, "/hello"); + } + + #[test] + fn test_partial_body() { + let mut parser = RequestParser::new(); + let requests = parser + .feed(b"POST /data HTTP/1.1\r\nContent-Length: 12\r\n\r\nHello") + .expect("parse failed"); + assert!(requests.is_empty()); + + let requests = parser.feed(b", wor").expect("parse failed"); + assert!(requests.is_empty()); + + let requests = parser.feed(b"ld").expect("parse failed"); + assert_eq!(requests.len(), 1); + assert_eq!(requests[0].body, "Hello, world"); + } + + #[test] + fn test_pipelining() { + let mut parser = RequestParser::new(); + let mut data = Vec::new(); + data.extend_from_slice(b"GET /a HTTP/1.1\r\nHost: h\r\n\r\n"); + data.extend_from_slice(b"GET /b HTTP/1.1\r\nHost: h\r\n\r\n"); + + let requests = parser.feed(&data).expect("parse failed"); + assert_eq!(requests.len(), 2); + assert_eq!(requests[0].head.path, "/a"); + assert_eq!(requests[1].head.path, "/b"); + } + + #[test] + fn test_no_query_string() { + let mut parser = RequestParser::new(); + let requests = parser + .feed(b"GET /path HTTP/1.1\r\nHost: h\r\n\r\n") + .expect("parse failed"); + assert_eq!(requests.len(), 1); + assert!(requests[0].head.query_string.is_empty()); + } + + #[test] + fn test_http10() { + let mut parser = RequestParser::new(); + let requests = parser + .feed(b"GET / HTTP/1.0\r\nHost: h\r\n\r\n") + .expect("parse failed"); + assert_eq!(requests.len(), 1); + assert_eq!(requests[0].head.version, HttpVersion::Http10); + } + + #[test] + fn test_multiple_headers() { + let mut parser = RequestParser::new(); + let requests = parser + .feed(b"GET / HTTP/1.1\r\nHost: h\r\nAccept: text/html\r\nX-Custom: val\r\n\r\n") + .expect("parse failed"); + assert_eq!(requests.len(), 1); + assert_eq!(requests[0].head.headers.len(), 3); + } + + #[test] + fn test_malformed_request() { + let mut parser = RequestParser::new(); + let result = parser.feed(b"INVALID\r\n\r\n"); + assert!(result.is_err() || result.expect("unexpected ok").is_empty()); + } + + #[test] + fn test_reset() { + let mut parser = RequestParser::new(); + parser.feed(b"GET /partial HTTP/1.1\r\n").expect("ok"); + parser.reset(); + let requests = parser + .feed(b"GET /fresh HTTP/1.1\r\nHost: h\r\n\r\n") + .expect("parse failed"); + assert_eq!(requests.len(), 1); + assert_eq!(requests[0].head.path, "/fresh"); + } + + #[test] + fn test_zero_content_length() { + let mut parser = RequestParser::new(); + let requests = parser + .feed(b"POST /data HTTP/1.1\r\nContent-Length: 0\r\n\r\n") + .expect("parse failed"); + assert_eq!(requests.len(), 1); + assert!(requests[0].body.is_empty()); + } +} diff --git a/crates/framework/src/protocol/router.rs b/crates/framework/src/protocol/router.rs new file mode 100644 index 00000000..b6c047de --- /dev/null +++ b/crates/framework/src/protocol/router.rs @@ -0,0 +1,155 @@ +//! Route matching powered by [`matchit`]. +//! +//! Pure data transformation — no I/O, no async, no Python callbacks. +//! Takes a path string + method, returns a match result. + +use pyo3::prelude::*; +use pyo3::types::PyDict; + +/// Fast path-based router backed by [`matchit::Router`]. +/// +/// Route patterns use matchit syntax (`{param}`, `{*catch_all}`). +/// Python-side code converts framework-specific syntax before insertion. +#[pyclass(module = "apx._core")] +#[derive(Debug)] +pub struct RustRouter { + inner: matchit::Router, +} + +/// Error inserting a route pattern. +#[derive(Debug, thiserror::Error)] +#[error("route insert error: {0}")] +struct InsertError(#[from] matchit::InsertError); + +#[pymethods] +impl RustRouter { + /// Create an empty router. + #[new] + fn new() -> Self { + Self { + inner: matchit::Router::new(), + } + } + + /// Insert a route pattern with an opaque integer identifier. + /// + /// # Errors + /// + /// Returns `ValueError` if the pattern is invalid or conflicts with + /// an existing route. + fn insert(&mut self, pattern: &str, route_id: u32) -> PyResult<()> { + self.inner + .insert(pattern, route_id) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(InsertError(e).to_string())) + } + + /// Match a request path against registered routes. + /// + /// Returns `(route_id, params_dict)` on match, or `None` if no route + /// matches the path. + fn match_route<'py>( + &self, + py: Python<'py>, + path: &str, + ) -> PyResult)>> { + let Ok(matched) = self.inner.at(path) else { + return Ok(None); + }; + + let route_id = *matched.value; + let params = PyDict::new(py); + for (key, value) in matched.params.iter() { + params.set_item(key, value)?; + } + + Ok(Some((route_id, params))) + } +} + +#[cfg(test)] +#[expect(clippy::unwrap_used, reason = "test code uses unwrap for clarity")] +mod tests { + use super::*; + + #[test] + fn test_basic_match() { + let mut router = RustRouter::new(); + router.insert("/users/{id}", 1).ok(); + router.insert("/health", 2).ok(); + + crate::with_py(|py| { + let result = router.match_route(py, "/users/42").ok().flatten(); + assert!(result.is_some()); + let (route_id, params) = result.as_ref().map(|(id, p)| (*id, p)).unwrap(); + assert_eq!(route_id, 1); + let id_val: String = params.get_item("id").unwrap().unwrap().extract().unwrap(); + assert_eq!(id_val, "42"); + + let result = router.match_route(py, "/health").ok().flatten(); + assert!(result.is_some()); + let (route_id, params) = result.as_ref().map(|(id, p)| (*id, p)).unwrap(); + assert_eq!(route_id, 2); + assert!(params.is_empty()); + }); + } + + #[test] + fn test_no_match() { + let router = RustRouter::new(); + + crate::with_py(|py| { + let result = router.match_route(py, "/nonexistent").ok().flatten(); + assert!(result.is_none()); + }); + } + + #[test] + fn test_catch_all() { + let mut router = RustRouter::new(); + router.insert("/static/{*filepath}", 1).ok(); + + crate::with_py(|py| { + let result = router + .match_route(py, "/static/css/style.css") + .ok() + .flatten(); + assert!(result.is_some()); + let (route_id, params) = result.as_ref().map(|(id, p)| (*id, p)).unwrap(); + assert_eq!(route_id, 1); + let fp: String = params + .get_item("filepath") + .unwrap() + .unwrap() + .extract() + .unwrap(); + assert_eq!(fp, "css/style.css"); + }); + } + + #[test] + fn test_multiple_params() { + let mut router = RustRouter::new(); + router.insert("/orgs/{org}/repos/{repo}", 1).ok(); + + crate::with_py(|py| { + let result = router + .match_route(py, "/orgs/acme/repos/widgets") + .ok() + .flatten(); + assert!(result.is_some()); + let (_, params) = result.as_ref().map(|(id, p)| (*id, p)).unwrap(); + let org: String = params.get_item("org").unwrap().unwrap().extract().unwrap(); + let repo: String = params.get_item("repo").unwrap().unwrap().extract().unwrap(); + assert_eq!(org, "acme"); + assert_eq!(repo, "widgets"); + }); + } + + #[test] + fn test_insert_conflict() { + let mut router = RustRouter::new(); + router.insert("/users/{id}", 1).ok(); + let result = router.insert("/users/{name}", 2); + assert!(result.is_err()); + } +} diff --git a/crates/framework/src/protocol/writer.rs b/crates/framework/src/protocol/writer.rs new file mode 100644 index 00000000..b7a93f89 --- /dev/null +++ b/crates/framework/src/protocol/writer.rs @@ -0,0 +1,537 @@ +//! HTTP/1.1 response writer backed by an asyncio transport. +//! +//! Builds HTTP response bytes and writes to the asyncio transport. +//! Sans-I/O core (`build_status_and_headers`, `parse_send_event`) is +//! testable with `#[test]`. + +use std::time::Instant; + +use bytes::{BufMut, Bytes, BytesMut}; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyDict, PyList}; + +use crate::asgi::scope::ResolvedAwaitable; +use crate::telemetry::dispatch_metrics; + +/// ASGI send event parsed from a Python dict. +#[derive(Debug)] +pub enum SendEvent { + /// `http.response.start` — status code + headers. + Start { + /// HTTP status code. + status: u16, + /// Response headers as raw byte pairs. + headers: Vec<(Bytes, Bytes)>, + }, + /// `http.response.body` — body chunk. + Body { + /// Body bytes. + data: Py, + /// Whether more body chunks will follow. + more_body: bool, + }, +} + +/// Writer state machine. +#[derive(Debug)] +enum WriteState { + /// Waiting for `http.response.start`. + AwaitingStart, + /// Got start, waiting for first body chunk. + HeadersPending { + /// HTTP status code. + status: u16, + /// Response headers. + headers: Vec<(Bytes, Bytes)>, + }, + /// Streaming body chunks (with or without chunked encoding). + Streaming { chunked: bool }, + /// Response complete. + Done, +} + +/// HTTP/1.1 response writer backed by an asyncio transport. +/// +/// Implements the ASGI `send` callable. Builds HTTP response bytes +/// in Rust and writes them to `transport.write()`. +#[pyclass(module = "apx._core")] +pub struct RustResponseWriter { + transport: Py, + state: WriteState, + resolved: Py, + on_complete: Option>, + /// HTTP status code from `http.response.start`, for metrics. + response_status: u16, +} + +crate::opaque_debug!(RustResponseWriter); + +impl RustResponseWriter { + /// Create a new response writer. + pub fn new( + py: Python<'_>, + transport: Py, + on_complete: Option>, + ) -> PyResult> { + let resolved = Py::new(py, ResolvedAwaitable)?; + Py::new( + py, + Self { + transport, + state: WriteState::AwaitingStart, + resolved, + on_complete, + response_status: 0, + }, + ) + } +} + +#[pymethods] +impl RustResponseWriter { + /// ASGI send callable. + fn __call__(&mut self, py: Python<'_>, event: &Bound<'_, PyDict>) -> PyResult> { + let t0 = Instant::now(); + let parsed = parse_send_event(py, event)?; + dispatch_metrics::record_send_parse(t0.elapsed().as_micros() as f64); + + match parsed { + SendEvent::Start { status, headers } => { + self.response_status = status; + self.state = WriteState::HeadersPending { status, headers }; + } + SendEvent::Body { data, more_body } => { + self.write_body(py, &data, more_body)?; + } + } + Ok(self.resolved.clone_ref(py).into_any()) + } + + /// Write a 500 error response directly (bypasses ASGI state machine). + fn send_error(&mut self, py: Python<'_>, traceback: &str) -> PyResult<()> { + let body = traceback.as_bytes(); + let headers = vec![( + Bytes::from_static(b"content-type"), + Bytes::from_static(b"text/plain; charset=utf-8"), + )]; + let response = build_full_response(500, &headers, body); + let py_bytes = PyBytes::new(py, &response); + self.transport + .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,))?; + self.state = WriteState::Done; + self.response_status = 500; + self.signal_complete(py)?; + Ok(()) + } +} + +impl RustResponseWriter { + fn write_body(&mut self, py: Python<'_>, data: &Py, more_body: bool) -> PyResult<()> { + match std::mem::replace(&mut self.state, WriteState::Done) { + WriteState::HeadersPending { status, headers } => { + self.write_first_body(py, status, &headers, data, more_body)?; + } + WriteState::Streaming { chunked } => { + self.write_continuation(py, data, more_body, chunked)?; + } + _ => { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "send: body before start or after done", + )); + } + } + Ok(()) + } + + fn write_first_body( + &mut self, + py: Python<'_>, + status: u16, + headers: &[(Bytes, Bytes)], + data: &Py, + more_body: bool, + ) -> PyResult<()> { + let body_bytes = data.bind(py).as_bytes(); + let has_content_length = headers + .iter() + .any(|(name, _)| name.eq_ignore_ascii_case(b"content-length")); + + let chunked = more_body && !has_content_length; + + let t_build = Instant::now(); + let hdr_bytes = if chunked { + build_status_and_headers_chunked(status, headers) + } else if !more_body && !has_content_length { + build_status_and_headers_with_length(status, headers, body_bytes.len()) + } else { + build_status_and_headers(status, headers) + }; + dispatch_metrics::record_response_build(t_build.elapsed().as_micros() as f64); + + let t_write = Instant::now(); + if chunked { + let mut buf = BytesMut::with_capacity(hdr_bytes.len() + body_bytes.len() + 32); + buf.put_slice(&hdr_bytes); + write_chunk(&mut buf, body_bytes); + let py_bytes = PyBytes::new(py, &buf); + self.transport + .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,))?; + } else { + let hdr_py = PyBytes::new(py, &hdr_bytes); + self.transport + .call_method1(py, pyo3::intern!(py, "write"), (hdr_py,))?; + self.transport + .call_method1(py, pyo3::intern!(py, "write"), (data.bind(py),))?; + } + dispatch_metrics::record_response_write(t_write.elapsed().as_micros() as f64); + + if more_body { + self.state = WriteState::Streaming { chunked }; + } else { + self.signal_complete(py)?; + } + Ok(()) + } + + fn write_continuation( + &mut self, + py: Python<'_>, + data: &Py, + more_body: bool, + chunked: bool, + ) -> PyResult<()> { + let t_write = Instant::now(); + let body_bytes = data.bind(py).as_bytes(); + + if chunked { + let terminator_len = if more_body { 0 } else { LAST_CHUNK.len() }; + let mut buf = BytesMut::with_capacity(body_bytes.len() + 32 + terminator_len); + write_chunk(&mut buf, body_bytes); + if !more_body { + buf.put_slice(LAST_CHUNK); + } + let py_bytes = PyBytes::new(py, &buf); + self.transport + .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,))?; + } else { + self.transport + .call_method1(py, pyo3::intern!(py, "write"), (data.bind(py),))?; + } + dispatch_metrics::record_response_write(t_write.elapsed().as_micros() as f64); + + if more_body { + self.state = WriteState::Streaming { chunked }; + } else { + self.signal_complete(py)?; + } + Ok(()) + } + + fn signal_complete(&self, py: Python<'_>) -> PyResult<()> { + if let Some(cb) = &self.on_complete { + cb.call1(py, (self.response_status,))?; + } + Ok(()) + } +} + +/// Parse an ASGI send event dict into a [`SendEvent`]. +pub fn parse_send_event(py: Python<'_>, event: &Bound<'_, PyDict>) -> PyResult { + let type_obj = event + .get_item(pyo3::intern!(py, "type"))? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("missing 'type' key"))?; + let type_val: String = type_obj.extract()?; + + match type_val.as_str() { + "http.response.start" => { + let status: u16 = event + .get_item(pyo3::intern!(py, "status"))? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("missing 'status' key"))? + .extract()?; + let headers = extract_response_headers(py, event)?; + Ok(SendEvent::Start { status, headers }) + } + "http.response.body" => { + let body_obj = event + .get_item(pyo3::intern!(py, "body"))? + .unwrap_or_else(|| PyBytes::new(py, b"").into_any()); + let data: Py = body_obj.extract()?; + let more_body: bool = event + .get_item(pyo3::intern!(py, "more_body"))? + .map(|v| v.extract()) + .transpose()? + .unwrap_or(false); + Ok(SendEvent::Body { data, more_body }) + } + other => Err(pyo3::exceptions::PyValueError::new_err(format!( + "unknown send event type: {other}" + ))), + } +} + +/// Extract response headers from the ASGI send event. +fn extract_response_headers( + py: Python<'_>, + event: &Bound<'_, PyDict>, +) -> PyResult> { + let headers_obj = event.get_item(pyo3::intern!(py, "headers"))?; + let Some(headers_list) = headers_obj else { + return Ok(Vec::new()); + }; + let list = headers_list.cast_into::()?; + let mut result = Vec::with_capacity(list.len()); + for item in list.iter() { + let tuple = item.cast_into::()?; + let name_obj = tuple.get_item(0)?.cast_into::()?; + let value_obj = tuple.get_item(1)?.cast_into::()?; + result.push(( + Bytes::copy_from_slice(name_obj.as_bytes()), + Bytes::copy_from_slice(value_obj.as_bytes()), + )); + } + Ok(result) +} + +/// HTTP/1.1 chunked transfer encoding terminator. +const LAST_CHUNK: &[u8] = b"0\r\n\r\n"; + +/// Write a single HTTP chunk frame: `{hex_len}\r\n{data}\r\n`. +fn write_chunk(buf: &mut BytesMut, data: &[u8]) { + if data.is_empty() { + return; + } + buf.put_slice(format!("{:x}\r\n", data.len()).as_bytes()); + buf.put_slice(data); + buf.put_slice(b"\r\n"); +} + +/// Build the HTTP/1.1 status line + headers as bytes. +pub fn build_status_and_headers(status: u16, headers: &[(Bytes, Bytes)]) -> Bytes { + let reason = reason_phrase(status); + let mut buf = BytesMut::with_capacity(256); + buf.put_slice(b"HTTP/1.1 "); + buf.put_slice(status.to_string().as_bytes()); + buf.put_slice(b" "); + buf.put_slice(reason.as_bytes()); + buf.put_slice(b"\r\n"); + for (name, value) in headers { + buf.put_slice(name); + buf.put_slice(b": "); + buf.put_slice(value); + buf.put_slice(b"\r\n"); + } + buf.put_slice(b"\r\n"); + buf.freeze() +} + +/// Build status line + headers with `Transfer-Encoding: chunked`. +fn build_status_and_headers_chunked(status: u16, headers: &[(Bytes, Bytes)]) -> Bytes { + let reason = reason_phrase(status); + let mut buf = BytesMut::with_capacity(256); + buf.put_slice(b"HTTP/1.1 "); + buf.put_slice(status.to_string().as_bytes()); + buf.put_slice(b" "); + buf.put_slice(reason.as_bytes()); + buf.put_slice(b"\r\n"); + for (name, value) in headers { + buf.put_slice(name); + buf.put_slice(b": "); + buf.put_slice(value); + buf.put_slice(b"\r\n"); + } + buf.put_slice(b"transfer-encoding: chunked\r\n"); + buf.put_slice(b"\r\n"); + buf.freeze() +} + +/// Build status line + headers with an auto-added `Content-Length`. +fn build_status_and_headers_with_length( + status: u16, + headers: &[(Bytes, Bytes)], + body_len: usize, +) -> Bytes { + let reason = reason_phrase(status); + let mut buf = BytesMut::with_capacity(256); + buf.put_slice(b"HTTP/1.1 "); + buf.put_slice(status.to_string().as_bytes()); + buf.put_slice(b" "); + buf.put_slice(reason.as_bytes()); + buf.put_slice(b"\r\n"); + for (name, value) in headers { + buf.put_slice(name); + buf.put_slice(b": "); + buf.put_slice(value); + buf.put_slice(b"\r\n"); + } + buf.put_slice(b"content-length: "); + buf.put_slice(body_len.to_string().as_bytes()); + buf.put_slice(b"\r\n\r\n"); + buf.freeze() +} + +/// Build a complete HTTP/1.1 response (status + headers + body). +fn build_full_response(status: u16, headers: &[(Bytes, Bytes)], body: &[u8]) -> Bytes { + let reason = reason_phrase(status); + let mut buf = BytesMut::with_capacity(256 + body.len()); + buf.put_slice(b"HTTP/1.1 "); + buf.put_slice(status.to_string().as_bytes()); + buf.put_slice(b" "); + buf.put_slice(reason.as_bytes()); + buf.put_slice(b"\r\n"); + + let mut has_content_length = false; + for (name, value) in headers { + buf.put_slice(name); + buf.put_slice(b": "); + buf.put_slice(value); + buf.put_slice(b"\r\n"); + if name.eq_ignore_ascii_case(b"content-length") { + has_content_length = true; + } + } + if !has_content_length { + buf.put_slice(b"content-length: "); + buf.put_slice(body.len().to_string().as_bytes()); + buf.put_slice(b"\r\n"); + } + buf.put_slice(b"\r\n"); + buf.put_slice(body); + buf.freeze() +} + +/// Standard HTTP reason phrase for common status codes. +fn reason_phrase(status: u16) -> &'static str { + match status { + 200 => "OK", + 201 => "Created", + 204 => "No Content", + 301 => "Moved Permanently", + 302 => "Found", + 304 => "Not Modified", + 307 => "Temporary Redirect", + 308 => "Permanent Redirect", + 400 => "Bad Request", + 401 => "Unauthorized", + 403 => "Forbidden", + 404 => "Not Found", + 405 => "Method Not Allowed", + 409 => "Conflict", + 422 => "Unprocessable Entity", + 429 => "Too Many Requests", + 500 => "Internal Server Error", + 502 => "Bad Gateway", + 503 => "Service Unavailable", + 504 => "Gateway Timeout", + _ => "Unknown", + } +} + +#[cfg(test)] +#[expect(clippy::expect_used, reason = "test code uses expect for clarity")] +mod tests { + use super::*; + + #[test] + fn test_build_status_and_headers_200() { + let headers = vec![( + Bytes::from_static(b"content-type"), + Bytes::from_static(b"text/html"), + )]; + let result = build_status_and_headers(200, &headers); + let s = std::str::from_utf8(&result).expect("valid utf8"); + assert!(s.starts_with("HTTP/1.1 200 OK\r\n")); + assert!(s.contains("content-type: text/html\r\n")); + assert!(s.ends_with("\r\n\r\n")); + } + + #[test] + fn test_build_status_and_headers_404() { + let result = build_status_and_headers(404, &[]); + let s = std::str::from_utf8(&result).expect("valid utf8"); + assert!(s.starts_with("HTTP/1.1 404 Not Found\r\n")); + } + + #[test] + fn test_build_full_response() { + let headers = vec![( + Bytes::from_static(b"content-type"), + Bytes::from_static(b"text/plain"), + )]; + let result = build_full_response(200, &headers, b"hello"); + let s = std::str::from_utf8(&result).expect("valid utf8"); + assert!(s.contains("content-length: 5\r\n")); + assert!(s.ends_with("hello")); + } + + #[test] + fn test_build_full_response_with_content_length() { + let headers = vec![( + Bytes::from_static(b"Content-Length"), + Bytes::from_static(b"5"), + )]; + let result = build_full_response(200, &headers, b"hello"); + let s = std::str::from_utf8(&result).expect("valid utf8"); + let lower = s.to_ascii_lowercase(); + let count = lower.matches("content-length").count(); + assert_eq!(count, 1, "should not add duplicate content-length"); + } + + #[test] + fn test_multiple_headers() { + let headers = vec![ + (Bytes::from_static(b"x-a"), Bytes::from_static(b"1")), + (Bytes::from_static(b"x-b"), Bytes::from_static(b"2")), + ]; + let result = build_status_and_headers(200, &headers); + let s = std::str::from_utf8(&result).expect("valid utf8"); + assert!(s.contains("x-a: 1\r\n")); + assert!(s.contains("x-b: 2\r\n")); + } + + #[test] + fn test_reason_phrase_unknown() { + assert_eq!(reason_phrase(999), "Unknown"); + } + + #[test] + fn test_build_chunked_headers() { + let headers = vec![( + Bytes::from_static(b"content-type"), + Bytes::from_static(b"text/plain"), + )]; + let result = build_status_and_headers_chunked(200, &headers); + let s = std::str::from_utf8(&result).expect("valid utf8"); + assert!(s.contains("transfer-encoding: chunked\r\n")); + assert!(s.contains("content-type: text/plain\r\n")); + } + + #[test] + fn test_build_headers_with_length() { + let headers = vec![( + Bytes::from_static(b"content-type"), + Bytes::from_static(b"text/plain"), + )]; + let result = build_status_and_headers_with_length(200, &headers, 42); + let s = std::str::from_utf8(&result).expect("valid utf8"); + assert!(s.contains("content-length: 42\r\n")); + } + + #[test] + fn test_write_chunk_framing() { + let mut buf = BytesMut::new(); + write_chunk(&mut buf, b"hello"); + assert_eq!(&buf[..], b"5\r\nhello\r\n"); + } + + #[test] + fn test_write_chunk_empty_is_noop() { + let mut buf = BytesMut::new(); + write_chunk(&mut buf, b""); + assert!(buf.is_empty()); + } + + #[test] + fn test_last_chunk_constant() { + assert_eq!(LAST_CHUNK, b"0\r\n\r\n"); + } +} diff --git a/crates/framework/src/protocol/ws/mod.rs b/crates/framework/src/protocol/ws/mod.rs deleted file mode 100644 index eb74ea86..00000000 --- a/crates/framework/src/protocol/ws/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! WebSocket protocol: upgrade, session, frame bridging. - -pub mod session; diff --git a/crates/framework/src/protocol/ws/session.rs b/crates/framework/src/protocol/ws/session.rs deleted file mode 100644 index f8846059..00000000 --- a/crates/framework/src/protocol/ws/session.rs +++ /dev/null @@ -1,524 +0,0 @@ -//! WebSocket bridge — upgrades HTTP connections and bridges frames between -//! tungstenite and ASGI receive/send channels. -//! -//! The flow: -//! 1. [`is_websocket_upgrade`] detects upgrade requests at the service layer. -//! 2. [`handle_upgrade`] performs the hyper-tungstenite handshake and spawns -//! a background session task. -//! 3. The session bridges tungstenite frames ↔ ASGI `websocket.*` events -//! through mpsc channels, with the ASGI app driven by the Rust scheduler. - -use crate::asgi::scope::{ - AsgiEvent, AsgiSend, AsgiWsReceive, ScopeInterns, WsIncomingEvent, build_ws_scope, -}; -use crate::protocol::http::error::AppError; -use crate::supervision::worker_context::WorkerContext; -use crate::transport::types::{ - BodyStream, InboundRequest, ProtocolVersion, ResponseBody, TransportKind, -}; -use bytes::Bytes; -use futures_util::{Sink, SinkExt, Stream, StreamExt}; -use hyper::body::Incoming; -use hyper::{Request, Response}; -use hyper_tungstenite::tungstenite; -use pyo3::prelude::*; -use std::net::SocketAddr; -use std::sync::Arc; -use tokio::sync::mpsc; -use tungstenite::Message; -use tungstenite::protocol::frame::CloseFrame; -use tungstenite::protocol::frame::coding::CloseCode; - -/// Buffer size for both incoming and outgoing WebSocket channels. -const WS_CHANNEL_CAPACITY: usize = 32; - -/// Default WebSocket close code (normal closure). -const WS_CLOSE_NORMAL: u16 = 1000; - -// ── Upgrade detection ─────────────────────────────────────────────────── - -/// Check if a request is a WebSocket upgrade request. -pub fn is_websocket_upgrade(req: &Request) -> bool { - hyper_tungstenite::is_upgrade_request(req) -} - -// ── Upgrade handler ───────────────────────────────────────────────────── - -/// Perform the WebSocket upgrade handshake and spawn the session task. -/// -/// Returns the 101 Switching Protocols response immediately. The actual -/// WebSocket session runs in a spawned tokio task. -/// -/// # Errors -/// -/// Returns an error if the upgrade handshake fails. -pub fn handle_upgrade( - mut request: Request, - server_addr: SocketAddr, - client_addr: Option, - app: Arc>, - interns: Arc, - ctx: Arc, -) -> Result, AppError> { - // Extract request metadata before the upgrade consumes the request. - let inbound = extract_request_info(&request, server_addr, client_addr); - - // Perform the upgrade handshake. - let (response, ws_future) = hyper_tungstenite::upgrade(&mut request, None) - .map_err(|e| AppError::Internal(format!("websocket upgrade failed: {e}")))?; - - // Convert the Full body to our ResponseBody (101 body is empty). - let (parts, _full_body) = response.into_parts(); - let response = Response::from_parts(parts, ResponseBody::Fixed(Bytes::new())); - - // Spawn the WebSocket session as a background task. - tokio::spawn(ws_session(ws_future, inbound, app, interns, ctx)); - - Ok(response) -} - -// ── Request extraction ────────────────────────────────────────────────── - -/// Extract request metadata from a borrowed hyper request. -/// -/// Builds an [`InboundRequest`] with [`BodyStream::Empty`] since the WS -/// upgrade path doesn't use the HTTP body. Generic over the body type -/// because only request parts (URI, headers, method, version) are accessed. -fn extract_request_info( - req: &Request, - server_addr: SocketAddr, - client_addr: Option, -) -> InboundRequest { - let path = req.uri().path().to_owned(); - let query_string = req - .uri() - .query() - .map(|q| Bytes::copy_from_slice(q.as_bytes())) - .unwrap_or_default(); - let mut headers = req.headers().clone(); - crate::protocol::http::service::ensure_request_id(&mut headers); - let method = req.method().clone(); - - let protocol = match req.version() { - hyper::Version::HTTP_10 => ProtocolVersion::Http10, - hyper::Version::HTTP_2 => ProtocolVersion::H2, - _ => ProtocolVersion::Http11, - }; - - InboundRequest::new( - method, - path, - query_string, - headers, - BodyStream::Empty, - protocol, - TransportKind::Tcp, - client_addr, - server_addr, - Vec::new(), - http::Extensions::default(), - ) -} - -// ── WebSocket session ─────────────────────────────────────────────────── - -/// Run a WebSocket session: bridge tungstenite frames ↔ ASGI events. -/// -/// This is a long-lived tokio task that runs for the lifetime of the -/// WebSocket connection. -async fn ws_session( - ws_future: hyper_tungstenite::HyperWebsocket, - request: InboundRequest, - app: Arc>, - interns: Arc, - ctx: Arc, -) { - // Await the upgrade completion to get the WebSocket stream. - let ws_stream = match ws_future.await { - Ok(stream) => stream, - Err(e) => { - tracing::error!(name: "apx.ws.upgrade_completion_failed", error = %e, "websocket upgrade completion failed"); - return; - } - }; - - let (sink, stream) = ws_stream.split(); - - // Create channels for ASGI ↔ tungstenite communication. - let (incoming_tx, incoming_rx) = mpsc::channel::(WS_CHANNEL_CAPACITY); - let (outgoing_tx, outgoing_rx) = mpsc::channel::(WS_CHANNEL_CAPACITY); - - // Send initial connect event per ASGI spec. - if incoming_tx.send(WsIncomingEvent::Connect).await.is_err() { - return; - } - - // Spawn forwarding tasks. - let recv_handle = tokio::spawn(forward_incoming(stream, incoming_tx)); - let send_handle = tokio::spawn(forward_outgoing(outgoing_rx, sink)); - - // Build scope, call app, submit to asyncio. - let schedule_result = Python::attach(|py| -> Result<(), AppError> { - let scope = build_ws_scope(py, &request, &interns) - .map_err(|e| AppError::Internal(format!("ws scope build: {e}")))?; - let receive = Py::new(py, AsgiWsReceive::new(incoming_rx)) - .map_err(|e| AppError::Internal(format!("wrap ws receive: {e}")))?; - let send = Py::new(py, AsgiSend::new(outgoing_tx)) - .map_err(|e| AppError::Internal(format!("wrap ws send: {e}")))?; - ctx.call_soon_threadsafe - .call1(py, (&ctx.launch_fn, &*app, &scope, &receive, &send)) - .map_err(|e| AppError::Internal(format!("submit ws to asyncio: {e}")))?; - Ok(()) - }); - if let Err(e) = schedule_result { - tracing::error!(name: "apx.ws.schedule_coroutine_failed", error = %e, "failed to schedule websocket coroutine"); - } - - // Clean up forwarding tasks. - recv_handle.abort(); - send_handle.abort(); -} - -// ── Frame forwarding ──────────────────────────────────────────────────── - -/// Forward incoming tungstenite frames to the ASGI receive channel. -/// -/// Generic over the stream type so it can be tested with mock streams. -async fn forward_incoming(mut stream: S, tx: mpsc::Sender) -where - S: Stream> + Unpin, -{ - loop { - match stream.next().await { - Some(Ok(msg)) => { - let event = match msg { - Message::Text(t) => WsIncomingEvent::Receive { - text: Some(t.to_string()), - bytes: None, - }, - Message::Binary(b) => WsIncomingEvent::Receive { - text: None, - bytes: Some(b), - }, - Message::Close(frame) => { - let code = frame - .as_ref() - .map_or(WS_CLOSE_NORMAL, |f| u16::from(f.code)); - let event = WsIncomingEvent::Disconnect { code }; - let _ = tx.send(event).await; - break; - } - Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue, - }; - if tx.send(event).await.is_err() { - break; - } - } - Some(Err(e)) => { - tracing::debug!(name: "apx.ws.read_error", error = %e, "websocket read error"); - let _ = tx - .send(WsIncomingEvent::Disconnect { - code: WS_CLOSE_NORMAL, - }) - .await; - break; - } - None => { - // Stream ended — send disconnect. - let _ = tx - .send(WsIncomingEvent::Disconnect { - code: WS_CLOSE_NORMAL, - }) - .await; - break; - } - } - } -} - -/// Forward outgoing ASGI events to the tungstenite sink. -/// -/// Generic over the sink type so it can be tested with mock sinks. -async fn forward_outgoing(mut rx: mpsc::Receiver, mut sink: K) -where - K: Sink + Unpin, -{ - let mut accepted = false; - - while let Some(event) = rx.recv().await { - match event { - AsgiEvent::WsAccept { .. } => { - // Protocol acknowledgement — 101 already sent, no WS frame needed. - accepted = true; - } - AsgiEvent::WsSend { text, bytes } => { - if !accepted { - tracing::warn!(name: "apx.ws.send_before_accept", "websocket send before accept — dropping frame"); - continue; - } - let msg = if let Some(t) = text { - Message::text(t) - } else if let Some(b) = bytes { - Message::binary(b) - } else { - continue; - }; - if let Err(e) = sink.send(msg).await { - tracing::debug!(name: "apx.ws.write_error", error = %e, "websocket write error"); - break; - } - } - AsgiEvent::WsClose { code } => { - let close_frame = CloseFrame { - code: CloseCode::from(code), - reason: "".into(), - }; - let _ = sink.send(Message::Close(Some(close_frame))).await; - break; - } - AsgiEvent::ResponseStart { .. } | AsgiEvent::ResponseBody { .. } => { - tracing::error!(name: "apx.ws.http_response_in_ws_context", "HTTP response event in websocket context — ignoring"); - } - } - } - - // Best-effort close the sink when the channel closes. - let _ = sink.close().await; -} - -// ── Tests ─────────────────────────────────────────────────────────────── - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - clippy::panic, - reason = "test code uses unwrap/assert for clarity" -)] -mod tests { - use super::*; - use std::pin::Pin; - - /// Empty body type usable in test requests (hyper's `Incoming` has no `Default`). - type EmptyBody = http_body_util::Empty; - - fn empty_body() -> EmptyBody { - http_body_util::Empty::new() - } - - // ── is_websocket_upgrade ──────────────────────────────────────────── - - #[test] - fn is_websocket_upgrade_positive() { - let req = Request::builder() - .header("Connection", "Upgrade") - .header("Upgrade", "websocket") - .body(empty_body()) - .unwrap(); - assert!(hyper_tungstenite::is_upgrade_request(&req)); - } - - #[test] - fn is_websocket_upgrade_negative() { - let req = Request::builder().method("GET").body(empty_body()).unwrap(); - assert!(!hyper_tungstenite::is_upgrade_request(&req)); - } - - #[test] - fn is_websocket_upgrade_case_insensitive() { - let req = Request::builder() - .header("Connection", "upgrade") - .header("Upgrade", "WEBSOCKET") - .body(empty_body()) - .unwrap(); - assert!(hyper_tungstenite::is_upgrade_request(&req)); - } - - // ── extract_request_info ──────────────────────────────────────────── - - #[test] - fn extract_request_info_preserves_fields() { - let req = Request::builder() - .uri("/ws/chat?token=abc") - .header("Host", "localhost") - .header("Connection", "Upgrade") - .header("Upgrade", "websocket") - .body(empty_body()) - .unwrap(); - - let server = SocketAddr::from(([127, 0, 0, 1], 8080)); - let client = SocketAddr::from(([10, 0, 0, 1], 54321)); - let inbound = extract_request_info(&req, server, Some(client)); - - assert_eq!(inbound.path, "/ws/chat"); - assert_eq!(inbound.query_string.as_ref(), b"token=abc"); - assert_eq!(inbound.server_addr, server); - assert_eq!(inbound.client_addr, Some(client)); - assert!(inbound.headers.contains_key("host")); - assert!(!inbound.has_body()); - } - - #[test] - fn extract_request_info_no_query() { - let req = Request::builder().uri("/ws").body(empty_body()).unwrap(); - let server = SocketAddr::from(([127, 0, 0, 1], 8080)); - let inbound = extract_request_info(&req, server, None); - assert_eq!(inbound.path, "/ws"); - assert!(inbound.query_string.is_empty()); - assert!(inbound.client_addr.is_none()); - } - - #[test] - fn extract_request_info_headers() { - let req = Request::builder() - .uri("/ws") - .header("X-Custom", "test-value") - .header("Authorization", "Bearer xyz") - .body(empty_body()) - .unwrap(); - let server = SocketAddr::from(([0, 0, 0, 0], 3000)); - let inbound = extract_request_info(&req, server, None); - assert_eq!(inbound.headers.get("x-custom").unwrap(), "test-value"); - assert_eq!(inbound.headers.get("authorization").unwrap(), "Bearer xyz"); - } - - // ── forward_incoming ──────────────────────────────────────────────── - - #[tokio::test] - async fn forward_incoming_text_message() { - let stream = futures_util::stream::iter(vec![Ok(Message::text("hello"))]); - let (tx, mut rx) = mpsc::channel(8); - tokio::spawn(forward_incoming(stream, tx)); - - match rx.recv().await.unwrap() { - WsIncomingEvent::Receive { text, bytes } => { - assert_eq!(text.as_deref(), Some("hello")); - assert!(bytes.is_none()); - } - other => panic!("expected Receive, got {other:?}"), - } - } - - #[tokio::test] - async fn forward_incoming_binary_message() { - let stream = futures_util::stream::iter(vec![Ok(Message::binary(vec![1u8, 2, 3]))]); - let (tx, mut rx) = mpsc::channel(8); - tokio::spawn(forward_incoming(stream, tx)); - - match rx.recv().await.unwrap() { - WsIncomingEvent::Receive { text, bytes } => { - assert!(text.is_none()); - assert_eq!(bytes.as_deref(), Some(&[1u8, 2, 3][..])); - } - other => panic!("expected Receive, got {other:?}"), - } - } - - #[tokio::test] - async fn forward_incoming_close_frame() { - let close = Message::Close(Some(CloseFrame { - code: CloseCode::Normal, - reason: "bye".into(), - })); - let stream = futures_util::stream::iter(vec![Ok(close)]); - let (tx, mut rx) = mpsc::channel(8); - tokio::spawn(forward_incoming(stream, tx)); - - match rx.recv().await.unwrap() { - WsIncomingEvent::Disconnect { code } => assert_eq!(code, 1000), - other => panic!("expected Disconnect, got {other:?}"), - } - } - - #[tokio::test] - async fn forward_incoming_stream_end() { - let stream = futures_util::stream::empty::>(); - let (tx, mut rx) = mpsc::channel(8); - tokio::spawn(forward_incoming(stream, tx)); - - match rx.recv().await.unwrap() { - WsIncomingEvent::Disconnect { code } => assert_eq!(code, WS_CLOSE_NORMAL), - other => panic!("expected Disconnect, got {other:?}"), - } - } - - // ── forward_outgoing ──────────────────────────────────────────────── - - /// Sink that collects messages + shared Vec for assertions. - type MockSink = Pin + Send>>; - - fn mock_sink() -> (MockSink, Arc>>) { - let messages = Arc::new(std::sync::Mutex::new(Vec::new())); - let msgs = Arc::clone(&messages); - let sink = futures_util::sink::unfold(msgs, |msgs, msg: Message| async move { - msgs.lock().unwrap().push(msg); - Ok::<_, tungstenite::Error>(msgs) - }); - (Box::pin(sink), messages) - } - - #[tokio::test] - async fn forward_outgoing_accept_then_send() { - let (event_tx, event_rx) = mpsc::channel(8); - let (sink, messages) = mock_sink(); - - event_tx - .send(AsgiEvent::WsAccept { - subprotocol: None, - headers: Vec::new(), - }) - .await - .unwrap(); - event_tx - .send(AsgiEvent::WsSend { - text: Some("world".to_owned()), - bytes: None, - }) - .await - .unwrap(); - event_tx - .send(AsgiEvent::WsClose { code: 1000 }) - .await - .unwrap(); - drop(event_tx); - - forward_outgoing(event_rx, sink).await; - - let msgs = messages.lock().unwrap(); - assert_eq!(msgs.len(), 2); // WsSend + WsClose - assert_eq!(msgs[0], Message::text("world")); - assert!(matches!(msgs[1], Message::Close(Some(_)))); - } - - #[tokio::test] - async fn forward_outgoing_close() { - let (event_tx, event_rx) = mpsc::channel(8); - let (sink, messages) = mock_sink(); - - event_tx - .send(AsgiEvent::WsAccept { - subprotocol: None, - headers: Vec::new(), - }) - .await - .unwrap(); - event_tx - .send(AsgiEvent::WsClose { code: 1001 }) - .await - .unwrap(); - drop(event_tx); - - forward_outgoing(event_rx, sink).await; - - let msgs = messages.lock().unwrap(); - assert_eq!(msgs.len(), 1); - match &msgs[0] { - Message::Close(Some(frame)) => assert_eq!(u16::from(frame.code), 1001), - other => panic!("expected Close, got {other:?}"), - } - } - - // Note: `dispatch_ws` default (400 response) and `is_websocket_upgrade` - // are tested indirectly via the `service.rs` integration path and the - // `is_upgrade_request` tests above. hyper's `Incoming` type has no public - // constructors, so direct unit tests require a real HTTP connection. -} diff --git a/crates/framework/src/pyapi.rs b/crates/framework/src/pyapi.rs index 28a3f506..c6a72b0e 100644 --- a/crates/framework/src/pyapi.rs +++ b/crates/framework/src/pyapi.rs @@ -8,17 +8,16 @@ use pyo3::types::PyModule; /// Register framework types into the `apx._core` extension module. pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - // ASGI lifespan protocol types m.add_class::()?; m.add_class::()?; - // 3-thread dispatch pipeline types - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + // HTTP protocol types + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; // Telemetry m.add_class::()?; diff --git a/crates/framework/src/supervision/mod.rs b/crates/framework/src/supervision/mod.rs index 5ccd54d3..3e86a520 100644 --- a/crates/framework/src/supervision/mod.rs +++ b/crates/framework/src/supervision/mod.rs @@ -5,4 +5,3 @@ pub mod ipc; pub mod signal; pub mod supervisor; pub mod worker; -pub mod worker_context; diff --git a/crates/framework/src/supervision/worker.rs b/crates/framework/src/supervision/worker.rs index f4a37692..2f082760 100644 --- a/crates/framework/src/supervision/worker.rs +++ b/crates/framework/src/supervision/worker.rs @@ -1,28 +1,24 @@ -//! Single worker: initialize Python, bind TCP, serve requests. +//! Single worker: initialize Python, serve requests via asyncio. //! //! A worker is a child process spawned by the supervisor. It owns one Python -//! interpreter, one inline asyncio event loop, and one TCP listener bound via -//! `SO_REUSEPORT`. +//! interpreter with an asyncio event loop. TCP binding happens via asyncio's +//! `loop.create_server()` with `SO_REUSEPORT`. use super::ipc::channel::WorkerChannel; use super::ipc::protocol::{BootstrapError, IpcMessage, Nonce, WorkerBootstrap}; use super::signal::shutdown_signal; -use super::worker_context::WorkerContext; use crate::asgi::app::{ModuleImport, format_pyerr}; -use crate::io::EventLoop; -use crate::protocol::http::service::{ApxService, ServiceConfig, serve_tcp}; -use crate::transport::{Listener, TransportConfig, TransportError}; +use crate::asgi::scope::ScopeInterns; +use crate::protocol::connection::ProtocolFactory; use pyo3::prelude::*; -use std::net::IpAddr; -use std::sync::Arc; -use std::time::Duration; +use std::net::{IpAddr, SocketAddr}; /// Errors during worker operation. #[derive(Debug, thiserror::Error)] pub enum WorkerError { /// TCP listener creation failed. #[error("transport: {0}")] - Transport(#[from] TransportError), + Transport(#[from] crate::transport::TransportError), /// Python interpreter initialization failed. #[error("python init failed: {0}")] @@ -38,7 +34,7 @@ pub enum WorkerError { /// Serving requests failed. #[error("serve failed: {0}")] - Serve(std::io::Error), + Serve(String), /// ASGI lifespan startup failed. #[error("lifespan startup failed: {0}")] @@ -46,10 +42,6 @@ pub enum WorkerError { } /// Format a worker error with full Python traceback when available. -/// -/// Pattern-matches through the error chain to find a `PyErr` inside -/// `AppLoadError::ImportFailed` and renders its traceback. Falls back to -/// the standard `Display` chain for non-Python errors. pub fn format_worker_error(err: &WorkerError) -> String { match err { WorkerError::AppLoad(crate::asgi::app::AppLoadError::ImportFailed { source, .. }) => { @@ -60,118 +52,47 @@ pub fn format_worker_error(err: &WorkerError) -> String { } } -/// Phase 1 runtime: TCP listener + Python interpreter (expensive, survives reloads). -pub struct WorkerRuntime { - /// TCP listener bound via the `Listener` trait. - pub listener: crate::transport::TcpListener, - /// IPC channel to supervisor — stays open for the worker's lifetime. - pub channel: WorkerChannel, - /// Asyncio event loop (dedicated thread, asyncio delegation). - pub event_loop: EventLoop, -} - -crate::opaque_debug!(WorkerRuntime); - -/// Phase 1: Create TCP listener and initialize the Python interpreter. -/// -/// Uses `io::EventLoop` — creates the asyncio loop on a dedicated thread. -/// Coroutines are submitted via `call_soon_threadsafe(create_task, coro)`. -/// -/// # Errors -/// -/// Returns an error if the listener cannot be created or Python init fails. -pub async fn init_worker( - bootstrap: &WorkerBootstrap, - channel: WorkerChannel, -) -> Result { - let host: IpAddr = bootstrap - .host - .parse() - .map_err(|e| TransportError::InvalidHost { - host: bootstrap.host.clone(), - source: e, - })?; - let config = TransportConfig::tcp(host, bootstrap.port); - let listener = crate::transport::TcpListener::bind(&config).await?; - - // Initialize the Python interpreter. - // IMPORTANT: must only be called once per process, only in worker processes. - Python::initialize(); - - // Initialize asyncio event loop (dedicated thread, asyncio delegation). - let event_loop = Python::attach(|py| EventLoop::init(py, &bootstrap.loop_policy)) - .map_err(|e| WorkerError::PythonInit(format!("event loop: {e}")))?; - - Ok(WorkerRuntime { - listener, - channel, - event_loop, - }) -} - -/// Signal readiness to supervisor over the IPC channel. -/// -/// # Errors -/// -/// Returns an error if the IPC send fails. -async fn signal_readiness(channel: &mut WorkerChannel) -> Result<(), WorkerError> { - channel - .send(&IpcMessage::Ready) - .await - .map_err(WorkerError::from) -} - -/// Loaded app and telemetry config, ready for readiness signal. +/// Loaded app state, ready to serve. struct AppReady { - dispatch: Arc, - telemetry: crate::telemetry::config::TelemetryConfig, - /// Raw ASGI callable for the lifespan protocol. + /// ASGI application callable. asgi_app: Py, - /// Cached `apx._bridge.launch` for submitting coroutines to the asyncio thread. - launch_fn: Py, + /// Pre-built scope interns for the server address. + interns: ScopeInterns, + /// Telemetry configuration. + telemetry: crate::telemetry::config::TelemetryConfig, + /// Server socket address. + server_addr: SocketAddr, } crate::opaque_debug!(AppReady); +/// Initialize the Python interpreter. +fn init_python() { + Python::initialize(); +} + /// Load the Python app and read telemetry configuration. -/// -/// Covers every fallible step between `init_worker` and `signal_readiness`. -/// On failure the caller sends `StartupFailed` over IPC before exiting. -fn load_app(runtime: &WorkerRuntime, bootstrap: &WorkerBootstrap) -> Result { +fn load_app(bootstrap: &WorkerBootstrap) -> Result { apply_python_log_config()?; - let pipeline = Arc::new( - crate::io::channel::DispatchPipeline::new() - .map_err(|e| WorkerError::PythonInit(format!("dispatch pipeline: {e}")))?, - ); - - let (ctx, launch_fn_ref) = { - let el = &runtime.event_loop; - Python::attach( - |py| -> Result<(Arc, Py), WorkerError> { - let launch_fn = register_launch(py) - .map_err(|e| WorkerError::PythonInit(format!("register launch: {e}")))?; - let launch_fn_ref = launch_fn.clone_ref(py); - let ctx = Arc::new(WorkerContext { - pipeline: Arc::clone(&pipeline), - call_soon_threadsafe: el.call_soon_threadsafe().clone_ref(py), - launch_fn, - }); - Ok((ctx, launch_fn_ref)) - }, - )? - }; - - let server_addr = runtime.listener.local_addr(); - let event_loop_py = runtime.event_loop.event_loop_py(); - let (dispatch, asgi_app) = Python::attach(|py| { - ModuleImport::new(bootstrap.app_module.as_str(), bootstrap.dev_mode).build_with_app( - py, - ctx, - event_loop_py, - server_addr, - ) - })?; + let host: IpAddr = + bootstrap + .host + .parse() + .map_err(|e| crate::transport::TransportError::InvalidHost { + host: bootstrap.host.clone(), + source: e, + })?; + let server_addr = SocketAddr::new(host, bootstrap.port); + + let (asgi_app, interns) = + Python::attach(|py| -> Result<(Py, ScopeInterns), WorkerError> { + let app_import = ModuleImport::new(bootstrap.app_module.as_str()); + let app = app_import.load_callable(py).map_err(WorkerError::AppLoad)?; + let asgi_app = app.inner().clone_ref(py); + let interns = ScopeInterns::new(py, server_addr); + Ok((asgi_app, interns)) + })?; let telemetry = Python::attach(|py| { crate::telemetry::bootstrap_python_telemetry(py) @@ -181,13 +102,21 @@ fn load_app(runtime: &WorkerRuntime, bootstrap: &WorkerBootstrap) -> Result Result<(), WorkerError> { + channel + .send(&IpcMessage::Ready) + .await + .map_err(WorkerError::from) +} + /// Relay telemetry config to the supervisor (worker 0 only). async fn relay_telemetry( channel: &mut WorkerChannel, @@ -230,135 +159,132 @@ fn init_metrics(telemetry: &crate::telemetry::config::TelemetryConfig) { ); } -/// Build the HTTP service from dispatch and bootstrap config. -fn build_service( - runtime: &WorkerRuntime, - bootstrap: &WorkerBootstrap, - dispatch: Arc, -) -> ApxService { - let mut config = ServiceConfig { - timeout: Duration::from_secs(bootstrap.request_timeout_secs), - ..ServiceConfig::default() - }; - if let Some(mc) = bootstrap.max_concurrent { - config.max_concurrent = mc; - } - let server_addr = runtime.listener.local_addr(); - ApxService::new(dispatch, server_addr, &config) -} - -/// Accept connections and serve until shutdown or drain. -async fn serve( - runtime: WorkerRuntime, - service: ApxService, - drain_timeout_secs: u64, - lifespan: Option, +/// Run the asyncio server via `asyncio.run(serve(...))`. +/// +/// The asyncio event loop owns everything: TCP accept, HTTP parsing, +/// request dispatch, and response writing. Rust provides accelerated +/// primitives as PyO3 `#[pyclass]` types. +fn run_server( + ready: AppReady, + shutdown_rx: tokio::sync::oneshot::Receiver<()>, ) -> Result<(), WorkerError> { - let (ipc_reader, mut ipc_writer) = runtime.channel.split(); - - let (drain_tx, drain_rx) = tokio::sync::oneshot::channel::<()>(); - tokio::spawn(async move { - let mut reader = ipc_reader; - match reader.recv().await { - Ok(IpcMessage::Drain) => { - tracing::info!( - name: "apx.worker.drain_received", - "received Drain from supervisor" - ); - let _ = drain_tx.send(()); - } - Ok(msg) => tracing::warn!( - name: "apx.worker.drain_unexpected_ipc", - ?msg, - "unexpected IPC message" - ), - Err(e) => tracing::debug!( - name: "apx.worker.drain_ipc_closed", - error = %e, - "IPC channel closed" - ), - } - }); - - let combined_shutdown = async { - tokio::select! { - () = shutdown_signal() => {} - _ = drain_rx => {} - } - }; - - let mut connections = serve_tcp(runtime.listener, service, combined_shutdown) - .await - .map_err(WorkerError::Serve)?; + Python::attach(|py| { + let asyncio = py + .import(c"asyncio") + .map_err(|e| WorkerError::Serve(format!("import asyncio: {e}")))?; + + let shutdown_event = asyncio + .call_method0(c"Event") + .map_err(|e| WorkerError::Serve(format!("create Event: {e}")))?; + + let host = ready.server_addr.ip().to_string(); + let port = ready.server_addr.port(); + + let factory_builder = + create_factory_builder(py, ready.asgi_app.clone_ref(py), ready.interns, &host, port)?; + + py.run( + c" +import asyncio as _asyncio +from apx._server import serve as _serve, _build_on_request +from apx._scheduler import CallSoonCapture + +async def _boot(_app, _factory_fn, _host, _port, _shutdown_event): + loop = _asyncio.get_running_loop() + capture = CallSoonCapture(loop) + on_request = _build_on_request(_app, loop, capture) + factory = _factory_fn(on_request) + await _serve(_host, _port, _app, factory, shutdown_event=_shutdown_event) +", + None, + None, + ) + .map_err(|e| WorkerError::Serve(format!("compile bootstrap: {e}")))?; + + let boot_fn = py + .eval(c"_boot", None, None) + .map_err(|e| WorkerError::Serve(format!("get _boot: {e}")))?; + + let coro = boot_fn + .call1(( + &ready.asgi_app, + factory_builder, + &host, + port, + &shutdown_event, + )) + .map_err(|e| WorkerError::Serve(format!("create boot coro: {e}")))?; + + let shutdown_event_ref = shutdown_event.unbind(); + std::thread::spawn(move || { + let _ = shutdown_rx.blocking_recv(); + Python::attach(|py| { + let _ = shutdown_event_ref.call_method0(py, pyo3::intern!(py, "set")); + }); + }); - if drain_timeout_secs > 0 { - let _ = tokio::time::timeout(Duration::from_secs(drain_timeout_secs), async { - while connections.join_next().await.is_some() {} - }) - .await; - } + asyncio + .call_method1(c"run", (coro,)) + .map_err(|e| WorkerError::Serve(format!("asyncio.run: {e}")))?; - // Trigger ASGI lifespan shutdown while the asyncio loop is still running. - if let Some(handle) = lifespan - && let Err(e) = handle.trigger_shutdown().await - { - tracing::warn!( - name: "apx.worker.lifespan_shutdown_failed", - error = %e, - "ASGI lifespan shutdown failed" - ); - } + Ok(()) + }) +} - let _ = ipc_writer.send(&IpcMessage::Drained).await; +/// Create a Python callable that, given `on_request`, returns a `ProtocolFactory`. +/// +/// This is a partial application: `ScopeInterns`, host, and port are +/// captured; the `on_request` callback is provided later (once the +/// event loop is running and `CallSoonCapture` can be created). +fn create_factory_builder( + py: Python<'_>, + _app: Py, + interns: ScopeInterns, + host: &str, + port: u16, +) -> Result, WorkerError> { + let host = host.to_owned(); + + let builder = FactoryBuilder { + interns: std::sync::Mutex::new(Some(interns)), + host, + port, + }; + let builder_py = Py::new(py, builder) + .map_err(|e| WorkerError::Serve(format!("create FactoryBuilder: {e}")))?; - apx_core::tracing_init::shutdown_telemetry(); - runtime.event_loop.shutdown(); + Ok(builder_py.into_any()) +} - Ok(()) +/// Python-callable that captures `ScopeInterns` and produces a +/// `ProtocolFactory` when called with `on_request`. +#[pyclass(module = "apx._core")] +struct FactoryBuilder { + interns: std::sync::Mutex>, + host: String, + port: u16, } -/// Launch the ASGI lifespan protocol and await startup completion. -/// -/// Returns `Ok(Some(handle))` if the app completed lifespan startup, -/// `Ok(None)` if the app does not support lifespan, or `Err` on failure. -async fn launch_and_await_lifespan( - runtime: &WorkerRuntime, - ready: &AppReady, -) -> Result, WorkerError> { - let pending = Python::attach(|py| { - crate::asgi::lifespan::launch_lifespan( - py, - &runtime.event_loop, - &ready.asgi_app, - &ready.launch_fn, - ) - }) - .map_err(|e| WorkerError::LifespanStartup(format!("{e}")))?; - - match pending.wait_startup().await { - Ok(Some(handle)) => { - tracing::info!( - name: "apx.worker.lifespan_startup_complete", - "ASGI lifespan startup complete" - ); - Ok(Some(handle)) - } - Ok(None) => { - tracing::info!( - name: "apx.worker.lifespan_unsupported", - "app does not support ASGI lifespan protocol" - ); - Ok(None) - } - Err(msg) => Err(WorkerError::LifespanStartup(msg)), +crate::opaque_debug!(FactoryBuilder); + +#[pymethods] +impl FactoryBuilder { + fn __call__(&self, py: Python<'_>, on_request: Py) -> PyResult> { + let interns = self + .interns + .lock() + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))? + .take() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("FactoryBuilder already consumed") + })?; + let factory = ProtocolFactory::new(on_request, interns, self.host.clone(), self.port); + Py::new(py, factory) } } /// Connect, init, load app, signal readiness, and serve. /// -/// If app loading fails, sends `StartupFailed` over IPC so the supervisor -/// receives the error message instead of waiting for a readiness timeout. -/// /// # Errors /// /// Returns an error at any step in the worker lifecycle. @@ -366,35 +292,21 @@ pub async fn run_worker( channel: WorkerChannel, bootstrap: WorkerBootstrap, ) -> Result<(), WorkerError> { - let mut runtime = init_worker(&bootstrap, channel).await?; + init_python(); - let ready = match load_app(&runtime, &bootstrap) { + let ready = match load_app(&bootstrap) { Ok(ready) => ready, Err(e) => { let detail = format_worker_error(&e); - let _ = runtime - .channel - .send(&IpcMessage::StartupFailed { error: detail }) - .await; + let mut ch = channel; + let _ = ch.send(&IpcMessage::StartupFailed { error: detail }).await; return Err(e); } }; - // ASGI lifespan startup — run app startup hooks before accepting traffic. - let lifespan = match launch_and_await_lifespan(&runtime, &ready).await { - Ok(handle) => handle, - Err(e) => { - let detail = format_worker_error(&e); - let _ = runtime - .channel - .send(&IpcMessage::StartupFailed { error: detail }) - .await; - return Err(e); - } - }; - - signal_readiness(&mut runtime.channel).await?; - relay_telemetry(&mut runtime.channel, &bootstrap, &ready.telemetry).await?; + let mut channel = channel; + signal_readiness(&mut channel).await?; + relay_telemetry(&mut channel, &bootstrap, &ready.telemetry).await?; init_metrics(&ready.telemetry); if ready.telemetry.process.enabled { @@ -403,8 +315,49 @@ pub async fn run_worker( ); } - let service = build_service(&runtime, &bootstrap, ready.dispatch); - serve(runtime, service, bootstrap.drain_timeout_secs, lifespan).await + // Set up shutdown coordination between IPC reader and asyncio. + let (drain_tx, drain_rx) = tokio::sync::oneshot::channel::<()>(); + let (ipc_reader, mut ipc_writer) = channel.split(); + + tokio::spawn(async move { + let mut reader = ipc_reader; + tokio::select! { + msg = reader.recv() => { + match msg { + Ok(IpcMessage::Drain) => { + tracing::info!( + name: "apx.worker.drain_received", + "received Drain from supervisor" + ); + let _ = drain_tx.send(()); + } + Ok(msg) => tracing::warn!( + name: "apx.worker.drain_unexpected_ipc", + ?msg, + "unexpected IPC message" + ), + Err(e) => tracing::debug!( + name: "apx.worker.drain_ipc_closed", + error = %e, + "IPC channel closed" + ), + } + } + () = shutdown_signal() => { + let _ = drain_tx.send(()); + } + } + }); + + // Run the asyncio server (blocking — this IS the event loop). + let serve_result = tokio::task::spawn_blocking(move || run_server(ready, drain_rx)) + .await + .map_err(|e| WorkerError::Serve(format!("server task panicked: {e}")))?; + + let _ = ipc_writer.send(&IpcMessage::Drained).await; + apx_core::tracing_init::shutdown_telemetry(); + + serve_result } /// Detect worker mode and connect to the supervisor's IPC channel. @@ -451,8 +404,7 @@ pub async fn connect_to_supervisor() // ── Python logging config ─────────────────────────────────────────────── /// Apply the customer's Python logging config when `APX_PYTHON_LOG_CONFIG` -/// is set (dev mode only). Supports JSON (`logging.config.dictConfig`) and -/// Python (`logging.config.fileConfig`) config files. +/// is set. fn apply_python_log_config() -> Result<(), WorkerError> { let config_path = match std::env::var("APX_PYTHON_LOG_CONFIG") { Ok(p) if !p.is_empty() => p, @@ -482,19 +434,6 @@ fn apply_python_log_config() -> Result<(), WorkerError> { .map_err(|e: PyErr| WorkerError::PythonInit(format!("log config: {e}"))) } -// ── launch wrapper ────────────────────────────────────────────────────── - -/// Import `launch` from `apx._bridge`. -/// -/// `launch(app, scope, receive, send)` runs on the asyncio thread as a -/// `call_soon_threadsafe` callback. It calls `app(scope, receive, send)` -/// and wraps the coroutine in error-guarding + `create_task` — all in a -/// single `_run_once` callback, keeping the tokio thread GIL-free. -fn register_launch(py: Python<'_>) -> PyResult> { - let bridge = py.import(c"apx._bridge")?; - bridge.getattr(c"launch").map(|f| f.unbind()) -} - // ── Tests ─────────────────────────────────────────────────────────────── #[cfg(test)] @@ -523,7 +462,7 @@ mod tests { #[test] fn worker_error_display_transport() { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - let err = WorkerError::Transport(TransportError::Bind { + let err = WorkerError::Transport(crate::transport::TransportError::Bind { addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 8000), source: std::io::Error::other("in use"), }); @@ -531,175 +470,6 @@ mod tests { assert!(msg.contains("transport")); } - /// `launch` must forward app exceptions through `send.send_error()` - /// without re-raising — otherwise asyncio logs "Task exception was - /// never retrieved" on every app error (the task is fire-and-forget). - #[test] - fn launch_forwards_error_without_asyncio_leak() { - crate::with_py(|py| { - let launch_fn = register_launch(py).expect("register_launch"); - - py.run( - c" -import asyncio, gc - -_leak_errors = [] - -def _capture(loop, ctx): - _leak_errors.append(ctx.get('message', '')) - -class _MockSend: - def __init__(self): - self.errors = [] - def send_error(self, tb): - self.errors.append(tb) - -_mock = _MockSend() - -async def _failing_app(scope, receive, send): - raise RuntimeError('deliberate test error') - -_el = asyncio.new_event_loop() -_el.set_exception_handler(_capture) -", - None, - None, - ) - .expect("define fixtures"); - - let app = py.eval(c"_failing_app", None, None).expect("get app"); - let mock = py.eval(c"_mock", None, None).expect("get mock"); - let scope = pyo3::types::PyDict::new(py); - let el = py.eval(c"_el", None, None).expect("get el"); - - // Submit via call_soon_threadsafe — same as production. - let csts = el.getattr(c"call_soon_threadsafe").expect("csts"); - csts.call1((&launch_fn, &app, &scope, py.None(), &mock)) - .expect("submit via csTS"); - - py.run( - c" -async def _drain(): - await asyncio.sleep(0) - await asyncio.sleep(0) - gc.collect() - gc.collect() - await asyncio.sleep(0) - -_el.run_until_complete(_drain()) -_el.close() -", - None, - None, - ) - .expect("drain loop"); - - let send_errors: Vec = py - .eval(c"_mock.errors", None, None) - .expect("get send_errors") - .extract() - .expect("extract"); - assert!( - !send_errors.is_empty(), - "send_error must be called on app exception" - ); - assert!( - send_errors[0].contains("deliberate test error"), - "traceback must contain the error: {}", - send_errors[0] - ); - - let leaks: Vec = py - .eval(c"_leak_errors", None, None) - .expect("get leak errors") - .extract() - .expect("extract"); - let task_leaks: Vec<_> = leaks - .iter() - .filter(|e| e.contains("Task exception was never retrieved")) - .collect(); - assert!( - task_leaks.is_empty(), - "launch re-raised, causing asyncio log spam: {task_leaks:?}" - ); - }); - } - - /// `CancelledError` must propagate through `launch` — it's a control - /// flow signal, not an app error. It must NOT be forwarded to - /// `send.send_error()`. - #[test] - fn launch_propagates_cancellation() { - crate::with_py(|py| { - let launch_fn = register_launch(py).expect("register_launch"); - - py.run( - c" -import asyncio - -class _MockSend2: - def __init__(self): - self.errors = [] - def send_error(self, tb): - self.errors.append(tb) - -_mock2 = _MockSend2() - -async def _slow_app(scope, receive, send): - await asyncio.sleep(10) - -_el2 = asyncio.new_event_loop() -", - None, - None, - ) - .expect("define fixtures"); - - let app = py.eval(c"_slow_app", None, None).expect("get app"); - let mock = py.eval(c"_mock2", None, None).expect("get mock"); - let scope = pyo3::types::PyDict::new(py); - let el = py.eval(c"_el2", None, None).expect("get el"); - - let csts = el.getattr(c"call_soon_threadsafe").expect("csts"); - csts.call1((&launch_fn, &app, &scope, py.None(), &mock)) - .expect("submit via csTS"); - - py.run( - c" -async def _run(): - await asyncio.sleep(0) # let launch create the task - # Find the app task (not ourselves). - app_tasks = [t for t in asyncio.all_tasks(_el2) - if not t.done() and t is not asyncio.current_task()] - for t in app_tasks: - t.cancel() - # Let cancel propagate. - for t in app_tasks: - try: - await t - except asyncio.CancelledError: - pass - -_el2.run_until_complete(_run()) -_el2.close() -", - None, - None, - ) - .expect("run test"); - - let send_errors: Vec = py - .eval(c"_mock2.errors", None, None) - .expect("get send_errors") - .extract() - .expect("extract"); - assert!( - send_errors.is_empty(), - "CancelledError must not be forwarded to send_error: {send_errors:?}" - ); - }); - } - // SAFETY: these env-var tests are single-threaded (#[test] with no async // or parallel spawns), so set_var / remove_var cannot race. #[expect(unsafe_code, reason = "env-var manipulation in single-threaded test")] diff --git a/crates/framework/src/supervision/worker_context.rs b/crates/framework/src/supervision/worker_context.rs deleted file mode 100644 index 7920e427..00000000 --- a/crates/framework/src/supervision/worker_context.rs +++ /dev/null @@ -1,24 +0,0 @@ -//! Shared worker infrastructure passed to every dispatch strategy. -//! -//! `WorkerContext` holds the dispatch pipeline for the 3-thread -//! architecture plus legacy asyncio submission callables used by the -//! WebSocket path (which still submits via `call_soon_threadsafe`). - -use crate::io::channel::DispatchPipeline; -use pyo3::Py; -use std::sync::Arc; - -/// Shared infrastructure available to all dispatch strategies. -/// -/// Created once per worker in `run_worker`, wrapped in `Arc`, and passed -/// to the dispatch implementation. -pub struct WorkerContext { - /// 3-thread dispatch pipeline (channels + wakeup). - pub pipeline: Arc, - /// Cached `loop.call_soon_threadsafe` — used by WS dispatch path. - pub call_soon_threadsafe: Py, - /// Cached `_bridge.launch` — used by WS dispatch path. - pub launch_fn: Py, -} - -crate::opaque_debug!(WorkerContext); diff --git a/crates/framework/src/telemetry/config.rs b/crates/framework/src/telemetry/config.rs index 8817b444..ba320ab2 100644 --- a/crates/framework/src/telemetry/config.rs +++ b/crates/framework/src/telemetry/config.rs @@ -142,7 +142,7 @@ impl Default for HttpMetricToggles { } } -/// APX framework dispatch timing instrumentation configuration. +/// APX protocol instrumentation configuration. #[derive(Debug, Clone, Copy)] pub struct ApxConfig { /// Whether APX dispatch metrics are enabled. @@ -151,31 +151,33 @@ pub struct ApxConfig { pub metrics: ApxMetricToggles, } -/// Per-metric boolean toggles for APX dispatch timing. +/// Per-metric boolean toggles for the APX dispatch pipeline. #[derive(Debug, Clone, Copy, Default)] #[expect( clippy::struct_excessive_bools, reason = "one bool per OTEL metric toggle" )] pub struct ApxMetricToggles { - /// Toggle for [`super::defs::DISPATCH_BODY_COLLECT`]. - pub dispatch_body_collect: bool, - /// Toggle for [`super::defs::DISPATCH_CROSSBEAM_SEND`]. - pub dispatch_crossbeam_send: bool, - /// Toggle for [`super::defs::DISPATCH_RESPONSE_WAIT`]. - pub dispatch_response_wait: bool, - /// Toggle for [`super::defs::DISPATCH_TOTAL`]. - pub dispatch_total: bool, - /// Toggle for [`super::defs::ASGI_RECEIVE_BUILD`]. - pub asgi_receive_build: bool, - /// Toggle for [`super::defs::ASGI_SEND_PARSE`]. - pub asgi_send_parse: bool, - /// Toggle for [`super::defs::DISPATCH_PICKUP_DELAY`]. - pub dispatch_pickup_delay: bool, - /// Toggle for [`super::defs::DISPATCH_MATERIALIZE`]. - pub dispatch_materialize: bool, - /// Toggle for [`super::defs::DISPATCH_QUEUE_DEPTH`]. - pub dispatch_queue_depth: bool, + /// Toggle for [`super::defs::PARSE`]. + pub parse: bool, + /// Toggle for [`super::defs::SCOPE_BUILD`]. + pub scope_build: bool, + /// Toggle for [`super::defs::RECEIVE_BUILD`]. + pub receive_build: bool, + /// Toggle for [`super::defs::SEND_PARSE`]. + pub send_parse: bool, + /// Toggle for [`super::defs::RESPONSE_BUILD`]. + pub response_build: bool, + /// Toggle for [`super::defs::RESPONSE_WRITE`]. + pub response_write: bool, + /// Toggle for [`super::defs::HANDLER_WAIT`]. + pub handler_wait: bool, + /// Toggle for [`super::defs::REQUEST_TOTAL`]. + pub request_total: bool, + /// Toggle for [`super::defs::ACTIVE_REQUESTS`]. + pub active_requests: bool, + /// Toggle for [`super::defs::CONNECTIONS`]. + pub connections: bool, } // ── Public defaults (used by supervisor) ───────────────────────────────── @@ -444,15 +446,16 @@ fn parse_apx_config(dict: &Bound<'_, PyDict>) -> PyResult { fn parse_apx_metric_toggles(dict: &Bound<'_, PyDict>) -> ApxMetricToggles { ApxMetricToggles { - dispatch_body_collect: extract_bool_or(dict, "dispatch_body_collect", false), - dispatch_crossbeam_send: extract_bool_or(dict, "dispatch_crossbeam_send", false), - dispatch_response_wait: extract_bool_or(dict, "dispatch_response_wait", false), - dispatch_total: extract_bool_or(dict, "dispatch_total", false), - asgi_receive_build: extract_bool_or(dict, "asgi_receive_build", false), - asgi_send_parse: extract_bool_or(dict, "asgi_send_parse", false), - dispatch_pickup_delay: extract_bool_or(dict, "dispatch_pickup_delay", false), - dispatch_materialize: extract_bool_or(dict, "dispatch_materialize", false), - dispatch_queue_depth: extract_bool_or(dict, "dispatch_queue_depth", false), + parse: extract_bool_or(dict, "parse", false), + scope_build: extract_bool_or(dict, "scope_build", false), + receive_build: extract_bool_or(dict, "receive_build", false), + send_parse: extract_bool_or(dict, "send_parse", false), + response_build: extract_bool_or(dict, "response_build", false), + response_write: extract_bool_or(dict, "response_write", false), + handler_wait: extract_bool_or(dict, "handler_wait", false), + request_total: extract_bool_or(dict, "request_total", false), + active_requests: extract_bool_or(dict, "active_requests", false), + connections: extract_bool_or(dict, "connections", false), } } diff --git a/crates/framework/src/telemetry/context.rs b/crates/framework/src/telemetry/context.rs index b1d97dd3..de6885d4 100644 --- a/crates/framework/src/telemetry/context.rs +++ b/crates/framework/src/telemetry/context.rs @@ -1,9 +1,9 @@ -//! Trace context propagation between Rust `tracing` spans and Python. +//! Trace context propagation between Rust and Python. //! -//! Rust request spans live on tokio threads. Before scheduling Python work -//! on the event loop, we extract the current trace context and inject it -//! into a Python `ContextVar` so that user-created `SpanHandle` instances -//! attach as children. +//! Rust protocol code runs on the asyncio thread alongside Python. +//! The trace context is propagated via a Python `ContextVar` so that +//! user-created `SpanHandle` instances attach as children of the +//! request span. //! //! Shared primitives (`SerializedContext`, `read_context_var_raw`, //! `parse_span_context`) are used by both `spans.rs` and `logging.rs`. diff --git a/crates/framework/src/telemetry/defs.rs b/crates/framework/src/telemetry/defs.rs index 996d174d..0c81a85e 100644 --- a/crates/framework/src/telemetry/defs.rs +++ b/crates/framework/src/telemetry/defs.rs @@ -4,7 +4,9 @@ //! creation, config doc comments, and Python toggle models all reference these //! constants instead of duplicating string literals. -use opentelemetry::metrics::{AsyncInstrument, Gauge, Histogram, Meter, ObservableGauge}; +use opentelemetry::metrics::{ + AsyncInstrument, Gauge, Histogram, Meter, ObservableGauge, UpDownCounter, +}; /// Descriptor for an OTEL metric instrument. #[derive(Debug, Clone, Copy)] @@ -41,6 +43,15 @@ impl MetricDef { .build() } + /// Build an i64 up-down counter from this definition. + pub fn up_down_counter(self, meter: &Meter) -> UpDownCounter { + meter + .i64_up_down_counter(self.name) + .with_description(self.description) + .with_unit(self.unit) + .build() + } + /// Build an observable f64 gauge that reports via a callback. pub fn observable_gauge(self, meter: &Meter, callback: F) -> ObservableGauge where @@ -131,68 +142,75 @@ pub const HTTP_ACTIVE_REQUESTS: MetricDef = MetricDef { unit: "1", }; -// ── APX dispatch metrics (per-worker) ──────────────────────────────────── +// ── APX protocol metrics (per-worker) ───────────────────────────────── -/// Time to collect the request body from the network stream. -pub const DISPATCH_BODY_COLLECT: MetricDef = MetricDef { - name: "apx.dispatch.body_collect.duration", - description: "Time to collect the request body from the network stream", +/// HTTP request parsing time. +pub const PARSE: MetricDef = MetricDef { + name: "apx.parse", + description: "HTTP request parsing time", unit: "us", }; -/// Time to push the request slot to the crossbeam channel and signal wakeup. -pub const DISPATCH_CROSSBEAM_SEND: MetricDef = MetricDef { - name: "apx.dispatch.crossbeam_send.duration", - description: "Time to push the request slot to the crossbeam channel and signal wakeup", +/// ASGI scope dict construction time. +pub const SCOPE_BUILD: MetricDef = MetricDef { + name: "apx.scope_build", + description: "ASGI scope dict construction time", unit: "us", }; -/// Time waiting for the Python handler to produce a response. -pub const DISPATCH_RESPONSE_WAIT: MetricDef = MetricDef { - name: "apx.dispatch.response_wait.duration", - description: "Time waiting for the Python handler to produce a response", +/// ASGI receive dict construction time. +pub const RECEIVE_BUILD: MetricDef = MetricDef { + name: "apx.receive_build", + description: "ASGI receive dict construction time", unit: "us", }; -/// Total dispatch duration from body collect start to response ready. -pub const DISPATCH_TOTAL: MetricDef = MetricDef { - name: "apx.dispatch.total.duration", - description: "Total dispatch duration from body collect start to response ready", +/// ASGI send event parsing time. +pub const SEND_PARSE: MetricDef = MetricDef { + name: "apx.send_parse", + description: "ASGI send event parsing time", unit: "us", }; -/// Time to build the ASGI receive dict for the Python handler. -pub const ASGI_RECEIVE_BUILD: MetricDef = MetricDef { - name: "apx.asgi.receive_build.duration", - description: "Time to build the ASGI receive dict for the Python handler", +/// HTTP response header construction time. +pub const RESPONSE_BUILD: MetricDef = MetricDef { + name: "apx.response_build", + description: "HTTP response header construction time", unit: "us", }; -/// Time to parse the ASGI send event dict from the Python handler. -pub const ASGI_SEND_PARSE: MetricDef = MetricDef { - name: "apx.asgi.send_parse.duration", - description: "Time to parse the ASGI send event dict from the Python handler", +/// Transport write time. +pub const RESPONSE_WRITE: MetricDef = MetricDef { + name: "apx.response_write", + description: "Transport write time", unit: "us", }; -/// Time from slot creation on the tokio thread to pickup on the asyncio thread. -pub const DISPATCH_PICKUP_DELAY: MetricDef = MetricDef { - name: "apx.dispatch.pickup_delay.duration", - description: "Time from slot creation to asyncio thread pickup", +/// Handler execution time (dispatch to response complete). +pub const HANDLER_WAIT: MetricDef = MetricDef { + name: "apx.handler_wait", + description: "Handler execution time", unit: "us", }; -/// Time to build the ASGI scope dict and receive/send callables. -pub const DISPATCH_MATERIALIZE: MetricDef = MetricDef { - name: "apx.dispatch.materialize.duration", - description: "Time to build ASGI scope and receive/send callables", +/// Total request processing time. +pub const REQUEST_TOTAL: MetricDef = MetricDef { + name: "apx.request_total", + description: "Total request processing time", unit: "us", }; -/// Number of pending request slots in the crossbeam channel at drain time. -pub const DISPATCH_QUEUE_DEPTH: MetricDef = MetricDef { - name: "apx.dispatch.queue_depth", - description: "Pending request slots in the crossbeam channel at drain time", +/// In-flight requests on this worker. +pub const ACTIVE_REQUESTS: MetricDef = MetricDef { + name: "apx.active_requests", + description: "In-flight requests on this worker", + unit: "1", +}; + +/// Active TCP connections on this worker. +pub const CONNECTIONS: MetricDef = MetricDef { + name: "apx.connections", + description: "Active TCP connections on this worker", unit: "1", }; @@ -264,49 +282,54 @@ pub static ALL_METRICS: &[MetricCatalogEntry] = &[ group: "http", scope: "worker", }, - // APX dispatch (per-worker) + // APX request pipeline (per-worker) + MetricCatalogEntry { + def: PARSE, + group: "apx", + scope: "worker", + }, MetricCatalogEntry { - def: DISPATCH_BODY_COLLECT, + def: SCOPE_BUILD, group: "apx", scope: "worker", }, MetricCatalogEntry { - def: DISPATCH_CROSSBEAM_SEND, + def: RECEIVE_BUILD, group: "apx", scope: "worker", }, MetricCatalogEntry { - def: DISPATCH_RESPONSE_WAIT, + def: SEND_PARSE, group: "apx", scope: "worker", }, MetricCatalogEntry { - def: DISPATCH_TOTAL, + def: RESPONSE_BUILD, group: "apx", scope: "worker", }, MetricCatalogEntry { - def: ASGI_RECEIVE_BUILD, + def: RESPONSE_WRITE, group: "apx", scope: "worker", }, MetricCatalogEntry { - def: ASGI_SEND_PARSE, + def: HANDLER_WAIT, group: "apx", scope: "worker", }, MetricCatalogEntry { - def: DISPATCH_PICKUP_DELAY, + def: REQUEST_TOTAL, group: "apx", scope: "worker", }, MetricCatalogEntry { - def: DISPATCH_MATERIALIZE, + def: ACTIVE_REQUESTS, group: "apx", scope: "worker", }, MetricCatalogEntry { - def: DISPATCH_QUEUE_DEPTH, + def: CONNECTIONS, group: "apx", scope: "worker", }, diff --git a/crates/framework/src/telemetry/dispatch_metrics.rs b/crates/framework/src/telemetry/dispatch_metrics.rs index 48bfcca8..a91c9649 100644 --- a/crates/framework/src/telemetry/dispatch_metrics.rs +++ b/crates/framework/src/telemetry/dispatch_metrics.rs @@ -1,6 +1,6 @@ -//! APX framework dispatch timing histograms. +//! APX request pipeline histograms. //! -//! Records per-phase latency for the ASGI dispatch pipeline via OTEL +//! Records per-phase latency for the request dispatch pipeline via OTEL //! histograms. All instruments are lazily created on first use and guarded //! by the `ApxMetricToggles` boolean flags — disabled metrics have zero //! overhead. @@ -10,7 +10,7 @@ use std::sync::OnceLock; -use opentelemetry::metrics::Histogram; +use opentelemetry::metrics::{Gauge, Histogram}; use super::config::ApxMetricToggles; use super::defs; @@ -19,15 +19,16 @@ use super::http::framework_meter; // ── Global toggles ──────────────────────────────────────────────────────── super::toggle_store!(TOGGLES: ApxMetricToggles = ApxMetricToggles { - dispatch_body_collect: false, - dispatch_crossbeam_send: false, - dispatch_response_wait: false, - dispatch_total: false, - asgi_receive_build: false, - asgi_send_parse: false, - dispatch_pickup_delay: false, - dispatch_materialize: false, - dispatch_queue_depth: false, + parse: false, + scope_build: false, + receive_build: false, + send_parse: false, + response_build: false, + response_write: false, + handler_wait: false, + request_total: false, + active_requests: false, + connections: false, }); // ── Metric declarations ─────────────────────────────────────────────────── @@ -51,66 +52,112 @@ macro_rules! dispatch_metric { }; } +/// Generate a lazy gauge getter and gated `inc_*` / `dec_*` functions. +macro_rules! dispatch_gauge { + ($inc_fn:ident, $dec_fn:ident, $gauge_fn:ident, $toggle:ident, $def:expr, $doc:literal) => { + fn $gauge_fn() -> &'static Gauge { + static INST: OnceLock> = OnceLock::new(); + INST.get_or_init(|| $def.gauge(&framework_meter())) + } + + #[doc = $doc] + pub fn $inc_fn() { + if toggles().$toggle { + $gauge_fn().record(1.0, NO_ATTRS); + } + } + + /// Decrement the gauge. + pub fn $dec_fn() { + if toggles().$toggle { + $gauge_fn().record(-1.0, NO_ATTRS); + } + } + }; +} + +// ── Histograms ─────────────────────────────────────────────────────────── + dispatch_metric!( - record_body_collect, - body_collect_hist, - dispatch_body_collect, - defs::DISPATCH_BODY_COLLECT, - "Record `apx.dispatch.body_collect.duration` if enabled." -); -dispatch_metric!( - record_crossbeam_send, - crossbeam_send_hist, - dispatch_crossbeam_send, - defs::DISPATCH_CROSSBEAM_SEND, - "Record `apx.dispatch.crossbeam_send.duration` if enabled." -); -dispatch_metric!( - record_response_wait, - response_wait_hist, - dispatch_response_wait, - defs::DISPATCH_RESPONSE_WAIT, - "Record `apx.dispatch.response_wait.duration` if enabled." + record_parse, + parse_hist, + parse, + defs::PARSE, + "Record `apx.parse` if enabled." ); + dispatch_metric!( - record_dispatch_total, - dispatch_total_hist, - dispatch_total, - defs::DISPATCH_TOTAL, - "Record `apx.dispatch.total.duration` if enabled." + record_scope_build, + scope_build_hist, + scope_build, + defs::SCOPE_BUILD, + "Record `apx.scope_build` if enabled." ); + dispatch_metric!( record_receive_build, receive_build_hist, - asgi_receive_build, - defs::ASGI_RECEIVE_BUILD, - "Record `apx.asgi.receive_build.duration` if enabled." + receive_build, + defs::RECEIVE_BUILD, + "Record `apx.receive_build` if enabled." ); + dispatch_metric!( record_send_parse, send_parse_hist, - asgi_send_parse, - defs::ASGI_SEND_PARSE, - "Record `apx.asgi.send_parse.duration` if enabled." + send_parse, + defs::SEND_PARSE, + "Record `apx.send_parse` if enabled." ); + +dispatch_metric!( + record_response_build, + response_build_hist, + response_build, + defs::RESPONSE_BUILD, + "Record `apx.response_build` if enabled." +); + dispatch_metric!( - record_pickup_delay, - pickup_delay_hist, - dispatch_pickup_delay, - defs::DISPATCH_PICKUP_DELAY, - "Record `apx.dispatch.pickup_delay.duration` if enabled." + record_response_write, + response_write_hist, + response_write, + defs::RESPONSE_WRITE, + "Record `apx.response_write` if enabled." ); + dispatch_metric!( - record_materialize, - materialize_hist, - dispatch_materialize, - defs::DISPATCH_MATERIALIZE, - "Record `apx.dispatch.materialize.duration` if enabled." + record_handler_wait, + handler_wait_hist, + handler_wait, + defs::HANDLER_WAIT, + "Record `apx.handler_wait` if enabled." ); + dispatch_metric!( - record_queue_depth, - queue_depth_hist, - dispatch_queue_depth, - defs::DISPATCH_QUEUE_DEPTH, - "Record `apx.dispatch.queue_depth` if enabled." + record_dispatch_total, + dispatch_total_hist, + request_total, + defs::REQUEST_TOTAL, + "Record `apx.request_total` if enabled." +); + +// ── Gauges ─────────────────────────────────────────────────────────────── + +dispatch_gauge!( + inc_active_requests, + dec_active_requests, + active_requests_gauge, + active_requests, + defs::ACTIVE_REQUESTS, + "Increment `apx.active_requests`." +); + +dispatch_gauge!( + inc_connections, + dec_connections, + connections_gauge, + connections, + defs::CONNECTIONS, + "Increment `apx.connections`." ); diff --git a/crates/framework/src/telemetry/http.rs b/crates/framework/src/telemetry/http.rs index c70b5cf6..579cb126 100644 --- a/crates/framework/src/telemetry/http.rs +++ b/crates/framework/src/telemetry/http.rs @@ -9,11 +9,12 @@ use std::sync::OnceLock; -use crate::protocol::http::error::AppError; use crate::telemetry::config::HttpMetricToggles; +use crate::telemetry::context::TraceContext; use crate::telemetry::defs; use opentelemetry::KeyValue; use opentelemetry::metrics::{Histogram, UpDownCounter}; +use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState}; // ── Global HTTP metric toggles ──────────────────────────────────────────── @@ -29,62 +30,18 @@ pub(crate) fn framework_meter() -> opentelemetry::metrics::Meter { super::get_meter("apx.framework") } -// ── Active requests instrument ──────────────────────────────────────────── - -fn active_requests_counter() -> &'static UpDownCounter { - static COUNTER: OnceLock> = OnceLock::new(); - COUNTER.get_or_init(|| { - framework_meter() - .i64_up_down_counter(defs::HTTP_ACTIVE_REQUESTS.name) - .with_description(defs::HTTP_ACTIVE_REQUESTS.description) - .with_unit(defs::HTTP_ACTIVE_REQUESTS.unit) - .build() - }) -} - -// ── Active requests guard ───────────────────────────────────────────────── - -/// RAII guard that decrements `http.server.active_requests` on drop. -/// -/// Covers panics, timeouts, and early returns — the counter is always -/// decremented when the guard goes out of scope. -/// -/// Returns `None` when the `server_active_requests` toggle is disabled. -#[derive(Debug)] -pub struct ActiveRequestGuard { - attrs: [KeyValue; 2], -} - -impl ActiveRequestGuard { - /// Increment active requests and return a guard that decrements on drop. - /// - /// Returns `None` if the `server_active_requests` metric is disabled. - pub fn enter(method: &str, scheme: &str) -> Option { - if !toggles().server_active_requests { - return None; - } - let attrs = [ - KeyValue::new("http.request.method", method.to_owned()), - KeyValue::new("url.scheme", scheme.to_owned()), - ]; - active_requests_counter().add(1, &attrs); - Some(Self { attrs }) - } -} - -impl Drop for ActiveRequestGuard { - fn drop(&mut self) { - active_requests_counter().add(-1, &self.attrs); - } -} - -// ── Request duration instrument ─────────────────────────────────────────── +// ── Instruments ────────────────────────────────────────────────────────── fn duration_histogram() -> &'static Histogram { static HIST: OnceLock> = OnceLock::new(); HIST.get_or_init(|| defs::HTTP_REQUEST_DURATION.histogram(&framework_meter())) } +fn active_requests_counter() -> &'static UpDownCounter { + static CTR: OnceLock> = OnceLock::new(); + CTR.get_or_init(|| defs::HTTP_ACTIVE_REQUESTS.up_down_counter(&framework_meter())) +} + // ── Request duration ────────────────────────────────────────────────────── /// Record `http.server.request.duration` with standard attributes. @@ -128,85 +85,106 @@ pub fn record_duration( }); } -// ── Error / protocol helpers ────────────────────────────────────────────── +// ── Active requests ───────────────────────────────────────────────────── -/// Map an `AppError` variant to an OTEL semconv `error.type` value. -pub fn error_type_for(err: &AppError) -> &'static str { - match err { - AppError::Internal(_) => "500", - AppError::Timeout => "408", +/// Increment the `http.server.active_requests` counter. +pub fn inc_active_requests() { + if toggles().server_active_requests { + active_requests_counter().add(1, &[]); } } -/// Map `http::Version` to the semconv `network.protocol.version` string. -pub fn protocol_version(version: http::Version) -> &'static str { - match version { - http::Version::HTTP_09 => "0.9", - http::Version::HTTP_10 => "1.0", - http::Version::HTTP_2 => "2", - http::Version::HTTP_3 => "3", - _ => "1.1", +/// Decrement the `http.server.active_requests` counter. +pub fn dec_active_requests() { + if toggles().server_active_requests { + active_requests_counter().add(-1, &[]); } } -// ── Header capture ─────────────────────────────────────────────────────── - -use super::config::HttpConfig; - -const REDACTED: &str = "[REDACTED]"; - -/// Extract header values as OTEL span attributes for the given direction. -fn capture_headers( - direction: &str, - header_names: &[String], - headers: &http::HeaderMap, - sanitize_patterns: &[String], -) -> Vec { - let mut attrs = Vec::new(); - for name in header_names { - let lower = name.to_lowercase(); - let values: Vec<&str> = headers - .get_all( - http::header::HeaderName::from_bytes(lower.as_bytes()) - .unwrap_or(http::header::HeaderName::from_static("x-unknown")), - ) - .iter() - .filter_map(|v| v.to_str().ok()) - .collect(); - if values.is_empty() { - continue; - } - let normalized = name.to_lowercase().replace('-', "_"); - let attr_name = format!("http.{direction}.header.{normalized}"); - let value = if sanitize_patterns - .iter() - .any(|p| lower.contains(&p.to_lowercase())) - { - REDACTED.to_owned() - } else { - values.join(", ") - }; - attrs.push(KeyValue::new(attr_name, value)); +// ── Request span ──────────────────────────────────────────────────────── + +/// Parse a UUID string into a 16-byte OTEL trace ID. +fn uuid_to_trace_id(uuid: &str) -> Option<[u8; 16]> { + let hex: String = uuid.chars().filter(|c| *c != '-').collect(); + hex::decode(&hex).ok()?.try_into().ok() +} + +/// Create a `tracing` span for an HTTP request. +/// +/// Uses the `x-request-id` UUID as the trace ID so that all spans +/// and logs within a request share the same trace. Returns the span +/// and a [`TraceContext`] for propagation to Python. +/// +/// The returned span must be entered (via `.enter()`) when recording +/// metrics or logs that should carry the request's trace context. +pub fn begin_request_span( + request_id: &str, + method: &str, + path: &str, +) -> (tracing::Span, TraceContext) { + use tracing_opentelemetry::OpenTelemetrySpanExt; + + let span = tracing::info_span!( + "http.server.request", + "http.request.method" = method, + "url.path" = path, + "http.response.status_code" = tracing::field::Empty, + otel.kind = "server", + ); + + if let Some(tid) = uuid_to_trace_id(request_id) { + let parent_sc = SpanContext::new( + TraceId::from_bytes(tid), + SpanId::from_bytes(rand::random()), + TraceFlags::SAMPLED, + true, + TraceState::default(), + ); + let parent_cx = opentelemetry::Context::new().with_remote_span_context(parent_sc); + span.set_parent(parent_cx); } - attrs + + let ctx = { + let _guard = span.enter(); + super::context::extract_trace_context().unwrap_or(TraceContext { + trace_id: [0; 16], + span_id: [0; 8], + trace_flags: 0, + trace_state: String::new(), + }) + }; + + (span, ctx) } -/// Extract request header values as OTEL span attributes. -pub fn capture_request_headers(headers: &http::HeaderMap, config: &HttpConfig) -> Vec { - capture_headers( - "request", - &config.capture_request_headers, - headers, - &config.sanitize_headers, - ) +/// Record response status on a request span before it ends. +pub fn finish_request_span(span: &tracing::Span, status: u16) { + span.record("http.response.status_code", i64::from(status)); } -/// Extract response header values as OTEL span attributes. -pub fn capture_response_headers(headers: &http::HeaderMap, config: &HttpConfig) -> Vec { - capture_headers( - "response", - &config.capture_response_headers, - headers, - &config.sanitize_headers, - ) +#[cfg(test)] +#[expect(clippy::expect_used, reason = "test code uses expect for clarity")] +mod tests { + use super::*; + + #[test] + fn test_uuid_to_trace_id_valid() { + let id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"; + let tid = uuid_to_trace_id(id).expect("valid UUID"); + assert_eq!(hex::encode(tid), "a1b2c3d4e5f64a7b8c9d0e1f2a3b4c5d"); + } + + #[test] + fn test_uuid_to_trace_id_no_dashes() { + let id = "a1b2c3d4e5f64a7b8c9d0e1f2a3b4c5d"; + let tid = uuid_to_trace_id(id).expect("valid hex"); + assert_eq!(hex::encode(tid), id); + } + + #[test] + fn test_uuid_to_trace_id_invalid() { + assert!(uuid_to_trace_id("not-a-uuid").is_none()); + assert!(uuid_to_trace_id("").is_none()); + assert!(uuid_to_trace_id("zzzzzzzz-zzzz-zzzz-zzzz-zzzzzzzzzzzz").is_none()); + } } diff --git a/crates/framework/src/telemetry/mod.rs b/crates/framework/src/telemetry/mod.rs index 1429f94c..c62c77bb 100644 --- a/crates/framework/src/telemetry/mod.rs +++ b/crates/framework/src/telemetry/mod.rs @@ -61,20 +61,6 @@ macro_rules! toggle_store { } pub(crate) use toggle_store; -/// Time an expression and record elapsed microseconds via a metric function. -/// -/// Rust equivalent of Python's `with timing(metric): `. -/// Works with `?`, `.await`, blocks, and nested `timed!` calls. -macro_rules! timed { - ($record:path, $expr:expr) => {{ - let __t0 = ::std::time::Instant::now(); - let __val = $expr; - $record(__t0.elapsed().as_micros() as f64); - __val - }}; -} -pub(crate) use timed; - /// State that can be refreshed before reading. /// /// Implemented by `SystemState` and `ProcessState` to share the diff --git a/crates/framework/src/transport/listener.rs b/crates/framework/src/transport/listener.rs deleted file mode 100644 index fb60f358..00000000 --- a/crates/framework/src/transport/listener.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! Transport listener trait and configuration. - -use super::types::TransportKind; -use std::future::Future; -use std::net::{IpAddr, SocketAddr}; - -/// Errors during transport operations. -#[derive(Debug, thiserror::Error)] -pub enum TransportError { - /// The host string could not be parsed as an IP address. - #[error("invalid host address '{host}': {source}")] - InvalidHost { - /// The host string that failed to parse. - host: String, - /// The underlying parse error. - source: std::net::AddrParseError, - }, - - /// Socket creation failed. - #[error("failed to create socket: {0}")] - SocketCreate(std::io::Error), - - /// Binding to the address failed. - #[error("failed to bind {addr}: {source}")] - Bind { - /// The socket address that failed to bind. - addr: SocketAddr, - /// The underlying IO error. - source: std::io::Error, - }, - - /// Transitioning to listen mode failed. - #[error("failed to listen: {0}")] - Listen(std::io::Error), - - /// Converting to a tokio listener failed. - #[error("failed to convert to tokio listener: {0}")] - TokioConvert(std::io::Error), - - /// Serving requests failed. - #[error("serve failed: {0}")] - Serve(std::io::Error), -} - -/// Configuration for transport binding. -#[derive(Debug, Clone, Copy)] -pub struct TransportConfig { - /// IP address to bind. - pub host: IpAddr, - /// Port to bind. - pub port: u16, - /// Which transport to use. - pub transport_kind: TransportKind, -} - -impl TransportConfig { - /// Create a TCP transport config. - pub fn tcp(host: IpAddr, port: u16) -> Self { - Self { - host, - port, - transport_kind: TransportKind::Tcp, - } - } -} - -/// Transport-agnostic listener trait. -/// -/// v1: `TcpListener` (hyper for HTTP/1 + HTTP/2). -/// Future: `QuicListener` (quinn for HTTP/3), `UnixListener`, `InMemoryListener`. -/// -/// The `serve()` method is intentionally absent — it lives in the hyper service -/// layer (`protocol::http::service`). -pub trait Listener: Send + Sync + 'static { - /// Bind to the configured address. - fn bind(config: &TransportConfig) -> impl Future> + Send - where - Self: Sized; - - /// Return the locally bound socket address. - fn local_addr(&self) -> SocketAddr; - - /// Return the transport kind. - fn transport_kind(&self) -> TransportKind; -} - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - reason = "test code uses unwrap/assert for clarity" -)] -mod tests { - use super::*; - use std::net::IpAddr; - - #[test] - fn transport_config_tcp() { - let config = TransportConfig::tcp(IpAddr::from([127, 0, 0, 1]), 8080); - assert_eq!(config.host, IpAddr::from([127, 0, 0, 1])); - assert_eq!(config.port, 8080); - assert!(matches!(config.transport_kind, TransportKind::Tcp)); - } - - #[test] - fn transport_error_display_invalid_host() { - let source = "bad".parse::().unwrap_err(); - let err = TransportError::InvalidHost { - host: "bad".to_owned(), - source, - }; - let msg = format!("{err}"); - assert!(msg.contains("bad")); - assert!(msg.contains("invalid")); - } - - #[test] - fn transport_error_display_socket_create() { - let err = TransportError::SocketCreate(std::io::Error::other("create fail")); - let msg = format!("{err}"); - assert!(msg.contains("create")); - } - - #[test] - fn transport_error_display_bind() { - let addr = SocketAddr::from(([127, 0, 0, 1], 80)); - let err = TransportError::Bind { - addr, - source: std::io::Error::other("in use"), - }; - let msg = format!("{err}"); - assert!(msg.contains("bind")); - } - - #[test] - fn transport_error_display_listen() { - let err = TransportError::Listen(std::io::Error::other("listen fail")); - let msg = format!("{err}"); - assert!(msg.contains("listen")); - } - - #[test] - fn transport_error_display_tokio_convert() { - let err = TransportError::TokioConvert(std::io::Error::other("convert fail")); - let msg = format!("{err}"); - assert!(msg.contains("tokio")); - } - - #[test] - fn transport_error_display_serve() { - let err = TransportError::Serve(std::io::Error::other("serve fail")); - let msg = format!("{err}"); - assert!(msg.contains("serve")); - } -} diff --git a/crates/framework/src/transport/mod.rs b/crates/framework/src/transport/mod.rs index a3748a7e..5e759bcb 100644 --- a/crates/framework/src/transport/mod.rs +++ b/crates/framework/src/transport/mod.rs @@ -1,14 +1,5 @@ -//! Transport layer abstraction. -//! -//! Separates transport-specific code (TCP/QUIC/Unix/in-memory) from the -//! application layer. The [`Listener`] trait is the binding point. +//! Transport type definitions. -pub mod listener; -pub mod tcp; pub mod types; -pub use listener::{Listener, TransportConfig, TransportError}; -pub use tcp::TcpListener; -pub use types::{ - BodyError, BodyStream, InboundRequest, OutboundResponse, ProtocolVersion, ResponseBody, -}; +pub use types::{ProtocolVersion, TransportError}; diff --git a/crates/framework/src/transport/tcp.rs b/crates/framework/src/transport/tcp.rs deleted file mode 100644 index 3b0f035b..00000000 --- a/crates/framework/src/transport/tcp.rs +++ /dev/null @@ -1,173 +0,0 @@ -//! TCP listener with `SO_REUSEPORT` implementing the [`Listener`] trait. -//! -//! Each worker creates its own `TcpListener` bound to the same port via -//! `SO_REUSEPORT`. The kernel distributes incoming connections across -//! all listeners (on Linux; macOS behavior differs — see spike-results.md). - -use super::types::TransportKind; -use super::{Listener, TransportConfig, TransportError}; -use std::net::SocketAddr; - -/// TCP listener with `SO_REUSEPORT` for multi-worker sharing. -pub struct TcpListener { - /// The underlying tokio TCP listener. - inner: tokio::net::TcpListener, - /// Bound address (resolved after `bind()`). - addr: SocketAddr, -} - -impl std::fmt::Debug for TcpListener { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TcpListener") - .field("addr", &self.addr) - .finish() - } -} - -impl TcpListener { - /// Accept a new TCP connection. - pub async fn accept(&self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> { - self.inner.accept().await - } - - /// Expose the inner tokio listener for the hyper service layer (Step 2). - pub fn into_inner(self) -> tokio::net::TcpListener { - self.inner - } -} - -impl Listener for TcpListener { - async fn bind(config: &TransportConfig) -> Result - where - Self: Sized, - { - let socket = create_socket(config)?; - let listener = tokio::net::TcpListener::from_std(socket.into()) - .map_err(TransportError::TokioConvert)?; - let addr = listener.local_addr().map_err(|e| TransportError::Bind { - addr: SocketAddr::new(config.host, config.port), - source: e, - })?; - Ok(Self { - inner: listener, - addr, - }) - } - - fn local_addr(&self) -> SocketAddr { - self.addr - } - - fn transport_kind(&self) -> TransportKind { - TransportKind::Tcp - } -} - -/// TCP listen backlog — max number of pending connections queued by the kernel. -const LISTEN_BACKLOG: i32 = 1024; - -/// Create a `socket2::Socket` configured for `SO_REUSEPORT` TCP listening. -fn create_socket(config: &TransportConfig) -> Result { - let addr = SocketAddr::new(config.host, config.port); - let domain = match addr { - SocketAddr::V4(_) => socket2::Domain::IPV4, - SocketAddr::V6(_) => socket2::Domain::IPV6, - }; - - let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP)) - .map_err(TransportError::SocketCreate)?; - - socket - .set_reuse_port(true) - .map_err(TransportError::SocketCreate)?; - - // IPv6 dual-stack: on Linux, set IPV6_V6ONLY=false for dual-stack. - // macOS handles dual-stack differently — binding "::" already accepts - // IPv4 connections by default. - #[cfg(target_os = "linux")] - if addr.is_ipv6() { - socket - .set_only_v6(false) - .map_err(TransportError::SocketCreate)?; - } - - socket - .bind(&addr.into()) - .map_err(|e| TransportError::Bind { addr, source: e })?; - - socket - .listen(LISTEN_BACKLOG) - .map_err(TransportError::Listen)?; - - socket - .set_nonblocking(true) - .map_err(TransportError::Listen)?; - - Ok(socket) -} - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - reason = "test code uses unwrap/assert for clarity" -)] -mod tests { - use super::*; - use std::net::IpAddr; - - #[tokio::test] - async fn tcp_listener_bind_ipv4() { - let config = TransportConfig::tcp(IpAddr::from([127, 0, 0, 1]), 0); - let listener = TcpListener::bind(&config).await; - assert!(listener.is_ok(), "IPv4 listener should succeed"); - } - - #[tokio::test] - async fn tcp_listener_bind_ipv6() { - let config = TransportConfig::tcp(IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]), 0); - let listener = TcpListener::bind(&config).await; - assert!(listener.is_ok(), "IPv6 listener should succeed"); - } - - #[tokio::test] - async fn tcp_listener_local_addr() { - let config = TransportConfig::tcp(IpAddr::from([127, 0, 0, 1]), 0); - let listener = TcpListener::bind(&config).await; - assert!(listener.is_ok()); - let listener = listener.unwrap_or_else(|_| unreachable!()); - assert_ne!(listener.local_addr().port(), 0); - } - - #[tokio::test] - async fn tcp_listener_transport_kind() { - let config = TransportConfig::tcp(IpAddr::from([127, 0, 0, 1]), 0); - let listener = TcpListener::bind(&config).await; - assert!(listener.is_ok()); - let listener = listener.unwrap_or_else(|_| unreachable!()); - assert_eq!(listener.transport_kind(), TransportKind::Tcp); - } - - #[tokio::test] - async fn tcp_listener_debug() { - let config = TransportConfig::tcp(IpAddr::from([127, 0, 0, 1]), 0); - let listener = TcpListener::bind(&config).await.unwrap(); - let dbg = format!("{listener:?}"); - assert!(dbg.contains("TcpListener")); - assert!(dbg.contains("addr")); - } - - #[tokio::test] - async fn tcp_listener_accept_returns_connection() { - let config = TransportConfig::tcp(IpAddr::from([127, 0, 0, 1]), 0); - let listener = TcpListener::bind(&config).await.unwrap(); - let addr = listener.local_addr(); - - // Connect from a client - let _client = tokio::net::TcpStream::connect(addr).await.unwrap(); - - // Accept should succeed - let (stream, client_addr) = listener.accept().await.unwrap(); - assert!(stream.peer_addr().is_ok()); - assert_eq!(client_addr.ip(), IpAddr::from([127, 0, 0, 1])); - } -} diff --git a/crates/framework/src/transport/types.rs b/crates/framework/src/transport/types.rs index e5dd75bd..0a17e3d2 100644 --- a/crates/framework/src/transport/types.rs +++ b/crates/framework/src/transport/types.rs @@ -1,38 +1,10 @@ -//! Transport-neutral request and response types. -//! -//! These types sit between the protocol layer and the application layer. -//! They are the architectural pivot that keeps ASGI and dispatch transport-agnostic. -//! -//! `InboundRequest` / `OutboundResponse` are the sole interface between the -//! transport-specific code (hyper, future quinn) and the transport-agnostic -//! application code (routing, dispatch, ASGI adapter). +//! Transport type definitions. -use bytes::{Bytes, BytesMut}; -use http::header::HeaderMap; -use http_body::Frame; -use serde::{Deserialize, Serialize}; use std::net::SocketAddr; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// Which transport carried this request. -/// -/// Closed enum — new transports are added as variants. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -pub enum TransportKind { - /// TCP socket. - Tcp, - /// Unix domain socket. - Unix, - /// In-memory channel (tests). - InMemory, - // Quic, // future -} /// HTTP protocol version. /// -/// Tracked per-request so ASGI scope can set `http_version` correctly -/// and future h3 responses can set appropriate headers. +/// Tracked per-request so ASGI scope can set `http_version` correctly. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ProtocolVersion { /// HTTP/1.0. @@ -41,611 +13,25 @@ pub enum ProtocolVersion { Http11, /// HTTP/2. H2, - // H3, // future — added when QUIC transport lands } -impl ProtocolVersion { - /// ASGI spec string for `scope["http_version"]`. - pub fn as_asgi_version(&self) -> &'static str { - match self { - Self::Http10 => "1.0", - Self::Http11 => "1.1", - Self::H2 => "2", - } - } -} - -/// Error reading or limiting a request/response body. +/// Errors during transport setup. #[derive(Debug, thiserror::Error)] -pub enum BodyError { - /// Body exceeded the configured size limit. - #[error("body exceeds size limit of {limit} bytes")] - TooLarge { - /// Configured limit. - limit: usize, +pub enum TransportError { + /// Bind to the requested address failed. + #[error("failed to bind {addr}: {source}")] + Bind { + /// Address we tried to bind. + addr: SocketAddr, + /// OS error. + source: std::io::Error, + }, + /// Invalid host string (not a valid IP address). + #[error("invalid host {host:?}: {source}")] + InvalidHost { + /// The unparseable host string. + host: String, + /// Parse error. + source: std::net::AddrParseError, }, - /// IO error while reading body stream. - #[error("body read error: {0}")] - Io(#[from] std::io::Error), -} - -/// Transport-neutral request body. -/// -/// Abstracts over pre-buffered bodies (HTTP/1.1) and streamed bodies -/// (HTTP/2 DATA frames, future HTTP/3 streams). -/// -/// Rule: Body is always a stream interface, never a concrete type. -pub enum BodyStream { - /// No body (GET, HEAD, DELETE), or body already taken via `take_body()`. - Empty, - /// Fully buffered body (small POST/PUT, already read by transport). - Buffered(Bytes), - /// Streaming body (large uploads, chunked transfer, h2/h3 streams). - Stream(Pin> + Send>>), -} - -impl std::fmt::Debug for BodyStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Empty => f.write_str("BodyStream::Empty"), - Self::Buffered(b) => write!(f, "BodyStream::Buffered({} bytes)", b.len()), - Self::Stream(_) => f.write_str("BodyStream::Stream(...)"), - } - } -} - -impl BodyStream { - /// Read the full body into a single `Bytes`, respecting the size limit. - /// - /// For `Stream` variant, collects chunks up to the limit. - /// - /// # Errors - /// - /// Returns `BodyError::TooLarge` if the collected body exceeds `limit`. - pub async fn collect(self, limit: usize) -> Result { - match self { - Self::Empty => Ok(Bytes::new()), - Self::Buffered(b) => { - if b.len() > limit { - return Err(BodyError::TooLarge { limit }); - } - Ok(b) - } - Self::Stream(mut stream) => collect_stream(&mut stream, limit).await, - } - } - - /// True if there is no body content. - pub fn is_empty(&self) -> bool { - matches!(self, Self::Empty) - } -} - -/// Default initial capacity for streaming body collection. -/// -/// Sized to hold a typical small API request body without reallocation. -/// Capped independently of the body size limit to avoid over-allocating -/// for endpoints with large limits but small actual bodies. -const STREAM_COLLECT_INITIAL_CAPACITY: usize = 4096; - -/// Read the next chunk from the stream. -async fn next_chunk( - stream: &mut Pin> + Send>>, -) -> Result, BodyError> { - match std::future::poll_fn(|cx| Pin::as_mut(stream).poll_next(cx)).await { - Some(Ok(bytes)) => Ok(Some(bytes)), - Some(Err(e)) => Err(BodyError::Io(e)), - None => Ok(None), - } -} - -/// Collect a stream of bytes into a single `Bytes`, enforcing a size limit. -/// -/// Fast path: single-chunk bodies (common for small JSON payloads) are returned -/// directly without any copy or concatenation. Multi-chunk bodies use `BytesMut` -/// for contiguous concatenation, then `freeze()` for zero-copy conversion to `Bytes`. -async fn collect_stream( - stream: &mut Pin> + Send>>, - limit: usize, -) -> Result { - // Read the first chunk. - let first = match next_chunk(stream).await? { - Some(bytes) => { - if bytes.len() > limit { - return Err(BodyError::TooLarge { limit }); - } - bytes - } - None => return Ok(Bytes::new()), - }; - - // Fast path: if the stream is exhausted after one chunk, return it directly (zero-copy). - let Some(second) = next_chunk(stream).await? else { - return Ok(first); - }; - - // Multi-chunk: concatenate into BytesMut. - let mut buf = BytesMut::with_capacity(limit.min(STREAM_COLLECT_INITIAL_CAPACITY)); - buf.extend_from_slice(&first); - - if buf.len() + second.len() > limit { - return Err(BodyError::TooLarge { limit }); - } - buf.extend_from_slice(&second); - - loop { - match next_chunk(stream).await? { - Some(bytes) => { - if buf.len() + bytes.len() > limit { - return Err(BodyError::TooLarge { limit }); - } - buf.extend_from_slice(&bytes); - } - None => return Ok(buf.freeze()), - } - } -} - -/// Transport-neutral HTTP request. -/// -/// Constructed by the transport/protocol layer (hyper, future quinn). -/// Consumed by the application layer (routing, dispatch, ASGI adapter). -/// This is the architectural boundary between transport-specific and -/// transport-agnostic code. -/// -/// Body ownership: the body is taken once via `take_body()` (returns the -/// `BodyStream` and replaces it with `Empty`). After that, the request -/// can still be borrowed for scope construction, header access, etc. -pub struct InboundRequest { - /// HTTP method. - pub method: http::Method, - /// Request path (without query string). - pub path: String, - /// Raw query string bytes. - pub query_string: Bytes, - /// HTTP headers. - pub headers: HeaderMap, - /// Request body (private — use `take_body()`). - body: BodyStream, - /// HTTP protocol version. - pub protocol: ProtocolVersion, - /// Which transport carried this request. - pub transport: TransportKind, - /// Client socket address (if available). - pub client_addr: Option, - /// Server socket address. - pub server_addr: SocketAddr, - /// Path parameters extracted by the router (populated after routing). - pub path_params: Vec<(String, String)>, - /// Opaque extensions for middleware (trace ids, etc.). - pub extensions: http::Extensions, -} - -impl std::fmt::Debug for InboundRequest { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InboundRequest") - .field("method", &self.method) - .field("path", &self.path) - .field("protocol", &self.protocol) - .field("transport", &self.transport) - .finish_non_exhaustive() - } -} - -impl InboundRequest { - /// Construct a new `InboundRequest`. Called by `transport/convert.rs`. - #[expect( - clippy::too_many_arguments, - reason = "constructor mirrors HTTP request fields" - )] - pub fn new( - method: http::Method, - path: String, - query_string: Bytes, - headers: HeaderMap, - body: BodyStream, - protocol: ProtocolVersion, - transport: TransportKind, - client_addr: Option, - server_addr: SocketAddr, - path_params: Vec<(String, String)>, - extensions: http::Extensions, - ) -> Self { - Self { - method, - path, - query_string, - headers, - body, - protocol, - transport, - client_addr, - server_addr, - path_params, - extensions, - } - } - - /// Take the body out, replacing it with `BodyStream::Empty`. - /// - /// Call this once before reading the body. After this, the request - /// can still be borrowed for `build_http_scope`, header access, etc. - pub fn take_body(&mut self) -> BodyStream { - std::mem::replace(&mut self.body, BodyStream::Empty) - } - - /// Whether the request still has a body (not yet taken). - pub fn has_body(&self) -> bool { - !self.body.is_empty() - } -} - -/// Transport-neutral HTTP response. -/// -/// Constructed by dispatch/ASGI adapter. Consumed by transport layer -/// to write the response back over the wire. -pub struct OutboundResponse { - /// HTTP status code. - pub status: http::StatusCode, - /// Response headers. - pub headers: HeaderMap, - /// Response body. - pub body: ResponseBody, - /// Matched route template extracted from the ASGI scope (e.g. `/users/{user_id}`). - pub server_route: Option, -} - -impl std::fmt::Debug for OutboundResponse { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("OutboundResponse") - .field("status", &self.status) - .finish_non_exhaustive() - } -} - -/// Response body — either fixed or streaming. -pub enum ResponseBody { - /// Complete body, known length. - Fixed(Bytes), - /// Streaming body (SSE, chunked, large responses). - Stream(Pin> + Send>>), -} - -impl std::fmt::Debug for ResponseBody { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Fixed(b) => write!(f, "ResponseBody::Fixed({} bytes)", b.len()), - Self::Stream(_) => f.write_str("ResponseBody::Stream(...)"), - } - } -} - -impl http_body::Body for ResponseBody { - type Data = Bytes; - type Error = Box; - - fn poll_frame( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - let this = self.get_mut(); - match this { - Self::Fixed(bytes) => { - if bytes.is_empty() { - return Poll::Ready(None); - } - let data = std::mem::replace(bytes, Bytes::new()); - Poll::Ready(Some(Ok(Frame::data(data)))) - } - Self::Stream(stream) => match Pin::as_mut(stream).poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => Poll::Ready(Some(Ok(Frame::data(chunk)))), - Poll::Ready(Some(Err(e))) => { - let err: Box = Box::new(e); - Poll::Ready(Some(Err(err))) - } - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - }, - } - } - - fn is_end_stream(&self) -> bool { - match self { - Self::Fixed(bytes) => bytes.is_empty(), - Self::Stream(_) => false, - } - } - - fn size_hint(&self) -> http_body::SizeHint { - match self { - Self::Fixed(bytes) => http_body::SizeHint::with_exact(bytes.len() as u64), - Self::Stream(_) => http_body::SizeHint::default(), - } - } -} - -#[cfg(test)] -#[expect( - clippy::unwrap_used, - reason = "test code uses unwrap/assert for clarity" -)] -mod tests { - use super::*; - - #[test] - fn transport_kind_serde_roundtrip() { - for kind in [ - TransportKind::Tcp, - TransportKind::Unix, - TransportKind::InMemory, - ] { - let json = serde_json::to_string(&kind).ok(); - assert!(json.is_some(), "serialize {kind:?}"); - let back: TransportKind = - serde_json::from_str(json.as_deref().unwrap_or("")).unwrap_or(TransportKind::Tcp); - assert_eq!(kind, back); - } - } - - #[test] - fn protocol_version_asgi_string() { - assert_eq!(ProtocolVersion::Http10.as_asgi_version(), "1.0"); - assert_eq!(ProtocolVersion::Http11.as_asgi_version(), "1.1"); - assert_eq!(ProtocolVersion::H2.as_asgi_version(), "2"); - } - - #[test] - fn body_stream_empty_is_empty() { - assert!(BodyStream::Empty.is_empty()); - assert!(!BodyStream::Buffered(Bytes::from_static(b"x")).is_empty()); - } - - #[tokio::test] - async fn body_stream_buffered_collect() { - let body = BodyStream::Buffered(Bytes::from_static(b"hello")); - let result = body.collect(1024).await; - assert!(result.is_ok()); - assert_eq!(result.ok().as_deref(), Some(b"hello".as_slice())); - } - - #[tokio::test] - async fn body_stream_collect_exceeds_limit() { - let body = BodyStream::Buffered(Bytes::from_static(b"hello world")); - let result = body.collect(5).await; - assert!(matches!(result, Err(BodyError::TooLarge { limit: 5 }))); - } - - #[test] - fn inbound_request_construction() { - let req = InboundRequest::new( - http::Method::GET, - "/test".to_owned(), - Bytes::new(), - HeaderMap::new(), - BodyStream::Empty, - ProtocolVersion::Http11, - TransportKind::Tcp, - None, - SocketAddr::from(([127, 0, 0, 1], 8080)), - Vec::new(), - http::Extensions::new(), - ); - assert_eq!(req.method, http::Method::GET); - assert_eq!(req.path, "/test"); - assert!(!req.has_body()); - } - - #[test] - fn inbound_request_take_body() { - let mut req = InboundRequest::new( - http::Method::POST, - "/upload".to_owned(), - Bytes::new(), - HeaderMap::new(), - BodyStream::Buffered(Bytes::from_static(b"data")), - ProtocolVersion::Http11, - TransportKind::Tcp, - None, - SocketAddr::from(([127, 0, 0, 1], 8080)), - Vec::new(), - http::Extensions::new(), - ); - assert!(req.has_body()); - let body = req.take_body(); - assert!(!req.has_body()); - assert!(matches!(body, BodyStream::Buffered(_))); - } - - #[test] - fn outbound_response_construction() { - let resp = OutboundResponse { - status: http::StatusCode::OK, - headers: HeaderMap::new(), - body: ResponseBody::Fixed(Bytes::from_static(b"ok")), - server_route: None, - }; - assert_eq!(resp.status, http::StatusCode::OK); - } - - #[test] - fn response_body_fixed_vs_stream() { - let fixed = ResponseBody::Fixed(Bytes::from_static(b"hello")); - assert!(matches!(fixed, ResponseBody::Fixed(_))); - - let empty_stream = tokio_stream::empty::>(); - let stream_body = ResponseBody::Stream(Box::pin(empty_stream)); - assert!(matches!(stream_body, ResponseBody::Stream(_))); - } - - #[tokio::test] - async fn body_stream_empty_collect() { - let body = BodyStream::Empty; - let result = body.collect(0).await; - assert!(result.is_ok()); - assert!(result.unwrap().is_empty()); - } - - #[tokio::test] - async fn body_stream_stream_collect_success() { - let chunks = vec![Ok(Bytes::from("hello")), Ok(Bytes::from(" world"))]; - let stream = tokio_stream::iter(chunks); - let body = BodyStream::Stream(Box::pin(stream)); - let result = body.collect(1024).await; - assert!(result.is_ok()); - assert_eq!(result.unwrap().as_ref(), b"hello world"); - } - - #[tokio::test] - async fn body_stream_stream_collect_io_error() { - let chunks: Vec> = vec![ - Ok(Bytes::from("ok")), - Err(std::io::Error::other("stream failed")), - ]; - let stream = tokio_stream::iter(chunks); - let body = BodyStream::Stream(Box::pin(stream)); - let result = body.collect(1024).await; - assert!(matches!(result, Err(BodyError::Io(_)))); - } - - #[tokio::test] - async fn body_stream_stream_collect_over_limit() { - let chunks = vec![ - Ok(Bytes::from(vec![0u8; 600])), - Ok(Bytes::from(vec![0u8; 600])), - ]; - let stream = tokio_stream::iter(chunks); - let body = BodyStream::Stream(Box::pin(stream)); - let result = body.collect(1000).await; - assert!(matches!(result, Err(BodyError::TooLarge { limit: 1000 }))); - } - - #[test] - fn body_error_display_too_large() { - let err = BodyError::TooLarge { limit: 1024 }; - let msg = format!("{err}"); - assert!(msg.contains("1024")); - } - - #[test] - fn body_error_display_io() { - let err = BodyError::Io(std::io::Error::other("read fail")); - let msg = format!("{err}"); - assert!(msg.contains("read fail")); - } - - #[test] - fn body_stream_debug_all_variants() { - let empty_dbg = format!("{:?}", BodyStream::Empty); - assert!(empty_dbg.contains("Empty")); - - let buf_dbg = format!("{:?}", BodyStream::Buffered(Bytes::from("hi"))); - assert!(buf_dbg.contains("2 bytes")); - - let stream = tokio_stream::empty::>(); - let stream_dbg = format!("{:?}", BodyStream::Stream(Box::pin(stream))); - assert!(stream_dbg.contains("Stream")); - } - - #[test] - fn inbound_request_debug() { - let req = InboundRequest::new( - http::Method::POST, - "/api/test".to_owned(), - Bytes::new(), - HeaderMap::new(), - BodyStream::Empty, - ProtocolVersion::Http11, - TransportKind::Tcp, - None, - SocketAddr::from(([127, 0, 0, 1], 8080)), - Vec::new(), - http::Extensions::new(), - ); - let dbg = format!("{req:?}"); - assert!(dbg.contains("InboundRequest")); - assert!(dbg.contains("POST")); - } - - #[test] - fn outbound_response_debug() { - let resp = OutboundResponse { - status: http::StatusCode::OK, - headers: HeaderMap::new(), - body: ResponseBody::Fixed(Bytes::from("ok")), - server_route: None, - }; - let dbg = format!("{resp:?}"); - assert!(dbg.contains("OutboundResponse")); - assert!(dbg.contains("200")); - } - - #[test] - fn response_body_debug() { - let fixed = ResponseBody::Fixed(Bytes::from("hello")); - let dbg = format!("{fixed:?}"); - assert!(dbg.contains("Fixed")); - assert!(dbg.contains("5 bytes")); - - let stream = tokio_stream::empty::>(); - let stream_body = ResponseBody::Stream(Box::pin(stream)); - let dbg = format!("{stream_body:?}"); - assert!(dbg.contains("Stream")); - } - - // ── http_body::Body impl tests ────────────────────────────────────── - - #[tokio::test] - async fn response_body_fixed_yields_data_then_none() { - use http_body::Body; - let mut body = ResponseBody::Fixed(Bytes::from_static(b"hello")); - let frame = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)) - .await - .unwrap() - .unwrap(); - assert_eq!(frame.into_data().unwrap(), Bytes::from_static(b"hello")); - let end = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await; - assert!(end.is_none()); - } - - #[test] - fn response_body_empty_fixed_is_end_stream() { - use http_body::Body; - let body = ResponseBody::Fixed(Bytes::new()); - assert!(body.is_end_stream()); - - let non_empty = ResponseBody::Fixed(Bytes::from_static(b"x")); - assert!(!non_empty.is_end_stream()); - } - - #[test] - fn response_body_fixed_size_hint_exact() { - use http_body::Body; - let body = ResponseBody::Fixed(Bytes::from_static(b"hello")); - let hint = body.size_hint(); - assert_eq!(hint.lower(), 5); - assert_eq!(hint.upper(), Some(5)); - } - - #[tokio::test] - async fn response_body_stream_yields_chunks() { - use http_body::Body; - let chunks = vec![Ok(Bytes::from("hel")), Ok(Bytes::from("lo"))]; - let stream = tokio_stream::iter(chunks); - let mut body = ResponseBody::Stream(Box::pin(stream)); - - let f1 = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)) - .await - .unwrap() - .unwrap(); - assert_eq!(f1.into_data().unwrap(), Bytes::from("hel")); - - let f2 = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)) - .await - .unwrap() - .unwrap(); - assert_eq!(f2.into_data().unwrap(), Bytes::from("lo")); - - let end = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await; - assert!(end.is_none()); - } } diff --git a/pyproject.toml b/pyproject.toml index 4bcd9f46..08525626 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,11 @@ description = "apx is the toolkit for building Databricks Apps" readme = { file = "README.md", content-type = "text/markdown" } authors = [{ name = "renardeinside", email = "polarpersonal@gmail.com" }] requires-python = ">=3.11" -dependencies = ["pydantic>=2.0", "uvloop>=0.21.0; sys_platform != 'win32'"] +dependencies = [ + "orjson>=3.11.7", + "pydantic>=2.0", + "uvloop>=0.21.0; sys_platform != 'win32'", +] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", diff --git a/src/apx/_core.pyi b/src/apx/_core.pyi index d79ba8a0..02edbef2 100644 --- a/src/apx/_core.pyi +++ b/src/apx/_core.pyi @@ -25,20 +25,46 @@ class Forbidden(Exception): ... -# ── Dispatch primitives ────────────────────────────────────────────────── +# ── HTTP protocol types ─────────────────────────────────────────────────── -class RequestQueue: - """Inbound request queue drained by the asyncio dispatch loop.""" - def try_recv(self) -> tuple[Any, Any, Any] | None: ... +class ProtocolFactory: + """Creates ``RustProtocol`` instances for ``loop.create_server``.""" + def __call__(self) -> RustProtocol: ... -class SlotReceive: - """ASGI receive() callable for the 3-thread dispatch path.""" +class RustProtocol: + """asyncio Protocol for HTTP/1.1 connections.""" + ... + +class HttpReceive: + """ASGI ``receive`` callable for HTTP requests.""" def __call__(self) -> Any: ... -class SlotSend: - """ASGI send() callable for the 3-thread dispatch path.""" +class RustRouter: + """High-performance HTTP path router.""" + def insert(self, path: str, route_id: int) -> None: ... + def match_route(self, path: str) -> tuple[int, dict[str, str]] | None: ... + +class RustResponseWriter: + """ASGI ``send`` callable that writes HTTP responses.""" + def __call__(self, event: dict[str, Any]) -> Any: ... + +# ── Lifespan types ──────────────────────────────────────────────────────── + +class LifespanReceive: + """ASGI ``receive`` callable for lifespan protocol.""" + def __init__(self, shutdown_event: Any) -> None: ... + def __call__(self) -> Any: ... + +class LifespanSend: + """ASGI ``send`` callable for lifespan protocol.""" + def __init__( + self, + startup_event: Any, + startup_result: Any, + shutdown_done_event: Any, + shutdown_result: Any, + ) -> None: ... def __call__(self, event: dict[str, Any]) -> Any: ... - def send_error(self, traceback: str) -> None: ... # ── Scheduler primitives ───────────────────────────────────────────────── diff --git a/src/apx/_dispatch.py b/src/apx/_dispatch.py deleted file mode 100644 index f2f396ca..00000000 --- a/src/apx/_dispatch.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Zero-GIL dispatch loop with inline coroutine driving. - -Installs an fd-based wakeup on the asyncio event loop. Drains -requests from the Rust crossbeam channel and drives each ASGI -coroutine inline. Simple handlers complete in ~21us with zero -event loop scheduling. Handlers that suspend on real I/O fall -back to callback-based continuation. - -Called once from Rust during reactor init via -``py.import(c"apx._dispatch")?.call_method1(c"install_dispatch", ...)``. -""" - -from __future__ import annotations - -import asyncio -import os -import traceback -from collections.abc import Coroutine -from typing import Any, Callable - -from apx._continuation import Continuation -from apx._core import RequestQueue -from apx._scheduler import ( - CallSoonCapture, - Completed, - Failed, - SchedulerTask, - Suspended, - drive_inline, -) - -MAX_DRAIN_BATCH: int = 32 - -def install_dispatch( - loop: asyncio.AbstractEventLoop, - queue: RequestQueue, - app: Callable[..., Coroutine[Any, Any, None]], - wakeup_fd: int | None = None, -) -> None: - """Install the inline dispatch driver on the asyncio event loop.""" - capture = CallSoonCapture(loop) - async def _guarded( - scope: dict[str, Any], - receive: Any, - send: Any, - ) -> None: - try: - await app(scope, receive, send) - except Exception as exc: - tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) - send.send_error(tb) - - def _dispatch_inline( - scope: dict[str, Any], - receive: Any, - send: Any, - ) -> None: - """Drive one request inline. Falls back on suspension.""" - coro = _guarded(scope, receive, send) - try: - task = SchedulerTask(loop=loop) - - capture.enter() - result = drive_inline(coro, task, loop, capture) - capture.leave() - except BaseException: - coro.close() - raise - - if isinstance(result, Completed): - return - elif isinstance(result, Failed): - tb = "".join( - traceback.format_exception( - type(result.exc), result.exc, result.exc.__traceback__ - ) - ) - send.send_error(tb) - return - elif isinstance(result, Suspended): - Continuation(coro, result.yielded, loop, task, capture) - - def _drain_queue() -> None: - for _ in range(MAX_DRAIN_BATCH): - result: tuple[Any, Any, Any] | None = queue.try_recv() - if result is None: - return - scope, receive, send = result - _dispatch_inline(scope, receive, send) - loop.call_soon(_drain_queue) - - if wakeup_fd is not None: - - def _on_readable() -> None: - try: - os.read(wakeup_fd, 4096) - except BlockingIOError: - pass - _drain_queue() - - loop.call_soon_threadsafe(loop.add_reader, wakeup_fd, _on_readable) - else: - install_dispatch._drain_queue = _drain_queue # type: ignore[attr-defined, ty:unresolved-attribute] diff --git a/src/apx/_scheduler.py b/src/apx/_scheduler.py index 65db90ac..53e99ef5 100644 --- a/src/apx/_scheduler.py +++ b/src/apx/_scheduler.py @@ -1,4 +1,4 @@ -"""Inline coroutine driver for the 3-thread dispatch architecture. +"""Inline coroutine driver for ASGI request handling. Drives ASGI coroutines to completion within a single ``_run_once`` callback, eliminating ``create_task`` scheduling overhead. Falls back diff --git a/src/apx/_server.py b/src/apx/_server.py new file mode 100644 index 00000000..1284bee0 --- /dev/null +++ b/src/apx/_server.py @@ -0,0 +1,194 @@ +"""ASGI server using Rust-accelerated protocol and inline driving. + +Uses asyncio ``loop.create_server()`` with Rust HTTP parsing, scope +building, and response writing. +""" + +from __future__ import annotations + +import asyncio +import logging +import traceback +from collections.abc import Callable, Coroutine +from typing import Any + +from apx._continuation import Continuation +from apx._core import LifespanReceive, LifespanSend +from apx._scheduler import ( + CallSoonCapture, + Completed, + Failed, + SchedulerTask, + Suspended, + drive_inline, +) + +logger = logging.getLogger("apx.server") + +LIFESPAN_TIMEOUT = 30.0 + + +async def _guarded( + app: Callable[..., Coroutine[Any, Any, None]], + scope: dict[str, Any], + receive: Any, + send: Any, +) -> None: + """Wrap the ASGI app call with error handling.""" + try: + await app(scope, receive, send) + except Exception: + tb = traceback.format_exc() + logger.error("ASGI application error:\n%s", tb) + try: + send.send_error(tb) + except Exception: + logger.error("Failed to send error response:\n%s", traceback.format_exc()) + + +def _build_on_request( + app: Callable[..., Coroutine[Any, Any, None]], + loop: asyncio.AbstractEventLoop, + capture: CallSoonCapture, +) -> Callable[..., None]: + """Build the on_request dispatch callback with inline driving.""" + + def on_request( + scope: dict[str, Any], + receive: Any, + send: Any, + ) -> None: + coro = _guarded(app, scope, receive, send) + task = SchedulerTask(loop=loop) + capture.enter() + try: + result = drive_inline(coro, task, loop, capture) + except BaseException: + coro.close() + capture.leave() + raise + capture.leave() + if isinstance(result, Completed): + return + if isinstance(result, Failed): + try: + send.send_error(traceback.format_exc()) + except Exception: + logger.error("Failed to send error: %s", traceback.format_exc()) + return + if isinstance(result, Suspended): + Continuation(coro, result.yielded, loop, task, capture) + + return on_request + + +async def _run_lifespan( + app: Callable[..., Coroutine[Any, Any, None]], + shutdown_event: asyncio.Event, +) -> tuple[asyncio.Task[None], str]: + """Run ASGI lifespan startup; return (task, result_str). + + result_str is "complete", "failed:", or "unsupported". + """ + startup_event = asyncio.Event() + startup_result: list[str | None] = [None] + shutdown_done_event = asyncio.Event() + shutdown_result: list[str | None] = [None] + + receive = LifespanReceive(shutdown_event) + send = LifespanSend( + startup_event, startup_result, shutdown_done_event, shutdown_result + ) + + scope: dict[str, Any] = { + "type": "lifespan", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + "state": {}, + } + + lifespan_task = asyncio.create_task( + _guarded_lifespan(app, scope, receive, send, startup_event, startup_result) + ) + + try: + await asyncio.wait_for(startup_event.wait(), timeout=LIFESPAN_TIMEOUT) + except asyncio.TimeoutError: + lifespan_task.cancel() + raise RuntimeError("ASGI lifespan startup timed out") from None + + result = startup_result[0] or "unsupported" + return lifespan_task, result + + +async def _guarded_lifespan( + app: Callable[..., Coroutine[Any, Any, None]], + scope: dict[str, Any], + receive: Any, + send: Any, + startup_event: asyncio.Event, + startup_result: list[str | None], +) -> None: + """Run the ASGI lifespan protocol with error handling.""" + try: + await app(scope, receive, send) + except Exception: + tb = traceback.format_exc() + logger.warning("ASGI lifespan not supported or errored:\n%s", tb) + if startup_result[0] is None: + startup_result[0] = "unsupported" + startup_event.set() + + +async def serve( + host: str, + port: int, + app: Callable[..., Coroutine[Any, Any, None]], + protocol_factory: Any, + *, + shutdown_event: asyncio.Event | None = None, +) -> None: + """Run the ASGI server. + + Parameters + ---------- + host: + Bind address. + port: + Bind port. + app: + ASGI application callable. + protocol_factory: + Rust ``ProtocolFactory`` (callable returning ``RustProtocol``). + shutdown_event: + When set, the server will initiate graceful shutdown. + """ + if shutdown_event is None: + shutdown_event = asyncio.Event() + + lifespan_task, lifespan_result = await _run_lifespan(app, shutdown_event) + + if lifespan_result.startswith("failed"): + msg = lifespan_result.removeprefix("failed:") + lifespan_task.cancel() + raise RuntimeError(f"ASGI lifespan startup failed: {msg}") + + logger.debug("lifespan startup: %s", lifespan_result) + + server = await asyncio.get_running_loop().create_server( + protocol_factory, + host, + port, + reuse_port=True, + ) + + async with server: + await shutdown_event.wait() + server.close() + await server.wait_closed() + + shutdown_event.set() + try: + await asyncio.wait_for(lifespan_task, timeout=LIFESPAN_TIMEOUT) + except asyncio.TimeoutError: + logger.warning("lifespan shutdown timed out, cancelling") + lifespan_task.cancel() diff --git a/src/apx/telemetry.py b/src/apx/telemetry.py index f5131c4b..3bf0dd37 100644 --- a/src/apx/telemetry.py +++ b/src/apx/telemetry.py @@ -691,25 +691,25 @@ class HttpMetrics(BaseModel): class ApxMetrics(BaseModel): - """APX framework dispatch pipeline metric toggles. + """APX per-worker request pipeline metric toggles. - Collected per-worker. Each histogram records latency for the - dispatch phases within a single worker process. Use - ``apx.worker.id`` to drill down; aggregate across workers for - server-wide distributions. + Each histogram records latency for a phase of request handling + within a single worker process. Use ``apx.worker.id`` to drill + down; aggregate across workers for server-wide distributions. If APX_PERF environment variable is not set, none of these metrics are collected. """ - dispatch_body_collect: bool = True - dispatch_crossbeam_send: bool = True - dispatch_response_wait: bool = True - dispatch_total: bool = True - asgi_receive_build: bool = True - asgi_send_parse: bool = True - dispatch_pickup_delay: bool = True - dispatch_materialize: bool = True - dispatch_queue_depth: bool = True + parse: bool = True + scope_build: bool = True + receive_build: bool = True + send_parse: bool = True + response_build: bool = True + response_write: bool = True + handler_wait: bool = True + request_total: bool = True + active_requests: bool = True + connections: bool = True class CaptureHeaders(BaseModel): diff --git a/tests/telemetry/test_dispatch_metrics.py b/tests/telemetry/test_dispatch_metrics.py index cf2d7842..37ef9377 100644 --- a/tests/telemetry/test_dispatch_metrics.py +++ b/tests/telemetry/test_dispatch_metrics.py @@ -1,9 +1,9 @@ -"""Verify all 9 APX dispatch pipeline metrics are collected when APX_PERF=1. +"""Verify APX request pipeline metrics are collected when APX_PERF=1. The telemetry_container fixture sets ``APX_PERF=1``, which enables all ``ApxMetrics`` toggles via ``_apx_perf_enabled()``. After sending HTTP -requests that exercise the full dispatch pipeline (receive + send), all -9 histogram metrics must appear in the OTEL collector output. +requests that exercise the full request pipeline, histogram and gauge +metrics must appear in the OTEL collector output. """ from __future__ import annotations @@ -19,22 +19,28 @@ wait_for_collector_data, ) -APX_DISPATCH_METRICS = { - "apx.dispatch.body_collect.duration", - "apx.dispatch.crossbeam_send.duration", - "apx.dispatch.response_wait.duration", - "apx.dispatch.total.duration", - "apx.asgi.receive_build.duration", - "apx.asgi.send_parse.duration", - "apx.dispatch.pickup_delay.duration", - "apx.dispatch.materialize.duration", - "apx.dispatch.queue_depth", +APX_HISTOGRAM_METRICS = { + "apx.parse", + "apx.scope_build", + "apx.receive_build", + "apx.send_parse", + "apx.response_build", + "apx.response_write", + "apx.handler_wait", + "apx.request_total", } +APX_GAUGE_METRICS = { + "apx.active_requests", + "apx.connections", +} + +APX_ALL_METRICS = APX_HISTOGRAM_METRICS | APX_GAUGE_METRICS + @pytest.mark.integration class TestDispatchMetrics: - """All APX dispatch histograms must appear when APX_PERF is enabled.""" + """APX request pipeline metrics must appear when APX_PERF is enabled.""" @pytest.fixture(autouse=True, scope="class") def _setup( @@ -49,37 +55,36 @@ def _setup( wait_for_collector_data(otel_collector) def test_all_dispatch_metrics_present(self, otel_collector: OtelCollector) -> None: - """Every APX dispatch histogram must have at least one data point.""" + """Every APX metric must have at least one data point.""" collected_names = {m.name for _, m in flat_metrics_with_scope(otel_collector)} - missing = APX_DISPATCH_METRICS - collected_names + missing = APX_ALL_METRICS - collected_names assert not missing, ( - f"Missing APX dispatch metrics: {sorted(missing)}. " + f"Missing APX metrics: {sorted(missing)}. " f"Collected metric names: {sorted(collected_names)}" ) - def test_dispatch_metrics_are_histograms( + def test_histogram_metrics_are_histograms( self, otel_collector: OtelCollector ) -> None: - """APX dispatch metrics must be exported as histograms.""" + """APX histogram metrics must be exported as histograms.""" for _, m in flat_metrics_with_scope(otel_collector): - if m.name in APX_DISPATCH_METRICS: + if m.name in APX_HISTOGRAM_METRICS: assert m.histogram is not None, ( f"{m.name} should be a histogram, got sum={m.sum} gauge={m.gauge}" ) - def test_dispatch_metrics_unit_is_microseconds( + def test_histogram_metrics_unit_is_microseconds( self, otel_collector: OtelCollector ) -> None: - """APX dispatch duration metrics must report in microseconds.""" - duration_metrics = {n for n in APX_DISPATCH_METRICS if n.endswith(".duration")} + """APX histogram metrics must report in microseconds.""" for _, m in flat_metrics_with_scope(otel_collector): - if m.name in duration_metrics: + if m.name in APX_HISTOGRAM_METRICS: assert m.unit == "us", f"{m.name} unit should be 'us', got {m.unit!r}" - def test_queue_depth_unit_is_dimensionless( + def test_gauge_metrics_unit_is_dimensionless( self, otel_collector: OtelCollector ) -> None: - """queue_depth is a count, not a duration — unit must be '1'.""" + """Gauge metrics (active_requests, connections) use dimensionless unit.""" for _, m in flat_metrics_with_scope(otel_collector): - if m.name == "apx.dispatch.queue_depth": + if m.name in APX_GAUGE_METRICS: assert m.unit == "1", f"{m.name} unit should be '1', got {m.unit!r}" diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index 4f3fb5a1..bd1f806c 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -113,21 +113,22 @@ def test_disable_both(self) -> None: class TestApxMetricsDefaults: def test_all_enabled_by_default(self) -> None: m = ApxMetrics() - assert m.dispatch_body_collect is True - assert m.dispatch_crossbeam_send is True - assert m.dispatch_response_wait is True - assert m.dispatch_total is True - assert m.asgi_receive_build is True - assert m.asgi_send_parse is True - assert m.dispatch_pickup_delay is True - assert m.dispatch_materialize is True - assert m.dispatch_queue_depth is True + assert m.parse is True + assert m.scope_build is True + assert m.receive_build is True + assert m.send_parse is True + assert m.response_build is True + assert m.response_write is True + assert m.handler_wait is True + assert m.request_total is True + assert m.active_requests is True + assert m.connections is True def test_disable_selective(self) -> None: - m = ApxMetrics(dispatch_total=False, asgi_send_parse=False) - assert m.dispatch_total is False - assert m.asgi_send_parse is False - assert m.dispatch_body_collect is True + m = ApxMetrics(request_total=False, send_parse=False) + assert m.request_total is False + assert m.send_parse is False + assert m.parse is True # ── Instrumentation models ──────────────────────────────────────────────── @@ -169,7 +170,7 @@ def test_apx_defaults(self) -> None: a = ApxInstrumentation() assert a.type == "apx" assert a.enabled is True - assert a.metrics.dispatch_total is True + assert a.metrics.request_total is True def test_discriminated_union_from_dict(self) -> None: """Configuration parses typed dicts via the discriminated union.""" @@ -178,7 +179,7 @@ def test_discriminated_union_from_dict(self) -> None: HttpInstrumentation(enabled=False), SystemInstrumentation(metrics=SystemMetrics(paging=True)), ProcessInstrumentation(metrics=ProcessMetrics(threads=True)), - ApxInstrumentation(metrics=ApxMetrics(dispatch_total=True)), + ApxInstrumentation(metrics=ApxMetrics(request_total=True)), ] ) types = [i.type for i in config.instrumentations] @@ -199,7 +200,7 @@ def test_discriminated_union_from_dict(self) -> None: apx = config.instrumentations[3] assert isinstance(apx, ApxInstrumentation) - assert apx.metrics.dispatch_total is True + assert apx.metrics.request_total is True # ── APX_PERF conditional defaults ───────────────────────────────────────── @@ -226,7 +227,7 @@ def test_apx_added_via_user_configure(self) -> None: configure( Configuration( instrumentations=[ - ApxInstrumentation(metrics=ApxMetrics(dispatch_total=True)) + ApxInstrumentation(metrics=ApxMetrics(request_total=True)) ] ) ) @@ -234,7 +235,7 @@ def test_apx_added_via_user_configure(self) -> None: types = [i["type"] for i in config["instrumentations"]] assert "apx" in types apx = next(i for i in config["instrumentations"] if i["type"] == "apx") - assert apx["metrics"]["dispatch_total"] is True + assert apx["metrics"]["request_total"] is True def test_apx_in_defaults_with_env(self) -> None: """With APX_PERF=1, default config includes 'apx' instrumentation.""" @@ -298,7 +299,7 @@ def test_override_preserves_unmentioned_defaults(self) -> None: configure( Configuration( instrumentations=[ - ApxInstrumentation(metrics=ApxMetrics(dispatch_total=True)) + ApxInstrumentation(metrics=ApxMetrics(request_total=True)) ] ) ) @@ -378,15 +379,16 @@ def test_full_serialization_roundtrip(self) -> None: } EXPECTED_APX_METRICS = { - "apx.dispatch.body_collect.duration", - "apx.dispatch.crossbeam_send.duration", - "apx.dispatch.response_wait.duration", - "apx.dispatch.total.duration", - "apx.asgi.receive_build.duration", - "apx.asgi.send_parse.duration", - "apx.dispatch.pickup_delay.duration", - "apx.dispatch.materialize.duration", - "apx.dispatch.queue_depth", + "apx.parse", + "apx.scope_build", + "apx.receive_build", + "apx.send_parse", + "apx.response_build", + "apx.response_write", + "apx.handler_wait", + "apx.request_total", + "apx.active_requests", + "apx.connections", } @@ -397,7 +399,7 @@ def test_returns_list(self) -> None: def test_count(self) -> None: catalog = metric_catalog() - assert len(catalog) == 19 + assert len(catalog) == 20 def test_entry_type(self) -> None: catalog = metric_catalog() @@ -510,15 +512,16 @@ def test_completeness_against_all_known_metrics(self) -> None: "process.thread.count": "1", "http.server.request.duration": "s", "http.server.active_requests": "1", - "apx.dispatch.body_collect.duration": "us", - "apx.dispatch.crossbeam_send.duration": "us", - "apx.dispatch.response_wait.duration": "us", - "apx.dispatch.total.duration": "us", - "apx.asgi.receive_build.duration": "us", - "apx.asgi.send_parse.duration": "us", - "apx.dispatch.pickup_delay.duration": "us", - "apx.dispatch.materialize.duration": "us", - "apx.dispatch.queue_depth": "1", + "apx.parse": "us", + "apx.scope_build": "us", + "apx.receive_build": "us", + "apx.send_parse": "us", + "apx.response_build": "us", + "apx.response_write": "us", + "apx.handler_wait": "us", + "apx.request_total": "us", + "apx.active_requests": "1", + "apx.connections": "1", } EXPECTED_DESCRIPTIONS: dict[str, str] = { @@ -532,15 +535,16 @@ def test_completeness_against_all_known_metrics(self) -> None: "process.thread.count": "Number of threads in the process", "http.server.request.duration": "Duration of HTTP server requests", "http.server.active_requests": "Number of in-flight HTTP server requests", - "apx.dispatch.body_collect.duration": "Time to collect the request body from the network stream", - "apx.dispatch.crossbeam_send.duration": "Time to push the request slot to the crossbeam channel and signal wakeup", - "apx.dispatch.response_wait.duration": "Time waiting for the Python handler to produce a response", - "apx.dispatch.total.duration": "Total dispatch duration from body collect start to response ready", - "apx.asgi.receive_build.duration": "Time to build the ASGI receive dict for the Python handler", - "apx.asgi.send_parse.duration": "Time to parse the ASGI send event dict from the Python handler", - "apx.dispatch.pickup_delay.duration": "Time from slot creation to asyncio thread pickup", - "apx.dispatch.materialize.duration": "Time to build ASGI scope and receive/send callables", - "apx.dispatch.queue_depth": "Pending request slots in the crossbeam channel at drain time", + "apx.parse": "HTTP request parsing time", + "apx.scope_build": "ASGI scope dict construction time", + "apx.receive_build": "ASGI receive dict construction time", + "apx.send_parse": "ASGI send event parsing time", + "apx.response_build": "HTTP response header construction time", + "apx.response_write": "Transport write time", + "apx.handler_wait": "Handler execution time", + "apx.request_total": "Total request processing time", + "apx.active_requests": "In-flight requests on this worker", + "apx.connections": "Active TCP connections on this worker", } diff --git a/uv.lock b/uv.lock index 7c843fad..99228028 100644 --- a/uv.lock +++ b/uv.lock @@ -33,6 +33,7 @@ wheels = [ name = "apx" source = { editable = "." } dependencies = [ + { name = "orjson" }, { name = "pydantic" }, { name = "uvloop", marker = "sys_platform != 'win32'" }, ] @@ -85,6 +86,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "orjson", specifier = ">=3.11.7" }, { name = "pydantic", specifier = ">=2.0" }, { name = "uvloop", marker = "sys_platform != 'win32'", specifier = ">=0.21.0" }, ] @@ -794,6 +796,74 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/b2/189b2577dde745b15625b3214302605b1353436219d42b7912e77fa8dc24/opentelemetry_proto-1.40.0-py3-none-any.whl", hash = "sha256:266c4385d88923a23d63e353e9761af0f47a6ed0d486979777fe4de59dc9b25f", size = 72073, upload-time = "2026-03-04T14:17:16.673Z" }, ] +[[package]] +name = "orjson" +version = "3.11.7" +source = { registry = "https://pypi-proxy.dev.databricks.com/simple/" } +sdist = { url = "https://files.pythonhosted.org/packages/53/45/b268004f745ede84e5798b48ee12b05129d19235d0e15267aa57dcdb400b/orjson-3.11.7.tar.gz", hash = "sha256:9b1a67243945819ce55d24a30b59d6a168e86220452d2c96f4d1f093e71c0c49", size = 6144992, upload-time = "2026-02-02T15:38:49.29Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/37/02/da6cb01fc6087048d7f61522c327edf4250f1683a58a839fdcc435746dd5/orjson-3.11.7-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9487abc2c2086e7c8eb9a211d2ce8855bae0e92586279d0d27b341d5ad76c85c", size = 228664, upload-time = "2026-02-02T15:37:25.542Z" }, + { url = "https://files.pythonhosted.org/packages/c1/c2/5885e7a5881dba9a9af51bc564e8967225a642b3e03d089289a35054e749/orjson-3.11.7-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:79cacb0b52f6004caf92405a7e1f11e6e2de8bdf9019e4f76b44ba045125cd6b", size = 125344, upload-time = "2026-02-02T15:37:26.92Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1d/4e7688de0a92d1caf600dfd5fb70b4c5bfff51dfa61ac555072ef2d0d32a/orjson-3.11.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2e85fe4698b6a56d5e2ebf7ae87544d668eb6bde1ad1226c13f44663f20ec9e", size = 128404, upload-time = "2026-02-02T15:37:28.108Z" }, + { url = "https://files.pythonhosted.org/packages/2f/b2/ec04b74ae03a125db7bd69cffd014b227b7f341e3261bf75b5eb88a1aa92/orjson-3.11.7-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8d14b71c0b12963fe8a62aac87119f1afdf4cb88a400f61ca5ae581449efcb5", size = 123677, upload-time = "2026-02-02T15:37:30.287Z" }, + { url = "https://files.pythonhosted.org/packages/4c/69/f95bdf960605f08f827f6e3291fe243d8aa9c5c9ff017a8d7232209184c3/orjson-3.11.7-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91c81ef070c8f3220054115e1ef468b1c9ce8497b4e526cb9f68ab4dc0a7ac62", size = 128950, upload-time = "2026-02-02T15:37:31.595Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1b/de59c57bae1d148ef298852abd31909ac3089cff370dfd4cd84cc99cbc42/orjson-3.11.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:411ebaf34d735e25e358a6d9e7978954a9c9d58cfb47bc6683cdc3964cd2f910", size = 141756, upload-time = "2026-02-02T15:37:32.985Z" }, + { url = "https://files.pythonhosted.org/packages/ee/9e/9decc59f4499f695f65c650f6cfa6cd4c37a3fbe8fa235a0a3614cb54386/orjson-3.11.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a16bcd08ab0bcdfc7e8801d9c4a9cc17e58418e4d48ddc6ded4e9e4b1a94062b", size = 130812, upload-time = "2026-02-02T15:37:34.204Z" }, + { url = "https://files.pythonhosted.org/packages/28/e6/59f932bcabd1eac44e334fe8e3281a92eacfcb450586e1f4bde0423728d8/orjson-3.11.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c0b51672e466fd7e56230ffbae7f1639e18d0ce023351fb75da21b71bc2c960", size = 133444, upload-time = "2026-02-02T15:37:35.446Z" }, + { url = "https://files.pythonhosted.org/packages/f1/36/b0f05c0eaa7ca30bc965e37e6a2956b0d67adb87a9872942d3568da846ae/orjson-3.11.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:136dcd6a2e796dfd9ffca9fc027d778567b0b7c9968d092842d3c323cef88aa8", size = 138609, upload-time = "2026-02-02T15:37:36.657Z" }, + { url = "https://files.pythonhosted.org/packages/b8/03/58ec7d302b8d86944c60c7b4b82975d5161fcce4c9bc8c6cb1d6741b6115/orjson-3.11.7-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:7ba61079379b0ae29e117db13bda5f28d939766e410d321ec1624afc6a0b0504", size = 408918, upload-time = "2026-02-02T15:37:38.076Z" }, + { url = "https://files.pythonhosted.org/packages/06/3a/868d65ef9a8b99be723bd510de491349618abd9f62c826cf206d962db295/orjson-3.11.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0527a4510c300e3b406591b0ba69b5dc50031895b0a93743526a3fc45f59d26e", size = 143998, upload-time = "2026-02-02T15:37:39.706Z" }, + { url = "https://files.pythonhosted.org/packages/5b/c7/1e18e1c83afe3349f4f6dc9e14910f0ae5f82eac756d1412ea4018938535/orjson-3.11.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a709e881723c9b18acddcfb8ba357322491ad553e277cf467e1e7e20e2d90561", size = 134802, upload-time = "2026-02-02T15:37:41.002Z" }, + { url = "https://files.pythonhosted.org/packages/d4/0b/ccb7ee1a65b37e8eeb8b267dc953561d72370e85185e459616d4345bab34/orjson-3.11.7-cp311-cp311-win32.whl", hash = "sha256:c43b8b5bab288b6b90dac410cca7e986a4fa747a2e8f94615aea407da706980d", size = 127828, upload-time = "2026-02-02T15:37:42.241Z" }, + { url = "https://files.pythonhosted.org/packages/af/9e/55c776dffda3f381e0f07d010a4f5f3902bf48eaba1bb7684d301acd4924/orjson-3.11.7-cp311-cp311-win_amd64.whl", hash = "sha256:6543001328aa857187f905308a028935864aefe9968af3848401b6fe80dbb471", size = 124941, upload-time = "2026-02-02T15:37:43.444Z" }, + { url = "https://files.pythonhosted.org/packages/aa/8e/424a620fa7d263b880162505fb107ef5e0afaa765b5b06a88312ac291560/orjson-3.11.7-cp311-cp311-win_arm64.whl", hash = "sha256:1ee5cc7160a821dfe14f130bc8e63e7611051f964b463d9e2a3a573204446a4d", size = 126245, upload-time = "2026-02-02T15:37:45.18Z" }, + { url = "https://files.pythonhosted.org/packages/80/bf/76f4f1665f6983385938f0e2a5d7efa12a58171b8456c252f3bae8a4cf75/orjson-3.11.7-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:bd03ea7606833655048dab1a00734a2875e3e86c276e1d772b2a02556f0d895f", size = 228545, upload-time = "2026-02-02T15:37:46.376Z" }, + { url = "https://files.pythonhosted.org/packages/79/53/6c72c002cb13b5a978a068add59b25a8bdf2800ac1c9c8ecdb26d6d97064/orjson-3.11.7-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:89e440ebc74ce8ab5c7bc4ce6757b4a6b1041becb127df818f6997b5c71aa60b", size = 125224, upload-time = "2026-02-02T15:37:47.697Z" }, + { url = "https://files.pythonhosted.org/packages/2c/83/10e48852865e5dd151bdfe652c06f7da484578ed02c5fca938e3632cb0b8/orjson-3.11.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ede977b5fe5ac91b1dffc0a517ca4542d2ec8a6a4ff7b2652d94f640796342a", size = 128154, upload-time = "2026-02-02T15:37:48.954Z" }, + { url = "https://files.pythonhosted.org/packages/6e/52/a66e22a2b9abaa374b4a081d410edab6d1e30024707b87eab7c734afe28d/orjson-3.11.7-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b7b1dae39230a393df353827c855a5f176271c23434cfd2db74e0e424e693e10", size = 123548, upload-time = "2026-02-02T15:37:50.187Z" }, + { url = "https://files.pythonhosted.org/packages/de/38/605d371417021359f4910c496f764c48ceb8997605f8c25bf1dfe58c0ebe/orjson-3.11.7-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed46f17096e28fb28d2975834836a639af7278aa87c84f68ab08fbe5b8bd75fa", size = 129000, upload-time = "2026-02-02T15:37:51.426Z" }, + { url = "https://files.pythonhosted.org/packages/44/98/af32e842b0ffd2335c89714d48ca4e3917b42f5d6ee5537832e069a4b3ac/orjson-3.11.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3726be79e36e526e3d9c1aceaadbfb4a04ee80a72ab47b3f3c17fefb9812e7b8", size = 141686, upload-time = "2026-02-02T15:37:52.607Z" }, + { url = "https://files.pythonhosted.org/packages/96/0b/fc793858dfa54be6feee940c1463370ece34b3c39c1ca0aa3845f5ba9892/orjson-3.11.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0724e265bc548af1dedebd9cb3d24b4e1c1e685a343be43e87ba922a5c5fff2f", size = 130812, upload-time = "2026-02-02T15:37:53.944Z" }, + { url = "https://files.pythonhosted.org/packages/dc/91/98a52415059db3f374757d0b7f0f16e3b5cd5976c90d1c2b56acaea039e6/orjson-3.11.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7745312efa9e11c17fbd3cb3097262d079da26930ae9ae7ba28fb738367cbad", size = 133440, upload-time = "2026-02-02T15:37:55.615Z" }, + { url = "https://files.pythonhosted.org/packages/dc/b6/cb540117bda61791f46381f8c26c8f93e802892830a6055748d3bb1925ab/orjson-3.11.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f904c24bdeabd4298f7a977ef14ca2a022ca921ed670b92ecd16ab6f3d01f867", size = 138386, upload-time = "2026-02-02T15:37:56.814Z" }, + { url = "https://files.pythonhosted.org/packages/63/1a/50a3201c334a7f17c231eee5f841342190723794e3b06293f26e7cf87d31/orjson-3.11.7-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b9fc4d0f81f394689e0814617aadc4f2ea0e8025f38c226cbf22d3b5ddbf025d", size = 408853, upload-time = "2026-02-02T15:37:58.291Z" }, + { url = "https://files.pythonhosted.org/packages/87/cd/8de1c67d0be44fdc22701e5989c0d015a2adf391498ad42c4dc589cd3013/orjson-3.11.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:849e38203e5be40b776ed2718e587faf204d184fc9a008ae441f9442320c0cab", size = 144130, upload-time = "2026-02-02T15:38:00.163Z" }, + { url = "https://files.pythonhosted.org/packages/0f/fe/d605d700c35dd55f51710d159fc54516a280923cd1b7e47508982fbb387d/orjson-3.11.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4682d1db3bcebd2b64757e0ddf9e87ae5f00d29d16c5cdf3a62f561d08cc3dd2", size = 134818, upload-time = "2026-02-02T15:38:01.507Z" }, + { url = "https://files.pythonhosted.org/packages/e4/e4/15ecc67edb3ddb3e2f46ae04475f2d294e8b60c1825fbe28a428b93b3fbd/orjson-3.11.7-cp312-cp312-win32.whl", hash = "sha256:f4f7c956b5215d949a1f65334cf9d7612dde38f20a95f2315deef167def91a6f", size = 127923, upload-time = "2026-02-02T15:38:02.75Z" }, + { url = "https://files.pythonhosted.org/packages/34/70/2e0855361f76198a3965273048c8e50a9695d88cd75811a5b46444895845/orjson-3.11.7-cp312-cp312-win_amd64.whl", hash = "sha256:bf742e149121dc5648ba0a08ea0871e87b660467ef168a3a5e53bc1fbd64bb74", size = 125007, upload-time = "2026-02-02T15:38:04.032Z" }, + { url = "https://files.pythonhosted.org/packages/68/40/c2051bd19fc467610fed469dc29e43ac65891571138f476834ca192bc290/orjson-3.11.7-cp312-cp312-win_arm64.whl", hash = "sha256:26c3b9132f783b7d7903bf1efb095fed8d4a3a85ec0d334ee8beff3d7a4749d5", size = 126089, upload-time = "2026-02-02T15:38:05.297Z" }, + { url = "https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:1d98b30cc1313d52d4af17d9c3d307b08389752ec5f2e5febdfada70b0f8c733", size = 228390, upload-time = "2026-02-02T15:38:06.8Z" }, + { url = "https://files.pythonhosted.org/packages/a5/29/a77f48d2fc8a05bbc529e5ff481fb43d914f9e383ea2469d4f3d51df3d00/orjson-3.11.7-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:d897e81f8d0cbd2abb82226d1860ad2e1ab3ff16d7b08c96ca00df9d45409ef4", size = 125189, upload-time = "2026-02-02T15:38:08.181Z" }, + { url = "https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:814be4b49b228cfc0b3c565acf642dd7d13538f966e3ccde61f4f55be3e20785", size = 128106, upload-time = "2026-02-02T15:38:09.41Z" }, + { url = "https://files.pythonhosted.org/packages/66/da/a2e505469d60666a05ab373f1a6322eb671cb2ba3a0ccfc7d4bc97196787/orjson-3.11.7-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d06e5c5fed5caedd2e540d62e5b1c25e8c82431b9e577c33537e5fa4aa909539", size = 123363, upload-time = "2026-02-02T15:38:10.73Z" }, + { url = "https://files.pythonhosted.org/packages/23/bf/ed73f88396ea35c71b38961734ea4a4746f7ca0768bf28fd551d37e48dd0/orjson-3.11.7-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31c80ce534ac4ea3739c5ee751270646cbc46e45aea7576a38ffec040b4029a1", size = 129007, upload-time = "2026-02-02T15:38:12.138Z" }, + { url = "https://files.pythonhosted.org/packages/73/3c/b05d80716f0225fc9008fbf8ab22841dcc268a626aa550561743714ce3bf/orjson-3.11.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f50979824bde13d32b4320eedd513431c921102796d86be3eee0b58e58a3ecd1", size = 141667, upload-time = "2026-02-02T15:38:13.398Z" }, + { url = "https://files.pythonhosted.org/packages/61/e8/0be9b0addd9bf86abfc938e97441dcd0375d494594b1c8ad10fe57479617/orjson-3.11.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e54f3808e2b6b945078c41aa8d9b5834b28c50843846e97807e5adb75fa9705", size = 130832, upload-time = "2026-02-02T15:38:14.698Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12b80df61aab7b98b490fe9e4879925ba666fccdfcd175252ce4d9035865ace", size = 133373, upload-time = "2026-02-02T15:38:16.109Z" }, + { url = "https://files.pythonhosted.org/packages/d2/45/f3466739aaafa570cc8e77c6dbb853c48bf56e3b43738020e2661e08b0ac/orjson-3.11.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:996b65230271f1a97026fd0e6a753f51fbc0c335d2ad0c6201f711b0da32693b", size = 138307, upload-time = "2026-02-02T15:38:17.453Z" }, + { url = "https://files.pythonhosted.org/packages/e1/84/9f7f02288da1ffb31405c1be07657afd1eecbcb4b64ee2817b6fe0f785fa/orjson-3.11.7-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:ab49d4b2a6a1d415ddb9f37a21e02e0d5dbfe10b7870b21bf779fc21e9156157", size = 408695, upload-time = "2026-02-02T15:38:18.831Z" }, + { url = "https://files.pythonhosted.org/packages/18/07/9dd2f0c0104f1a0295ffbe912bc8d63307a539b900dd9e2c48ef7810d971/orjson-3.11.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:390a1dce0c055ddf8adb6aa94a73b45a4a7d7177b5c584b8d1c1947f2ba60fb3", size = 144099, upload-time = "2026-02-02T15:38:20.28Z" }, + { url = "https://files.pythonhosted.org/packages/a5/66/857a8e4a3292e1f7b1b202883bcdeb43a91566cf59a93f97c53b44bd6801/orjson-3.11.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1eb80451a9c351a71dfaf5b7ccc13ad065405217726b59fdbeadbcc544f9d223", size = 134806, upload-time = "2026-02-02T15:38:22.186Z" }, + { url = "https://files.pythonhosted.org/packages/0a/5b/6ebcf3defc1aab3a338ca777214966851e92efb1f30dc7fc8285216e6d1b/orjson-3.11.7-cp313-cp313-win32.whl", hash = "sha256:7477aa6a6ec6139c5cb1cc7b214643592169a5494d200397c7fc95d740d5fcf3", size = 127914, upload-time = "2026-02-02T15:38:23.511Z" }, + { url = "https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl", hash = "sha256:b9f95dcdea9d4f805daa9ddf02617a89e484c6985fa03055459f90e87d7a0757", size = 124986, upload-time = "2026-02-02T15:38:24.836Z" }, + { url = "https://files.pythonhosted.org/packages/03/ba/077a0f6f1085d6b806937246860fafbd5b17f3919c70ee3f3d8d9c713f38/orjson-3.11.7-cp313-cp313-win_arm64.whl", hash = "sha256:800988273a014a0541483dc81021247d7eacb0c845a9d1a34a422bc718f41539", size = 126045, upload-time = "2026-02-02T15:38:26.216Z" }, + { url = "https://files.pythonhosted.org/packages/e9/1e/745565dca749813db9a093c5ebc4bac1a9475c64d54b95654336ac3ed961/orjson-3.11.7-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:de0a37f21d0d364954ad5de1970491d7fbd0fb1ef7417d4d56a36dc01ba0c0a0", size = 228391, upload-time = "2026-02-02T15:38:27.757Z" }, + { url = "https://files.pythonhosted.org/packages/46/19/e40f6225da4d3aa0c8dc6e5219c5e87c2063a560fe0d72a88deb59776794/orjson-3.11.7-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:c2428d358d85e8da9d37cba18b8c4047c55222007a84f97156a5b22028dfbfc0", size = 125188, upload-time = "2026-02-02T15:38:29.241Z" }, + { url = "https://files.pythonhosted.org/packages/9d/7e/c4de2babef2c0817fd1f048fd176aa48c37bec8aef53d2fa932983032cce/orjson-3.11.7-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c4bc6c6ac52cdaa267552544c73e486fecbd710b7ac09bc024d5a78555a22f6", size = 128097, upload-time = "2026-02-02T15:38:30.618Z" }, + { url = "https://files.pythonhosted.org/packages/eb/74/233d360632bafd2197f217eee7fb9c9d0229eac0c18128aee5b35b0014fe/orjson-3.11.7-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bd0d68edd7dfca1b2eca9361a44ac9f24b078de3481003159929a0573f21a6bf", size = 123364, upload-time = "2026-02-02T15:38:32.363Z" }, + { url = "https://files.pythonhosted.org/packages/79/51/af79504981dd31efe20a9e360eb49c15f06df2b40e7f25a0a52d9ae888e8/orjson-3.11.7-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:623ad1b9548ef63886319c16fa317848e465a21513b31a6ad7b57443c3e0dcf5", size = 129076, upload-time = "2026-02-02T15:38:33.68Z" }, + { url = "https://files.pythonhosted.org/packages/67/e2/da898eb68b72304f8de05ca6715870d09d603ee98d30a27e8a9629abc64b/orjson-3.11.7-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6e776b998ac37c0396093d10290e60283f59cfe0fc3fccbd0ccc4bd04dd19892", size = 141705, upload-time = "2026-02-02T15:38:34.989Z" }, + { url = "https://files.pythonhosted.org/packages/c5/89/15364d92acb3d903b029e28d834edb8780c2b97404cbf7929aa6b9abdb24/orjson-3.11.7-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:652c6c3af76716f4a9c290371ba2e390ede06f6603edb277b481daf37f6f464e", size = 130855, upload-time = "2026-02-02T15:38:36.379Z" }, + { url = "https://files.pythonhosted.org/packages/c2/8b/ecdad52d0b38d4b8f514be603e69ccd5eacf4e7241f972e37e79792212ec/orjson-3.11.7-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a56df3239294ea5964adf074c54bcc4f0ccd21636049a2cf3ca9cf03b5d03cf1", size = 133386, upload-time = "2026-02-02T15:38:37.704Z" }, + { url = "https://files.pythonhosted.org/packages/b9/0e/45e1dcf10e17d0924b7c9162f87ec7b4ca79e28a0548acf6a71788d3e108/orjson-3.11.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:bda117c4148e81f746655d5a3239ae9bd00cb7bc3ca178b5fc5a5997e9744183", size = 138295, upload-time = "2026-02-02T15:38:39.096Z" }, + { url = "https://files.pythonhosted.org/packages/63/d7/4d2e8b03561257af0450f2845b91fbd111d7e526ccdf737267108075e0ba/orjson-3.11.7-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:23d6c20517a97a9daf1d48b580fcdc6f0516c6f4b5038823426033690b4d2650", size = 408720, upload-time = "2026-02-02T15:38:40.634Z" }, + { url = "https://files.pythonhosted.org/packages/78/cf/d45343518282108b29c12a65892445fc51f9319dc3c552ceb51bb5905ed2/orjson-3.11.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:8ff206156006da5b847c9304b6308a01e8cdbc8cce824e2779a5ba71c3def141", size = 144152, upload-time = "2026-02-02T15:38:42.262Z" }, + { url = "https://files.pythonhosted.org/packages/a9/3a/d6001f51a7275aacd342e77b735c71fa04125a3f93c36fee4526bc8c654e/orjson-3.11.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:962d046ee1765f74a1da723f4b33e3b228fe3a48bd307acce5021dfefe0e29b2", size = 134814, upload-time = "2026-02-02T15:38:43.627Z" }, + { url = "https://files.pythonhosted.org/packages/1d/d3/f19b47ce16820cc2c480f7f1723e17f6d411b3a295c60c8ad3aa9ff1c96a/orjson-3.11.7-cp314-cp314-win32.whl", hash = "sha256:89e13dd3f89f1c38a9c9eba5fbf7cdc2d1feca82f5f290864b4b7a6aac704576", size = 127997, upload-time = "2026-02-02T15:38:45.06Z" }, + { url = "https://files.pythonhosted.org/packages/12/df/172771902943af54bf661a8d102bdf2e7f932127968080632bda6054b62c/orjson-3.11.7-cp314-cp314-win_amd64.whl", hash = "sha256:845c3e0d8ded9c9271cd79596b9b552448b885b97110f628fb687aee2eed11c1", size = 124985, upload-time = "2026-02-02T15:38:46.388Z" }, + { url = "https://files.pythonhosted.org/packages/6f/1c/f2a8d8a1b17514660a614ce5f7aac74b934e69f5abc2700cc7ced882a009/orjson-3.11.7-cp314-cp314-win_arm64.whl", hash = "sha256:4a2e9c5be347b937a2e0203866f12bba36082e89b402ddb9e927d5822e43088d", size = 126038, upload-time = "2026-02-02T15:38:47.703Z" }, +] + [[package]] name = "packaging" version = "26.0" From e38f3f004e34a43e00ea7184d1e307f9dc3dd639 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Thu, 2 Apr 2026 14:25:56 +0200 Subject: [PATCH 09/18] =?UTF-8?q?=F0=9F=90=9B=20fix:=20correct=20telemetry?= =?UTF-8?q?=20gauges,=20request=5Ftotal=20lifecycle,=20and=20histogram=20b?= =?UTF-8?q?oundaries?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/core/src/tracing_init.rs | 84 ++++++++++++++-- .../src/integration_tests/telemetry.rs | 77 +++++++++++++++ crates/framework/src/protocol/connection.rs | 2 +- .../src/telemetry/dispatch_metrics.rs | 30 +++--- src/apx/_scheduler.py | 6 +- src/apx/telemetry.py | 5 +- tests/integration/test_telemetry.py | 19 ++++ tests/telemetry/test_dispatch_metrics.py | 97 +++++++++++++++++-- 8 files changed, 286 insertions(+), 34 deletions(-) diff --git a/crates/core/src/tracing_init.rs b/crates/core/src/tracing_init.rs index b68ad28c..7c30436a 100644 --- a/crates/core/src/tracing_init.rs +++ b/crates/core/src/tracing_init.rs @@ -181,19 +181,22 @@ impl opentelemetry_sdk::logs::LogProcessor for TimestampProcessor { /// Histogram boundaries for duration metrics recorded in seconds. /// -/// Aligned with OpenTelemetry HTTP semantic conventions for -/// `http.server.request.duration`. +/// Extends the OpenTelemetry HTTP semantic convention boundaries with +/// sub-millisecond resolution (100µs–2.5ms) so fast endpoints like health +/// probes get meaningful percentile estimates instead of landing in the +/// catch-all `[0, 5ms)` bucket. const DURATION_SECONDS_BOUNDARIES: &[f64] = &[ - 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, + 0.000_1, 0.000_25, 0.000_5, 0.001, 0.002_5, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, ]; /// Histogram boundaries for duration metrics recorded in microseconds. /// -/// Covers sub-millisecond dispatch latencies (body collect, crossbeam send, -/// ASGI parse) up to 100ms outliers. +/// Covers sub-microsecond dispatch phases (scope build, send parse) through +/// 100ms handler latencies with ≤2.5× gaps between adjacent boundaries. const DURATION_MICROSECONDS_BOUNDARIES: &[f64] = &[ - 1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1_000.0, 2_500.0, 5_000.0, 10_000.0, 50_000.0, - 100_000.0, + 1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1_000.0, 2_500.0, 5_000.0, 10_000.0, 25_000.0, + 50_000.0, 100_000.0, ]; /// SDK View that assigns appropriate histogram bucket boundaries based on unit. @@ -414,3 +417,70 @@ fn init_tracing_with_otel( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + fn is_strictly_ascending(boundaries: &[f64]) -> bool { + boundaries.windows(2).all(|w| w[0] < w[1]) + } + + #[test] + fn seconds_boundaries_are_strictly_ascending() { + assert!( + is_strictly_ascending(DURATION_SECONDS_BOUNDARIES), + "seconds boundaries must be sorted with no duplicates" + ); + } + + #[test] + fn microseconds_boundaries_are_strictly_ascending() { + assert!( + is_strictly_ascending(DURATION_MICROSECONDS_BOUNDARIES), + "microseconds boundaries must be sorted with no duplicates" + ); + } + + #[test] + fn seconds_boundaries_resolve_sub_millisecond() { + let sub_ms_boundaries: Vec = DURATION_SECONDS_BOUNDARIES + .iter() + .copied() + .filter(|&b| b < 0.001) + .collect(); + assert!( + sub_ms_boundaries.len() >= 3, + "expected ≥3 boundaries below 1ms for sub-millisecond resolution, \ + got {sub_ms_boundaries:?}" + ); + } + + #[test] + fn microseconds_boundaries_max_gap_ratio() { + for w in DURATION_MICROSECONDS_BOUNDARIES.windows(2) { + let ratio = w[1] / w[0]; + assert!( + ratio <= 5.1, + "gap between {:.0} and {:.0} is {:.1}×, exceeds 5× maximum", + w[0], + w[1], + ratio + ); + } + } + + #[test] + fn seconds_boundaries_max_gap_ratio() { + for w in DURATION_SECONDS_BOUNDARIES.windows(2) { + let ratio = w[1] / w[0]; + assert!( + ratio <= 5.1, + "gap between {} and {} is {:.1}×, exceeds 5× maximum", + w[0], + w[1], + ratio + ); + } + } +} diff --git a/crates/framework/src/integration_tests/telemetry.rs b/crates/framework/src/integration_tests/telemetry.rs index e82ba013..e65b24c7 100644 --- a/crates/framework/src/integration_tests/telemetry.rs +++ b/crates/framework/src/integration_tests/telemetry.rs @@ -659,5 +659,82 @@ logging.getLogger('test.handler').warning('hello from python') }); } +// ── Up-down counter / gauge semantics tests ──────────────────────────── + +#[test] +fn up_down_counter_add_semantics() { + let tt = setup(); + let _lock = EXPORT_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + tt.metric_exporter.reset(); + + let meter = opentelemetry::global::meter("test.up_down"); + let counter = meter.i64_up_down_counter("test.inflight").build(); + + counter.add(1, &[]); + counter.add(1, &[]); + counter.add(1, &[]); + counter.add(-1, &[]); + + tt.meter_provider.force_flush().unwrap(); + let metrics = tt.metric_exporter.get_finished_metrics().unwrap(); + + let mut found_value = None; + for rm in &metrics { + for sm in &rm.scope_metrics { + for m in &sm.metrics { + if m.name == "test.inflight" + && let Some(sum) = m.data.as_any().downcast_ref::>() + { + found_value = sum.data_points.first().map(|dp| dp.value); + } + } + } + } + + assert_eq!( + found_value, + Some(2), + "up-down counter: 3 increments - 1 decrement = 2" + ); +} + +#[test] +fn gauge_record_is_absolute_not_additive() { + let tt = setup(); + let _lock = EXPORT_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + tt.metric_exporter.reset(); + + let meter = opentelemetry::global::meter("test.gauge_abs"); + let gauge = meter.f64_gauge("test.last_value").build(); + + gauge.record(42.0, &[]); + gauge.record(7.0, &[]); + + tt.meter_provider.force_flush().unwrap(); + let metrics = tt.metric_exporter.get_finished_metrics().unwrap(); + + let mut found_value = None; + for rm in &metrics { + for sm in &rm.scope_metrics { + for m in &sm.metrics { + if m.name == "test.last_value" + && let Some(g) = m + .data + .as_any() + .downcast_ref::>() + { + found_value = g.data_points.first().map(|dp| dp.value); + } + } + } + } + + assert_eq!( + found_value, + Some(7.0), + "gauge records absolute value, not cumulative: last record(7.0) wins" + ); +} + // ── Full HTTP request tests removed (depend on TestServer) ───────────── // Require TestServer infrastructure not yet available in this crate. diff --git a/crates/framework/src/protocol/connection.rs b/crates/framework/src/protocol/connection.rs index f3b2e9eb..1a280415 100644 --- a/crates/framework/src/protocol/connection.rs +++ b/crates/framework/src/protocol/connection.rs @@ -189,7 +189,6 @@ impl RustProtocol { RustResponseWriter::new(py, transport.clone_ref(py), Some(on_complete.into_any()))?; self.shared.on_request.call1(py, (scope, receive, send))?; - dispatch_metrics::record_dispatch_total(t_dispatch.elapsed().as_micros() as f64); Ok(()) } } @@ -243,6 +242,7 @@ impl OnRequestComplete { { let _guard = self.request_span.enter(); dispatch_metrics::record_handler_wait(elapsed.as_micros() as f64); + dispatch_metrics::record_dispatch_total(elapsed.as_micros() as f64); crate::telemetry::http::record_duration( elapsed.as_secs_f64(), diff --git a/crates/framework/src/telemetry/dispatch_metrics.rs b/crates/framework/src/telemetry/dispatch_metrics.rs index a91c9649..d6f8841d 100644 --- a/crates/framework/src/telemetry/dispatch_metrics.rs +++ b/crates/framework/src/telemetry/dispatch_metrics.rs @@ -1,16 +1,16 @@ -//! APX request pipeline histograms. +//! APX request pipeline metrics. //! -//! Records per-phase latency for the request dispatch pipeline via OTEL -//! histograms. All instruments are lazily created on first use and guarded -//! by the `ApxMetricToggles` boolean flags — disabled metrics have zero -//! overhead. +//! Records per-phase latency via OTEL histograms and connection/request +//! counts via up-down counters. All instruments are lazily created on +//! first use and guarded by the `ApxMetricToggles` boolean flags — +//! disabled metrics have zero overhead. //! //! Toggles are initialized once per worker process via [`init`] after //! reading the Python telemetry config. use std::sync::OnceLock; -use opentelemetry::metrics::{Gauge, Histogram}; +use opentelemetry::metrics::{Histogram, UpDownCounter}; use super::config::ApxMetricToggles; use super::defs; @@ -52,25 +52,25 @@ macro_rules! dispatch_metric { }; } -/// Generate a lazy gauge getter and gated `inc_*` / `dec_*` functions. +/// Generate a lazy up-down counter getter and gated `inc_*` / `dec_*` functions. macro_rules! dispatch_gauge { - ($inc_fn:ident, $dec_fn:ident, $gauge_fn:ident, $toggle:ident, $def:expr, $doc:literal) => { - fn $gauge_fn() -> &'static Gauge { - static INST: OnceLock> = OnceLock::new(); - INST.get_or_init(|| $def.gauge(&framework_meter())) + ($inc_fn:ident, $dec_fn:ident, $counter_fn:ident, $toggle:ident, $def:expr, $doc:literal) => { + fn $counter_fn() -> &'static UpDownCounter { + static INST: OnceLock> = OnceLock::new(); + INST.get_or_init(|| $def.up_down_counter(&framework_meter())) } #[doc = $doc] pub fn $inc_fn() { if toggles().$toggle { - $gauge_fn().record(1.0, NO_ATTRS); + $counter_fn().add(1, NO_ATTRS); } } - /// Decrement the gauge. + /// Decrement the counter. pub fn $dec_fn() { if toggles().$toggle { - $gauge_fn().record(-1.0, NO_ATTRS); + $counter_fn().add(-1, NO_ATTRS); } } }; @@ -142,7 +142,7 @@ dispatch_metric!( "Record `apx.request_total` if enabled." ); -// ── Gauges ─────────────────────────────────────────────────────────────── +// ── Up-down counters ───────────────────────────────────────────────────── dispatch_gauge!( inc_active_requests, diff --git a/src/apx/_scheduler.py b/src/apx/_scheduler.py index 53e99ef5..73de18d7 100644 --- a/src/apx/_scheduler.py +++ b/src/apx/_scheduler.py @@ -235,6 +235,10 @@ def drive_inline( return Failed(exc) _leave_task(loop, task) + if result is not None and getattr(result, "_asyncio_future_blocking", False): + result._asyncio_future_blocking = False + return Suspended(result) + capture.flush() if result is None: @@ -243,6 +247,4 @@ def drive_inline( return Suspended(None) continue - if getattr(result, "_asyncio_future_blocking", False): - result._asyncio_future_blocking = False return Suspended(result) diff --git a/src/apx/telemetry.py b/src/apx/telemetry.py index 3bf0dd37..5ba9a0d7 100644 --- a/src/apx/telemetry.py +++ b/src/apx/telemetry.py @@ -693,8 +693,9 @@ class HttpMetrics(BaseModel): class ApxMetrics(BaseModel): """APX per-worker request pipeline metric toggles. - Each histogram records latency for a phase of request handling - within a single worker process. Use ``apx.worker.id`` to drill + Histograms record latency (microseconds) for each phase of request + handling. Up-down counters (``active_requests``, ``connections``) + track current in-flight counts. Use ``apx.worker.id`` to drill down; aggregate across workers for server-wide distributions. If APX_PERF environment variable is not set, none of these metrics are collected. diff --git a/tests/integration/test_telemetry.py b/tests/integration/test_telemetry.py index 2f1db6cb..c63f40a9 100644 --- a/tests/integration/test_telemetry.py +++ b/tests/integration/test_telemetry.py @@ -1027,6 +1027,25 @@ def test_custom_counter_has_data_points( return pytest.fail("test.custom_counter sum not found") + # ── Histogram bucket boundaries ─────────────────────────────────────── + + def test_http_duration_has_sub_millisecond_boundaries( + self, otel_collector: OtelCollector + ) -> None: + """The view must inject sub-ms boundaries so fast endpoints get useful percentiles.""" + for m in _flat_metrics(otel_collector): + if m.name == "http.server.request.duration" and m.histogram: + for dp in m.histogram.dataPoints: + if not dp.explicitBounds: + continue + sub_ms = [b for b in dp.explicitBounds if b < 0.001] + assert len(sub_ms) >= 3, ( + f"expected ≥3 boundaries below 1ms for sub-millisecond resolution, " + f"got {sub_ms} (all bounds: {dp.explicitBounds})" + ) + return + pytest.fail("http.server.request.duration histogram with explicitBounds not found") + # --------------------------------------------------------------------------- # Tests — span attributes and structure diff --git a/tests/telemetry/test_dispatch_metrics.py b/tests/telemetry/test_dispatch_metrics.py index 37ef9377..4e3cd41d 100644 --- a/tests/telemetry/test_dispatch_metrics.py +++ b/tests/telemetry/test_dispatch_metrics.py @@ -2,8 +2,8 @@ The telemetry_container fixture sets ``APX_PERF=1``, which enables all ``ApxMetrics`` toggles via ``_apx_perf_enabled()``. After sending HTTP -requests that exercise the full request pipeline, histogram and gauge -metrics must appear in the OTEL collector output. +requests that exercise the full request pipeline, histogram and +up-down counter metrics must appear in the OTEL collector output. """ from __future__ import annotations @@ -30,12 +30,34 @@ "apx.request_total", } -APX_GAUGE_METRICS = { +APX_UP_DOWN_COUNTER_METRICS = { "apx.active_requests", "apx.connections", } -APX_ALL_METRICS = APX_HISTOGRAM_METRICS | APX_GAUGE_METRICS +APX_ALL_METRICS = APX_HISTOGRAM_METRICS | APX_UP_DOWN_COUNTER_METRICS + + +def _histogram_count(otel_collector: OtelCollector, name: str) -> int: + """Sum observation counts across all exported histogram data points.""" + total = 0 + for _, m in flat_metrics_with_scope(otel_collector): + if m.name == name and m.histogram is not None: + for dp in m.histogram.dataPoints: + if dp.count is not None: + total += int(dp.count) + return total + + +def _histogram_sum(otel_collector: OtelCollector, name: str) -> float: + """Sum all histogram sums across exported data points.""" + total = 0.0 + for _, m in flat_metrics_with_scope(otel_collector): + if m.name == name and m.histogram is not None: + for dp in m.histogram.dataPoints: + if dp.sum is not None: + total += dp.sum + return total @pytest.mark.integration @@ -81,10 +103,71 @@ def test_histogram_metrics_unit_is_microseconds( if m.name in APX_HISTOGRAM_METRICS: assert m.unit == "us", f"{m.name} unit should be 'us', got {m.unit!r}" - def test_gauge_metrics_unit_is_dimensionless( + def test_up_down_counter_metrics_are_sum_type( + self, otel_collector: OtelCollector + ) -> None: + """active_requests and connections must be OTLP sum (UpDownCounter), not gauge.""" + for _, m in flat_metrics_with_scope(otel_collector): + if m.name in APX_UP_DOWN_COUNTER_METRICS: + assert m.sum is not None, ( + f"{m.name} should be sum (UpDownCounter), " + f"got gauge={m.gauge} histogram={m.histogram}" + ) + assert m.gauge is None, ( + f"{m.name} must not be a gauge (was migrated to UpDownCounter)" + ) + + def test_up_down_counter_metrics_unit_is_dimensionless( self, otel_collector: OtelCollector ) -> None: - """Gauge metrics (active_requests, connections) use dimensionless unit.""" + """Up-down counter metrics (active_requests, connections) use dimensionless unit.""" for _, m in flat_metrics_with_scope(otel_collector): - if m.name in APX_GAUGE_METRICS: + if m.name in APX_UP_DOWN_COUNTER_METRICS: assert m.unit == "1", f"{m.name} unit should be '1', got {m.unit!r}" + + def test_send_parse_count_gte_request_total( + self, otel_collector: OtelCollector + ) -> None: + """send_parse fires per ASGI send event (>= 2 per request), so its count must exceed request_total.""" + send_parse_count = _histogram_count(otel_collector, "apx.send_parse") + request_total_count = _histogram_count(otel_collector, "apx.request_total") + assert send_parse_count >= request_total_count, ( + f"send_parse count ({send_parse_count}) should be >= " + f"request_total count ({request_total_count}): " + f"each request has at least 2 send events (start + body)" + ) + + def test_histogram_boundaries_have_sub_millisecond_resolution( + self, otel_collector: OtelCollector + ) -> None: + """µs histograms must have boundaries below 1000µs for sub-ms latency phases.""" + for _, m in flat_metrics_with_scope(otel_collector): + if m.name == "apx.parse" and m.histogram is not None: + for dp in m.histogram.dataPoints: + if not dp.explicitBounds: + continue + sub_ms = [b for b in dp.explicitBounds if b < 1000.0] + assert len(sub_ms) >= 6, ( + f"expected ≥6 boundaries below 1000µs, got {sub_ms}" + ) + return + pytest.skip("apx.parse histogram with explicitBounds not found") + + def test_request_total_measures_full_lifecycle( + self, otel_collector: OtelCollector + ) -> None: + """request_total and handler_wait should measure the same interval.""" + rt_count = _histogram_count(otel_collector, "apx.request_total") + hw_count = _histogram_count(otel_collector, "apx.handler_wait") + if rt_count == 0 or hw_count == 0: + pytest.skip("no request_total or handler_wait observations") + + rt_mean = _histogram_sum(otel_collector, "apx.request_total") / rt_count + hw_mean = _histogram_sum(otel_collector, "apx.handler_wait") / hw_count + + ratio = rt_mean / hw_mean if hw_mean > 0 else float("inf") + assert 0.5 <= ratio <= 2.0, ( + f"request_total mean ({rt_mean:.0f}µs) and handler_wait mean " + f"({hw_mean:.0f}µs) should be in the same order of magnitude " + f"(ratio={ratio:.2f}), since both measure dispatch-to-response-complete" + ) From 326b05eb23aefd1ddff97742e30ad8bb2604a91d Mon Sep 17 00:00:00 2001 From: renardeinside Date: Thu, 2 Apr 2026 15:43:33 +0200 Subject: [PATCH 10/18] =?UTF-8?q?=F0=9F=90=9B=20fix:=20apply=20uvloop=20ev?= =?UTF-8?q?ent=20loop=20policy=20in=20oneshot=20worker?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/framework/src/supervision/worker.rs | 57 ++++++++++++++++++++-- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/crates/framework/src/supervision/worker.rs b/crates/framework/src/supervision/worker.rs index 2f082760..4801c469 100644 --- a/crates/framework/src/supervision/worker.rs +++ b/crates/framework/src/supervision/worker.rs @@ -71,6 +71,49 @@ fn init_python() { Python::initialize(); } +/// Apply the asyncio event loop policy before the loop is created. +/// +/// uvloop provides ~5-10x faster transport.write() and selector dispatch +/// compared to the default `_UnixSelectorEventLoop`. +fn install_loop_policy( + py: Python<'_>, + asyncio: &Bound<'_, PyModule>, + loop_policy: &str, +) -> Result<(), WorkerError> { + if loop_policy == "uvloop" { + match py.import(c"uvloop") { + Ok(uvloop) => { + let policy = uvloop + .call_method0(c"EventLoopPolicy") + .map_err(|e| WorkerError::Serve(format!("uvloop.EventLoopPolicy: {e}")))?; + asyncio + .call_method1(c"set_event_loop_policy", (policy,)) + .map_err(|e| WorkerError::Serve(format!("set_event_loop_policy: {e}")))?; + tracing::info!( + name: "apx.worker.loop_policy", + policy = "uvloop", + "event loop policy set" + ); + } + Err(_) => { + tracing::warn!( + name: "apx.worker.loop_policy_fallback", + requested = "uvloop", + fallback = "asyncio", + "uvloop not available, falling back to default asyncio" + ); + } + } + } else { + tracing::info!( + name: "apx.worker.loop_policy", + policy = loop_policy, + "using default asyncio event loop" + ); + } + Ok(()) +} + /// Load the Python app and read telemetry configuration. fn load_app(bootstrap: &WorkerBootstrap) -> Result { apply_python_log_config()?; @@ -167,12 +210,18 @@ fn init_metrics(telemetry: &crate::telemetry::config::TelemetryConfig) { fn run_server( ready: AppReady, shutdown_rx: tokio::sync::oneshot::Receiver<()>, + loop_policy: &str, ) -> Result<(), WorkerError> { Python::attach(|py| { let asyncio = py .import(c"asyncio") .map_err(|e| WorkerError::Serve(format!("import asyncio: {e}")))?; + // Apply the event loop policy before asyncio.run() creates the loop. + // uvloop is ~5-10x faster than the default selector loop for + // transport.write() and selector dispatch. + install_loop_policy(py, &asyncio, loop_policy)?; + let shutdown_event = asyncio .call_method0(c"Event") .map_err(|e| WorkerError::Serve(format!("create Event: {e}")))?; @@ -350,9 +399,11 @@ pub async fn run_worker( }); // Run the asyncio server (blocking — this IS the event loop). - let serve_result = tokio::task::spawn_blocking(move || run_server(ready, drain_rx)) - .await - .map_err(|e| WorkerError::Serve(format!("server task panicked: {e}")))?; + let loop_policy = bootstrap.loop_policy.clone(); + let serve_result = + tokio::task::spawn_blocking(move || run_server(ready, drain_rx, &loop_policy)) + .await + .map_err(|e| WorkerError::Serve(format!("server task panicked: {e}")))?; let _ = ipc_writer.send(&IpcMessage::Drained).await; apx_core::tracing_init::shutdown_telemetry(); From ae2e41d499116d99b2d750ed6c98bf2b4f8a9ee0 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Thu, 2 Apr 2026 17:13:35 +0200 Subject: [PATCH 11/18] =?UTF-8?q?=F0=9F=90=9B=20fix:=20prevent=20active=5F?= =?UTF-8?q?requests=20counter=20leak=20on=20connection=20close?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/framework/src/protocol/connection.rs | 17 +++++++-- crates/framework/src/protocol/writer.rs | 40 ++++++++++++++------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/crates/framework/src/protocol/connection.rs b/crates/framework/src/protocol/connection.rs index 1a280415..b5dd5ca1 100644 --- a/crates/framework/src/protocol/connection.rs +++ b/crates/framework/src/protocol/connection.rs @@ -256,11 +256,24 @@ impl OnRequestComplete { crate::telemetry::http::finish_request_span(&self.request_span, status); } - self.transport - .call_method0(py, pyo3::intern!(py, "resume_reading"))?; + // Always decrement counters, even if the transport is gone. + // If resume_reading fails (connection already closed), we must + // still release the concurrency slot — otherwise the counter + // leaks and eventually all requests get 503. + let resume_result = self + .transport + .call_method0(py, pyo3::intern!(py, "resume_reading")); self.shared.active_requests.fetch_sub(1, Ordering::Relaxed); dispatch_metrics::dec_active_requests(); crate::telemetry::http::dec_active_requests(); + + if let Err(e) = resume_result { + tracing::debug!( + name: "apx.protocol.resume_reading_failed", + error = %e, + "resume_reading failed (connection likely closed)" + ); + } Ok(()) } } diff --git a/crates/framework/src/protocol/writer.rs b/crates/framework/src/protocol/writer.rs index b7a93f89..5da7c5be 100644 --- a/crates/framework/src/protocol/writer.rs +++ b/crates/framework/src/protocol/writer.rs @@ -169,28 +169,35 @@ impl RustResponseWriter { dispatch_metrics::record_response_build(t_build.elapsed().as_micros() as f64); let t_write = Instant::now(); - if chunked { + let write_result = if chunked { let mut buf = BytesMut::with_capacity(hdr_bytes.len() + body_bytes.len() + 32); buf.put_slice(&hdr_bytes); write_chunk(&mut buf, body_bytes); let py_bytes = PyBytes::new(py, &buf); self.transport - .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,))?; + .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,)) + .map(|_| ()) } else { let hdr_py = PyBytes::new(py, &hdr_bytes); self.transport - .call_method1(py, pyo3::intern!(py, "write"), (hdr_py,))?; - self.transport - .call_method1(py, pyo3::intern!(py, "write"), (data.bind(py),))?; - } + .call_method1(py, pyo3::intern!(py, "write"), (hdr_py,)) + .and_then(|_| { + self.transport + .call_method1(py, pyo3::intern!(py, "write"), (data.bind(py),)) + .map(|_| ()) + }) + }; dispatch_metrics::record_response_write(t_write.elapsed().as_micros() as f64); if more_body { self.state = WriteState::Streaming { chunked }; } else { + // Always signal completion — even if write failed. The callback + // decrements active_requests; skipping it leaks concurrency slots. self.signal_complete(py)?; } - Ok(()) + // Propagate write error after completion callback has fired. + write_result } fn write_continuation( @@ -203,7 +210,7 @@ impl RustResponseWriter { let t_write = Instant::now(); let body_bytes = data.bind(py).as_bytes(); - if chunked { + let write_result = if chunked { let terminator_len = if more_body { 0 } else { LAST_CHUNK.len() }; let mut buf = BytesMut::with_capacity(body_bytes.len() + 32 + terminator_len); write_chunk(&mut buf, body_bytes); @@ -212,21 +219,30 @@ impl RustResponseWriter { } let py_bytes = PyBytes::new(py, &buf); self.transport - .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,))?; + .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,)) + .map(|_| ()) } else { self.transport - .call_method1(py, pyo3::intern!(py, "write"), (data.bind(py),))?; - } + .call_method1(py, pyo3::intern!(py, "write"), (data.bind(py),)) + .map(|_| ()) + }; dispatch_metrics::record_response_write(t_write.elapsed().as_micros() as f64); if more_body { self.state = WriteState::Streaming { chunked }; } else { + // Always signal completion — even if write failed. self.signal_complete(py)?; } - Ok(()) + write_result } + /// Signal response completion to the protocol layer. + /// + /// Must be called even when `transport.write()` fails — the callback + /// decrements the active-request counter and resumes reading. Failing + /// to call it leaks concurrency slots until `MAX_CONCURRENT` is hit + /// and all new requests receive 503. fn signal_complete(&self, py: Python<'_>) -> PyResult<()> { if let Some(cb) = &self.on_complete { cb.call1(py, (self.response_status,))?; From a58498d9bfbc1f476e262726b3159dbcc2410a39 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Thu, 2 Apr 2026 17:48:23 +0200 Subject: [PATCH 12/18] =?UTF-8?q?=E2=9C=A8=20feat:=20implement=20HTTP=20pr?= =?UTF-8?q?otocol=20compliance=20and=20performance=20from=20uvi-compare?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 3 + Cargo.toml | 3 + crates/framework/Cargo.toml | 3 + crates/framework/src/asgi/scope.rs | 15 +- crates/framework/src/protocol/connection.rs | 295 ++++++++++++------- crates/framework/src/protocol/parser.rs | 35 +++ crates/framework/src/protocol/writer.rs | 301 ++++++++++++++++---- src/apx/_bridge.py | 43 +-- src/apx/_scheduler.py | 21 +- 9 files changed, 514 insertions(+), 205 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e29d970d..e571c65d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -305,11 +305,13 @@ dependencies = [ "hex", "http", "httparse", + "httpdate", "matchit", "mimalloc", "notify", "opentelemetry 0.29.1", "opentelemetry_sdk 0.29.0", + "percent-encoding", "pyo3", "rand 0.8.5", "rmp-serde", @@ -323,6 +325,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "uuid", "which", ] diff --git a/Cargo.toml b/Cargo.toml index 4ddb998d..14d1c26e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -144,7 +144,10 @@ http-body-util = "0.1" # Framework: oneshot protocol primitives httparse = "1.10" +httpdate = "1" matchit = "0.8" +percent-encoding = "2" +uuid = { version = "1", features = ["v4"] } [workspace.lints.rust] unsafe_code = "deny" diff --git a/crates/framework/Cargo.toml b/crates/framework/Cargo.toml index 21d95054..e629b10b 100644 --- a/crates/framework/Cargo.toml +++ b/crates/framework/Cargo.toml @@ -42,7 +42,10 @@ which.workspace = true notify.workspace = true futures-util.workspace = true httparse.workspace = true +httpdate.workspace = true matchit.workspace = true +percent-encoding.workspace = true +uuid.workspace = true [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = { version = "0.6", optional = true } diff --git a/crates/framework/src/asgi/scope.rs b/crates/framework/src/asgi/scope.rs index 358fe2f8..f22087e0 100644 --- a/crates/framework/src/asgi/scope.rs +++ b/crates/framework/src/asgi/scope.rs @@ -4,11 +4,13 @@ //! [`ResolvedAwaitable`] / [`ResolvedAwaitableWithValue`] for zero-overhead //! Python awaitables. +use std::collections::HashMap; +use std::net::SocketAddr; + use crate::transport::types::ProtocolVersion; use http::header::{self, HeaderName}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyString, PyTuple}; -use std::net::SocketAddr; use super::{ASGI_SPEC_VERSION, ASGI_VERSION}; @@ -79,9 +81,10 @@ const COMMON_HEADERS: &[HeaderName] = &[ header::REFERER, ]; -/// Pre-built `PyBytes` for common HTTP header names. +/// Pre-built `PyBytes` for common HTTP header names, keyed by +/// lowercase header bytes for O(1) lookup. pub struct HeaderInterns { - pub(crate) map: Vec<(HeaderName, Py)>, + pub(crate) map: HashMap, Py>, } crate::opaque_debug!(HeaderInterns); @@ -91,7 +94,11 @@ impl HeaderInterns { pub fn new(py: Python<'_>) -> Self { let map = COMMON_HEADERS .iter() - .map(|h| (h.clone(), PyBytes::new(py, h.as_str().as_bytes()).unbind())) + .map(|h| { + let key: Box<[u8]> = h.as_str().as_bytes().into(); + let val = PyBytes::new(py, h.as_str().as_bytes()).unbind(); + (key, val) + }) .collect(); Self { map } } diff --git a/crates/framework/src/protocol/connection.rs b/crates/framework/src/protocol/connection.rs index b5dd5ca1..a282cb09 100644 --- a/crates/framework/src/protocol/connection.rs +++ b/crates/framework/src/protocol/connection.rs @@ -6,7 +6,7 @@ use std::borrow::Cow; use std::net::SocketAddr; use std::sync::Arc; -use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::time::Instant; use bytes::Bytes; @@ -23,6 +23,9 @@ use super::writer::RustResponseWriter; /// Maximum concurrent in-flight requests per protocol instance. const MAX_CONCURRENT: u32 = 256; +/// Seconds of idle time before closing a keep-alive connection. +const KEEPALIVE_TIMEOUT_S: f64 = 5.0; + /// Shared state for all protocol instances on this worker. struct ProtocolShared { on_request: Py, @@ -30,6 +33,7 @@ struct ProtocolShared { server_host: String, server_port: u16, active_requests: AtomicU32, + write_paused: AtomicBool, } /// Factory that creates [`RustProtocol`] instances for `loop.create_server()`. @@ -58,6 +62,7 @@ impl ProtocolFactory { server_host, server_port, active_requests: AtomicU32::new(0), + write_paused: AtomicBool::new(false), }), } } @@ -74,6 +79,7 @@ impl ProtocolFactory { parser: RequestParser::new(), shared: Arc::clone(&self.shared), client_addr: None, + keepalive_handle: None, }, ) } @@ -89,6 +95,7 @@ pub struct RustProtocol { parser: RequestParser, shared: Arc, client_addr: Option, + keepalive_handle: Option>, } crate::opaque_debug!(RustProtocol); @@ -103,17 +110,51 @@ impl RustProtocol { Ok(()) } + /// Close the connection if idle (no active requests). + /// + /// Called by the event loop's `call_later` as the keep-alive timeout. + fn close_idle(&mut self, py: Python<'_>) { + if self.shared.active_requests.load(Ordering::Relaxed) == 0 + && let Some(transport) = &self.transport + { + let _ = transport.call_method0(py, pyo3::intern!(py, "close")); + } + } + /// Called by asyncio when data is received on the connection. - fn data_received(&mut self, py: Python<'_>, data: &[u8]) -> PyResult<()> { + fn data_received(slf: &Bound<'_, Self>, py: Python<'_>, data: &[u8]) -> PyResult<()> { + let py_self = slf.clone().unbind(); + let mut this = py_self.borrow_mut(py); + this.cancel_keepalive_timer(py); let t0 = Instant::now(); - let requests = self - .parser - .feed(data) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + let requests = match this.parser.feed(data) { + Ok(r) => r, + Err(e) => { + tracing::debug!( + name: "apx.protocol.parse_error", + error = %e, + "malformed HTTP request" + ); + if let Some(transport) = &this.transport { + let _ = transport_write(py, transport, REJECT_BAD_REQUEST); + let _ = transport.call_method0(py, pyo3::intern!(py, "close")); + } + return Ok(()); + } + }; dispatch_metrics::record_parse(t0.elapsed().as_micros() as f64); + let event_loop = py + .import("asyncio")? + .call_method0(pyo3::intern!(py, "get_running_loop"))? + .unbind(); + for parsed in requests { - self.dispatch_request(py, parsed)?; + // Temporarily drop the borrow so dispatch_request can create + // a Py reference for the OnRequestComplete callback. + drop(this); + Self::dispatch_request_inner(py, &py_self, &event_loop, parsed)?; + this = py_self.borrow_mut(py); } Ok(()) } @@ -124,8 +165,23 @@ impl RustProtocol { false } + /// Called by asyncio when the transport's write buffer exceeds + /// the high-water mark. + fn pause_writing(&self) { + self.shared.write_paused.store(true, Ordering::Release); + tracing::debug!(name: "apx.protocol.pause_writing", "transport write buffer full"); + } + + /// Called by asyncio when the transport's write buffer drains + /// below the low-water mark. + fn resume_writing(&self) { + self.shared.write_paused.store(false, Ordering::Release); + tracing::debug!(name: "apx.protocol.resume_writing", "transport write buffer drained"); + } + /// Called by asyncio when the connection is lost. - fn connection_lost(&mut self, _py: Python<'_>, _exc: Option<&Bound<'_, PyAny>>) { + fn connection_lost(&mut self, py: Python<'_>, _exc: Option<&Bound<'_, PyAny>>) { + self.cancel_keepalive_timer(py); self.transport = None; self.parser.reset(); dispatch_metrics::dec_connections(); @@ -133,16 +189,28 @@ impl RustProtocol { } impl RustProtocol { - fn dispatch_request(&self, py: Python<'_>, parsed: ParsedRequest) -> PyResult<()> { + fn cancel_keepalive_timer(&mut self, py: Python<'_>) { + if let Some(handle) = self.keepalive_handle.take() { + let _ = handle.call_method0(py, pyo3::intern!(py, "cancel")); + } + } + + fn dispatch_request_inner( + py: Python<'_>, + py_self: &Py, + event_loop: &Py, + parsed: ParsedRequest, + ) -> PyResult<()> { + let this = py_self.borrow(py); let t_dispatch = Instant::now(); - let Some(transport) = &self.transport else { + let Some(transport) = &this.transport else { return Ok(()); }; - let active = self.shared.active_requests.fetch_add(1, Ordering::Relaxed); + let active = this.shared.active_requests.fetch_add(1, Ordering::Relaxed); if active >= MAX_CONCURRENT { - self.shared.active_requests.fetch_sub(1, Ordering::Relaxed); - write_503(py, transport)?; + this.shared.active_requests.fetch_sub(1, Ordering::Relaxed); + transport_write(py, transport, REJECT_OVERLOADED)?; return Ok(()); } dispatch_metrics::inc_active_requests(); @@ -150,22 +218,28 @@ impl RustProtocol { transport.call_method0(py, pyo3::intern!(py, "pause_reading"))?; - let request_id = resolve_request_id(&parsed.head.headers); + let (request_id, has_request_id) = resolve_request_id(&parsed.head.headers); let t_scope = Instant::now(); let scope = build_scope_from_parsed( py, &parsed, - &self.shared.interns, - &self.shared.server_host, - self.shared.server_port, - self.client_addr, + &this.shared.interns, + &this.shared.server_host, + this.shared.server_port, + this.client_addr, &request_id, + has_request_id, )?; dispatch_metrics::record_scope_build(t_scope.elapsed().as_micros() as f64); let t_receive = Instant::now(); - let receive = HttpReceive::new(py, parsed.body)?; + let receive = HttpReceive::new( + py, + parsed.body, + Some(transport.clone_ref(py)), + parsed.head.expect_continue, + )?; dispatch_metrics::record_receive_build(t_receive.elapsed().as_micros() as f64); let method = parsed.head.method.as_str().to_owned(); @@ -176,19 +250,31 @@ impl RustProtocol { crate::telemetry::context::set_python_context(py, &trace_ctx)?; let transport_clone = transport.clone_ref(py); + let shared = Arc::clone(&this.shared); + let on_request = this.shared.on_request.clone_ref(py); + drop(this); + let on_complete = OnRequestComplete::create( py, transport_clone, - Arc::clone(&self.shared), + shared, t_dispatch, method, path, request_span, + py_self.clone_ref(py), + event_loop.clone_ref(py), )?; + + let this = py_self.borrow(py); + let Some(transport) = &this.transport else { + return Ok(()); + }; let send = RustResponseWriter::new(py, transport.clone_ref(py), Some(on_complete.into_any()))?; + drop(this); - self.shared.on_request.call1(py, (scope, receive, send))?; + on_request.call1(py, (scope, receive, send))?; Ok(()) } } @@ -206,11 +292,17 @@ struct OnRequestComplete { method: String, path: String, request_span: tracing::Span, + protocol: Py, + event_loop: Py, } crate::opaque_debug!(OnRequestComplete); impl OnRequestComplete { + #[expect( + clippy::too_many_arguments, + reason = "all fields needed for completion callback; struct builder would add overhead" + )] fn create( py: Python<'_>, transport: Py, @@ -219,6 +311,8 @@ impl OnRequestComplete { method: String, path: String, request_span: tracing::Span, + protocol: Py, + event_loop: Py, ) -> PyResult> { Py::new( py, @@ -229,6 +323,8 @@ impl OnRequestComplete { method, path, request_span, + protocol, + event_loop, }, ) } @@ -274,6 +370,18 @@ impl OnRequestComplete { "resume_reading failed (connection likely closed)" ); } + + if let Ok(close_idle) = self.protocol.getattr(py, pyo3::intern!(py, "close_idle")) + && let Ok(handle) = self.event_loop.call_method1( + py, + pyo3::intern!(py, "call_later"), + (KEEPALIVE_TIMEOUT_S, close_idle), + ) + && let Ok(mut proto) = self.protocol.try_borrow_mut(py) + { + proto.keepalive_handle = Some(handle); + } + Ok(()) } } @@ -284,20 +392,30 @@ impl OnRequestComplete { /// /// First call returns the request body immediately via /// `ResolvedAwaitableWithValue`. Subsequent calls return a pending -/// future (disconnect sentinel). +/// future (disconnect sentinel). Handles `Expect: 100-continue` by +/// writing the informational response before delivering the body. #[pyclass(module = "apx._core", freelist = 64)] pub struct HttpReceive { body: std::sync::Mutex>, + transport: Option>, + expect_continue: bool, } crate::opaque_debug!(HttpReceive); impl HttpReceive { - fn new(py: Python<'_>, body: Bytes) -> PyResult> { + fn new( + py: Python<'_>, + body: Bytes, + transport: Option>, + expect_continue: bool, + ) -> PyResult> { Py::new( py, Self { body: std::sync::Mutex::new(Some(body)), + transport, + expect_continue, }, ) } @@ -313,6 +431,12 @@ impl HttpReceive { .take(); if let Some(b) = body { + if self.expect_continue + && let Some(transport) = &self.transport + { + let _ = transport_write(py, transport, INFORMATIONAL_CONTINUE); + } + let event = PyDict::new(py); event.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?; event.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &b))?; @@ -338,6 +462,10 @@ impl HttpReceive { /// /// Bypasses `ScopeSource` trait and `HeaderMap` — works with raw byte /// pairs from the parser, avoiding the intermediate allocation. +#[expect( + clippy::too_many_arguments, + reason = "scope construction needs all ASGI fields; splitting would fragment the hot path" +)] fn build_scope_from_parsed( py: Python<'_>, parsed: &ParsedRequest, @@ -346,6 +474,7 @@ fn build_scope_from_parsed( server_port: u16, client_addr: Option, request_id: &str, + has_request_id: bool, ) -> PyResult> { let scope = interns .scope_template @@ -378,7 +507,14 @@ fn build_scope_from_parsed( PyBytes::new(py, &parsed.head.query_string), )?; - set_headers_from_parsed(py, &scope, &parsed.head, interns, request_id)?; + set_headers_from_parsed( + py, + &scope, + &parsed.head, + interns, + request_id, + has_request_id, + )?; set_addresses(py, &scope, interns, server_host, server_port, client_addr)?; scope.set_item( interns.keys.path_params.bind(py), @@ -390,15 +526,18 @@ fn build_scope_from_parsed( } /// Extract existing `x-request-id` from headers or generate a UUID v4. -fn resolve_request_id(headers: &[(Bytes, Bytes)]) -> String { +/// +/// Returns `(request_id, has_request_id)` so the scope builder can +/// skip a second header scan. +fn resolve_request_id(headers: &[(Bytes, Bytes)]) -> (String, bool) { for (name, value) in headers { if name.eq_ignore_ascii_case(b"x-request-id") && let Ok(s) = std::str::from_utf8(value) { - return s.to_owned(); + return (s.to_owned(), true); } } - generate_uuid_v4() + (generate_uuid_v4(), false) } /// Set headers list from raw byte pairs (no `HeaderMap` intermediary). @@ -410,12 +549,8 @@ fn set_headers_from_parsed( head: &ParsedHead, interns: &ScopeInterns, request_id: &str, + has_request_id: bool, ) -> PyResult<()> { - let has_request_id = head - .headers - .iter() - .any(|(name, _)| name.eq_ignore_ascii_case(b"x-request-id")); - let extra_cap = usize::from(!has_request_id); let mut pairs: Vec> = Vec::with_capacity(head.headers.len() + extra_cap); @@ -439,19 +574,7 @@ fn set_headers_from_parsed( /// Generate a UUID v4 string (random, RFC 4122 variant 1). fn generate_uuid_v4() -> String { - let mut bytes: [u8; 16] = rand::random(); - bytes[6] = (bytes[6] & 0x0f) | 0x40; - bytes[8] = (bytes[8] & 0x3f) | 0x80; - format!( - "{:08x}-{:04x}-{:04x}-{:04x}-{:012x}", - u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), - u16::from_be_bytes([bytes[4], bytes[5]]), - u16::from_be_bytes([bytes[6], bytes[7]]), - u16::from_be_bytes([bytes[8], bytes[9]]), - u64::from_be_bytes([ - 0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15] - ]) - ) + uuid::Uuid::new_v4().to_string() } /// Try to use the header intern cache, falling back to `PyBytes::new`. @@ -461,10 +584,8 @@ fn intern_header_name<'py>( interns: &ScopeInterns, ) -> Bound<'py, PyBytes> { let name_lower = name.to_ascii_lowercase(); - for (cached_name, cached_py) in &interns.headers.map { - if cached_name.as_str().as_bytes() == name_lower.as_slice() { - return cached_py.bind(py).clone(); - } + if let Some(cached) = interns.headers.map.get(name_lower.as_slice()) { + return cached.bind(py).clone(); } PyBytes::new(py, &name_lower) } @@ -493,45 +614,7 @@ fn set_addresses( /// Percent-decode a URL path. fn percent_decode(input: &str) -> Cow<'_, str> { - if !input.contains('%') { - return Cow::Borrowed(input); - } - let mut bytes = Vec::with_capacity(input.len()); - let mut chars = input.as_bytes().iter().copied(); - while let Some(b) = chars.next() { - if b == b'%' { - let hi = chars.next(); - let lo = chars.next(); - if let (Some(h), Some(l)) = (hi, lo) { - if let (Some(hv), Some(lv)) = (hex_val(h), hex_val(l)) { - bytes.push(hv << 4 | lv); - continue; - } - bytes.extend_from_slice(&[b'%', h, l]); - } else { - bytes.push(b'%'); - if let Some(h) = hi { - bytes.push(h); - } - } - } else { - bytes.push(b); - } - } - match String::from_utf8(bytes) { - Ok(s) => Cow::Owned(s), - Err(e) => Cow::Owned(String::from_utf8_lossy(e.as_bytes()).into_owned()), - } -} - -/// Convert a hex ASCII char to its value. -fn hex_val(b: u8) -> Option { - match b { - b'0'..=b'9' => Some(b - b'0'), - b'a'..=b'f' => Some(b - b'a' + 10), - b'A'..=b'F' => Some(b - b'A' + 10), - _ => None, - } + percent_encoding::percent_decode_str(input).decode_utf8_lossy() } /// Extract the peer address from an asyncio transport. @@ -550,17 +633,35 @@ fn extract_peer_addr(py: Python<'_>, transport: &Py) -> Option = bound.cast().ok()?; let host: String = tuple.get_item(0).ok()?.extract().ok()?; let port: u16 = tuple.get_item(1).ok()?.extract().ok()?; - format!("{host}:{port}").parse().ok() + let ip: std::net::IpAddr = host.parse().ok()?; + Some(SocketAddr::new(ip, port)) } -/// Write a 503 Service Unavailable response directly. -fn write_503(py: Python<'_>, transport: &Py) -> PyResult<()> { - let body = b"Service Unavailable"; - let response = format!( - "HTTP/1.1 503 Service Unavailable\r\ncontent-length: {}\r\ncontent-type: text/plain\r\n\r\nService Unavailable", - body.len(), - ); - let py_bytes = PyBytes::new(py, response.as_bytes()); +// ── Pre-built error responses (sans-I/O: pure data, no transport) ─── + +/// Sent when the parser cannot decode the incoming bytes as valid HTTP. +const REJECT_BAD_REQUEST: &[u8] = b"HTTP/1.1 400 Bad Request\r\n\ + content-length: 11\r\n\ + content-type: text/plain\r\n\ + connection: close\r\n\ + \r\n\ + Bad Request"; + +/// Sent when the per-connection concurrency limit is reached. +const REJECT_OVERLOADED: &[u8] = b"HTTP/1.1 503 Service Unavailable\r\n\ + content-length: 19\r\n\ + content-type: text/plain\r\n\ + \r\n\ + Service Unavailable"; + +/// Informational response for `Expect: 100-continue`. +const INFORMATIONAL_CONTINUE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n"; + +// ── Transport I/O helper ──────────────────────────────────────────── + +/// Write raw bytes to an asyncio transport. +fn transport_write(py: Python<'_>, transport: &Py, data: &[u8]) -> PyResult<()> { + let py_bytes = PyBytes::new(py, data); transport.call_method1(py, pyo3::intern!(py, "write"), (py_bytes,))?; Ok(()) } diff --git a/crates/framework/src/protocol/parser.rs b/crates/framework/src/protocol/parser.rs index e5f84901..bf8801a5 100644 --- a/crates/framework/src/protocol/parser.rs +++ b/crates/framework/src/protocol/parser.rs @@ -82,6 +82,8 @@ pub struct ParsedHead { pub version: HttpVersion, /// Content-Length value, if present. pub content_length: Option, + /// Whether the request includes `Expect: 100-continue`. + pub expect_continue: bool, } /// A fully parsed HTTP request (head + body). @@ -228,6 +230,7 @@ fn build_head(req: &httparse::Request<'_, '_>) -> Result }; let mut content_length = None; + let mut expect_continue = false; let mut headers = Vec::with_capacity(req.headers.len()); for header in req.headers.iter() { @@ -238,6 +241,11 @@ fn build_head(req: &httparse::Request<'_, '_>) -> Result { content_length = s.trim().parse().ok(); } + if header.name.eq_ignore_ascii_case("expect") + && header.value.eq_ignore_ascii_case(b"100-continue") + { + expect_continue = true; + } headers.push((name, value)); } @@ -248,6 +256,7 @@ fn build_head(req: &httparse::Request<'_, '_>) -> Result headers, version, content_length, + expect_continue, }) } @@ -403,4 +412,30 @@ mod tests { assert_eq!(requests.len(), 1); assert!(requests[0].body.is_empty()); } + + #[test] + fn test_expect_100_continue_detected() { + let mut parser = RequestParser::new(); + let requests = parser + .feed( + b"POST /upload HTTP/1.1\r\n\ + Host: h\r\n\ + Expect: 100-continue\r\n\ + Content-Length: 5\r\n\r\n\ + hello", + ) + .expect("parse failed"); + assert_eq!(requests.len(), 1); + assert!(requests[0].head.expect_continue); + } + + #[test] + fn test_expect_100_continue_absent() { + let mut parser = RequestParser::new(); + let requests = parser + .feed(b"GET / HTTP/1.1\r\nHost: h\r\n\r\n") + .expect("parse failed"); + assert_eq!(requests.len(), 1); + assert!(!requests[0].head.expect_continue); + } } diff --git a/crates/framework/src/protocol/writer.rs b/crates/framework/src/protocol/writer.rs index 5da7c5be..474ab93d 100644 --- a/crates/framework/src/protocol/writer.rs +++ b/crates/framework/src/protocol/writer.rs @@ -4,6 +4,7 @@ //! Sans-I/O core (`build_status_and_headers`, `parse_send_event`) is //! testable with `#[test]`. +use std::cell::RefCell; use std::time::Instant; use bytes::{BufMut, Bytes, BytesMut}; @@ -13,6 +14,28 @@ use pyo3::types::{PyBytes, PyDict, PyList}; use crate::asgi::scope::ResolvedAwaitable; use crate::telemetry::dispatch_metrics; +// ── Date header cache ─────────────────────────────────────────────── + +thread_local! { + static CACHED_DATE: RefCell<(Instant, Bytes)> = RefCell::new(( + Instant::now(), + Bytes::from_static(b""), + )); +} + +/// RFC 7231 `Date` header, cached and refreshed every second. +fn cached_date_header() -> Bytes { + CACHED_DATE.with(|cell| { + let mut cached = cell.borrow_mut(); + if cached.0.elapsed().as_secs() >= 1 || cached.1.is_empty() { + let now = httpdate::fmt_http_date(std::time::SystemTime::now()); + cached.1 = Bytes::from(format!("date: {now}\r\n")); + cached.0 = Instant::now(); + } + cached.1.clone() + }) +} + /// ASGI send event parsed from a Python dict. #[derive(Debug)] pub enum SendEvent { @@ -168,6 +191,8 @@ impl RustResponseWriter { }; dispatch_metrics::record_response_build(t_build.elapsed().as_micros() as f64); + const MERGE_THRESHOLD: usize = 65_536; + let t_write = Instant::now(); let write_result = if chunked { let mut buf = BytesMut::with_capacity(hdr_bytes.len() + body_bytes.len() + 32); @@ -177,6 +202,14 @@ impl RustResponseWriter { self.transport .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,)) .map(|_| ()) + } else if body_bytes.len() <= MERGE_THRESHOLD { + let mut buf = BytesMut::with_capacity(hdr_bytes.len() + body_bytes.len()); + buf.put_slice(&hdr_bytes); + buf.put_slice(body_bytes); + let py_bytes = PyBytes::new(py, &buf); + self.transport + .call_method1(py, pyo3::intern!(py, "write"), (py_bytes,)) + .map(|_| ()) } else { let hdr_py = PyBytes::new(py, &hdr_bytes); self.transport @@ -311,108 +344,201 @@ fn extract_response_headers( /// HTTP/1.1 chunked transfer encoding terminator. const LAST_CHUNK: &[u8] = b"0\r\n\r\n"; +const HEX_DIGITS: &[u8; 16] = b"0123456789abcdef"; + /// Write a single HTTP chunk frame: `{hex_len}\r\n{data}\r\n`. fn write_chunk(buf: &mut BytesMut, data: &[u8]) { if data.is_empty() { return; } - buf.put_slice(format!("{:x}\r\n", data.len()).as_bytes()); + write_hex(buf, data.len()); + buf.put_slice(b"\r\n"); buf.put_slice(data); buf.put_slice(b"\r\n"); } -/// Build the HTTP/1.1 status line + headers as bytes. -pub fn build_status_and_headers(status: u16, headers: &[(Bytes, Bytes)]) -> Bytes { - let reason = reason_phrase(status); - let mut buf = BytesMut::with_capacity(256); +/// Write a `usize` as lowercase hex directly into `buf` (no heap alloc). +fn write_hex(buf: &mut BytesMut, mut n: usize) { + if n == 0 { + buf.put_u8(b'0'); + return; + } + let mut stack = [0u8; 16]; + let mut pos = stack.len(); + while n > 0 { + pos -= 1; + stack[pos] = HEX_DIGITS[n & 0xf]; + n >>= 4; + } + buf.put_slice(&stack[pos..]); +} + +// ── Response head encoding ────────────────────────────────────────── + +/// Write the status line for a given code, using cached bytes for +/// common codes to avoid per-response `to_string()` + concatenation. +fn write_status_line(buf: &mut BytesMut, status: u16) { + match status { + 200 => { + buf.put_slice(b"HTTP/1.1 200 OK\r\n"); + return; + } + 201 => { + buf.put_slice(b"HTTP/1.1 201 Created\r\n"); + return; + } + 204 => { + buf.put_slice(b"HTTP/1.1 204 No Content\r\n"); + return; + } + 301 => { + buf.put_slice(b"HTTP/1.1 301 Moved Permanently\r\n"); + return; + } + 302 => { + buf.put_slice(b"HTTP/1.1 302 Found\r\n"); + return; + } + 304 => { + buf.put_slice(b"HTTP/1.1 304 Not Modified\r\n"); + return; + } + 307 => { + buf.put_slice(b"HTTP/1.1 307 Temporary Redirect\r\n"); + return; + } + 308 => { + buf.put_slice(b"HTTP/1.1 308 Permanent Redirect\r\n"); + return; + } + 400 => { + buf.put_slice(b"HTTP/1.1 400 Bad Request\r\n"); + return; + } + 401 => { + buf.put_slice(b"HTTP/1.1 401 Unauthorized\r\n"); + return; + } + 403 => { + buf.put_slice(b"HTTP/1.1 403 Forbidden\r\n"); + return; + } + 404 => { + buf.put_slice(b"HTTP/1.1 404 Not Found\r\n"); + return; + } + 405 => { + buf.put_slice(b"HTTP/1.1 405 Method Not Allowed\r\n"); + return; + } + 409 => { + buf.put_slice(b"HTTP/1.1 409 Conflict\r\n"); + return; + } + 422 => { + buf.put_slice(b"HTTP/1.1 422 Unprocessable Entity\r\n"); + return; + } + 429 => { + buf.put_slice(b"HTTP/1.1 429 Too Many Requests\r\n"); + return; + } + 500 => { + buf.put_slice(b"HTTP/1.1 500 Internal Server Error\r\n"); + return; + } + 502 => { + buf.put_slice(b"HTTP/1.1 502 Bad Gateway\r\n"); + return; + } + 503 => { + buf.put_slice(b"HTTP/1.1 503 Service Unavailable\r\n"); + return; + } + 504 => { + buf.put_slice(b"HTTP/1.1 504 Gateway Timeout\r\n"); + return; + } + _ => {} + } buf.put_slice(b"HTTP/1.1 "); buf.put_slice(status.to_string().as_bytes()); buf.put_slice(b" "); - buf.put_slice(reason.as_bytes()); + buf.put_slice(reason_phrase(status).as_bytes()); buf.put_slice(b"\r\n"); +} + +/// Encode status line + app headers + extra trailer headers + Date +/// into a `BytesMut`. Optionally appends body. +fn encode_head( + status: u16, + headers: &[(Bytes, Bytes)], + extra: &[(&[u8], &[u8])], + body: Option<&[u8]>, +) -> Bytes { + let body_len = body.map_or(0, <[u8]>::len); + let mut buf = BytesMut::with_capacity(256 + body_len); + write_status_line(&mut buf, status); for (name, value) in headers { buf.put_slice(name); buf.put_slice(b": "); buf.put_slice(value); buf.put_slice(b"\r\n"); } - buf.put_slice(b"\r\n"); - buf.freeze() -} - -/// Build status line + headers with `Transfer-Encoding: chunked`. -fn build_status_and_headers_chunked(status: u16, headers: &[(Bytes, Bytes)]) -> Bytes { - let reason = reason_phrase(status); - let mut buf = BytesMut::with_capacity(256); - buf.put_slice(b"HTTP/1.1 "); - buf.put_slice(status.to_string().as_bytes()); - buf.put_slice(b" "); - buf.put_slice(reason.as_bytes()); - buf.put_slice(b"\r\n"); - for (name, value) in headers { + for &(name, value) in extra { buf.put_slice(name); buf.put_slice(b": "); buf.put_slice(value); buf.put_slice(b"\r\n"); } - buf.put_slice(b"transfer-encoding: chunked\r\n"); + buf.put_slice(&cached_date_header()); buf.put_slice(b"\r\n"); + if let Some(b) = body { + buf.put_slice(b); + } buf.freeze() } +/// Build the HTTP/1.1 status line + headers as bytes. +pub fn build_status_and_headers(status: u16, headers: &[(Bytes, Bytes)]) -> Bytes { + encode_head(status, headers, &[], None) +} + +/// Build status line + headers with `Transfer-Encoding: chunked`. +fn build_status_and_headers_chunked(status: u16, headers: &[(Bytes, Bytes)]) -> Bytes { + encode_head(status, headers, &[(b"transfer-encoding", b"chunked")], None) +} + /// Build status line + headers with an auto-added `Content-Length`. fn build_status_and_headers_with_length( status: u16, headers: &[(Bytes, Bytes)], body_len: usize, ) -> Bytes { - let reason = reason_phrase(status); - let mut buf = BytesMut::with_capacity(256); - buf.put_slice(b"HTTP/1.1 "); - buf.put_slice(status.to_string().as_bytes()); - buf.put_slice(b" "); - buf.put_slice(reason.as_bytes()); - buf.put_slice(b"\r\n"); - for (name, value) in headers { - buf.put_slice(name); - buf.put_slice(b": "); - buf.put_slice(value); - buf.put_slice(b"\r\n"); - } - buf.put_slice(b"content-length: "); - buf.put_slice(body_len.to_string().as_bytes()); - buf.put_slice(b"\r\n\r\n"); - buf.freeze() + let len_str = body_len.to_string(); + encode_head( + status, + headers, + &[(b"content-length", len_str.as_bytes())], + None, + ) } /// Build a complete HTTP/1.1 response (status + headers + body). fn build_full_response(status: u16, headers: &[(Bytes, Bytes)], body: &[u8]) -> Bytes { - let reason = reason_phrase(status); - let mut buf = BytesMut::with_capacity(256 + body.len()); - buf.put_slice(b"HTTP/1.1 "); - buf.put_slice(status.to_string().as_bytes()); - buf.put_slice(b" "); - buf.put_slice(reason.as_bytes()); - buf.put_slice(b"\r\n"); - - let mut has_content_length = false; - for (name, value) in headers { - buf.put_slice(name); - buf.put_slice(b": "); - buf.put_slice(value); - buf.put_slice(b"\r\n"); - if name.eq_ignore_ascii_case(b"content-length") { - has_content_length = true; - } - } - if !has_content_length { - buf.put_slice(b"content-length: "); - buf.put_slice(body.len().to_string().as_bytes()); - buf.put_slice(b"\r\n"); + let has_content_length = headers + .iter() + .any(|(name, _)| name.eq_ignore_ascii_case(b"content-length")); + if has_content_length { + return encode_head(status, headers, &[], Some(body)); } - buf.put_slice(b"\r\n"); - buf.put_slice(body); - buf.freeze() + let len_str = body.len().to_string(); + encode_head( + status, + headers, + &[(b"content-length", len_str.as_bytes())], + Some(body), + ) } /// Standard HTTP reason phrase for common status codes. @@ -457,6 +583,7 @@ mod tests { let s = std::str::from_utf8(&result).expect("valid utf8"); assert!(s.starts_with("HTTP/1.1 200 OK\r\n")); assert!(s.contains("content-type: text/html\r\n")); + assert!(s.contains("date: ")); assert!(s.ends_with("\r\n\r\n")); } @@ -465,6 +592,7 @@ mod tests { let result = build_status_and_headers(404, &[]); let s = std::str::from_utf8(&result).expect("valid utf8"); assert!(s.starts_with("HTTP/1.1 404 Not Found\r\n")); + assert!(s.contains("date: ")); } #[test] @@ -476,6 +604,7 @@ mod tests { let result = build_full_response(200, &headers, b"hello"); let s = std::str::from_utf8(&result).expect("valid utf8"); assert!(s.contains("content-length: 5\r\n")); + assert!(s.contains("date: ")); assert!(s.ends_with("hello")); } @@ -519,6 +648,7 @@ mod tests { let s = std::str::from_utf8(&result).expect("valid utf8"); assert!(s.contains("transfer-encoding: chunked\r\n")); assert!(s.contains("content-type: text/plain\r\n")); + assert!(s.contains("date: ")); } #[test] @@ -530,6 +660,7 @@ mod tests { let result = build_status_and_headers_with_length(200, &headers, 42); let s = std::str::from_utf8(&result).expect("valid utf8"); assert!(s.contains("content-length: 42\r\n")); + assert!(s.contains("date: ")); } #[test] @@ -550,4 +681,48 @@ mod tests { fn test_last_chunk_constant() { assert_eq!(LAST_CHUNK, b"0\r\n\r\n"); } + + #[test] + fn test_date_header_cached_within_second() { + let a = cached_date_header(); + let b = cached_date_header(); + assert_eq!(a, b, "two calls within 1s must return the same value"); + assert!(a.starts_with(b"date: ")); + assert!(a.ends_with(b"\r\n")); + } + + #[test] + fn test_write_hex_zero() { + let mut buf = BytesMut::new(); + write_hex(&mut buf, 0); + assert_eq!(&buf[..], b"0"); + } + + #[test] + fn test_write_hex_small() { + let mut buf = BytesMut::new(); + write_hex(&mut buf, 255); + assert_eq!(&buf[..], b"ff"); + } + + #[test] + fn test_write_hex_large() { + let mut buf = BytesMut::new(); + write_hex(&mut buf, 0x1a2b); + assert_eq!(&buf[..], b"1a2b"); + } + + #[test] + fn test_write_status_line_cached() { + let mut buf = BytesMut::new(); + write_status_line(&mut buf, 200); + assert_eq!(&buf[..], b"HTTP/1.1 200 OK\r\n"); + } + + #[test] + fn test_write_status_line_uncached() { + let mut buf = BytesMut::new(); + write_status_line(&mut buf, 418); + assert_eq!(&buf[..], b"HTTP/1.1 418 Unknown\r\n"); + } } diff --git a/src/apx/_bridge.py b/src/apx/_bridge.py index 7a08c92d..52b51ce4 100644 --- a/src/apx/_bridge.py +++ b/src/apx/_bridge.py @@ -7,13 +7,8 @@ from __future__ import annotations import logging -import traceback -from collections.abc import Coroutine -from typing import Any, Callable, Protocol - - -class _ErrorSink(Protocol): - def send_error(self, tb: str) -> None: ... +from collections.abc import Callable +from typing import Any class _ApxHandler(logging.Handler): @@ -29,7 +24,7 @@ def emit(self, record: logging.LogRecord) -> None: pass -def install_log_handler(emit_fn: Callable[[int, str, str], None]) -> None: +def install_log_handler(emit_fn: Callable[[int, str, str, str], None]) -> None: handler = _ApxHandler(emit_fn) logging.root.addHandler(handler) logging.root.setLevel(logging.DEBUG) @@ -37,35 +32,3 @@ def install_log_handler(emit_fn: Callable[[int, str, str], None]) -> None: async def resolved(val: Any) -> Any: return val - - -async def guarded(coro: Coroutine[Any, Any, None], send: _ErrorSink) -> None: - try: - await coro - except Exception as exc: - tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) - send.send_error(tb) - - -_AsgiApp = Callable[..., Coroutine[Any, Any, None]] - - -def launch( - app: _AsgiApp, scope: dict[str, Any], receive: Any, send: _ErrorSink -) -> None: - """Create an ASGI coroutine and submit it as a guarded task. - - Called on the asyncio thread via ``call_soon_threadsafe``. - Combines ``app(scope, receive, send)`` + error guard + ``create_task`` - into a single ``_run_once`` callback so the tokio thread does no Python work. - """ - import asyncio - - async def _run() -> None: - try: - await app(scope, receive, send) - except Exception as exc: - tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) - send.send_error(tb) - - asyncio.get_running_loop().create_task(_run()) diff --git a/src/apx/_scheduler.py b/src/apx/_scheduler.py index 73de18d7..086c28a8 100644 --- a/src/apx/_scheduler.py +++ b/src/apx/_scheduler.py @@ -14,11 +14,14 @@ import asyncio import contextvars +import sys import time from collections import deque from collections.abc import Callable, Coroutine from typing import Any +_PY312 = sys.version_info >= (3, 12) + # ── Constants ──────────────────────────────────────────────────────── STEP_BUDGET: int = 256 @@ -44,6 +47,15 @@ async def _park_forever() -> None: await asyncio.get_event_loop().create_future() +async def _sentinel() -> None: + """Minimal coroutine for eager_start (3.12+). + + Completes immediately so the Task reaches ``done()`` after + ``eager_start`` finishes synchronously, preventing any stale + ``__step`` callback from polluting ``_ready``. + """ + + class SchedulerTask(asyncio.Task): """Placeholder task for ``_enter_task`` / ``_leave_task`` bracketing. @@ -65,7 +77,14 @@ def __init__(self, *, loop: asyncio.AbstractEventLoop) -> None: # We store it explicitly because CPython's C-implemented Task # does not expose ``_context`` as a Python-accessible attribute. self._drive_context: contextvars.Context = contextvars.copy_context() - super().__init__(_park_forever(), loop=loop) + if _PY312: + super().__init__(_sentinel(), loop=loop, eager_start=True) # type: ignore[call-arg] + else: + ready = getattr(loop, "_ready", None) + n_before = len(ready) if ready is not None else 0 + super().__init__(_park_forever(), loop=loop) + if ready is not None and len(ready) > n_before: + ready.pop() self._log_destroy_pending: bool = False self._cancel_flag: bool = False self._cancel_msg: str | None = None From 5efd1dcaacf2eaaa7924eef92bcb5c49dc7ce502 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Thu, 2 Apr 2026 20:45:48 +0200 Subject: [PATCH 13/18] =?UTF-8?q?=F0=9F=90=9B=20fix:=20RAII=20request=20sl?= =?UTF-8?q?ot=20guard,=20pause=5Fwriting=20via=20multiple-pymethods,=20tim?= =?UTF-8?q?ed!=20macro?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 10 + Cargo.toml | 2 +- crates/framework/src/protocol/connection.rs | 283 ++++++++++++++------ crates/framework/src/protocol/writer.rs | 26 +- crates/framework/src/telemetry/mod.rs | 16 ++ src/apx/_core.pyi | 1 + tests/integration/test_telemetry.py | 4 +- 7 files changed, 244 insertions(+), 98 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e571c65d..e51eba0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2963,6 +2963,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71dd52191aae121e8611f1e8dc3e324dd0dd1dee1e6dd91d10ee07a3cfb4d9d8" +[[package]] +name = "inventory" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc61209c082fbeb19919bee74b176221b27223e27b65d781eb91af24eb1fb46e" +dependencies = [ + "rustversion", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -4576,6 +4585,7 @@ version = "0.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf85e27e86080aafd5a22eae58a162e133a589551542b3e5cee4beb27e54f8e1" dependencies = [ + "inventory", "libc", "once_cell", "portable-atomic", diff --git a/Cargo.toml b/Cargo.toml index 14d1c26e..08a3569d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -125,7 +125,7 @@ which = "7" # instead of silent interpreter corruption in the multi-worker fork model. # NOTE: "extension-module" is NOT set here — it is set only in crates/apx # (the cdylib). Library crates need to link against libpython for `cargo test`. -pyo3 = { version = "0.28.2", features = ["macros"] } +pyo3 = { version = "0.28.2", features = ["macros", "multiple-pymethods"] } pyo3-async-runtimes = { version = "0.28.0", features = ["tokio-runtime"] } socket2 = "0.6.2" diff --git a/crates/framework/src/protocol/connection.rs b/crates/framework/src/protocol/connection.rs index a282cb09..5d6911be 100644 --- a/crates/framework/src/protocol/connection.rs +++ b/crates/framework/src/protocol/connection.rs @@ -36,6 +36,54 @@ struct ProtocolShared { write_paused: AtomicBool, } +/// RAII guard for a concurrency slot in `active_requests`. +/// +/// Increments the counter on [`acquire`] and decrements on [`Drop`]. +/// This guarantees the counter is always released — even if the ASGI +/// app never calls `send`, the handler raises, or the connection +/// closes mid-request. +/// +/// The [`release`] method allows explicit decrement (e.g. in +/// `OnRequestComplete.__call__`) while making `Drop` a no-op. +struct RequestSlot { + shared: Arc, + released: bool, +} + +impl RequestSlot { + /// Try to acquire a concurrency slot. Returns `None` if the + /// worker is at `MAX_CONCURRENT` — the caller should send 503. + fn acquire(shared: &Arc) -> Option { + let active = shared.active_requests.fetch_add(1, Ordering::Relaxed); + if active >= MAX_CONCURRENT { + shared.active_requests.fetch_sub(1, Ordering::Relaxed); + return None; + } + dispatch_metrics::inc_active_requests(); + crate::telemetry::http::inc_active_requests(); + Some(Self { + shared: Arc::clone(shared), + released: false, + }) + } + + /// Explicitly release the slot. Subsequent calls and `Drop` are no-ops. + fn release(&mut self) { + if !self.released { + self.released = true; + self.shared.active_requests.fetch_sub(1, Ordering::Relaxed); + dispatch_metrics::dec_active_requests(); + crate::telemetry::http::dec_active_requests(); + } + } +} + +impl Drop for RequestSlot { + fn drop(&mut self) { + self.release(); + } +} + /// Factory that creates [`RustProtocol`] instances for `loop.create_server()`. /// /// Holds shared worker state (interns, dispatch callback, concurrency limit). @@ -110,24 +158,36 @@ impl RustProtocol { Ok(()) } - /// Close the connection if idle (no active requests). - /// - /// Called by the event loop's `call_later` as the keep-alive timeout. - fn close_idle(&mut self, py: Python<'_>) { - if self.shared.active_requests.load(Ordering::Relaxed) == 0 - && let Some(transport) = &self.transport - { - let _ = transport.call_method0(py, pyo3::intern!(py, "close")); - } - } - /// Called by asyncio when data is received on the connection. + /// + /// The borrow_mut is held ONLY for pure-Rust operations (parse, + /// timer handle read). It is dropped BEFORE any Python calls + /// (cancel, close, import) to prevent PyO3 BorrowError when + /// uvloop re-enters via `pause_writing` during `transport.write`. fn data_received(slf: &Bound<'_, Self>, py: Python<'_>, data: &[u8]) -> PyResult<()> { let py_self = slf.clone().unbind(); - let mut this = py_self.borrow_mut(py); - this.cancel_keepalive_timer(py); - let t0 = Instant::now(); - let requests = match this.parser.feed(data) { + + // Phase 1: borrow_mut for pure Rust (parse + extract handles). + let (feed_result, keepalive_handle, error_transport) = { + let mut this = py_self.borrow_mut(py); + let keepalive = this.keepalive_handle.take(); + let result = + crate::telemetry::timed!(dispatch_metrics::record_parse, this.parser.feed(data)); + let err_transport = if result.is_err() { + this.transport.as_ref().map(|t| t.clone_ref(py)) + } else { + None + }; + (result, keepalive, err_transport) + // borrow_mut dropped here — BEFORE any Python calls. + }; + + // Phase 2: Python calls with NO borrow held. + if let Some(handle) = keepalive_handle { + let _ = handle.call_method0(py, pyo3::intern!(py, "cancel")); + } + + let requests = match feed_result { Ok(r) => r, Err(e) => { tracing::debug!( @@ -135,14 +195,13 @@ impl RustProtocol { error = %e, "malformed HTTP request" ); - if let Some(transport) = &this.transport { - let _ = transport_write(py, transport, REJECT_BAD_REQUEST); + if let Some(transport) = error_transport { + let _ = transport_write(py, &transport, REJECT_BAD_REQUEST); let _ = transport.call_method0(py, pyo3::intern!(py, "close")); } return Ok(()); } }; - dispatch_metrics::record_parse(t0.elapsed().as_micros() as f64); let event_loop = py .import("asyncio")? @@ -150,11 +209,7 @@ impl RustProtocol { .unbind(); for parsed in requests { - // Temporarily drop the borrow so dispatch_request can create - // a Py reference for the OnRequestComplete callback. - drop(this); Self::dispatch_request_inner(py, &py_self, &event_loop, parsed)?; - this = py_self.borrow_mut(py); } Ok(()) } @@ -165,26 +220,43 @@ impl RustProtocol { false } - /// Called by asyncio when the transport's write buffer exceeds - /// the high-water mark. + /// Called by asyncio/uvloop when the transport's write buffer + /// exceeds the high-water mark. + /// + /// No-op stub: the method must exist to prevent uvloop from logging + /// "protocol.pause_writing() failed". We don't implement write-side + /// flow control — this is a safety stub to prevent uvloop errors. + /// + /// Empty body: we don't implement write-side flow control. + /// The `&self` borrow is needed for the asyncio protocol interface. + /// Called by asyncio when the connection is lost. + fn connection_lost(&mut self, py: Python<'_>, _exc: Option<&Bound<'_, PyAny>>) { + self.cancel_keepalive_timer(py); + self.transport = None; + self.parser.reset(); + dispatch_metrics::dec_connections(); + } + /// Called by asyncio/uvloop when the transport's write buffer + /// exceeds the high-water mark. fn pause_writing(&self) { self.shared.write_paused.store(true, Ordering::Release); - tracing::debug!(name: "apx.protocol.pause_writing", "transport write buffer full"); } - /// Called by asyncio when the transport's write buffer drains - /// below the low-water mark. + /// Called by asyncio/uvloop when the transport's write buffer + /// drains below the low-water mark. fn resume_writing(&self) { self.shared.write_paused.store(false, Ordering::Release); - tracing::debug!(name: "apx.protocol.resume_writing", "transport write buffer drained"); } - /// Called by asyncio when the connection is lost. - fn connection_lost(&mut self, py: Python<'_>, _exc: Option<&Bound<'_, PyAny>>) { - self.cancel_keepalive_timer(py); - self.transport = None; - self.parser.reset(); - dispatch_metrics::dec_connections(); + /// Close the connection if idle (no active requests). + /// + /// Called by the event loop's `call_later` as the keep-alive timeout. + fn close_idle(&mut self, py: Python<'_>) { + if self.shared.active_requests.load(Ordering::Relaxed) == 0 + && let Some(transport) = &self.transport + { + let _ = transport.call_method0(py, pyo3::intern!(py, "close")); + } } } @@ -207,40 +279,94 @@ impl RustProtocol { return Ok(()); }; - let active = this.shared.active_requests.fetch_add(1, Ordering::Relaxed); - if active >= MAX_CONCURRENT { - this.shared.active_requests.fetch_sub(1, Ordering::Relaxed); + let Some(slot) = RequestSlot::acquire(&this.shared) else { transport_write(py, transport, REJECT_OVERLOADED)?; return Ok(()); + }; + + // Release the borrow before calling into Python. + // `slot` owns the concurrency decrement — if anything below + // fails or the writer is GC'd without calling send, `Drop` + // fires and the counter is released automatically. + drop(this); + + let result = Self::dispatch_body(py, py_self, event_loop, parsed, t_dispatch, slot); + + if let Err(e) = result { + // `slot` was either moved into OnRequestComplete (and will + // be released via Drop when OnRequestComplete is GC'd) or + // it was dropped by dispatch_body on error (Drop fires). + // Either way, the counter is handled. Just resume reading. + if let Ok(this) = py_self.try_borrow(py) + && let Some(transport) = &this.transport + { + let _ = transport.call_method0(py, pyo3::intern!(py, "resume_reading")); + } + tracing::debug!( + name: "apx.protocol.dispatch_error", + error = %e, + "request dispatch failed" + ); + return Err(e); } - dispatch_metrics::inc_active_requests(); - crate::telemetry::http::inc_active_requests(); + Ok(()) + } + + /// Inner dispatch body. The `slot` RAII guard ensures the + /// active_requests counter is decremented if this function errors + /// before transferring the slot into `OnRequestComplete`. + fn dispatch_body( + py: Python<'_>, + py_self: &Py, + event_loop: &Py, + parsed: ParsedRequest, + t_dispatch: Instant, + slot: RequestSlot, + ) -> PyResult<()> { + // Borrow briefly to extract what we need, then release before + // calling into Python (transport.write may trigger pause_writing + // which needs to borrow &self). + let (transport, shared, on_request, client_addr) = { + let this = py_self.borrow(py); + let Some(transport) = &this.transport else { + return Ok(()); + }; + let t = transport.clone_ref(py); + let s = Arc::clone(&this.shared); + let o = this.shared.on_request.clone_ref(py); + let c = this.client_addr; + (t, s, o, c) + }; + // Borrow released — safe to call Python methods that may + // re-enter RustProtocol (e.g. pause_writing via transport.write). transport.call_method0(py, pyo3::intern!(py, "pause_reading"))?; let (request_id, has_request_id) = resolve_request_id(&parsed.head.headers); - let t_scope = Instant::now(); - let scope = build_scope_from_parsed( - py, - &parsed, - &this.shared.interns, - &this.shared.server_host, - this.shared.server_port, - this.client_addr, - &request_id, - has_request_id, - )?; - dispatch_metrics::record_scope_build(t_scope.elapsed().as_micros() as f64); + let scope = crate::telemetry::timed!( + dispatch_metrics::record_scope_build, + build_scope_from_parsed( + py, + &parsed, + &shared.interns, + &shared.server_host, + shared.server_port, + client_addr, + &request_id, + has_request_id, + )? + ); - let t_receive = Instant::now(); - let receive = HttpReceive::new( - py, - parsed.body, - Some(transport.clone_ref(py)), - parsed.head.expect_continue, - )?; - dispatch_metrics::record_receive_build(t_receive.elapsed().as_micros() as f64); + let receive = crate::telemetry::timed!( + dispatch_metrics::record_receive_build, + HttpReceive::new( + py, + parsed.body, + Some(transport.clone_ref(py)), + parsed.head.expect_continue, + )? + ); let method = parsed.head.method.as_str().to_owned(); let path = parsed.head.path; @@ -249,30 +375,19 @@ impl RustProtocol { crate::telemetry::http::begin_request_span(&request_id, &method, &path); crate::telemetry::context::set_python_context(py, &trace_ctx)?; - let transport_clone = transport.clone_ref(py); - let shared = Arc::clone(&this.shared); - let on_request = this.shared.on_request.clone_ref(py); - drop(this); - let on_complete = OnRequestComplete::create( py, - transport_clone, - shared, + transport.clone_ref(py), t_dispatch, method, path, request_span, py_self.clone_ref(py), event_loop.clone_ref(py), + slot, )?; - let this = py_self.borrow(py); - let Some(transport) = &this.transport else { - return Ok(()); - }; - let send = - RustResponseWriter::new(py, transport.clone_ref(py), Some(on_complete.into_any()))?; - drop(this); + let send = RustResponseWriter::new(py, transport, Some(on_complete.into_any()))?; on_request.call1(py, (scope, receive, send))?; Ok(()) @@ -287,13 +402,16 @@ impl RustProtocol { #[pyclass(module = "apx._core")] struct OnRequestComplete { transport: Py, - shared: Arc, dispatch_start: Instant, method: String, path: String, request_span: tracing::Span, protocol: Py, event_loop: Py, + /// RAII concurrency slot — `Drop` decrements `active_requests` + /// if `__call__` was never invoked (e.g. app never called `send`, + /// or the `RustResponseWriter` was GC'd without completion). + slot: RequestSlot, } crate::opaque_debug!(OnRequestComplete); @@ -306,25 +424,25 @@ impl OnRequestComplete { fn create( py: Python<'_>, transport: Py, - shared: Arc, dispatch_start: Instant, method: String, path: String, request_span: tracing::Span, protocol: Py, event_loop: Py, + slot: RequestSlot, ) -> PyResult> { Py::new( py, Self { transport, - shared, dispatch_start, method, path, request_span, protocol, event_loop, + slot, }, ) } @@ -352,17 +470,10 @@ impl OnRequestComplete { crate::telemetry::http::finish_request_span(&self.request_span, status); } - // Always decrement counters, even if the transport is gone. - // If resume_reading fails (connection already closed), we must - // still release the concurrency slot — otherwise the counter - // leaks and eventually all requests get 503. + // Resume reading (may fail if connection closed — that's OK). let resume_result = self .transport .call_method0(py, pyo3::intern!(py, "resume_reading")); - self.shared.active_requests.fetch_sub(1, Ordering::Relaxed); - dispatch_metrics::dec_active_requests(); - crate::telemetry::http::dec_active_requests(); - if let Err(e) = resume_result { tracing::debug!( name: "apx.protocol.resume_reading_failed", @@ -371,6 +482,10 @@ impl OnRequestComplete { ); } + // Release the concurrency slot explicitly. This makes the + // Drop a no-op — the counter won't be double-decremented. + self.slot.release(); + if let Ok(close_idle) = self.protocol.getattr(py, pyo3::intern!(py, "close_idle")) && let Ok(handle) = self.event_loop.call_method1( py, diff --git a/crates/framework/src/protocol/writer.rs b/crates/framework/src/protocol/writer.rs index 474ab93d..c03a0269 100644 --- a/crates/framework/src/protocol/writer.rs +++ b/crates/framework/src/protocol/writer.rs @@ -114,9 +114,10 @@ impl RustResponseWriter { impl RustResponseWriter { /// ASGI send callable. fn __call__(&mut self, py: Python<'_>, event: &Bound<'_, PyDict>) -> PyResult> { - let t0 = Instant::now(); - let parsed = parse_send_event(py, event)?; - dispatch_metrics::record_send_parse(t0.elapsed().as_micros() as f64); + let parsed = crate::telemetry::timed!( + dispatch_metrics::record_send_parse, + parse_send_event(py, event)? + ); match parsed { SendEvent::Start { status, headers } => { @@ -181,15 +182,16 @@ impl RustResponseWriter { let chunked = more_body && !has_content_length; - let t_build = Instant::now(); - let hdr_bytes = if chunked { - build_status_and_headers_chunked(status, headers) - } else if !more_body && !has_content_length { - build_status_and_headers_with_length(status, headers, body_bytes.len()) - } else { - build_status_and_headers(status, headers) - }; - dispatch_metrics::record_response_build(t_build.elapsed().as_micros() as f64); + let hdr_bytes = crate::telemetry::timed!( + dispatch_metrics::record_response_build, + if chunked { + build_status_and_headers_chunked(status, headers) + } else if !more_body && !has_content_length { + build_status_and_headers_with_length(status, headers, body_bytes.len()) + } else { + build_status_and_headers(status, headers) + } + ); const MERGE_THRESHOLD: usize = 65_536; diff --git a/crates/framework/src/telemetry/mod.rs b/crates/framework/src/telemetry/mod.rs index c62c77bb..1d072f11 100644 --- a/crates/framework/src/telemetry/mod.rs +++ b/crates/framework/src/telemetry/mod.rs @@ -61,6 +61,22 @@ macro_rules! toggle_store { } pub(crate) use toggle_store; +/// Time an expression and record its duration in microseconds via a +/// `dispatch_metrics::record_*` function. +/// +/// Wraps the common `let t = Instant::now(); let v = expr; record(t.elapsed()); v` +/// pattern into a single call, enforcing consistent timing across the +/// request pipeline. +macro_rules! timed { + ($record:path, $expr:expr) => {{ + let __t0 = ::std::time::Instant::now(); + let __val = $expr; + $record(__t0.elapsed().as_micros() as f64); + __val + }}; +} +pub(crate) use timed; + /// State that can be refreshed before reading. /// /// Implemented by `SystemState` and `ProcessState` to share the diff --git a/src/apx/_core.pyi b/src/apx/_core.pyi index 02edbef2..2ef98970 100644 --- a/src/apx/_core.pyi +++ b/src/apx/_core.pyi @@ -33,6 +33,7 @@ class ProtocolFactory: class RustProtocol: """asyncio Protocol for HTTP/1.1 connections.""" + ... class HttpReceive: diff --git a/tests/integration/test_telemetry.py b/tests/integration/test_telemetry.py index c63f40a9..b9199bd5 100644 --- a/tests/integration/test_telemetry.py +++ b/tests/integration/test_telemetry.py @@ -1044,7 +1044,9 @@ def test_http_duration_has_sub_millisecond_boundaries( f"got {sub_ms} (all bounds: {dp.explicitBounds})" ) return - pytest.fail("http.server.request.duration histogram with explicitBounds not found") + pytest.fail( + "http.server.request.duration histogram with explicitBounds not found" + ) # --------------------------------------------------------------------------- From 4fd7bb9ad19bee6a8a5e93f43ae067a4fbca458c Mon Sep 17 00:00:00 2001 From: renardeinside Date: Thu, 2 Apr 2026 21:49:59 +0200 Subject: [PATCH 14/18] =?UTF-8?q?=F0=9F=8E=A8=20style:=20remove=20stale=20?= =?UTF-8?q?comments,=20dead=20code,=20and=20old=20metric=20names?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/framework/src/protocol/connection.rs | 34 ++++++--------------- src/apx/_bridge.py | 7 ----- src/apx/_continuation.py | 2 +- src/apx/_scheduler.py | 10 +++--- src/apx/telemetry.py | 2 +- 5 files changed, 17 insertions(+), 38 deletions(-) diff --git a/crates/framework/src/protocol/connection.rs b/crates/framework/src/protocol/connection.rs index 5d6911be..85776faa 100644 --- a/crates/framework/src/protocol/connection.rs +++ b/crates/framework/src/protocol/connection.rs @@ -6,7 +6,7 @@ use std::borrow::Cow; use std::net::SocketAddr; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::atomic::{AtomicU32, Ordering}; use std::time::Instant; use bytes::Bytes; @@ -33,7 +33,6 @@ struct ProtocolShared { server_host: String, server_port: u16, active_requests: AtomicU32, - write_paused: AtomicBool, } /// RAII guard for a concurrency slot in `active_requests`. @@ -110,7 +109,6 @@ impl ProtocolFactory { server_host, server_port, active_requests: AtomicU32::new(0), - write_paused: AtomicBool::new(false), }), } } @@ -167,7 +165,7 @@ impl RustProtocol { fn data_received(slf: &Bound<'_, Self>, py: Python<'_>, data: &[u8]) -> PyResult<()> { let py_self = slf.clone().unbind(); - // Phase 1: borrow_mut for pure Rust (parse + extract handles). + // Borrow mutably only for pure-Rust work (parse + extract handles). let (feed_result, keepalive_handle, error_transport) = { let mut this = py_self.borrow_mut(py); let keepalive = this.keepalive_handle.take(); @@ -182,7 +180,7 @@ impl RustProtocol { // borrow_mut dropped here — BEFORE any Python calls. }; - // Phase 2: Python calls with NO borrow held. + // All Python calls below — borrow is released to avoid PyO3 BorrowError. if let Some(handle) = keepalive_handle { let _ = handle.call_method0(py, pyo3::intern!(py, "cancel")); } @@ -220,15 +218,6 @@ impl RustProtocol { false } - /// Called by asyncio/uvloop when the transport's write buffer - /// exceeds the high-water mark. - /// - /// No-op stub: the method must exist to prevent uvloop from logging - /// "protocol.pause_writing() failed". We don't implement write-side - /// flow control — this is a safety stub to prevent uvloop errors. - /// - /// Empty body: we don't implement write-side flow control. - /// The `&self` borrow is needed for the asyncio protocol interface. /// Called by asyncio when the connection is lost. fn connection_lost(&mut self, py: Python<'_>, _exc: Option<&Bound<'_, PyAny>>) { self.cancel_keepalive_timer(py); @@ -236,17 +225,14 @@ impl RustProtocol { self.parser.reset(); dispatch_metrics::dec_connections(); } - /// Called by asyncio/uvloop when the transport's write buffer - /// exceeds the high-water mark. - fn pause_writing(&self) { - self.shared.write_paused.store(true, Ordering::Release); - } + /// Flow control callback from uvloop. Required by the asyncio + /// protocol interface; uvloop logs an error if the method is missing. + #[expect(clippy::unused_self, reason = "required by asyncio protocol interface")] + fn pause_writing(&self) {} - /// Called by asyncio/uvloop when the transport's write buffer - /// drains below the low-water mark. - fn resume_writing(&self) { - self.shared.write_paused.store(false, Ordering::Release); - } + /// Counterpart to `pause_writing` — write buffer drained. + #[expect(clippy::unused_self, reason = "required by asyncio protocol interface")] + fn resume_writing(&self) {} /// Close the connection if idle (no active requests). /// diff --git a/src/apx/_bridge.py b/src/apx/_bridge.py index 52b51ce4..da42b8ac 100644 --- a/src/apx/_bridge.py +++ b/src/apx/_bridge.py @@ -8,9 +8,6 @@ import logging from collections.abc import Callable -from typing import Any - - class _ApxHandler(logging.Handler): def __init__(self, emit_fn: Callable[[int, str, str, str], None]) -> None: super().__init__() @@ -28,7 +25,3 @@ def install_log_handler(emit_fn: Callable[[int, str, str, str], None]) -> None: handler = _ApxHandler(emit_fn) logging.root.addHandler(handler) logging.root.setLevel(logging.DEBUG) - - -async def resolved(val: Any) -> Any: - return val diff --git a/src/apx/_continuation.py b/src/apx/_continuation.py index beb57926..26462ba9 100644 --- a/src/apx/_continuation.py +++ b/src/apx/_continuation.py @@ -28,7 +28,7 @@ class Continuation: """Drives a suspended coroutine via done-callbacks. Each step uses per-step ``_enter_task`` / ``_leave_task`` brackets, - keeping invariant I1. Runs entirely on the asyncio thread. + keeping one task entered per loop at a time. Runs entirely on the asyncio thread. When an asyncio Future resolves, the continuation delivers the result (or exception) to the coroutine via ``drive_inline``'s diff --git a/src/apx/_scheduler.py b/src/apx/_scheduler.py index 086c28a8..5394c3f2 100644 --- a/src/apx/_scheduler.py +++ b/src/apx/_scheduler.py @@ -6,8 +6,8 @@ Safety: all driving happens on the asyncio thread during callback processing (``current_task() is None``). Per-step ``_enter_task`` / -``_leave_task`` brackets maintain invariant I1. See -``.plans/framework/io/pythonic-inlining.md`` for the full analysis. +``_leave_task`` brackets ensure only one task is entered at a time +per event loop. """ from __future__ import annotations @@ -116,7 +116,7 @@ class CallSoonCapture: While active, callbacks are captured into an internal queue instead of being appended to the event loop's ``_ready`` deque. This prevents the sentinel ``__step`` from ``SchedulerTask.__init__`` - (invariant I7) from polluting ``_run_once``. + from ``Task.__init__``'s ``call_soon(__step)`` polluting ``_run_once``. Captured callbacks are processed between drive steps via ``flush()`` or spilled back to the real ``call_soon`` on ``leave()``. @@ -126,7 +126,7 @@ class CallSoonCapture: # Queue entry: (callback, args, context). Context is preserved so # that Task.__step and Future done-callbacks run in their correct - # contextvars snapshot (invariant I2). + # contextvars snapshot for per-request isolation. _Entry = tuple[Callable[..., Any], tuple[Any, ...], contextvars.Context | None] def __init__(self, loop: asyncio.AbstractEventLoop) -> None: @@ -220,7 +220,7 @@ def drive_inline( Must be called from a ``_run_once`` callback where ``current_task() is None``. Uses per-step ``_enter_task`` / - ``_leave_task`` brackets (invariant I1). + ``_leave_task`` brackets (one task entered per loop at a time). On initial entry ``send_value`` is ``None`` (starts the coroutine). On continuation re-entry after a Future resolves, pass the Future's diff --git a/src/apx/telemetry.py b/src/apx/telemetry.py index 5ba9a0d7..e3ce5753 100644 --- a/src/apx/telemetry.py +++ b/src/apx/telemetry.py @@ -775,7 +775,7 @@ class ApxInstrumentation(BaseModel): Collected per-worker. Records per-phase histograms for the ASGI dispatch pipeline. If APX_PERF environment variable is not set, none of these metrics are collected:: - ApxInstrumentation(metrics=ApxMetrics(dispatch_total=True)) + ApxInstrumentation(metrics=ApxMetrics(request_total=True)) """ type: Literal["apx"] = "apx" From c6b1bcf6b6d15d4102b8f85033b67e4f2baf1637 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Fri, 3 Apr 2026 00:52:07 +0200 Subject: [PATCH 15/18] =?UTF-8?q?=F0=9F=9A=80=20perf:=20optimize=20continu?= =?UTF-8?q?ation=20resume=20with=20inline=20fast=20path?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/apx/_bridge.py | 2 + src/apx/_continuation.py | 102 ++++++++++++++++++++++----------------- src/apx/_scheduler.py | 19 ++++++++ 3 files changed, 79 insertions(+), 44 deletions(-) diff --git a/src/apx/_bridge.py b/src/apx/_bridge.py index da42b8ac..76492931 100644 --- a/src/apx/_bridge.py +++ b/src/apx/_bridge.py @@ -8,6 +8,8 @@ import logging from collections.abc import Callable + + class _ApxHandler(logging.Handler): def __init__(self, emit_fn: Callable[[int, str, str, str], None]) -> None: super().__init__() diff --git a/src/apx/_continuation.py b/src/apx/_continuation.py index 26462ba9..62b8aec8 100644 --- a/src/apx/_continuation.py +++ b/src/apx/_continuation.py @@ -76,35 +76,58 @@ def _attach(self, yielded: object) -> None: def _on_future_done(self, future: asyncio.Future[Any]) -> None: self._task._waiter = None - self._resolved_future = future - # If another task is currently entered (defensive guard), - # defer to next callback cycle. + # If another task is entered, defer to next callback cycle. if asyncio.current_task() is not None: + self._resolved_future = future self._loop.call_soon(self._step) return - self._step() + # Fast path: extract result inline and resume immediately. + # Avoids _step → _extract_resume → drive_inline indirection. + if future.cancelled(): + self._resume(None, asyncio.CancelledError()) + else: + fut_exc = future.exception() + if fut_exc is not None: + self._resume(None, fut_exc) + else: + self._resume(future.result(), None) - def _extract_resume( + def _resume( self, - ) -> tuple[Any, BaseException | None]: - """Extract send_value / send_exception from a resolved Future. + send_value: Any, + send_exception: BaseException | None, + ) -> None: + """Drive the coroutine with the given value or exception. - Mirrors the ``Task.__step`` protocol: deliver the Future's - result to the coroutine, or throw its exception. + Skips the ``_extract_resume`` overhead used by ``_step``. + Called directly from ``_on_future_done`` for the fast path. """ - future = self._resolved_future - self._resolved_future = None - if future is None: - # yield-None re-entry — no value to deliver. - return None, None - if future.cancelled(): - return None, asyncio.CancelledError() - exc = future.exception() - if exc is not None: - return None, exc - return future.result(), None + if self._coro is None: + return + self._capture.enter() + result = drive_inline( + self._coro, + self._task, + self._loop, + self._capture, + send_value=send_value, + send_exception=send_exception, + ) + self._capture.leave() + + if isinstance(result, Completed): + self._finish() + elif isinstance(result, Failed): + self._finish() + elif isinstance(result, Suspended): + self._attach(result.yielded) def _step(self) -> None: + """Resume from yield-None or deferred Future (via call_soon). + + Used when ``_on_future_done`` defers (another task entered) + or when the coroutine yielded None (asyncio.sleep(0)). + """ if self._coro is None: return @@ -117,18 +140,14 @@ def _step(self) -> None: asyncio.CancelledError(self._task._cancel_msg) ) except StopIteration: - # Coroutine caught CancelledError and returned normally. _leave_task(self._loop, self._task) self._finish() return except asyncio.CancelledError: - # Coroutine re-raised CancelledError — expected. _leave_task(self._loop, self._task) self._finish() return except BaseException as exc: - # Coroutine raised a different exception during cancel - # cleanup (e.g. error in a yield-dep finalizer). _leave_task(self._loop, self._task) _log_cancel_exception(exc) self._finish() @@ -137,26 +156,21 @@ def _step(self) -> None: self._attach(yielded) return - # Normal step: deliver the Future result and resume driving. - send_value, send_exception = self._extract_resume() - - self._capture.enter() - result = drive_inline( - self._coro, - self._task, - self._loop, - self._capture, - send_value=send_value, - send_exception=send_exception, - ) - self._capture.leave() - - if isinstance(result, Completed): - self._finish() - elif isinstance(result, Failed): - self._finish() - elif isinstance(result, Suspended): - self._attach(result.yielded) + # Deferred Future resume (from _on_future_done via call_soon). + future = self._resolved_future + self._resolved_future = None + if future is not None: + if future.cancelled(): + self._resume(None, asyncio.CancelledError()) + else: + fut_exc = future.exception() + if fut_exc is not None: + self._resume(None, fut_exc) + else: + self._resume(future.result(), None) + else: + # yield-None re-entry. + self._resume(None, None) def _finish(self) -> None: self._coro = None diff --git a/src/apx/_scheduler.py b/src/apx/_scheduler.py index 5394c3f2..cad23e5c 100644 --- a/src/apx/_scheduler.py +++ b/src/apx/_scheduler.py @@ -256,6 +256,25 @@ def drive_inline( if result is not None and getattr(result, "_asyncio_future_blocking", False): result._asyncio_future_blocking = False + # Fast path: if the Future is already resolved, extract its + # result and continue driving inline — avoids a full + # Continuation round-trip through the event loop. + if result.done(): + if result.cancelled(): + send_exception = asyncio.CancelledError() + send_value = None + else: + fut_exc = result.exception() + if fut_exc is not None: + send_exception = fut_exc + send_value = None + else: + send_value = result.result() + send_exception = None + budget -= 1 + if budget <= 0 or time.monotonic() > deadline: + return Suspended(None) + continue return Suspended(result) capture.flush() From 8bb262949c7483ad9f64608fc5d56c3182a51501 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Fri, 3 Apr 2026 01:35:16 +0200 Subject: [PATCH 16/18] =?UTF-8?q?=F0=9F=90=9B=20fix:=20add=20SQLite=20busy?= =?UTF-8?q?=5Ftimeout=20to=20eliminate=20write=20contention=20errors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/bench/app/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/bench/app/api.py b/scripts/bench/app/api.py index 46bf8143..b7c0da37 100644 --- a/scripts/bench/app/api.py +++ b/scripts/bench/app/api.py @@ -63,6 +63,7 @@ async def _seed_defaults(db: aiosqlite.Connection) -> None: async def lifespan(app: FastAPI): db = await aiosqlite.connect(DB_PATH) await db.execute("PRAGMA journal_mode=WAL") + await db.execute("PRAGMA busy_timeout=5000") await db.execute(_CREATE_TABLE) cursor = await db.execute("SELECT count(*) FROM items") row = await cursor.fetchone() From 0accff5bfe4db6421c04346da8d2e9d968595e39 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Tue, 7 Apr 2026 19:04:32 +0200 Subject: [PATCH 17/18] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20WebSocket=20sup?= =?UTF-8?q?port=20via=20wsproto=20bridge?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/framework/src/protocol/connection.rs | 170 ++++++++++++++++ crates/framework/src/supervision/worker.rs | 20 +- pyproject.toml | 1 + scripts/bench/app/api.py | 32 ++- src/apx/_server.py | 34 ++++ src/apx/_websocket.py | 215 ++++++++++++++++++++ tests/integration/test_websocket.py | 81 ++++++++ uv.lock | 14 ++ 8 files changed, 562 insertions(+), 5 deletions(-) create mode 100644 src/apx/_websocket.py create mode 100644 tests/integration/test_websocket.py diff --git a/crates/framework/src/protocol/connection.rs b/crates/framework/src/protocol/connection.rs index 85776faa..a74a84db 100644 --- a/crates/framework/src/protocol/connection.rs +++ b/crates/framework/src/protocol/connection.rs @@ -29,6 +29,7 @@ const KEEPALIVE_TIMEOUT_S: f64 = 5.0; /// Shared state for all protocol instances on this worker. struct ProtocolShared { on_request: Py, + on_ws_connect: Option>, interns: ScopeInterns, server_host: String, server_port: u16, @@ -98,6 +99,7 @@ impl ProtocolFactory { /// Create a factory with shared worker state (Rust-side constructor). pub fn new( on_request: Py, + on_ws_connect: Option>, interns: ScopeInterns, server_host: String, server_port: u16, @@ -105,6 +107,7 @@ impl ProtocolFactory { Self { shared: Arc::new(ProtocolShared { on_request, + on_ws_connect, interns, server_host, server_port, @@ -126,6 +129,7 @@ impl ProtocolFactory { shared: Arc::clone(&self.shared), client_addr: None, keepalive_handle: None, + ws_bridge: None, }, ) } @@ -142,6 +146,9 @@ pub struct RustProtocol { shared: Arc, client_addr: Option, keepalive_handle: Option>, + /// Active WebSocket bridge — when set, `data_received` forwards + /// raw bytes here instead of parsing HTTP. + ws_bridge: Option>, } crate::opaque_debug!(RustProtocol); @@ -165,6 +172,17 @@ impl RustProtocol { fn data_received(slf: &Bound<'_, Self>, py: Python<'_>, data: &[u8]) -> PyResult<()> { let py_self = slf.clone().unbind(); + // WebSocket fast path: if a bridge is active, forward raw bytes + // to the wsproto parser instead of the HTTP parser. + { + let this = py_self.borrow(py); + if let Some(bridge) = &this.ws_bridge { + let py_bytes = PyBytes::new(py, data); + bridge.call_method1(py, pyo3::intern!(py, "feed_data"), (py_bytes,))?; + return Ok(()); + } + } + // Borrow mutably only for pure-Rust work (parse + extract handles). let (feed_result, keepalive_handle, error_transport) = { let mut this = py_self.borrow_mut(py); @@ -221,6 +239,9 @@ impl RustProtocol { /// Called by asyncio when the connection is lost. fn connection_lost(&mut self, py: Python<'_>, _exc: Option<&Bound<'_, PyAny>>) { self.cancel_keepalive_timer(py); + if let Some(bridge) = self.ws_bridge.take() { + let _ = bridge.call_method0(py, pyo3::intern!(py, "connection_lost")); + } self.transport = None; self.parser.reset(); dispatch_metrics::dec_connections(); @@ -239,11 +260,20 @@ impl RustProtocol { /// Called by the event loop's `call_later` as the keep-alive timeout. fn close_idle(&mut self, py: Python<'_>) { if self.shared.active_requests.load(Ordering::Relaxed) == 0 + && self.ws_bridge.is_none() && let Some(transport) = &self.transport { let _ = transport.call_method0(py, pyo3::intern!(py, "close")); } } + + /// Store a WebSocket bridge reference on this protocol. + /// + /// Called from the Python-side WebSocket handler after upgrade. + /// Subsequent `data_received` calls will route to the bridge. + fn set_ws_bridge(&mut self, bridge: Py) { + self.ws_bridge = Some(bridge); + } } impl RustProtocol { @@ -309,6 +339,11 @@ impl RustProtocol { t_dispatch: Instant, slot: RequestSlot, ) -> PyResult<()> { + // WebSocket upgrade: detect and dispatch separately. + if is_websocket_upgrade(&parsed.head.headers) { + return Self::dispatch_websocket(py, py_self, parsed, slot); + } + // Borrow briefly to extract what we need, then release before // calling into Python (transport.write may trigger pause_writing // which needs to borrow &self). @@ -378,6 +413,141 @@ impl RustProtocol { on_request.call1(py, (scope, receive, send))?; Ok(()) } + + /// Dispatch a WebSocket upgrade request to the Python bridge. + /// + /// Builds the ASGI WebSocket scope, writes the 101 Switching + /// Protocols response, creates a `WebSocketBridge`, and stores + /// it so subsequent `data_received` calls route to it. + fn dispatch_websocket( + py: Python<'_>, + py_self: &Py, + parsed: ParsedRequest, + _slot: RequestSlot, + ) -> PyResult<()> { + let (transport, shared, client_addr) = { + let this = py_self.borrow(py); + let Some(transport) = &this.transport else { + return Ok(()); + }; + ( + transport.clone_ref(py), + Arc::clone(&this.shared), + this.client_addr, + ) + }; + + let Some(on_ws_connect) = &shared.on_ws_connect else { + // No WS handler registered — reject with 400. + transport_write(py, &transport, REJECT_BAD_REQUEST)?; + return Ok(()); + }; + + // Build WebSocket ASGI scope. + let scope = PyDict::new(py); + scope.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket"))?; + let asgi_dict = PyDict::new(py); + asgi_dict.set_item(pyo3::intern!(py, "version"), "3.0")?; + asgi_dict.set_item(pyo3::intern!(py, "spec_version"), "2.4")?; + scope.set_item(pyo3::intern!(py, "asgi"), asgi_dict)?; + scope.set_item(pyo3::intern!(py, "http_version"), "1.1")?; + scope.set_item(pyo3::intern!(py, "scheme"), "ws")?; + scope.set_item(pyo3::intern!(py, "path"), &parsed.head.path)?; + scope.set_item( + pyo3::intern!(py, "raw_path"), + PyBytes::new(py, parsed.head.path.as_bytes()), + )?; + scope.set_item( + pyo3::intern!(py, "query_string"), + PyBytes::new(py, &parsed.head.query_string), + )?; + scope.set_item(pyo3::intern!(py, "root_path"), "")?; + + // Build headers list (same format as HTTP scope). + let header_list = PyList::empty(py); + for (name, value) in &parsed.head.headers { + let tuple = PyTuple::new( + py, + [ + PyBytes::new(py, name).as_any(), + PyBytes::new(py, value).as_any(), + ], + )?; + header_list.append(tuple)?; + } + scope.set_item(pyo3::intern!(py, "headers"), header_list)?; + + // Client/server addresses. + if let Some(addr) = client_addr { + scope.set_item( + pyo3::intern!(py, "client"), + (addr.ip().to_string(), addr.port()), + )?; + } else { + scope.set_item(pyo3::intern!(py, "client"), py.None())?; + } + scope.set_item( + pyo3::intern!(py, "server"), + (&*shared.server_host, shared.server_port), + )?; + + // Extract subprotocols from Sec-WebSocket-Protocol header. + let subprotocols = PyList::empty(py); + for (name, value) in &parsed.head.headers { + if name.eq_ignore_ascii_case(b"sec-websocket-protocol") + && let Ok(s) = std::str::from_utf8(value) + { + for proto in s.split(',') { + subprotocols.append(proto.trim())?; + } + } + } + scope.set_item(pyo3::intern!(py, "subprotocols"), subprotocols)?; + scope.set_item(pyo3::intern!(py, "state"), PyDict::new(py))?; + + // Extract Sec-WebSocket-Key for the 101 response. + let mut ws_key = String::new(); + for (name, value) in &parsed.head.headers { + if name.eq_ignore_ascii_case(b"sec-websocket-key") + && let Ok(s) = std::str::from_utf8(value) + { + s.clone_into(&mut ws_key); + } + } + + // Call the Python-side WebSocket handler. + // It builds the 101 response, creates the bridge, and stores + // a reference back on this protocol via set_ws_bridge(). + let scope_obj = scope.unbind().into_any(); + on_ws_connect.call1(py, (scope_obj, &transport, &ws_key, py_self))?; + Ok(()) + } +} + +// ── WebSocket upgrade detection ───────────────────────────────── + +/// Check if a parsed HTTP request is a WebSocket upgrade. +fn is_websocket_upgrade(headers: &[(Bytes, Bytes)]) -> bool { + let mut has_upgrade = false; + let mut has_connection = false; + for (name, value) in headers { + if name.eq_ignore_ascii_case(b"upgrade") && value.eq_ignore_ascii_case(b"websocket") { + has_upgrade = true; + } + if name.eq_ignore_ascii_case(b"connection") { + for part in value.split(|&b| b == b',') { + let trimmed = part + .iter() + .copied() + .skip_while(u8::is_ascii_whitespace) + .collect::>(); + if trimmed.eq_ignore_ascii_case(b"upgrade") { + has_connection = true; + } + } + } + } + has_upgrade && has_connection } /// Callback invoked when a response is fully written. diff --git a/crates/framework/src/supervision/worker.rs b/crates/framework/src/supervision/worker.rs index 4801c469..b991b32b 100644 --- a/crates/framework/src/supervision/worker.rs +++ b/crates/framework/src/supervision/worker.rs @@ -235,14 +235,15 @@ fn run_server( py.run( c" import asyncio as _asyncio -from apx._server import serve as _serve, _build_on_request +from apx._server import serve as _serve, _build_on_request, _build_on_ws_connect from apx._scheduler import CallSoonCapture async def _boot(_app, _factory_fn, _host, _port, _shutdown_event): loop = _asyncio.get_running_loop() capture = CallSoonCapture(loop) on_request = _build_on_request(_app, loop, capture) - factory = _factory_fn(on_request) + on_ws_connect = _build_on_ws_connect(_app) + factory = _factory_fn(on_request, on_ws_connect) await _serve(_host, _port, _app, factory, shutdown_event=_shutdown_event) ", None, @@ -318,7 +319,12 @@ crate::opaque_debug!(FactoryBuilder); #[pymethods] impl FactoryBuilder { - fn __call__(&self, py: Python<'_>, on_request: Py) -> PyResult> { + fn __call__( + &self, + py: Python<'_>, + on_request: Py, + on_ws_connect: Option>, + ) -> PyResult> { let interns = self .interns .lock() @@ -327,7 +333,13 @@ impl FactoryBuilder { .ok_or_else(|| { pyo3::exceptions::PyRuntimeError::new_err("FactoryBuilder already consumed") })?; - let factory = ProtocolFactory::new(on_request, interns, self.host.clone(), self.port); + let factory = ProtocolFactory::new( + on_request, + on_ws_connect, + interns, + self.host.clone(), + self.port, + ); Py::new(py, factory) } } diff --git a/pyproject.toml b/pyproject.toml index 08525626..819a11e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "orjson>=3.11.7", "pydantic>=2.0", "uvloop>=0.21.0; sys_platform != 'win32'", + "wsproto>=1.2", ] classifiers = [ "Development Status :: 3 - Alpha", diff --git a/scripts/bench/app/api.py b/scripts/bench/app/api.py index b7c0da37..bf25ef47 100644 --- a/scripts/bench/app/api.py +++ b/scripts/bench/app/api.py @@ -5,7 +5,7 @@ from contextlib import asynccontextmanager import aiosqlite -from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect from fastapi.responses import Response, StreamingResponse from .models import Item, ItemCreate, ItemUpdate @@ -252,6 +252,36 @@ async def deps_endpoint(a: str = Depends(dep_level_a)): return {"chain": a} +# --------------------------------------------------------------------------- +# WebSocket endpoints +# --------------------------------------------------------------------------- + + +@router.websocket("/ws/echo") +async def ws_echo(ws: WebSocket): + """Echo WebSocket — returns each message back to the client.""" + await ws.accept() + try: + while True: + data = await ws.receive_text() + await ws.send_text(data) + except WebSocketDisconnect: + pass + + +@router.websocket("/ws/json") +async def ws_json(ws: WebSocket): + """JSON WebSocket — echoes JSON with a server timestamp.""" + await ws.accept() + try: + while True: + msg = await ws.receive_json() + msg["server_ts"] = asyncio.get_event_loop().time() + await ws.send_json(msg) + except WebSocketDisconnect: + pass + + # --------------------------------------------------------------------------- # Telemetry test endpoint # --------------------------------------------------------------------------- diff --git a/src/apx/_server.py b/src/apx/_server.py index 1284bee0..f726e67e 100644 --- a/src/apx/_server.py +++ b/src/apx/_server.py @@ -14,6 +14,7 @@ from apx._continuation import Continuation from apx._core import LifespanReceive, LifespanSend +from apx._websocket import WebSocketBridge, build_upgrade_response from apx._scheduler import ( CallSoonCapture, Completed, @@ -82,6 +83,39 @@ def on_request( return on_request +def _build_on_ws_connect( + app: Callable[..., Coroutine[Any, Any, None]], +) -> Callable[..., None]: + """Build the on_ws_connect callback for WebSocket upgrades. + + Called from Rust when an HTTP Upgrade: websocket request is detected. + Writes the 101 Switching Protocols response, creates a WebSocketBridge, + and stores it on the protocol so subsequent data_received calls route + to the wsproto parser. + """ + + def on_ws_connect( + scope: dict[str, Any], + transport: Any, + ws_key: str, + protocol: Any, + ) -> None: + logger.debug("WebSocket upgrade: key=%r len=%d", ws_key, len(ws_key)) + # Write the 101 Switching Protocols response. + response = build_upgrade_response(ws_key) + logger.debug("101 response: %r", response[:100]) + transport.write(response) + + # Create the bridge and store it on the protocol. + bridge = WebSocketBridge(transport, scope, app) + protocol.set_ws_bridge(bridge) + + # Start the ASGI WebSocket lifecycle as an asyncio task. + bridge.start() + + return on_ws_connect + + async def _run_lifespan( app: Callable[..., Coroutine[Any, Any, None]], shutdown_event: asyncio.Event, diff --git a/src/apx/_websocket.py b/src/apx/_websocket.py new file mode 100644 index 00000000..21412bce --- /dev/null +++ b/src/apx/_websocket.py @@ -0,0 +1,215 @@ +"""WebSocket ASGI bridge using wsproto for frame encoding/decoding. + +Bridges an asyncio transport to the ASGI WebSocket protocol. The HTTP +upgrade handshake (101 Switching Protocols) is handled externally by +the Rust protocol layer; this module handles only the post-handshake +WebSocket frame lifecycle. + +Uses wsproto as a sans-I/O state machine — we own the transport, wsproto +owns the frame parsing. +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import logging +from collections.abc import Callable, Coroutine +from typing import Any + +from wsproto import ConnectionType +from wsproto.connection import Connection +from wsproto.events import ( + BytesMessage, + CloseConnection, + Ping, + TextMessage, +) + +logger = logging.getLogger("apx.websocket") + +# RFC 6455 §4.2.2: magic GUID for Sec-WebSocket-Accept computation. +_WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def compute_accept_key(sec_websocket_key: str) -> str: + """Compute ``Sec-WebSocket-Accept`` from the client's key (RFC 6455 §4.2.2).""" + digest = hashlib.sha1(sec_websocket_key.encode() + _WS_GUID).digest() + return base64.b64encode(digest).decode() + + +def build_upgrade_response( + sec_websocket_key: str, + subprotocol: str | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, +) -> bytes: + """Build the HTTP 101 Switching Protocols response.""" + accept = compute_accept_key(sec_websocket_key) + lines = [ + b"HTTP/1.1 101 Switching Protocols", + b"Upgrade: websocket", + b"Connection: Upgrade", + f"Sec-WebSocket-Accept: {accept}".encode(), + ] + if subprotocol: + lines.append(f"Sec-WebSocket-Protocol: {subprotocol}".encode()) + if extra_headers: + for name, value in extra_headers: + lines.append(name + b": " + value) + lines.append(b"") + lines.append(b"") + return b"\r\n".join(lines) + + +class WebSocketBridge: + """Bridges an asyncio transport to the ASGI WebSocket protocol. + + The Rust protocol creates this after detecting an HTTP upgrade + request. Subsequent ``data_received`` calls on the protocol are + forwarded to ``feed_data``, which parses WebSocket frames via + wsproto and enqueues ASGI receive events. + + The ASGI app interacts with this bridge through the ``receive`` + and ``send`` callables passed to ``app(scope, receive, send)``. + """ + + __slots__ = ( + "_transport", + "_app", + "_scope", + "_ws", + "_receive_queue", + "_closed", + "_task", + ) + + def __init__( + self, + transport: Any, + scope: dict[str, Any], + app: Callable[..., Coroutine[Any, Any, None]], + ) -> None: + self._transport = transport + self._scope = scope + self._app = app + # wsproto server connection starts in OPEN state — handshake + # was handled by the Rust protocol (101 already written). + self._ws = Connection(ConnectionType.SERVER) + self._receive_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + self._closed = False + self._task: asyncio.Task[None] | None = None + + def start(self) -> None: + """Start the ASGI WebSocket lifecycle as an asyncio task.""" + # Queue the initial connect event per ASGI spec. + self._receive_queue.put_nowait({"type": "websocket.connect"}) + self._task = asyncio.get_running_loop().create_task(self._run()) + + async def _run(self) -> None: + """Run the ASGI app with WebSocket receive/send callables.""" + try: + await self._app(self._scope, self._receive, self._send) + except Exception: + logger.exception("WebSocket app error") + finally: + if not self._closed: + self._close_transport() + + async def _receive(self) -> dict[str, Any]: + """ASGI receive callable — blocks until a WS event is available.""" + return await self._receive_queue.get() + + async def _send(self, message: dict[str, Any]) -> None: + """ASGI send callable — processes WebSocket events from the app.""" + msg_type = message["type"] + if msg_type == "websocket.accept": + self._handle_accept(message) + elif msg_type == "websocket.send": + self._handle_send(message) + elif msg_type == "websocket.close": + self._handle_close(message) + + def feed_data(self, data: bytes) -> None: + """Feed raw bytes from the transport into the wsproto parser. + + Called by ``RustProtocol.data_received`` after the WebSocket + upgrade is complete. + """ + self._ws.receive_data(data) + for event in self._ws.events(): + if isinstance(event, TextMessage): + if event.message_finished: + self._receive_queue.put_nowait( + {"type": "websocket.receive", "text": event.data} + ) + elif isinstance(event, BytesMessage): + if event.message_finished: + self._receive_queue.put_nowait( + {"type": "websocket.receive", "bytes": event.data} + ) + elif isinstance(event, CloseConnection): + self._receive_queue.put_nowait( + { + "type": "websocket.disconnect", + "code": event.code or 1005, + } + ) + # Send close acknowledgment + data_out = self._ws.send(event.response()) + self._transport.write(data_out) + self._closed = True + elif isinstance(event, Ping): + # Auto-respond to pings + data_out = self._ws.send(event.response()) + self._transport.write(data_out) + + def connection_lost(self) -> None: + """Called when the transport connection is lost.""" + if not self._closed: + self._closed = True + self._receive_queue.put_nowait( + {"type": "websocket.disconnect", "code": 1006} + ) + + def _handle_accept(self, message: dict[str, Any]) -> None: + """Process websocket.accept — the 101 was already sent by Rust. + + Extract subprotocol and extra headers if the app provided them, + but in practice the 101 is already written. This is a no-op + for the transport but required by the ASGI lifecycle. + """ + # The 101 response was already written by Rust before this + # bridge was created. Nothing to do here. + + def _handle_send(self, message: dict[str, Any]) -> None: + """Process websocket.send — encode and write a WS frame.""" + if self._closed: + return + text = message.get("text") + data_bytes = message.get("bytes") + if text is not None: + out = self._ws.send(TextMessage(data=text)) + elif data_bytes is not None: + out = self._ws.send(BytesMessage(data=data_bytes)) + else: + return + self._transport.write(out) + + def _handle_close(self, message: dict[str, Any]) -> None: + """Process websocket.close — send close frame and close transport.""" + if self._closed: + return + code = message.get("code", 1000) + reason = message.get("reason", "") + out = self._ws.send(CloseConnection(code=code, reason=reason)) + self._transport.write(out) + self._closed = True + self._close_transport() + + def _close_transport(self) -> None: + """Close the underlying transport.""" + try: + self._transport.close() + except Exception: + pass diff --git a/tests/integration/test_websocket.py b/tests/integration/test_websocket.py new file mode 100644 index 00000000..8dd0c1f7 --- /dev/null +++ b/tests/integration/test_websocket.py @@ -0,0 +1,81 @@ +"""WebSocket integration tests for APX. + +Tests run against the bench app's WebSocket endpoints (ws/echo, ws/json) +inside the Docker container managed by the session fixtures. +""" + +from __future__ import annotations + +import asyncio +import json + +import pytest +import websockets + + +@pytest.fixture(scope="session") +def ws_url(apx_container: str) -> str: + """WebSocket base URL derived from the HTTP container URL.""" + return apx_container.replace("http://", "ws://") + + +@pytest.mark.integration +class TestWebSocketEcho: + def test_single_message(self, ws_url: str) -> None: + async def _run(): + async with websockets.connect(f"{ws_url}/api/ws/echo") as ws: + await ws.send("hello") + reply = await asyncio.wait_for(ws.recv(), timeout=5) + assert reply == "hello" + + asyncio.run(_run()) + + def test_multiple_messages(self, ws_url: str) -> None: + async def _run(): + async with websockets.connect(f"{ws_url}/api/ws/echo") as ws: + for i in range(10): + await ws.send(f"msg-{i}") + reply = await asyncio.wait_for(ws.recv(), timeout=5) + assert reply == f"msg-{i}" + + asyncio.run(_run()) + + def test_concurrent_connections(self, ws_url: str) -> None: + async def worker(n: int) -> None: + async with websockets.connect(f"{ws_url}/api/ws/echo") as ws: + for i in range(5): + msg = f"w{n}-{i}" + await ws.send(msg) + reply = await asyncio.wait_for(ws.recv(), timeout=5) + assert reply == msg + + async def _run(): + await asyncio.gather(*[worker(i) for i in range(10)]) + + asyncio.run(_run()) + + def test_client_disconnect(self, ws_url: str) -> None: + """Server handles client disconnect gracefully.""" + + async def _run(): + async with websockets.connect(f"{ws_url}/api/ws/echo") as ws: + await ws.send("hello") + await ws.recv() + # Connection closed — no crash expected. + + asyncio.run(_run()) + + +@pytest.mark.integration +class TestWebSocketJSON: + def test_json_echo(self, ws_url: str) -> None: + async def _run(): + async with websockets.connect(f"{ws_url}/api/ws/json") as ws: + payload = {"action": "ping", "n": 42} + await ws.send(json.dumps(payload)) + reply = json.loads(await asyncio.wait_for(ws.recv(), timeout=5)) + assert reply["action"] == "ping" + assert reply["n"] == 42 + assert "server_ts" in reply + + asyncio.run(_run()) diff --git a/uv.lock b/uv.lock index 99228028..5cf1e851 100644 --- a/uv.lock +++ b/uv.lock @@ -36,6 +36,7 @@ dependencies = [ { name = "orjson" }, { name = "pydantic" }, { name = "uvloop", marker = "sys_platform != 'win32'" }, + { name = "wsproto" }, ] [package.dev-dependencies] @@ -89,6 +90,7 @@ requires-dist = [ { name = "orjson", specifier = ">=3.11.7" }, { name = "pydantic", specifier = ">=2.0" }, { name = "uvloop", marker = "sys_platform != 'win32'", specifier = ">=0.21.0" }, + { name = "wsproto", specifier = ">=1.2" }, ] [package.metadata.requires-dev] @@ -1909,3 +1911,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" }, { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, ] + +[[package]] +name = "wsproto" +version = "1.3.2" +source = { registry = "https://pypi-proxy.dev.databricks.com/simple/" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" }, +] From 226384c587e4bf41126224ffa198a2650b6017b9 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Tue, 7 Apr 2026 19:22:08 +0200 Subject: [PATCH 18/18] =?UTF-8?q?=E2=9C=85=20test:=20add=20thorough=20WebS?= =?UTF-8?q?ocket=20integration=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/bench/app/api.py | 43 +++ tests/integration/test_websocket.py | 407 ++++++++++++++++++++++++++-- 2 files changed, 424 insertions(+), 26 deletions(-) diff --git a/scripts/bench/app/api.py b/scripts/bench/app/api.py index bf25ef47..37420190 100644 --- a/scripts/bench/app/api.py +++ b/scripts/bench/app/api.py @@ -282,6 +282,49 @@ async def ws_json(ws: WebSocket): pass +@router.websocket("/ws/binary") +async def ws_binary(ws: WebSocket): + """Binary echo WebSocket — returns each binary frame unchanged.""" + await ws.accept() + try: + while True: + data = await ws.receive_bytes() + await ws.send_bytes(data) + except WebSocketDisconnect: + pass + + +@router.websocket("/ws/close-with-code") +async def ws_close_with_code(ws: WebSocket): + """Accepts, reads one JSON message, then closes with the requested code.""" + await ws.accept() + msg = await ws.receive_json() + await ws.close(code=msg.get("code", 1000), reason=msg.get("reason", "")) + + +@router.websocket("/ws/reject") +async def ws_reject(ws: WebSocket): + """Returns immediately without calling accept.""" + pass + + +@router.websocket("/ws/error-in-handler") +async def ws_error_in_handler(ws: WebSocket): + """Accepts then raises to exercise server-side error handling.""" + await ws.accept() + raise RuntimeError("deliberate test error") + + +@router.websocket("/ws/subprotocol") +async def ws_subprotocol(ws: WebSocket): + """Echoes the negotiated subprotocol back as JSON.""" + protos: list[str] = ws.scope.get("subprotocols", []) + selected = protos[0] if protos else "" + await ws.accept(subprotocol=selected or None) + await ws.send_json({"selected": selected}) + await ws.close() + + # --------------------------------------------------------------------------- # Telemetry test endpoint # --------------------------------------------------------------------------- diff --git a/tests/integration/test_websocket.py b/tests/integration/test_websocket.py index 8dd0c1f7..ec9428c5 100644 --- a/tests/integration/test_websocket.py +++ b/tests/integration/test_websocket.py @@ -1,16 +1,25 @@ """WebSocket integration tests for APX. -Tests run against the bench app's WebSocket endpoints (ws/echo, ws/json) -inside the Docker container managed by the session fixtures. +Tests run against the bench app's WebSocket endpoints inside the Docker +container managed by the session fixtures. """ from __future__ import annotations import asyncio import json +import os +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any +import httpx import pytest -import websockets +from websockets.asyncio.client import ClientConnection, connect +from websockets.exceptions import ConnectionClosed, InvalidStatus +from websockets.frames import CloseCode + +RECV_TIMEOUT: float = 5.0 @pytest.fixture(scope="session") @@ -19,37 +28,63 @@ def ws_url(apx_container: str) -> str: return apx_container.replace("http://", "ws://") +@asynccontextmanager +async def ws_connect( + base_url: str, + path: str, + **kwargs: Any, +) -> AsyncIterator[ClientConnection]: + """Typed wrapper around websockets.connect.""" + async with connect(f"{base_url}{path}", **kwargs) as ws: + yield ws + + +async def echo_roundtrip(ws: ClientConnection, message: str) -> str: + """Send a text message and return the echo reply.""" + await ws.send(message) + reply = await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + assert isinstance(reply, str) + return reply + + +async def binary_roundtrip(ws: ClientConnection, data: bytes) -> bytes: + """Send a binary message and return the echo reply.""" + await ws.send(data) + reply = await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + assert isinstance(reply, bytes) + return reply + + +# --------------------------------------------------------------------------- +# Echo +# --------------------------------------------------------------------------- + + @pytest.mark.integration class TestWebSocketEcho: def test_single_message(self, ws_url: str) -> None: - async def _run(): - async with websockets.connect(f"{ws_url}/api/ws/echo") as ws: - await ws.send("hello") - reply = await asyncio.wait_for(ws.recv(), timeout=5) - assert reply == "hello" + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/echo") as ws: + assert await echo_roundtrip(ws, "hello") == "hello" asyncio.run(_run()) def test_multiple_messages(self, ws_url: str) -> None: - async def _run(): - async with websockets.connect(f"{ws_url}/api/ws/echo") as ws: + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/echo") as ws: for i in range(10): - await ws.send(f"msg-{i}") - reply = await asyncio.wait_for(ws.recv(), timeout=5) - assert reply == f"msg-{i}" + assert await echo_roundtrip(ws, f"msg-{i}") == f"msg-{i}" asyncio.run(_run()) def test_concurrent_connections(self, ws_url: str) -> None: async def worker(n: int) -> None: - async with websockets.connect(f"{ws_url}/api/ws/echo") as ws: + async with ws_connect(ws_url, "/api/ws/echo") as ws: for i in range(5): msg = f"w{n}-{i}" - await ws.send(msg) - reply = await asyncio.wait_for(ws.recv(), timeout=5) - assert reply == msg + assert await echo_roundtrip(ws, msg) == msg - async def _run(): + async def _run() -> None: await asyncio.gather(*[worker(i) for i in range(10)]) asyncio.run(_run()) @@ -57,25 +92,345 @@ async def _run(): def test_client_disconnect(self, ws_url: str) -> None: """Server handles client disconnect gracefully.""" - async def _run(): - async with websockets.connect(f"{ws_url}/api/ws/echo") as ws: - await ws.send("hello") - await ws.recv() - # Connection closed — no crash expected. + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/echo") as ws: + await echo_roundtrip(ws, "hello") asyncio.run(_run()) +# --------------------------------------------------------------------------- +# JSON echo +# --------------------------------------------------------------------------- + + @pytest.mark.integration class TestWebSocketJSON: def test_json_echo(self, ws_url: str) -> None: - async def _run(): - async with websockets.connect(f"{ws_url}/api/ws/json") as ws: + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/json") as ws: payload = {"action": "ping", "n": 42} await ws.send(json.dumps(payload)) - reply = json.loads(await asyncio.wait_for(ws.recv(), timeout=5)) + raw = await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + assert isinstance(raw, str) + reply: dict[str, Any] = json.loads(raw) assert reply["action"] == "ping" assert reply["n"] == 42 assert "server_ts" in reply asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Binary +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWebSocketBinary: + @pytest.mark.parametrize( + "payload", + [ + pytest.param(b"\x00\x01\x02\xff", id="small"), + pytest.param(b"", id="empty"), + pytest.param(b"\x80" * 1024, id="1kb-repeated"), + ], + ) + def test_binary_echo(self, ws_url: str, payload: bytes) -> None: + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/binary") as ws: + assert await binary_roundtrip(ws, payload) == payload + + asyncio.run(_run()) + + def test_large_binary_payload(self, ws_url: str) -> None: + payload = os.urandom(1_000_000) + + async def _run() -> None: + async with ws_connect( + ws_url, + "/api/ws/binary", + max_size=2_000_000, + ) as ws: + assert await binary_roundtrip(ws, payload) == payload + + asyncio.run(_run()) + + def test_mixed_text_and_binary(self, ws_url: str) -> None: + """Alternate between text and binary on separate connections.""" + + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/echo") as text_ws: + async with ws_connect(ws_url, "/api/ws/binary") as bin_ws: + for i in range(5): + assert await echo_roundtrip(text_ws, f"t-{i}") == f"t-{i}" + data = os.urandom(64) + assert await binary_roundtrip(bin_ws, data) == data + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Close codes +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWebSocketClose: + @pytest.mark.parametrize( + ("code", "reason"), + [ + pytest.param(1000, "normal", id="normal"), + pytest.param(1001, "going away", id="going-away"), + pytest.param(4000, "app-defined", id="app-defined"), + ], + ) + def test_server_close(self, ws_url: str, code: int, reason: str) -> None: + """Server closes with a specific code after receiving a JSON command.""" + + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/close-with-code") as ws: + await ws.send(json.dumps({"code": code, "reason": reason})) + with pytest.raises(ConnectionClosed) as exc_info: + await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + assert exc_info.value.rcvd is not None + assert exc_info.value.rcvd.code == code + + asyncio.run(_run()) + + @pytest.mark.parametrize( + "code", + [ + pytest.param(CloseCode.NORMAL_CLOSURE, id="normal"), + pytest.param(CloseCode.GOING_AWAY, id="going-away"), + ], + ) + def test_client_close(self, ws_url: str, code: CloseCode) -> None: + """Client-initiated close with specific code; server stays healthy.""" + + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/echo") as ws: + await echo_roundtrip(ws, "before-close") + await ws.close(code) + + async with ws_connect(ws_url, "/api/ws/echo") as ws: + assert await echo_roundtrip(ws, "still-alive") == "still-alive" + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Reject / routing errors +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWebSocketReject: + def test_handler_rejects(self, ws_url: str) -> None: + """App returns without calling accept; connection closes promptly. + + APX eagerly sends 101 before the ASGI app runs, so the WS + handshake succeeds but the server-side immediately tears down. + """ + + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/reject") as ws: + with pytest.raises(ConnectionClosed): + await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + + asyncio.run(_run()) + + def test_nonexistent_route(self, ws_url: str) -> None: + """Connecting to an unregistered path yields a refused upgrade.""" + + async def _run() -> None: + with pytest.raises((InvalidStatus, ConnectionClosed, OSError)): + async with ws_connect(ws_url, "/api/ws/does-not-exist") as ws: + await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + + asyncio.run(_run()) + + def test_http_get_on_ws_route(self, client: httpx.Client) -> None: + """Plain HTTP GET to a WebSocket route returns an error, not a crash.""" + r = client.get("/api/ws/echo") + assert r.status_code >= 400 + + +# --------------------------------------------------------------------------- +# Payload edge cases +# --------------------------------------------------------------------------- + +_UNICODE_SAMPLES: str = ( + "\U0001f600" # emoji (grinning face) + "\u4e16\u754c" # CJK (世界) + "\u0645\u0631\u062d\u0628\u0627" # Arabic (مرحبا) + "\u00e9\u00e0\u00fc" # Latin accented +) + + +@pytest.mark.integration +class TestWebSocketPayload: + @pytest.mark.parametrize( + "text", + [ + pytest.param("", id="empty"), + pytest.param("a", id="single-char"), + pytest.param("hello world", id="ascii"), + pytest.param(_UNICODE_SAMPLES, id="unicode"), + pytest.param("x" * 1_000_000, id="1mb"), + ], + ) + def test_text_payload(self, ws_url: str, text: str) -> None: + async def _run() -> None: + async with ws_connect( + ws_url, "/api/ws/echo", max_size=2_000_000 + ) as ws: + assert await echo_roundtrip(ws, text) == text + + asyncio.run(_run()) + + @pytest.mark.parametrize( + "value", + [ + pytest.param(None, id="null"), + pytest.param([], id="empty-list"), + pytest.param({}, id="empty-dict"), + pytest.param( + {"nested": {"a": [1, 2, None]}}, + id="nested", + ), + pytest.param("", id="empty-string"), + pytest.param(0, id="zero"), + pytest.param(True, id="bool"), + ], + ) + def test_json_special_values( + self, ws_url: str, value: Any + ) -> None: + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/json") as ws: + payload = {"action": "test", "v": value} + await ws.send(json.dumps(payload)) + raw = await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + assert isinstance(raw, str) + reply: dict[str, Any] = json.loads(raw) + assert reply["action"] == "test" + assert reply["v"] == value + assert "server_ts" in reply + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Concurrency +# --------------------------------------------------------------------------- + +RAPID_FIRE_COUNT: int = 100 +CONCURRENT_WORKERS: int = 10 +MESSAGES_PER_WORKER: int = 5 + + +@pytest.mark.integration +class TestWebSocketConcurrency: + def test_rapid_fire(self, ws_url: str) -> None: + """Send many messages without waiting, then verify all replies in order.""" + + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/echo") as ws: + for i in range(RAPID_FIRE_COUNT): + await ws.send(f"rf-{i}") + + for i in range(RAPID_FIRE_COUNT): + reply = await asyncio.wait_for( + ws.recv(), timeout=RECV_TIMEOUT + ) + assert reply == f"rf-{i}" + + asyncio.run(_run()) + + def test_reconnect_after_close(self, ws_url: str) -> None: + """Close, reconnect, and verify the second session works.""" + + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/echo") as ws: + assert await echo_roundtrip(ws, "session-1") == "session-1" + + async with ws_connect(ws_url, "/api/ws/echo") as ws: + assert await echo_roundtrip(ws, "session-2") == "session-2" + + asyncio.run(_run()) + + def test_concurrent_binary(self, ws_url: str) -> None: + """Multiple workers sending binary frames concurrently.""" + + async def worker(n: int) -> None: + async with ws_connect(ws_url, "/api/ws/binary") as ws: + for i in range(MESSAGES_PER_WORKER): + data = f"w{n}-{i}".encode() + assert await binary_roundtrip(ws, data) == data + + async def _run() -> None: + await asyncio.gather( + *[worker(i) for i in range(CONCURRENT_WORKERS)] + ) + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Error recovery +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWebSocketErrorRecovery: + def test_server_error_closes_cleanly(self, ws_url: str) -> None: + """Server-side RuntimeError closes the WS; server stays healthy.""" + + async def _run() -> None: + async with ws_connect(ws_url, "/api/ws/error-in-handler") as ws: + with pytest.raises(ConnectionClosed): + await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + + async with ws_connect(ws_url, "/api/ws/echo") as ws: + assert await echo_roundtrip(ws, "post-error") == "post-error" + + asyncio.run(_run()) + + def test_send_after_close(self, ws_url: str) -> None: + """Sending on a closed connection raises ConnectionClosed.""" + + async def _run() -> None: + ws: ClientConnection + async with ws_connect(ws_url, "/api/ws/echo") as ws: + await echo_roundtrip(ws, "before") + await ws.close() + + with pytest.raises(ConnectionClosed): + await ws.send("after-close") + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Subprotocol +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWebSocketSubprotocol: + def test_subprotocol_negotiation(self, ws_url: str) -> None: + """Server receives offered subprotocols and echoes the selected one.""" + + async def _run() -> None: + async with ws_connect( + ws_url, + "/api/ws/subprotocol", + subprotocols=["graphql-ws", "graphql-transport-ws"], + ) as ws: + raw = await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + assert isinstance(raw, str) + reply: dict[str, str] = json.loads(raw) + assert reply["selected"] == "graphql-ws" + + asyncio.run(_run())