Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ func makeAPICallWithAutoRefresh(
return fmt.Errorf("refresh failed: %w", err)
}

storage.AccessToken = newStorage.AccessToken
storage.RefreshToken = newStorage.RefreshToken
storage.ExpiresAt = newStorage.ExpiresAt
// Adopt every refreshed field (TokenType and ClientID too), rather than a
// partial copy that would silently drop any field added to the token.
*storage = *newStorage

ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenRefreshedRetrying})

Expand Down
12 changes: 9 additions & 3 deletions browser_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/url"
"sync"
"time"

"github.com/go-authgate/cli/tui"
Expand Down Expand Up @@ -126,11 +127,16 @@ func performBrowserFlowWithUpdates(
},
}

// Start goroutine to send timer updates
// Start goroutine to send timer updates. Join it before returning (Wait runs
// after close(done) thanks to LIFO defer order) so the goroutine can never
// send on `updates` after the caller closes that channel — which would panic
// with "send on closed channel".
done := make(chan struct{})
var timerWG sync.WaitGroup
defer timerWG.Wait()
defer close(done)

go func() {
timerWG.Go(func() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
startTime := time.Now()
Expand Down Expand Up @@ -164,7 +170,7 @@ func performBrowserFlowWithUpdates(
}
}
}
}()
})

