diff --git a/README.md b/README.md index 6a0807f..239ba5c 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ services: --providers.docker=true --providers.docker.network=default --experimental.plugins.captcha-protect.modulename=github.com/libops/captcha-protect - --experimental.plugins.captcha-protect.version=v1.13.0 + --experimental.plugins.captcha-protect.version=v1.13.1 volumes: - /var/run/docker.sock:/var/run/docker.sock:z - /CHANGEME/TO/A/HOST/PATH/FOR/STATE/FILE:/tmp/state.json:rw diff --git a/ci/test.go b/ci/test.go index 90ffe7e..f4d2abb 100755 --- a/ci/test.go +++ b/ci/test.go @@ -18,15 +18,18 @@ func main() { _ = os.Remove("./tmp/state.json") fmt.Println("Bringing traefik/nginx online") - runCommand("docker", "compose", "up", "-d") + runCommand("docker", "compose", "down", "--remove-orphans") + runCommand("docker", "compose", "up", "-d", "--force-recreate") waitForService("http://localhost") waitForService("http://localhost/app2") + assertTraefikPluginLogsClean() fmt.Println("Testing Traefik plugin smoke path...") assertProtectedRoute(rootSmokeIP, "http://localhost", "http://localhost/challenge?destination=%2F") assertNoRedirect(rootSmokeIP, "http://localhost/node/123/manifest") assertNoRedirect(rootSmokeIP, "http://localhost/oai/request?foo=bar") assertProtectedRoute(app2SmokeIP, "http://localhost/app2", "http://localhost/challenge?destination=%2Fapp2") + assertTraefikPluginLogsClean() _ = os.Remove("./tmp/state.json") fmt.Println("✓ Traefik plugin smoke test passed") @@ -115,18 +118,58 @@ func httpRequest(ip, url string) (string, error) { return strings.TrimSpace(location.String()), nil } +func assertTraefikPluginLogsClean() { + output, err := commandOutput("docker", "compose", "logs", "--no-color", "traefik") + if err != nil { + slog.Error("Failed to inspect Traefik logs", "err", err, "output", output) + os.Exit(1) + } + + if failure, found := traefikPluginLogFailure(output); found { + slog.Error("Traefik plugin load failure detected", "failure", failure, "logs", output) + os.Exit(1) + } +} + +func traefikPluginLogFailure(output string) (string, bool) { + failures := []string{ + "Plugins are disabled", + "failed to create Yaegi interpreter", + "failed to import plugin code", + "cannot use type", + "cannot define new methods", + } + for _, failure := range failures { + if strings.Contains(output, failure) { + return failure, true + } + } + return "", false +} + +func commandOutput(name string, args ...string) (string, error) { + cmd := exec.Command(name, args...) // #nosec G204 -- CI smoke test invokes fixed docker compose commands. + cmd.Env = testCommandEnv() + output, err := cmd.CombinedOutput() + return string(output), err +} + func runCommand(name string, args ...string) { cmd := exec.Command(name, args...) // #nosec G204 -- CI smoke test invokes fixed docker compose commands. cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - cmd.Env = append(os.Environ(), fmt.Sprintf("RATE_LIMIT=%d", rateLimit)) - - if traefikTag := os.Getenv("TRAEFIK_TAG"); traefikTag != "" { - cmd.Env = append(cmd.Env, fmt.Sprintf("TRAEFIK_TAG=%s", traefikTag)) - } + cmd.Env = testCommandEnv() if err := cmd.Run(); err != nil { slog.Error("Command failed", "err", err) os.Exit(1) } } + +func testCommandEnv() []string { + env := append(os.Environ(), fmt.Sprintf("RATE_LIMIT=%d", rateLimit)) + if traefikTag := os.Getenv("TRAEFIK_TAG"); traefikTag != "" { + env = append(env, fmt.Sprintf("TRAEFIK_TAG=%s", traefikTag)) + } + return env +} diff --git a/ci/test_test.go b/ci/test_test.go new file mode 100644 index 0000000..751abaf --- /dev/null +++ b/ci/test_test.go @@ -0,0 +1,23 @@ +package main + +import "testing" + +func TestTraefikPluginLogFailureDetectsYaegiImportErrors(t *testing.T) { + logs := `traefik-1 | {"level":"error","plugins":["captcha-protect"],"error":"failed to create Yaegi interpreter: failed to import plugin code \"github.com/libops/captcha-protect\": 1:21: import \"github.com/libops/captcha-protect\" error: plugins-local/src/github.com/libops/captcha-protect/main.go:304:23: cannot use type func(string,[]string) bool as type func(context.Context,string,[]string) bool in struct literal","time":"2026-06-24T09:18:16Z","message":"Plugins are disabled because an error has occurred."}` + + failure, found := traefikPluginLogFailure(logs) + if !found { + t.Fatal("expected Traefik plugin load failure to be detected") + } + if failure != "Plugins are disabled" { + t.Fatalf("expected first detected failure %q, got %q", "Plugins are disabled", failure) + } +} + +func TestTraefikPluginLogFailureAllowsCleanLogs(t *testing.T) { + logs := `traefik-1 | {"level":"info","message":"Configuration loaded from flags."}` + + if failure, found := traefikPluginLogFailure(logs); found { + t.Fatalf("did not expect clean logs to fail, got %q", failure) + } +} diff --git a/internal/helper/ip.go b/internal/helper/ip.go index 36b11ad..c96fd58 100644 --- a/internal/helper/ip.go +++ b/internal/helper/ip.go @@ -30,7 +30,11 @@ func ParseCIDR(cidr string) (*net.IPNet, error) { return ipNet, nil } -func IsIpGoodBot(ctx context.Context, clientIP string, goodBots []string) bool { +func IsIpGoodBot(clientIP string, goodBots []string) bool { + return IsIpGoodBotContext(context.Background(), clientIP, goodBots) +} + +func IsIpGoodBotContext(ctx context.Context, clientIP string, goodBots []string) bool { if len(goodBots) == 0 { return false } diff --git a/internal/helper/ip_test.go b/internal/helper/ip_test.go index f7ada1e..c1bcad7 100644 --- a/internal/helper/ip_test.go +++ b/internal/helper/ip_test.go @@ -134,7 +134,7 @@ func TestIsIpGoodBot(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - result := IsIpGoodBot(context.Background(), tc.clientIP, tc.goodBots) + result := IsIpGoodBotContext(context.Background(), tc.clientIP, tc.goodBots) if result != tc.expected { t.Errorf("IsIpGoodBot(%q) = %v; expected %v", tc.clientIP, result, tc.expected) } diff --git a/internal/helper/uptimerobot.go b/internal/helper/uptimerobot.go index 80b93c6..6ad9b43 100644 --- a/internal/helper/uptimerobot.go +++ b/internal/helper/uptimerobot.go @@ -17,11 +17,25 @@ const maxUptimeRobotIPResponseSize = 1 << 20 var UptimeRobotIPRangeURL = "https://api.uptimerobot.com/meta/ips" // UptimeRobotIPs is a thread-safe set of UptimeRobot IP ranges. -type UptimeRobotIPs = GooglebotIPs +type UptimeRobotIPs struct { + ranges *GooglebotIPs +} // NewUptimeRobotIPs creates an empty UptimeRobot IP range set. func NewUptimeRobotIPs() *UptimeRobotIPs { - return NewGooglebotIPs() + return &UptimeRobotIPs{ + ranges: NewGooglebotIPs(), + } +} + +// Update parses a slice of CIDR strings and replaces the existing IP ranges with the new ones. +func (u *UptimeRobotIPs) Update(cidrs []string, log *slog.Logger) { + u.ranges.Update(cidrs, log) +} + +// Contains checks if the given IP address is within any stored UptimeRobot IP range. +func (u *UptimeRobotIPs) Contains(ip net.IP) bool { + return u.ranges.Contains(ip) } type uptimeRobotIPsJSON struct { diff --git a/main.go b/main.go index e0241e5..41c410c 100644 --- a/main.go +++ b/main.go @@ -39,6 +39,7 @@ const ( DefaultHealthCheckPeriodSeconds = 0 // How often to check captcha provider health DefaultHealthCheckFailureThreshold = 0 // Number of consecutive health check failures before opening circuit goodBotLookupTimeout = 2 * time.Second + maxCaptchaChallengeAge = 5 * time.Minute ) type circuitState int @@ -125,7 +126,9 @@ type CaptchaConfig struct { } type captchaResponse struct { - Success bool `json:"success"` + Success bool `json:"success"` + Hostname string `json:"hostname"` + ChallengeTS string `json:"challenge_ts"` } type challengeData struct { @@ -297,7 +300,7 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n }, rateCache: lru.New(expiration, 1*time.Minute), botCache: lru.New(expiration, 1*time.Hour), - goodBotLookup: helper.IsIpGoodBot, + goodBotLookup: helper.IsIpGoodBotContext, verifiedCache: lru.New(expiration, 1*time.Hour), exemptIps: ips, tmpl: tmpl, @@ -357,33 +360,35 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n if config.EnableUptimeRobotBypass == "true" { log.Info("UptimeRobot bypass enabled") bc.uptimeRobotIPs = helper.NewUptimeRobotIPs() - go bc.uptimeRobotIPCheckLoop(ctx) + go uptimeRobotIPCheckLoop(ctx, log, bc.httpClient, bc.uptimeRobotIPs) } return &bc, nil } -func (bc *CaptchaProtect) uptimeRobotIPCheckLoop(ctx context.Context) { +func uptimeRobotIPCheckLoop(ctx context.Context, log *slog.Logger, httpClient *http.Client, uptimeRobotIPs *helper.UptimeRobotIPs) { ticker := time.NewTicker(24 * time.Hour) defer ticker.Stop() - refresh := func() { - count, err := helper.RefreshUptimeRobotIPs(ctx, bc.log, bc.httpClient, bc.uptimeRobotIPs, helper.UptimeRobotIPRangeURL) - if err != nil { - bc.log.Error("failed to fetch UptimeRobot IPs", "err", err) - return - } - bc.log.Info("Updated UptimeRobot IPs", "count", count) - } - if ctx.Err() != nil { return } - refresh() + count, err := helper.RefreshUptimeRobotIPs(ctx, log, httpClient, uptimeRobotIPs, helper.UptimeRobotIPRangeURL) + if err != nil { + log.Error("failed to fetch UptimeRobot IPs", "err", err) + } else { + log.Info("Updated UptimeRobot IPs", "count", count) + } + for { select { case <-ticker.C: - refresh() + count, err := helper.RefreshUptimeRobotIPs(ctx, log, httpClient, uptimeRobotIPs, helper.UptimeRobotIPRangeURL) + if err != nil { + log.Error("failed to fetch UptimeRobot IPs", "err", err) + continue + } + log.Info("Updated UptimeRobot IPs", "count", count) case <-ctx.Done(): return } @@ -687,6 +692,16 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. var body = url.Values{} body.Add("secret", bc.config.SecretKey) body.Add("response", response) + if activeConfig.key == "cf-turnstile" { + idempotencyKey, err := randomUUID() + if err != nil { + bc.log.Error("unable to create turnstile idempotency key", "err", err) + http.Error(rw, "Internal error", http.StatusInternalServerError) + return http.StatusInternalServerError + } + body.Add("remoteip", ip) + body.Add("idempotency_key", idempotencyKey) + } validationReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, activeConfig.validate, strings.NewReader(body.Encode())) if err != nil { bc.log.Error("unable to create captcha validation request", "url", activeConfig.validate, "err", err) @@ -715,6 +730,28 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. } success = captchaResponse.Success + if success && activeConfig.key == "cf-turnstile" { + expectedHostname := captchaValidationHostname(req) + if captchaResponse.Hostname != expectedHostname { + bc.log.Warn("captcha hostname mismatch", "hostname", captchaResponse.Hostname, "expectedHostname", expectedHostname) + success = false + } else { + challengeTime, err := time.Parse(time.RFC3339Nano, captchaResponse.ChallengeTS) + if err != nil { + bc.log.Warn("invalid captcha challenge timestamp", "challenge_ts", captchaResponse.ChallengeTS, "err", err) + success = false + } else { + age := time.Since(challengeTime) + if age < 0 { + age = 0 + } + if age > maxCaptchaChallengeAge { + bc.log.Warn("stale captcha challenge rejected", "challenge_ts", captchaResponse.ChallengeTS, "age", age) + success = false + } + } + } + } } if success { @@ -731,6 +768,35 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. return http.StatusForbidden } +func captchaValidationHostname(req *http.Request) string { + host := req.Host + if host == "" { + host = req.URL.Host + } + if hostname, _, err := net.SplitHostPort(host); err == nil { + return hostname + } + return host +} + +func randomUUID() (string, error) { + var b [16]byte + if _, err := crand.Read(b[:]); err != nil { + return "", err + } + + b[6] = (b[6] & 0x0f) | 0x40 + b[8] = (b[8] & 0x3f) | 0x80 + + return fmt.Sprintf("%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + b[0], b[1], b[2], b[3], + b[4], b[5], + b[6], b[7], + b[8], b[9], + b[10], b[11], b[12], b[13], b[14], b[15], + ), nil +} + func normalizeDestination(destination string) string { if destination == "" { return "/" diff --git a/main_test.go b/main_test.go index c7fc803..2b69c73 100644 --- a/main_test.go +++ b/main_test.go @@ -3,6 +3,7 @@ package captcha_protect import ( "context" "encoding/json" + "fmt" "log/slog" "net" "net/http" @@ -1210,6 +1211,9 @@ func TestStateBookkeepingErrorBranches(t *testing.T) { } func TestVerifyChallengePage(t *testing.T) { + validChallengeTS := time.Now().Add(-1 * time.Minute).Format(time.RFC3339Nano) + validCaptchaResponse := fmt.Sprintf(`{"success":true,"hostname":"example.com","challenge_ts":%q}`, validChallengeTS) + tests := []struct { name string provider string @@ -1233,7 +1237,7 @@ func TestVerifyChallengePage(t *testing.T) { "cf-turnstile-response": "valid-token", "destination": "%2Fhome", }, - mockResponse: `{"success":true}`, + mockResponse: validCaptchaResponse, expectedStatus: http.StatusFound, shouldSetCache: true, expectedLocation: "/home", @@ -1245,7 +1249,7 @@ func TestVerifyChallengePage(t *testing.T) { "cf-turnstile-response": "valid-token", "destination": "/Chrome%20+%20MariaDB%20", }, - mockResponse: `{"success":true}`, + mockResponse: validCaptchaResponse, expectedStatus: http.StatusFound, shouldSetCache: true, expectedLocation: "/Chrome%20+%20MariaDB%20", @@ -1278,7 +1282,7 @@ func TestVerifyChallengePage(t *testing.T) { "cf-turnstile-response": "valid-token", "destination": "%ZZ", }, - mockResponse: `{"success":true}`, + mockResponse: validCaptchaResponse, expectedStatus: http.StatusFound, shouldSetCache: true, expectedLocation: "/", @@ -1334,6 +1338,215 @@ func TestVerifyChallengePage(t *testing.T) { } } +func TestVerifyChallengePageRejectsInvalidSiteverifyMetadata(t *testing.T) { + tests := []struct { + name string + mockResponse string + }{ + { + name: "success false", + mockResponse: fmt.Sprintf(`{"success":false,"hostname":"example.com","challenge_ts":%q}`, time.Now().Format(time.RFC3339Nano)), + }, + { + name: "hostname mismatch", + mockResponse: fmt.Sprintf(`{"success":true,"hostname":"evil.example","challenge_ts":%q}`, time.Now().Format(time.RFC3339Nano)), + }, + { + name: "stale challenge", + mockResponse: fmt.Sprintf(`{"success":true,"hostname":"example.com","challenge_ts":%q}`, time.Now().Add(-6*time.Minute).Format(time.RFC3339Nano)), + }, + { + name: "missing challenge timestamp", + mockResponse: `{"success":true,"hostname":"example.com"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tt.mockResponse)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.CaptchaProvider = "turnstile" + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc.captchaConfig.validate = mockServer.URL + + req := httptest.NewRequest(http.MethodPost, "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("cf-turnstile-response", "token-"+tt.name) + + rr := httptest.NewRecorder() + status := bc.verifyChallengePage(rr, req, "1.2.3.4") + + if status != http.StatusForbidden { + t.Fatalf("expected status %d, got %d", http.StatusForbidden, status) + } + if _, found := bc.verifiedCache.Get("1.2.3.4"); found { + t.Fatal("did not expect invalid siteverify response to set verified cache") + } + }) + } +} + +func TestVerifyChallengePageAllowsNonTurnstileWithoutMetadata(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true}`)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.CaptchaProvider = "hcaptcha" + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc.captchaConfig.validate = mockServer.URL + + req := httptest.NewRequest(http.MethodPost, "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("h-captcha-response", "valid-token") + + rr := httptest.NewRecorder() + status := bc.verifyChallengePage(rr, req, "1.2.3.4") + + if status != http.StatusFound { + t.Fatalf("expected status %d, got %d", http.StatusFound, status) + } + if _, found := bc.verifiedCache.Get("1.2.3.4"); !found { + t.Fatal("expected non-turnstile success response to set verified cache") + } +} + +func TestVerifyChallengePageMatchesTurnstileHostnameWithoutRequestPort(t *testing.T) { + validChallengeTS := time.Now().Format(time.RFC3339Nano) + mockResponse := fmt.Sprintf(`{"success":true,"hostname":"example.com","challenge_ts":%q}`, validChallengeTS) + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(mockResponse)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + config.CaptchaProvider = "turnstile" + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc.captchaConfig.validate = mockServer.URL + + req := httptest.NewRequest(http.MethodPost, "http://example.com:8443/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("cf-turnstile-response", "valid-token") + + rr := httptest.NewRecorder() + status := bc.verifyChallengePage(rr, req, "1.2.3.4") + + if status != http.StatusFound { + t.Fatalf("expected status %d, got %d", http.StatusFound, status) + } +} + +func TestVerifyChallengePageSendsTurnstileAdvancedValidationFields(t *testing.T) { + validChallengeTS := time.Now().Format(time.RFC3339Nano) + mockResponse := fmt.Sprintf(`{"success":true,"hostname":"example.com","challenge_ts":%q}`, validChallengeTS) + var siteverifyForm url.Values + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm failed: %v", err) + } + siteverifyForm = r.PostForm + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(mockResponse)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test-secret" + config.ProtectRoutes = []string{"/"} + config.CaptchaProvider = "turnstile" + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc.captchaConfig.validate = mockServer.URL + + req := httptest.NewRequest(http.MethodPost, "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("cf-turnstile-response", "valid-token") + + status := bc.verifyChallengePage(httptest.NewRecorder(), req, "1.2.3.4") + if status != http.StatusFound { + t.Fatalf("expected status %d, got %d", http.StatusFound, status) + } + + if got := siteverifyForm.Get("secret"); got != "test-secret" { + t.Fatalf("expected secret %q, got %q", "test-secret", got) + } + if got := siteverifyForm.Get("response"); got != "valid-token" { + t.Fatalf("expected response %q, got %q", "valid-token", got) + } + if got := siteverifyForm.Get("remoteip"); got != "1.2.3.4" { + t.Fatalf("expected remoteip %q, got %q", "1.2.3.4", got) + } + + idempotencyKey := siteverifyForm.Get("idempotency_key") + if idempotencyKey == "" { + t.Fatal("expected idempotency_key to be sent") + } + if !regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`).MatchString(idempotencyKey) { + t.Fatalf("expected idempotency_key to be UUID v4, got %q", idempotencyKey) + } +} + +func TestVerifyChallengePageDoesNotSendTurnstileFieldsToOtherProviders(t *testing.T) { + var siteverifyForm url.Values + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm failed: %v", err) + } + siteverifyForm = r.PostForm + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true}`)) + })) + defer mockServer.Close() + + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test-secret" + config.ProtectRoutes = []string{"/"} + config.CaptchaProvider = "hcaptcha" + + bc, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc.captchaConfig.validate = mockServer.URL + + req := httptest.NewRequest(http.MethodPost, "http://example.com/challenge", nil) + req.Form = make(map[string][]string) + req.Form.Set("h-captcha-response", "valid-token") + + status := bc.verifyChallengePage(httptest.NewRecorder(), req, "1.2.3.4") + if status != http.StatusFound { + t.Fatalf("expected status %d, got %d", http.StatusFound, status) + } + if got := siteverifyForm.Get("remoteip"); got != "" { + t.Fatalf("expected remoteip to be omitted for hcaptcha, got %q", got) + } + if got := siteverifyForm.Get("idempotency_key"); got != "" { + t.Fatalf("expected idempotency_key to be omitted for hcaptcha, got %q", got) + } +} + func TestVerifyChallengePageHTTPError(t *testing.T) { // Test HTTP client error config := CreateConfig() @@ -2111,7 +2324,7 @@ func TestUptimeRobotIPCheckLoopInitialFetch(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) go func() { - bc.uptimeRobotIPCheckLoop(ctx) + uptimeRobotIPCheckLoop(ctx, bc.log, bc.httpClient, bc.uptimeRobotIPs) close(done) }()