Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions cmd/thv/app/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@
// Header forwarding flags
remoteForwardHeaders []string
remoteForwardHeadersSecret []string

// CORS flags
proxyCORSOrigins []string
)

// Environment variable names
Expand Down Expand Up @@ -157,6 +160,19 @@
proxyCmd.Flags().StringArrayVar(&remoteForwardHeadersSecret, "remote-forward-headers-secret", []string{},
"Headers with secret values from ToolHive secrets manager (format: Name=secret-name, can be repeated)")

// CORS — disabled by default; opt in explicitly to avoid widening the attack surface
proxyCmd.Flags().StringArrayVar(&proxyCORSOrigins, "allow-origins", []string{},

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to do before CI goes green: run task docs from the repo root to regenerate the CLI reference docs. Adding --allow-origins means the auto-generated content under docs/ is out of date, and the docs check will fail without it.

`Allowed CORS origins for the MCP proxy endpoint (repeatable).
CORS is disabled by default; if you handle CORS at a gateway or reverse proxy,
leaving this unset is the correct, secure choice. Each origin must include a
scheme (e.g. http://) and no trailing slash, otherwise it can never match a
browser request.
Supported forms:
exact: http://localhost:6274
scheme+host: http://localhost (matches any port on localhost)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "disabled by default" note is in the code comment above, not the help string an operator sees in --help. An operator who handles CORS at a gateway gets no signal that doing nothing is the correct choice. One line in the help text would fix it:

CORS is disabled by default; omit this flag when CORS is handled by an upstream gateway.

Also watch the --allow-origins "" edge: it passes the len > 0 check (length 1), activates the middleware, but matchCORSOrigin matches nothing... so preflights get eaten (204 instead of 405) with no CORS headers. Half-configured and confusing. Validate/reject empty entries or filter them before the guard.

wildcard: * (allow all — use with caution)
Example: --allow-origins http://localhost:6274`)

// Mark target-uri as required
if err := proxyCmd.MarkFlagRequired("target-uri"); err != nil {
slog.Warn(fmt.Sprintf("Failed to mark flag as required: %v", err))
Expand All @@ -167,7 +183,7 @@

}

func proxyCmdFunc(cmd *cobra.Command, args []string) error {

Check failure on line 186 in cmd/thv/app/proxy.go

View workflow job for this annotation

GitHub Actions / Linting / Lint Go Code

cyclomatic complexity 16 of func `proxyCmdFunc` is high (> 15) (gocyclo)
ctx, stopSignal := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
defer stopSignal()
// Get the server name
Expand Down Expand Up @@ -264,8 +280,20 @@
slog.Debug(fmt.Sprintf("Setting up transparent proxy to forward from host port %d to %s",
port, proxyTargetURI))

// Build optional functional options (e.g. CORS), only when configured.
var proxyOptions []transparent.Option
if len(proxyCORSOrigins) > 0 {
// Validate origins at startup so a misconfigured entry (missing scheme,
// trailing slash) fails loudly instead of silently never matching.
corsOrigins, err := middleware.ValidateAndNormalizeOrigins(proxyCORSOrigins)
if err != nil {
return fmt.Errorf("invalid --allow-origins: %w", err)
}
proxyOptions = append(proxyOptions, transparent.WithAllowedOrigins(corsOrigins))
}

// Create the transparent proxy with middlewares
proxy := transparent.NewTransparentProxy(
proxy := transparent.NewTransparentProxyWithOptions(
proxyHost,
port,
proxyTargetURI,
Expand All @@ -279,7 +307,8 @@
nil, // onUnauthorizedResponse - not needed for local proxies
"", // endpointPrefix - not configured for proxy command
false, // trustProxyHeaders - not configured for proxy command
middlewares...)
middlewares,
proxyOptions...)
if err := proxy.Start(ctx); err != nil {
return fmt.Errorf("failed to start proxy: %w", err)
}
Expand Down
10 changes: 10 additions & 0 deletions docs/cli/thv_proxy.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

164 changes: 164 additions & 0 deletions pkg/transport/middleware/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package middleware

import (
"fmt"
"log/slog"
"net/http"
"strings"

"github.com/stacklok/toolhive/pkg/transport/types"
)

const (
// defaultCORSAllowedMethods is the fallback preflight method list used when
// the caller does not supply an explicit set. The proxy passes a set derived
// from the server's actual capabilities (see the stateless/stateful method
// sources of truth in the transparent proxy) so the preflight never
// advertises a method the backend will reject.
defaultCORSAllowedMethods = "GET, POST, DELETE, OPTIONS"

// corsAllowedHeaders lists request headers MCP clients may send. MCP-Protocol-Version
// must be allow-listed: ToolHive reads and validates it on the request path
// (an unsupported value yields 400), so a browser MCP client cannot send it
// through CORS unless it is listed here.
corsAllowedHeaders = "Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Authorization"

// corsExposedHeaders lists response headers that browsers may read. MCP-Protocol-Version
// is exposed so a browser client can read the negotiated protocol version back.
// Content-Type is omitted because it is already a CORS-safelisted response
// header and does not need to be exposed explicitly.
corsExposedHeaders = "Mcp-Session-Id, MCP-Protocol-Version"

// corsMaxAge is the preflight cache lifetime in seconds (24 hours).
corsMaxAge = "86400"
)

// CORS returns a middleware that handles CORS preflight (OPTIONS) requests and
// injects Access-Control-Allow-* response headers. When allowedOrigins is empty
// the middleware is a no-op, preserving the default security posture.
//
// Origin matching rules (applied in order):
// - "*": matches every origin; Access-Control-Allow-Origin is set to "*".
// - Exact: "http://localhost:6274" matches only that origin.
// - Scheme+host prefix: "http://localhost" also matches any
// "http://localhost:<port>" (e.g. the MCP Inspector default port).
//
// All OPTIONS requests are handled directly (returning 204) when this middleware
// is active so that CORS preflights never reach the backend, which previously
// returned 405 Method Not Allowed.
//
// allowedMethods is the value advertised in Access-Control-Allow-Methods. It
// should reflect the methods the backend actually accepts so a preflight never
// succeeds for a method the real request would reject. When empty,
// defaultCORSAllowedMethods is used.
func CORS(allowedOrigins []string, allowedMethods string) types.MiddlewareFunction {
if len(allowedOrigins) == 0 {
return func(next http.Handler) http.Handler { return next }
}

if allowedMethods == "" {
allowedMethods = defaultCORSAllowedMethods
}

slog.Debug("CORS middleware configured",
"allowed_origins", strings.Join(allowedOrigins, ", "), "allowed_methods", allowedMethods)

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
matched := matchCORSOrigin(origin, allowedOrigins)

if matched != "" {
h := w.Header()
h.Set("Access-Control-Allow-Origin", matched)
h.Set("Access-Control-Allow-Methods", allowedMethods)
h.Set("Access-Control-Allow-Headers", corsAllowedHeaders)
h.Set("Access-Control-Expose-Headers", corsExposedHeaders)
h.Add("Vary", "Origin")
}

// Intercept OPTIONS so preflight requests never reach the backend
// (which returns 405 because it has no OPTIONS handler).
// A matched origin gets the full preflight response; an unmatched
// origin gets 204 without CORS headers — the browser will reject
// the follow-up request, which is the correct security outcome.
if r.Method == http.MethodOptions {
if matched != "" {
w.Header().Set("Access-Control-Max-Age", corsMaxAge)
}
w.WriteHeader(http.StatusNoContent)
return
}

next.ServeHTTP(w, r)
})
}
}

