diff --git a/DESIGN.md b/DESIGN.md index ebaf0e3..fff5158 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -104,6 +104,8 @@ ServeHTTP(w, r) - `WithAutoRouterInterceptor(i)` — Add interceptor to chain - `WithAutoRouterHTTPClient(c)` — Custom HTTP client - `WithAutoRouterFallbackProvider(p)` — Provider when detection fails +- `WithAutoRouterWebSocket(upgrader, dialer)` — Enable WebSocket mode (opt-in, see [WebSocket Mode](#websocket-mode)) +- `WithAutoRouterWSBillingCallback(cb)` — Per-turn billing callback for WebSocket connections **Example:** @@ -462,13 +464,24 @@ All built-in providers implement this interface for streaming support. ### Usage Extraction -**OpenAI**: Usage is sent in the final chunk before `[DONE]`: +**OpenAI Chat Completions**: Usage is sent in the final chunk before `[DONE]`: ```json data: {"id":"...","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}} data: [DONE] ``` +**OpenAI Responses API**: Usage is in the `response.completed` event. The `StreamingMultiAPIExtractor` automatically detects the API type from the request context and dispatches to the correct streaming extractor: + +```json +data: {"type":"response.created","response":{"id":"resp_123","model":"gpt-4o"}} +data: {"type":"response.output_text.delta","delta":"Hello"} +data: {"type":"response.completed","response":{"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}} +data: [DONE] +``` + +No `stream_options.include_usage` is needed for the Responses API — usage is always included in `response.completed`. The proxy automatically skips `stream_options` injection for Responses API requests. + **Anthropic**: Usage is sent in `message_start` and `message_delta` events: ```json @@ -480,7 +493,7 @@ data: {"type":"message_stop"} ### Auto stream_options Injection -When `BillingCalculator` is configured and the request has `stream: true`, the proxy automatically injects: +When `BillingCalculator` is configured and the request has `stream: true`, the proxy automatically injects `stream_options.include_usage` for **OpenAI-compatible Chat Completions** endpoints only: ```json { @@ -489,7 +502,13 @@ When `BillingCalculator` is configured and the request has `stream: true`, the p } ``` -This ensures OpenAI returns token usage in the streaming response for billing calculation. +This ensures providers return token usage in their streaming responses for billing calculation. + +The following are **excluded** from this injection because they already include usage natively: + +- **Responses API** — usage is always present in the `response.completed` event +- **Anthropic** — usage is sent in `message_start` and `message_delta` events +- **Bedrock** and **Google AI** — usage is included in their streaming event formats ### Efficient Flushing @@ -526,6 +545,189 @@ After the stream completes, the billing callback is invoked with the extracted t --- +## WebSocket Mode + +The proxy supports the OpenAI Responses API WebSocket mode for persistent, multi-turn connections. This is useful for tool-call-heavy workflows where multiple `response.create` / `response.completed` cycles happen on a single connection. + +### Adapter Pattern (Zero Dependencies) + +The library defines abstract WebSocket interfaces — **no WebSocket library is vendored**. Consumers bring their own implementation (gorilla/websocket, nhooyr.io/websocket, etc.) and wire it in via thin adapters. + +#### Interfaces + +```go +// WSConn abstracts a WebSocket connection. +// gorilla/websocket's *Conn satisfies this directly. +type WSConn interface { + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error + Close() error +} + +// WSUpgrader upgrades an HTTP request to a WebSocket connection. +type WSUpgrader interface { + Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (WSConn, error) +} + +// WSDialer dials a WebSocket connection to an upstream server. +type WSDialer interface { + DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (WSConn, *http.Response, error) +} + +// WebSocketCapableProvider is implemented by providers that support WebSocket mode. +type WebSocketCapableProvider interface { + Provider + WebSocketURL(meta BodyMetadata) (*url.URL, error) +} +``` + +The OpenAI provider implements `WebSocketCapableProvider`. Other providers can opt in by implementing the same interface. + +#### gorilla/websocket Example + +gorilla's `*websocket.Conn` already satisfies `WSConn` — no wrapper needed. You only need thin adapters for `Upgrader` and `Dialer` because their return types differ (`*websocket.Conn` vs `WSConn`): + +```go +package myadapter + +import ( + "context" + "net/http" + + "github.com/agentuity/llmproxy" + "github.com/gorilla/websocket" +) + +// GorillaUpgrader wraps gorilla's Upgrader to satisfy llmproxy.WSUpgrader. +type GorillaUpgrader struct { + Upgrader websocket.Upgrader +} + +func (u *GorillaUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, h http.Header) (llmproxy.WSConn, error) { + conn, err := u.Upgrader.Upgrade(w, r, h) + if err != nil { + return nil, err + } + return conn, nil // *websocket.Conn satisfies WSConn +} + +// GorillaDialer wraps gorilla's Dialer to satisfy llmproxy.WSDialer. +type GorillaDialer struct { + Dialer websocket.Dialer +} + +func (d *GorillaDialer) DialContext(ctx context.Context, urlStr string, h http.Header) (llmproxy.WSConn, *http.Response, error) { + conn, resp, err := d.Dialer.DialContext(ctx, urlStr, h) + if err != nil { + return nil, resp, err + } + return conn, resp, nil // *websocket.Conn satisfies WSConn +} +``` + +#### Wiring It Together + +```go +router := llmproxy.NewAutoRouter( + llmproxy.WithAutoRouterFallbackProvider(openaiProvider), + // In production, configure CheckOrigin to validate against trusted origins. + llmproxy.WithAutoRouterWebSocket( + &myadapter.GorillaUpgrader{ + Upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + // Validate r.Header.Get("Origin") against a whitelist + return isAllowedOrigin(r) + }, + }, + }, + &myadapter.GorillaDialer{ + Dialer: websocket.Dialer{}, + }, + ), + llmproxy.WithAutoRouterWSBillingCallback(func(turn int, meta llmproxy.ResponseMetadata, billing *llmproxy.BillingResult) { + log.Printf("Turn %d: model=%s prompt=%d completion=%d", + turn, meta.Model, meta.Usage.PromptTokens, meta.Usage.CompletionTokens) + if billing != nil { + log.Printf(" Cost: $%.6f", billing.TotalCost) + } + }), +) +router.RegisterProvider(openaiProvider) + +http.Handle("/", router) +``` + +WebSocket support is **opt-in** — if `WithAutoRouterWebSocket` is not called, WebSocket upgrade requests are rejected and normal HTTP handling is unchanged. + +### WebSocket Protocol + +The OpenAI Responses API WebSocket mode operates at `wss://api.openai.com/v1/responses`: + +| Aspect | Detail | +|--------|--------| +| **Client sends** | `{"type":"response.create","model":"gpt-4o","input":[...]}` | +| **Server sends** | Same events as SSE: `response.created`, `response.output_text.delta`, `response.completed`, etc. | +| **Multi-turn** | Client sends new `response.create` with `previous_response_id` on same connection | +| **Usage** | In `response.completed` → `response.usage.{input_tokens, output_tokens, total_tokens}` | +| **Frames** | JSON text frames — no `data:` prefix (unlike SSE) | + +### WebSocket Flow + +```text ++------------------+ +------------------+ +------------------+ +| Client | | AutoRouter | | Upstream | +| | | | | (OpenAI) | ++--------+---------+ +--------+---------+ +--------+---------+ + | | | + 1. WS Upgrade --------> 2. Accept upgrade | + | 3. Read first message | + | (response.create) | + | 4. Detect provider/model | + | 5. Strip model prefix | + | 6. Enrich headers (auth) | + | 7. Dial upstream WS --------> 8. Accept + | 9. Forward first msg --------> | + | | | + | === Bidirectional Relay === | + | | | + 10. response.create -----> Strip prefix ------------> response.create + | | | + | | <----------- response.created + response.created <------ | | + | | <----------- response.output_text.delta + text delta <---------------- | | + | | <----------- response.completed + response.completed <---- Extract usage, (includes usage) + | compute billing | + | | | + 11. response.create -----> Strip prefix ------------> (next turn) + | ... ... + | | | + 12. Close ----------------> Close both ------------> Close +``` + +### Billing + +Billing is calculated **per turn** — each `response.completed` event triggers: + +1. Usage extraction (`input_tokens`, `output_tokens`, `total_tokens`, cache details) +2. Cost calculation via `BillingCalculator` (if configured) +3. `WSBillingCallback` invocation with the turn number, response metadata, and billing result + +For multi-turn connections, each turn is billed independently. The callback receives an incrementing turn counter so consumers can aggregate if needed. + +### Model Prefix Stripping + +Works the same as HTTP mode — `openai/gpt-4o` is stripped to `gpt-4o` in all `response.create` messages. Non-`response.create` messages are forwarded byte-for-byte without modification. + +### Connection Lifecycle + +- Both relay goroutines use a `sync.Once`-guarded close — when either side closes, the other is closed immediately +- WebSocket close errors (`io.EOF`, "connection closed") are treated as normal termination, not errors +- The `ForwardWebSocket` method blocks until both relay goroutines complete + +--- + ## Providers Nine providers are included. Six share the OpenAI-compatible base; three have fully custom implementations. @@ -554,7 +756,7 @@ The OpenAI provider also supports the **Responses API** (`/v1/responses`) with a **OpenAI** — Wraps `openai_compatible` with support for multiple APIs: - **Chat Completions** (`/v1/chat/completions`) — Standard messages-based API -- **Responses** (`/v1/responses`) — Newer API with `input` field, built-in tools support +- **Responses** (`/v1/responses`) — Newer API with `input` field, built-in tools support. Supports both HTTP (SSE streaming) and WebSocket modes. - **Legacy Completions** (`/v1/completions`) — Older prompt-based API The provider auto-detects the API type from the request body: @@ -562,6 +764,8 @@ The provider auto-detects the API type from the request body: - `prompt` field → Completions API - `messages` field → Chat Completions API +The OpenAI provider implements `WebSocketCapableProvider`, enabling persistent WebSocket connections for multi-turn Responses API workflows when `WithAutoRouterWebSocket` is configured. + **Anthropic** — Custom body parser translates between the proxy's canonical format and Anthropic's Messages API. Custom extractor maps Anthropic's response shape (content blocks, stop_reason) back to `ResponseMetadata`. Auth uses the `x-api-key` header alongside an `anthropic-version` header. **Groq, Fireworks, x.AI** — Each wraps `openai_compatible` with its own base URL and provider name. No custom parsing or extraction logic needed. @@ -1017,6 +1221,7 @@ Matches the signature of `github.com/agentuity/go-common/logger` without requiri llmproxy/ ├── apitype.go # API type detection and constants ├── autorouter.go # AutoRouter, provider/API auto-detection, streaming +├── autorouter_websocket.go # WebSocket forwarding, bidirectional relay ├── billing.go # CostInfo, CostLookup, BillingResult, CalculateCost ├── billing_calculator.go # BillingCalculator for streaming/non-streaming ├── detection.go # Provider detection from model/header @@ -1034,6 +1239,7 @@ llmproxy/ ├── registry.go # Registry interface, MapRegistry ├── resolver.go # URLResolver interface ├── streaming.go # SSE parser, streaming types, usage extraction +├── websocket.go # WebSocket interfaces, message parsing, usage extraction ├── interceptors/ │ ├── addheader.go # AddHeaderInterceptor │ ├── billing.go # BillingInterceptor @@ -1056,12 +1262,14 @@ llmproxy/ │ ├── googleai/ # Google AI Gemini │ │ └── streaming_extractor.go # Google AI streaming │ ├── groq/ # Groq (OpenAI-compatible) -│ ├── openai/ # OpenAI (Chat Completions + Responses) +│ ├── openai/ # OpenAI (Chat Completions + Responses + WebSocket) │ ├── openai_compatible/ # Base for OpenAI-compatible providers -│ │ ├── multiapi.go # Multi-API parser/extractor -│ │ ├── streaming_extractor.go # SSE streaming with usage extraction -│ │ ├── responses_parser.go # Responses API parser -│ │ └── responses_extractor.go # Responses API extractor +│ │ ├── multiapi.go # Multi-API parser/extractor dispatch +│ │ ├── streaming_extractor.go # Chat Completions SSE streaming +│ │ ├── responses_parser.go # Responses API parser +│ │ ├── responses_extractor.go # Responses API extractor +│ │ ├── responses_streaming_extractor.go # Responses API SSE streaming +│ │ └── websocket.go # WebSocket URL resolver │ └── xai/ # x.AI (OpenAI-compatible) └── examples/ └── basic/ # Multi-provider proxy server example (uses AutoRouter) diff --git a/README.md b/README.md index 3ee6a7b..0d5f2d3 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,8 @@ curl -X POST http://localhost:8080/ \ - **9 Provider Implementations**: OpenAI, Anthropic, Groq, Fireworks, x.AI, Google AI, AWS Bedrock, Azure OpenAI, OpenAI-compatible base - **AutoRouter**: Single endpoint with automatic provider/API detection -- **Responses API**: Full support for OpenAI's new Responses API +- **Responses API**: Full support for OpenAI's Responses API (HTTP streaming and WebSocket mode) +- **WebSocket Mode**: Persistent connections for multi-turn Responses API workflows with per-turn billing - **SSE Streaming**: Full streaming support with efficient token usage extraction - **8 Built-in Interceptors**: Logging, Metrics, Retry, Billing, Tracing (OTel), HeaderBan, AddHeader, PromptCaching - **Pricing Integration**: models.dev adapter with markup support @@ -185,11 +186,109 @@ router := llmproxy.NewAutoRouter( ) ``` +## WebSocket Mode + +The Responses API supports persistent WebSocket connections for multi-turn, tool-call-heavy workflows. WebSocket support is **opt-in** with a zero-dependency adapter pattern — bring your own WebSocket library. + +### gorilla/websocket Example + +```go +package main + +import ( + "context" + "log" + "net/http" + + "github.com/agentuity/llmproxy" + "github.com/agentuity/llmproxy/providers/openai" + "github.com/gorilla/websocket" +) + +// Configure allowed origins for WebSocket upgrades. +var trustedOrigins = []string{"https://myapp.example.com"} + +// Thin adapters — gorilla's *Conn already satisfies llmproxy.WSConn + +type gorillaUpgrader struct{ websocket.Upgrader } + +func (u *gorillaUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, h http.Header) (llmproxy.WSConn, error) { + conn, err := u.Upgrader.Upgrade(w, r, h) + return conn, err +} + +type gorillaDialer struct{ websocket.Dialer } + +func (d *gorillaDialer) DialContext(ctx context.Context, urlStr string, h http.Header) (llmproxy.WSConn, *http.Response, error) { + conn, resp, err := d.Dialer.DialContext(ctx, urlStr, h) + return conn, resp, err +} + +func main() { + // In production, validate the Origin header against trusted origins. + // This example allows all origins for brevity. + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + origin := r.Header.Get("Origin") + for _, allowed := range trustedOrigins { + if origin == allowed { + return true + } + } + return false + }, + } + + provider, _ := openai.New("sk-your-key") + + router := llmproxy.NewAutoRouter( + llmproxy.WithAutoRouterFallbackProvider(provider), + llmproxy.WithAutoRouterWebSocket( + &gorillaUpgrader{upgrader}, + &gorillaDialer{websocket.Dialer{}}, + ), + llmproxy.WithAutoRouterWSBillingCallback(func(turn int, meta llmproxy.ResponseMetadata, billing *llmproxy.BillingResult) { + log.Printf("Turn %d: %d prompt + %d completion tokens", + turn, meta.Usage.PromptTokens, meta.Usage.CompletionTokens) + }), + ) + router.RegisterProvider(provider) + + http.Handle("/", router) + log.Fatal(http.ListenAndServe(":8080", nil)) +} +``` + +Clients connect with any WebSocket library: + +```python +from websocket import create_connection +import json + +ws = create_connection("ws://localhost:8080/v1/responses", + header=["Authorization: Bearer sk-your-key"]) + +ws.send(json.dumps({ + "type": "response.create", + "model": "gpt-4o", + "input": [{"type": "message", "role": "user", + "content": [{"type": "input_text", "text": "Hello!"}]}], +})) + +for msg in ws: + event = json.loads(msg) + print(event["type"], event.get("delta", "")) + if event["type"] == "response.completed": + break +``` + +The proxy handles model prefix stripping, auth header forwarding, usage extraction, and per-turn billing automatically. See [DESIGN.md](DESIGN.md#websocket-mode) for full protocol details. + ## Providers | Provider | Auth | API Format | Notes | | ------------ | --------------------- | ------------------------------ | ----- | -| OpenAI | Bearer token | Chat completions, Responses | Supports both `/v1/chat/completions` and `/v1/responses` | +| OpenAI | Bearer token | Chat completions, Responses, WebSocket | HTTP + WebSocket for `/v1/responses` | | Anthropic | `x-api-key` | Messages API | | | Groq | Bearer token | OpenAI-compatible | | | Fireworks | Bearer token | OpenAI-compatible | | diff --git a/autorouter.go b/autorouter.go index b5cbc73..7744431 100644 --- a/autorouter.go +++ b/autorouter.go @@ -19,6 +19,9 @@ type AutoRouter struct { client *http.Client fallbackProvider Provider billingCalculator *BillingCalculator + wsUpgrader WSUpgrader + wsDialer WSDialer + wsBillingCallback WSBillingCallback } type AutoRouterOption func(*AutoRouter) @@ -51,6 +54,17 @@ func WithAutoRouterBillingCalculator(calculator *BillingCalculator) AutoRouterOp return func(a *AutoRouter) { a.billingCalculator = calculator } } +func WithAutoRouterWebSocket(upgrader WSUpgrader, dialer WSDialer) AutoRouterOption { + return func(a *AutoRouter) { + a.wsUpgrader = upgrader + a.wsDialer = dialer + } +} + +func WithAutoRouterWSBillingCallback(cb WSBillingCallback) AutoRouterOption { + return func(a *AutoRouter) { a.wsBillingCallback = cb } +} + func NewAutoRouter(opts ...AutoRouterOption) *AutoRouter { a := &AutoRouter{ registry: NewRegistry(), @@ -233,6 +247,11 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w } } + apiType := DetectAPITypeFromPath(req.URL.Path) + if apiType == "" { + apiType = DetectAPITypeFromBodyAndProvider(body, providerName) + } + if raw != nil { if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { raw["model"] = strippedModel @@ -240,7 +259,7 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w } if a.billingCalculator != nil { if stream, ok := raw["stream"].(bool); ok && stream { - if !nativeStreamUsageProviders[providerName] { + if !nativeStreamUsageProviders[providerName] && apiType != APITypeResponses { // Merge include_usage into existing stream_options if present streamOpts, ok := raw["stream_options"].(map[string]any) if !ok { @@ -258,11 +277,6 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w } } - apiType := DetectAPITypeFromPath(req.URL.Path) - if apiType == "" { - apiType = DetectAPITypeFromBodyAndProvider(body, providerName) - } - meta, _, err := provider.BodyParser().Parse(io.NopCloser(bytes.NewReader(body))) if err != nil { return ResponseMetadata{}, err @@ -410,6 +424,15 @@ func (a *AutoRouter) streamResponseWithFlush(r io.Reader, w http.ResponseWriter, } func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if isWebSocketUpgrade(r) && a.wsUpgrader != nil && a.wsDialer != nil { + if err := a.ForwardWebSocket(r.Context(), w, r); err != nil { + if !headerSent(w) { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -486,6 +509,12 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func isWebSocketUpgrade(r *http.Request) bool { + connection := strings.ToLower(r.Header.Get("Connection")) + upgrade := strings.ToLower(r.Header.Get("Upgrade")) + return strings.Contains(connection, "upgrade") && strings.Contains(upgrade, "websocket") +} + func headerSent(w http.ResponseWriter) bool { type headerChecker interface { WroteHeader() bool diff --git a/autorouter_test.go b/autorouter_test.go index 23933a2..4f06b66 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -735,3 +735,88 @@ func TestAutoRouter_AnthropicStreamingNoStreamOptions(t *testing.T) { t.Error("stream_options should NOT be injected for Anthropic (always sends usage in events)") } } + +func TestAutoRouter_ResponsesAPIStreamingNoStreamOptions(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte("event: response.completed\ndata: {\"id\":\"resp_test\",\"output\":[],\"usage\":{\"input_tokens\":10,\"output_tokens\":20}}\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4o", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + io.Copy(w, resp.Body) + rc.Flush() + return ResponseMetadata{ID: "resp_test"}, nil + }, + }, + } + provider.mockProvider.extractFn = func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "resp_test"}, body, nil + } + + billing := NewBillingCalculator( + func(provider, model string) (CostInfo, bool) { + return CostInfo{Input: 1, Output: 2}, true + }, + nil, + ) + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })), + WithAutoRouterBillingCalculator(billing), + ) + router.RegisterProvider(provider) + + t.Run("path-based detection", func(t *testing.T) { + receivedBody = nil + req := httptest.NewRequest("POST", "/v1/responses", bytes.NewReader([]byte(`{"model":"gpt-4o","stream":true,"input":"Hello"}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + + if _, ok := receivedBody["stream_options"]; ok { + t.Error("stream_options should NOT be injected for Responses API requests") + } + }) + + t.Run("body-based detection", func(t *testing.T) { + receivedBody = nil + req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4o","stream":true,"input":"Hello"}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + + if _, ok := receivedBody["stream_options"]; ok { + t.Error("stream_options should NOT be injected for Responses API requests (body-detected)") + } + }) +} diff --git a/autorouter_websocket.go b/autorouter_websocket.go new file mode 100644 index 0000000..22d099f --- /dev/null +++ b/autorouter_websocket.go @@ -0,0 +1,290 @@ +package llmproxy + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" +) + +var ErrWebSocketNotConfigured = errors.New("websocket forwarding is not configured") + +func (a *AutoRouter) ForwardWebSocket(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + if a.wsUpgrader == nil || a.wsDialer == nil { + return ErrWebSocketNotConfigured + } + + clientConn, err := a.wsUpgrader.Upgrade(w, r, nil) + if err != nil { + return fmt.Errorf("upgrade websocket: %w", err) + } + + firstType, firstData, err := clientConn.ReadMessage() + if err != nil { + _ = clientConn.Close() + return fmt.Errorf("read first websocket message: %w", err) + } + + firstMsg, err := ParseWSMessage(firstData) + if err != nil { + _ = clientConn.Close() + return fmt.Errorf("parse first websocket message: %w", err) + } + if firstMsg == nil || firstMsg.Type != "response.create" { + _ = clientConn.Close() + return errors.New("first websocket message must be type=response.create") + } + + model := firstMsg.Model + hint := ProviderHint{Model: model, Headers: r.Header} + providerName := a.detector.Detect(hint) + if providerName == "" && a.modelProviderLookup != nil && model != "" { + providerName = a.modelProviderLookup(model) + } + + var provider Provider + if providerName != "" { + provider, _ = a.registry.Get(providerName) + if provider == nil { + _ = clientConn.Close() + return ErrNoProvider + } + } else { + provider = a.fallbackProvider + if provider == nil { + _ = clientConn.Close() + return ErrNoProvider + } + providerName = provider.Name() + } + + wsProvider, ok := provider.(WebSocketCapableProvider) + if !ok { + _ = clientConn.Close() + return fmt.Errorf("provider %q does not support websocket mode", provider.Name()) + } + + firstOut := firstData + if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { + firstOut, err = rewriteWSCreateModel(firstData, strippedModel) + if err != nil { + _ = clientConn.Close() + return fmt.Errorf("rewrite first websocket message: %w", err) + } + model = strippedModel + } + + meta, _, err := provider.BodyParser().Parse(io.NopCloser(bytes.NewReader(firstOut))) + if err != nil { + _ = clientConn.Close() + return fmt.Errorf("parse websocket metadata: %w", err) + } + if meta.Custom == nil { + meta.Custom = make(map[string]any) + } + meta.Custom["api_type"] = APITypeResponses + meta.Custom["provider"] = providerName + if meta.Model == "" { + meta.Model = model + } + + upstreamURL, err := wsProvider.WebSocketURL(meta) + if err != nil { + _ = clientConn.Close() + return fmt.Errorf("resolve websocket url: %w", err) + } + + headers := cloneHeader(r.Header) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodGet, upstreamURL.String(), bytes.NewReader(firstOut)) + if err != nil { + _ = clientConn.Close() + return fmt.Errorf("create websocket upstream request: %w", err) + } + upstreamReq.Header = headers + if err := provider.RequestEnricher().Enrich(upstreamReq, meta, firstOut); err != nil { + _ = clientConn.Close() + return fmt.Errorf("enrich websocket request: %w", err) + } + + upstreamConn, _, err := a.wsDialer.DialContext(ctx, upstreamURL.String(), upstreamReq.Header) + if err != nil { + _ = clientConn.Close() + return fmt.Errorf("dial upstream websocket: %w", err) + } + + var closeOnce sync.Once + closeBoth := func() { + closeOnce.Do(func() { + _ = clientConn.Close() + _ = upstreamConn.Close() + }) + } + defer closeBoth() + + if err := upstreamConn.WriteMessage(firstType, firstOut); err != nil { + return fmt.Errorf("forward first websocket message: %w", err) + } + + var modelMu sync.RWMutex + currentModel := model + setModel := func(m string) { + if m == "" { + return + } + modelMu.Lock() + currentModel = m + meta.Model = m + modelMu.Unlock() + } + getModel := func() string { + modelMu.RLock() + defer modelMu.RUnlock() + return currentModel + } + + errCh := make(chan error, 2) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + if relayErr := relayClientToUpstream(clientConn, upstreamConn, setModel); relayErr != nil { + errCh <- relayErr + } + closeBoth() + }() + + wg.Add(1) + go func() { + defer wg.Done() + if relayErr := a.relayUpstreamToClient(upstreamConn, clientConn, providerName, getModel); relayErr != nil { + errCh <- relayErr + } + closeBoth() + }() + + wg.Wait() + close(errCh) + + for relayErr := range errCh { + if relayErr != nil && !isWSRelayCloseError(relayErr) { + return relayErr + } + } + + return nil +} + +func relayClientToUpstream(clientConn, upstreamConn WSConn, setModel func(string)) error { + for { + messageType, data, err := clientConn.ReadMessage() + if err != nil { + return err + } + + outData := data + if messageType == TextMessage { + msg, parseErr := ParseWSMessage(data) + if parseErr == nil && msg != nil && msg.Type == "response.create" { + model := msg.Model + if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { + var rewriteErr error + outData, rewriteErr = rewriteWSCreateModel(data, strippedModel) + if rewriteErr != nil { + return rewriteErr + } + model = strippedModel + } + setModel(model) + } + } + + if err := upstreamConn.WriteMessage(messageType, outData); err != nil { + return err + } + } +} + +func (a *AutoRouter) relayUpstreamToClient(upstreamConn, clientConn WSConn, providerName string, model func() string) error { + turn := 0 + for { + messageType, data, err := upstreamConn.ReadMessage() + if err != nil { + return err + } + + if messageType == TextMessage { + msg, parseErr := ParseWSMessage(data) + if parseErr == nil && msg != nil && msg.Type == "response.completed" { + usage, usageErr := ExtractWSUsage(data) + if usageErr != nil { + return usageErr + } + if usage != nil { + turn++ + respMeta := ResponseMetadata{ + Model: model(), + Usage: Usage{ + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + }, + Custom: map[string]any{}, + } + if usage.CacheUsage != nil { + respMeta.Custom["cache_usage"] = *usage.CacheUsage + } + if usage.ReasoningTokens > 0 { + respMeta.Custom["reasoning_tokens"] = usage.ReasoningTokens + } + + var billing *BillingResult + if a.billingCalculator != nil { + meta := BodyMetadata{Model: model(), Custom: map[string]any{"provider": providerName}} + billing = a.billingCalculator.Calculate(meta, &respMeta) + } + if a.wsBillingCallback != nil { + a.wsBillingCallback(turn, respMeta, billing) + } + } + } + } + + if err := clientConn.WriteMessage(messageType, data); err != nil { + return err + } + } +} + +func rewriteWSCreateModel(data []byte, model string) ([]byte, error) { + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + raw["model"] = model + return json.Marshal(raw) +} + +func cloneHeader(h http.Header) http.Header { + out := make(http.Header, len(h)) + for k, vv := range h { + out[k] = append([]string(nil), vv...) + } + return out +} + +func isWSRelayCloseError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) { + return true + } + errText := strings.ToLower(err.Error()) + return strings.Contains(errText, "closed") || strings.Contains(errText, "websocket: close") +} diff --git a/autorouter_websocket_test.go b/autorouter_websocket_test.go new file mode 100644 index 0000000..b480585 --- /dev/null +++ b/autorouter_websocket_test.go @@ -0,0 +1,609 @@ +package llmproxy + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "sync/atomic" + "testing" + "time" +) + +type wsFrame struct { + messageType int + data []byte +} + +type mockWSConn struct { + incoming chan wsFrame + peer *mockWSConn + closed atomic.Bool + closeCh chan struct{} +} + +func newMockWSLinkedPair() (*mockWSConn, *mockWSConn) { + a := &mockWSConn{incoming: make(chan wsFrame, 32), closeCh: make(chan struct{})} + b := &mockWSConn{incoming: make(chan wsFrame, 32), closeCh: make(chan struct{})} + a.peer = b + b.peer = a + return a, b +} + +func (c *mockWSConn) ReadMessage() (int, []byte, error) { + select { + case frame := <-c.incoming: + return frame.messageType, append([]byte(nil), frame.data...), nil + case <-c.closeCh: + return 0, nil, io.EOF + } +} + +func (c *mockWSConn) WriteMessage(messageType int, data []byte) error { + if c.closed.Load() { + return io.EOF + } + if c.peer == nil || c.peer.closed.Load() { + return io.EOF + } + select { + case c.peer.incoming <- wsFrame{messageType: messageType, data: append([]byte(nil), data...)}: + return nil + case <-c.closeCh: + return io.EOF + case <-c.peer.closeCh: + return io.EOF + } +} + +func (c *mockWSConn) Close() error { + if c.closed.CompareAndSwap(false, true) { + close(c.closeCh) + if c.peer != nil { + c.peer.closeFromPeer() + } + } + return nil +} + +func (c *mockWSConn) closeFromPeer() { + if c.closed.CompareAndSwap(false, true) { + close(c.closeCh) + } +} + +type mockWSUpgrader struct { + conn WSConn + err error + called atomic.Bool +} + +func (u *mockWSUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, h http.Header) (WSConn, error) { + u.called.Store(true) + if u.err != nil { + return nil, u.err + } + return u.conn, nil +} + +type mockWSDialer struct { + conn WSConn + err error + dialedURL string + dialedHeader http.Header + mu sync.Mutex +} + +func (d *mockWSDialer) DialContext(ctx context.Context, urlStr string, h http.Header) (WSConn, *http.Response, error) { + d.mu.Lock() + d.dialedURL = urlStr + d.dialedHeader = cloneHeader(h) + d.mu.Unlock() + if d.err != nil { + return nil, nil, d.err + } + return d.conn, &http.Response{StatusCode: http.StatusSwitchingProtocols}, nil +} + +type mockWSProvider struct { + *mockProvider + wsURL *url.URL + err error +} + +func (m *mockWSProvider) WebSocketURL(meta BodyMetadata) (*url.URL, error) { + if m.err != nil { + return nil, m.err + } + return m.wsURL, nil +} + +func wsTestProvider(t *testing.T, name string, wsURL string) *mockWSProvider { + t.Helper() + u, err := url.Parse(wsURL) + if err != nil { + t.Fatalf("url parse: %v", err) + } + + return &mockWSProvider{ + mockProvider: &mockProvider{ + name: name, + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + var raw map[string]any + _ = json.Unmarshal(data, &raw) + model, _ := raw["model"].(string) + return BodyMetadata{Model: model, Custom: map[string]any{}}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + if req.Header.Get("Authorization") == "" { + req.Header.Set("Authorization", "Bearer enriched") + } + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return u, nil + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + return ResponseMetadata{}, nil, nil + }, + }, + wsURL: u, + } +} + +type wsFixture struct { + router *AutoRouter + client *mockWSConn + upstream *mockWSConn + upgrader *mockWSUpgrader + dialer *mockWSDialer + providerName string +} + +func newWSFixture(t *testing.T, provider Provider) *wsFixture { + t.Helper() + clientApp, routerClient := newMockWSLinkedPair() + upstreamApp, routerUpstream := newMockWSLinkedPair() + + upgrader := &mockWSUpgrader{conn: routerClient} + dialer := &mockWSDialer{conn: routerUpstream} + + router := NewAutoRouter( + WithAutoRouterWebSocket(upgrader, dialer), + WithAutoRouterDetector(ProviderDetectorFunc(func(h ProviderHint) string { + if h.Model == "" { + return "openai" + } + return DetectProviderFromModel(h.Model) + })), + ) + router.RegisterProvider(provider) + + return &wsFixture{ + router: router, + client: clientApp, + upstream: upstreamApp, + upgrader: upgrader, + dialer: dialer, + providerName: provider.Name(), + } +} + +func mustReadFrame(t *testing.T, c *mockWSConn) wsFrame { + t.Helper() + ch := make(chan wsFrame, 1) + errCh := make(chan error, 1) + go func() { + mt, data, err := c.ReadMessage() + if err != nil { + errCh <- err + return + } + ch <- wsFrame{messageType: mt, data: data} + }() + select { + case f := <-ch: + return f + case err := <-errCh: + t.Fatalf("read frame error: %v", err) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for frame") + } + return wsFrame{} +} + +func mustReadError(t *testing.T, c *mockWSConn) error { + t.Helper() + ch := make(chan error, 1) + go func() { + _, _, err := c.ReadMessage() + ch <- err + }() + select { + case err := <-ch: + return err + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for read error") + } + return nil +} + +func startForwardWS(t *testing.T, f *wsFixture, reqHeaders http.Header) chan error { + t.Helper() + req := httptest.NewRequest(http.MethodGet, "http://localhost/v1/responses", nil) + for k, vv := range reqHeaders { + req.Header[k] = vv + } + w := httptest.NewRecorder() + errCh := make(chan error, 1) + go func() { + errCh <- f.router.ForwardWebSocket(context.Background(), w, req) + }() + return errCh +} + +func TestForwardWebSocket_BasicRelay(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + + errCh := startForwardWS(t, f, http.Header{"Authorization": []string{"Bearer test"}}) + + if err := f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)); err != nil { + t.Fatalf("client write: %v", err) + } + + first := mustReadFrame(t, f.upstream) + if first.messageType != TextMessage { + t.Fatalf("upstream messageType=%d, want TextMessage", first.messageType) + } + + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_1"}}`)) + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}}`)) + + created := mustReadFrame(t, f.client) + completed := mustReadFrame(t, f.client) + if !bytes.Contains(created.data, []byte(`"response.created"`)) { + t.Fatalf("expected response.created frame, got %s", string(created.data)) + } + if !bytes.Contains(completed.data, []byte(`"response.completed"`)) { + t.Fatalf("expected response.completed frame, got %s", string(completed.data)) + } + + _ = f.client.Close() + if err := <-errCh; err != nil { + t.Fatalf("ForwardWebSocket() error = %v", err) + } +} + +func TestForwardWebSocket_UsageExtraction(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + + var gotMeta ResponseMetadata + f.router.wsBillingCallback = func(turn int, meta ResponseMetadata, billing *BillingResult) { gotMeta = meta } + + errCh := startForwardWS(t, f, http.Header{}) + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":22,"total_tokens":33}}}`)) + _ = mustReadFrame(t, f.client) + _ = f.client.Close() + _ = <-errCh + + if gotMeta.Usage.PromptTokens != 11 || gotMeta.Usage.CompletionTokens != 22 || gotMeta.Usage.TotalTokens != 33 { + t.Fatalf("unexpected usage: %+v", gotMeta.Usage) + } +} + +func TestForwardWebSocket_CacheUsage(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + + var gotCache int + f.router.wsBillingCallback = func(turn int, meta ResponseMetadata, billing *BillingResult) { + if cu, ok := meta.Custom["cache_usage"].(CacheUsage); ok { + gotCache = cu.CachedTokens + } + } + + errCh := startForwardWS(t, f, http.Header{}) + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":20,"output_tokens":5,"total_tokens":25,"input_tokens_details":{"cached_tokens":4}}}}`)) + _ = mustReadFrame(t, f.client) + _ = f.client.Close() + _ = <-errCh + + if gotCache != 4 { + t.Fatalf("cached tokens = %d, want 4", gotCache) + } +} + +func TestForwardWebSocket_ReasoningTokens(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + + var gotPrompt int + var gotReasoning int + f.router.wsBillingCallback = func(turn int, meta ResponseMetadata, billing *BillingResult) { + gotPrompt = meta.Usage.PromptTokens + if rt, ok := meta.Custom["reasoning_tokens"].(int); ok { + gotReasoning = rt + } + } + + errCh := startForwardWS(t, f, http.Header{}) + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":8,"output_tokens":6,"total_tokens":14,"output_tokens_details":{"reasoning_tokens":2}}}}`)) + _ = mustReadFrame(t, f.client) + _ = f.client.Close() + _ = <-errCh + + if gotPrompt != 8 { + t.Fatalf("prompt tokens = %d, want 8", gotPrompt) + } + if gotReasoning != 2 { + t.Fatalf("reasoning_tokens = %d, want 2", gotReasoning) + } +} + +func TestForwardWebSocket_ModelPrefixStripping(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + errCh := startForwardWS(t, f, http.Header{}) + + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"openai/gpt-4o","input":[]}`)) + first := mustReadFrame(t, f.upstream) + if bytes.Contains(first.data, []byte("openai/gpt-4o")) { + t.Fatalf("expected stripped model, got %s", string(first.data)) + } + if !bytes.Contains(first.data, []byte(`"model":"gpt-4o"`)) { + t.Fatalf("expected model gpt-4o, got %s", string(first.data)) + } + + _ = f.client.Close() + _ = <-errCh +} + +func TestForwardWebSocket_MultiTurn(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + var turns int + f.router.wsBillingCallback = func(turn int, meta ResponseMetadata, billing *BillingResult) { turns = turn } + + errCh := startForwardWS(t, f, http.Header{}) + + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}}`)) + _ = mustReadFrame(t, f.client) + + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"openai/gpt-4o","previous_response_id":"resp_1","input":[]}`)) + second := mustReadFrame(t, f.upstream) + if bytes.Contains(second.data, []byte("openai/gpt-4o")) { + t.Fatalf("expected stripped model in second turn") + } + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":4,"output_tokens":5,"total_tokens":9}}}`)) + _ = mustReadFrame(t, f.client) + + _ = f.client.Close() + _ = <-errCh + + if turns != 2 { + t.Fatalf("turns = %d, want 2", turns) + } +} + +func TestForwardWebSocket_BillingCallback(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + + calc := NewBillingCalculator(func(provider, model string) (CostInfo, bool) { + return CostInfo{Input: 1.0, Output: 2.0, CacheRead: 0.5}, true + }, nil) + f.router.billingCalculator = calc + + var gotTurns []int + var gotCosts []float64 + f.router.wsBillingCallback = func(turn int, meta ResponseMetadata, billing *BillingResult) { + gotTurns = append(gotTurns, turn) + if billing != nil { + gotCosts = append(gotCosts, billing.TotalCost) + } + } + + errCh := startForwardWS(t, f, http.Header{}) + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":1000,"output_tokens":500,"total_tokens":1500}}}`)) + _ = mustReadFrame(t, f.client) + _ = f.client.Close() + _ = <-errCh + + if len(gotTurns) != 1 || gotTurns[0] != 1 { + t.Fatalf("got turns = %v, want [1]", gotTurns) + } + if len(gotCosts) != 1 || gotCosts[0] <= 0 { + t.Fatalf("unexpected billing callback costs: %v", gotCosts) + } +} + +func TestForwardWebSocket_ClientClose(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + errCh := startForwardWS(t, f, http.Header{}) + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + + _ = f.client.Close() + if err := <-errCh; err != nil { + t.Fatalf("ForwardWebSocket() error = %v", err) + } + if err := mustReadError(t, f.upstream); err == nil { + t.Fatal("expected upstream side to close") + } +} + +func TestForwardWebSocket_UpstreamClose(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + errCh := startForwardWS(t, f, http.Header{}) + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + + _ = f.upstream.Close() + if err := <-errCh; err != nil { + t.Fatalf("ForwardWebSocket() error = %v", err) + } + if err := mustReadError(t, f.client); err == nil { + t.Fatal("expected client side to close") + } +} + +func TestForwardWebSocket_ErrorEvent(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + errCh := startForwardWS(t, f, http.Header{}) + + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + _ = f.upstream.WriteMessage(TextMessage, []byte(`{"type":"response.failed","error":{"message":"boom"}}`)) + fail := mustReadFrame(t, f.client) + if !bytes.Contains(fail.data, []byte(`"response.failed"`)) { + t.Fatalf("expected response.failed passthrough, got %s", string(fail.data)) + } + _ = f.client.Close() + _ = <-errCh +} + +func TestForwardWebSocket_NoWSUpgrader(t *testing.T) { + router := NewAutoRouter() + req := httptest.NewRequest(http.MethodGet, "http://localhost/v1/responses", nil) + w := httptest.NewRecorder() + err := router.ForwardWebSocket(context.Background(), w, req) + if !errors.Is(err, ErrWebSocketNotConfigured) { + t.Fatalf("error = %v, want ErrWebSocketNotConfigured", err) + } +} + +func TestForwardWebSocket_NonWSProvider(t *testing.T) { + provider := &mockProvider{ + name: "openai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4o", Custom: map[string]any{}}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse("https://api.openai.com/v1/responses") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + return ResponseMetadata{}, nil, nil + }, + } + + f := newWSFixture(t, provider) + errCh := startForwardWS(t, f, http.Header{}) + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + err := <-errCh + if err == nil || !bytes.Contains([]byte(err.Error()), []byte("does not support websocket mode")) { + t.Fatalf("expected non-websocket provider error, got %v", err) + } +} + +func TestForwardWebSocket_PassthroughNonCreateMessages(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + errCh := startForwardWS(t, f, http.Header{}) + + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + + original := []byte(`{"type":"response.cancel","response_id":"resp_1"}`) + _ = f.client.WriteMessage(TextMessage, original) + passthrough := mustReadFrame(t, f.upstream) + if !bytes.Equal(passthrough.data, original) { + t.Fatalf("expected byte-for-byte passthrough\n got: %s\nwant: %s", string(passthrough.data), string(original)) + } + + _ = f.client.Close() + _ = <-errCh +} + +func TestServeHTTP_WebSocketDetection(t *testing.T) { + provider := wsTestProvider(t, "openai", "wss://api.openai.com/v1/responses") + f := newWSFixture(t, provider) + + req := httptest.NewRequest(http.MethodGet, "http://localhost/v1/responses", nil) + req.Header.Set("Connection", "keep-alive, upgrade") + req.Header.Set("Upgrade", "websocket") + w := httptest.NewRecorder() + + done := make(chan struct{}) + go func() { + f.router.ServeHTTP(w, req) + close(done) + }() + + _ = f.client.WriteMessage(TextMessage, []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + _ = mustReadFrame(t, f.upstream) + _ = f.client.Close() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("ServeHTTP did not complete") + } + + if !f.upgrader.called.Load() { + t.Fatal("expected websocket upgrader to be called") + } +} + +func TestServeHTTP_NonWebSocketUnchanged(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"ok"}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "openai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4o", Custom: map[string]any{}}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + data, _ := io.ReadAll(resp.Body) + return ResponseMetadata{}, data, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(h ProviderHint) string { return "openai" })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4o"}`))) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got, want := w.Body.String(), `{"id":"ok"}`; got != want { + t.Fatalf("body=%q, want %q", got, want) + } +} diff --git a/providers/openai/provider.go b/providers/openai/provider.go index f163c6c..6496e08 100644 --- a/providers/openai/provider.go +++ b/providers/openai/provider.go @@ -10,9 +10,12 @@ package openai import ( + "github.com/agentuity/llmproxy" "github.com/agentuity/llmproxy/providers/openai_compatible" ) +var _ llmproxy.WebSocketCapableProvider = (*openai_compatible.Provider)(nil) + // New creates a new OpenAI provider with the given API key. // The provider is configured to use OpenAI's API endpoint (https://api.openai.com). // diff --git a/providers/openai_compatible/extractor.go b/providers/openai_compatible/extractor.go index fc92a9b..b5253cf 100644 --- a/providers/openai_compatible/extractor.go +++ b/providers/openai_compatible/extractor.go @@ -52,6 +52,10 @@ func (e *Extractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []b } } + if openaiResp.Usage.CompletionTokensDetails != nil && openaiResp.Usage.CompletionTokensDetails.ReasoningTokens > 0 { + meta.Custom["reasoning_tokens"] = openaiResp.Usage.CompletionTokensDetails.ReasoningTokens + } + for i, c := range openaiResp.Choices { meta.Choices[i] = llmproxy.Choice{ Index: c.Index, diff --git a/providers/openai_compatible/extractor_test.go b/providers/openai_compatible/extractor_test.go new file mode 100644 index 0000000..edf44af --- /dev/null +++ b/providers/openai_compatible/extractor_test.go @@ -0,0 +1,135 @@ +package openai_compatible + +import ( + "bytes" + "io" + "net/http" + "testing" + + "github.com/agentuity/llmproxy" +) + +func TestExtractor_ReasoningTokens(t *testing.T) { + body := `{ + "id": "chatcmpl-abc", + "object": "chat.completion", + "model": "o1", + "usage": { + "prompt_tokens": 75, + "completion_tokens": 1186, + "total_tokens": 1261, + "completion_tokens_details": { + "reasoning_tokens": 1024 + } + }, + "choices": [] + }` + + extractor := NewExtractor() + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(body))), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + + if meta.Usage.PromptTokens != 75 { + t.Errorf("PromptTokens = %d, want 75", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 1186 { + t.Errorf("CompletionTokens = %d, want 1186", meta.Usage.CompletionTokens) + } + + rt, ok := meta.Custom["reasoning_tokens"].(int) + if !ok { + t.Fatal("expected reasoning_tokens in custom metadata") + } + if rt != 1024 { + t.Errorf("reasoning_tokens = %d, want 1024", rt) + } +} + +func TestExtractor_ReasoningTokensZero(t *testing.T) { + body := `{ + "id": "chatcmpl-abc", + "object": "chat.completion", + "model": "gpt-4o", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "choices": [] + }` + + extractor := NewExtractor() + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(body))), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + + if _, ok := meta.Custom["reasoning_tokens"]; ok { + t.Error("expected no reasoning_tokens when value is 0") + } +} + +func TestExtractor_CacheAndReasoningTokens(t *testing.T) { + body := `{ + "id": "chatcmpl-abc", + "object": "chat.completion", + "model": "o1", + "usage": { + "prompt_tokens": 100, + "completion_tokens": 500, + "total_tokens": 600, + "prompt_tokens_details": { + "cached_tokens": 80 + }, + "completion_tokens_details": { + "reasoning_tokens": 256 + } + }, + "choices": [] + }` + + extractor := NewExtractor() + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(body))), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + + cu, ok := meta.Custom["cache_usage"].(llmproxy.CacheUsage) + if !ok { + t.Fatal("expected cache_usage in custom metadata") + } + if cu.CachedTokens != 80 { + t.Errorf("CachedTokens = %d, want 80", cu.CachedTokens) + } + + rt, ok := meta.Custom["reasoning_tokens"].(int) + if !ok { + t.Fatal("expected reasoning_tokens in custom metadata") + } + if rt != 256 { + t.Errorf("reasoning_tokens = %d, want 256", rt) + } +} diff --git a/providers/openai_compatible/multiapi.go b/providers/openai_compatible/multiapi.go index 4b4e013..4ef6ea3 100644 --- a/providers/openai_compatible/multiapi.go +++ b/providers/openai_compatible/multiapi.go @@ -79,12 +79,14 @@ func (e *MultiAPIExtractor) Extract(resp *http.Response) (llmproxy.ResponseMetad type StreamingMultiAPIExtractor struct { *MultiAPIExtractor chatCompletionsStreaming *StreamingExtractor + responsesStreaming *ResponsesStreamingExtractor } func NewStreamingMultiAPIExtractor() *StreamingMultiAPIExtractor { return &StreamingMultiAPIExtractor{ MultiAPIExtractor: NewMultiAPIExtractor(), chatCompletionsStreaming: NewStreamingExtractor(), + responsesStreaming: NewResponsesStreamingExtractor(), } } @@ -96,6 +98,14 @@ func (e *StreamingMultiAPIExtractor) ExtractStreamingWithController(resp *http.R if !e.IsStreamingResponse(resp) { return e.extractNonStreamingWithController(resp, w, rc) } + + if resp.Request != nil { + metaCtx := llmproxy.GetMetaFromContext(resp.Request.Context()) + if apiType, ok := metaCtx.Meta.Custom["api_type"].(llmproxy.APIType); ok && apiType == llmproxy.APITypeResponses { + return e.responsesStreaming.ExtractStreamingWithController(resp, w, rc) + } + } + return e.chatCompletionsStreaming.ExtractStreamingWithController(resp, w, rc) } diff --git a/providers/openai_compatible/provider.go b/providers/openai_compatible/provider.go index 95810e5..eebd6c5 100644 --- a/providers/openai_compatible/provider.go +++ b/providers/openai_compatible/provider.go @@ -1,6 +1,9 @@ package openai_compatible import ( + "errors" + "net/url" + "github.com/agentuity/llmproxy" ) @@ -10,6 +13,8 @@ type Provider struct { *llmproxy.BaseProvider } +var _ llmproxy.WebSocketCapableProvider = (*Provider)(nil) + // New creates a new OpenAI-compatible provider with the given configuration. // // Parameters: @@ -65,3 +70,18 @@ func NewMultiAPI(name, apiKey, baseURL string) (*Provider, error) { func NewWithProvider(name string, p *llmproxy.BaseProvider) *Provider { return &Provider{BaseProvider: p} } + +func (p *Provider) WebSocketURL(meta llmproxy.BodyMetadata) (*url.URL, error) { + resolver := p.URLResolver() + if resolver == nil { + return nil, errors.New("provider has no URL resolver") + } + + if wsResolver, ok := resolver.(interface { + WebSocketURL(llmproxy.BodyMetadata) (*url.URL, error) + }); ok { + return wsResolver.WebSocketURL(meta) + } + + return nil, errors.New("provider URL resolver does not support websocket URL") +} diff --git a/providers/openai_compatible/resolver.go b/providers/openai_compatible/resolver.go index 2234b2c..bb3d5b0 100644 --- a/providers/openai_compatible/resolver.go +++ b/providers/openai_compatible/resolver.go @@ -2,6 +2,7 @@ package openai_compatible import ( "net/url" + "strings" "github.com/agentuity/llmproxy" ) @@ -32,7 +33,7 @@ func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { } func NewResolver(baseURL string) (*Resolver, error) { - u, err := url.Parse(baseURL) + u, err := url.Parse(normalizeBaseURL(baseURL)) if err != nil { return nil, err } @@ -40,9 +41,20 @@ func NewResolver(baseURL string) (*Resolver, error) { } func NewResolverWithAPIType(baseURL string, apiType llmproxy.APIType) (*Resolver, error) { - u, err := url.Parse(baseURL) + u, err := url.Parse(normalizeBaseURL(baseURL)) if err != nil { return nil, err } return &Resolver{BaseURL: u, APIType: apiType}, nil } + +// normalizeBaseURL strips a trailing "/v1" or "/v1/" suffix from the base URL +// so that Resolve and WebSocketURL can unconditionally prepend "/v1/..." without +// producing double segments like "/v1/v1/responses". +func normalizeBaseURL(raw string) string { + raw = strings.TrimRight(raw, "/") + if strings.HasSuffix(raw, "/v1") { + raw = raw[:len(raw)-3] + } + return raw +} diff --git a/providers/openai_compatible/responses_extractor.go b/providers/openai_compatible/responses_extractor.go index 403c674..a6f3602 100644 --- a/providers/openai_compatible/responses_extractor.go +++ b/providers/openai_compatible/responses_extractor.go @@ -58,6 +58,10 @@ func (e *ResponsesExtractor) Extract(resp *http.Response) (llmproxy.ResponseMeta } } + if responsesResp.Usage.OutputTokensDetails != nil && responsesResp.Usage.OutputTokensDetails.ReasoningTokens > 0 { + meta.Custom["reasoning_tokens"] = responsesResp.Usage.OutputTokensDetails.ReasoningTokens + } + if len(responsesResp.Output) > 0 { content := extractResponsesContent(responsesResp.Output) meta.Choices = []llmproxy.Choice{ diff --git a/providers/openai_compatible/responses_streaming_extractor.go b/providers/openai_compatible/responses_streaming_extractor.go new file mode 100644 index 0000000..7ce2de5 --- /dev/null +++ b/providers/openai_compatible/responses_streaming_extractor.go @@ -0,0 +1,180 @@ +package openai_compatible + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + + "github.com/agentuity/llmproxy" +) + +type ResponsesStreamingExtractor struct { + *ResponsesExtractor +} + +func NewResponsesStreamingExtractor() *ResponsesStreamingExtractor { + return &ResponsesStreamingExtractor{ + ResponsesExtractor: NewResponsesExtractor(), + } +} + +func (e *ResponsesStreamingExtractor) IsStreamingResponse(resp *http.Response) bool { + return llmproxy.IsSSEStream(resp.Header.Get("Content-Type")) +} + +func (e *ResponsesStreamingExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + if !e.IsStreamingResponse(resp) { + return e.extractNonStreamingWithController(resp, w, rc) + } + + return e.extractResponsesStreamingWithController(resp, w, rc) +} + +func (e *ResponsesStreamingExtractor) extractNonStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + + meta, _, err := e.ResponsesExtractor.Extract(&http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header, + Body: io.NopCloser(tee), + }) + if err != nil { + return meta, err + } + + readBuf := make([]byte, 1024*512) + for { + n, err := buf.Read(readBuf) + if err != nil { + if err == io.EOF { + if n > 0 { + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return meta, err + } + if n == 0 { + break + } + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + if flushErr := rc.Flush(); flushErr != nil { + return meta, flushErr + } + } + + return meta, nil +} + +func (e *ResponsesStreamingExtractor) extractResponsesStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + meta := llmproxy.ResponseMetadata{ + Choices: make([]llmproxy.Choice, 0), + Custom: map[string]any{"api_type": llmproxy.APITypeResponses}, + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + + var accumulatedUsage *llmproxy.StreamingUsage + + for scanner.Scan() { + line := scanner.Bytes() + + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n")); err != nil { + return meta, err + } + _ = rc.Flush() + + if len(line) == 0 { + continue + } + + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + + data := bytes.TrimPrefix(line, []byte("data:")) + data = bytes.TrimSpace(data) + + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + + event, err := llmproxy.ParseResponsesSSEEvent(data) + if err != nil { + if errors.Is(err, llmproxy.ErrStreamComplete) { + continue + } + continue + } + if event == nil { + continue + } + + if len(event.Response) > 0 { + var response llmproxy.ResponsesStreamResponse + if err := json.Unmarshal(event.Response, &response); err == nil { + if response.ID != "" { + meta.ID = response.ID + } + if response.Model != "" { + meta.Model = response.Model + } + if response.Object != "" { + meta.Object = response.Object + } + if response.Status != "" { + meta.Custom["status"] = response.Status + } + if response.Usage != nil && response.Usage.OutputTokensDetails != nil && response.Usage.OutputTokensDetails.ReasoningTokens > 0 { + meta.Custom["reasoning_tokens"] = response.Usage.OutputTokensDetails.ReasoningTokens + } + } + } + + usage := llmproxy.ExtractUsageFromResponsesEvent(event) + if usage != nil { + accumulatedUsage = usage + } + } + + if err := scanner.Err(); err != nil { + return meta, err + } + + if accumulatedUsage != nil { + meta.Usage = llmproxy.Usage{ + PromptTokens: accumulatedUsage.PromptTokens, + CompletionTokens: accumulatedUsage.CompletionTokens, + TotalTokens: accumulatedUsage.TotalTokens, + } + if accumulatedUsage.CacheUsage != nil { + meta.Custom["cache_usage"] = *accumulatedUsage.CacheUsage + } + if accumulatedUsage.ReasoningTokens > 0 { + meta.Custom["reasoning_tokens"] = accumulatedUsage.ReasoningTokens + } + } + + return meta, nil +} diff --git a/providers/openai_compatible/responses_streaming_extractor_test.go b/providers/openai_compatible/responses_streaming_extractor_test.go new file mode 100644 index 0000000..60fbb15 --- /dev/null +++ b/providers/openai_compatible/responses_streaming_extractor_test.go @@ -0,0 +1,278 @@ +package openai_compatible + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/agentuity/llmproxy" +) + +func runResponsesStreamExtraction(t *testing.T, contentType, body string) (llmproxy.ResponseMetadata, string, error) { + t.Helper() + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{contentType}}, + Body: io.NopCloser(strings.NewReader(body)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewResponsesStreamingExtractor() + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + return meta, recorder.Body.String(), err +} + +func TestResponsesStreamingExtractor_FullLifecycle(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"object\":\"response\",\"model\":\"gpt-4o\",\"status\":\"in_progress\"}}\n\n" + + "data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_123\",\"status\":\"in_progress\"}}\n\n" + + "data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\",\"id\":\"msg_1\"},\"output_index\":0}\n\n" + + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Hello\",\"content_index\":0,\"output_index\":0,\"item_id\":\"msg_1\"}\n\n" + + "data: {\"type\":\"response.output_text.delta\",\"delta\":\" world\",\"content_index\":0,\"output_index\":0,\"item_id\":\"msg_1\"}\n\n" + + "data: {\"type\":\"response.output_text.done\",\"text\":\"Hello world\",\"content_index\":0,\"output_index\":0}\n\n" + + "data: {\"type\":\"response.content_part.done\",\"part\":{\"type\":\"output_text\",\"text\":\"Hello world\"},\"content_index\":0,\"output_index\":0}\n\n" + + "data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_123\",\"object\":\"response\",\"model\":\"gpt-4o\",\"status\":\"completed\",\"usage\":{\"input_tokens\":10,\"output_tokens\":5,\"total_tokens\":15}}}\n\n" + + "data: [DONE]\n\n" + + meta, output, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if output != stream { + t.Fatalf("stream passthrough mismatch\nwant:\n%s\ngot:\n%s", stream, output) + } + if meta.ID != "resp_123" { + t.Errorf("ID = %q, want resp_123", meta.ID) + } + if meta.Model != "gpt-4o" { + t.Errorf("Model = %q, want gpt-4o", meta.Model) + } + if meta.Object != "response" { + t.Errorf("Object = %q, want response", meta.Object) + } + if meta.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", meta.Usage.TotalTokens) + } + if meta.Custom["api_type"] != llmproxy.APITypeResponses { + t.Errorf("api_type = %v, want responses", meta.Custom["api_type"]) + } +} + +func TestResponsesStreamingExtractor_UsageExtraction(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-4o\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":101,\"output_tokens\":44,\"total_tokens\":145}}}\n\n" + + "data: [DONE]\n\n" + + meta, _, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.Usage.PromptTokens != 101 { + t.Errorf("PromptTokens = %d, want 101", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 44 { + t.Errorf("CompletionTokens = %d, want 44", meta.Usage.CompletionTokens) + } + if meta.Usage.TotalTokens != 145 { + t.Errorf("TotalTokens = %d, want 145", meta.Usage.TotalTokens) + } +} + +func TestResponsesStreamingExtractor_CacheUsageExtraction(t *testing.T) { + stream := "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":100,\"output_tokens\":20,\"total_tokens\":120,\"input_tokens_details\":{\"cached_tokens\":80}}}}\n\n" + + "data: [DONE]\n\n" + + meta, _, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cacheUsage, ok := meta.Custom["cache_usage"].(llmproxy.CacheUsage) + if !ok { + t.Fatalf("expected cache_usage in custom metadata") + } + if cacheUsage.CachedTokens != 80 { + t.Errorf("CachedTokens = %d, want 80", cacheUsage.CachedTokens) + } +} + +func TestResponsesStreamingExtractor_ReasoningTokens(t *testing.T) { + stream := "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":10,\"output_tokens\":10,\"total_tokens\":20,\"output_tokens_details\":{\"reasoning_tokens\":7}}}}\n\n" + + "data: [DONE]\n\n" + + meta, _, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + value, ok := meta.Custom["reasoning_tokens"].(int) + if !ok { + t.Fatalf("expected reasoning_tokens custom field") + } + if value != 7 { + t.Errorf("reasoning_tokens = %d, want 7", value) + } +} + +func TestResponsesStreamingExtractor_FunctionCallStream(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_fc\",\"model\":\"gpt-4o\"}}\n\n" + + "data: {\"type\":\"response.function_call_arguments.delta\",\"delta\":\"{\\\"city\\\":\"}\n\n" + + "data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"id\":\"fc_1\",\"name\":\"get_weather\",\"arguments\":\"{\\\"city\\\":\\\"SF\\\"}\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":15,\"output_tokens\":9,\"total_tokens\":24}}}\n\n" + + "data: [DONE]\n\n" + + meta, output, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(output, "function_call_arguments.delta") { + t.Fatalf("expected function call delta in passthrough") + } + if meta.Usage.TotalTokens != 24 { + t.Errorf("TotalTokens = %d, want 24", meta.Usage.TotalTokens) + } +} + +func TestResponsesStreamingExtractor_ErrorEvent(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_err\",\"model\":\"gpt-4o\"}}\n\n" + + "data: {\"type\":\"response.failed\",\"response\":{\"id\":\"resp_err\",\"status\":\"failed\"},\"error\":{\"message\":\"upstream failed\"}}\n\n" + + "data: [DONE]\n\n" + + meta, output, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(output, "response.failed") { + t.Fatalf("expected failed event in output") + } + if meta.ID != "resp_err" { + t.Errorf("ID = %q, want resp_err", meta.ID) + } +} + +func TestResponsesStreamingExtractor_NoCompletedEvent(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-4o\"}}\n\n" + + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello\"}\n\n" + + meta, _, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.Usage.TotalTokens != 0 { + t.Errorf("expected no usage extraction, got %+v", meta.Usage) + } +} + +func TestResponsesStreamingExtractor_EmptyStream(t *testing.T) { + stream := "data: [DONE]\n\n" + + meta, output, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if output != stream { + t.Fatalf("expected passthrough output to match input") + } + if meta.Usage.TotalTokens != 0 { + t.Errorf("expected empty usage, got %+v", meta.Usage) + } +} + +func TestResponsesStreamingExtractor_MalformedEvents(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-4o\"}}\n\n" + + "data: {\"type\":\"response.output_text.delta\",\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":4,\"total_tokens\":7}}}\n\n" + + "data: [DONE]\n\n" + + meta, output, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(output, "response.output_text.delta") { + t.Fatalf("expected malformed line to still be forwarded") + } + if meta.Usage.TotalTokens != 7 { + t.Errorf("TotalTokens = %d, want 7", meta.Usage.TotalTokens) + } +} + +func TestResponsesStreamingExtractor_NonStreamingFallback(t *testing.T) { + body := `{"id":"resp_123","object":"response","model":"gpt-4o","status":"completed","output":[{"id":"msg_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello"}]}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}` + + meta, output, err := runResponsesStreamExtraction(t, "application/json", body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.Custom["api_type"] != llmproxy.APITypeResponses { + t.Errorf("api_type = %v, want responses", meta.Custom["api_type"]) + } + if output != body { + t.Fatalf("non-streaming passthrough mismatch") + } +} + +func TestResponsesStreamingExtractor_IsStreamingResponse(t *testing.T) { + extractor := NewResponsesStreamingExtractor() + + tests := []struct { + name string + contentType string + expected bool + }{ + {name: "sse", contentType: "text/event-stream", expected: true}, + {name: "sse with charset", contentType: "text/event-stream; charset=utf-8", expected: true}, + {name: "json", contentType: "application/json", expected: false}, + {name: "plain", contentType: "text/plain", expected: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &http.Response{Header: http.Header{"Content-Type": []string{tt.contentType}}} + if got := extractor.IsStreamingResponse(resp); got != tt.expected { + t.Errorf("IsStreamingResponse() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestResponsesStreamingExtractor_WithEventPrefixes(t *testing.T) { + stream := "event: response.created\n" + + "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-4o\"}}\n\n" + + "event: response.completed\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n\n" + + "data: [DONE]\n\n" + + meta, output, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(output, "event: response.created") { + t.Fatalf("expected event line passthrough") + } + if meta.Usage.TotalTokens != 3 { + t.Errorf("TotalTokens = %d, want 3", meta.Usage.TotalTokens) + } +} + +func TestResponsesStreamingExtractor_DataPassthrough(t *testing.T) { + stream := "event: response.created\n" + + "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_passthrough\",\"model\":\"gpt-4o\"}}\n\n" + + ": ping\n" + + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello\"}\n\n" + + "data: [DONE]\n\n" + + _, output, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal([]byte(output), []byte(stream)) { + t.Fatalf("expected byte-accurate passthrough") + } +} diff --git a/providers/openai_compatible/responses_test.go b/providers/openai_compatible/responses_test.go index a02d7d7..5f995fc 100644 --- a/providers/openai_compatible/responses_test.go +++ b/providers/openai_compatible/responses_test.go @@ -2,9 +2,12 @@ package openai_compatible import ( "bytes" + "context" "encoding/json" "io" "net/http" + "net/http/httptest" + "strings" "testing" "github.com/agentuity/llmproxy" @@ -1043,6 +1046,14 @@ func TestResponsesExtractor_ReasoningTokensInUsage(t *testing.T) { t.Errorf("CompletionTokens = %d, want 200", meta.Usage.CompletionTokens) } + rt, ok := meta.Custom["reasoning_tokens"].(int) + if !ok { + t.Fatal("expected reasoning_tokens in custom metadata") + } + if rt != 150 { + t.Errorf("reasoning_tokens = %d, want 150", rt) + } + output := meta.Custom["output"].([]ResponsesOutputItem) if output[0].Type != "reasoning" { t.Errorf("First output should be reasoning, got %q", output[0].Type) @@ -1705,3 +1716,110 @@ func TestResponsesExtractor_AnnotationMissingSpanFields(t *testing.T) { t.Errorf("EndIndex should be nil when not provided, got %v", annotation.EndIndex) } } + +func TestStreamingMultiAPIExtractor_ResponsesAPIDispatch(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_dispatch\",\"model\":\"gpt-4o\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_dispatch\",\"model\":\"gpt-4o\",\"usage\":{\"input_tokens\":10,\"output_tokens\":5,\"total_tokens\":15}}}\n\n" + + "data: [DONE]\n\n" + + req, err := http.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil) + if err != nil { + t.Fatalf("NewRequest() error = %v", err) + } + ctxValue := llmproxy.MetaContextValue{ + Meta: llmproxy.BodyMetadata{Custom: map[string]any{"api_type": llmproxy.APITypeResponses}}, + } + req = req.WithContext(context.WithValue(req.Context(), llmproxy.MetaContextKey{}, ctxValue)) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(stream)), + Request: req, + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingMultiAPIExtractor() + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("ExtractStreamingWithController() error = %v", err) + } + + if meta.ID != "resp_dispatch" { + t.Errorf("ID = %q, want resp_dispatch", meta.ID) + } + if meta.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", meta.Usage.TotalTokens) + } + if meta.Custom["api_type"] != llmproxy.APITypeResponses { + t.Errorf("api_type = %v, want responses", meta.Custom["api_type"]) + } +} + +func TestStreamingMultiAPIExtractor_ChatCompletionsDispatch(t *testing.T) { + stream := "data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n" + + "data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":3,\"total_tokens\":12}}\n\n" + + "data: [DONE]\n\n" + + req, err := http.NewRequest(http.MethodPost, "https://example.com/v1/chat/completions", nil) + if err != nil { + t.Fatalf("NewRequest() error = %v", err) + } + ctxValue := llmproxy.MetaContextValue{ + Meta: llmproxy.BodyMetadata{Custom: map[string]any{"api_type": llmproxy.APITypeChatCompletions}}, + } + req = req.WithContext(context.WithValue(req.Context(), llmproxy.MetaContextKey{}, ctxValue)) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(stream)), + Request: req, + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingMultiAPIExtractor() + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("ExtractStreamingWithController() error = %v", err) + } + + if meta.ID != "chatcmpl_1" { + t.Errorf("ID = %q, want chatcmpl_1", meta.ID) + } + if meta.Usage.TotalTokens != 12 { + t.Errorf("TotalTokens = %d, want 12", meta.Usage.TotalTokens) + } +} + +func TestStreamingMultiAPIExtractor_NoContextFallback(t *testing.T) { + stream := "data: {\"id\":\"chatcmpl_fallback\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n" + + "data: {\"id\":\"chatcmpl_fallback\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":4,\"completion_tokens\":2,\"total_tokens\":6}}\n\n" + + "data: [DONE]\n\n" + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(stream)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingMultiAPIExtractor() + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("ExtractStreamingWithController() error = %v", err) + } + + if meta.ID != "chatcmpl_fallback" { + t.Errorf("ID = %q, want chatcmpl_fallback", meta.ID) + } + if meta.Usage.TotalTokens != 6 { + t.Errorf("TotalTokens = %d, want 6", meta.Usage.TotalTokens) + } +} diff --git a/providers/openai_compatible/streaming_extractor.go b/providers/openai_compatible/streaming_extractor.go index 1a8bb2f..99b526a 100644 --- a/providers/openai_compatible/streaming_extractor.go +++ b/providers/openai_compatible/streaming_extractor.go @@ -173,6 +173,9 @@ func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response, if accumulatedUsage.CacheUsage != nil { meta.Custom["cache_usage"] = *accumulatedUsage.CacheUsage } + if accumulatedUsage.ReasoningTokens > 0 { + meta.Custom["reasoning_tokens"] = accumulatedUsage.ReasoningTokens + } } else if lastChunk != nil { for _, choice := range lastChunk.Choices { c := llmproxy.Choice{ diff --git a/providers/openai_compatible/streaming_extractor_test.go b/providers/openai_compatible/streaming_extractor_test.go index c25dd4c..0d99fd4 100644 --- a/providers/openai_compatible/streaming_extractor_test.go +++ b/providers/openai_compatible/streaming_extractor_test.go @@ -145,3 +145,86 @@ data: [DONE] t.Errorf("expected cached tokens 80, got %d", cacheUsage.CachedTokens) } } + +func TestStreamingExtractor_ExtractStreamingWithReasoning(t *testing.T) { + streamData := `data: {"id":"chatcmpl-456","model":"o1","choices":[{"index":0,"delta":{"content":"test"}}]} + +data: {"id":"chatcmpl-456","model":"o1","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":75,"completion_tokens":1186,"total_tokens":1261,"completion_tokens_details":{"reasoning_tokens":1024}}} + +data: [DONE] + +` + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(streamData)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Usage.PromptTokens != 75 { + t.Errorf("expected prompt tokens 75, got %d", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 1186 { + t.Errorf("expected completion tokens 1186, got %d", meta.Usage.CompletionTokens) + } + + rt, ok := meta.Custom["reasoning_tokens"].(int) + if !ok { + t.Fatal("expected reasoning_tokens in custom map") + } + if rt != 1024 { + t.Errorf("expected reasoning tokens 1024, got %d", rt) + } +} + +func TestStreamingExtractor_ExtractStreamingWithCacheAndReasoning(t *testing.T) { + streamData := `data: {"id":"chatcmpl-789","model":"o1","choices":[{"index":0,"delta":{"content":"test"}}]} + +data: {"id":"chatcmpl-789","model":"o1","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":100,"completion_tokens":500,"total_tokens":600,"prompt_tokens_details":{"cached_tokens":80},"completion_tokens_details":{"reasoning_tokens":256}}} + +data: [DONE] + +` + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(streamData)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cacheUsage, ok := meta.Custom["cache_usage"].(llmproxy.CacheUsage) + if !ok { + t.Fatal("expected cache_usage in custom map") + } + if cacheUsage.CachedTokens != 80 { + t.Errorf("expected cached tokens 80, got %d", cacheUsage.CachedTokens) + } + + rt, ok := meta.Custom["reasoning_tokens"].(int) + if !ok { + t.Fatal("expected reasoning_tokens in custom map") + } + if rt != 256 { + t.Errorf("expected reasoning tokens 256, got %d", rt) + } +} diff --git a/providers/openai_compatible/websocket.go b/providers/openai_compatible/websocket.go new file mode 100644 index 0000000..ab190ca --- /dev/null +++ b/providers/openai_compatible/websocket.go @@ -0,0 +1,28 @@ +package openai_compatible + +import ( + "fmt" + "net/url" + + "github.com/agentuity/llmproxy" +) + +// WebSocketURL converts the provider HTTP base URL to a WebSocket URL. +// +// https://api.openai.com -> wss://api.openai.com/v1/responses +// http://localhost:8080 -> ws://localhost:8080/v1/responses +func (r *Resolver) WebSocketURL(meta llmproxy.BodyMetadata) (*url.URL, error) { + if r == nil || r.BaseURL == nil { + return nil, fmt.Errorf("resolver base URL is nil") + } + + u := *r.BaseURL + switch u.Scheme { + case "https": + u.Scheme = "wss" + case "http": + u.Scheme = "ws" + } + + return u.JoinPath("v1", "responses"), nil +} diff --git a/providers/openai_compatible/websocket_test.go b/providers/openai_compatible/websocket_test.go new file mode 100644 index 0000000..115280f --- /dev/null +++ b/providers/openai_compatible/websocket_test.go @@ -0,0 +1,78 @@ +package openai_compatible + +import ( + "testing" + + "github.com/agentuity/llmproxy" +) + +func TestWebSocketURL_HTTPS(t *testing.T) { + r, err := NewResolver("https://api.openai.com") + if err != nil { + t.Fatalf("NewResolver() error = %v", err) + } + u, err := r.WebSocketURL(llmproxy.BodyMetadata{}) + if err != nil { + t.Fatalf("WebSocketURL() error = %v", err) + } + if got, want := u.String(), "wss://api.openai.com/v1/responses"; got != want { + t.Fatalf("URL = %q, want %q", got, want) + } +} + +func TestWebSocketURL_HTTP(t *testing.T) { + r, err := NewResolver("http://localhost:8080") + if err != nil { + t.Fatalf("NewResolver() error = %v", err) + } + u, err := r.WebSocketURL(llmproxy.BodyMetadata{}) + if err != nil { + t.Fatalf("WebSocketURL() error = %v", err) + } + if got, want := u.String(), "ws://localhost:8080/v1/responses"; got != want { + t.Fatalf("URL = %q, want %q", got, want) + } +} + +func TestWebSocketURL_WithTrailingSlash(t *testing.T) { + r, err := NewResolver("https://api.openai.com/") + if err != nil { + t.Fatalf("NewResolver() error = %v", err) + } + u, err := r.WebSocketURL(llmproxy.BodyMetadata{}) + if err != nil { + t.Fatalf("WebSocketURL() error = %v", err) + } + if got, want := u.String(), "wss://api.openai.com/v1/responses"; got != want { + t.Fatalf("URL = %q, want %q", got, want) + } +} + +func TestWebSocketURL_WithExistingPath(t *testing.T) { + r, err := NewResolver("https://api.openai.com/v1") + if err != nil { + t.Fatalf("NewResolver() error = %v", err) + } + u, err := r.WebSocketURL(llmproxy.BodyMetadata{}) + if err != nil { + t.Fatalf("WebSocketURL() error = %v", err) + } + // NewResolver normalizes trailing /v1 so WebSocketURL doesn't double it + if got, want := u.String(), "wss://api.openai.com/v1/responses"; got != want { + t.Fatalf("URL = %q, want %q", got, want) + } +} + +func TestWebSocketURL_WithExistingV1Slash(t *testing.T) { + r, err := NewResolver("https://api.openai.com/v1/") + if err != nil { + t.Fatalf("NewResolver() error = %v", err) + } + u, err := r.WebSocketURL(llmproxy.BodyMetadata{}) + if err != nil { + t.Fatalf("WebSocketURL() error = %v", err) + } + if got, want := u.String(), "wss://api.openai.com/v1/responses"; got != want { + t.Fatalf("URL = %q, want %q", got, want) + } +} diff --git a/streaming.go b/streaming.go index 6a10195..45d7ffc 100644 --- a/streaming.go +++ b/streaming.go @@ -93,6 +93,7 @@ type StreamingUsage struct { CompletionTokens int TotalTokens int CacheUsage *CacheUsage + ReasoningTokens int } type OpenAIStreamChunk struct { @@ -200,6 +201,35 @@ type AnthropicStreamMessage struct { Usage *AnthropicStreamUsage `json:"usage,omitempty"` } +type ResponsesStreamEvent struct { + Type string `json:"type"` + Response json.RawMessage `json:"response,omitempty"` +} + +type ResponsesStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Status string `json:"status"` + Usage *ResponsesStreamUsage `json:"usage,omitempty"` +} + +type ResponsesStreamUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + InputTokensDetails *ResponsesStreamInputDetails `json:"input_tokens_details,omitempty"` + OutputTokensDetails *ResponsesStreamOutputDetails `json:"output_tokens_details,omitempty"` +} + +type ResponsesStreamInputDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +type ResponsesStreamOutputDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` +} + func ParseAnthropicSSEEvent(data []byte) (*AnthropicStreamEvent, error) { data = bytes.TrimSpace(data) if len(data) == 0 { @@ -214,6 +244,24 @@ func ParseAnthropicSSEEvent(data []byte) (*AnthropicStreamEvent, error) { return &event, nil } +func ParseResponsesSSEEvent(data []byte) (*ResponsesStreamEvent, error) { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return nil, nil + } + + if bytes.Equal(data, []byte("[DONE]")) { + return nil, ErrStreamComplete + } + + var event ResponsesStreamEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, err + } + + return &event, nil +} + func IsSSEStream(contentType string) bool { return strings.Contains(strings.ToLower(contentType), "text/event-stream") } @@ -235,6 +283,10 @@ func ExtractUsageFromOpenAIChunk(chunk *OpenAIStreamChunk) *StreamingUsage { } } + if chunk.Usage.CompletionTokensDetails != nil && chunk.Usage.CompletionTokensDetails.ReasoningTokens > 0 { + usage.ReasoningTokens = chunk.Usage.CompletionTokensDetails.ReasoningTokens + } + return usage } @@ -279,6 +331,39 @@ func ExtractUsageFromAnthropicEvent(event *AnthropicStreamEvent) *StreamingUsage return result } +func ExtractUsageFromResponsesEvent(event *ResponsesStreamEvent) *StreamingUsage { + if event == nil || event.Type != "response.completed" || len(event.Response) == 0 { + return nil + } + + var response ResponsesStreamResponse + if err := json.Unmarshal(event.Response, &response); err != nil { + return nil + } + + if response.Usage == nil { + return nil + } + + usage := &StreamingUsage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + } + + if response.Usage.InputTokensDetails != nil && response.Usage.InputTokensDetails.CachedTokens > 0 { + usage.CacheUsage = &CacheUsage{ + CachedTokens: response.Usage.InputTokensDetails.CachedTokens, + } + } + + if response.Usage.OutputTokensDetails != nil && response.Usage.OutputTokensDetails.ReasoningTokens > 0 { + usage.ReasoningTokens = response.Usage.OutputTokensDetails.ReasoningTokens + } + + return usage +} + func FormatSSEEvent(event string, data []byte) []byte { var buf bytes.Buffer if len(event) > 0 { diff --git a/streaming_test.go b/streaming_test.go index 2cb5db9..ceeaf42 100644 --- a/streaming_test.go +++ b/streaming_test.go @@ -214,6 +214,50 @@ func TestExtractUsageFromOpenAIChunk(t *testing.T) { }, }, }, + { + name: "chunk with reasoning tokens", + chunk: &OpenAIStreamChunk{ + Usage: &OpenAIStreamUsage{ + PromptTokens: 75, + CompletionTokens: 1186, + TotalTokens: 1261, + CompletionTokensDetails: &OpenAIStreamCompletionDetails{ + ReasoningTokens: 1024, + }, + }, + }, + expected: &StreamingUsage{ + PromptTokens: 75, + CompletionTokens: 1186, + TotalTokens: 1261, + ReasoningTokens: 1024, + }, + }, + { + name: "chunk with both cache and reasoning tokens", + chunk: &OpenAIStreamChunk{ + Usage: &OpenAIStreamUsage{ + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + PromptTokensDetails: &OpenAIStreamPromptDetails{ + CachedTokens: 50, + }, + CompletionTokensDetails: &OpenAIStreamCompletionDetails{ + ReasoningTokens: 128, + }, + }, + }, + expected: &StreamingUsage{ + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + CacheUsage: &CacheUsage{ + CachedTokens: 50, + }, + ReasoningTokens: 128, + }, + }, } for _, tt := range tests { @@ -248,6 +292,10 @@ func TestExtractUsageFromOpenAIChunk(t *testing.T) { t.Errorf("expected CachedTokens %d, got %d", tt.expected.CacheUsage.CachedTokens, result.CacheUsage.CachedTokens) } } + + if result.ReasoningTokens != tt.expected.ReasoningTokens { + t.Errorf("expected ReasoningTokens %d, got %d", tt.expected.ReasoningTokens, result.ReasoningTokens) + } }) } } @@ -559,3 +607,202 @@ data: {"type":"message_stop"} t.Errorf("expected 25 output tokens, got %d", deltaEvent.Usage.OutputTokens) } } + +func TestParseResponsesSSEEvent_Created(t *testing.T) { + data := []byte(`{"type":"response.created","response":{"id":"resp_123","object":"response","model":"gpt-4o","status":"in_progress"}}`) + + event, err := ParseResponsesSSEEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if event == nil { + t.Fatal("expected non-nil event") + } + if event.Type != "response.created" { + t.Errorf("Type = %q, want response.created", event.Type) + } + if len(event.Response) == 0 { + t.Fatal("expected response payload") + } +} + +func TestParseResponsesSSEEvent_TextDelta(t *testing.T) { + data := []byte(`{"type":"response.output_text.delta","delta":"Hello","content_index":0,"output_index":0}`) + + event, err := ParseResponsesSSEEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if event == nil { + t.Fatal("expected non-nil event") + } + if event.Type != "response.output_text.delta" { + t.Errorf("Type = %q, want response.output_text.delta", event.Type) + } +} + +func TestParseResponsesSSEEvent_Completed(t *testing.T) { + data := []byte(`{"type":"response.completed","response":{"id":"resp_123","model":"gpt-4o","status":"completed","usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}}`) + + event, err := ParseResponsesSSEEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if event == nil { + t.Fatal("expected non-nil event") + } + if event.Type != "response.completed" { + t.Errorf("Type = %q, want response.completed", event.Type) + } +} + +func TestParseResponsesSSEEvent_Empty(t *testing.T) { + event, err := ParseResponsesSSEEvent([]byte(" \n\t ")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if event != nil { + t.Fatalf("expected nil event, got %+v", event) + } +} + +func TestParseResponsesSSEEvent_Done(t *testing.T) { + event, err := ParseResponsesSSEEvent([]byte("[DONE]")) + if err != ErrStreamComplete { + t.Fatalf("expected ErrStreamComplete, got %v", err) + } + if event != nil { + t.Fatalf("expected nil event for done marker, got %+v", event) + } +} + +func TestParseResponsesSSEEvent_MalformedJSON(t *testing.T) { + event, err := ParseResponsesSSEEvent([]byte(`{"type":"response.created",`)) + if err == nil { + t.Fatal("expected error for malformed JSON") + } + if event != nil { + t.Fatalf("expected nil event on malformed input, got %+v", event) + } +} + +func TestExtractUsageFromResponsesEvent(t *testing.T) { + tests := []struct { + name string + event *ResponsesStreamEvent + expectedPrompt int + expectedCompletion int + expectedTotal int + expectedCached int + expectedReasoning int + expectNil bool + }{ + { + name: "nil event", + event: nil, + expectNil: true, + }, + { + name: "completed with usage", + event: &ResponsesStreamEvent{ + Type: "response.completed", + Response: []byte(`{"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}`), + }, + expectedPrompt: 10, + expectedCompletion: 5, + expectedTotal: 15, + }, + { + name: "completed without usage", + event: &ResponsesStreamEvent{ + Type: "response.completed", + Response: []byte(`{"id":"resp_1","status":"completed"}`), + }, + expectNil: true, + }, + { + name: "non-completed event", + event: &ResponsesStreamEvent{ + Type: "response.created", + Response: []byte(`{"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}`), + }, + expectNil: true, + }, + { + name: "usage with cached tokens", + event: &ResponsesStreamEvent{ + Type: "response.completed", + Response: []byte(`{ + "usage":{ + "input_tokens":100, + "output_tokens":20, + "total_tokens":120, + "input_tokens_details":{"cached_tokens":80} + } + }`), + }, + expectedPrompt: 100, + expectedCompletion: 20, + expectedTotal: 120, + expectedCached: 80, + }, + { + name: "usage with reasoning tokens", + event: &ResponsesStreamEvent{ + Type: "response.completed", + Response: []byte(`{ + "usage":{ + "input_tokens":30, + "output_tokens":10, + "total_tokens":40, + "output_tokens_details":{"reasoning_tokens":7} + } + }`), + }, + expectedPrompt: 30, + expectedCompletion: 10, + expectedTotal: 40, + expectedReasoning: 7, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractUsageFromResponsesEvent(tt.event) + if tt.expectNil { + if result != nil { + t.Fatalf("expected nil, got %+v", result) + } + return + } + + if result == nil { + t.Fatal("expected non-nil usage") + } + if result.PromptTokens != tt.expectedPrompt { + t.Errorf("PromptTokens = %d, want %d", result.PromptTokens, tt.expectedPrompt) + } + if result.CompletionTokens != tt.expectedCompletion { + t.Errorf("CompletionTokens = %d, want %d", result.CompletionTokens, tt.expectedCompletion) + } + if result.TotalTokens != tt.expectedTotal { + t.Errorf("TotalTokens = %d, want %d", result.TotalTokens, tt.expectedTotal) + } + + if tt.expectedCached > 0 { + if result.CacheUsage == nil { + t.Fatal("expected cache usage") + } + if result.CacheUsage.CachedTokens != tt.expectedCached { + t.Errorf("CachedTokens = %d, want %d", result.CacheUsage.CachedTokens, tt.expectedCached) + } + } + + if tt.expectedReasoning > 0 { + if result.ReasoningTokens != tt.expectedReasoning { + t.Errorf("ReasoningTokens = %d, want %d", result.ReasoningTokens, tt.expectedReasoning) + } + } + }) + } +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..9c44745 --- /dev/null +++ b/websocket.go @@ -0,0 +1,135 @@ +package llmproxy + +import ( + "context" + "encoding/json" + "net/http" + "net/url" +) + +// RFC 6455 WebSocket message type constants. +const ( + TextMessage = 1 + BinaryMessage = 2 + CloseMessage = 8 + PingMessage = 9 + PongMessage = 10 +) + +// WSConn abstracts a WebSocket connection for reading and writing messages. +// +// gorilla/websocket's *Conn satisfies this interface directly. +type WSConn interface { + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error + Close() error +} + +// WSUpgrader upgrades an HTTP request to a WebSocket connection. +// Consumers wrap their WebSocket library's upgrader to implement this. +type WSUpgrader interface { + Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (WSConn, error) +} + +// WSDialer dials a WebSocket connection to an upstream server. +// Consumers wrap their WebSocket library's dialer to implement this. +type WSDialer interface { + DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (WSConn, *http.Response, error) +} + +// WebSocketCapableProvider is implemented by providers that support WebSocket mode. +type WebSocketCapableProvider interface { + Provider + // WebSocketURL returns the upstream WebSocket URL for this provider. + WebSocketURL(meta BodyMetadata) (*url.URL, error) +} + +// WSEventCallback is an optional callback for WebSocket events. +// usage is non-nil for response.completed events that include usage data. +type WSEventCallback func(eventType string, data []byte, usage *StreamingUsage) + +// WSBillingCallback is invoked per completed response turn. +type WSBillingCallback func(turn int, meta ResponseMetadata, billing *BillingResult) + +// WSMessage is a lightweight parsed view of a WebSocket JSON message. +type WSMessage struct { + Type string `json:"type"` + Model string `json:"model,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Raw json.RawMessage `json:"-"` +} + +// ParseWSMessage parses a WebSocket JSON message and extracts commonly used fields. +func ParseWSMessage(data []byte) (*WSMessage, error) { + var msg WSMessage + if err := json.Unmarshal(data, &msg); err != nil { + return nil, err + } + msg.Raw = append(json.RawMessage(nil), data...) + return &msg, nil +} + +// WSResponseCompleted is the minimal shape needed to extract usage from +// OpenAI Responses API WebSocket response.completed events. +type WSResponseCompleted struct { + Type string `json:"type"` + Response *WSResponseEnvelope `json:"response,omitempty"` + Usage *WSResponseUsage `json:"usage,omitempty"` +} + +type WSResponseEnvelope struct { + Usage *WSResponseUsage `json:"usage,omitempty"` +} + +type WSResponseUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + InputTokensDetails *WSResponseInputDetails `json:"input_tokens_details,omitempty"` + OutputTokensDetails *WSResponseOutputDetails `json:"output_tokens_details,omitempty"` +} + +type WSResponseInputDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +type WSResponseOutputDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` +} + +// ExtractWSUsage extracts usage from a response.completed WebSocket message. +// Returns nil,nil for non-response.completed events. +func ExtractWSUsage(data []byte) (*StreamingUsage, error) { + var msg WSResponseCompleted + if err := json.Unmarshal(data, &msg); err != nil { + return nil, err + } + + if msg.Type != "response.completed" { + return nil, nil + } + + usage := msg.Usage + if usage == nil && msg.Response != nil { + usage = msg.Response.Usage + } + if usage == nil { + return nil, nil + } + + out := &StreamingUsage{ + PromptTokens: usage.InputTokens, + CompletionTokens: usage.OutputTokens, + TotalTokens: usage.TotalTokens, + } + + if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { + out.CacheUsage = &CacheUsage{CachedTokens: usage.InputTokensDetails.CachedTokens} + } + + if usage.OutputTokensDetails != nil && usage.OutputTokensDetails.ReasoningTokens > 0 { + out.ReasoningTokens = usage.OutputTokensDetails.ReasoningTokens + } + + return out, nil +} diff --git a/websocket_test.go b/websocket_test.go new file mode 100644 index 0000000..360cdbe --- /dev/null +++ b/websocket_test.go @@ -0,0 +1,100 @@ +package llmproxy + +import "testing" + +func TestParseWSMessage_ResponseCreate(t *testing.T) { + msg, err := ParseWSMessage([]byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)) + if err != nil { + t.Fatalf("ParseWSMessage() error = %v", err) + } + if msg.Type != "response.create" { + t.Fatalf("Type = %q, want response.create", msg.Type) + } + if msg.Model != "gpt-4o" { + t.Fatalf("Model = %q, want gpt-4o", msg.Model) + } +} + +func TestParseWSMessage_ResponseCreateWithPreviousID(t *testing.T) { + msg, err := ParseWSMessage([]byte(`{"type":"response.create","model":"gpt-4o","previous_response_id":"resp_123"}`)) + if err != nil { + t.Fatalf("ParseWSMessage() error = %v", err) + } + if msg.PreviousResponseID != "resp_123" { + t.Fatalf("PreviousResponseID = %q, want resp_123", msg.PreviousResponseID) + } +} + +func TestParseWSMessage_NonCreate(t *testing.T) { + msg, err := ParseWSMessage([]byte(`{"type":"response.output_text.delta","delta":"hi"}`)) + if err != nil { + t.Fatalf("ParseWSMessage() error = %v", err) + } + if msg.Type != "response.output_text.delta" { + t.Fatalf("Type = %q, want response.output_text.delta", msg.Type) + } +} + +func TestParseWSMessage_MalformedJSON(t *testing.T) { + if _, err := ParseWSMessage([]byte(`{"type":`)); err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestExtractWSUsage_Completed(t *testing.T) { + usage, err := ExtractWSUsage([]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":12,"output_tokens":7,"total_tokens":19}}}`)) + if err != nil { + t.Fatalf("ExtractWSUsage() error = %v", err) + } + if usage == nil { + t.Fatal("usage is nil") + } + if usage.PromptTokens != 12 || usage.CompletionTokens != 7 || usage.TotalTokens != 19 { + t.Fatalf("unexpected usage: %+v", usage) + } +} + +func TestExtractWSUsage_CompletedWithCache(t *testing.T) { + usage, err := ExtractWSUsage([]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":20,"output_tokens":5,"total_tokens":25,"input_tokens_details":{"cached_tokens":4}}}}`)) + if err != nil { + t.Fatalf("ExtractWSUsage() error = %v", err) + } + if usage == nil || usage.CacheUsage == nil { + t.Fatal("expected cache usage") + } + if usage.CacheUsage.CachedTokens != 4 { + t.Fatalf("CachedTokens = %d, want 4", usage.CacheUsage.CachedTokens) + } +} + +func TestExtractWSUsage_CompletedWithReasoning(t *testing.T) { + usage, err := ExtractWSUsage([]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":10,"output_tokens":8,"total_tokens":18,"output_tokens_details":{"reasoning_tokens":3}}}}`)) + if err != nil { + t.Fatalf("ExtractWSUsage() error = %v", err) + } + if usage == nil { + t.Fatal("expected non-nil usage") + } + if usage.PromptTokens != 10 || usage.CompletionTokens != 8 || usage.TotalTokens != 18 { + t.Fatalf("unexpected usage: %+v", usage) + } + if usage.ReasoningTokens != 3 { + t.Fatalf("ReasoningTokens = %d, want 3", usage.ReasoningTokens) + } +} + +func TestExtractWSUsage_NonCompleted(t *testing.T) { + usage, err := ExtractWSUsage([]byte(`{"type":"response.created","response":{"id":"resp_1"}}`)) + if err != nil { + t.Fatalf("ExtractWSUsage() error = %v", err) + } + if usage != nil { + t.Fatalf("usage = %+v, want nil", usage) + } +} + +func TestExtractWSUsage_MalformedJSON(t *testing.T) { + if _, err := ExtractWSUsage([]byte(`{"type":`)); err == nil { + t.Fatal("expected error for malformed JSON") + } +}