-
Notifications
You must be signed in to change notification settings - Fork 239
Add CORS support to the transparent MCP proxy #5588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -123,6 +123,9 @@ | |
| // Header forwarding flags | ||
| remoteForwardHeaders []string | ||
| remoteForwardHeadersSecret []string | ||
|
|
||
| // CORS flags | ||
| proxyCORSOrigins []string | ||
| ) | ||
|
|
||
| // Environment variable names | ||
|
|
@@ -157,6 +160,19 @@ | |
| proxyCmd.Flags().StringArrayVar(&remoteForwardHeadersSecret, "remote-forward-headers-secret", []string{}, | ||
| "Headers with secret values from ToolHive secrets manager (format: Name=secret-name, can be repeated)") | ||
|
|
||
| // CORS — disabled by default; opt in explicitly to avoid widening the attack surface | ||
| proxyCmd.Flags().StringArrayVar(&proxyCORSOrigins, "allow-origins", []string{}, | ||
| `Allowed CORS origins for the MCP proxy endpoint (repeatable). | ||
| CORS is disabled by default; if you handle CORS at a gateway or reverse proxy, | ||
| leaving this unset is the correct, secure choice. Each origin must include a | ||
| scheme (e.g. http://) and no trailing slash, otherwise it can never match a | ||
| browser request. | ||
| Supported forms: | ||
| exact: http://localhost:6274 | ||
| scheme+host: http://localhost (matches any port on localhost) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The "disabled by default" note is in the code comment above, not the help string an operator sees in
Also watch the |
||
| wildcard: * (allow all — use with caution) | ||
| Example: --allow-origins http://localhost:6274`) | ||
|
|
||
| // Mark target-uri as required | ||
| if err := proxyCmd.MarkFlagRequired("target-uri"); err != nil { | ||
| slog.Warn(fmt.Sprintf("Failed to mark flag as required: %v", err)) | ||
|
|
@@ -167,7 +183,7 @@ | |
|
|
||
| } | ||
|
|
||
| func proxyCmdFunc(cmd *cobra.Command, args []string) error { | ||
| ctx, stopSignal := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) | ||
| defer stopSignal() | ||
| // Get the server name | ||
|
|
@@ -264,8 +280,20 @@ | |
| slog.Debug(fmt.Sprintf("Setting up transparent proxy to forward from host port %d to %s", | ||
| port, proxyTargetURI)) | ||
|
|
||
| // Build optional functional options (e.g. CORS), only when configured. | ||
| var proxyOptions []transparent.Option | ||
| if len(proxyCORSOrigins) > 0 { | ||
| // Validate origins at startup so a misconfigured entry (missing scheme, | ||
| // trailing slash) fails loudly instead of silently never matching. | ||
| corsOrigins, err := middleware.ValidateAndNormalizeOrigins(proxyCORSOrigins) | ||
| if err != nil { | ||
| return fmt.Errorf("invalid --allow-origins: %w", err) | ||
| } | ||
| proxyOptions = append(proxyOptions, transparent.WithAllowedOrigins(corsOrigins)) | ||
| } | ||
|
|
||
| // Create the transparent proxy with middlewares | ||
| proxy := transparent.NewTransparentProxy( | ||
| proxy := transparent.NewTransparentProxyWithOptions( | ||
| proxyHost, | ||
| port, | ||
| proxyTargetURI, | ||
|
|
@@ -279,7 +307,8 @@ | |
| nil, // onUnauthorizedResponse - not needed for local proxies | ||
| "", // endpointPrefix - not configured for proxy command | ||
| false, // trustProxyHeaders - not configured for proxy command | ||
| middlewares...) | ||
| middlewares, | ||
| proxyOptions...) | ||
| if err := proxy.Start(ctx); err != nil { | ||
| return fmt.Errorf("failed to start proxy: %w", err) | ||
| } | ||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| package middleware | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "log/slog" | ||
| "net/http" | ||
| "strings" | ||
|
|
||
| "github.com/stacklok/toolhive/pkg/transport/types" | ||
| ) | ||
|
|
||
| const ( | ||
| // defaultCORSAllowedMethods is the fallback preflight method list used when | ||
| // the caller does not supply an explicit set. The proxy passes a set derived | ||
| // from the server's actual capabilities (see the stateless/stateful method | ||
| // sources of truth in the transparent proxy) so the preflight never | ||
| // advertises a method the backend will reject. | ||
| defaultCORSAllowedMethods = "GET, POST, DELETE, OPTIONS" | ||
|
|
||
| // corsAllowedHeaders lists request headers MCP clients may send. MCP-Protocol-Version | ||
| // must be allow-listed: ToolHive reads and validates it on the request path | ||
| // (an unsupported value yields 400), so a browser MCP client cannot send it | ||
| // through CORS unless it is listed here. | ||
| corsAllowedHeaders = "Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Authorization" | ||
|
|
||
| // corsExposedHeaders lists response headers that browsers may read. MCP-Protocol-Version | ||
| // is exposed so a browser client can read the negotiated protocol version back. | ||
| // Content-Type is omitted because it is already a CORS-safelisted response | ||
| // header and does not need to be exposed explicitly. | ||
| corsExposedHeaders = "Mcp-Session-Id, MCP-Protocol-Version" | ||
|
|
||
| // corsMaxAge is the preflight cache lifetime in seconds (24 hours). | ||
| corsMaxAge = "86400" | ||
| ) | ||
|
|
||
| // CORS returns a middleware that handles CORS preflight (OPTIONS) requests and | ||
| // injects Access-Control-Allow-* response headers. When allowedOrigins is empty | ||
| // the middleware is a no-op, preserving the default security posture. | ||
| // | ||
| // Origin matching rules (applied in order): | ||
| // - "*": matches every origin; Access-Control-Allow-Origin is set to "*". | ||
| // - Exact: "http://localhost:6274" matches only that origin. | ||
| // - Scheme+host prefix: "http://localhost" also matches any | ||
| // "http://localhost:<port>" (e.g. the MCP Inspector default port). | ||
| // | ||
| // All OPTIONS requests are handled directly (returning 204) when this middleware | ||
| // is active so that CORS preflights never reach the backend, which previously | ||
| // returned 405 Method Not Allowed. | ||
| // | ||
| // allowedMethods is the value advertised in Access-Control-Allow-Methods. It | ||
| // should reflect the methods the backend actually accepts so a preflight never | ||
| // succeeds for a method the real request would reject. When empty, | ||
| // defaultCORSAllowedMethods is used. | ||
| func CORS(allowedOrigins []string, allowedMethods string) types.MiddlewareFunction { | ||
| if len(allowedOrigins) == 0 { | ||
| return func(next http.Handler) http.Handler { return next } | ||
| } | ||
|
|
||
| if allowedMethods == "" { | ||
| allowedMethods = defaultCORSAllowedMethods | ||
| } | ||
|
|
||
| slog.Debug("CORS middleware configured", | ||
| "allowed_origins", strings.Join(allowedOrigins, ", "), "allowed_methods", allowedMethods) | ||
|
|
||
| return func(next http.Handler) http.Handler { | ||
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| origin := r.Header.Get("Origin") | ||
| matched := matchCORSOrigin(origin, allowedOrigins) | ||
|
|
||
| if matched != "" { | ||
| h := w.Header() | ||
| h.Set("Access-Control-Allow-Origin", matched) | ||
| h.Set("Access-Control-Allow-Methods", allowedMethods) | ||
| h.Set("Access-Control-Allow-Headers", corsAllowedHeaders) | ||
| h.Set("Access-Control-Expose-Headers", corsExposedHeaders) | ||
| h.Add("Vary", "Origin") | ||
| } | ||
|
|
||
| // Intercept OPTIONS so preflight requests never reach the backend | ||
| // (which returns 405 because it has no OPTIONS handler). | ||
| // A matched origin gets the full preflight response; an unmatched | ||
| // origin gets 204 without CORS headers — the browser will reject | ||
| // the follow-up request, which is the correct security outcome. | ||
| if r.Method == http.MethodOptions { | ||
| if matched != "" { | ||
| w.Header().Set("Access-Control-Max-Age", corsMaxAge) | ||
| } | ||
| w.WriteHeader(http.StatusNoContent) | ||
| return | ||
| } | ||
|
|
||
| next.ServeHTTP(w, r) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| // matchCORSOrigin returns the Access-Control-Allow-Origin value to send when | ||
| // requestOrigin matches an entry in allowed, or "" when there is no match. | ||
| // | ||
| // The returned value is the verbatim requestOrigin (a concrete origin) except | ||
| // when an allowed entry is "*", in which case "*" is returned directly. | ||
| func matchCORSOrigin(requestOrigin string, allowed []string) string { | ||
| if requestOrigin == "" { | ||
| return "" | ||
| } | ||
| for _, entry := range allowed { | ||
| switch { | ||
| case entry == "*": | ||
| return "*" | ||
| case entry == requestOrigin: | ||
| return requestOrigin | ||
| case strings.HasPrefix(requestOrigin, entry+":"): | ||
| // The trailing ":" boundary is load-bearing: it ensures the entry | ||
| // matches only "<entry>:<port>" and never a longer host. Without it, | ||
| // the entry "http://localhost" would also match | ||
| // "http://localhost.evil.com". See cors_test.go for the invariant. | ||
| return requestOrigin | ||
| } | ||
| } | ||
| return "" | ||
| } | ||
|
|
||
| // ValidateAndNormalizeOrigins validates configured CORS origins and returns a | ||
| // normalized copy. It surfaces misconfiguration at startup instead of letting an | ||
| // origin silently never match (which produces a broken browser experience with | ||
| // no signal): | ||
| // | ||
| // - "*" (wildcard) is passed through unchanged. | ||
| // - A trailing slash (e.g. "http://localhost:6274/") is stripped — a browser | ||
| // Origin header never carries one — with a warning, so the entry still matches. | ||
| // - An entry without a scheme (e.g. "localhost:6274") can never match a browser | ||
| // Origin header (always scheme://host[:port]) and is rejected with an error. | ||
| func ValidateAndNormalizeOrigins(origins []string) ([]string, error) { | ||
| normalized := make([]string, 0, len(origins)) | ||
| for _, origin := range origins { | ||
| entry := strings.TrimSpace(origin) | ||
|
|
||
| if entry == "*" { | ||
| normalized = append(normalized, entry) | ||
| continue | ||
| } | ||
|
|
||
| if strings.HasSuffix(entry, "/") { | ||
| stripped := strings.TrimRight(entry, "/") | ||
| slog.Warn("CORS origin has a trailing slash that browsers never send; normalizing", | ||
| "origin", origin, "normalized", stripped) | ||
| entry = stripped | ||
| } | ||
|
|
||
| // A browser Origin is scheme://host[:port]; without a scheme the entry | ||
| // can never match an incoming Origin header. | ||
| if !strings.Contains(entry, "://") { | ||
| return nil, fmt.Errorf( | ||
| "invalid CORS origin %q: missing scheme (expected e.g. %q)", origin, "http://localhost:6274") | ||
| } | ||
|
|
||
| normalized = append(normalized, entry) | ||
| } | ||
| return normalized, nil | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to do before CI goes green: run
task docsfrom the repo root to regenerate the CLI reference docs. Adding--allow-originsmeans the auto-generated content underdocs/is out of date, and the docs check will fail without it.