storage, err := startCallbackServer(ctx, cfg.CallbackPort, state, cfg.CallbackTimeout,
func(callbackCtx context.Context, code string) (*credstore.Token, error) {
Expand Down
75 changes: 41 additions & 34 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func loadConfig() *AppConfig {
portStr = strconv.Itoa(flagCallbackPort)
}
portStr = getConfig(portStr, "CALLBACK_PORT", "8888")
if port, err := strconv.Atoi(portStr); err != nil || port <= 0 {
if port, err := strconv.Atoi(portStr); err != nil || port <= 0 || port > 65535 {
cfg.CallbackPort = 8888
} else {
cfg.CallbackPort = port
Expand All @@ -246,6 +246,11 @@ func loadConfig() *AppConfig {
os.Exit(1)
}

// Drop trailing slashes so the fallback default endpoints (used when OIDC
// Discovery is unavailable) don't double the separator, e.g.
// "https://host/" + "/oauth/token" → "https://host//oauth/token".
cfg.ServerURL = strings.TrimRight(cfg.ServerURL, "/")

if strings.HasPrefix(strings.ToLower(cfg.ServerURL), "http://") {
fmt.Fprintln(
os.Stderr,
Expand Down Expand Up @@ -402,11 +407,18 @@ func loadExtraClaimsFile(path string) (map[string]any, error) {
}

func parseExtraClaimPair(pair string) (string, any, error) {
idx := strings.IndexByte(pair, '=')
if idx <= 0 {
rawKey, rawVal, ok := strings.Cut(pair, "=")
if !ok {
return "", nil, errors.New("must be key=value with a non-empty key")
}
// Trim the key so a flag like "name =v" matches the godotenv-trimmed keys
// from --extra-claims-file, preserving the documented "flags override file
// entries" contract on conflicting keys.
key := strings.TrimSpace(rawKey)
if key == "" {
return "", nil, errors.New("must be key=value with a non-empty key")
}
return pair[:idx], parseClaimValue(pair[idx+1:]), nil
return key, parseClaimValue(rawVal), nil
}

// parseClaimValue tries to decode raw as JSON so users can write count=42
Expand Down Expand Up @@ -520,31 +532,12 @@ func resolveEndpoints(ctx context.Context, cfg *AppConfig) {
cfg.Endpoints = meta.Endpoints()
}

// getDurationConfig resolves a time.Duration from flag → env → default.
// The value is parsed with time.ParseDuration (e.g. "10s", "2m", "1m30s").
// On parse error or non-positive value, it falls back to the default and prints a warning.
// getDurationConfig resolves a time.Duration from flag → env → default, capped
// at maxDurationConfig. The value is parsed with time.ParseDuration (e.g. "10s",
// "2m", "1m30s"). On parse error or non-positive value, it falls back to the
// default and prints a warning.
func getDurationConfig(flagValue, envKey string, defaultValue time.Duration) time.Duration {
raw := getConfig(flagValue, envKey, "")
if raw == "" {
return defaultValue
}
d, err := time.ParseDuration(raw)
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: invalid duration %q for %s, using default %s\n",
raw, envKey, defaultValue)
return defaultValue
}
if d <= 0 {
fmt.Fprintf(os.Stderr, "WARNING: %s must be positive, got %s, using default %s\n",
envKey, d, defaultValue)
return defaultValue
}
if d > maxDurationConfig {
fmt.Fprintf(os.Stderr, "WARNING: %s exceeds maximum %s, capping at %s\n",
envKey, maxDurationConfig, maxDurationConfig)
return maxDurationConfig
}
return d
return parseDurationConfig(flagValue, envKey, defaultValue, maxDurationConfig)
}

// getRefreshThresholdConfig resolves the proactive-refresh threshold from
Expand All @@ -554,21 +547,35 @@ func getDurationConfig(flagValue, envKey string, defaultValue time.Duration) tim
// the CLI refresh sooner (it can't cause a hang), so users may legitimately
// want 30m, 1h, etc.
func getRefreshThresholdConfig(flagValue string) time.Duration {
const envKey = "REFRESH_THRESHOLD"
return parseDurationConfig(flagValue, "REFRESH_THRESHOLD", defaultRefreshThreshold, 0)
}

// parseDurationConfig is the shared flag → env → default duration resolver. A
// maxValue of 0 means "no upper cap"; any positive maxValue caps the result and
// warns. On parse error or non-positive value it returns defaultValue.
func parseDurationConfig(
flagValue, envKey string,
defaultValue, maxValue time.Duration,
) time.Duration {
raw := getConfig(flagValue, envKey, "")
if raw == "" {
return defaultRefreshThreshold
return defaultValue
}
d, err := time.ParseDuration(raw)
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: invalid duration %q for %s, using default %s\n",
raw, envKey, defaultRefreshThreshold)
return defaultRefreshThreshold
raw, envKey, defaultValue)
return defaultValue
}
if d <= 0 {
fmt.Fprintf(os.Stderr, "WARNING: %s must be positive, got %s, using default %s\n",
envKey, d, defaultRefreshThreshold)
return defaultRefreshThreshold
envKey, d, defaultValue)
return defaultValue
}
if maxValue > 0 && d > maxValue {
fmt.Fprintf(os.Stderr, "WARNING: %s exceeds maximum %s, capping at %s\n",
envKey, maxValue, maxValue)
return maxValue
}
return d
}
Expand Down
19 changes: 10 additions & 9 deletions device_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ type pollErrorResult struct {
func handleDevicePollError(
err error,
pollInterval *time.Duration,
backoffMultiplier *float64,
pollTicker *time.Ticker,
) pollErrorResult {
var oauthErr *oauth2.RetrieveError
Expand All @@ -116,11 +115,12 @@ func handleDevicePollError(
return pollErrorResult{action: pollContinue}

case "slow_down":
*backoffMultiplier *= 1.5
*pollInterval = min(
time.Duration(float64(*pollInterval)*(*backoffMultiplier)),
maxPollInterval,
)
// RFC 8628 §3.5: lengthen the interval on each slow_down. Grow the
// current interval by 1.5x (capped at maxPollInterval) instead of
// compounding a separate multiplier, which ballooned the interval
// super-exponentially (base × 1.5^(1+2+…)) and hit the cap almost
// immediately.
*pollInterval = min(*pollInterval*3/2, maxPollInterval)
pollTicker.Reset(*pollInterval)
return pollErrorResult{action: pollBackoff}

Expand Down Expand Up @@ -288,13 +288,14 @@ func pollForTokenWithUpdates(
deviceAuth *oauth2.DeviceAuthResponse,
updates chan<- tui.FlowUpdate,
) (*oauth2.Token, error) {
// Clamp non-positive server intervals (a missing, zero, or malicious
// negative value) to the RFC 8628 default — time.NewTicker panics on <= 0.
interval := deviceAuth.Interval
if interval == 0 {
if interval <= 0 {
interval = defaultPollInterval
}

pollInterval := time.Duration(interval) * time.Second
backoffMultiplier := 1.0
pollCount := 0
startTime := time.Now()

Expand Down Expand Up @@ -330,7 +331,7 @@ func pollForTokenWithUpdates(
)
if err != nil {
oldInterval := pollInterval
result := handleDevicePollError(err, &pollInterval, &backoffMultiplier, pollTicker)
result := handleDevicePollError(err, &pollInterval, pollTicker)
switch result.action {
case pollContinue:
continue
Expand Down
16 changes: 16 additions & 0 deletions extra_claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ func TestParseExtraClaimPair(t *testing.T) {
t.Fatalf("got (%q, %v, %v)", k, v, err)
}
})

t.Run("surrounding key whitespace trimmed", func(t *testing.T) {
// Matches godotenv key-trimming on the --extra-claims-file path so a
// flag with an incidental space still overrides the file's key.
k, v, err := parseExtraClaimPair(" project =acme")
if err != nil || k != "project" || v != "acme" {
t.Fatalf("got (%q, %v, %v)", k, v, err)
}
})

t.Run("whitespace-only key rejected", func(t *testing.T) {
_, _, err := parseExtraClaimPair(" =value")
if err == nil {
t.Fatal("expected error for whitespace-only key")
}
})
}

// -----------------------------------------------------------------------
Expand Down
15 changes: 9 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,18 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int {
}
if err == nil {
ui.ShowStatus(tui.StatusUpdate{Event: tui.EventExistingTokens})
// Reuse the already-loaded token as-is. Shared by the still-valid,
// no-refresh-token, and refresh-failed branches below.
useCached := func() {
storage = &existing
flow = "cached"
}
// Capture now once so the refresh decision and the graceful-degradation
// check below reason about the same instant.
now := time.Now()
if !needsRefresh(existing, cfg.RefreshThreshold, now) {
ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenStillValid})
storage = &existing
flow = "cached"
useCached()
} else {
// Whether the old token is still usable as-is — shared with
// ensureFreshToken via tokenUsable so a corrupt token (empty access
Expand All @@ -103,8 +108,7 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int {
// network call entirely (no refresh step is shown). Reuse the old
// token while it's still valid; otherwise fall through to re-auth.
if reuseValid {
storage = &existing
flow = "cached"
useCached()
}
} else {
// About to refresh: show an accurate in-progress status —
Expand All @@ -121,8 +125,7 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int {
// fall through to re-authentication once expired.
ui.ShowStatus(tui.StatusUpdate{Event: tui.EventRefreshFailed, Err: refreshErr})
if reuseValid {
storage = &existing
flow = "cached"
useCached()
}
} else {
storage = newStorage
Expand Down
45 changes: 45 additions & 0 deletions polling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,51 @@ func TestPollForToken_ContextTimeout(t *testing.T) {
}
}

func TestPollForToken_NegativeInterval(t *testing.T) {
// A malicious or buggy server may return a negative poll interval. It must
// be clamped to the default rather than reaching time.NewTicker (which
// panics on a non-positive duration and would crash the whole CLI).
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "authorization_pending",
"error_description": "User has not yet authorized",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
defer server.Close()

cfg := testConfig(t)

config := &oauth2.Config{
ClientID: "test-client",
Endpoint: oauth2.Endpoint{TokenURL: server.URL},
}
deviceAuth := &oauth2.DeviceAuthResponse{
DeviceCode: "test-device-code",
Interval: -1, // negative — must be clamped, not passed to NewTicker
}

// Short context: with the interval clamped to the 5s default, the first poll
// never fires before this deadline, so the call returns via ctx instead of
// waiting. The assertion is simply that it returns rather than panicking.
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()

updates := make(chan tui.FlowUpdate, 100)
drainUpdates(t, updates)

_, err := pollForTokenWithUpdates(ctx, cfg, config, deviceAuth, updates)
if err == nil {
t.Fatal("expected context error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("expected context.DeadlineExceeded, got: %v", err)
}
}

func TestExchangeDeviceCode_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Expand Down
6 changes: 6 additions & 0 deletions token_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@ func parseJWTPayload(token string) (map[string]any, error) {
if err := dec.Decode(&claims); err != nil {
return nil, fmt.Errorf("parse claims: %w", err)
}
// A payload of the JSON literal `null` unmarshals into a nil map with no
// error; reject it so a malformed token isn't reported as an empty-but-valid
// claim set.
if claims == nil {
return nil, errors.New("parse claims: payload is not a JSON object")
}
// Reject trailing bytes (or a second concatenated value) so a malformed
// payload can't masquerade as valid by hiding extra data after the object.
var sink struct{}
Expand Down
2 changes: 1 addition & 1 deletion tui/styles.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func FormatDurationHuman(d time.Duration) string {
hours := int(d.Hours())
minutes := int(d.Minutes()) % 60

if hours > 24 {
if hours >= 24 {
days := hours / 24
hours %= 24
return fmt.Sprintf("%dd %dh", days, hours)
Expand Down
Loading