Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/serve/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ func startServerWithConfigPath(ctx context.Context, run *command.Context, cfg co
Upgrade: upgradeManager,
ActivityDecider: channelActivityDecider(codexBridgeMgr),
ConfigPath: configPath,
ServerConfig: cfg.Server,
AccessToken: cfg.Server.AccessToken,
NoAuth: cfg.Server.NoAuth,
Context: ctx,
Expand Down
74 changes: 55 additions & 19 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"

"csgclaw/internal/auth"
"csgclaw/internal/config"
)

const authCallbackPath = "/api/v1/auth/callback"
Expand All @@ -35,15 +35,15 @@ var appAuthLogout = func(r *http.Request) (auth.Status, error) {
return auth.Default().Logout(r.Context())
}

var appAuthCallback = func(r *http.Request) (string, error) {
var appAuthCallback = func(r *http.Request, opts auth.CallbackOptions) (string, error) {
values := r.URL.Query()
if values.Get("jwt_token") == "" {
if token := bearerToken(r.Header.Get("Authorization")); token != "" {
values = cloneURLValues(values)
values.Set("jwt_token", token)
}
}
return auth.Default().CompleteCallback(r.Context(), values)
return auth.Default().CompleteCallback(r.Context(), values, opts)
}

func (h *Handler) handleAuthStatus(w http.ResponseWriter, r *http.Request) {
Expand All @@ -64,7 +64,9 @@ func (h *Handler) handleAuthCallback(w http.ResponseWriter, r *http.Request) {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
redirectURL, err := appAuthCallback(r)
redirectURL, err := appAuthCallback(r, auth.CallbackOptions{
AllowedReturnURLBase: h.authCallbackURL(r),
})
if err != nil {
status := http.StatusBadRequest
if !auth.IsCallbackValidationError(err) {
Expand Down Expand Up @@ -94,7 +96,7 @@ func (h *Handler) handleAuthLogin(w http.ResponseWriter, r *http.Request) {
req.ReturnURL = r.Referer()
}
if req.CallbackURL == "" {
req.CallbackURL = authLocalCallbackURL(r)
req.CallbackURL = h.authCallbackURL(r, req.ReturnURL)
}
resp, err := appAuthLogin(r, req)
if err != nil {
Expand All @@ -117,15 +119,33 @@ func (h *Handler) handleAuthLogout(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, status)
}

func authLocalCallbackURL(r *http.Request) string {
func (h *Handler) authCallbackURL(r *http.Request, pageURLs ...string) string {
if h != nil && strings.TrimSpace(h.server.AdvertiseBaseURL) != "" {
return authCallbackURLFromBase(config.ResolveAdvertiseBaseURL(h.server))
}
for _, pageURL := range pageURLs {
if callbackURL := authCallbackURLFromPageURL(pageURL); callbackURL != "" {
return callbackURL
}
}
if callbackURL := authRequestCallbackURL(r); callbackURL != "" {
return callbackURL
}
if h != nil && (strings.TrimSpace(h.server.ListenAddr) != "" || strings.TrimSpace(h.server.AdvertiseBaseURL) != "") {
return authCallbackURLFromBase(config.ResolveAdvertiseBaseURL(h.server))
}
return ""
}

func authRequestCallbackURL(r *http.Request) string {
if r == nil {
return ""
}
host := strings.TrimSpace(r.Host)
if host == "" && r.URL != nil {
host = strings.TrimSpace(r.URL.Host)
}
if !isLocalRequestHost(host) {
if host == "" {
return ""
}
scheme := "http"
Expand All @@ -140,21 +160,37 @@ func authLocalCallbackURL(r *http.Request) string {
return u.String()
}

func isLocalRequestHost(hostport string) bool {
hostport = strings.TrimSpace(hostport)
if hostport == "" {
return false
func authCallbackURLFromPageURL(pageURL string) string {
u, err := url.Parse(strings.TrimSpace(pageURL))
if err != nil || u.Scheme == "" || u.Host == "" {
return ""
}
host, _, err := net.SplitHostPort(hostport)
if err != nil {
host = hostport
scheme := strings.ToLower(u.Scheme)
if scheme != "http" && scheme != "https" {
return ""
}
switch strings.ToLower(strings.Trim(host, "[]")) {
case "127.0.0.1", "localhost", "::1":
return true
default:
return false
u.RawQuery = ""
u.Fragment = ""
return authCallbackURLFromBase(u.String())
}

func authCallbackURLFromBase(baseURL string) string {
baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
if baseURL == "" {
return ""
}
u, err := url.Parse(baseURL)
if err != nil || u.Scheme == "" || u.Host == "" {
return ""
}
scheme := strings.ToLower(u.Scheme)
if scheme != "http" && scheme != "https" {
return ""
}
u.Path = strings.TrimRight(u.Path, "/") + authCallbackPath
u.RawQuery = ""
u.Fragment = ""
return u.String()
}

func bearerToken(authHeader string) string {
Expand Down
63 changes: 61 additions & 2 deletions internal/api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"csgclaw/internal/auth"
"csgclaw/internal/config"
)

func TestHandleAuthStatus(t *testing.T) {
Expand Down Expand Up @@ -72,16 +73,74 @@ func TestHandleAuthLogin(t *testing.T) {
}
}

func TestHandleAuthLoginUsesReturnURLOrigin(t *testing.T) {
var gotCallbackURL string
restore := stubAuthLogin(func(_ *http.Request, req authLoginRequest) (auth.LoginResponse, error) {
gotCallbackURL = req.CallbackURL
return auth.LoginResponse{LoginURL: "https://iam.example.test/login"}, nil
})
defer restore()

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(`{"return_url":"https://current.example.test/#/workspace"}`))
req.Host = "fallback.example.test"
(&Handler{}).Routes().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
if gotCallbackURL != "https://current.example.test/api/v1/auth/callback" {
t.Fatalf("callback_url = %q", gotCallbackURL)
}
}

func TestHandleAuthLoginUsesAdvertiseBaseURL(t *testing.T) {
var gotCallbackURL string
restore := stubAuthLogin(func(_ *http.Request, req authLoginRequest) (auth.LoginResponse, error) {
gotCallbackURL = req.CallbackURL
return auth.LoginResponse{LoginURL: "https://iam.example.test/login"}, nil
})
defer restore()

srv := &Handler{}
srv.SetServerConfig(config.ServerConfig{
AdvertiseBaseURL: "https://csgclaw.example.test/base/",
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", strings.NewReader(`{"return_url":"https://csgclaw.example.test/base/#/workspace"}`))
req.Host = "evil.example.test"
srv.Routes().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
if gotCallbackURL != "https://csgclaw.example.test/base/api/v1/auth/callback" {
t.Fatalf("callback_url = %q", gotCallbackURL)
}
}

func TestAuthCallbackURLUsesRequestHostWhenUnconfigured(t *testing.T) {
srv := &Handler{}
srv.SetServerConfig(config.ServerConfig{ListenAddr: "0.0.0.0:18080"})
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil)
req.Host = "evil.example.test"
if got := srv.authCallbackURL(req); got != "http://evil.example.test/api/v1/auth/callback" {
t.Fatalf("authCallbackURL() = %q", got)
}
}

func TestHandleAuthCallback(t *testing.T) {
var gotToken string
restore := stubAuthCallback(func(r *http.Request) (string, error) {
restore := stubAuthCallback(func(r *http.Request, opts auth.CallbackOptions) (string, error) {
gotToken = r.URL.Query().Get("jwt_token")
if opts.AllowedReturnURLBase != "http://127.0.0.1:18080/api/v1/auth/callback" {
t.Fatalf("AllowedReturnURLBase = %q", opts.AllowedReturnURLBase)
}
return "http://127.0.0.1:18080/#/workspace", nil
})
defer restore()

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/callback?jwt_token=jwt-value", nil)
req.Host = "127.0.0.1:18080"
(&Handler{}).Routes().ServeHTTP(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusFound, rec.Body.String())
Expand Down Expand Up @@ -126,7 +185,7 @@ func stubAuthLogin(fn func(*http.Request, authLoginRequest) (auth.LoginResponse,
return func() { appAuthLogin = previous }
}

func stubAuthCallback(fn func(*http.Request) (string, error)) func() {
func stubAuthCallback(fn func(*http.Request, auth.CallbackOptions) (string, error)) func() {
previous := appAuthCallback
appAuthCallback = fn
return func() { appAuthCallback = previous }
Expand Down
7 changes: 7 additions & 0 deletions internal/api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Handler struct {
activityDecider ActivityDecider
localDirectoryPicker func(context.Context) (string, error)
feishuRegistrationStateDir string
server config.ServerConfig

participantActivityTurnsMu sync.Mutex
participantActivityTurns map[string]participantActivityTurn
Expand Down Expand Up @@ -610,6 +611,12 @@ func (h *Handler) SetConfigPath(path string) {
}
}

func (h *Handler) SetServerConfig(server config.ServerConfig) {
if h != nil {
h.server = server
}
}

func (h *Handler) validateServerAccessToken(authHeader string) bool {
if h.serverNoAuth {
return true
Expand Down
44 changes: 26 additions & 18 deletions internal/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type LoginOptions struct {
CallbackURL string
}

type CallbackOptions struct {
AllowedReturnURLBase string
}

type Service struct {
Store Store
HTTPClient *http.Client
Expand All @@ -51,8 +55,9 @@ func (s *Service) Login(_ context.Context, opts ...LoginOptions) (LoginResponse,
returnURL := ""
callbackURL := ""
if len(opts) > 0 {
returnURL = sanitizeReturnURL(opts[0].ReturnURL)
callbackURL = callbackURLWithReturnURL(sanitizeCallbackURL(opts[0].CallbackURL), returnURL)
callbackURL = sanitizeCallbackURL(opts[0].CallbackURL)
returnURL = sanitizeReturnURL(opts[0].ReturnURL, callbackURL)
callbackURL = callbackURLWithReturnURL(callbackURL, returnURL)
}
if callbackURL == "" {
return LoginResponse{}, fmt.Errorf("auth callback url is required")
Expand All @@ -71,11 +76,11 @@ func (s *Service) Logout(context.Context) (Status, error) {
return Status{}, nil
}

func (s *Service) CompleteCallback(ctx context.Context, values url.Values) (string, error) {
return s.completeCallback(ctx, values)
func (s *Service) CompleteCallback(ctx context.Context, values url.Values, opts ...CallbackOptions) (string, error) {
return s.completeCallback(ctx, values, opts...)
}

func (s *Service) completeCallback(ctx context.Context, values url.Values) (string, error) {
func (s *Service) completeCallback(ctx context.Context, values url.Values, opts ...CallbackOptions) (string, error) {
jwtToken := strings.TrimSpace(values.Get("jwt_token"))
if jwtToken == "" {
jwtToken = strings.TrimSpace(values.Get("jwt"))
Expand Down Expand Up @@ -140,7 +145,11 @@ func (s *Service) completeCallback(ctx context.Context, values url.Values) (stri
}
}

if returnURL := callbackReturnURL(values); returnURL != "" {
allowedReturnURLBase := ""
if len(opts) > 0 {
allowedReturnURLBase = opts[0].AllowedReturnURLBase
}
if returnURL := callbackReturnURL(values, allowedReturnURLBase); returnURL != "" {
return returnURL, nil
}
if portalURL == "" {
Expand Down Expand Up @@ -346,7 +355,7 @@ func joinAPIPath(baseURL, apiPath string) (*url.URL, error) {
return base.ResolveReference(ref), nil
}

func sanitizeReturnURL(raw string) string {
func sanitizeReturnURL(raw, allowedBase string) string {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil || u.Scheme == "" || u.Host == "" {
return ""
Expand All @@ -355,7 +364,7 @@ func sanitizeReturnURL(raw string) string {
if scheme != "http" && scheme != "https" {
return ""
}
if isLocalHostname(u.Hostname()) {
if sameOrigin(u, allowedBase) {
return u.String()
}
return ""
Expand All @@ -370,9 +379,6 @@ func sanitizeCallbackURL(raw string) string {
if scheme != "http" && scheme != "https" {
return ""
}
if !isLocalHostname(u.Hostname()) {
return ""
}
return u.String()
}

Expand All @@ -390,22 +396,24 @@ func callbackURLWithReturnURL(callbackURL, returnURL string) string {
return u.String()
}

func callbackReturnURL(values url.Values) string {
func callbackReturnURL(values url.Values, allowedBase string) string {
for _, key := range []string{"return_url", "url"} {
if returnURL := sanitizeReturnURL(values.Get(key)); returnURL != "" {
if returnURL := sanitizeReturnURL(values.Get(key), allowedBase); returnURL != "" {
return returnURL
}
}
return ""
}

func isLocalHostname(hostname string) bool {
switch strings.ToLower(strings.Trim(hostname, "[]")) {
case "127.0.0.1", "localhost", "::1":
return true
default:
func sameOrigin(u *url.URL, allowedBase string) bool {
if u == nil {
return false
}
base, err := url.Parse(strings.TrimSpace(allowedBase))
if err != nil || base.Scheme == "" || base.Host == "" {
return false
}
return strings.EqualFold(u.Scheme, base.Scheme) && strings.EqualFold(u.Host, base.Host)
}

type callbackValidationError string
Expand Down
Loading
Loading