diff --git a/admin/api/auth.go b/admin/api/auth.go new file mode 100644 index 0000000..fd2ef06 --- /dev/null +++ b/admin/api/auth.go @@ -0,0 +1,227 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "time" + + "github.com/cloudblue/chaperone/admin/auth" +) + +// AuthHandler handles login, logout, and password change endpoints. +type AuthHandler struct { + auth *auth.Service + secureCookies bool + sessionMaxAge time.Duration +} + +// NewAuthHandler creates a handler for auth endpoints. +func NewAuthHandler(authService *auth.Service, secureCookies bool, sessionMaxAge time.Duration) *AuthHandler { + return &AuthHandler{ + auth: authService, + secureCookies: secureCookies, + sessionMaxAge: sessionMaxAge, + } +} + +// Register mounts auth routes on the given mux. +func (h *AuthHandler) Register(mux *http.ServeMux) { + mux.HandleFunc("POST /api/login", h.login) + mux.HandleFunc("POST /api/logout", h.logout) + mux.HandleFunc("GET /api/me", h.me) + mux.HandleFunc("PUT /api/user/password", h.changePassword) +} + +type loginRequest struct { + Username string `json:"username"` + Password string `json:"password"` // #nosec G117 -- request field, not a hardcoded secret +} + +type loginResponse struct { + User loginUser `json:"user"` +} + +type loginUser struct { + ID int64 `json:"id"` + Username string `json:"username"` +} + +func (h *AuthHandler) login(w http.ResponseWriter, r *http.Request) { + var req loginRequest + if !decodeJSON(w, r, &req) { + return + } + + if req.Username == "" || req.Password == "" { + respondError(w, http.StatusBadRequest, "VALIDATION_ERROR", "username and password are required") + return + } + + ip := clientIP(r) + result, err := h.auth.Login(r.Context(), ip, req.Username, req.Password) + if errors.Is(err, auth.ErrRateLimited) { + w.Header().Set("Retry-After", "60") + respondError(w, http.StatusTooManyRequests, "RATE_LIMITED", "Too many failed login attempts. Try again later.") + return + } + if errors.Is(err, auth.ErrInvalidCredentials) { + respondError(w, http.StatusUnauthorized, "UNAUTHORIZED", "Invalid username or password") + return + } + if err != nil { + slog.Error("login failed", "error", err) + respondError(w, http.StatusInternalServerError, "INTERNAL_ERROR", "Login failed") + return + } + + h.setSessionCookie(w, result.SessionToken) + h.setCSRFCookie(w) + + respondJSON(w, http.StatusOK, loginResponse{ + User: loginUser{ + ID: result.User.ID, + Username: result.User.Username, + }, + }) +} + +func (h *AuthHandler) me(w http.ResponseWriter, r *http.Request) { + user := auth.ContextUser(r.Context()) + if user == nil { + respondError(w, http.StatusUnauthorized, "UNAUTHORIZED", "Authentication required") + return + } + respondJSON(w, http.StatusOK, loginResponse{ + User: loginUser{ID: user.ID, Username: user.Username}, + }) +} + +func (h *AuthHandler) logout(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(auth.SessionCookieName) + if err == nil { + if logoutErr := h.auth.Logout(r.Context(), cookie.Value); logoutErr != nil { + slog.Error("logout session deletion", "error", logoutErr) + } + } + h.clearCookies(w) + w.WriteHeader(http.StatusNoContent) +} + +type changePasswordRequest struct { + CurrentPassword string `json:"current_password"` + NewPassword string `json:"new_password"` +} + +func (h *AuthHandler) changePassword(w http.ResponseWriter, r *http.Request) { + user := auth.ContextUser(r.Context()) + if user == nil { + respondError(w, http.StatusUnauthorized, "UNAUTHORIZED", "Authentication required") + return + } + + cookie, err := r.Cookie(auth.SessionCookieName) + if err != nil { + respondError(w, http.StatusUnauthorized, "UNAUTHORIZED", "Authentication required") + return + } + + var req changePasswordRequest + if !decodeJSON(w, r, &req) { + return + } + + if req.CurrentPassword == "" || req.NewPassword == "" { + respondError(w, http.StatusBadRequest, "VALIDATION_ERROR", "current_password and new_password are required") + return + } + + err = h.auth.ChangePassword(r.Context(), user.ID, cookie.Value, req.CurrentPassword, req.NewPassword) + if errors.Is(err, auth.ErrInvalidCredentials) { + respondError(w, http.StatusUnauthorized, "UNAUTHORIZED", "Current password is incorrect") + return + } + if errors.Is(err, auth.ErrPasswordTooShort) { + respondError(w, http.StatusBadRequest, "VALIDATION_ERROR", + fmt.Sprintf("Password must be at least %d characters", auth.MinPasswordLength)) + return + } + if errors.Is(err, auth.ErrPasswordTooLong) { + respondError(w, http.StatusBadRequest, "VALIDATION_ERROR", + fmt.Sprintf("Password must be at most %d characters", auth.MaxPasswordLength)) + return + } + if err != nil { + slog.Error("password change failed", "user_id", user.ID, "error", err) + respondError(w, http.StatusInternalServerError, "INTERNAL_ERROR", "Failed to change password") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (h *AuthHandler) setSessionCookie(w http.ResponseWriter, token string) { + http.SetCookie(w, &http.Cookie{ + Name: auth.SessionCookieName, + Value: token, + Path: "/", + MaxAge: int(h.sessionMaxAge.Seconds()), + HttpOnly: true, + Secure: h.secureCookies, + SameSite: http.SameSiteLaxMode, + }) +} + +func (h *AuthHandler) setCSRFCookie(w http.ResponseWriter) { + token, err := auth.GenerateToken(16) + if err != nil { + slog.Error("generating CSRF token", "error", err) + return + } + http.SetCookie(w, &http.Cookie{ + Name: auth.CSRFCookieName, + Value: token, + Path: "/", + MaxAge: int(h.sessionMaxAge.Seconds()), + HttpOnly: false, + Secure: h.secureCookies, + SameSite: http.SameSiteStrictMode, + }) +} + +func (h *AuthHandler) clearCookies(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: auth.SessionCookieName, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: h.secureCookies, + SameSite: http.SameSiteLaxMode, + }) + http.SetCookie(w, &http.Cookie{ + Name: auth.CSRFCookieName, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: false, + Secure: h.secureCookies, + SameSite: http.SameSiteStrictMode, + }) +} + +// clientIP extracts the client IP from the request's TCP peer address. +// The admin portal is deployed direct-to-network within Distributor infrastructure; +// X-Forwarded-For is not trusted and must be ignored for rate-limiting. +func clientIP(r *http.Request) string { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} diff --git a/admin/api/auth_test.go b/admin/api/auth_test.go new file mode 100644 index 0000000..8a833ad --- /dev/null +++ b/admin/api/auth_test.go @@ -0,0 +1,293 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/cloudblue/chaperone/admin/auth" +) + +const testPassword = "securepassword12" + +func newTestAuthMux(t *testing.T) (*http.ServeMux, *auth.Service) { + t.Helper() + st := openTestStore(t) + svc := auth.NewService(st, 24*time.Hour, 2*time.Hour) + h := NewAuthHandler(svc, false, 24*time.Hour) + mux := http.NewServeMux() + h.Register(mux) + return mux, svc +} + +func createTestUser(t *testing.T, svc *auth.Service) { + t.Helper() + if err := svc.CreateUser(context.Background(), "admin", testPassword); err != nil { + t.Fatalf("CreateUser() error = %v", err) + } +} + +// --- Login --- + +func TestLogin_Success_Returns200WithCookies(t *testing.T) { + t.Parallel() + mux, svc := newTestAuthMux(t) + createTestUser(t, svc) + + body := `{"username":"admin","password":"` + testPassword + `"}` + req := httptest.NewRequest(http.MethodPost, "/api/login", strings.NewReader(body)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp loginResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode error: %v", err) + } + if resp.User.Username != "admin" { + t.Errorf("username = %q, want %q", resp.User.Username, "admin") + } + + cookies := rec.Result().Cookies() + var sessionCookie, csrfCookie *http.Cookie + for _, c := range cookies { + switch c.Name { + case auth.SessionCookieName: + sessionCookie = c + case auth.CSRFCookieName: + csrfCookie = c + } + } + + if sessionCookie == nil { + t.Fatal("missing session cookie") + } + if !sessionCookie.HttpOnly { + t.Error("session cookie should be HttpOnly") + } + if sessionCookie.Secure { + t.Error("session cookie should not be Secure in test (secureCookies=false)") + } + + if csrfCookie == nil { + t.Fatal("missing CSRF cookie") + } + if csrfCookie.HttpOnly { + t.Error("CSRF cookie should NOT be HttpOnly") + } +} + +func TestLogin_WrongPassword_Returns401(t *testing.T) { + t.Parallel() + mux, svc := newTestAuthMux(t) + createTestUser(t, svc) + + body := `{"username":"admin","password":"wrongpassword1"}` + req := httptest.NewRequest(http.MethodPost, "/api/login", strings.NewReader(body)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusUnauthorized, rec.Body.String()) + } +} + +func TestLogin_MissingFields_Returns400(t *testing.T) { + t.Parallel() + mux, _ := newTestAuthMux(t) + + body := `{"username":"admin"}` + req := httptest.NewRequest(http.MethodPost, "/api/login", strings.NewReader(body)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestLogin_RateLimited_Returns429(t *testing.T) { + t.Parallel() + mux, svc := newTestAuthMux(t) + createTestUser(t, svc) + + for range 5 { + body := `{"username":"admin","password":"badpassword00"}` + req := httptest.NewRequest(http.MethodPost, "/api/login", strings.NewReader(body)) + req.RemoteAddr = "10.0.0.1:12345" + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + } + + body := `{"username":"admin","password":"` + testPassword + `"}` + req := httptest.NewRequest(http.MethodPost, "/api/login", strings.NewReader(body)) + req.RemoteAddr = "10.0.0.1:12345" + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusTooManyRequests { + t.Errorf("status = %d, want %d", rec.Code, http.StatusTooManyRequests) + } + if got := rec.Header().Get("Retry-After"); got != "60" { + t.Errorf("Retry-After = %q, want %q", got, "60") + } +} + +// --- Logout --- + +func TestLogout_Returns204_ClearsCookies(t *testing.T) { + t.Parallel() + mux, svc := newTestAuthMux(t) + createTestUser(t, svc) + + result, _ := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + + req := httptest.NewRequest(http.MethodPost, "/api/logout", nil) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: result.SessionToken}) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Errorf("status = %d, want %d", rec.Code, http.StatusNoContent) + } + + for _, c := range rec.Result().Cookies() { + if c.Name == auth.SessionCookieName && c.MaxAge != -1 { + t.Error("session cookie should be cleared (MaxAge=-1)") + } + if c.Name == auth.CSRFCookieName && c.MaxAge != -1 { + t.Error("CSRF cookie should be cleared (MaxAge=-1)") + } + } +} + +// --- ChangePassword --- + +func TestChangePassword_Success_Returns204(t *testing.T) { + t.Parallel() + mux, svc := newTestAuthMux(t) + createTestUser(t, svc) + result, _ := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + + body := `{"current_password":"` + testPassword + `","new_password":"newpassword1234"}` + req := httptest.NewRequest(http.MethodPut, "/api/user/password", strings.NewReader(body)) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: result.SessionToken}) + req = req.WithContext(auth.WithUser(req.Context(), &auth.User{ + ID: result.User.ID, + Username: result.User.Username, + })) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Errorf("status = %d, want %d; body = %s", rec.Code, http.StatusNoContent, rec.Body.String()) + } +} + +func TestChangePassword_WrongCurrent_Returns401(t *testing.T) { + t.Parallel() + mux, svc := newTestAuthMux(t) + createTestUser(t, svc) + result, _ := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + + body := `{"current_password":"wrongcurrent1","new_password":"newpassword1234"}` + req := httptest.NewRequest(http.MethodPut, "/api/user/password", strings.NewReader(body)) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: result.SessionToken}) + req = req.WithContext(auth.WithUser(req.Context(), &auth.User{ + ID: result.User.ID, + Username: result.User.Username, + })) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} + +func TestChangePassword_TooShort_Returns400(t *testing.T) { + t.Parallel() + mux, svc := newTestAuthMux(t) + createTestUser(t, svc) + result, _ := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + + body := `{"current_password":"` + testPassword + `","new_password":"short"}` + req := httptest.NewRequest(http.MethodPut, "/api/user/password", strings.NewReader(body)) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: result.SessionToken}) + req = req.WithContext(auth.WithUser(req.Context(), &auth.User{ + ID: result.User.ID, + Username: result.User.Username, + })) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestChangePassword_NoUser_Returns401(t *testing.T) { + t.Parallel() + mux, _ := newTestAuthMux(t) + + body := `{"current_password":"old","new_password":"newpassword1234"}` + req := httptest.NewRequest(http.MethodPut, "/api/user/password", strings.NewReader(body)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} + +// --- Me --- + +func TestMe_Authenticated_Returns200(t *testing.T) { + t.Parallel() + mux, svc := newTestAuthMux(t) + createTestUser(t, svc) + result, _ := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + + req := httptest.NewRequest(http.MethodGet, "/api/me", nil) + req = req.WithContext(auth.WithUser(req.Context(), &auth.User{ + ID: result.User.ID, + Username: result.User.Username, + })) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp loginResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode error: %v", err) + } + if resp.User.Username != "admin" { + t.Errorf("username = %q, want %q", resp.User.Username, "admin") + } +} + +func TestMe_Unauthenticated_Returns401(t *testing.T) { + t.Parallel() + mux, _ := newTestAuthMux(t) + + req := httptest.NewRequest(http.MethodGet, "/api/me", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} diff --git a/admin/auth/auth.go b/admin/auth/auth.go new file mode 100644 index 0000000..fde53ba --- /dev/null +++ b/admin/auth/auth.go @@ -0,0 +1,302 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "log/slog" + "net/http" + "time" + + "golang.org/x/crypto/bcrypt" + + "github.com/cloudblue/chaperone/admin/store" +) + +// Cookie and header names used by the auth system. +const ( + SessionCookieName = "session" + CSRFCookieName = "csrf_token" + CSRFHeaderName = "X-CSRF-Token" + MinPasswordLength = 12 + MaxPasswordLength = 72 // bcrypt silently truncates beyond this + MaxUsernameLength = 64 +) + +// Sentinel errors for authentication operations. +var ( + ErrUnauthenticated = errors.New("unauthenticated") + ErrInvalidCredentials = errors.New("invalid credentials") + ErrPasswordTooShort = errors.New("password too short") + ErrPasswordTooLong = errors.New("password too long") + ErrInvalidUsername = errors.New("invalid username") + ErrRateLimited = errors.New("rate limited") + ErrSessionExpired = errors.New("session expired") +) + +// dummyHash is a pre-computed bcrypt hash used when a user is not found, +// to prevent timing-based username enumeration. +// +//nolint:errcheck // bcrypt.GenerateFromPassword with DefaultCost never fails +var dummyHash, _ = bcrypt.GenerateFromPassword([]byte("dummy-password-for-timing"), bcrypt.DefaultCost) + +// Authenticator validates a request and returns the authenticated user. +// This interface enables future auth backends (OIDC, etc.) without +// changing middleware or handlers. +type Authenticator interface { + Authenticate(r *http.Request) (*User, error) +} + +// User represents an authenticated portal user. +type User struct { + ID int64 + Username string +} + +// LoginResult holds the outcome of a successful login. +type LoginResult struct { + SessionToken string // #nosec G117 -- this is a session token, not a hardcoded secret + User User +} + +// Service implements local authentication using SQLite-backed users +// with bcrypt password hashing and session cookies. +type Service struct { + store *store.Store + limiter *RateLimiter + maxAge time.Duration + idleTimeout time.Duration +} + +// NewService creates an auth service with the given session parameters. +func NewService(st *store.Store, maxAge, idleTimeout time.Duration) *Service { + return &Service{ + store: st, + limiter: NewRateLimiter(5, time.Minute), + maxAge: maxAge, + idleTimeout: idleTimeout, + } +} + +// SweepRateLimiter removes expired entries from the rate limiter. +func (s *Service) SweepRateLimiter() { + s.limiter.Sweep() +} + +// Authenticate validates the session cookie on an HTTP request. +// It checks absolute TTL, idle timeout, and touches the session. +func (s *Service) Authenticate(r *http.Request) (*User, error) { + cookie, err := r.Cookie(SessionCookieName) + if err != nil { + return nil, ErrUnauthenticated + } + + rawToken := cookie.Value + sess, err := s.store.GetSessionByToken(r.Context(), rawToken) + if errors.Is(err, store.ErrSessionNotFound) { + return nil, ErrUnauthenticated + } + if err != nil { + return nil, fmt.Errorf("validating session: %w", err) + } + + now := time.Now() + if now.After(sess.ExpiresAt) { + if delErr := s.store.DeleteSession(r.Context(), rawToken); delErr != nil { + slog.Error("deleting expired session", "error", delErr) + } + return nil, ErrSessionExpired + } + if now.Sub(sess.LastActiveAt) > s.idleTimeout { + if delErr := s.store.DeleteSession(r.Context(), rawToken); delErr != nil { + slog.Error("deleting idle session", "error", delErr) + } + return nil, ErrSessionExpired + } + + if touchErr := s.store.TouchSession(r.Context(), rawToken); touchErr != nil { + slog.Error("touching session", "error", touchErr) + } + + user, err := s.store.GetUserByID(r.Context(), sess.UserID) + if err != nil { + return nil, fmt.Errorf("getting user for session: %w", err) + } + + return &User{ID: user.ID, Username: user.Username}, nil +} + +// Login authenticates credentials and creates a new session. +// It enforces rate limiting per IP and uses constant-time comparison +// to prevent username enumeration. +func (s *Service) Login(ctx context.Context, ip, username, password string) (*LoginResult, error) { + if !s.limiter.Allow(ip) { + return nil, ErrRateLimited + } + + user, err := s.store.GetUserByUsername(ctx, username) + if errors.Is(err, store.ErrUserNotFound) { + _ = bcrypt.CompareHashAndPassword(dummyHash, []byte(password)) + s.limiter.Record(ip) + return nil, ErrInvalidCredentials + } + if err != nil { + return nil, fmt.Errorf("looking up user: %w", err) + } + + err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) + if err != nil { + s.limiter.Record(ip) + return nil, ErrInvalidCredentials + } + + s.limiter.Reset(ip) + + token, err := GenerateToken(32) + if err != nil { + return nil, err + } + + expiresAt := time.Now().Add(s.maxAge) + if err := s.store.CreateSession(ctx, user.ID, token, expiresAt); err != nil { + return nil, fmt.Errorf("creating session: %w", err) + } + + return &LoginResult{ + SessionToken: token, + User: User{ID: user.ID, Username: user.Username}, + }, nil +} + +// Logout invalidates a session by its token. +func (s *Service) Logout(ctx context.Context, token string) error { + return s.store.DeleteSession(ctx, token) +} + +// ChangePassword verifies the current password, updates to a new one, +// and invalidates all sessions except the caller's. +func (s *Service) ChangePassword(ctx context.Context, userID int64, currentToken, currentPassword, newPassword string) error { + if err := validatePassword(newPassword); err != nil { + return err + } + + user, err := s.store.GetUserByID(ctx, userID) + if err != nil { + return fmt.Errorf("getting user: %w", err) + } + + err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(currentPassword)) + if err != nil { + return ErrInvalidCredentials + } + + hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("hashing password: %w", err) + } + + if err := s.store.UpdateUserPassword(ctx, userID, string(hash)); err != nil { + return err + } + + if err := s.store.DeleteOtherSessions(ctx, userID, currentToken); err != nil { + return fmt.Errorf("invalidating other sessions: %w", err) + } + + return nil +} + +// CreateUser creates a new portal user (CLI operation). +func (s *Service) CreateUser(ctx context.Context, username, password string) error { + if err := validateUsername(username); err != nil { + return err + } + if err := validatePassword(password); err != nil { + return err + } + + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("hashing password: %w", err) + } + + _, err = s.store.CreateUser(ctx, username, string(hash)) + return err +} + +// ResetPassword changes a user's password and invalidates all their sessions (CLI operation). +func (s *Service) ResetPassword(ctx context.Context, username, password string) error { + if err := validatePassword(password); err != nil { + return err + } + + user, err := s.store.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("looking up user: %w", err) + } + + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("hashing password: %w", err) + } + + if err := s.store.UpdateUserPassword(ctx, user.ID, string(hash)); err != nil { + return fmt.Errorf("updating password: %w", err) + } + + if err := s.store.DeleteUserSessions(ctx, user.ID); err != nil { + return fmt.Errorf("invalidating sessions: %w", err) + } + + return nil +} + +func validatePassword(password string) error { + if len(password) < MinPasswordLength { + return ErrPasswordTooShort + } + if len(password) > MaxPasswordLength { + return ErrPasswordTooLong + } + return nil +} + +func validateUsername(username string) error { + if username == "" || len(username) > MaxUsernameLength { + return ErrInvalidUsername + } + for _, r := range username { + if r < 0x20 || r > 0x7E { + return ErrInvalidUsername + } + } + return nil +} + +// GenerateToken returns a cryptographically random hex-encoded token. +func GenerateToken(byteLen int) (string, error) { + b := make([]byte, byteLen) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generating random token: %w", err) + } + return hex.EncodeToString(b), nil +} + +type contextKey struct{} + +// WithUser stores an authenticated user in the request context. +func WithUser(ctx context.Context, u *User) context.Context { + return context.WithValue(ctx, contextKey{}, u) +} + +// ContextUser extracts the authenticated user from a request context. +// Returns nil if no user is present (unauthenticated request). +func ContextUser(ctx context.Context) *User { + u, _ := ctx.Value(contextKey{}).(*User) + return u +} diff --git a/admin/auth/auth_test.go b/admin/auth/auth_test.go new file mode 100644 index 0000000..07ac071 --- /dev/null +++ b/admin/auth/auth_test.go @@ -0,0 +1,500 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/cloudblue/chaperone/admin/store" +) + +const testPassword = "securepassword12" + +func newTestService(t *testing.T) *Service { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + st, err := store.Open(context.Background(), dbPath) + if err != nil { + t.Fatalf("Open(%q) failed: %v", dbPath, err) + } + t.Cleanup(func() { st.Close() }) + return NewService(st, 24*time.Hour, 2*time.Hour) +} + +func createTestUser(t *testing.T, svc *Service) { + t.Helper() + if err := svc.CreateUser(context.Background(), "admin", testPassword); err != nil { + t.Fatalf("CreateUser() error = %v", err) + } +} + +func loginTestUser(t *testing.T, svc *Service) string { + t.Helper() + result, err := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + if err != nil { + t.Fatalf("Login() error = %v", err) + } + return result.SessionToken +} + +// --- CreateUser --- + +func TestCreateUser_Success(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + err := svc.CreateUser(context.Background(), "admin", testPassword) + if err != nil { + t.Fatalf("CreateUser() error = %v", err) + } +} + +func TestCreateUser_TooShort_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + err := svc.CreateUser(context.Background(), "admin", "short") + if !errors.Is(err, ErrPasswordTooShort) { + t.Errorf("error = %v, want %v", err, ErrPasswordTooShort) + } +} + +func TestCreateUser_TooLong_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + longPass := strings.Repeat("a", MaxPasswordLength+1) + err := svc.CreateUser(context.Background(), "admin", longPass) + if !errors.Is(err, ErrPasswordTooLong) { + t.Errorf("error = %v, want %v", err, ErrPasswordTooLong) + } +} + +func TestCreateUser_EmptyUsername_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + err := svc.CreateUser(context.Background(), "", testPassword) + if !errors.Is(err, ErrInvalidUsername) { + t.Errorf("error = %v, want %v", err, ErrInvalidUsername) + } +} + +func TestCreateUser_UsernameTooLong_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + longName := strings.Repeat("a", MaxUsernameLength+1) + err := svc.CreateUser(context.Background(), longName, testPassword) + if !errors.Is(err, ErrInvalidUsername) { + t.Errorf("error = %v, want %v", err, ErrInvalidUsername) + } +} + +func TestCreateUser_ControlCharsInUsername_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + err := svc.CreateUser(context.Background(), "admin\x00", testPassword) + if !errors.Is(err, ErrInvalidUsername) { + t.Errorf("error = %v, want %v", err, ErrInvalidUsername) + } +} + +func TestCreateUser_Duplicate_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + ctx := context.Background() + + if err := svc.CreateUser(ctx, "admin", testPassword); err != nil { + t.Fatalf("first CreateUser() error = %v", err) + } + + err := svc.CreateUser(ctx, "admin", testPassword) + if !errors.Is(err, store.ErrDuplicateUsername) { + t.Errorf("error = %v, want %v", err, store.ErrDuplicateUsername) + } +} + +// --- Login --- + +func TestLogin_Success(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + + result, err := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + if err != nil { + t.Fatalf("Login() error = %v", err) + } + if result.SessionToken == "" { + t.Error("expected non-empty session token") + } + if result.User.Username != "admin" { + t.Errorf("Username = %q, want %q", result.User.Username, "admin") + } +} + +func TestLogin_WrongPassword_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + + _, err := svc.Login(context.Background(), "127.0.0.1", "admin", "wrongpassword1") + if !errors.Is(err, ErrInvalidCredentials) { + t.Errorf("error = %v, want %v", err, ErrInvalidCredentials) + } +} + +func TestLogin_UserNotFound_ReturnsInvalidCredentials(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + _, err := svc.Login(context.Background(), "127.0.0.1", "nobody", testPassword) + if !errors.Is(err, ErrInvalidCredentials) { + t.Errorf("error = %v, want %v", err, ErrInvalidCredentials) + } +} + +func TestLogin_RateLimited_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + ctx := context.Background() + + for range 5 { + svc.Login(ctx, "10.0.0.1", "admin", "badpassword00") + } + + _, err := svc.Login(ctx, "10.0.0.1", "admin", testPassword) + if !errors.Is(err, ErrRateLimited) { + t.Errorf("error = %v, want %v", err, ErrRateLimited) + } +} + +func TestLogin_RateLimit_ResetsOnSuccess(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + ctx := context.Background() + + // 4 failures (under limit of 5). + for range 4 { + svc.Login(ctx, "10.0.0.2", "admin", "badpassword00") + } + + // Successful login resets counter. + if _, err := svc.Login(ctx, "10.0.0.2", "admin", testPassword); err != nil { + t.Fatalf("Login() error = %v", err) + } + + // 4 more failures should be allowed (counter was reset). + for range 4 { + svc.Login(ctx, "10.0.0.2", "admin", "badpassword00") + } + + // 5th failure should still be under limit. + _, err := svc.Login(ctx, "10.0.0.2", "admin", "badpassword00") + if !errors.Is(err, ErrInvalidCredentials) { + t.Errorf("error = %v, want %v (should still be under limit)", err, ErrInvalidCredentials) + } +} + +// --- Authenticate --- + +func TestAuthenticate_ValidSession_ReturnsUser(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + token := loginTestUser(t, svc) + + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token}) + + user, err := svc.Authenticate(req) + if err != nil { + t.Fatalf("Authenticate() error = %v", err) + } + if user.Username != "admin" { + t.Errorf("Username = %q, want %q", user.Username, "admin") + } +} + +func TestAuthenticate_NoCookie_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + + _, err := svc.Authenticate(req) + if !errors.Is(err, ErrUnauthenticated) { + t.Errorf("error = %v, want %v", err, ErrUnauthenticated) + } +} + +func TestAuthenticate_InvalidToken_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: "bad-token"}) + + _, err := svc.Authenticate(req) + if !errors.Is(err, ErrUnauthenticated) { + t.Errorf("error = %v, want %v", err, ErrUnauthenticated) + } +} + +func TestAuthenticate_ExpiredSession_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + // Use very short maxAge so session expires immediately. + svc.maxAge = time.Millisecond + createTestUser(t, svc) + token := loginTestUser(t, svc) + + time.Sleep(5 * time.Millisecond) + + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token}) + + _, err := svc.Authenticate(req) + if !errors.Is(err, ErrSessionExpired) { + t.Errorf("error = %v, want %v", err, ErrSessionExpired) + } +} + +func TestAuthenticate_IdleSession_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + // Use very short idle timeout. + svc.idleTimeout = time.Millisecond + createTestUser(t, svc) + token := loginTestUser(t, svc) + + time.Sleep(5 * time.Millisecond) + + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token}) + + _, err := svc.Authenticate(req) + if !errors.Is(err, ErrSessionExpired) { + t.Errorf("error = %v, want %v", err, ErrSessionExpired) + } +} + +// --- Logout --- + +func TestLogout_DeletesSession(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + token := loginTestUser(t, svc) + + if err := svc.Logout(context.Background(), token); err != nil { + t.Fatalf("Logout() error = %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token}) + + _, err := svc.Authenticate(req) + if !errors.Is(err, ErrUnauthenticated) { + t.Errorf("after logout: error = %v, want %v", err, ErrUnauthenticated) + } +} + +// --- ChangePassword --- + +func TestChangePassword_Success(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + ctx := context.Background() + + result, _ := svc.Login(ctx, "127.0.0.1", "admin", testPassword) + + newPass := "newpassword1234" + if err := svc.ChangePassword(ctx, result.User.ID, result.SessionToken, testPassword, newPass); err != nil { + t.Fatalf("ChangePassword() error = %v", err) + } + + // Old password should fail. + _, err := svc.Login(ctx, "127.0.0.1", "admin", testPassword) + if !errors.Is(err, ErrInvalidCredentials) { + t.Errorf("old password: error = %v, want %v", err, ErrInvalidCredentials) + } + + // New password should work. + if _, err := svc.Login(ctx, "127.0.0.1", "admin", newPass); err != nil { + t.Errorf("new password: unexpected error = %v", err) + } +} + +func TestChangePassword_InvalidatesOtherSessions(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + ctx := context.Background() + + // Login twice to create two sessions. + result1, _ := svc.Login(ctx, "127.0.0.1", "admin", testPassword) + result2, _ := svc.Login(ctx, "127.0.0.2", "admin", testPassword) + + // Change password using session 1. + newPass := "newpassword1234" + if err := svc.ChangePassword(ctx, result1.User.ID, result1.SessionToken, testPassword, newPass); err != nil { + t.Fatalf("ChangePassword() error = %v", err) + } + + // Session 1 (caller) should still work. + req1 := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + req1.AddCookie(&http.Cookie{Name: SessionCookieName, Value: result1.SessionToken}) + if _, err := svc.Authenticate(req1); err != nil { + t.Errorf("caller session should remain valid: %v", err) + } + + // Session 2 (other) should be invalidated. + req2 := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + req2.AddCookie(&http.Cookie{Name: SessionCookieName, Value: result2.SessionToken}) + _, err := svc.Authenticate(req2) + if !errors.Is(err, ErrUnauthenticated) { + t.Errorf("other session: error = %v, want %v", err, ErrUnauthenticated) + } +} + +func TestChangePassword_WrongCurrent_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + + result, _ := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + + err := svc.ChangePassword(context.Background(), result.User.ID, result.SessionToken, "wrongcurrent1", "newpassword1234") + if !errors.Is(err, ErrInvalidCredentials) { + t.Errorf("error = %v, want %v", err, ErrInvalidCredentials) + } +} + +func TestChangePassword_TooShort_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + + result, _ := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + + err := svc.ChangePassword(context.Background(), result.User.ID, result.SessionToken, testPassword, "short") + if !errors.Is(err, ErrPasswordTooShort) { + t.Errorf("error = %v, want %v", err, ErrPasswordTooShort) + } +} + +func TestChangePassword_TooLong_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + + result, _ := svc.Login(context.Background(), "127.0.0.1", "admin", testPassword) + + longPass := strings.Repeat("a", MaxPasswordLength+1) + err := svc.ChangePassword(context.Background(), result.User.ID, result.SessionToken, testPassword, longPass) + if !errors.Is(err, ErrPasswordTooLong) { + t.Errorf("error = %v, want %v", err, ErrPasswordTooLong) + } +} + +// --- ResetPassword --- + +func TestResetPassword_Success_InvalidatesSessions(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + token := loginTestUser(t, svc) + + newPass := "resetpassword12" + if err := svc.ResetPassword(context.Background(), "admin", newPass); err != nil { + t.Fatalf("ResetPassword() error = %v", err) + } + + // Old session should be invalid. + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: SessionCookieName, Value: token}) + _, err := svc.Authenticate(req) + if !errors.Is(err, ErrUnauthenticated) { + t.Errorf("old session: error = %v, want %v", err, ErrUnauthenticated) + } + + // New password should work. + if _, err := svc.Login(context.Background(), "127.0.0.1", "admin", newPass); err != nil { + t.Errorf("new password: unexpected error = %v", err) + } +} + +func TestResetPassword_UserNotFound_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + + err := svc.ResetPassword(context.Background(), "nobody", "newpassword1234") + if err == nil { + t.Error("expected error, got nil") + } +} + +func TestResetPassword_TooShort_ReturnsError(t *testing.T) { + t.Parallel() + svc := newTestService(t) + createTestUser(t, svc) + + err := svc.ResetPassword(context.Background(), "admin", "short") + if !errors.Is(err, ErrPasswordTooShort) { + t.Errorf("error = %v, want %v", err, ErrPasswordTooShort) + } +} + +// --- GenerateToken --- + +func TestGenerateToken_ReturnsUniqueTokens(t *testing.T) { + t.Parallel() + + t1, err := GenerateToken(32) + if err != nil { + t.Fatalf("GenerateToken() error = %v", err) + } + t2, err := GenerateToken(32) + if err != nil { + t.Fatalf("GenerateToken() error = %v", err) + } + + if len(t1) != 64 { + t.Errorf("token length = %d, want 64", len(t1)) + } + if t1 == t2 { + t.Error("consecutive tokens should be unique") + } +} + +// --- Context helpers --- + +func TestContextUser_RoundTrip(t *testing.T) { + t.Parallel() + + ctx := context.Background() + if got := ContextUser(ctx); got != nil { + t.Error("expected nil user from empty context") + } + + user := &User{ID: 42, Username: "admin"} + ctx = WithUser(ctx, user) + got := ContextUser(ctx) + if got == nil || got.ID != 42 || got.Username != "admin" { + t.Errorf("ContextUser() = %v, want %v", got, user) + } +} diff --git a/admin/auth/middleware.go b/admin/auth/middleware.go new file mode 100644 index 0000000..4759215 --- /dev/null +++ b/admin/auth/middleware.go @@ -0,0 +1,103 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "crypto/subtle" + "encoding/json" + "log/slog" + "net/http" + "strings" +) + +// RequireAuth wraps an http.Handler and enforces session authentication +// on all /api/* routes except POST /api/login and GET /api/health. +func RequireAuth(auth Authenticator, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requiresAuth(r) { + next.ServeHTTP(w, r) + return + } + + user, err := auth.Authenticate(r) + if err != nil { + slog.Debug("authentication failed", "path", r.URL.Path, "error", err) + writeError(w, http.StatusUnauthorized, "UNAUTHORIZED", "Authentication required") + return + } + + next.ServeHTTP(w, r.WithContext(WithUser(r.Context(), user))) + }) +} + +// CSRFProtection validates the double-submit cookie pattern on all +// write requests to /api/* (except POST /api/login which has no session yet). +func CSRFProtection(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requiresCSRF(r) { + next.ServeHTTP(w, r) + return + } + + cookie, err := r.Cookie(CSRFCookieName) + if err != nil { + writeError(w, http.StatusForbidden, "CSRF_ERROR", "Missing CSRF token") + return + } + + header := r.Header.Get(CSRFHeaderName) + if header == "" || subtle.ConstantTimeCompare([]byte(header), []byte(cookie.Value)) != 1 { + writeError(w, http.StatusForbidden, "CSRF_ERROR", "Invalid CSRF token") + return + } + + next.ServeHTTP(w, r) + }) +} + +func requiresAuth(r *http.Request) bool { + if !strings.HasPrefix(r.URL.Path, "/api/") { + return false + } + if r.Method == http.MethodPost && r.URL.Path == "/api/login" { + return false + } + if r.Method == http.MethodGet && r.URL.Path == "/api/health" { + return false + } + return true +} + +func requiresCSRF(r *http.Request) bool { + switch r.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + return false + } + if !strings.HasPrefix(r.URL.Path, "/api/") { + return false + } + if r.URL.Path == "/api/login" { + return false + } + return true +} + +type middlewareError struct { + Error middlewareErrorDetail `json:"error"` +} + +type middlewareErrorDetail struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func writeError(w http.ResponseWriter, status int, code, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(middlewareError{ + Error: middlewareErrorDetail{Code: code, Message: message}, + }); err != nil { + slog.Error("writing middleware error response", "error", err) + } +} diff --git a/admin/auth/middleware_test.go b/admin/auth/middleware_test.go new file mode 100644 index 0000000..ffabab9 --- /dev/null +++ b/admin/auth/middleware_test.go @@ -0,0 +1,235 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +type mockAuthenticator struct { + user *User + err error +} + +func (m *mockAuthenticator) Authenticate(_ *http.Request) (*User, error) { + return m.user, m.err +} + +func echoUserHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := ContextUser(r.Context()) + if user != nil { + w.Header().Set("X-User", user.Username) + } + w.WriteHeader(http.StatusOK) + }) +} + +// --- RequireAuth --- + +func TestRequireAuth_ProtectedRoute_Unauthenticated_Returns401(t *testing.T) { + t.Parallel() + + handler := RequireAuth(&mockAuthenticator{err: ErrUnauthenticated}, echoUserHandler()) + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } + + var resp middlewareError + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode error: %v", err) + } + if resp.Error.Code != "UNAUTHORIZED" { + t.Errorf("code = %q, want %q", resp.Error.Code, "UNAUTHORIZED") + } +} + +func TestRequireAuth_ProtectedRoute_Authenticated_PassesThrough(t *testing.T) { + t.Parallel() + + user := &User{ID: 1, Username: "admin"} + handler := RequireAuth(&mockAuthenticator{user: user}, echoUserHandler()) + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + if got := rec.Header().Get("X-User"); got != "admin" { + t.Errorf("X-User = %q, want %q", got, "admin") + } +} + +func TestRequireAuth_LoginRoute_SkipsAuth(t *testing.T) { + t.Parallel() + + handler := RequireAuth(&mockAuthenticator{err: ErrUnauthenticated}, echoUserHandler()) + req := httptest.NewRequest(http.MethodPost, "/api/login", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d (login should skip auth)", rec.Code, http.StatusOK) + } +} + +func TestRequireAuth_HealthRoute_SkipsAuth(t *testing.T) { + t.Parallel() + + handler := RequireAuth(&mockAuthenticator{err: ErrUnauthenticated}, echoUserHandler()) + req := httptest.NewRequest(http.MethodGet, "/api/health", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d (health should skip auth)", rec.Code, http.StatusOK) + } +} + +func TestRequireAuth_SPARoute_SkipsAuth(t *testing.T) { + t.Parallel() + + handler := RequireAuth(&mockAuthenticator{err: ErrUnauthenticated}, echoUserHandler()) + req := httptest.NewRequest(http.MethodGet, "/dashboard", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d (non-API should skip auth)", rec.Code, http.StatusOK) + } +} + +// --- CSRFProtection --- + +func TestCSRF_SafeMethod_SkipsCheck(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodGet, "/api/instances", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d (GET should skip CSRF)", rec.Code, http.StatusOK) + } +} + +func TestCSRF_LoginRoute_SkipsCheck(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodPost, "/api/login", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d (login should skip CSRF)", rec.Code, http.StatusOK) + } +} + +func TestCSRF_WriteRequest_MissingCookie_Returns403(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodPost, "/api/instances", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d", rec.Code, http.StatusForbidden) + } +} + +func TestCSRF_WriteRequest_MissingHeader_Returns403(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodPost, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: CSRFCookieName, Value: "token123"}) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d", rec.Code, http.StatusForbidden) + } +} + +func TestCSRF_WriteRequest_MismatchedToken_Returns403(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodPost, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: CSRFCookieName, Value: "token123"}) + req.Header.Set(CSRFHeaderName, "different-token") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d", rec.Code, http.StatusForbidden) + } +} + +func TestCSRF_WriteRequest_ValidToken_PassesThrough(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodPost, "/api/instances", nil) + req.AddCookie(&http.Cookie{Name: CSRFCookieName, Value: "token123"}) + req.Header.Set(CSRFHeaderName, "token123") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestCSRF_DeleteRequest_RequiresToken(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodDelete, "/api/instances/1", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d (DELETE should require CSRF)", rec.Code, http.StatusForbidden) + } +} + +func TestCSRF_PutRequest_ValidToken_PassesThrough(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodPut, "/api/user/password", nil) + req.AddCookie(&http.Cookie{Name: CSRFCookieName, Value: "csrf-val"}) + req.Header.Set(CSRFHeaderName, "csrf-val") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestCSRF_NonAPIRoute_SkipsCheck(t *testing.T) { + t.Parallel() + + handler := CSRFProtection(echoUserHandler()) + req := httptest.NewRequest(http.MethodPost, "/some/form", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d (non-API should skip CSRF)", rec.Code, http.StatusOK) + } +} diff --git a/admin/auth/ratelimit.go b/admin/auth/ratelimit.go new file mode 100644 index 0000000..3da339f --- /dev/null +++ b/admin/auth/ratelimit.go @@ -0,0 +1,95 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "sync" + "time" +) + +// RateLimiter tracks failed login attempts per IP using a sliding window. +type RateLimiter struct { + mu sync.Mutex + attempts map[string][]time.Time + maxAttempts int + window time.Duration + now func() time.Time // injectable clock for testing +} + +// NewRateLimiter creates a rate limiter that allows maxAttempts failures +// within the given window duration per IP. +func NewRateLimiter(maxAttempts int, window time.Duration) *RateLimiter { + return &RateLimiter{ + attempts: make(map[string][]time.Time), + maxAttempts: maxAttempts, + window: window, + now: time.Now, + } +} + +// Allow returns true if the IP has not exceeded the failure limit. +func (rl *RateLimiter) Allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.prune(ip) + return len(rl.attempts[ip]) < rl.maxAttempts +} + +// Record logs a failed login attempt for the given IP. +func (rl *RateLimiter) Record(ip string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.attempts[ip] = append(rl.attempts[ip], rl.now()) +} + +// Reset clears the failure counter for an IP (called on successful login). +func (rl *RateLimiter) Reset(ip string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + delete(rl.attempts, ip) +} + +// prune removes attempts older than the sliding window. Must be called under lock. +func (rl *RateLimiter) prune(ip string) { + attempts := rl.attempts[ip] + if len(attempts) == 0 { + return + } + + cutoff := rl.now().Add(-rl.window) + i := 0 + for i < len(attempts) && attempts[i].Before(cutoff) { + i++ + } + if i > 0 { + rl.attempts[ip] = attempts[i:] + } + if len(rl.attempts[ip]) == 0 { + delete(rl.attempts, ip) + } +} + +// Sweep removes all expired entries across all IPs. +// Call periodically from a background goroutine to prevent unbounded growth +// from IPs that record failures but never return. +func (rl *RateLimiter) Sweep() { + rl.mu.Lock() + defer rl.mu.Unlock() + + cutoff := rl.now().Add(-rl.window) + for ip, attempts := range rl.attempts { + i := 0 + for i < len(attempts) && attempts[i].Before(cutoff) { + i++ + } + if i == len(attempts) { + delete(rl.attempts, ip) + } else if i > 0 { + rl.attempts[ip] = attempts[i:] + } + } +} diff --git a/admin/auth/ratelimit_test.go b/admin/auth/ratelimit_test.go new file mode 100644 index 0000000..b552063 --- /dev/null +++ b/admin/auth/ratelimit_test.go @@ -0,0 +1,159 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "testing" + "time" +) + +func TestRateLimiter_AllowsUnderLimit(t *testing.T) { + t.Parallel() + rl := NewRateLimiter(3, time.Minute) + + for i := range 3 { + if !rl.Allow("1.2.3.4") { + t.Fatalf("attempt %d should be allowed", i+1) + } + rl.Record("1.2.3.4") + } +} + +func TestRateLimiter_BlocksAtLimit(t *testing.T) { + t.Parallel() + rl := NewRateLimiter(3, time.Minute) + + for range 3 { + rl.Record("1.2.3.4") + } + + if rl.Allow("1.2.3.4") { + t.Error("should be blocked after 3 failures") + } +} + +func TestRateLimiter_DifferentIPs_Independent(t *testing.T) { + t.Parallel() + rl := NewRateLimiter(2, time.Minute) + + rl.Record("1.1.1.1") + rl.Record("1.1.1.1") + + if rl.Allow("1.1.1.1") { + t.Error("1.1.1.1 should be blocked") + } + if !rl.Allow("2.2.2.2") { + t.Error("2.2.2.2 should be allowed (separate IP)") + } +} + +func TestRateLimiter_ResetClearsCounter(t *testing.T) { + t.Parallel() + rl := NewRateLimiter(2, time.Minute) + + rl.Record("1.2.3.4") + rl.Record("1.2.3.4") + rl.Reset("1.2.3.4") + + if !rl.Allow("1.2.3.4") { + t.Error("should be allowed after reset") + } +} + +func TestRateLimiter_SlidingWindow_PrunesOldAttempts(t *testing.T) { + t.Parallel() + rl := NewRateLimiter(2, time.Minute) + + now := time.Now() + rl.now = func() time.Time { return now } + + rl.Record("1.2.3.4") + rl.Record("1.2.3.4") + + // Advance past the window. + rl.now = func() time.Time { return now.Add(61 * time.Second) } + + if !rl.Allow("1.2.3.4") { + t.Error("old attempts should be pruned; IP should be allowed") + } +} + +func TestRateLimiter_PartialPrune_KeepsRecentAttempts(t *testing.T) { + t.Parallel() + rl := NewRateLimiter(2, time.Minute) + + now := time.Now() + rl.now = func() time.Time { return now } + rl.Record("1.2.3.4") // t=0 + + rl.now = func() time.Time { return now.Add(50 * time.Second) } + rl.Record("1.2.3.4") // t=50s + + // At t=61s, the first attempt is pruned but the second is still within window. + rl.now = func() time.Time { return now.Add(61 * time.Second) } + + if !rl.Allow("1.2.3.4") { + t.Error("should be allowed (only 1 recent attempt after prune)") + } + + rl.Record("1.2.3.4") // second recent attempt + + if rl.Allow("1.2.3.4") { + t.Error("should be blocked (2 recent attempts)") + } +} + +func TestRateLimiter_Sweep_RemovesExpiredEntries(t *testing.T) { + t.Parallel() + rl := NewRateLimiter(2, time.Minute) + + now := time.Now() + rl.now = func() time.Time { return now } + + rl.Record("1.1.1.1") + rl.Record("2.2.2.2") + rl.Record("3.3.3.3") + + // Advance past the window. + rl.now = func() time.Time { return now.Add(61 * time.Second) } + + rl.Sweep() + + rl.mu.Lock() + remaining := len(rl.attempts) + rl.mu.Unlock() + + if remaining != 0 { + t.Errorf("expected 0 entries after sweep, got %d", remaining) + } +} + +func TestRateLimiter_Sweep_KeepsRecentEntries(t *testing.T) { + t.Parallel() + rl := NewRateLimiter(2, time.Minute) + + now := time.Now() + rl.now = func() time.Time { return now } + rl.Record("1.1.1.1") // old + + rl.now = func() time.Time { return now.Add(50 * time.Second) } + rl.Record("2.2.2.2") // recent + + // At t=61s, 1.1.1.1 is expired but 2.2.2.2 is still within window. + rl.now = func() time.Time { return now.Add(61 * time.Second) } + + rl.Sweep() + + rl.mu.Lock() + remaining := len(rl.attempts) + rl.mu.Unlock() + + if remaining != 1 { + t.Errorf("expected 1 entry after sweep, got %d", remaining) + } + + if !rl.Allow("1.1.1.1") { + t.Error("1.1.1.1 should be allowed after sweep removed its expired entry") + } +} diff --git a/admin/cmd/chaperone-admin/main.go b/admin/cmd/chaperone-admin/main.go index 6f34f7d..4a9950f 100644 --- a/admin/cmd/chaperone-admin/main.go +++ b/admin/cmd/chaperone-admin/main.go @@ -12,10 +12,14 @@ import ( "net/http" "os" "os/signal" + "strings" "syscall" "time" + "golang.org/x/term" + "github.com/cloudblue/chaperone/admin" + "github.com/cloudblue/chaperone/admin/auth" "github.com/cloudblue/chaperone/admin/config" "github.com/cloudblue/chaperone/admin/metrics" "github.com/cloudblue/chaperone/admin/poller" @@ -36,9 +40,38 @@ func main() { } func run() error { - configPath := flag.String("config", "", "Path to config file (default: chaperone-admin.yaml)") - showVersion := flag.Bool("version", false, "Print version and exit") - flag.Parse() + if len(os.Args) > 1 && !strings.HasPrefix(os.Args[1], "-") { + switch os.Args[1] { + case "create-user": + return runCreateUser(os.Args[2:]) + case "reset-password": + return runResetPassword(os.Args[2:]) + case "serve": + return runServer(os.Args[2:]) + default: + return fmt.Errorf("unknown command %q (available: serve, create-user, reset-password)", os.Args[1]) + } + } + return runServer(os.Args[1:]) +} + +func runServer(args []string) error { + fs := flag.NewFlagSet("serve", flag.ExitOnError) + fs.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: chaperone-admin [command] [flags]\n\n") + fmt.Fprintf(os.Stderr, "Commands:\n") + fmt.Fprintf(os.Stderr, " serve Start the admin portal server (default)\n") + fmt.Fprintf(os.Stderr, " create-user Create a new admin user\n") + fmt.Fprintf(os.Stderr, " reset-password Reset a user's password\n") + fmt.Fprintf(os.Stderr, "\nServer flags:\n") + fs.PrintDefaults() + } + + configPath := fs.String("config", "", "Path to config file (default: chaperone-admin.yaml)") + showVersion := fs.Bool("version", false, "Print version and exit") + if err := fs.Parse(args); err != nil { + return err + } if *showVersion { fmt.Printf("chaperone-admin %s (commit: %s, built: %s)\n", Version, GitCommit, BuildDate) @@ -67,16 +100,158 @@ func run() error { return fmt.Errorf("creating server: %w", err) } - // Start the background health + metrics poller. - pollerCtx, pollerCancel := context.WithCancel(context.Background()) - defer pollerCancel() + // Start background goroutines. + bgCtx, bgCancel := context.WithCancel(context.Background()) + defer bgCancel() p := poller.New(st, collector, cfg.Scraper.Interval.Unwrap(), cfg.Scraper.Timeout.Unwrap()) - go p.Run(pollerCtx) + go p.Run(bgCtx) + go cleanupExpiredSessions(bgCtx, st) + go sweepRateLimiter(bgCtx, srv) return serve(cfg.Server.Addr, srv) } +func runCreateUser(args []string) error { + fs := flag.NewFlagSet("create-user", flag.ExitOnError) + configPath := fs.String("config", "", "Path to config file") + username := fs.String("username", "", "Username for the new user") + if err := fs.Parse(args); err != nil { + return err + } + + if *username == "" { + return fmt.Errorf("--username is required") + } + + password, err := readPasswordConfirm("Password: ", "Confirm password: ") + if err != nil { + return err + } + + svc, cleanup, err := openAuthService(*configPath) + if err != nil { + return err + } + defer cleanup() + + if err := svc.CreateUser(context.Background(), *username, password); err != nil { + return fmt.Errorf("creating user: %w", err) + } + + fmt.Fprintf(os.Stderr, "User %q created successfully.\n", *username) + return nil +} + +func runResetPassword(args []string) error { + fs := flag.NewFlagSet("reset-password", flag.ExitOnError) + configPath := fs.String("config", "", "Path to config file") + username := fs.String("username", "", "Username to reset") + if err := fs.Parse(args); err != nil { + return err + } + + if *username == "" { + return fmt.Errorf("--username is required") + } + + password, err := readPasswordConfirm("New password: ", "Confirm password: ") + if err != nil { + return err + } + + svc, cleanup, err := openAuthService(*configPath) + if err != nil { + return err + } + defer cleanup() + + if err := svc.ResetPassword(context.Background(), *username, password); err != nil { + return fmt.Errorf("resetting password: %w", err) + } + + fmt.Fprintf(os.Stderr, "Password for %q has been reset. All existing sessions invalidated.\n", *username) + return nil +} + +func openAuthService(configPath string) (*auth.Service, func(), error) { + cfg, err := config.Load(configPath) + if err != nil { + return nil, nil, fmt.Errorf("loading configuration: %w", err) + } + + st, err := store.Open(context.Background(), cfg.Database.Path) + if err != nil { + return nil, nil, fmt.Errorf("opening database: %w", err) + } + + svc := auth.NewService(st, cfg.Session.MaxAge.Unwrap(), cfg.Session.IdleTimeout.Unwrap()) + cleanup := func() { + if err := st.Close(); err != nil { + slog.Error("closing database", "error", err) + } + } + return svc, cleanup, nil +} + +func readPasswordConfirm(prompt, confirmPrompt string) (string, error) { + password, err := readPassword(prompt) + if err != nil { + return "", err + } + confirm, err := readPassword(confirmPrompt) + if err != nil { + return "", err + } + if password != confirm { + return "", fmt.Errorf("passwords do not match") + } + return password, nil +} + +func readPassword(prompt string) (string, error) { + fmt.Fprint(os.Stderr, prompt) + password, err := term.ReadPassword(int(os.Stdin.Fd())) // #nosec G115 -- stdin fd is always 0 + fmt.Fprintln(os.Stderr) // newline after hidden input + if err != nil { + return "", fmt.Errorf("reading password: %w", err) + } + return string(password), nil +} + +func cleanupExpiredSessions(ctx context.Context, st *store.Store) { + ticker := time.NewTicker(time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + n, err := st.DeleteExpiredSessions(ctx) + if err != nil { + slog.Error("cleaning up expired sessions", "error", err) + } else if n > 0 { + slog.Info("cleaned up expired sessions", "count", n) + } + } + } +} + +func sweepRateLimiter(ctx context.Context, srv *admin.Server) { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + srv.SweepRateLimiter() + } + } +} + func serve(addr string, srv *admin.Server) error { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() diff --git a/admin/config/config.go b/admin/config/config.go index 0f6371e..82ef646 100644 --- a/admin/config/config.go +++ b/admin/config/config.go @@ -34,7 +34,8 @@ type Config struct { // ServerConfig configures the HTTP server. type ServerConfig struct { - Addr string `yaml:"addr"` + Addr string `yaml:"addr"` + SecureCookies bool `yaml:"secure_cookies"` } // DatabaseConfig configures the SQLite database. diff --git a/admin/config/loader.go b/admin/config/loader.go index 50ca0da..341c132 100644 --- a/admin/config/loader.go +++ b/admin/config/loader.go @@ -98,6 +98,9 @@ func applyEnvOverrides(cfg *Config) error { if v := getEnv("SERVER_ADDR"); v != "" { cfg.Server.Addr = v } + if v := getEnv("SERVER_SECURE_COOKIES"); v != "" { + cfg.Server.SecureCookies = v == "true" || v == "1" + } if v := getEnv("DATABASE_PATH"); v != "" { cfg.Database.Path = v } diff --git a/admin/config/loader_test.go b/admin/config/loader_test.go index ac513e7..320597e 100644 --- a/admin/config/loader_test.go +++ b/admin/config/loader_test.go @@ -69,6 +69,7 @@ func TestLoad_ValidYAML_ParsesAllFields(t *testing.T) { path := writeTestConfig(t, ` server: addr: "0.0.0.0:9090" + secure_cookies: true database: path: "/var/lib/admin.db" scraper: @@ -94,6 +95,9 @@ log: if cfg.Server.Addr != "0.0.0.0:9090" { t.Errorf("Server.Addr = %q, want %q", cfg.Server.Addr, "0.0.0.0:9090") } + if !cfg.Server.SecureCookies { + t.Error("Server.SecureCookies = false, want true") + } if cfg.Database.Path != "/var/lib/admin.db" { t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/var/lib/admin.db") } @@ -163,6 +167,7 @@ func TestLoad_EnvOverrides_AllFields(t *testing.T) { // Arrange path := filepath.Join(t.TempDir(), "nonexistent.yaml") t.Setenv("CHAPERONE_ADMIN_SERVER_ADDR", "0.0.0.0:3000") + t.Setenv("CHAPERONE_ADMIN_SERVER_SECURE_COOKIES", "true") t.Setenv("CHAPERONE_ADMIN_DATABASE_PATH", "/tmp/test.db") t.Setenv("CHAPERONE_ADMIN_SCRAPER_INTERVAL", "20s") t.Setenv("CHAPERONE_ADMIN_SCRAPER_TIMEOUT", "8s") @@ -182,6 +187,9 @@ func TestLoad_EnvOverrides_AllFields(t *testing.T) { if cfg.Server.Addr != "0.0.0.0:3000" { t.Errorf("Server.Addr = %q, want %q", cfg.Server.Addr, "0.0.0.0:3000") } + if !cfg.Server.SecureCookies { + t.Error("Server.SecureCookies = false, want true") + } if cfg.Database.Path != "/tmp/test.db" { t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/tmp/test.db") } diff --git a/admin/go.mod b/admin/go.mod index 8fe936d..6de0c0d 100644 --- a/admin/go.mod +++ b/admin/go.mod @@ -5,6 +5,8 @@ go 1.26.2 require ( github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.67.5 + golang.org/x/crypto v0.48.0 + golang.org/x/term v0.40.0 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.46.1 ) @@ -19,7 +21,7 @@ require ( github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/sys v0.39.0 // indirect + golang.org/x/sys v0.41.0 // indirect google.golang.org/protobuf v1.36.11 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/admin/go.sum b/admin/go.sum index a22dc4b..6a60705 100644 --- a/admin/go.sum +++ b/admin/go.sum @@ -37,6 +37,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= @@ -44,8 +46,10 @@ golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= diff --git a/admin/server.go b/admin/server.go index e730e7c..e61ada8 100644 --- a/admin/server.go +++ b/admin/server.go @@ -14,6 +14,7 @@ import ( "time" "github.com/cloudblue/chaperone/admin/api" + "github.com/cloudblue/chaperone/admin/auth" "github.com/cloudblue/chaperone/admin/config" "github.com/cloudblue/chaperone/admin/metrics" "github.com/cloudblue/chaperone/admin/store" @@ -21,31 +22,38 @@ import ( // Server is the admin portal HTTP server. type Server struct { - httpServer *http.Server - config *config.Config - store *store.Store - collector *metrics.Collector + httpServer *http.Server + config *config.Config + store *store.Store + collector *metrics.Collector + authService *auth.Service } // NewServer creates a new admin portal server. func NewServer(cfg *config.Config, st *store.Store, collector *metrics.Collector) (*Server, error) { mux := http.NewServeMux() + authService := auth.NewService(st, cfg.Session.MaxAge.Unwrap(), cfg.Session.IdleTimeout.Unwrap()) + secureCookies := cfg.Server.SecureCookies + + handler := securityHeaders(auth.RequireAuth(authService, auth.CSRFProtection(mux))) + s := &Server{ httpServer: &http.Server{ Addr: cfg.Server.Addr, - Handler: securityHeaders(mux), + Handler: handler, ReadHeaderTimeout: 5 * time.Second, ReadTimeout: 15 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 60 * time.Second, }, - config: cfg, - store: st, - collector: collector, + config: cfg, + store: st, + collector: collector, + authService: authService, } - if err := s.routes(mux); err != nil { + if err := s.routes(mux, authService, secureCookies); err != nil { return nil, fmt.Errorf("setting up routes: %w", err) } return s, nil @@ -61,10 +69,19 @@ func (s *Server) Shutdown(ctx context.Context) error { return s.httpServer.Shutdown(ctx) } -func (s *Server) routes(mux *http.ServeMux) error { +// SweepRateLimiter removes expired entries from the login rate limiter. +func (s *Server) SweepRateLimiter() { + s.authService.SweepRateLimiter() +} + +func (s *Server) routes(mux *http.ServeMux, authService *auth.Service, secureCookies bool) error { // API health check for the portal itself. mux.HandleFunc("GET /api/health", s.handleHealth) + // Auth endpoints (login, logout, password change). + authHandler := api.NewAuthHandler(authService, secureCookies, s.config.Session.MaxAge.Unwrap()) + authHandler.Register(mux) + // Instance CRUD + test connection. instances := api.NewInstanceHandler(s.store, s.config.Scraper.Timeout.Unwrap()) instances.Register(mux) diff --git a/admin/store/user.go b/admin/store/user.go new file mode 100644 index 0000000..91db384 --- /dev/null +++ b/admin/store/user.go @@ -0,0 +1,198 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package store + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "time" +) + +// Sentinel errors for user and session operations. +var ( + ErrUserNotFound = errors.New("user not found") + ErrDuplicateUsername = errors.New("duplicate username") + ErrSessionNotFound = errors.New("session not found") +) + +// User represents a portal admin user. +type User struct { + ID int64 `json:"id"` + Username string `json:"username"` + PasswordHash string `json:"-"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// Session represents an active user session. +type Session struct { + ID int64 + UserID int64 + TokenHash string + ExpiresAt time.Time + LastActiveAt time.Time + CreatedAt time.Time +} + +// CreateUser inserts a new user with the given bcrypt password hash. +func (s *Store) CreateUser(ctx context.Context, username, passwordHash string) (*User, error) { + result, err := s.db.ExecContext(ctx, + `INSERT INTO users (username, password_hash) VALUES (?, ?)`, username, passwordHash) + if err != nil { + if isUniqueConstraintError(err) { + return nil, ErrDuplicateUsername + } + return nil, fmt.Errorf("creating user: %w", err) + } + + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("getting last insert ID: %w", err) + } + return s.GetUserByID(ctx, id) +} + +// GetUserByID returns a user by their ID. +func (s *Store) GetUserByID(ctx context.Context, id int64) (*User, error) { + var u User + err := s.db.QueryRowContext(ctx, + `SELECT id, username, password_hash, created_at, updated_at FROM users WHERE id = ?`, id). + Scan(&u.ID, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrUserNotFound + } + if err != nil { + return nil, fmt.Errorf("getting user %d: %w", id, err) + } + return &u, nil +} + +// GetUserByUsername returns a user by their username. +func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) { + var u User + err := s.db.QueryRowContext(ctx, + `SELECT id, username, password_hash, created_at, updated_at FROM users WHERE username = ?`, username). + Scan(&u.ID, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrUserNotFound + } + if err != nil { + return nil, fmt.Errorf("getting user %q: %w", username, err) + } + return &u, nil +} + +// UpdateUserPassword changes a user's password hash. +func (s *Store) UpdateUserPassword(ctx context.Context, userID int64, passwordHash string) error { + result, err := s.db.ExecContext(ctx, + `UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, + passwordHash, userID) + if err != nil { + return fmt.Errorf("updating password for user %d: %w", userID, err) + } + + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("checking rows affected: %w", err) + } + if n == 0 { + return ErrUserNotFound + } + return nil +} + +// CreateSession inserts a new session record. +// The raw token is hashed before storage; callers always pass raw tokens. +func (s *Store) CreateSession(ctx context.Context, userID int64, token string, expiresAt time.Time) error { + _, err := s.db.ExecContext(ctx, + `INSERT INTO sessions (user_id, token, expires_at) VALUES (?, ?, ?)`, + userID, hashToken(token), expiresAt) + if err != nil { + return fmt.Errorf("creating session: %w", err) + } + return nil +} + +// GetSessionByToken looks up a session by its raw token (hashed for lookup). +func (s *Store) GetSessionByToken(ctx context.Context, token string) (*Session, error) { + var sess Session + err := s.db.QueryRowContext(ctx, + `SELECT id, user_id, token, expires_at, last_active_at, created_at + FROM sessions WHERE token = ?`, hashToken(token)). + Scan(&sess.ID, &sess.UserID, &sess.TokenHash, &sess.ExpiresAt, &sess.LastActiveAt, &sess.CreatedAt) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrSessionNotFound + } + if err != nil { + return nil, fmt.Errorf("getting session: %w", err) + } + return &sess, nil +} + +// TouchSession updates the last_active_at timestamp for idle timeout tracking. +// Accepts the raw token (hashed for lookup). +func (s *Store) TouchSession(ctx context.Context, token string) error { + _, err := s.db.ExecContext(ctx, + `UPDATE sessions SET last_active_at = CURRENT_TIMESTAMP WHERE token = ?`, hashToken(token)) + if err != nil { + return fmt.Errorf("touching session: %w", err) + } + return nil +} + +// DeleteSession removes a session by raw token (hashed for lookup). +func (s *Store) DeleteSession(ctx context.Context, token string) error { + _, err := s.db.ExecContext(ctx, + `DELETE FROM sessions WHERE token = ?`, hashToken(token)) + if err != nil { + return fmt.Errorf("deleting session: %w", err) + } + return nil +} + +// DeleteUserSessions removes all sessions for a user (password reset). +func (s *Store) DeleteUserSessions(ctx context.Context, userID int64) error { + _, err := s.db.ExecContext(ctx, + `DELETE FROM sessions WHERE user_id = ?`, userID) + if err != nil { + return fmt.Errorf("deleting sessions for user %d: %w", userID, err) + } + return nil +} + +// DeleteOtherSessions removes all sessions for a user except the given token. +// Accepts the raw keepToken (hashed for comparison). +func (s *Store) DeleteOtherSessions(ctx context.Context, userID int64, keepToken string) error { + _, err := s.db.ExecContext(ctx, + `DELETE FROM sessions WHERE user_id = ? AND token != ?`, userID, hashToken(keepToken)) + if err != nil { + return fmt.Errorf("deleting other sessions for user %d: %w", userID, err) + } + return nil +} + +// DeleteExpiredSessions removes sessions past their absolute expiry. +func (s *Store) DeleteExpiredSessions(ctx context.Context) (int64, error) { + result, err := s.db.ExecContext(ctx, + `DELETE FROM sessions WHERE expires_at < ?`, time.Now()) + if err != nil { + return 0, fmt.Errorf("deleting expired sessions: %w", err) + } + n, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("checking rows affected: %w", err) + } + return n, nil +} + +// hashToken computes the SHA-256 hash of a raw session token. +// The database stores hashes so a DB compromise does not leak usable tokens. +func hashToken(raw string) string { + h := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(h[:]) +} diff --git a/admin/store/user_test.go b/admin/store/user_test.go new file mode 100644 index 0000000..ec96121 --- /dev/null +++ b/admin/store/user_test.go @@ -0,0 +1,287 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package store + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestCreateUser_Success(t *testing.T) { + t.Parallel() + st := openTestStore(t) + + user, err := st.CreateUser(context.Background(), "admin", "$2a$10$hash") + if err != nil { + t.Fatalf("CreateUser() error = %v", err) + } + if user.ID == 0 { + t.Error("expected non-zero ID") + } + if user.Username != "admin" { + t.Errorf("Username = %q, want %q", user.Username, "admin") + } +} + +func TestCreateUser_DuplicateUsername_ReturnsError(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + if _, err := st.CreateUser(ctx, "admin", "$2a$10$hash1"); err != nil { + t.Fatalf("first CreateUser() error = %v", err) + } + + _, err := st.CreateUser(ctx, "admin", "$2a$10$hash2") + if !errors.Is(err, ErrDuplicateUsername) { + t.Errorf("error = %v, want %v", err, ErrDuplicateUsername) + } +} + +func TestGetUserByID_Exists_ReturnsUser(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + created, err := st.CreateUser(ctx, "admin", "$2a$10$hash") + if err != nil { + t.Fatalf("CreateUser() error = %v", err) + } + + got, err := st.GetUserByID(ctx, created.ID) + if err != nil { + t.Fatalf("GetUserByID() error = %v", err) + } + if got.Username != "admin" { + t.Errorf("Username = %q, want %q", got.Username, "admin") + } + if got.PasswordHash != "$2a$10$hash" { + t.Errorf("PasswordHash = %q, want %q", got.PasswordHash, "$2a$10$hash") + } +} + +func TestGetUserByID_NotFound_ReturnsError(t *testing.T) { + t.Parallel() + st := openTestStore(t) + + _, err := st.GetUserByID(context.Background(), 999) + if !errors.Is(err, ErrUserNotFound) { + t.Errorf("error = %v, want %v", err, ErrUserNotFound) + } +} + +func TestGetUserByUsername_Exists_ReturnsUser(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + if _, err := st.CreateUser(ctx, "admin", "$2a$10$hash"); err != nil { + t.Fatalf("CreateUser() error = %v", err) + } + + got, err := st.GetUserByUsername(ctx, "admin") + if err != nil { + t.Fatalf("GetUserByUsername() error = %v", err) + } + if got.Username != "admin" { + t.Errorf("Username = %q, want %q", got.Username, "admin") + } +} + +func TestGetUserByUsername_NotFound_ReturnsError(t *testing.T) { + t.Parallel() + st := openTestStore(t) + + _, err := st.GetUserByUsername(context.Background(), "nonexistent") + if !errors.Is(err, ErrUserNotFound) { + t.Errorf("error = %v, want %v", err, ErrUserNotFound) + } +} + +func TestUpdateUserPassword_Success(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + user, err := st.CreateUser(ctx, "admin", "$2a$10$oldhash") + if err != nil { + t.Fatalf("CreateUser() error = %v", err) + } + + err = st.UpdateUserPassword(ctx, user.ID, "$2a$10$newhash") + if err != nil { + t.Fatalf("UpdateUserPassword() error = %v", err) + } + + got, err := st.GetUserByID(ctx, user.ID) + if err != nil { + t.Fatalf("GetUserByID() error = %v", err) + } + if got.PasswordHash != "$2a$10$newhash" { + t.Errorf("PasswordHash = %q, want %q", got.PasswordHash, "$2a$10$newhash") + } +} + +func TestUpdateUserPassword_NotFound_ReturnsError(t *testing.T) { + t.Parallel() + st := openTestStore(t) + + err := st.UpdateUserPassword(context.Background(), 999, "$2a$10$hash") + if !errors.Is(err, ErrUserNotFound) { + t.Errorf("error = %v, want %v", err, ErrUserNotFound) + } +} + +func TestCreateSession_And_GetByToken(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + user, err := st.CreateUser(ctx, "admin", "$2a$10$hash") + if err != nil { + t.Fatalf("CreateUser() error = %v", err) + } + + expiresAt := time.Now().Add(24 * time.Hour) + err = st.CreateSession(ctx, user.ID, "tok-abc-123", expiresAt) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + sess, err := st.GetSessionByToken(ctx, "tok-abc-123") + if err != nil { + t.Fatalf("GetSessionByToken() error = %v", err) + } + if sess.UserID != user.ID { + t.Errorf("UserID = %d, want %d", sess.UserID, user.ID) + } + if sess.TokenHash == "" { + t.Error("expected non-empty token hash") + } + if sess.TokenHash == "tok-abc-123" { + t.Error("token should be stored as a hash, not raw") + } +} + +func TestGetSessionByToken_NotFound_ReturnsError(t *testing.T) { + t.Parallel() + st := openTestStore(t) + + _, err := st.GetSessionByToken(context.Background(), "nonexistent") + if !errors.Is(err, ErrSessionNotFound) { + t.Errorf("error = %v, want %v", err, ErrSessionNotFound) + } +} + +func TestTouchSession_UpdatesLastActiveAt(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + user, _ := st.CreateUser(ctx, "admin", "$2a$10$hash") + expiresAt := time.Now().Add(24 * time.Hour) + _ = st.CreateSession(ctx, user.ID, "tok-touch", expiresAt) + + before, _ := st.GetSessionByToken(ctx, "tok-touch") + if err := st.TouchSession(ctx, "tok-touch"); err != nil { + t.Fatalf("TouchSession() error = %v", err) + } + after, _ := st.GetSessionByToken(ctx, "tok-touch") + + if after.LastActiveAt.Before(before.LastActiveAt) { + t.Errorf("LastActiveAt should not go backward after touch") + } +} + +func TestDeleteSession_RemovesSession(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + user, _ := st.CreateUser(ctx, "admin", "$2a$10$hash") + _ = st.CreateSession(ctx, user.ID, "tok-del", time.Now().Add(time.Hour)) + + if err := st.DeleteSession(ctx, "tok-del"); err != nil { + t.Fatalf("DeleteSession() error = %v", err) + } + + _, err := st.GetSessionByToken(ctx, "tok-del") + if !errors.Is(err, ErrSessionNotFound) { + t.Errorf("after delete: error = %v, want %v", err, ErrSessionNotFound) + } +} + +func TestDeleteUserSessions_RemovesAll(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + user, _ := st.CreateUser(ctx, "admin", "$2a$10$hash") + _ = st.CreateSession(ctx, user.ID, "tok-a", time.Now().Add(time.Hour)) + _ = st.CreateSession(ctx, user.ID, "tok-b", time.Now().Add(time.Hour)) + + if err := st.DeleteUserSessions(ctx, user.ID); err != nil { + t.Fatalf("DeleteUserSessions() error = %v", err) + } + + _, errA := st.GetSessionByToken(ctx, "tok-a") + _, errB := st.GetSessionByToken(ctx, "tok-b") + if !errors.Is(errA, ErrSessionNotFound) || !errors.Is(errB, ErrSessionNotFound) { + t.Errorf("sessions should be deleted; errA = %v, errB = %v", errA, errB) + } +} + +func TestDeleteOtherSessions_KeepsSpecifiedToken(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + user, _ := st.CreateUser(ctx, "admin", "$2a$10$hash") + _ = st.CreateSession(ctx, user.ID, "tok-keep", time.Now().Add(time.Hour)) + _ = st.CreateSession(ctx, user.ID, "tok-remove-a", time.Now().Add(time.Hour)) + _ = st.CreateSession(ctx, user.ID, "tok-remove-b", time.Now().Add(time.Hour)) + + if err := st.DeleteOtherSessions(ctx, user.ID, "tok-keep"); err != nil { + t.Fatalf("DeleteOtherSessions() error = %v", err) + } + + // Kept session should still exist. + if _, err := st.GetSessionByToken(ctx, "tok-keep"); err != nil { + t.Errorf("kept session should exist: %v", err) + } + + // Other sessions should be deleted. + _, errA := st.GetSessionByToken(ctx, "tok-remove-a") + _, errB := st.GetSessionByToken(ctx, "tok-remove-b") + if !errors.Is(errA, ErrSessionNotFound) || !errors.Is(errB, ErrSessionNotFound) { + t.Errorf("other sessions should be deleted; errA = %v, errB = %v", errA, errB) + } +} + +func TestDeleteExpiredSessions_RemovesExpiredOnly(t *testing.T) { + t.Parallel() + st := openTestStore(t) + ctx := context.Background() + + user, _ := st.CreateUser(ctx, "admin", "$2a$10$hash") + + // One expired, one active. + _ = st.CreateSession(ctx, user.ID, "tok-expired", time.Now().Add(-time.Hour)) + _ = st.CreateSession(ctx, user.ID, "tok-active", time.Now().Add(time.Hour)) + + n, err := st.DeleteExpiredSessions(ctx) + if err != nil { + t.Fatalf("DeleteExpiredSessions() error = %v", err) + } + if n != 1 { + t.Errorf("deleted = %d, want 1", n) + } + + _, err = st.GetSessionByToken(ctx, "tok-active") + if err != nil { + t.Errorf("active session should still exist: %v", err) + } +} diff --git a/admin/ui/src/App.vue b/admin/ui/src/App.vue index f80ecb1..6ce1078 100644 --- a/admin/ui/src/App.vue +++ b/admin/ui/src/App.vue @@ -1,10 +1,14 @@ diff --git a/admin/ui/src/layouts/AppLayout.vue b/admin/ui/src/layouts/AppLayout.vue index 3d7f8f9..8f5301b 100644 --- a/admin/ui/src/layouts/AppLayout.vue +++ b/admin/ui/src/layouts/AppLayout.vue @@ -66,8 +66,72 @@ Audit Log + + + + + + Settings + +
+
+ + + + + {{ auth.user?.username }} +
+ +
@@ -76,9 +140,22 @@ diff --git a/admin/ui/src/router/index.js b/admin/ui/src/router/index.js index 267c2b7..3ee19de 100644 --- a/admin/ui/src/router/index.js +++ b/admin/ui/src/router/index.js @@ -1,9 +1,18 @@ import { createRouter, createWebHistory } from 'vue-router'; +import { useAuthStore } from '../stores/auth.js'; import DashboardView from '../views/DashboardView.vue'; import AuditLogView from '../views/AuditLogView.vue'; import InstanceDetailView from '../views/InstanceDetailView.vue'; +import LoginView from '../views/LoginView.vue'; +import SettingsView from '../views/SettingsView.vue'; const routes = [ + { + path: '/login', + name: 'login', + component: LoginView, + meta: { public: true }, + }, { path: '/', name: 'dashboard', @@ -19,6 +28,11 @@ const routes = [ name: 'audit-log', component: AuditLogView, }, + { + path: '/settings', + name: 'settings', + component: SettingsView, + }, { path: '/:pathMatch(.*)*', name: 'not-found', @@ -31,4 +45,18 @@ const router = createRouter({ routes, }); +router.beforeEach(async (to) => { + const auth = useAuthStore(); + + if (!auth.ready) await auth.checkSession(); + + if (!to.meta.public && !auth.isAuthenticated) { + return { name: 'login', query: { redirect: to.fullPath } }; + } + + if (to.name === 'login' && auth.isAuthenticated) { + return { name: 'dashboard' }; + } +}); + export default router; diff --git a/admin/ui/src/stores/auth.js b/admin/ui/src/stores/auth.js new file mode 100644 index 0000000..4596b91 --- /dev/null +++ b/admin/ui/src/stores/auth.js @@ -0,0 +1,47 @@ +import { ref, computed } from 'vue'; +import { defineStore } from 'pinia'; +import * as api from '../utils/api.js'; + +export const useAuthStore = defineStore('auth', () => { + const user = ref(null); + const ready = ref(false); + const isAuthenticated = computed(() => user.value !== null); + + async function checkSession() { + try { + const data = await api.get('/api/me'); + user.value = data.user; + } catch { + user.value = null; + } finally { + ready.value = true; + } + } + + async function login(username, password) { + const data = await api.post('/api/login', { username, password }); + user.value = data.user; + } + + async function logout() { + await api.post('/api/logout'); + user.value = null; + } + + async function changePassword(currentPassword, newPassword) { + await api.put('/api/user/password', { + current_password: currentPassword, + new_password: newPassword, + }); + } + + return { + user, + ready, + isAuthenticated, + checkSession, + login, + logout, + changePassword, + }; +}); diff --git a/admin/ui/src/stores/auth.test.js b/admin/ui/src/stores/auth.test.js new file mode 100644 index 0000000..c412716 --- /dev/null +++ b/admin/ui/src/stores/auth.test.js @@ -0,0 +1,89 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { setActivePinia, createPinia } from 'pinia'; +import { useAuthStore } from './auth.js'; + +vi.mock('../utils/api.js', () => ({ + get: vi.fn(), + post: vi.fn(), + put: vi.fn(), +})); + +import * as api from '../utils/api.js'; + +describe('useAuthStore', () => { + let store; + + beforeEach(() => { + setActivePinia(createPinia()); + store = useAuthStore(); + vi.restoreAllMocks(); + }); + + describe('checkSession', () => { + it('sets user on valid session', async () => { + api.get.mockResolvedValue({ user: { id: 1, username: 'admin' } }); + await store.checkSession(); + expect(store.user).toEqual({ id: 1, username: 'admin' }); + expect(store.ready).toBe(true); + expect(store.isAuthenticated).toBe(true); + }); + + it('clears user on invalid session', async () => { + api.get.mockRejectedValue(new Error('401')); + await store.checkSession(); + expect(store.user).toBeNull(); + expect(store.ready).toBe(true); + expect(store.isAuthenticated).toBe(false); + }); + }); + + describe('login', () => { + it('sets user on success', async () => { + api.post.mockResolvedValue({ user: { id: 1, username: 'admin' } }); + await store.login('admin', 'password123456'); + expect(api.post).toHaveBeenCalledWith('/api/login', { + username: 'admin', + password: 'password123456', + }); + expect(store.user).toEqual({ id: 1, username: 'admin' }); + }); + + it('propagates error on failure', async () => { + const err = new Error('Invalid'); + err.status = 401; + api.post.mockRejectedValue(err); + await expect(store.login('admin', 'wrong')).rejects.toThrow('Invalid'); + expect(store.user).toBeNull(); + }); + }); + + describe('logout', () => { + it('clears user on success', async () => { + store.user = { id: 1, username: 'admin' }; + api.post.mockResolvedValue(null); + await store.logout(); + expect(api.post).toHaveBeenCalledWith('/api/logout'); + expect(store.user).toBeNull(); + }); + }); + + describe('changePassword', () => { + it('sends correct payload', async () => { + api.put.mockResolvedValue(null); + await store.changePassword('old-password1', 'new-password1'); + expect(api.put).toHaveBeenCalledWith('/api/user/password', { + current_password: 'old-password1', + new_password: 'new-password1', + }); + }); + + it('propagates error on failure', async () => { + const err = new Error('Current password is incorrect'); + err.status = 401; + api.put.mockRejectedValue(err); + await expect( + store.changePassword('wrong', 'new-password1'), + ).rejects.toThrow('Current password is incorrect'); + }); + }); +}); diff --git a/admin/ui/src/utils/api.js b/admin/ui/src/utils/api.js index 1582b1b..d7456b3 100644 --- a/admin/ui/src/utils/api.js +++ b/admin/ui/src/utils/api.js @@ -7,16 +7,36 @@ class ApiError extends Error { } } +export function getCsrfToken() { + const match = document.cookie.match(/(?:^|;\s*)csrf_token=([^;]*)/); + return match ? decodeURIComponent(match[1]) : ''; +} + +const writeMethods = new Set(['POST', 'PUT', 'DELETE', 'PATCH']); + async function request(path, options = {}) { - const res = await fetch(path, { - ...options, - headers: { - 'Content-Type': 'application/json', - ...options.headers, - }, - }); + const headers = { + 'Content-Type': 'application/json', + ...options.headers, + }; + + if (writeMethods.has(options.method)) { + const token = getCsrfToken(); + if (token) headers['X-CSRF-Token'] = token; + } + + const res = await fetch(path, { ...options, headers }); if (!res.ok) { + if (res.status === 401 && path !== '/api/login') { + const { useAuthStore } = await import('../stores/auth.js'); + const auth = useAuthStore(); + if (auth.ready) { + auth.user = null; + window.location.href = '/login'; + } + } + let message = `Request failed (${res.status})`; let code; try { diff --git a/admin/ui/src/utils/api.test.js b/admin/ui/src/utils/api.test.js index fa7c1e2..d2d9c48 100644 --- a/admin/ui/src/utils/api.test.js +++ b/admin/ui/src/utils/api.test.js @@ -1,9 +1,13 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; -import { get, post, put, del, ApiError } from './api.js'; +import { get, post, put, del, getCsrfToken, ApiError } from './api.js'; describe('api client', () => { beforeEach(() => { vi.restoreAllMocks(); + Object.defineProperty(document, 'cookie', { + writable: true, + value: '', + }); }); function mockFetch(status, body, { json = true } = {}) { @@ -18,6 +22,23 @@ describe('api client', () => { return res; } + describe('getCsrfToken', () => { + it('returns empty string when no cookie', () => { + document.cookie = ''; + expect(getCsrfToken()).toBe(''); + }); + + it('extracts csrf_token from cookies', () => { + document.cookie = 'session=abc123; csrf_token=my-token-value'; + expect(getCsrfToken()).toBe('my-token-value'); + }); + + it('decodes URL-encoded token', () => { + document.cookie = 'csrf_token=token%20with%20spaces'; + expect(getCsrfToken()).toBe('token with spaces'); + }); + }); + describe('get', () => { it('returns parsed JSON on success', async () => { mockFetch(200, [{ id: 1 }]); @@ -28,6 +49,14 @@ describe('api client', () => { }); }); + it('does not send CSRF token on GET', async () => { + document.cookie = 'csrf_token=my-token'; + mockFetch(200, {}); + await get('/api/me'); + const [, opts] = globalThis.fetch.mock.calls[0]; + expect(opts.headers['X-CSRF-Token']).toBeUndefined(); + }); + it('throws ApiError with server message on failure', async () => { mockFetch(404, { error: { @@ -49,6 +78,41 @@ describe('api client', () => { expect(err.message).toBe('Request failed (500)'); expect(err.status).toBe(500); }); + + it('redirects to login on 401 when session is established', async () => { + mockFetch(401, { + error: { code: 'UNAUTHORIZED', message: 'No valid session' }, + }); + const mockStore = { user: { id: 1 }, ready: true }; + vi.doMock('../stores/auth.js', () => ({ + useAuthStore: () => mockStore, + })); + delete window.location; + window.location = { href: '/' }; + const err = await get('/api/me').catch((e) => e); + expect(err).toBeInstanceOf(ApiError); + expect(err.status).toBe(401); + expect(mockStore.user).toBeNull(); + expect(window.location.href).toBe('/login'); + vi.doUnmock('../stores/auth.js'); + }); + + it('does not redirect on 401 during initial session check', async () => { + mockFetch(401, { + error: { code: 'UNAUTHORIZED', message: 'No valid session' }, + }); + const mockStore = { user: null, ready: false }; + vi.doMock('../stores/auth.js', () => ({ + useAuthStore: () => mockStore, + })); + delete window.location; + window.location = { href: '/' }; + const err = await get('/api/me').catch((e) => e); + expect(err).toBeInstanceOf(ApiError); + expect(err.status).toBe(401); + expect(window.location.href).toBe('/'); + vi.doUnmock('../stores/auth.js'); + }); }); describe('post', () => { @@ -69,6 +133,14 @@ describe('api client', () => { }); }); + it('includes CSRF token on POST', async () => { + document.cookie = 'csrf_token=my-csrf-token'; + mockFetch(200, {}); + await post('/api/logout'); + const [, opts] = globalThis.fetch.mock.calls[0]; + expect(opts.headers['X-CSRF-Token']).toBe('my-csrf-token'); + }); + it('throws ApiError with server message on conflict', async () => { mockFetch(409, { error: { @@ -95,6 +167,17 @@ describe('api client', () => { const [, opts] = globalThis.fetch.mock.calls[0]; expect(opts.method).toBe('PUT'); }); + + it('includes CSRF token on PUT', async () => { + document.cookie = 'csrf_token=put-token'; + mockFetch(204, null); + await put('/api/user/password', { + current_password: 'a', + new_password: 'b', + }); + const [, opts] = globalThis.fetch.mock.calls[0]; + expect(opts.headers['X-CSRF-Token']).toBe('put-token'); + }); }); describe('del', () => { @@ -110,7 +193,8 @@ describe('api client', () => { expect(res.json).not.toHaveBeenCalled(); }); - it('sends DELETE method', async () => { + it('sends DELETE method with CSRF token', async () => { + document.cookie = 'csrf_token=del-token'; const res = { ok: true, status: 204, json: vi.fn() }; vi.spyOn(globalThis, 'fetch').mockResolvedValue(res); await del('/api/instances/1'); @@ -118,6 +202,7 @@ describe('api client', () => { const [url, opts] = globalThis.fetch.mock.calls[0]; expect(url).toBe('/api/instances/1'); expect(opts.method).toBe('DELETE'); + expect(opts.headers['X-CSRF-Token']).toBe('del-token'); }); }); }); diff --git a/admin/ui/src/utils/validation.js b/admin/ui/src/utils/validation.js index d2402e1..e6cd435 100644 --- a/admin/ui/src/utils/validation.js +++ b/admin/ui/src/utils/validation.js @@ -4,3 +4,28 @@ export function validateInstanceForm(name, address) { address: address.trim() ? '' : 'Address is required', }; } + +const MIN_PASSWORD_LENGTH = 12; +const MAX_PASSWORD_LENGTH = 72; + +export function validatePasswordChange( + currentPassword, + newPassword, + confirmPassword, +) { + const errors = {}; + if (!currentPassword) errors.currentPassword = 'Current password is required'; + if (!newPassword) { + errors.newPassword = 'New password is required'; + } else if (newPassword.length < MIN_PASSWORD_LENGTH) { + errors.newPassword = `Password must be at least ${MIN_PASSWORD_LENGTH} characters`; + } else if (newPassword.length > MAX_PASSWORD_LENGTH) { + errors.newPassword = `Password must be at most ${MAX_PASSWORD_LENGTH} characters`; + } + if (!confirmPassword) { + errors.confirmPassword = 'Please confirm your new password'; + } else if (newPassword && confirmPassword !== newPassword) { + errors.confirmPassword = 'Passwords do not match'; + } + return errors; +} diff --git a/admin/ui/src/utils/validation.test.js b/admin/ui/src/utils/validation.test.js index 4bfa41f..d58ed42 100644 --- a/admin/ui/src/utils/validation.test.js +++ b/admin/ui/src/utils/validation.test.js @@ -1,5 +1,5 @@ import { describe, it, expect } from 'vitest'; -import { validateInstanceForm } from './validation.js'; +import { validateInstanceForm, validatePasswordChange } from './validation.js'; describe('validateInstanceForm', () => { it('returns no errors for valid inputs', () => { @@ -38,3 +38,56 @@ describe('validateInstanceForm', () => { expect(errors.address).toBe(''); }); }); + +describe('validatePasswordChange', () => { + it('requires all fields', () => { + const errors = validatePasswordChange('', '', ''); + expect(errors.currentPassword).toBe('Current password is required'); + expect(errors.newPassword).toBe('New password is required'); + expect(errors.confirmPassword).toBe('Please confirm your new password'); + }); + + it('rejects passwords shorter than 12 characters', () => { + const errors = validatePasswordChange('currentpass1', 'short', 'short'); + expect(errors.newPassword).toBe('Password must be at least 12 characters'); + }); + + it('rejects passwords longer than 72 characters', () => { + const long = 'a'.repeat(73); + const errors = validatePasswordChange('currentpass1', long, long); + expect(errors.newPassword).toBe('Password must be at most 72 characters'); + }); + + it('rejects mismatched passwords', () => { + const errors = validatePasswordChange( + 'currentpass1', + 'validpassword1', + 'differentpass1', + ); + expect(errors.confirmPassword).toBe('Passwords do not match'); + }); + + it('returns empty object for valid input', () => { + const errors = validatePasswordChange( + 'currentpass1', + 'newpassword12', + 'newpassword12', + ); + expect(Object.keys(errors)).toHaveLength(0); + }); + + it('accepts exactly 12 character password', () => { + const errors = validatePasswordChange( + 'currentpass1', + 'exactly12chr', + 'exactly12chr', + ); + expect(Object.keys(errors)).toHaveLength(0); + }); + + it('accepts exactly 72 character password', () => { + const pw = 'a'.repeat(72); + const errors = validatePasswordChange('currentpass1', pw, pw); + expect(Object.keys(errors)).toHaveLength(0); + }); +}); diff --git a/admin/ui/src/views/LoginView.vue b/admin/ui/src/views/LoginView.vue new file mode 100644 index 0000000..55f7fc9 --- /dev/null +++ b/admin/ui/src/views/LoginView.vue @@ -0,0 +1,126 @@ + + + + + diff --git a/admin/ui/src/views/SettingsView.vue b/admin/ui/src/views/SettingsView.vue new file mode 100644 index 0000000..101dd61 --- /dev/null +++ b/admin/ui/src/views/SettingsView.vue @@ -0,0 +1,168 @@ + + + + + diff --git a/go.work b/go.work index 4fb25d9..f052787 100644 --- a/go.work +++ b/go.work @@ -4,4 +4,5 @@ use ( . ./plugins/contrib ./sdk + ./admin )