// matchCORSOrigin returns the Access-Control-Allow-Origin value to send when
// requestOrigin matches an entry in allowed, or "" when there is no match.
//
// The returned value is the verbatim requestOrigin (a concrete origin) except
// when an allowed entry is "*", in which case "*" is returned directly.
func matchCORSOrigin(requestOrigin string, allowed []string) string {
if requestOrigin == "" {
return ""
}
for _, entry := range allowed {
switch {
case entry == "*":
return "*"
case entry == requestOrigin:
return requestOrigin
case strings.HasPrefix(requestOrigin, entry+":"):
// The trailing ":" boundary is load-bearing: it ensures the entry
// matches only "<entry>:<port>" and never a longer host. Without it,
// the entry "http://localhost" would also match
// "http://localhost.evil.com". See cors_test.go for the invariant.
return requestOrigin
}
}
return ""
}

// ValidateAndNormalizeOrigins validates configured CORS origins and returns a
// normalized copy. It surfaces misconfiguration at startup instead of letting an
// origin silently never match (which produces a broken browser experience with
// no signal):
//
// - "*" (wildcard) is passed through unchanged.
// - A trailing slash (e.g. "http://localhost:6274/") is stripped — a browser
// Origin header never carries one — with a warning, so the entry still matches.
// - An entry without a scheme (e.g. "localhost:6274") can never match a browser
// Origin header (always scheme://host[:port]) and is rejected with an error.
func ValidateAndNormalizeOrigins(origins []string) ([]string, error) {
normalized := make([]string, 0, len(origins))
for _, origin := range origins {
entry := strings.TrimSpace(origin)

if entry == "*" {
normalized = append(normalized, entry)
continue
}

if strings.HasSuffix(entry, "/") {
stripped := strings.TrimRight(entry, "/")
slog.Warn("CORS origin has a trailing slash that browsers never send; normalizing",
"origin", origin, "normalized", stripped)
entry = stripped
}

// A browser Origin is scheme://host[:port]; without a scheme the entry
// can never match an incoming Origin header.
if !strings.Contains(entry, "://") {
return nil, fmt.Errorf(
"invalid CORS origin %q: missing scheme (expected e.g. %q)", origin, "http://localhost:6274")
}

normalized = append(normalized, entry)
}
return normalized, nil
}
Loading
Loading