From 16adb697257c04a55fc926d9b772623a88afcb84 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sat, 30 May 2026 09:37:10 +0800 Subject: [PATCH] fix(cli): prevent crashes and edge-case bugs in auth flows - Clamp non-positive device-code poll intervals so a negative server value no longer panics time.NewTicker - Join the browser-flow timer goroutine before returning to avoid a send-on-closed-channel panic - Fix slow_down backoff to grow the poll interval 1.5x per step instead of compounding super-exponentially - Strip trailing slashes from the server URL to avoid double-slash default endpoints - Reject a JWT payload of null instead of reporting it as an empty-but-valid claim set - Trim extra-claims flag keys so they override file entries as documented - Bound the callback port and fix an off-by-one in the human duration formatter - Adopt all refreshed token fields wholesale instead of a partial copy - Deduplicate duration config resolution into a shared resolver - Add regression tests for the negative interval and key trimming Co-Authored-By: Claude Opus 4.8 (1M context) --- auth.go | 6 ++-- browser_flow.go | 12 +++++-- config.go | 75 ++++++++++++++++++++++++-------------------- device_flow.go | 19 +++++------ extra_claims_test.go | 16 ++++++++++ main.go | 15 +++++---- polling_test.go | 45 ++++++++++++++++++++++++++ token_cmd.go | 6 ++++ tui/styles.go | 2 +- 9 files changed, 140 insertions(+), 56 deletions(-) diff --git a/auth.go b/auth.go index 403b236..cb153ec 100644 --- a/auth.go +++ b/auth.go @@ -254,9 +254,9 @@ func makeAPICallWithAutoRefresh( return fmt.Errorf("refresh failed: %w", err) } - storage.AccessToken = newStorage.AccessToken - storage.RefreshToken = newStorage.RefreshToken - storage.ExpiresAt = newStorage.ExpiresAt + // Adopt every refreshed field (TokenType and ClientID too), rather than a + // partial copy that would silently drop any field added to the token. + *storage = *newStorage ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenRefreshedRetrying}) diff --git a/browser_flow.go b/browser_flow.go index 659a1b1..8dc161f 100644 --- a/browser_flow.go +++ b/browser_flow.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/url" + "sync" "time" "github.com/go-authgate/cli/tui" @@ -126,11 +127,16 @@ func performBrowserFlowWithUpdates( }, } - // Start goroutine to send timer updates + // Start goroutine to send timer updates. Join it before returning (Wait runs + // after close(done) thanks to LIFO defer order) so the goroutine can never + // send on `updates` after the caller closes that channel — which would panic + // with "send on closed channel". done := make(chan struct{}) + var timerWG sync.WaitGroup + defer timerWG.Wait() defer close(done) - go func() { + timerWG.Go(func() { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() startTime := time.Now() @@ -164,7 +170,7 @@ func performBrowserFlowWithUpdates( } } } - }() + }) storage, err := startCallbackServer(ctx, cfg.CallbackPort, state, cfg.CallbackTimeout, func(callbackCtx context.Context, code string) (*credstore.Token, error) { diff --git a/config.go b/config.go index 256b0d4..307c305 100644 --- a/config.go +++ b/config.go @@ -231,7 +231,7 @@ func loadConfig() *AppConfig { portStr = strconv.Itoa(flagCallbackPort) } portStr = getConfig(portStr, "CALLBACK_PORT", "8888") - if port, err := strconv.Atoi(portStr); err != nil || port <= 0 { + if port, err := strconv.Atoi(portStr); err != nil || port <= 0 || port > 65535 { cfg.CallbackPort = 8888 } else { cfg.CallbackPort = port @@ -246,6 +246,11 @@ func loadConfig() *AppConfig { os.Exit(1) } + // Drop trailing slashes so the fallback default endpoints (used when OIDC + // Discovery is unavailable) don't double the separator, e.g. + // "https://host/" + "/oauth/token" → "https://host//oauth/token". + cfg.ServerURL = strings.TrimRight(cfg.ServerURL, "/") + if strings.HasPrefix(strings.ToLower(cfg.ServerURL), "http://") { fmt.Fprintln( os.Stderr, @@ -402,11 +407,18 @@ func loadExtraClaimsFile(path string) (map[string]any, error) { } func parseExtraClaimPair(pair string) (string, any, error) { - idx := strings.IndexByte(pair, '=') - if idx <= 0 { + rawKey, rawVal, ok := strings.Cut(pair, "=") + if !ok { + return "", nil, errors.New("must be key=value with a non-empty key") + } + // Trim the key so a flag like "name =v" matches the godotenv-trimmed keys + // from --extra-claims-file, preserving the documented "flags override file + // entries" contract on conflicting keys. + key := strings.TrimSpace(rawKey) + if key == "" { return "", nil, errors.New("must be key=value with a non-empty key") } - return pair[:idx], parseClaimValue(pair[idx+1:]), nil + return key, parseClaimValue(rawVal), nil } // parseClaimValue tries to decode raw as JSON so users can write count=42 @@ -520,31 +532,12 @@ func resolveEndpoints(ctx context.Context, cfg *AppConfig) { cfg.Endpoints = meta.Endpoints() } -// getDurationConfig resolves a time.Duration from flag → env → default. -// The value is parsed with time.ParseDuration (e.g. "10s", "2m", "1m30s"). -// On parse error or non-positive value, it falls back to the default and prints a warning. +// getDurationConfig resolves a time.Duration from flag → env → default, capped +// at maxDurationConfig. The value is parsed with time.ParseDuration (e.g. "10s", +// "2m", "1m30s"). On parse error or non-positive value, it falls back to the +// default and prints a warning. func getDurationConfig(flagValue, envKey string, defaultValue time.Duration) time.Duration { - raw := getConfig(flagValue, envKey, "") - if raw == "" { - return defaultValue - } - d, err := time.ParseDuration(raw) - if err != nil { - fmt.Fprintf(os.Stderr, "WARNING: invalid duration %q for %s, using default %s\n", - raw, envKey, defaultValue) - return defaultValue - } - if d <= 0 { - fmt.Fprintf(os.Stderr, "WARNING: %s must be positive, got %s, using default %s\n", - envKey, d, defaultValue) - return defaultValue - } - if d > maxDurationConfig { - fmt.Fprintf(os.Stderr, "WARNING: %s exceeds maximum %s, capping at %s\n", - envKey, maxDurationConfig, maxDurationConfig) - return maxDurationConfig - } - return d + return parseDurationConfig(flagValue, envKey, defaultValue, maxDurationConfig) } // getRefreshThresholdConfig resolves the proactive-refresh threshold from @@ -554,21 +547,35 @@ func getDurationConfig(flagValue, envKey string, defaultValue time.Duration) tim // the CLI refresh sooner (it can't cause a hang), so users may legitimately // want 30m, 1h, etc. func getRefreshThresholdConfig(flagValue string) time.Duration { - const envKey = "REFRESH_THRESHOLD" + return parseDurationConfig(flagValue, "REFRESH_THRESHOLD", defaultRefreshThreshold, 0) +} + +// parseDurationConfig is the shared flag → env → default duration resolver. A +// maxValue of 0 means "no upper cap"; any positive maxValue caps the result and +// warns. On parse error or non-positive value it returns defaultValue. +func parseDurationConfig( + flagValue, envKey string, + defaultValue, maxValue time.Duration, +) time.Duration { raw := getConfig(flagValue, envKey, "") if raw == "" { - return defaultRefreshThreshold + return defaultValue } d, err := time.ParseDuration(raw) if err != nil { fmt.Fprintf(os.Stderr, "WARNING: invalid duration %q for %s, using default %s\n", - raw, envKey, defaultRefreshThreshold) - return defaultRefreshThreshold + raw, envKey, defaultValue) + return defaultValue } if d <= 0 { fmt.Fprintf(os.Stderr, "WARNING: %s must be positive, got %s, using default %s\n", - envKey, d, defaultRefreshThreshold) - return defaultRefreshThreshold + envKey, d, defaultValue) + return defaultValue + } + if maxValue > 0 && d > maxValue { + fmt.Fprintf(os.Stderr, "WARNING: %s exceeds maximum %s, capping at %s\n", + envKey, maxValue, maxValue) + return maxValue } return d } diff --git a/device_flow.go b/device_flow.go index 312657a..72e53b4 100644 --- a/device_flow.go +++ b/device_flow.go @@ -95,7 +95,6 @@ type pollErrorResult struct { func handleDevicePollError( err error, pollInterval *time.Duration, - backoffMultiplier *float64, pollTicker *time.Ticker, ) pollErrorResult { var oauthErr *oauth2.RetrieveError @@ -116,11 +115,12 @@ func handleDevicePollError( return pollErrorResult{action: pollContinue} case "slow_down": - *backoffMultiplier *= 1.5 - *pollInterval = min( - time.Duration(float64(*pollInterval)*(*backoffMultiplier)), - maxPollInterval, - ) + // RFC 8628 §3.5: lengthen the interval on each slow_down. Grow the + // current interval by 1.5x (capped at maxPollInterval) instead of + // compounding a separate multiplier, which ballooned the interval + // super-exponentially (base × 1.5^(1+2+…)) and hit the cap almost + // immediately. + *pollInterval = min(*pollInterval*3/2, maxPollInterval) pollTicker.Reset(*pollInterval) return pollErrorResult{action: pollBackoff} @@ -288,13 +288,14 @@ func pollForTokenWithUpdates( deviceAuth *oauth2.DeviceAuthResponse, updates chan<- tui.FlowUpdate, ) (*oauth2.Token, error) { + // Clamp non-positive server intervals (a missing, zero, or malicious + // negative value) to the RFC 8628 default — time.NewTicker panics on <= 0. interval := deviceAuth.Interval - if interval == 0 { + if interval <= 0 { interval = defaultPollInterval } pollInterval := time.Duration(interval) * time.Second - backoffMultiplier := 1.0 pollCount := 0 startTime := time.Now() @@ -330,7 +331,7 @@ func pollForTokenWithUpdates( ) if err != nil { oldInterval := pollInterval - result := handleDevicePollError(err, &pollInterval, &backoffMultiplier, pollTicker) + result := handleDevicePollError(err, &pollInterval, pollTicker) switch result.action { case pollContinue: continue diff --git a/extra_claims_test.go b/extra_claims_test.go index 17f0f16..f18f587 100644 --- a/extra_claims_test.go +++ b/extra_claims_test.go @@ -115,6 +115,22 @@ func TestParseExtraClaimPair(t *testing.T) { t.Fatalf("got (%q, %v, %v)", k, v, err) } }) + + t.Run("surrounding key whitespace trimmed", func(t *testing.T) { + // Matches godotenv key-trimming on the --extra-claims-file path so a + // flag with an incidental space still overrides the file's key. + k, v, err := parseExtraClaimPair(" project =acme") + if err != nil || k != "project" || v != "acme" { + t.Fatalf("got (%q, %v, %v)", k, v, err) + } + }) + + t.Run("whitespace-only key rejected", func(t *testing.T) { + _, _, err := parseExtraClaimPair(" =value") + if err == nil { + t.Fatal("expected error for whitespace-only key") + } + }) } // ----------------------------------------------------------------------- diff --git a/main.go b/main.go index 008b94f..086b038 100644 --- a/main.go +++ b/main.go @@ -86,13 +86,18 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int { } if err == nil { ui.ShowStatus(tui.StatusUpdate{Event: tui.EventExistingTokens}) + // Reuse the already-loaded token as-is. Shared by the still-valid, + // no-refresh-token, and refresh-failed branches below. + useCached := func() { + storage = &existing + flow = "cached" + } // Capture now once so the refresh decision and the graceful-degradation // check below reason about the same instant. now := time.Now() if !needsRefresh(existing, cfg.RefreshThreshold, now) { ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenStillValid}) - storage = &existing - flow = "cached" + useCached() } else { // Whether the old token is still usable as-is — shared with // ensureFreshToken via tokenUsable so a corrupt token (empty access @@ -103,8 +108,7 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int { // network call entirely (no refresh step is shown). Reuse the old // token while it's still valid; otherwise fall through to re-auth. if reuseValid { - storage = &existing - flow = "cached" + useCached() } } else { // About to refresh: show an accurate in-progress status — @@ -121,8 +125,7 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int { // fall through to re-authentication once expired. ui.ShowStatus(tui.StatusUpdate{Event: tui.EventRefreshFailed, Err: refreshErr}) if reuseValid { - storage = &existing - flow = "cached" + useCached() } } else { storage = newStorage diff --git a/polling_test.go b/polling_test.go index 4f3fa50..81ac5a9 100644 --- a/polling_test.go +++ b/polling_test.go @@ -262,6 +262,51 @@ func TestPollForToken_ContextTimeout(t *testing.T) { } } +func TestPollForToken_NegativeInterval(t *testing.T) { + // A malicious or buggy server may return a negative poll interval. It must + // be clamped to the default rather than reaching time.NewTicker (which + // panics on a non-positive duration and would crash the whole CLI). + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(map[string]string{ + "error": "authorization_pending", + "error_description": "User has not yet authorized", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + defer server.Close() + + cfg := testConfig(t) + + config := &oauth2.Config{ + ClientID: "test-client", + Endpoint: oauth2.Endpoint{TokenURL: server.URL}, + } + deviceAuth := &oauth2.DeviceAuthResponse{ + DeviceCode: "test-device-code", + Interval: -1, // negative — must be clamped, not passed to NewTicker + } + + // Short context: with the interval clamped to the 5s default, the first poll + // never fires before this deadline, so the call returns via ctx instead of + // waiting. The assertion is simply that it returns rather than panicking. + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + updates := make(chan tui.FlowUpdate, 100) + drainUpdates(t, updates) + + _, err := pollForTokenWithUpdates(ctx, cfg, config, deviceAuth, updates) + if err == nil { + t.Fatal("expected context error, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected context.DeadlineExceeded, got: %v", err) + } +} + func TestExchangeDeviceCode_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { diff --git a/token_cmd.go b/token_cmd.go index 31049c9..52649e5 100644 --- a/token_cmd.go +++ b/token_cmd.go @@ -422,6 +422,12 @@ func parseJWTPayload(token string) (map[string]any, error) { if err := dec.Decode(&claims); err != nil { return nil, fmt.Errorf("parse claims: %w", err) } + // A payload of the JSON literal `null` unmarshals into a nil map with no + // error; reject it so a malformed token isn't reported as an empty-but-valid + // claim set. + if claims == nil { + return nil, errors.New("parse claims: payload is not a JSON object") + } // Reject trailing bytes (or a second concatenated value) so a malformed // payload can't masquerade as valid by hiding extra data after the object. var sink struct{} diff --git a/tui/styles.go b/tui/styles.go index 2b7a0ce..811304d 100644 --- a/tui/styles.go +++ b/tui/styles.go @@ -31,7 +31,7 @@ func FormatDurationHuman(d time.Duration) string { hours := int(d.Hours()) minutes := int(d.Minutes()) % 60 - if hours > 24 { + if hours >= 24 { days := hours / 24 hours %= 24 return fmt.Sprintf("%dd %dh", days, hours)