diff --git a/.gitignore b/.gitignore index 0a20beb..5aa0e7b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ __pycache__/ .venv/ node_modules/ apps/openant-cli/bin/ -docs/ +libs/openant-core/parsers/go/go_parser/go_parser +# docs/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..bb51bb2 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,42 @@ +# Changelog + +All notable changes to OpenAnt are documented in this file. + +## [Unreleased] + +This release syncs a large body of work from internal development. Highlights: + +### Added + +- **Parallelization** across all pipeline stages: + - Stage 1 analysis (Detect), Stage 2 verification, Enhance, and Dynamic Test now run units concurrently via worker pools. + - Thread-safe `TokenTracker` and `ProgressReporter` for correct aggregate metrics under parallel execution. + - Shared HTTP client and a token-bucket `RateLimiter` (`libs/openant-core/utilities/rate_limiter.py`) to stay within Anthropic API rate limits. +- **Checkpoint / resume system** (`libs/openant-core/core/checkpoint.py`): every phase persists per-unit progress so interrupted scans can resume without re-running completed work. +- **Zig parser** (`libs/openant-core/parsers/zig/`): repository scanner, unit generator, and test pipeline. +- **HTML report improvements** (`apps/openant-cli/internal/report/`): + - Two themes: dark (`overview.gohtml`) and Knostic-branded light (`report-reskin.gohtml`). + - Report header shows repo name, commit SHA, language, total scan duration (formatted `Xd Xh Xm Xs`), and cost. + - Findings are numbered (`#N`), have anchor IDs, and are grouped into collapsible sections by verdict (vulnerable / bypassable open by default; inconclusive / protected / safe closed). + - Within each verdict group, findings are sub-sorted by dynamic test outcome (CONFIRMED first, NOT_REPRODUCED last). + - File paths link directly to the repo at the exact commit. + - Pipeline Costs & Timing section with per-step breakdown and a Totals row. + - Executive Summary links to findings via `#N` references; priority labels (Critical / High / Medium) replace fabricated timeframes. +- **Dynamic testing** hardening: structured result classification (CONFIRMED / NOT_REPRODUCED / BLOCKED / INCONCLUSIVE / ERROR), Docker template updates, retry logic, and checkpoint-aware resume. +- `openant build-output` and `openant dynamic-test` subcommands with prompt-before-skip UX. + +### Changed + +- Finding verifier (`utilities/finding_verifier.py`) hardened with better error handling and agentic tool integration. +- Context enhancer (`utilities/context_enhancer.py`) overhauled for parallel, agentic enhancement. +- Report data pipeline rewritten: Python computes a `ReportData` JSON blob; Go renders the HTML template. +- Cost tracking reworked to report per-unit costs in progress output and aggregate correctly across parallel workers. + +### Fixed + +- Cost tracking no longer shows negative or incorrect totals under parallel execution. +- `merge_dynamic_results` no longer contaminates stdout, unblocking clean JSON output. +- HTML report entities (`>`, `<`) render correctly (previously double-escaped). +- "Max iterations reached" verifier timeouts now mark findings as `inconclusive` rather than leaving a stale verdict. +- Checkpoint resume behavior unified across phases. +- Stdin race during interactive signal forwarding. diff --git a/apps/openant-cli/cmd/analyze.go b/apps/openant-cli/cmd/analyze.go index 4ba3e18..986213b 100644 --- a/apps/openant-cli/cmd/analyze.go +++ b/apps/openant-cli/cmd/analyze.go @@ -4,6 +4,7 @@ import ( "fmt" "os" + "github.com/knostic/open-ant-cli/internal/checkpoint" "github.com/knostic/open-ant-cli/internal/output" "github.com/knostic/open-ant-cli/internal/python" "github.com/spf13/cobra" @@ -31,6 +32,9 @@ var ( analyzeExploitOnly bool analyzeLimit int analyzeModel string + analyzeWorkers int + analyzeCheckpoint string + analyzeBackoff int ) func init() { @@ -42,6 +46,9 @@ func init() { analyzeCmd.Flags().BoolVar(&analyzeExploitOnly, "exploitable-only", false, "Only analyze units classified as exploitable by enhancer") analyzeCmd.Flags().IntVar(&analyzeLimit, "limit", 0, "Max units to analyze (0 = no limit)") analyzeCmd.Flags().StringVar(&analyzeModel, "model", "opus", "Model: opus or sonnet") + analyzeCmd.Flags().IntVar(&analyzeWorkers, "workers", 8, "Number of parallel workers for LLM steps (default: 8)") + analyzeCmd.Flags().StringVar(&analyzeCheckpoint, "checkpoint", "", "Path to checkpoint directory for save/resume") + analyzeCmd.Flags().IntVar(&analyzeBackoff, "backoff", 30, "Seconds to wait when rate-limited (default: 30)") } func runAnalyze(cmd *cobra.Command, args []string) { @@ -74,6 +81,17 @@ func runAnalyze(cmd *cobra.Command, args []string) { os.Exit(2) } + // Auto-detect checkpoints from a previous interrupted run + if analyzeCheckpoint == "" && ctx != nil { + if cpInfo := checkpoint.DetectViaPython(rt.Path, ctx.ScanDir, "analyze"); cpInfo != nil { + if checkpoint.PromptResume(cpInfo, "analyze", quiet) { + analyzeCheckpoint = cpInfo.Dir + } else { + _ = checkpoint.Clean(cpInfo.Dir) + } + } + } + pyArgs := []string{"analyze", datasetPath, "--output", analyzeOutput} if analyzeVerify { pyArgs = append(pyArgs, "--verify") @@ -96,6 +114,15 @@ func runAnalyze(cmd *cobra.Command, args []string) { if analyzeModel != "opus" { pyArgs = append(pyArgs, "--model", analyzeModel) } + if analyzeWorkers != 8 { + pyArgs = append(pyArgs, "--workers", fmt.Sprintf("%d", analyzeWorkers)) + } + if analyzeCheckpoint != "" { + pyArgs = append(pyArgs, "--checkpoint", analyzeCheckpoint) + } + if analyzeBackoff != 30 { + pyArgs = append(pyArgs, "--backoff", fmt.Sprintf("%d", analyzeBackoff)) + } result, err := python.Invoke(rt.Path, pyArgs, "", quiet, requireAPIKey()) if err != nil { @@ -103,7 +130,9 @@ func runAnalyze(cmd *cobra.Command, args []string) { os.Exit(2) } - if jsonOutput { + if result.Envelope.Status == "interrupted" { + os.Exit(130) + } else if jsonOutput { output.PrintJSON(result.Envelope) } else if result.Envelope.Status == "success" { if data, ok := result.Envelope.Data.(map[string]any); ok { diff --git a/apps/openant-cli/cmd/dynamictest.go b/apps/openant-cli/cmd/dynamictest.go index c46ab82..5ff99e4 100644 --- a/apps/openant-cli/cmd/dynamictest.go +++ b/apps/openant-cli/cmd/dynamictest.go @@ -4,6 +4,7 @@ import ( "fmt" "os" + "github.com/knostic/open-ant-cli/internal/checkpoint" "github.com/knostic/open-ant-cli/internal/output" "github.com/knostic/open-ant-cli/internal/python" "github.com/spf13/cobra" @@ -41,6 +42,12 @@ func runDynamicTest(cmd *cobra.Command, args []string) { os.Exit(2) } + // Check pipeline_output.json exists before launching Python + if _, err := os.Stat(pipelineOutputPath); err != nil { + output.PrintError("pipeline_output.json not found. Run 'openant build-output' first.") + os.Exit(2) + } + // Apply project defaults if ctx != nil { if dynamicTestOutput == "" { @@ -54,6 +61,17 @@ func runDynamicTest(cmd *cobra.Command, args []string) { os.Exit(2) } + // Auto-detect checkpoints from a previous interrupted run + if ctx != nil { + if cpInfo := checkpoint.DetectViaPython(rt.Path, ctx.ScanDir, "dynamic_test"); cpInfo != nil { + if checkpoint.PromptResume(cpInfo, "dynamic-test", quiet) { + // Resume — Python auto-detects checkpoint dir in output dir + } else { + _ = checkpoint.Clean(cpInfo.Dir) + } + } + } + pyArgs := []string{"dynamic-test", pipelineOutputPath} if dynamicTestOutput != "" { pyArgs = append(pyArgs, "--output", dynamicTestOutput) @@ -68,7 +86,9 @@ func runDynamicTest(cmd *cobra.Command, args []string) { os.Exit(2) } - if jsonOutput { + if result.Envelope.Status == "interrupted" { + os.Exit(130) + } else if jsonOutput { output.PrintJSON(result.Envelope) } else if result.Envelope.Status == "success" { if data, ok := result.Envelope.Data.(map[string]any); ok { diff --git a/apps/openant-cli/cmd/enhance.go b/apps/openant-cli/cmd/enhance.go index d316c9b..5381213 100644 --- a/apps/openant-cli/cmd/enhance.go +++ b/apps/openant-cli/cmd/enhance.go @@ -1,8 +1,10 @@ package cmd import ( + "fmt" "os" + "github.com/knostic/open-ant-cli/internal/checkpoint" "github.com/knostic/open-ant-cli/internal/output" "github.com/knostic/open-ant-cli/internal/python" "github.com/spf13/cobra" @@ -28,6 +30,8 @@ var ( enhanceRepoPath string enhanceMode string enhanceCheckpoint string + enhanceWorkers int + enhanceBackoff int ) func init() { @@ -36,6 +40,8 @@ func init() { enhanceCmd.Flags().StringVar(&enhanceRepoPath, "repo-path", "", "Path to the repository (required for agentic mode)") enhanceCmd.Flags().StringVar(&enhanceMode, "mode", "agentic", "Enhancement mode: agentic (thorough) or single-shot (fast)") enhanceCmd.Flags().StringVar(&enhanceCheckpoint, "checkpoint", "", "Path to save/resume checkpoint (agentic mode)") + enhanceCmd.Flags().IntVar(&enhanceWorkers, "workers", 8, "Number of parallel workers for LLM steps (default: 8)") + enhanceCmd.Flags().IntVar(&enhanceBackoff, "backoff", 30, "Seconds to wait when rate-limited (default: 30)") } func runEnhance(cmd *cobra.Command, args []string) { @@ -64,6 +70,18 @@ func runEnhance(cmd *cobra.Command, args []string) { os.Exit(2) } + // Auto-detect checkpoints from a previous interrupted run + if enhanceCheckpoint == "" && ctx != nil { + if cpInfo := checkpoint.DetectViaPython(rt.Path, ctx.ScanDir, "enhance"); cpInfo != nil { + if checkpoint.PromptResume(cpInfo, "enhance", quiet) { + enhanceCheckpoint = cpInfo.Dir + } else { + // User chose fresh start — remove old checkpoints + _ = checkpoint.Clean(cpInfo.Dir) + } + } + } + pyArgs := []string{"enhance", datasetPath} if enhanceOutput != "" { pyArgs = append(pyArgs, "--output", enhanceOutput) @@ -80,6 +98,12 @@ func runEnhance(cmd *cobra.Command, args []string) { if enhanceCheckpoint != "" { pyArgs = append(pyArgs, "--checkpoint", enhanceCheckpoint) } + if enhanceWorkers != 8 { + pyArgs = append(pyArgs, "--workers", fmt.Sprintf("%d", enhanceWorkers)) + } + if enhanceBackoff != 30 { + pyArgs = append(pyArgs, "--backoff", fmt.Sprintf("%d", enhanceBackoff)) + } result, err := python.Invoke(rt.Path, pyArgs, "", quiet, requireAPIKey()) if err != nil { @@ -87,7 +111,9 @@ func runEnhance(cmd *cobra.Command, args []string) { os.Exit(2) } - if jsonOutput { + if result.Envelope.Status == "interrupted" { + os.Exit(130) + } else if jsonOutput { output.PrintJSON(result.Envelope) } else if result.Envelope.Status == "success" { if data, ok := result.Envelope.Data.(map[string]any); ok { diff --git a/apps/openant-cli/cmd/report.go b/apps/openant-cli/cmd/report.go index 2d5b872..d2b34b7 100644 --- a/apps/openant-cli/cmd/report.go +++ b/apps/openant-cli/cmd/report.go @@ -1,11 +1,18 @@ package cmd import ( + "encoding/json" + "fmt" "os" "path/filepath" + "strings" + "github.com/charmbracelet/huh" + "github.com/fatih/color" "github.com/knostic/open-ant-cli/internal/output" "github.com/knostic/open-ant-cli/internal/python" + "github.com/knostic/open-ant-cli/internal/report" + isatty "github.com/mattn/go-isatty" "github.com/spf13/cobra" ) @@ -15,17 +22,18 @@ var reportCmd = &cobra.Command{ Long: `Report generates reports from analysis results or pipeline output. Formats: - html HTML report with interactive findings (default) - csv CSV export of all findings - summary Markdown summary report (requires --pipeline-output) - disclosure Per-vulnerability disclosure documents (requires --pipeline-output) + disclosure Per-vulnerability disclosure documents (default, uses LLM) + summary Narrative security overview (uses LLM) + html Interactive HTML report with charts and filters + csv Spreadsheet export of all findings If no results path is given, the active project's results_verified.json is used. +Python owns default output paths — you only need -o to override. Examples: - openant report results.json -o report/ --dataset dataset.json - openant report --pipeline-output pipeline_output.json -f summary -o report/SUMMARY.md - openant report -f disclosure -o report/disclosures/`, + openant report -p myproject + openant report -p myproject -f summary + openant report results.json -f html -o report.html --dataset dataset.json`, Args: cobra.MaximumNArgs(1), Run: runReport, } @@ -36,14 +44,21 @@ var ( reportFormat string reportPipelineOutput string reportRepoName string + reportExtraDest string ) func init() { - reportCmd.Flags().StringVarP(&reportOutput, "output", "o", "", "Output path") + reportCmd.Flags().StringVarP(&reportOutput, "output", "o", "", "Output path (default: derived from format)") reportCmd.Flags().StringVar(&reportDataset, "dataset", "", "Path to dataset JSON (for html/csv)") - reportCmd.Flags().StringVarP(&reportFormat, "format", "f", "html", "Report format: html, csv, summary, disclosure") + reportCmd.Flags().StringVarP(&reportFormat, "format", "f", "", "Report format: disclosure, summary, html, csv") reportCmd.Flags().StringVar(&reportPipelineOutput, "pipeline-output", "", "Path to pipeline_output.json (for summary/disclosure)") reportCmd.Flags().StringVar(&reportRepoName, "repo-name", "", "Repository name (used when auto-building pipeline_output)") + reportCmd.Flags().StringVar(&reportExtraDest, "copy-to", "", "Copy reports to an additional location") +} + +// isInteractive returns true if stdin is a terminal and we're not in quiet mode. +func isInteractive() bool { + return !quiet && isatty.IsTerminal(os.Stdin.Fd()) } func runReport(cmd *cobra.Command, args []string) { @@ -53,13 +68,13 @@ func runReport(cmd *cobra.Command, args []string) { os.Exit(2) } - // Apply project defaults + // Apply project defaults for pipeline-output, repo-name, dataset if ctx != nil { - if reportOutput == "" { - reportOutput = filepath.Join(ctx.ScanDir, "report") - } if reportPipelineOutput == "" { - reportPipelineOutput = ctx.scanFile("pipeline_output.json") + candidate := ctx.scanFile("pipeline_output.json") + if _, err := os.Stat(candidate); err == nil { + reportPipelineOutput = candidate + } } if reportRepoName == "" { reportRepoName = ctx.Project.Name @@ -68,20 +83,279 @@ func runReport(cmd *cobra.Command, args []string) { reportDataset = ctx.scanFile("dataset_enhanced.json") } } - if reportOutput == "" { - output.PrintError("--output is required (or use openant init to set up a project)") + + // Check prerequisite steps before generating reports + if ctx != nil { + yellow := color.New(color.FgYellow, color.Bold) + + // Check if build-output has been run (needed for summary/disclosure and dynamic-test) + poPath := ctx.scanFile("pipeline_output.json") + if _, err := os.Stat(poPath); err != nil { + if isInteractive() { + yellow.Fprintln(os.Stderr, "pipeline_output.json not found — 'openant build-output' hasn't been run yet.") + fmt.Fprint(os.Stderr, "Continue without it? [Y/n] ") + var answer string + fmt.Scanln(&answer) + answer = strings.TrimSpace(strings.ToLower(answer)) + if answer == "n" || answer == "no" { + fmt.Fprintln(os.Stderr, "Run 'openant build-output' first, then re-run report.") + os.Exit(0) + } + } else { + yellow.Fprintln(os.Stderr, "Warning: pipeline_output.json not found — summary/disclosure reports will not be available.") + } + } + + // Check if dynamic tests have been run + dtPath := ctx.scanFile("dynamic_test_results.json") + if _, err := os.Stat(dtPath); err != nil { + if isInteractive() { + yellow.Fprintln(os.Stderr, "Dynamic tests haven't been run yet.") + fmt.Fprint(os.Stderr, "Continue without dynamic test results? [Y/n] ") + var answer string + fmt.Scanln(&answer) + answer = strings.TrimSpace(strings.ToLower(answer)) + if answer == "n" || answer == "no" { + fmt.Fprintln(os.Stderr, "Run 'openant dynamic-test' first, then re-run report.") + os.Exit(0) + } + } else { + yellow.Fprintln(os.Stderr, "Warning: dynamic tests haven't been run — reports will not include dynamic test results.") + } + } + } + + // Determine formats to generate + formatFlagSet := cmd.Flags().Changed("format") + formats := []string{} + + if formatFlagSet { + // User explicitly provided -f, use it directly + formats = []string{reportFormat} + } else if isInteractive() { + // Interactive: show format picker + selected, err := promptFormats() + if err != nil { + output.PrintError(err.Error()) + os.Exit(2) + } + if len(selected) == 0 { + output.PrintError("No formats selected") + os.Exit(2) + } + formats = selected + } else { + // Non-interactive, no flag: use default + formats = []string{"disclosure"} + } + + // Check if any selected format requires pipeline_output.json + needsPipelineOutput := false + for _, f := range formats { + if f == "summary" || f == "disclosure" { + needsPipelineOutput = true + break + } + } + if needsPipelineOutput && reportPipelineOutput == "" { + output.PrintError("It seems like you haven't run 'openant build-output'. You must run it first.") os.Exit(2) } + // Prompt for extra output location (interactive only, unless --copy-to given) + scanDir := "" + if ctx != nil { + scanDir = ctx.ScanDir + } + extraDest := reportExtraDest + if extraDest == "" && !formatFlagSet && isInteractive() { + extraDest, err = promptExtraLocation(scanDir) + if err != nil { + output.PrintError(err.Error()) + os.Exit(2) + } + } + + // Ensure Python runtime rt, err := ensurePython() if err != nil { output.PrintError(err.Error()) os.Exit(2) } - pyArgs := []string{"report", resultsPath, "--output", reportOutput} - if reportFormat != "html" { - pyArgs = append(pyArgs, "--format", reportFormat) + // Run each selected format + exitCode := 0 + var allResults []map[string]any + + for _, fmt := range formats { + if fmt == "html" { + // HTML reports use the Go renderer + outputPath := reportOutput + if outputPath == "" { + // Derive default: same dir as results, in final-reports/ + resultsDir := filepath.Dir(resultsPath) + outputPath = filepath.Join(resultsDir, "final-reports", "report.html") + } + + reskinPath := filepath.Join(filepath.Dir(outputPath), "report-reskin.html") + if err := runHTMLReport(rt, resultsPath, outputPath); err != nil { + output.PrintError("html: " + err.Error()) + exitCode = 2 + continue + } + + data := map[string]any{ + "output_path": outputPath, + "reskin_path": reskinPath, + "format": "html", + } + if !jsonOutput { + output.PrintReportSummary(data) + } + allResults = append(allResults, data) + } else { + // Other formats delegate to Python + pyArgs := buildReportArgs(resultsPath, fmt) + + result, err := python.Invoke(rt.Path, pyArgs, "", quiet, resolvedAPIKey()) + if err != nil { + output.PrintError(fmt + ": " + err.Error()) + exitCode = 2 + continue + } + + if result.ExitCode != 0 { + exitCode = result.ExitCode + } + + if jsonOutput { + output.PrintJSON(result.Envelope) + } else if result.Envelope.Status == "success" { + if data, ok := result.Envelope.Data.(map[string]any); ok { + output.PrintReportSummary(data) + allResults = append(allResults, data) + } + } else { + output.PrintErrors(result.Envelope.Errors) + } + } + } + + // Copy to extra location if requested + if extraDest != "" && len(allResults) > 0 { + copyReportsToExtra(allResults, extraDest) + } + + os.Exit(exitCode) +} + +// promptFormats shows an interactive multi-select with spacebar toggle. +func promptFormats() ([]string, error) { + var selected []string + + form := huh.NewForm( + huh.NewGroup( + huh.NewMultiSelect[string](). + Title("Select report formats"). + Options( + huh.NewOption("Disclosure docs — per-vulnerability reports for responsible disclosure ($)", "disclosure").Selected(true), + huh.NewOption("Summary — narrative security overview written by AI ($)", "summary"), + huh.NewOption("HTML — interactive report with charts and filters", "html"), + huh.NewOption("CSV — spreadsheet export of all findings", "csv"), + ). + Value(&selected), + ), + ) + + err := form.Run() + if err != nil { + return nil, err + } + + return selected, nil +} + +// promptExtraLocation asks the user for an optional extra output directory. +func promptExtraLocation(scanDir string) (string, error) { + var dest string + + title := "Copy reports to additional location?" + if scanDir != "" { + title = fmt.Sprintf("Reports will be saved to %s/final-reports/\nCopy to additional location?", scanDir) + } + + form := huh.NewForm( + huh.NewGroup( + huh.NewInput(). + Title(title). + Prompt("> "). + Placeholder("enter to skip"). + Value(&dest), + ), + ) + + err := form.Run() + if err != nil { + return "", err + } + + return strings.TrimSpace(dest), nil +} + +// runHTMLReport generates an HTML report using the Go renderer. +// It calls Python's report-data subcommand to get pre-computed data, +// then renders the HTML template. +func runHTMLReport(rt *python.RuntimeInfo, resultsPath string, outputPath string) error { + // 1. Call Python report-data to get pre-computed JSON + pyArgs := []string{"report-data", resultsPath} + if reportDataset != "" { + pyArgs = append(pyArgs, "--dataset", reportDataset) + } + + result, err := python.Invoke(rt.Path, pyArgs, "", quiet, resolvedAPIKey()) + if err != nil { + return fmt.Errorf("report-data failed: %w", err) + } + if result.Envelope.Status != "success" { + msg := "report-data returned error" + if len(result.Envelope.Errors) > 0 { + msg = result.Envelope.Errors[0] + } + return fmt.Errorf("%s", msg) + } + + // 2. Marshal data back to JSON, then unmarshal into our struct + dataBytes, err := json.Marshal(result.Envelope.Data) + if err != nil { + return fmt.Errorf("failed to marshal report data: %w", err) + } + + var reportData report.ReportData + if err := json.Unmarshal(dataBytes, &reportData); err != nil { + return fmt.Errorf("failed to parse report data: %w", err) + } + + // 3. Render HTML (original dark theme) + if err := report.GenerateOverview(reportData, outputPath); err != nil { + return fmt.Errorf("failed to render HTML: %w", err) + } + + // 4. Render reskin HTML (Knostic light theme) alongside the original + reskinPath := filepath.Join(filepath.Dir(outputPath), "report-reskin.html") + if err := report.GenerateReskin(reportData, reskinPath); err != nil { + return fmt.Errorf("failed to render reskin HTML: %w", err) + } + + return nil +} + +// buildReportArgs constructs the Python CLI arguments for a single format. +func buildReportArgs(resultsPath string, format string) []string { + pyArgs := []string{"report", resultsPath, "--format", format} + + // Only pass --output if user explicitly set it + if reportOutput != "" { + pyArgs = append(pyArgs, "--output", reportOutput) } if reportDataset != "" { pyArgs = append(pyArgs, "--dataset", reportDataset) @@ -93,21 +367,78 @@ func runReport(cmd *cobra.Command, args []string) { pyArgs = append(pyArgs, "--repo-name", reportRepoName) } - result, err := python.Invoke(rt.Path, pyArgs, "", quiet, resolvedAPIKey()) - if err != nil { - output.PrintError(err.Error()) - os.Exit(2) + return pyArgs +} + +// copyReportsToExtra copies generated report files/dirs to the extra destination. +func copyReportsToExtra(results []map[string]any, dest string) { + cyan := color.New(color.FgCyan) + + // Ensure dest directory exists + if err := os.MkdirAll(dest, 0o755); err != nil { + output.PrintError("Failed to create " + dest + ": " + err.Error()) + return } - if jsonOutput { - output.PrintJSON(result.Envelope) - } else if result.Envelope.Status == "success" { - if data, ok := result.Envelope.Data.(map[string]any); ok { - output.PrintReportSummary(data) + for _, data := range results { + srcPath, ok := data["output_path"].(string) + if !ok || srcPath == "" { + continue } - } else { - output.PrintErrors(result.Envelope.Errors) + + info, err := os.Stat(srcPath) + if err != nil { + output.PrintError("Cannot access " + srcPath + ": " + err.Error()) + continue + } + + if info.IsDir() { + // Copy directory recursively + destDir := filepath.Join(dest, filepath.Base(srcPath)) + if err := copyDir(srcPath, destDir); err != nil { + output.PrintError("Failed to copy " + srcPath + ": " + err.Error()) + continue + } + cyan.Printf(" Copied: ") + fmt.Println(destDir) + } else { + // Copy single file + destFile := filepath.Join(dest, filepath.Base(srcPath)) + if err := copyFile(srcPath, destFile); err != nil { + output.PrintError("Failed to copy " + srcPath + ": " + err.Error()) + continue + } + cyan.Printf(" Copied: ") + fmt.Println(destFile) + } + } +} + +// copyFile copies a single file from src to dst. +func copyFile(src, dst string) error { + data, err := os.ReadFile(src) + if err != nil { + return err } + return os.WriteFile(dst, data, 0o644) +} + +// copyDir recursively copies a directory. +func copyDir(src, dst string) error { + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } - os.Exit(result.ExitCode) + relPath, err := filepath.Rel(src, path) + if err != nil { + return err + } + destPath := filepath.Join(dst, relPath) + + if info.IsDir() { + return os.MkdirAll(destPath, 0o755) + } + return copyFile(path, destPath) + }) } diff --git a/apps/openant-cli/cmd/scan.go b/apps/openant-cli/cmd/scan.go index e122d4a..39c1e57 100644 --- a/apps/openant-cli/cmd/scan.go +++ b/apps/openant-cli/cmd/scan.go @@ -4,6 +4,7 @@ import ( "fmt" "os" + "github.com/knostic/open-ant-cli/internal/checkpoint" "github.com/knostic/open-ant-cli/internal/output" "github.com/knostic/open-ant-cli/internal/python" "github.com/spf13/cobra" @@ -38,6 +39,8 @@ var ( scanDynamicTest bool scanLimit int scanModel string + scanWorkers int + scanBackoff int ) func init() { @@ -52,6 +55,8 @@ func init() { scanCmd.Flags().BoolVar(&scanDynamicTest, "dynamic-test", false, "Enable Docker-isolated dynamic testing (off by default)") scanCmd.Flags().IntVar(&scanLimit, "limit", 0, "Max units to analyze (0 = no limit)") scanCmd.Flags().StringVar(&scanModel, "model", "opus", "Model: opus or sonnet") + scanCmd.Flags().IntVar(&scanWorkers, "workers", 8, "Number of parallel workers for LLM steps (default: 8)") + scanCmd.Flags().IntVar(&scanBackoff, "backoff", 30, "Seconds to wait when rate-limited (default: 30)") } func runScan(cmd *cobra.Command, args []string) { @@ -80,6 +85,20 @@ func runScan(cmd *cobra.Command, args []string) { os.Exit(2) } + // Check for interrupted runs in the scan directory + if ctx != nil && scanOutput != "" { + steps := []string{"enhance", "analyze", "verify"} + for _, step := range steps { + if cpInfo := checkpoint.DetectViaPython(rt.Path, scanOutput, step); cpInfo != nil { + if !checkpoint.PromptResume(cpInfo, step, quiet) { + _ = checkpoint.Clean(cpInfo.Dir) + } + // Note: Python side auto-detects and uses the checkpoint dir, + // so we only need to clean if the user wants a fresh start. + } + } + } + // Build Python CLI args pyArgs := []string{"scan", repoPath} if scanOutput != "" { @@ -115,6 +134,12 @@ func runScan(cmd *cobra.Command, args []string) { if scanModel != "opus" { pyArgs = append(pyArgs, "--model", scanModel) } + if scanWorkers != 8 { + pyArgs = append(pyArgs, "--workers", fmt.Sprintf("%d", scanWorkers)) + } + if scanBackoff != 30 { + pyArgs = append(pyArgs, "--backoff", fmt.Sprintf("%d", scanBackoff)) + } result, err := python.Invoke(rt.Path, pyArgs, "", quiet, requireAPIKey()) if err != nil { @@ -122,7 +147,9 @@ func runScan(cmd *cobra.Command, args []string) { os.Exit(2) } - if jsonOutput { + if result.Envelope.Status == "interrupted" { + os.Exit(130) + } else if jsonOutput { output.PrintJSON(result.Envelope) } else if result.Envelope.Status == "success" { if data, ok := result.Envelope.Data.(map[string]any); ok { diff --git a/apps/openant-cli/cmd/verify.go b/apps/openant-cli/cmd/verify.go index e3bafaf..cad9b8a 100644 --- a/apps/openant-cli/cmd/verify.go +++ b/apps/openant-cli/cmd/verify.go @@ -1,8 +1,10 @@ package cmd import ( + "fmt" "os" + "github.com/knostic/open-ant-cli/internal/checkpoint" "github.com/knostic/open-ant-cli/internal/output" "github.com/knostic/open-ant-cli/internal/python" "github.com/spf13/cobra" @@ -29,6 +31,9 @@ var ( verifyAnalyzerOutput string verifyAppContext string verifyRepoPath string + verifyWorkers int + verifyCheckpoint string + verifyBackoff int ) func init() { @@ -36,6 +41,9 @@ func init() { verifyCmd.Flags().StringVar(&verifyAnalyzerOutput, "analyzer-output", "", "Path to analyzer_output.json") verifyCmd.Flags().StringVar(&verifyAppContext, "app-context", "", "Path to application_context.json") verifyCmd.Flags().StringVar(&verifyRepoPath, "repo-path", "", "Path to the repository") + verifyCmd.Flags().IntVar(&verifyWorkers, "workers", 8, "Number of parallel workers for LLM steps (default: 8)") + verifyCmd.Flags().StringVar(&verifyCheckpoint, "checkpoint", "", "Path to checkpoint directory for save/resume") + verifyCmd.Flags().IntVar(&verifyBackoff, "backoff", 30, "Seconds to wait when rate-limited (default: 30)") } func runVerify(cmd *cobra.Command, args []string) { @@ -68,6 +76,17 @@ func runVerify(cmd *cobra.Command, args []string) { os.Exit(2) } + // Auto-detect checkpoints from a previous interrupted run + if verifyCheckpoint == "" && ctx != nil { + if cpInfo := checkpoint.DetectViaPython(rt.Path, ctx.ScanDir, "verify"); cpInfo != nil { + if checkpoint.PromptResume(cpInfo, "verify", quiet) { + verifyCheckpoint = cpInfo.Dir + } else { + _ = checkpoint.Clean(cpInfo.Dir) + } + } + } + pyArgs := []string{"verify", resultsPath, "--analyzer-output", verifyAnalyzerOutput} if verifyOutput != "" { pyArgs = append(pyArgs, "--output", verifyOutput) @@ -78,6 +97,15 @@ func runVerify(cmd *cobra.Command, args []string) { if verifyRepoPath != "" { pyArgs = append(pyArgs, "--repo-path", verifyRepoPath) } + if verifyWorkers != 8 { + pyArgs = append(pyArgs, "--workers", fmt.Sprintf("%d", verifyWorkers)) + } + if verifyCheckpoint != "" { + pyArgs = append(pyArgs, "--checkpoint", verifyCheckpoint) + } + if verifyBackoff != 30 { + pyArgs = append(pyArgs, "--backoff", fmt.Sprintf("%d", verifyBackoff)) + } result, err := python.Invoke(rt.Path, pyArgs, "", quiet, requireAPIKey()) if err != nil { @@ -85,7 +113,9 @@ func runVerify(cmd *cobra.Command, args []string) { os.Exit(2) } - if jsonOutput { + if result.Envelope.Status == "interrupted" { + os.Exit(130) + } else if jsonOutput { output.PrintJSON(result.Envelope) } else if result.Envelope.Status == "success" { if data, ok := result.Envelope.Data.(map[string]any); ok { diff --git a/apps/openant-cli/go.mod b/apps/openant-cli/go.mod index 536783a..c63c0cb 100644 --- a/apps/openant-cli/go.mod +++ b/apps/openant-cli/go.mod @@ -8,9 +8,34 @@ require ( ) require ( + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/catppuccin/go v0.3.0 // indirect + github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect + github.com/charmbracelet/bubbletea v1.3.6 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/huh v1.0.0 // indirect + github.com/charmbracelet/lipgloss v1.1.0 // indirect + github.com/charmbracelet/x/ansi v0.9.3 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13 // indirect + github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/pflag v1.0.9 // indirect - golang.org/x/sys v0.25.0 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + golang.org/x/sync v0.15.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.23.0 // indirect ) diff --git a/apps/openant-cli/go.sum b/apps/openant-cli/go.sum index d39b047..b82079a 100644 --- a/apps/openant-cli/go.sum +++ b/apps/openant-cli/go.sum @@ -1,21 +1,75 @@ +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY= +github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= +github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 h1:JFgG/xnwFfbezlUnFMJy0nusZvytYysV4SCS2cYbvws= +github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7/go.mod h1:ISC1gtLcVilLOf23wvTfoQuYbW2q0JevFxPfUzZ9Ybw= +github.com/charmbracelet/bubbletea v1.3.6 h1:VkHIxPJQeDt0aFJIsVxw8BQdh/F/L2KKZGsK6et5taU= +github.com/charmbracelet/bubbletea v1.3.6/go.mod h1:oQD9VCRQFF8KplacJLo28/jofOI2ToOfGYeFgBBxHOc= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/huh v1.0.0 h1:wOnedH8G4qzJbmhftTqrpppyqHakl/zbbNdXIWJyIxw= +github.com/charmbracelet/huh v1.0.0/go.mod h1:5YVc+SlZ1IhQALxRPpkGwwEKftN/+OlJlnJYlDRFqN4= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.9.3 h1:BXt5DHS/MKF+LjuK4huWrC6NCvHtexww7dMayh6GXd0= +github.com/charmbracelet/x/ansi v0.9.3/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= +github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k= +github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko3AQ4gK1MTS/de7F5hPGx6/k1u0w4TeYmBFwzYVP4= +github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/apps/openant-cli/internal/checkpoint/checkpoint.go b/apps/openant-cli/internal/checkpoint/checkpoint.go new file mode 100644 index 0000000..18ff59b --- /dev/null +++ b/apps/openant-cli/internal/checkpoint/checkpoint.go @@ -0,0 +1,240 @@ +// Package checkpoint provides auto-resume detection for LLM pipeline steps. +// +// When a step (enhance, analyze, verify) is interrupted, per-unit checkpoint +// files remain in {scanDir}/{step}_checkpoints/. On the next run the Go CLI +// detects these files and prompts the user to resume or start fresh. +// +// Checkpoint status (completed vs errored counts) is determined by calling +// the Python CLI's `checkpoint-status` command, which is the single source +// of truth for checkpoint semantics. +package checkpoint + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/fatih/color" + "github.com/knostic/open-ant-cli/internal/python" +) + +const summaryFile = "_summary.json" + +// Summary represents the _summary.json written by Python pipeline steps. +type Summary struct { + Step string `json:"step"` + Phase string `json:"phase"` + Timestamp string `json:"timestamp"` + TotalUnits int `json:"total_units"` + Completed int `json:"completed"` + Errors int `json:"errors"` + ErrorBreakdown map[string]int `json:"error_breakdown"` +} + +// Info describes an existing checkpoint directory. +type Info struct { + Dir string // full path to the checkpoint dir + Count int // number of successfully completed units + Errors int // number of errored units + Summary *Summary // parsed _summary.json (may have counts overridden by Python) +} + +// Phase returns the detected phase state as a human-readable string. +func (i *Info) Phase() string { + if i.Summary == nil { + return "legacy" + } + if i.Summary.Phase == "done" && i.Summary.Errors > 0 { + return "done_with_errors" + } + return i.Summary.Phase // "in_progress" or "done" +} + +// DetectViaPython checks for checkpoints by calling the Python CLI's +// checkpoint-status command for accurate completed/errored counts. +// Returns nil if no checkpoint is found or Python fails. +func DetectViaPython(pythonPath, scanDir, stepName string) *Info { + dir := filepath.Join(scanDir, stepName+"_checkpoints") + + // Quick filesystem check: skip Python call if dir doesn't exist + if _, err := os.Stat(dir); err != nil { + return nil + } + + // Call Python for accurate counts + result, err := python.Invoke(pythonPath, []string{"checkpoint-status", dir}, "", true, "") + if err != nil || result.Envelope.Status != "success" { + // Python failed — fall back to simple file count + return DetectFallback(scanDir, stepName) + } + + // Parse the response data + dataBytes, err := json.Marshal(result.Envelope.Data) + if err != nil { + return DetectFallback(scanDir, stepName) + } + + var status struct { + Step string `json:"step"` + Completed int `json:"completed"` + Errors int `json:"errors"` + TotalFiles int `json:"total_files"` + TotalUnits int `json:"total_units"` + Phase string `json:"phase"` + ErrorBreakdown map[string]int `json:"error_breakdown"` + } + if json.Unmarshal(dataBytes, &status) != nil { + return DetectFallback(scanDir, stepName) + } + + if status.TotalFiles == 0 { + return nil + } + + // If the previous run finished cleanly (phase=done, no errors), there's + // nothing to resume — treat it as if there are no checkpoints. The + // checkpoint files are preserved for audit/retro but don't trigger a prompt. + if status.Phase == "done" && status.Errors == 0 { + return nil + } + + info := &Info{ + Dir: dir, + Count: status.Completed, + Errors: status.Errors, + Summary: &Summary{ + Step: status.Step, + Phase: status.Phase, + TotalUnits: status.TotalUnits, + Completed: status.Completed, + Errors: status.Errors, + ErrorBreakdown: status.ErrorBreakdown, + }, + } + + return info +} + +// DetectFallback checks for checkpoints using only filesystem scanning. +// Used when Python is not available. Counts .json files without classifying +// completed vs errored — all files are counted as completed. +func DetectFallback(scanDir, stepName string) *Info { + dir := filepath.Join(scanDir, stepName+"_checkpoints") + + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + + count := 0 + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".json") && e.Name() != summaryFile { + count++ + } + } + if count == 0 { + return nil + } + + info := &Info{Dir: dir, Count: count} + + // Try to read _summary.json for phase state and total_units + summaryPath := filepath.Join(dir, summaryFile) + data, err := os.ReadFile(summaryPath) + if err == nil { + var s Summary + if json.Unmarshal(data, &s) == nil { + info.Summary = &s + } + } + + return info +} + +// PromptResume asks the user whether to resume an interrupted run or start +// fresh. Returns true for resume, false for fresh start. +// +// In non-interactive mode (stdin is not a terminal, or quiet mode), defaults +// to resume. +func PromptResume(info *Info, stepName string, quiet bool) bool { + if quiet || !isTerminal() { + // Non-interactive: default to resume + dim := color.New(color.Faint) + dim.Fprintf(os.Stderr, "[%s] Auto-resuming from %d checkpointed units (non-interactive)\n", + stepName, info.Count) + return true + } + + yellow := color.New(color.FgYellow, color.Bold) + red := color.New(color.FgRed, color.Bold) + bold := color.New(color.Bold) + + fmt.Fprintln(os.Stderr) + + switch info.Phase() { + case "in_progress": + // Interrupted run — show progress out of total + yellow.Fprintf(os.Stderr, "Previous %s run interrupted", stepName) + s := info.Summary + if info.Errors > 0 { + fmt.Fprintf(os.Stderr, " (%d/%d completed, %d errors)\n", + info.Count, s.TotalUnits, info.Errors) + } else { + fmt.Fprintf(os.Stderr, " (%d/%d completed)\n", + info.Count, s.TotalUnits) + } + + case "done_with_errors": + // Ran to completion but had errors — different message + red.Fprintf(os.Stderr, "Previous %s run completed with errors", stepName) + s := info.Summary + fmt.Fprintf(os.Stderr, " (%d/%d completed, %d errors)\n", + info.Count, s.TotalUnits, info.Errors) + + case "done": + // Clean completion — shouldn't normally get here (checkpoints cleaned up) + yellow.Fprintf(os.Stderr, "Previous %s run completed", stepName) + s := info.Summary + fmt.Fprintf(os.Stderr, " (%d/%d completed)\n", info.Count, s.TotalUnits) + + default: + // Legacy format (no _summary.json) or fallback + yellow.Fprintf(os.Stderr, "Previous %s run found", stepName) + if info.Errors > 0 { + fmt.Fprintf(os.Stderr, " (%d completed, %d errors)\n", info.Count, info.Errors) + } else { + fmt.Fprintf(os.Stderr, " (~%d units)\n", info.Count) + } + } + + fmt.Fprintf(os.Stderr, " Checkpoint: %s\n", info.Dir) + fmt.Fprintln(os.Stderr) + bold.Fprint(os.Stderr, "Resume where you left off? ") + fmt.Fprint(os.Stderr, "[Y/n] (n = discard progress, start fresh) ") + + reader := bufio.NewReader(os.Stdin) + answer, _ := reader.ReadString('\n') + answer = strings.TrimSpace(strings.ToLower(answer)) + + if answer == "" || answer == "y" || answer == "yes" { + return true + } + return false +} + +// Clean removes a checkpoint directory. +func Clean(dir string) error { + return os.RemoveAll(dir) +} + +// isTerminal checks if stdin is a terminal (not a pipe). +func isTerminal() bool { + fi, err := os.Stdin.Stat() + if err != nil { + return false + } + return fi.Mode()&os.ModeCharDevice != 0 +} diff --git a/apps/openant-cli/internal/config/project.go b/apps/openant-cli/internal/config/project.go index d9933a7..28ad925 100644 --- a/apps/openant-cli/internal/config/project.go +++ b/apps/openant-cli/internal/config/project.go @@ -97,7 +97,12 @@ func SetActiveProject(name string) error { } // ListProjects returns the names of all initialized projects. -// It walks ~/.openant/projects/ looking for project.json files. +// It looks for project.json at exactly one or two levels deep: +// - ~/.openant/projects//project.json → local projects +// - ~/.openant/projects///project.json → remote (org/repo) projects +// +// It does NOT recurse deeper, to avoid picking up project.json files +// inside cloned repositories (e.g. Grafana's plugin project.json files). func ListProjects() ([]string, error) { projsDir, err := ProjectsDir() if err != nil { @@ -108,23 +113,39 @@ func ListProjects() ([]string, error) { return nil, nil } + // Read first-level directories (e.g. "grafana", "ghostty-org", "openant") + level1Entries, err := os.ReadDir(projsDir) + if err != nil { + return nil, fmt.Errorf("failed to list projects: %w", err) + } + var names []string - err = filepath.WalkDir(projsDir, func(path string, d os.DirEntry, err error) error { + for _, l1 := range level1Entries { + if !l1.IsDir() { + continue + } + l1Path := filepath.Join(projsDir, l1.Name()) + + // Check for project.json at level 1 (local projects: "openant") + if _, err := os.Stat(filepath.Join(l1Path, "project.json")); err == nil { + names = append(names, l1.Name()) + continue // don't also scan subdirs — this is the project + } + + // Check level 2 (org/repo projects: "grafana/grafana") + l2Entries, err := os.ReadDir(l1Path) if err != nil { - return nil // skip errors + continue } - if d.Name() == "project.json" && !d.IsDir() { - // Extract project name from path: - // ~/.openant/projects/org/repo/project.json → org/repo - rel, err := filepath.Rel(projsDir, filepath.Dir(path)) - if err == nil { - names = append(names, rel) + + for _, l2 := range l2Entries { + if !l2.IsDir() { + continue + } + if _, err := os.Stat(filepath.Join(l1Path, l2.Name(), "project.json")); err == nil { + names = append(names, l1.Name()+"/"+l2.Name()) } } - return nil - }) - if err != nil { - return nil, fmt.Errorf("failed to list projects: %w", err) } return names, nil diff --git a/apps/openant-cli/internal/output/formatter.go b/apps/openant-cli/internal/output/formatter.go index 4987b72..39dc73c 100644 --- a/apps/openant-cli/internal/output/formatter.go +++ b/apps/openant-cli/internal/output/formatter.go @@ -73,12 +73,12 @@ func PrintScanSummary(data map[string]any) { PrintHeader("Scan Results") - total := intFromAny(metrics["total_units"]) - vulnerable := intFromAny(metrics["vulnerable_units"]) - safe := intFromAny(metrics["safe_units"]) - unclear := intFromAny(metrics["unclear_units"]) - verified := intFromAny(metrics["verified_vulnerable"]) - falsePos := intFromAny(metrics["false_positives"]) + total := intFromAny(metrics["total"]) + vulnerable := intFromAny(metrics["vulnerable"]) + safe := intFromAny(metrics["safe"]) + unclear := intFromAny(metrics["inconclusive"]) + verified := intFromAny(metrics["verified"]) + falsePos := intFromAny(metrics["stage2_disagreed"]) PrintKeyValue("Total units analyzed", fmt.Sprintf("%d", total)) @@ -108,7 +108,7 @@ func PrintScanSummary(data map[string]any) { // Usage info if usage, ok := data["usage"].(map[string]any); ok { PrintHeader("Usage") - cost := floatFromAny(usage["total_cost"]) + cost := floatFromAny(usage["total_cost_usd"]) inputTokens := intFromAny(usage["total_input_tokens"]) outputTokens := intFromAny(usage["total_output_tokens"]) @@ -171,17 +171,30 @@ func PrintAnalyzeSummary(data map[string]any) { } PrintHeader("Analysis Results") - total := intFromAny(metrics["total_units"]) - vulnerable := intFromAny(metrics["vulnerable_units"]) - safe := intFromAny(metrics["safe_units"]) + total := intFromAny(metrics["total"]) + vulnerable := intFromAny(metrics["vulnerable"]) + bypassable := intFromAny(metrics["bypassable"]) + protected := intFromAny(metrics["protected"]) + safe := intFromAny(metrics["safe"]) + inconclusive := intFromAny(metrics["inconclusive"]) + errors := intFromAny(metrics["errors"]) PrintKeyValue("Total units", fmt.Sprintf("%d", total)) - if vulnerable > 0 { - red.Printf(" Vulnerable: %d\n", vulnerable) + + combined := vulnerable + bypassable + if combined > 0 { + red.Printf(" Vulnerable: %d\n", combined) } else { green.Printf(" Vulnerable: 0\n") } + PrintKeyValue("Protected", fmt.Sprintf("%d", protected)) PrintKeyValue("Safe", fmt.Sprintf("%d", safe)) + if inconclusive > 0 { + yellow.Printf(" Inconclusive: %d\n", inconclusive) + } + if errors > 0 { + yellow.Printf(" Errors: %d\n", errors) + } if path, ok := data["results_path"].(string); ok { PrintKeyValue("Output", path) @@ -191,15 +204,21 @@ func PrintAnalyzeSummary(data map[string]any) { // PrintReportSummary outputs a formatted summary of report generation. func PrintReportSummary(data map[string]any) { - PrintHeader("Reports Generated") - if html, ok := data["html_path"].(string); ok && html != "" { - PrintKeyValue("HTML", html) + PrintHeader("Report Generated") + if format, ok := data["format"].(string); ok && format != "" { + PrintKeyValue("Format", format) + } + if path, ok := data["output_path"].(string); ok && path != "" { + PrintKeyValue("Output", path) } - if csv, ok := data["csv_path"].(string); ok && csv != "" { - PrintKeyValue("CSV", csv) + if path, ok := data["reskin_path"].(string); ok && path != "" { + PrintKeyValue("Reskin", path) } - if summary, ok := data["summary_path"].(string); ok && summary != "" { - PrintKeyValue("Summary", summary) + if usage, ok := data["usage"].(map[string]any); ok { + cost := floatFromAny(usage["total_cost_usd"]) + if cost > 0 { + PrintKeyValue("Cost", fmt.Sprintf("$%.4f", cost)) + } } fmt.Println() } diff --git a/apps/openant-cli/internal/python/invoke.go b/apps/openant-cli/internal/python/invoke.go index dd0e644..d127e11 100644 --- a/apps/openant-cli/internal/python/invoke.go +++ b/apps/openant-cli/internal/python/invoke.go @@ -8,7 +8,10 @@ import ( "io" "os" "os/exec" + "os/signal" "strings" + "syscall" + "time" "github.com/knostic/open-ant-cli/internal/types" ) @@ -57,6 +60,32 @@ func Invoke(pythonPath string, args []string, workDir string, quiet bool, apiKey return nil, fmt.Errorf("failed to start Python process: %w", err) } + // Forward SIGINT/SIGTERM to the Python subprocess so Ctrl+C kills it. + sigChan := make(chan os.Signal, 1) + interrupted := false + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + go func() { + sig, ok := <-sigChan + if !ok { + return // channel closed, normal exit + } + interrupted = true + // Forward signal to Python subprocess + if cmd.Process != nil { + _ = cmd.Process.Signal(sig) + } + // Give Python a few seconds to exit gracefully, then force kill + time.AfterFunc(5*time.Second, func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + }) + }() + defer func() { + signal.Stop(sigChan) + close(sigChan) + }() + // Stream stderr in a goroutine stderrDone := make(chan struct{}) go func() { @@ -87,6 +116,16 @@ func Invoke(pythonPath string, args []string, workDir string, quiet bool, apiKey // Parse JSON from stdout rawJSON := strings.TrimSpace(stdoutBuf.String()) if rawJSON == "" { + if interrupted { + // User interrupted with Ctrl+C — not an error + return &InvokeResult{ + Envelope: types.Envelope{ + Status: "interrupted", + Errors: []string{}, + }, + ExitCode: 130, // standard SIGINT exit code + }, nil + } return &InvokeResult{ Envelope: types.Envelope{ Status: "error", diff --git a/apps/openant-cli/internal/report/report.go b/apps/openant-cli/internal/report/report.go new file mode 100644 index 0000000..d6f882e --- /dev/null +++ b/apps/openant-cli/internal/report/report.go @@ -0,0 +1,76 @@ +package report + +import ( + "embed" + "encoding/json" + "html/template" + "io" + "os" + "path/filepath" +) + +//go:embed templates/overview.gohtml +var templateFS embed.FS + +//go:embed templates/report-reskin.gohtml +var reskinFS embed.FS + +var ( + overviewTmpl *template.Template + reskinTmpl *template.Template +) + +func init() { + funcMap := template.FuncMap{ + "toJSON": func(v any) template.JS { + b, _ := json.Marshal(v) + return template.JS(b) + }, + "even": func(i int) bool { + return i%2 == 0 + }, + } + + overviewTmpl = template.Must( + template.New("overview.gohtml").Funcs(funcMap).ParseFS(templateFS, "templates/overview.gohtml"), + ) + + reskinTmpl = template.Must( + template.New("report-reskin.gohtml").Funcs(funcMap).ParseFS(reskinFS, "templates/report-reskin.gohtml"), + ) +} + +// RenderOverview renders the HTML overview report to the given writer. +func RenderOverview(data ReportData, w io.Writer) error { + return overviewTmpl.Execute(w, data) +} + +// GenerateOverview renders the HTML overview report to a file. +func GenerateOverview(data ReportData, outputPath string) error { + return generateToFile(overviewTmpl, data, outputPath) +} + +// RenderReskin renders the Knostic-themed HTML report to the given writer. +func RenderReskin(data ReportData, w io.Writer) error { + return reskinTmpl.Execute(w, data) +} + +// GenerateReskin renders the Knostic-themed HTML report to a file. +func GenerateReskin(data ReportData, outputPath string) error { + return generateToFile(reskinTmpl, data, outputPath) +} + +// generateToFile renders a template to a file, creating parent directories as needed. +func generateToFile(tmpl *template.Template, data ReportData, outputPath string) error { + if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil { + return err + } + + f, err := os.Create(outputPath) + if err != nil { + return err + } + defer f.Close() + + return tmpl.Execute(f, data) +} diff --git a/apps/openant-cli/internal/report/templates/overview.gohtml b/apps/openant-cli/internal/report/templates/overview.gohtml new file mode 100644 index 0000000..209ff41 --- /dev/null +++ b/apps/openant-cli/internal/report/templates/overview.gohtml @@ -0,0 +1,372 @@ + + + + + + {{.Title}} + + + + + + + + + + + +
+ + +
+
+
+

{{.Title}}

+
+ {{if .RepoName}}{{.RepoName}}{{end}} + {{if .ShortCommit}}{{.ShortCommit}}{{end}} + {{if .Language}}{{.Language}}{{end}} +
+
+
+ {{if .FormatDuration}}{{.FormatDuration}}{{end}} + {{.Timestamp}} + Powered by Knostic +
+
+
+ + +
+
+
{{.Stats.TotalUnits}}
+
Code Units
+
+
+
{{.Stats.TotalFiles}}
+
Files
+
+
+
{{.Stats.Vulnerable}}
+
Vulnerable Units
+
+
+
{{.Stats.Bypassable}}
+
Bypassable Units
+
+
+
{{.Stats.Secure}}
+
Secure Units
+
+
+ + +
+

Distribution Overview

+
+
+

By Code Unit

+
+ +
+
+
+

By File (Worst Verdict)

+
+ +
+
+
+ + +
+ + Verdict Categories + +
+ + + + + + + + + {{range $i, $cat := .Categories}} + + + + + {{end}} + +
CategoryDescription
+ {{$cat.Verdict}} + {{$cat.Description}}
+
+
+
+ + +
+

Remediation Guidance

+
+ {{.SafeRemediation}} +
+
+ + + {{if .HasStepReports}} +
+
+ + Pipeline Costs & Timing + +
+ + + +
+ {{range .StepReports}} +
+
{{.Step}}
+
+
+ Duration +
{{.Duration}}
+
+
+ Cost +
{{.Cost}}
+
+
+ Status +
{{.Status}}
+
+
+
+ {{end}} +
+
Total
+
+
+ Duration +
{{.FormatDuration}}
+
+
+ Cost +
{{.FormatTotalCost}}
+
+
+
+
+
+
+
+ {{end}} + + + {{if .HasFindingGroups}} +
+

All Findings

+
+ {{range $g := .FindingsByVerdict}} + + + {{$g.Verdict}} + {{$g.Count}} finding{{if ne $g.Count 1}}s{{end}} + +
+ {{range $f := $g.Findings}} +
+ + #{{$f.Number}} + {{if $f.HasDynamicTest}} + {{$f.DynamicTestStatus}} + {{end}} + {{$url := $.FileURL $f.File}}{{if $url}}{{$f.File}}{{else}}{{$f.File}}{{end}} + {{$f.Function}} + {{if $f.AttackVector}} + + {{end}} + +
+ {{if $f.AttackVector}} +
+ Attack Vector +

{{$f.AttackVector}}

+
+ {{end}} +
+ Analysis +

{{$f.Analysis}}

+
+ {{if $f.HasDynamicTest}} +
+ Dynamic Test +
+ {{$f.DynamicTestStatus}} + {{if $f.DynamicTestDetails}}{{$f.DynamicTestDetails}}{{end}} +
+
+ {{end}} +
+
+ {{end}} +
+ + {{end}} +
+
+ {{end}} + + + +
+ + + + + diff --git a/apps/openant-cli/internal/report/templates/report-reskin.gohtml b/apps/openant-cli/internal/report/templates/report-reskin.gohtml new file mode 100644 index 0000000..d706300 --- /dev/null +++ b/apps/openant-cli/internal/report/templates/report-reskin.gohtml @@ -0,0 +1,402 @@ + + + + + + {{.Title}} + + + + + + + + + + + + + + +
+ + +
+
+
+

{{.Title}}

+
+ {{if .RepoName}}{{.RepoName}}{{end}} + {{if .ShortCommit}}{{.ShortCommit}}{{end}} + {{if .Language}}{{.Language}}{{end}} +
+
+
+ {{if .FormatDuration}}{{.FormatDuration}}{{end}} + {{.Timestamp}} + Powered by Knostic +
+
+
+ + +
+
+
{{.Stats.TotalUnits}}
+
Code Units
+
+
+
{{.Stats.TotalFiles}}
+
Files
+
+
+
{{.Stats.Vulnerable}}
+
Vulnerable Units
+
+
+
{{.Stats.Bypassable}}
+
Bypassable Units
+
+
+
{{.Stats.Secure}}
+
Secure Units
+
+
+ + +
+

Distribution Overview

+
+
+

By Code Unit

+
+ +
+
+
+

By File (Worst Verdict)

+
+ +
+
+
+ + +
+ + Verdict Categories + +
+ + + + + + + + + {{range $i, $cat := .Categories}} + + + + + {{end}} + +
CategoryDescription
+ {{$cat.Verdict}} + {{$cat.Description}}
+
+
+
+ + +
+

Remediation Guidance

+
+ {{.SafeRemediation}} +
+
+ + + {{if .HasStepReports}} +
+
+ + Pipeline Costs & Timing + +
+ + + +
+ {{range .StepReports}} +
+
{{.Step}}
+
+
+ Duration +
{{.Duration}}
+
+
+ Cost +
{{.Cost}}
+
+
+ Status +
{{.Status}}
+
+
+
+ {{end}} +
+
Total
+
+
+ Duration +
{{.FormatDuration}}
+
+
+ Cost +
{{.FormatTotalCost}}
+
+
+
+
+
+
+
+ {{end}} + + + {{if .HasFindingGroups}} +
+

All Findings

+
+ {{range $g := .FindingsByVerdict}} + + + {{$g.Verdict}} + {{$g.Count}} finding{{if ne $g.Count 1}}s{{end}} + +
+ {{if $g.HasSubgroups}}{{range $sg := $g.Subgroups}} +
+
+ {{$sg.Label}} + ({{len $sg.Findings}}) +
+
+ {{range $f := $sg.Findings}} +
+ + #{{$f.Number}} + {{if $f.HasDynamicTest}}{{$f.DynamicTestStatus}}{{end}} + {{$url := $.FileURL $f.File}}{{if $url}}{{$f.File}}{{else}}{{$f.File}}{{end}} + {{$f.Function}} + {{if $f.AttackVector}}{{end}} + +
+ {{if $f.AttackVector}}
Attack Vector

{{$f.AttackVector}}

{{end}} +
Analysis

{{$f.Analysis}}

+ {{if $f.HasDynamicTest}}
Dynamic Test
{{$f.DynamicTestStatus}}{{if $f.DynamicTestDetails}}{{$f.DynamicTestDetails}}{{end}}
{{end}} +
+
+ {{end}} +
+ {{end}}{{else}}{{range $f := $g.Findings}} +
+ + #{{$f.Number}} + {{if $f.HasDynamicTest}}{{$f.DynamicTestStatus}}{{end}} + {{$url := $.FileURL $f.File}}{{if $url}}{{$f.File}}{{else}}{{$f.File}}{{end}} + {{$f.Function}} + {{if $f.AttackVector}}{{end}} + +
+ {{if $f.AttackVector}}
Attack Vector

{{$f.AttackVector}}

{{end}} +
Analysis

{{$f.Analysis}}

+ {{if $f.HasDynamicTest}}
Dynamic Test
{{$f.DynamicTestStatus}}{{if $f.DynamicTestDetails}}{{$f.DynamicTestDetails}}{{end}}
{{end}} +
+
+ {{end}}{{end}} +
+ + {{end}} +
+
+ {{end}} + + + +
+ + + + + diff --git a/apps/openant-cli/internal/report/types.go b/apps/openant-cli/internal/report/types.go new file mode 100644 index 0000000..b19dc89 --- /dev/null +++ b/apps/openant-cli/internal/report/types.go @@ -0,0 +1,214 @@ +// Package report provides HTML report generation from pre-computed data. +package report + +import ( + "fmt" + "html/template" + "strings" +) + +// ReportData holds all pre-computed data needed to render the HTML overview report. +// This struct maps 1:1 to the JSON output of the Python `report-data` subcommand. +type ReportData struct { + Title string `json:"title"` + Timestamp string `json:"timestamp"` + RepoName string `json:"repo_name"` + CommitSHA string `json:"commit_sha"` + Language string `json:"language"` + RepoURL string `json:"repo_url"` + TotalDurationS float64 `json:"total_duration_seconds"` + TotalCostUSD float64 `json:"total_cost_usd"` + Stats Stats `json:"stats"` + UnitChart ChartData `json:"unit_chart"` + FileChart ChartData `json:"file_chart"` + RemediationHTML string `json:"remediation_html"` + Findings []Finding `json:"findings"` + FindingsByVerdict []FindingGroup `json:"findings_by_verdict"` + StepReports []StepReport `json:"step_reports"` + Categories []Category `json:"categories"` +} + +// SafeRemediation returns the remediation HTML as a template.HTML +// so Go's html/template does not escape it. +func (d ReportData) SafeRemediation() template.HTML { + return template.HTML(d.RemediationHTML) +} + +// FormatDuration returns TotalDurationS as a human-readable string +// like "1d 2h 3m 4s", omitting leading zero components. +func (d ReportData) FormatDuration() string { + total := int(d.TotalDurationS) + if total <= 0 { + return "" + } + days := total / 86400 + hours := (total % 86400) / 3600 + mins := (total % 3600) / 60 + secs := total % 60 + + var parts []string + if days > 0 { + parts = append(parts, fmt.Sprintf("%dd", days)) + } + if hours > 0 { + parts = append(parts, fmt.Sprintf("%dh", hours)) + } + if mins > 0 { + parts = append(parts, fmt.Sprintf("%dm", mins)) + } + if secs > 0 || len(parts) == 0 { + parts = append(parts, fmt.Sprintf("%ds", secs)) + } + return strings.Join(parts, " ") +} + +// FormatTotalCost returns TotalCostUSD as "$X.XX", or "-" if zero. +func (d ReportData) FormatTotalCost() string { + if d.TotalCostUSD <= 0 { + return "-" + } + return fmt.Sprintf("$%.2f", d.TotalCostUSD) +} + +// ShortCommit returns the first 10 characters of CommitSHA, or empty. +func (d ReportData) ShortCommit() string { + if len(d.CommitSHA) > 10 { + return d.CommitSHA[:10] + } + return d.CommitSHA +} + +// FileURL constructs a browseable URL for a file path in the repo. +// Returns empty string if repo URL or commit SHA is missing. +func (d ReportData) FileURL(filePath string) string { + if d.RepoURL == "" || d.CommitSHA == "" { + return "" + } + base := strings.TrimRight(d.RepoURL, "/") + base = strings.TrimSuffix(base, ".git") + return base + "/blob/" + d.CommitSHA + "/" + filePath +} + +// HasStepReports returns true if there are step reports to display. +func (d ReportData) HasStepReports() bool { + return len(d.StepReports) > 0 +} + +// HasFindings returns true if there are findings to display. +func (d ReportData) HasFindings() bool { + return len(d.Findings) > 0 +} + +// HasFindingGroups returns true if there are grouped findings to display. +func (d ReportData) HasFindingGroups() bool { + return len(d.FindingsByVerdict) > 0 +} + +// Stats holds the summary statistics for the report header cards. +type Stats struct { + TotalUnits int `json:"total_units"` + TotalFiles int `json:"total_files"` + Vulnerable int `json:"vulnerable"` + Bypassable int `json:"bypassable"` + Secure int `json:"secure"` +} + +// ChartData holds the data for a Chart.js pie chart. +type ChartData struct { + Labels []string `json:"labels"` + Data []int `json:"data"` + Colors []string `json:"colors"` +} + +// FindingGroup holds findings grouped by verdict for collapsible sections. +type FindingGroup struct { + Verdict string `json:"verdict"` + VerdictColor string `json:"verdict_color"` + Count int `json:"count"` + OpenByDefault bool `json:"open_by_default"` + Findings []Finding `json:"findings"` + Subgroups []FindingSubgroup `json:"subgroups"` + HasSubgroups bool `json:"has_subgroups"` +} + +// FindingSubgroup holds findings within a verdict group, sub-grouped by +// dynamic test outcome (e.g. "Confirmed", "Test error", "Not tested"). +type FindingSubgroup struct { + Label string `json:"label"` + Findings []Finding `json:"findings"` +} + +// Finding represents a single finding row in the report table. +type Finding struct { + Number int `json:"number"` + Verdict string `json:"verdict"` + VerdictColor string `json:"verdict_color"` + File string `json:"file"` + Function string `json:"function"` + AttackVector string `json:"attack_vector"` + Analysis string `json:"analysis"` + DynamicTestStatus string `json:"dynamic_test_status"` + DynamicTestDetails string `json:"dynamic_test_details"` +} + +// HasDynamicTest returns true if this finding has dynamic test results. +func (f Finding) HasDynamicTest() bool { + return f.DynamicTestStatus != "" +} + +// DynamicTestColor returns a color for the dynamic test status badge. +func (f Finding) DynamicTestColor() string { + switch f.DynamicTestStatus { + case "CONFIRMED": + return "#dc3545" + case "NOT_REPRODUCED": + return "#28a745" + case "BLOCKED": + return "#28a745" + case "ERROR": + return "#6c757d" + case "INCONCLUSIVE": + return "#fd7e14" + default: + return "#6c757d" + } +} + +// IsHighSeverity returns true for vulnerable/bypassable findings, +// used to auto-open their
accordion in the HTML report. +func (f Finding) IsHighSeverity() bool { + switch f.Verdict { + case "vulnerable", "bypassable": + return true + default: + return false + } +} + +// StepReport holds display-ready data for a pipeline step. +type StepReport struct { + Step string `json:"step"` + Duration string `json:"duration"` + Cost string `json:"cost"` + Status string `json:"status"` + Timestamp string `json:"timestamp"` +} + +// StatusColor returns a Tailwind text color class based on step status. +func (s StepReport) StatusColor() string { + switch s.Status { + case "success": + return "text-green-400" + case "error": + return "text-red-400" + default: + return "text-gray-400" + } +} + +// Category holds a verdict category description for the legend table. +type Category struct { + Verdict string `json:"verdict"` + Color string `json:"color"` + Description string `json:"description"` +} diff --git a/apps/openant-cli/internal/types/results.go b/apps/openant-cli/internal/types/results.go index c84c903..1358299 100644 --- a/apps/openant-cli/internal/types/results.go +++ b/apps/openant-cli/internal/types/results.go @@ -27,20 +27,24 @@ type AnalyzeData struct { // AnalysisMetrics holds vulnerability counts from analysis. type AnalysisMetrics struct { - TotalUnits int `json:"total_units"` - VulnerableUnits int `json:"vulnerable_units"` - SafeUnits int `json:"safe_units"` - UnclearUnits int `json:"unclear_units"` - VerifiedVulnerable int `json:"verified_vulnerable"` - FalsePositives int `json:"false_positives"` - VerificationSkipped int `json:"verification_skipped"` + Total int `json:"total"` + Vulnerable int `json:"vulnerable"` + Bypassable int `json:"bypassable"` + Inconclusive int `json:"inconclusive"` + Protected int `json:"protected"` + Safe int `json:"safe"` + Errors int `json:"errors"` + // Stage 2 metrics (optional) + Verified int `json:"verified"` + Stage2Agreed int `json:"stage2_agreed"` + Stage2Disagreed int `json:"stage2_disagreed"` } // ReportData is returned by the `report` command. type ReportData struct { - HTMLPath string `json:"html_path"` - CSVPath string `json:"csv_path"` - SummaryPath string `json:"summary_path"` + OutputPath string `json:"output_path"` + Format string `json:"format"` + Usage UsageInfo `json:"usage"` } // ScanData is returned by the `scan` command (all-in-one pipeline). diff --git a/libs/openant-core/CLAUDE.md b/libs/openant-core/CLAUDE.md index 6d76a8c..3c61665 100644 --- a/libs/openant-core/CLAUDE.md +++ b/libs/openant-core/CLAUDE.md @@ -6,6 +6,19 @@ - If unsure about scope, ask first - After any context compaction, re-read this file and referenced docs before taking any action +# Go CLI Build Rules + +**NEVER run `make install` in `apps/openant-cli/`** — it overwrites the symlink with a copy. + +The system uses a symlink: `/usr/local/bin/openant` → `apps/openant-cli/bin/openant` + +To rebuild the Go CLI: +```bash +cd apps/openant-cli && go build -o bin/openant . +``` + +The symlink automatically picks up the new binary. Running `make install` would replace the symlink with a copied file, breaking the dev workflow. + # Project Context This is OpenAnt, a two-stage SAST tool using Claude for vulnerability analysis. Supports Python, JavaScript/TypeScript, and Go codebases with 4-level cost optimization. @@ -93,7 +106,7 @@ python -m autopilot --repo owner/repo --api # API mode (JSON protocol) 2. Assess - Score for vuln-hunting potential (skip in --repo/--path modes) 3. Parse - Clone, parse, filter to reachable units 4. Enhance - Add call path context -5. Detect - Stage 1 vulnerability detection +5. Analyze - Stage 1 vulnerability analysis 6. Verify - Stage 2 attacker simulation 7. Dynamic Test - Docker-isolated exploit testing (requires Docker) 8. Report - Generate security reports diff --git a/libs/openant-core/core/analyzer.py b/libs/openant-core/core/analyzer.py index 4704075..2776237 100644 --- a/libs/openant-core/core/analyzer.py +++ b/libs/openant-core/core/analyzer.py @@ -5,20 +5,30 @@ hardcoded dataset names. Reuses the existing analysis functions directly. Stage 2 verification is handled separately by ``core.verifier``. + +Checkpoints are always enabled. Per-unit results are saved to +``{output_dir}/analyze_checkpoints/`` so interrupted runs can resume. +On successful completion the checkpoint dir is removed. """ import json import os import sys +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from pathlib import Path from core.schemas import AnalyzeResult, AnalysisMetrics, UsageInfo from core import tracking +from core.checkpoint import StepCheckpoint +from core.progress import ProgressReporter # Import existing analysis machinery from utilities.llm_client import AnthropicClient, get_global_tracker from utilities.json_corrector import JSONCorrector +from utilities.rate_limiter import get_rate_limiter, is_rate_limit_error, is_retryable_error # Reuse the core analysis functions from experiment.py from experiment import ( @@ -36,6 +46,216 @@ load_context = None +def _process_unit(client, unit, index, json_corrector, app_context): + """Process a single unit for Stage 1 detection. + + Returns a dict with all result data. Does not mutate shared state. + """ + uid = unit.get("id", f"unit_{index}") + start = time.monotonic() + tracker = get_global_tracker() + tracker.start_unit_tracking() + + try: + result = analyze_unit( + client, unit, + use_multifile=True, + json_corrector=json_corrector, + app_context=app_context, + ) + + # Ensure unit_id is always present + result["unit_id"] = uid + + # Ensure finding field is always set (may be None after JSON correction) + if not result.get("finding") and result.get("verdict"): + result["finding"] = result["verdict"].lower() + + # Extract code for verify step + route_key = result.get("route_key", uid) + code_field = unit.get("code", {}) + if isinstance(code_field, dict): + code_for_route = code_field.get("primary_code", "") + else: + code_for_route = code_field + + finding = result.get("finding", "error") + elapsed = time.monotonic() - start + worker = threading.current_thread().name + + return { + "index": index, + "result": result, + "route_key": route_key, + "code_for_route": code_for_route, + "finding": finding, + "elapsed": elapsed, + "error": None, + "worker": worker, + "usage": tracker.get_unit_usage(), + } + + except Exception as e: + elapsed = time.monotonic() - start + worker = threading.current_thread().name + return { + "index": index, + "result": { + "unit_id": uid, + "verdict": "ERROR", + "finding": "error", + "error": str(e), + }, + "route_key": uid, + "code_for_route": "", + "finding": "error", + "elapsed": elapsed, + "error": str(e), + "worker": worker, + "usage": tracker.get_unit_usage(), + } + + +def _run_detection(units, client, json_corrector, app_context, workers, + checkpoint=None, summary_callback=None): + """Run Stage 1 detection across all units. + + Uses ThreadPoolExecutor for parallel processing when workers > 1. + Supports checkpoint/resume via the checkpoint parameter. + + Args: + summary_callback: Optional callable(finding, usage=None) called from + main thread after each unit completes. Used for _summary.json updates. + + Returns (results_list, code_by_route_dict) in original unit order. + """ + total = len(units) + tracker = get_global_tracker() + + # Load checkpoint state + checkpointed = {} + if checkpoint is not None: + checkpointed = checkpoint.load() + if checkpointed: + print(f"[Detect] Restored {len(checkpointed)} units from checkpoints", + file=sys.stderr, flush=True) + + progress = ProgressReporter("Detect", total, tracker=tracker, completed=len(checkpointed)) + + mode = "sequential" if workers <= 1 else f"parallel ({workers} workers)" + remaining = total - len(checkpointed) + print(f"[Detect] Mode: {mode}, {remaining} units to process ({len(checkpointed)} already done)", + file=sys.stderr, flush=True) + + # Pre-populate results from checkpoints, but ONLY for successfully-completed + # units. Errored units are loaded into the "units_to_process" list so they + # get retried on resume (matches enhance's behavior). + results = [None] * total + code_by_route = {} + units_to_process = [] + + def _cp_is_error(cp_data): + res = cp_data.get("result", {}) if cp_data else {} + return res.get("verdict") == "ERROR" or res.get("finding") == "error" + + for i, unit in enumerate(units): + uid = unit.get("id", f"unit_{i}") + cp_data = checkpointed.get(uid) + if cp_data and not _cp_is_error(cp_data): + results[i] = cp_data.get("result", {}) + code_by_route[cp_data.get("route_key", uid)] = cp_data.get("code_for_route", "") + else: + units_to_process.append((i, unit)) + + def _process_and_save(i, unit): + out = _process_unit(client, unit, i, json_corrector, app_context) + # Save checkpoint + if checkpoint is not None: + uid = out["result"].get("unit_id", f"unit_{i}") + cp_data = { + "result": out["result"], + "route_key": out["route_key"], + "code_for_route": out["code_for_route"], + } + if out.get("usage"): + cp_data["usage"] = out["usage"] + checkpoint.save(uid, cp_data) + return out + + if workers <= 1: + # Sequential mode + try: + for i, unit in units_to_process: + out = _process_and_save(i, unit) + results[i] = out["result"] + code_by_route[out["route_key"]] = out["code_for_route"] + if summary_callback: + summary_callback(out["finding"], usage=out.get("usage")) + progress.report( + out["result"].get("unit_id", f"unit_{i}"), + detail=out["finding"], + unit_elapsed=out["elapsed"], + ) + except KeyboardInterrupt: + print("[Detect] Interrupted — progress saved to checkpoints", + file=sys.stderr, flush=True) + progress.finish() + return results, code_by_route + + # Parallel mode + executor = ThreadPoolExecutor(max_workers=workers) + future_to_index = {} + for i, unit in units_to_process: + future = executor.submit(_process_and_save, i, unit) + future_to_index[future] = i + + try: + for future in as_completed(future_to_index): + out = future.result() + idx = out["index"] + results[idx] = out["result"] + code_by_route[out["route_key"]] = out["code_for_route"] + if summary_callback: + summary_callback(out["finding"], usage=out.get("usage")) + worker = out.get("worker", "?") + progress.report( + out["result"].get("unit_id", f"unit_{idx}"), + detail=f"{out['finding']} [{worker}]", + unit_elapsed=out["elapsed"], + ) + except KeyboardInterrupt: + print("[Detect] Interrupted — cancelling pending work...", + file=sys.stderr, flush=True) + executor.shutdown(wait=False, cancel_futures=True) + print("[Detect] Progress saved to checkpoints", + file=sys.stderr, flush=True) + else: + executor.shutdown(wait=False) + + progress.finish() + + return results, code_by_route + + +def _count_verdicts(results): + """Count verdict categories from a results list.""" + counts = { + "vulnerable": 0, + "bypassable": 0, + "inconclusive": 0, + "protected": 0, + "safe": 0, + "errors": 0, + } + for r in results: + finding = r.get("finding", r.get("verdict", "error").lower()) + if finding in counts: + counts[finding] += 1 + elif r.get("verdict") == "ERROR": + counts["errors"] += 1 + return counts + + def run_analysis( dataset_path: str, output_dir: str, @@ -44,7 +264,10 @@ def run_analysis( repo_path: str | None = None, limit: int | None = None, model: str = "opus", - exploitable_only: bool = False, + exploitable_filter: str | None = None, + workers: int = 8, + checkpoint_path: str | None = None, + backoff_seconds: int = 30, ) -> AnalyzeResult: """Run Stage 1 vulnerability detection on a dataset. @@ -52,6 +275,10 @@ def run_analysis( accepting file paths instead of dataset names. Stage 1 only — for Stage 2 verification use ``core.verifier.run_verification()``. + Checkpoints are always enabled. Per-unit results are saved to + ``{output_dir}/analyze_checkpoints/`` so interrupted runs resume + automatically. + Args: dataset_path: Path to dataset.json produced by a parser. output_dir: Directory to write results.json. @@ -61,16 +288,29 @@ def run_analysis( repo_path: Path to the repository (for context correction). limit: Max number of units to analyze. model: "opus" or "sonnet". - exploitable_only: If True, only analyze units classified as exploitable - by the agentic enhancer (requires enhanced dataset). + exploitable_filter: Filter by enhancement classification. Options: + None (default) — no filtering, analyze all units. + "all" — keep exploitable + vulnerable_internal (recommended). + "strict" — keep exploitable only (use after parser fixes). + checkpoint_path: Path to checkpoint directory. If None, auto-derived + from output_dir. + workers: Number of parallel workers (default: 8). + backoff_seconds: Seconds to wait on rate limit before retry (default: 30). Returns: AnalyzeResult with results path, metrics, and usage. """ os.makedirs(output_dir, exist_ok=True) - # Reset tracking for this analysis run - tracking.reset_tracking() + # Configure global rate limiter + from utilities.rate_limiter import configure_rate_limiter + configure_rate_limiter(backoff_seconds=float(backoff_seconds)) + + # Set up checkpoint + if checkpoint_path is None: + checkpoint_path = os.path.join(output_dir, "analyze_checkpoints") + checkpoint = StepCheckpoint("Analyze", output_dir) + checkpoint.dir = checkpoint_path # Select model model_id = "claude-opus-4-6" if model == "opus" else "claude-sonnet-4-20250514" @@ -95,82 +335,144 @@ def run_analysis( units = dataset.get("units", []) - # Optional: filter to exploitable units only (requires enhanced dataset) - if exploitable_only: + # Optional: filter by enhancement security classification + if exploitable_filter: original_count = len(units) + if exploitable_filter == "strict": + keep = ("exploitable",) + else: # "all" — default when filtering is enabled + keep = ("exploitable", "vulnerable_internal") units = [ u for u in units - if u.get("agent_context", {}).get("security_classification") in ("exploitable", "vulnerable") + if u.get("agent_context", {}).get("security_classification") in keep ] - print(f"[Analyze] Exploitable filter: {original_count} -> {len(units)} units", file=sys.stderr) + print(f"[Analyze] Exploitable filter ({exploitable_filter}): {original_count} -> {len(units)} units", file=sys.stderr) if limit: units = units[:limit] - print(f"[Analyze] Analyzing {len(units)} units...", file=sys.stderr) + total = len(units) + print(f"[Analyze] Analyzing {total} units...", file=sys.stderr) + + # Initialize summary tracking for _summary.json + # Count checkpointed units to seed the counters and sum existing usage + _existing = checkpoint.load() + _summary_completed = 0 + _summary_errors = 0 + _summary_error_breakdown = {} + _summary_input_tokens = 0 + _summary_output_tokens = 0 + _summary_cost_usd = 0.0 + for _uid, _cp in _existing.items(): + _r = _cp.get("result", {}) + if _r.get("verdict") == "ERROR" or _r.get("finding") == "error": + _summary_errors += 1 + _summary_error_breakdown["api"] = _summary_error_breakdown.get("api", 0) + 1 + else: + _summary_completed += 1 + _cp_usage = _cp.get("usage", {}) + _summary_input_tokens += _cp_usage.get("input_tokens", 0) + _summary_output_tokens += _cp_usage.get("output_tokens", 0) + _summary_cost_usd += _cp_usage.get("cost_usd", 0.0) + + def _usage_dict(): + return {"input_tokens": _summary_input_tokens, + "output_tokens": _summary_output_tokens, + "cost_usd": round(_summary_cost_usd, 6)} + + # Inject prior usage into tracker so step_report captures the total + if _summary_input_tokens or _summary_output_tokens: + tracker.add_prior_usage( + _summary_input_tokens, _summary_output_tokens, _summary_cost_usd) + + # Write initial summary + checkpoint.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="in_progress", + usage=_usage_dict()) + + def _summary_callback(finding, usage=None): + """Update summary counters after each unit. Called from main thread.""" + nonlocal _summary_completed, _summary_errors, _summary_error_breakdown + nonlocal _summary_input_tokens, _summary_output_tokens, _summary_cost_usd + if finding == "error": + _summary_errors += 1 + _summary_error_breakdown["api"] = _summary_error_breakdown.get("api", 0) + 1 + else: + _summary_completed += 1 + if usage: + _summary_input_tokens += usage.get("input_tokens", 0) + _summary_output_tokens += usage.get("output_tokens", 0) + _summary_cost_usd += usage.get("cost_usd", 0.0) + checkpoint.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="in_progress", + usage=_usage_dict()) # --- Stage 1: Detection --- - results = [] - code_by_route = {} - counts = { - "vulnerable": 0, - "bypassable": 0, - "inconclusive": 0, - "protected": 0, - "safe": 0, - "errors": 0, - } - - for i, unit in enumerate(units): - uid = unit.get("id", f"unit_{i}") - print(f" [{i+1}/{len(units)}] {uid}", file=sys.stderr, end="") - - try: - result = analyze_unit( - client, unit, - use_multifile=True, - json_corrector=json_corrector, - app_context=app_context, - ) - - # Ensure unit_id is always present - result["unit_id"] = uid - - # Ensure finding field is always set (may be None after JSON correction) - if not result.get("finding") and result.get("verdict"): - result["finding"] = result["verdict"].lower() - - results.append(result) - - # Track code for verify step (code_by_route persisted in results.json) - route_key = result.get("route_key", uid) - code_field = unit.get("code", {}) - if isinstance(code_field, dict): - code_by_route[route_key] = code_field.get("primary_code", "") - else: - code_by_route[route_key] = code_field - - # Count verdicts - finding = result.get("finding", "error") - if finding in counts: - counts[finding] += 1 - elif result.get("verdict") == "ERROR": - counts["errors"] += 1 + results, code_by_route = _run_detection( + units, client, json_corrector, app_context, workers, checkpoint=checkpoint, + summary_callback=_summary_callback, + ) - print(f" -> {finding}", file=sys.stderr) - - except Exception as e: - print(f" -> ERROR: {e}", file=sys.stderr) - counts["errors"] += 1 - results.append({ - "unit_id": uid, - "verdict": "ERROR", - "finding": "error", - "error": str(e), - }) + # Auto-retry failed units with transient errors (rate limit, connection, timeout, 5xx) + retryable_indices = [ + i for i, r in enumerate(results) + if r and is_retryable_error(r.get("error")) + ] + if retryable_indices: + rate_limiter = get_rate_limiter() + backoff = rate_limiter.time_until_ready() + if backoff > 0: + print(f"[Analyze] Retrying {len(retryable_indices)} failed units " + f"(waiting {backoff:.0f}s for rate limit to clear)...", file=sys.stderr) + rate_limiter.wait_if_needed() + else: + print(f"[Analyze] Retrying {len(retryable_indices)} failed units (transient errors)...", + file=sys.stderr) + + # Retry sequentially to avoid re-triggering rate limit + for i in retryable_indices: + unit = units[i] + out = _process_unit(client, unit, i, json_corrector, app_context) + results[i] = out["result"] + code_by_route[out["route_key"]] = out["code_for_route"] + + # Update summary: retry succeeded → flip error to completed + if out["finding"] != "error": + _summary_errors = max(0, _summary_errors - 1) + _summary_completed += 1 + retry_usage = out.get("usage", {}) + _summary_input_tokens += retry_usage.get("input_tokens", 0) + _summary_output_tokens += retry_usage.get("output_tokens", 0) + _summary_cost_usd += retry_usage.get("cost_usd", 0.0) + checkpoint.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="in_progress", + usage=_usage_dict()) + + # Update checkpoint + if checkpoint is not None: + uid = out["result"].get("unit_id", f"unit_{i}") + cp_data = { + "result": out["result"], + "route_key": out["route_key"], + "code_for_route": out["code_for_route"], + } + if out.get("usage"): + cp_data["usage"] = out["usage"] + checkpoint.save(uid, cp_data) + + print(f" Retry {i+1}/{len(retryable_indices)}: {out['finding']} (retry)", + file=sys.stderr, flush=True) + + # Write final summary with phase="done" + checkpoint.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="done", + usage=_usage_dict()) tracking.log_usage("Stage 1") + # Compute verdict counts from results + counts = _count_verdicts(results) + # --- Stage 1 Consistency Check --- consistency_corrections = 0 try: @@ -183,14 +485,7 @@ def run_analysis( consistency_corrections += 1 if consistency_corrections: print(f" Consistency corrections: {consistency_corrections}", file=sys.stderr) - # Recount after corrections - counts = {k: 0 for k in counts} - for r in results: - f = r.get("finding", r.get("verdict", "error").lower()) - if f in counts: - counts[f] += 1 - elif r.get("verdict") == "ERROR": - counts["errors"] += 1 + counts = _count_verdicts(results) except ImportError: print("[Analyze] Stage 1 consistency check not available, skipping.", file=sys.stderr) except Exception as e: @@ -215,6 +510,9 @@ def run_analysis( print(f"\n[Analyze] Results written to {results_path}", file=sys.stderr) + # Checkpoints are preserved as a permanent artifact alongside results. + # Final summary (phase="done") was already written before result writing. + # Build return value usage = tracking.get_usage() metrics = AnalysisMetrics( diff --git a/libs/openant-core/core/checkpoint.py b/libs/openant-core/core/checkpoint.py new file mode 100644 index 0000000..7c42f52 --- /dev/null +++ b/libs/openant-core/core/checkpoint.py @@ -0,0 +1,319 @@ +""" +Shared checkpoint utilities for resumable pipeline steps. + +Each LLM-heavy step (enhance, analyze, verify) can save per-unit checkpoint +files so interrupted runs resume where they left off. The checkpoint dir +lives next to the output file: + + {scan_dir}/enhance_checkpoints/ + {scan_dir}/analyze_checkpoints/ + {scan_dir}/verify_checkpoints/ + +On success (all units done), the checkpoint dir is cleaned up automatically. + +Usage: + + cp = StepCheckpoint("enhance", output_dir="/path/to/scan/dir") + completed = cp.load() # set of unit IDs already done + ...process units... + cp.save(unit_id, data_dict) # save one unit + cp.cleanup() # remove dir on success +""" + +import json +import os +import shutil +import sys +from datetime import datetime, timezone + +from utilities.safe_filename import safe_filename +from pathlib import Path + + +SUMMARY_FILE = "_summary.json" + + +class StepCheckpoint: + """Manages per-unit checkpoint files for a pipeline step.""" + + def __init__(self, step_name: str, output_dir: str): + """ + Args: + step_name: Pipeline step name (enhance, analyze, verify). + output_dir: Directory where step outputs live (scan dir). + """ + self.step_name = step_name + self.dir = os.path.join(output_dir, f"{step_name}_checkpoints") + + @property + def exists(self) -> bool: + """True if a checkpoint directory exists with at least one unit file.""" + if not os.path.isdir(self.dir): + return False + return any(f.endswith(".json") and f != SUMMARY_FILE + for f in os.listdir(self.dir)) + + def count(self) -> int: + """Number of per-unit checkpoint files (excludes _summary.json).""" + if not os.path.isdir(self.dir): + return 0 + return sum(1 for f in os.listdir(self.dir) + if f.endswith(".json") and f != SUMMARY_FILE) + + def ensure_dir(self): + """Create the checkpoint directory if it doesn't exist.""" + os.makedirs(self.dir, exist_ok=True) + + def load(self) -> dict[str, dict]: + """Load all checkpointed units. + + Returns: + Dict mapping unit_id -> checkpoint data dict. + """ + results = {} + if not os.path.isdir(self.dir): + return results + + for filename in os.listdir(self.dir): + if not filename.endswith(".json"): + continue + filepath = os.path.join(self.dir, filename) + try: + with open(filepath, "r") as f: + data = json.load(f) + unit_id = data.get("id") + if unit_id: + results[unit_id] = data + except (json.JSONDecodeError, OSError): + continue + + return results + + def load_ids(self, skip_errors: bool = True) -> set[str]: + """Load just the set of completed unit IDs. + + Args: + skip_errors: If True, don't count units that errored as completed. + Supports all four phase formats: enhance, analyze, verify, dynamic-test. + """ + ids = set() + loaded = self.load() + for unit_id, data in loaded.items(): + if skip_errors: + # Enhance: agent_context.error + agent_ctx = data.get("agent_context", {}) + if agent_ctx.get("error"): + continue + # Analyze: result.verdict/finding + result = data.get("result", {}) + if result.get("verdict") == "ERROR" or result.get("finding") == "error": + continue + # Verify: verification empty or correct_finding == "error" + if "verification" in data: + v = data.get("verification", {}) + if not v or v.get("correct_finding") == "error": + continue + # Dynamic-test: top-level status == "ERROR" + if data.get("status") == "ERROR": + continue + ids.add(unit_id) + return ids + + def save(self, unit_id: str, data: dict): + """Save a single unit's checkpoint. + + Args: + unit_id: The unit identifier. + data: Dict to persist (must include 'id' key). + """ + self.ensure_dir() + filename = self._safe_filename(unit_id) + ".json" + filepath = os.path.join(self.dir, filename) + data["id"] = unit_id # ensure id is always present + with open(filepath, "w") as f: + json.dump(data, f, indent=2) + + def write_summary( + self, + total_units: int, + completed: int, + errors: int, + error_breakdown: dict, + phase: str = "in_progress", + usage: dict | None = None, + ): + """Write/overwrite _summary.json in checkpoint dir. + + Called from the main thread (as_completed loop) — no lock needed. + + Args: + total_units: Total units in the step. + completed: Number of successfully completed units. + errors: Number of errored units. + error_breakdown: Dict of error_type -> count. + phase: ``"in_progress"`` or ``"done"``. + usage: Optional dict with ``input_tokens``, ``output_tokens``, + ``cost_usd`` accumulated so far for this step. + """ + self.ensure_dir() + filepath = os.path.join(self.dir, SUMMARY_FILE) + data = { + "step": self.step_name, + "phase": phase, + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + "total_units": total_units, + "completed": completed, + "errors": errors, + "error_breakdown": error_breakdown, + } + if usage is not None: + data["usage"] = usage + with open(filepath, "w") as f: + json.dump(data, f, indent=2) + + @staticmethod + def read_summary(checkpoint_dir: str) -> dict | None: + """Read _summary.json from a checkpoint directory. + + Returns: + Parsed dict or None if not found / unreadable. + """ + filepath = os.path.join(checkpoint_dir, SUMMARY_FILE) + if not os.path.isfile(filepath): + return None + try: + with open(filepath, "r") as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + return None + + def cleanup(self): + """Remove the checkpoint directory (call on successful completion).""" + if os.path.isdir(self.dir): + shutil.rmtree(self.dir) + print(f"[{self.step_name}] Cleaned up checkpoints", file=sys.stderr) + + _safe_filename = staticmethod(safe_filename) + + @staticmethod + def status(checkpoint_dir: str) -> dict: + """Return accurate checkpoint status by reading actual checkpoint files. + + This is the single source of truth for checkpoint counts. The Go CLI + calls this via ``python -m openant checkpoint-status`` instead of + doing its own file scanning. + + Returns: + Dict with keys: step, checkpoint_dir, completed, errors, + total_files, total_units, phase, error_breakdown. + """ + # Derive step name from directory name (e.g. "enhance_checkpoints" → "enhance") + dir_name = os.path.basename(checkpoint_dir.rstrip("/")) + step = dir_name.replace("_checkpoints", "") if dir_name.endswith("_checkpoints") else dir_name + + result = { + "step": step, + "checkpoint_dir": checkpoint_dir, + "completed": 0, + "errors": 0, + "total_files": 0, + "total_units": 0, + "phase": "unknown", + "error_breakdown": {}, + } + + if not os.path.isdir(checkpoint_dir): + return result + + # Read _summary.json for total_units and phase + summary = StepCheckpoint.read_summary(checkpoint_dir) + if summary: + result["total_units"] = summary.get("total_units", 0) + result["phase"] = summary.get("phase", "unknown") + + # Read all checkpoint files and classify each + completed = 0 + errors = 0 + error_breakdown = {} + + for filename in os.listdir(checkpoint_dir): + if not filename.endswith(".json") or filename == SUMMARY_FILE: + continue + filepath = os.path.join(checkpoint_dir, filename) + try: + with open(filepath, "r") as f: + data = json.load(f) + except (json.JSONDecodeError, OSError): + errors += 1 + error_breakdown["unreadable"] = error_breakdown.get("unreadable", 0) + 1 + continue + + unit_id = data.get("id") + if not unit_id: + errors += 1 + error_breakdown["missing_id"] = error_breakdown.get("missing_id", 0) + 1 + continue + + # Check for errors. Each phase stores checkpoint data differently: + # - enhance: agent_context.error is set + # - analyze: result.verdict == "ERROR" or result.finding == "error" + # - verify: verification is empty or verification.correct_finding == "error" + # - dynamic-test: top-level status == "ERROR" + is_error = False + err_type = None + + # Enhance-style: agent_context.error + agent_ctx = data.get("agent_context", {}) + if agent_ctx.get("error"): + is_error = True + err = agent_ctx["error"] + err_type = err.get("type", "unknown") if isinstance(err, dict) else "unknown" + + # Analyze-style: result.verdict or result.finding + elif "result" in data: + res = data.get("result", {}) + if res.get("verdict") == "ERROR" or res.get("finding") == "error": + is_error = True + err_type = "analysis_error" + + # Verify-style: verification empty or correct_finding == "error" + elif "verification" in data: + v = data.get("verification", {}) + if not v or v.get("correct_finding") == "error": + is_error = True + err_type = "verification_error" + + # Dynamic-test-style: top-level status == "ERROR" + elif data.get("status") == "ERROR": + is_error = True + err_type = "test_error" + + if is_error: + errors += 1 + if err_type: + error_breakdown[err_type] = error_breakdown.get(err_type, 0) + 1 + else: + completed += 1 + + result["completed"] = completed + result["errors"] = errors + result["total_files"] = completed + errors + result["error_breakdown"] = error_breakdown + + return result + + +def auto_checkpoint_dir(output_path: str, step_name: str) -> str: + """Derive the checkpoint directory from the output file path. + + For enhance: output_path is dataset_enhanced.json + -> same dir / enhance_checkpoints/ + For analyze: output_dir contains results.json + -> output_dir / analyze_checkpoints/ + For verify: output_dir contains results_verified.json + -> output_dir / verify_checkpoints/ + """ + if os.path.isdir(output_path): + return os.path.join(output_path, f"{step_name}_checkpoints") + return os.path.join(os.path.dirname(os.path.abspath(output_path)), + f"{step_name}_checkpoints") diff --git a/libs/openant-core/core/dynamic_tester.py b/libs/openant-core/core/dynamic_tester.py index ed10fcc..7c16603 100644 --- a/libs/openant-core/core/dynamic_tester.py +++ b/libs/openant-core/core/dynamic_tester.py @@ -49,9 +49,6 @@ def run_tests( os.makedirs(output_dir, exist_ok=True) - # Reset tracking - tracking.reset_tracking() - # Check how many findings to test with open(pipeline_output_path) as f: pipeline_data = json.load(f) diff --git a/libs/openant-core/core/enhancer.py b/libs/openant-core/core/enhancer.py index 33052a9..fef1453 100644 --- a/libs/openant-core/core/enhancer.py +++ b/libs/openant-core/core/enhancer.py @@ -3,6 +3,10 @@ Wraps utilities/context_enhancer.py, providing a path-based interface for both agentic and single-shot enhancement modes. + +Checkpoints are always enabled for agentic mode. Per-unit progress is saved +to ``{output_dir}/enhance_checkpoints/`` so interrupted runs can resume +automatically. On successful completion the checkpoint dir is removed. """ import json @@ -12,6 +16,7 @@ from core.schemas import EnhanceResult, UsageInfo from core import tracking from core.progress import ProgressReporter +from utilities.rate_limiter import configure_rate_limiter def enhance_dataset( @@ -22,6 +27,8 @@ def enhance_dataset( mode: str = "agentic", checkpoint_path: str | None = None, model: str = "sonnet", + workers: int = 8, + backoff_seconds: int = 30, ) -> EnhanceResult: """Enhance a parsed dataset with security context. @@ -32,18 +39,26 @@ def enhance_dataset( repo_path: Path to the repository (required for agentic mode). mode: "agentic" (thorough, tool-use) or "single-shot" (fast, cheaper). checkpoint_path: Path to save/resume checkpoint (agentic mode only). + If None, auto-derived from output_path. model: "sonnet" (default, cost-effective). + workers: Number of parallel workers (default: 8). + backoff_seconds: Seconds to wait on rate limit before retry (default: 30). Returns: EnhanceResult with output path, stats, and usage. """ - # Reset tracking for this step - tracking.reset_tracking() + # Configure global rate limiter + configure_rate_limiter(backoff_seconds=float(backoff_seconds)) model_id = "claude-sonnet-4-20250514" if model == "sonnet" else "claude-opus-4-6" print(f"[Enhance] Mode: {mode}", file=sys.stderr) print(f"[Enhance] Model: {model_id}", file=sys.stderr) + # Auto-derive checkpoint path for agentic mode + if mode == "agentic" and checkpoint_path is None: + output_dir = os.path.dirname(os.path.abspath(output_path)) + checkpoint_path = os.path.join(output_dir, "enhance_checkpoints") + # Import here to avoid heavy imports at module load from utilities.llm_client import AnthropicClient, get_global_tracker from utilities.context_enhancer import ContextEnhancer @@ -70,6 +85,9 @@ def _on_unit_done(unit_id: str, classification: str, unit_elapsed: float): unit_elapsed=unit_elapsed, ) + def _on_restored(count: int): + progress.completed = count + # Run enhancement if mode == "agentic": if not analyzer_output_path: @@ -81,40 +99,52 @@ def _on_unit_done(unit_id: str, classification: str, unit_elapsed: float): repo_path=repo_path, checkpoint_path=checkpoint_path, progress_callback=_on_unit_done, + restored_callback=_on_restored, + workers=workers, ) elif mode == "single-shot": enhanced = enhancer.enhance_dataset( dataset, progress_callback=_on_unit_done, + workers=workers, ) else: raise ValueError(f"Unknown enhancement mode: {mode}. Use 'agentic' or 'single-shot'.") progress.finish() - # Write enhanced dataset - os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) - with open(output_path, "w") as f: - json.dump(enhanced, f, indent=2) - - print(f"[Enhance] Enhanced dataset: {output_path}", file=sys.stderr) - - # Compute classification distribution + # Compute classification distribution and error summary FIRST (before cleanup decision) classifications = {} error_count = 0 + error_summary = {} context_key = "agent_context" if mode == "agentic" else "llm_context" for unit in enhanced.get("units", []): ctx = unit.get(context_key, {}) if ctx.get("error"): error_count += 1 + err = ctx["error"] + if isinstance(err, dict): + err_type = err.get("type", "unknown") + else: + err_type = "legacy_string" + error_summary[err_type] = error_summary.get(err_type, 0) + 1 continue cls = ctx.get("security_classification", "unknown") classifications[cls] = classifications.get(cls, 0) + 1 + # Checkpoints are preserved as a permanent artifact alongside results. + # Final summary (phase="done") is written by context_enhancer. + + # Write enhanced dataset + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) + with open(output_path, "w") as f: + json.dump(enhanced, f, indent=2) + + print(f"[Enhance] Enhanced dataset: {output_path}", file=sys.stderr) print(f"[Enhance] Classifications: {classifications}", file=sys.stderr) if error_count: - print(f"[Enhance] Errors: {error_count}", file=sys.stderr) + print(f"[Enhance] Errors: {error_count} ({error_summary})", file=sys.stderr) tracking.log_usage("Enhance") @@ -124,6 +154,7 @@ def _on_unit_done(unit_id: str, classification: str, unit_elapsed: float): enhanced_dataset_path=output_path, units_enhanced=len(units) - error_count, error_count=error_count, + error_summary=error_summary, classifications=classifications, usage=usage, ) diff --git a/libs/openant-core/core/parser_adapter.py b/libs/openant-core/core/parser_adapter.py index 8e3ecc7..3969897 100644 --- a/libs/openant-core/core/parser_adapter.py +++ b/libs/openant-core/core/parser_adapter.py @@ -30,7 +30,7 @@ def detect_language(repo_path: str) -> str: "python", "javascript", or "go" """ repo = Path(repo_path) - counts = {"python": 0, "javascript": 0, "go": 0, "c": 0, "ruby": 0, "php": 0} + counts = {"python": 0, "javascript": 0, "go": 0, "c": 0, "ruby": 0, "php": 0, "zig": 0} for f in repo.rglob("*"): if not f.is_file(): @@ -56,11 +56,13 @@ def detect_language(repo_path: str) -> str: counts["ruby"] += 1 elif suffix == ".php": counts["php"] += 1 + elif suffix == ".zig": + counts["zig"] += 1 if not any(counts.values()): raise ValueError( f"No supported source files found in {repo_path}. " - "Supported languages: Python, JavaScript/TypeScript, Go, C/C++, Ruby, PHP." + "Supported languages: Python, JavaScript/TypeScript, Go, C/C++, Ruby, PHP, Zig." ) return max(counts, key=counts.get) @@ -116,6 +118,8 @@ def parse_repository( return _parse_ruby(repo_path, output_dir, processing_level, skip_tests, name) elif language == "php": return _parse_php(repo_path, output_dir, processing_level, skip_tests, name) + elif language == "zig": + return _parse_zig(repo_path, output_dir, processing_level, skip_tests, name) else: raise ValueError(f"Unsupported language: {language}") @@ -594,3 +598,63 @@ def _parse_php(repo_path: str, output_dir: str, processing_level: str, skip_test language="php", processing_level=processing_level, ) + + +# --------------------------------------------------------------------------- +# Zig parser +# --------------------------------------------------------------------------- + +def _parse_zig(repo_path: str, output_dir: str, processing_level: str, skip_tests: bool = True, name: str = None) -> ParseResult: + """Invoke the Zig parser. + + The Zig parser uses tree-sitter for function extraction and call graph + building. Invoked via subprocess (same pattern as other parsers). + + Requires: tree-sitter, tree-sitter-zig + """ + print("[Parser] Running Zig parser...", file=sys.stderr) + + parser_script = _CORE_ROOT / "parsers" / "zig" / "test_pipeline.py" + + cmd = [ + sys.executable, str(parser_script), + repo_path, + "--output", output_dir, + "--processing-level", processing_level, + ] + + if name: + cmd.extend(["--name", name]) + if skip_tests: + cmd.append("--skip-tests") + + result = subprocess.run( + cmd, + stdout=sys.stderr, + stderr=sys.stderr, + cwd=str(_CORE_ROOT), + timeout=1800, + ) + + if result.returncode != 0: + raise RuntimeError(f"Zig parser failed with exit code {result.returncode}") + + dataset_path = os.path.join(output_dir, "dataset.json") + analyzer_output_path = os.path.join(output_dir, "analyzer_output.json") + + # Count units + units_count = 0 + if os.path.exists(dataset_path): + with open(dataset_path) as f: + data = json.load(f) + units_count = len(data.get("units", [])) + + print(f" Zig parser complete: {units_count} units", file=sys.stderr) + + return ParseResult( + dataset_path=dataset_path, + analyzer_output_path=analyzer_output_path if os.path.exists(analyzer_output_path) else None, + units_count=units_count, + language="zig", + processing_level=processing_level, + ) diff --git a/libs/openant-core/core/progress.py b/libs/openant-core/core/progress.py index dade4e4..3afcfbb 100644 --- a/libs/openant-core/core/progress.py +++ b/libs/openant-core/core/progress.py @@ -6,6 +6,7 @@ """ import sys +import threading import time from typing import Optional @@ -52,12 +53,15 @@ def __init__( total: int, tracker=None, summary_interval: int | None = None, + completed: int = 0, ): self.step_name = step_name self.total = total self.tracker = tracker self.start_time = time.monotonic() - self.completed = 0 + self.completed = completed + self._lock = threading.Lock() + self._last_cost = self._get_cost() # snapshot for per-unit delta # Width for the counter so alignment stays consistent self._width = len(str(total)) @@ -100,38 +104,41 @@ def report( detail: Extra info (e.g. classification, verdict). unit_elapsed: How long this specific unit took, in seconds. """ - self.completed += 1 - elapsed = time.monotonic() - self.start_time - eta = self._estimate_remaining(elapsed) - cost = self._get_cost() - - # Truncate label if too long - if len(unit_label) > 50: - unit_label = unit_label[:47] + "..." - - # Build the progress line - parts = [ - f"[{self.step_name}]", - f"{self.completed:>{self._width}}/{self.total}", - unit_label, - ] - if detail: - parts.append(detail) - if unit_elapsed > 0: - parts.append(f"{unit_elapsed:.1f}s") - - meta = f"(elapsed {_fmt_duration(elapsed)}, ETA {eta}, {_fmt_cost(cost)})" - parts.append(meta) - - line = " ".join(parts) - print(line, file=sys.stderr, flush=True) - - # Periodic summary - if ( - self.completed % self._summary_interval == 0 - and self.completed < self.total - ): - self._print_summary(elapsed, cost) + with self._lock: + self.completed += 1 + elapsed = time.monotonic() - self.start_time + eta = self._estimate_remaining(elapsed) + total_cost = self._get_cost() + unit_cost = total_cost - self._last_cost + self._last_cost = total_cost + + # Truncate label if too long + if len(unit_label) > 50: + unit_label = unit_label[:47] + "..." + + # Build the progress line — show per-unit cost, not cumulative + parts = [ + f"[{self.step_name}]", + f"{self.completed:>{self._width}}/{self.total}", + unit_label, + ] + if detail: + parts.append(detail) + if unit_elapsed > 0: + parts.append(f"{unit_elapsed:.1f}s") + + meta = f"(elapsed {_fmt_duration(elapsed)}, ETA {eta}, {_fmt_cost(unit_cost)})" + parts.append(meta) + + line = " ".join(parts) + print(line, file=sys.stderr, flush=True) + + # Periodic summary — shows cumulative total + if ( + self.completed % self._summary_interval == 0 + and self.completed < self.total + ): + self._print_summary(elapsed, total_cost) def _print_summary(self, elapsed: float, cost: float) -> None: """Print a highlighted summary line.""" @@ -152,14 +159,15 @@ def _print_summary(self, elapsed: float, cost: float) -> None: def finish(self) -> None: """Print a final summary line when the step is done.""" - elapsed = time.monotonic() - self.start_time - cost = self._get_cost() - avg = elapsed / self.completed if self.completed else 0 - - line = ( - f"[{self.step_name}] Done: " - f"{self.completed}/{self.total} units in {_fmt_duration(elapsed)} | " - f"avg {avg:.1f}s/unit | " - f"cost {_fmt_cost(cost)}" - ) - print(line, file=sys.stderr, flush=True) + with self._lock: + elapsed = time.monotonic() - self.start_time + cost = self._get_cost() + avg = elapsed / self.completed if self.completed else 0 + + line = ( + f"[{self.step_name}] Done: " + f"{self.completed}/{self.total} units in {_fmt_duration(elapsed)} | " + f"avg {avg:.1f}s/unit | " + f"cost {_fmt_cost(cost)}" + ) + print(line, file=sys.stderr, flush=True) diff --git a/libs/openant-core/core/reporter.py b/libs/openant-core/core/reporter.py index 2131e01..4f604dd 100644 --- a/libs/openant-core/core/reporter.py +++ b/libs/openant-core/core/reporter.py @@ -191,7 +191,7 @@ def build_pipeline_output( print(f" pipeline_output.json: {len(findings_data)} findings", file=sys.stderr) print(f" Written to {output_path}", file=sys.stderr) - return output_path + return output_path, len(findings_data) def generate_html_report( @@ -213,8 +213,14 @@ def generate_html_report( """ print("[Report] Generating HTML report...", file=sys.stderr) + # Pass step reports dir so the HTML report can include cost/time breakdown + step_reports_dir = os.path.dirname(os.path.abspath(results_path)) + script = _CORE_ROOT / "generate_report.py" - cmd = [sys.executable, str(script), results_path, dataset_path, output_path] + cmd = [ + sys.executable, str(script), results_path, dataset_path, output_path, + "--step-reports-dir", step_reports_dir, + ] result = subprocess.run(cmd, stdout=sys.stderr, stderr=sys.stderr, cwd=str(_CORE_ROOT)) @@ -262,31 +268,45 @@ def generate_summary_report( ) -> ReportResult: """Generate LLM-based summary report (Markdown). - Wraps report/generator.py. Requires ANTHROPIC_API_KEY. + Calls report/generator.py directly (in-process) for proper cost tracking. Args: - results_path: Path to results JSON (pipeline output format). + results_path: Path to pipeline_output.json or results JSON. output_path: Path for the output Markdown file. Returns: - ReportResult with the output path. + ReportResult with the output path and usage info. """ + import json + from report.generator import generate_summary_report as _generate_summary, merge_dynamic_results + from report.schema import validate_pipeline_output, ValidationError + print("[Report] Generating summary report (LLM)...", file=sys.stderr) - # Use the report module via subprocess - cmd = [ - sys.executable, "-m", "report", - "summary", results_path, - "-o", output_path, - ] + with open(results_path) as f: + pipeline_data = json.load(f) - result = subprocess.run(cmd, stdout=sys.stderr, stderr=sys.stderr, cwd=str(_CORE_ROOT)) + # Merge dynamic test results if available + pipeline_data = merge_dynamic_results(pipeline_data, results_path) - if result.returncode != 0: - raise RuntimeError(f"Summary report generation failed (exit code {result.returncode})") + try: + validate_pipeline_output(pipeline_data) + except ValidationError as e: + raise RuntimeError(f"Invalid pipeline output: {e}") + + report_text, usage = _generate_summary(pipeline_data) + + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) + with open(output_path, "w") as f: + f.write(report_text) print(f" Summary report: {output_path}", file=sys.stderr) - return ReportResult(output_path=output_path, format="summary") + print(f" Cost: ${usage['cost_usd']:.4f} ({usage['total_tokens']:,} tokens)", file=sys.stderr) + + # Record in global tracker so step_context picks it up + _record_usage_in_tracker(usage) + + return ReportResult(output_path=output_path, format="summary", usage=_usage_to_info(usage)) def generate_disclosure_docs( @@ -295,27 +315,111 @@ def generate_disclosure_docs( ) -> ReportResult: """Generate per-vulnerability disclosure documents. - Wraps report/generator.py disclosures command. Requires ANTHROPIC_API_KEY. + Calls report/generator.py directly (in-process) for proper cost tracking. Args: - results_path: Path to results JSON (pipeline output format). + results_path: Path to pipeline_output.json or results JSON. output_dir: Directory for disclosure Markdown files. Returns: - ReportResult with the output directory path. + ReportResult with the output directory path and usage info. """ + import json + from concurrent.futures import ThreadPoolExecutor, as_completed + from report.generator import generate_disclosure as _generate_disclosure, _merge_usage, merge_dynamic_results + from report.schema import validate_pipeline_output, ValidationError + print("[Report] Generating disclosure documents (LLM)...", file=sys.stderr) - cmd = [ - sys.executable, "-m", "report", - "disclosures", results_path, - "-o", output_dir, - ] + with open(results_path) as f: + pipeline_data = json.load(f) - result = subprocess.run(cmd, stdout=sys.stderr, stderr=sys.stderr, cwd=str(_CORE_ROOT)) + # Merge dynamic test results if available + pipeline_data = merge_dynamic_results(pipeline_data, results_path) - if result.returncode != 0: - raise RuntimeError(f"Disclosure generation failed (exit code {result.returncode})") + try: + validate_pipeline_output(pipeline_data) + except ValidationError as e: + raise RuntimeError(f"Invalid pipeline output: {e}") + + os.makedirs(output_dir, exist_ok=True) + + product_name = pipeline_data["repository"]["name"] + all_usages = [] + count = 0 + + # Collect confirmed findings first + confirmed = [ + (i, finding) for i, finding in enumerate(pipeline_data["findings"], 1) + if finding.get("stage2_verdict") in ("confirmed", "agreed", "vulnerable") + ] - print(f" Disclosures: {output_dir}", file=sys.stderr) - return ReportResult(output_path=output_dir, format="disclosure") + if not confirmed: + print(" No confirmed vulnerabilities to generate disclosures for.", file=sys.stderr) + else: + print(f" Generating {len(confirmed)} disclosures in parallel (8 workers)...", + file=sys.stderr) + + def _one(args): + i, finding = args + disclosure_text, usage = _generate_disclosure(finding, product_name) + safe_name = finding["short_name"].replace(" ", "_").upper() + filename = f"DISCLOSURE_{i:02d}_{safe_name}.md" + filepath = os.path.join(output_dir, filename) + with open(filepath, "w") as f: + f.write(disclosure_text) + return finding["short_name"], filepath, usage + + executor = ThreadPoolExecutor(max_workers=8) + futures = {executor.submit(_one, item): item for item in confirmed} + try: + for future in as_completed(futures): + name, filepath, usage = future.result() + all_usages.append(usage) + count += 1 + print(f" [{count}/{len(confirmed)}] {name} -> {filepath}", + file=sys.stderr) + except KeyboardInterrupt: + print("\n[Report] Interrupted — cancelling pending disclosures...", + file=sys.stderr, flush=True) + executor.shutdown(wait=False, cancel_futures=True) + raise + executor.shutdown(wait=False) + + merged_usage = _merge_usage(all_usages) if all_usages else {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0, "cost_usd": 0.0} + + print(f" Disclosures: {count} files in {output_dir}", file=sys.stderr) + print(f" Cost: ${merged_usage['cost_usd']:.4f} ({merged_usage['total_tokens']:,} tokens)", file=sys.stderr) + + # Record in global tracker so step_context picks it up + _record_usage_in_tracker(merged_usage) + + return ReportResult(output_path=output_dir, format="disclosure", usage=_usage_to_info(merged_usage)) + + +def _record_usage_in_tracker(usage: dict): + """Record usage in the global TokenTracker so step_context captures it.""" + try: + from utilities.llm_client import get_global_tracker + tracker = get_global_tracker() + # Record as a single aggregated call + if usage.get("total_tokens", 0) > 0: + tracker.record_call( + model="claude-opus-4-6", + input_tokens=usage["input_tokens"], + output_tokens=usage["output_tokens"], + ) + except Exception: + pass # Best effort — don't break report generation + + +def _usage_to_info(usage: dict): + """Convert a usage dict to a UsageInfo dataclass.""" + from core.schemas import UsageInfo + return UsageInfo( + total_calls=1, + total_input_tokens=usage.get("input_tokens", 0), + total_output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + total_cost_usd=usage.get("cost_usd", 0.0), + ) diff --git a/libs/openant-core/core/scanner.py b/libs/openant-core/core/scanner.py index 8678c11..08e2dfe 100644 --- a/libs/openant-core/core/scanner.py +++ b/libs/openant-core/core/scanner.py @@ -4,7 +4,7 @@ Runs the full pipeline: Parse → App Context → Enhance → Detect → Verify - → Build pipeline_output → Report → Dynamic Test + → Build pipeline_output → Dynamic Test → Report This is the implementation behind ``open-ant scan ``. @@ -53,6 +53,8 @@ def scan_repository( enhance: bool = True, enhance_mode: str = "agentic", dynamic_test: bool = False, + workers: int = 8, + backoff_seconds: int = 30, ) -> ScanResult: """Scan a repository for vulnerabilities. @@ -64,8 +66,8 @@ def scan_repository( 4. **Detect** — Stage 1 vulnerability detection 5. **Verify** — Stage 2 attacker simulation (optional) 6. **Build pipeline_output.json** — bridge format for reports + dynamic tests - 7. **Report** — summary + disclosure documents (optional) - 8. **Dynamic Test** — Docker-isolated exploit testing (optional, off by default) + 7. **Dynamic Test** — Docker-isolated exploit testing (optional, off by default) + 8. **Report** — summary + disclosure documents (optional, merges dynamic test results) Args: repo_path: Path to the repository to scan. @@ -81,6 +83,8 @@ def scan_repository( enhance: If True, run agentic/single-shot context enhancement. enhance_mode: ``"agentic"`` (thorough) or ``"single-shot"`` (fast). dynamic_test: If True, run Docker-isolated dynamic testing (requires Docker). + workers: Number of parallel workers for LLM steps (default: 8). + backoff_seconds: Seconds to wait when rate-limited (default: 30). Returns: ScanResult with paths to all generated files and metrics. @@ -108,7 +112,7 @@ def _step_label(name: str) -> str: _print_banner(repo_path, output_dir, language, processing_level, verify, generate_context, enhance, enhance_mode, - generate_report, dynamic_test) + generate_report, dynamic_test, workers, backoff_seconds) # --------------------------------------------------------------- # Step 1: Parse @@ -210,6 +214,9 @@ def _step_label(name: str) -> str: analyzer_output_path=parse_result.analyzer_output_path, repo_path=repo_path, mode=enhance_mode, + workers=workers, + backoff_seconds=backoff_seconds, + # checkpoint_path auto-derived from output_path ) ctx.summary = { @@ -218,6 +225,8 @@ def _step_label(name: str) -> str: "classifications": enhance_result.classifications, "mode": enhance_mode, } + if enhance_result.error_summary: + ctx.summary["error_summary"] = enhance_result.error_summary ctx.outputs = { "enhanced_dataset_path": enhance_result.enhanced_dataset_path, } @@ -228,6 +237,8 @@ def _step_label(name: str) -> str: print(f" Enhanced: {enhance_result.units_enhanced} units", file=sys.stderr) print(f" Classifications: {enhance_result.classifications}", file=sys.stderr) + if enhance_result.error_summary: + print(f" Errors: {enhance_result.error_count} ({enhance_result.error_summary})", file=sys.stderr) else: print(_step_label("Skipping enhancement (--no-enhance)."), file=sys.stderr) result.skipped_steps.append("enhance") @@ -253,6 +264,8 @@ def _step_label(name: str) -> str: repo_path=repo_path, limit=limit, model=model, + workers=workers, + backoff_seconds=backoff_seconds, ) ctx.summary = { @@ -300,6 +313,8 @@ def _step_label(name: str) -> str: analyzer_output_path=parse_result.analyzer_output_path, app_context_path=app_context_path, repo_path=repo_path, + workers=workers, + backoff_seconds=backoff_seconds, ) ctx.summary = { @@ -374,54 +389,7 @@ def _step_label(name: str) -> str: print(file=sys.stderr) # --------------------------------------------------------------- - # Step 7: Report (optional) - # --------------------------------------------------------------- - if generate_report: - from core.reporter import generate_summary_report, generate_disclosure_docs - - print(_step_label("Generating reports..."), file=sys.stderr) - - with step_context("report", output_dir, inputs={ - "pipeline_output_path": pipeline_output_path, - }) as ctx: - report_dir = os.path.join(output_dir, "report") - os.makedirs(report_dir, exist_ok=True) - - summary_path = os.path.join(report_dir, "SUMMARY_REPORT.md") - disclosures_dir = os.path.join(report_dir, "disclosures") - - outputs = {} - - try: - generate_summary_report(pipeline_output_path, summary_path) - result.summary_path = summary_path - outputs["summary_path"] = summary_path - print(f" Summary: {summary_path}", file=sys.stderr) - except Exception as e: - print(f" WARNING: Summary report failed: {e}", file=sys.stderr) - ctx.errors.append(f"Summary report: {e}") - - # Only generate disclosures if there are findings - if has_findings: - try: - generate_disclosure_docs(pipeline_output_path, disclosures_dir) - outputs["disclosures_dir"] = disclosures_dir - print(f" Disclosures: {disclosures_dir}", file=sys.stderr) - except Exception as e: - print(f" WARNING: Disclosure docs failed: {e}", file=sys.stderr) - ctx.errors.append(f"Disclosure docs: {e}") - - ctx.summary = {"formats_generated": list(outputs.keys())} - ctx.outputs = outputs - - collected_step_reports.append(_load_step_report(output_dir, "report")) - else: - print(_step_label("Skipping report generation (--no-report)."), file=sys.stderr) - result.skipped_steps.append("report") - print(file=sys.stderr) - - # --------------------------------------------------------------- - # Step 8: Dynamic Test (optional, off by default) + # Step 7: Dynamic Test (optional, off by default) # --------------------------------------------------------------- if dynamic_test and has_findings: if not shutil.which("docker"): @@ -470,6 +438,53 @@ def _step_label(name: str) -> str: result.skipped_steps.append("dynamic-test") print(file=sys.stderr) + # --------------------------------------------------------------- + # Step 8: Report (optional) + # --------------------------------------------------------------- + if generate_report: + from core.reporter import generate_summary_report, generate_disclosure_docs + + print(_step_label("Generating reports..."), file=sys.stderr) + + with step_context("report", output_dir, inputs={ + "pipeline_output_path": pipeline_output_path, + }) as ctx: + report_dir = os.path.join(output_dir, "report") + os.makedirs(report_dir, exist_ok=True) + + summary_path = os.path.join(report_dir, "SUMMARY_REPORT.md") + disclosures_dir = os.path.join(report_dir, "disclosures") + + outputs = {} + + try: + generate_summary_report(pipeline_output_path, summary_path) + result.summary_path = summary_path + outputs["summary_path"] = summary_path + print(f" Summary: {summary_path}", file=sys.stderr) + except Exception as e: + print(f" WARNING: Summary report failed: {e}", file=sys.stderr) + ctx.errors.append(f"Summary report: {e}") + + # Only generate disclosures if there are findings + if has_findings: + try: + generate_disclosure_docs(pipeline_output_path, disclosures_dir) + outputs["disclosures_dir"] = disclosures_dir + print(f" Disclosures: {disclosures_dir}", file=sys.stderr) + except Exception as e: + print(f" WARNING: Disclosure docs failed: {e}", file=sys.stderr) + ctx.errors.append(f"Disclosure docs: {e}") + + ctx.summary = {"formats_generated": list(outputs.keys())} + ctx.outputs = outputs + + collected_step_reports.append(_load_step_report(output_dir, "report")) + else: + print(_step_label("Skipping report generation (--no-report)."), file=sys.stderr) + result.skipped_steps.append("report") + print(file=sys.stderr) + # --------------------------------------------------------------- # Final: Aggregate scan report # --------------------------------------------------------------- @@ -587,6 +602,8 @@ def _print_banner( enhance_mode: str, generate_report: bool, dynamic_test: bool, + workers: int = 8, + backoff_seconds: int = 30, ) -> None: """Print the scan configuration banner.""" print("=" * 60, file=sys.stderr) @@ -601,6 +618,9 @@ def _print_banner( print(f" App context: {generate_context}", file=sys.stderr) print(f" Report: {generate_report}", file=sys.stderr) print(f" Dynamic test: {dynamic_test}", file=sys.stderr) + workers_label = f"{workers} (parallel)" if workers > 1 else "1 (sequential)" + print(f" Workers: {workers_label}", file=sys.stderr) + print(f" Rate backoff: {backoff_seconds}s", file=sys.stderr) print("=" * 60, file=sys.stderr) print(file=sys.stderr) diff --git a/libs/openant-core/core/schemas.py b/libs/openant-core/core/schemas.py index 3e97307..88d30d4 100644 --- a/libs/openant-core/core/schemas.py +++ b/libs/openant-core/core/schemas.py @@ -13,7 +13,7 @@ import json import os from dataclasses import dataclass, field, asdict -from datetime import datetime +from datetime import datetime, timezone from typing import Any @@ -104,9 +104,14 @@ class ReportResult: """Result of `open-ant report`.""" output_path: str format: str = "html" + usage: UsageInfo = field(default_factory=UsageInfo) def to_dict(self) -> dict: - return asdict(self) + return { + "output_path": self.output_path, + "format": self.format, + "usage": self.usage.to_dict(), + } @dataclass @@ -162,17 +167,21 @@ class EnhanceResult: enhanced_dataset_path: str units_enhanced: int = 0 error_count: int = 0 + error_summary: dict = field(default_factory=dict) classifications: dict = field(default_factory=dict) usage: UsageInfo = field(default_factory=UsageInfo) def to_dict(self) -> dict: - return { + result = { "enhanced_dataset_path": self.enhanced_dataset_path, "units_enhanced": self.units_enhanced, "error_count": self.error_count, "classifications": self.classifications, "usage": self.usage.to_dict(), } + if self.error_summary: + result["error_summary"] = self.error_summary + return result # --------------------------------------------------------------------------- @@ -250,7 +259,7 @@ class StepReport: def __post_init__(self): if not self.timestamp: - self.timestamp = datetime.utcnow().isoformat() + "Z" + self.timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") def to_dict(self) -> dict: return asdict(self) diff --git a/libs/openant-core/core/step_report.py b/libs/openant-core/core/step_report.py index 28a6355..065d162 100644 --- a/libs/openant-core/core/step_report.py +++ b/libs/openant-core/core/step_report.py @@ -16,7 +16,7 @@ import time import traceback from contextlib import contextmanager -from datetime import datetime +from datetime import datetime, timezone from core.schemas import StepReport @@ -38,7 +38,7 @@ def step_context(step: str, output_dir: str, inputs: dict | None = None): """ report = StepReport( step=step, - timestamp=datetime.utcnow().isoformat() + "Z", + timestamp=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), inputs=inputs or {}, ) diff --git a/libs/openant-core/core/verifier.py b/libs/openant-core/core/verifier.py index b4a0773..0f00fc6 100644 --- a/libs/openant-core/core/verifier.py +++ b/libs/openant-core/core/verifier.py @@ -3,6 +3,10 @@ Wraps FindingVerifier to run Stage 2 verification on Stage 1 results. Only verifies findings classified as vulnerable or bypassable. + +Checkpoints are always enabled. Per-finding results are saved to +``{output_dir}/verify_checkpoints/`` so interrupted runs can resume. +On successful completion the checkpoint dir is removed. """ import json @@ -12,6 +16,7 @@ from core.schemas import VerifyResult, UsageInfo from core import tracking +from core.checkpoint import StepCheckpoint from core.progress import ProgressReporter from utilities.llm_client import TokenTracker, get_global_tracker @@ -33,12 +38,19 @@ def run_verification( analyzer_output_path: str, app_context_path: str | None = None, repo_path: str | None = None, + workers: int = 8, + checkpoint_path: str | None = None, + backoff_seconds: int = 30, ) -> VerifyResult: """Run Stage 2 attacker-simulation verification on Stage 1 results. Only findings with verdict ``vulnerable`` or ``bypassable`` are verified. Results are written to ``results_verified.json`` in *output_dir*. + Checkpoints are always enabled. Per-finding verification results are + saved to ``{output_dir}/verify_checkpoints/`` so interrupted runs + resume automatically. + Args: results_path: Path to ``results.json`` from the analyze step. output_dir: Directory to write ``results_verified.json``. @@ -46,14 +58,25 @@ def run_verification( repository index / tool use). app_context_path: Optional path to ``application_context.json``. repo_path: Optional path to the repository root (passed to index). + checkpoint_path: Path to checkpoint directory. If None, auto-derived + from output_dir. + workers: Number of parallel workers (default: 8). + backoff_seconds: Seconds to wait on rate limit before retry (default: 30). Returns: VerifyResult with paths, counts, and usage info. """ os.makedirs(output_dir, exist_ok=True) - # Reset tracking for this verification run - tracking.reset_tracking() + # Configure global rate limiter + from utilities.rate_limiter import configure_rate_limiter + configure_rate_limiter(backoff_seconds=float(backoff_seconds)) + + # Set up checkpoint + if checkpoint_path is None: + checkpoint_path = os.path.join(output_dir, "verify_checkpoints") + checkpoint = StepCheckpoint("Verify", output_dir) + checkpoint.dir = checkpoint_path # Load Stage 1 results print(f"[Verify] Loading results: {results_path}", file=sys.stderr) @@ -130,10 +153,16 @@ def _on_finding_done(unit_id: str, detail: str, unit_elapsed: float): unit_elapsed=unit_elapsed, ) + def _on_restored(count: int): + progress.completed = count + try: verified_results = verifier.verify_batch( vulnerable_results, code_by_route, progress_callback=_on_finding_done, + workers=workers, + checkpoint=checkpoint, + restored_callback=_on_restored, ) except Exception as e: print(f"[Verify] ERROR during batch verification: {e}", file=sys.stderr) @@ -145,8 +174,12 @@ def _on_finding_done(unit_id: str, detail: str, unit_elapsed: float): agreed = 0 disagreed = 0 confirmed_vulnerabilities = 0 + error_count = 0 for r in verified_results: + if r.get("error"): + error_count += 1 + continue verification = r.get("verification", {}) if verification.get("agree", False): agreed += 1 @@ -158,6 +191,11 @@ def _on_finding_done(unit_id: str, detail: str, unit_elapsed: float): print(f"\n[Verify] Results: {agreed} agreed, {disagreed} disagreed, " f"{confirmed_vulnerabilities} confirmed vulnerabilities", file=sys.stderr) + if error_count: + print(f"[Verify] Errors: {error_count}", file=sys.stderr) + + # Checkpoints are preserved as a permanent artifact alongside results + # (final summary with phase="done" is written inside verify_batch). tracking.log_usage("Stage 2") diff --git a/libs/openant-core/experiment.py b/libs/openant-core/experiment.py index 3d7cdad..409d4fa 100644 --- a/libs/openant-core/experiment.py +++ b/libs/openant-core/experiment.py @@ -343,13 +343,13 @@ def analyze_unit( code_field = unit.get("code", {}) if isinstance(code_field, dict): code = code_field.get("primary_code", "") - # Check if this is an enhanced dataset with file metadata + # Check if dependencies were inlined into this unit's primary_code primary_origin = code_field.get("primary_origin", {}) - is_enhanced = primary_origin.get("enhanced", False) + has_deps_inlined = primary_origin.get("deps_inlined", primary_origin.get("enhanced", False)) files_included = primary_origin.get("files_included", []) else: code = code_field - is_enhanced = False + has_deps_inlined = False files_included = [] # Extract agent context (security classification from agentic parser) @@ -424,7 +424,7 @@ def analyze_unit( result["response_length"] = len(response) result["code_length"] = len(code) result["files_included"] = files_included - result["is_enhanced"] = is_enhanced + result["has_deps_inlined"] = has_deps_inlined result["context_reviewed"] = context_enhanced if additional_files_added: result["files_added_by_review"] = additional_files_added @@ -622,7 +622,7 @@ def make_prompt(expanded_code, expanded_files): # Preserve route_key and other metadata from original result corrected_result["route_key"] = result.get("route_key") corrected_result["code_length"] = result.get("code_length") - corrected_result["is_enhanced"] = result.get("is_enhanced") + corrected_result["has_deps_inlined"] = result.get("has_deps_inlined") result = corrected_result results.append(result) diff --git a/libs/openant-core/generate_report.py b/libs/openant-core/generate_report.py index 4497f29..7b1ecd8 100644 --- a/libs/openant-core/generate_report.py +++ b/libs/openant-core/generate_report.py @@ -152,11 +152,76 @@ def generate_remediation_guidance(findings: list) -> str: return response.content[0].text +def _build_pipeline_costs_html(step_reports: list[dict]) -> str: + """Build an HTML table with pipeline step costs and durations.""" + if not step_reports: + return "" + + # Sort by timestamp (or keep as-is) + rows = "" + total_cost = 0.0 + total_duration = 0.0 + + for sr in sorted(step_reports, key=lambda s: s.get("timestamp", "")): + step = sr.get("step", "unknown") + duration = sr.get("duration_seconds", 0) + cost = sr.get("cost_usd", 0) + status = sr.get("status", "unknown") + + total_cost += cost + total_duration += duration + + # Format duration + if duration >= 60: + dur_str = f"{duration / 60:.1f}m" + else: + dur_str = f"{duration:.1f}s" + + cost_str = f"${cost:.4f}" if cost > 0 else "-" + status_color = "#28a745" if status == "success" else "#dc3545" if status == "error" else "#6c757d" + + rows += f""" + + {html.escape(step)} + {dur_str} + {cost_str} + {html.escape(status)} + """ + + # Total row + total_dur_str = f"{total_duration / 60:.1f}m" if total_duration >= 60 else f"{total_duration:.1f}s" + + return f""" +
+

Pipeline Costs & Timing

+ + + + + + + + + + + {rows} + + + + + + + +
StepDurationCostStatus
Total{total_dur_str}${total_cost:.4f}
+
""" + + def generate_html_report( experiment: dict, dataset: dict, remediation_html: str, - output_path: str + output_path: str, + step_reports: list[dict] | None = None, ): """Generate the HTML report.""" # Prepare data @@ -556,6 +621,8 @@ def generate_html_report( + {_build_pipeline_costs_html(step_reports or [])} +

All Findings

@@ -647,11 +714,24 @@ def generate_html_report( print(f"Report generated: {output_path}") +def _load_step_reports_from_dir(directory: str) -> list[dict]: + """Load all {step}.report.json files from a directory.""" + import glob + reports = [] + for path in glob.glob(os.path.join(directory, "*.report.json")): + try: + reports.append(load_json(path)) + except Exception: + continue + return reports + + def main(): parser = argparse.ArgumentParser(description='Generate HTML security report') parser.add_argument('experiment', help='Path to experiment results JSON') parser.add_argument('dataset', help='Path to dataset JSON') parser.add_argument('output', nargs='?', default='report.html', help='Output HTML path (default: report.html)') + parser.add_argument('--step-reports-dir', help='Directory containing *.report.json step reports') args = parser.parse_args() @@ -659,6 +739,15 @@ def main(): experiment = load_json(args.experiment) dataset = load_json(args.dataset) + # Load step reports if available + step_reports = [] + if args.step_reports_dir: + step_reports = _load_step_reports_from_dir(args.step_reports_dir) + else: + # Try to auto-detect from experiment path's directory + exp_dir = os.path.dirname(os.path.abspath(args.experiment)) + step_reports = _load_step_reports_from_dir(exp_dir) + print("Preparing findings...") findings = prepare_findings_summary(experiment, dataset) @@ -666,7 +755,7 @@ def main(): remediation_html = generate_remediation_guidance(findings) print("Building HTML report...") - generate_html_report(experiment, dataset, remediation_html, args.output) + generate_html_report(experiment, dataset, remediation_html, args.output, step_reports=step_reports) # Print summary verdict_counts = {} diff --git a/libs/openant-core/openant/cli.py b/libs/openant-core/openant/cli.py index cdaf2bf..4c7d3a7 100644 --- a/libs/openant-core/openant/cli.py +++ b/libs/openant-core/openant/cli.py @@ -29,6 +29,23 @@ def _output_json(data: dict): sys.stdout.write("\n") +def _load_step_reports(directory: str) -> list[dict]: + """Load all {step}.report.json files from a directory. + + Used by standalone commands (build-output, report) to feed + cost/duration data into pipeline_output.json. + """ + import glob + reports = [] + for path in glob.glob(os.path.join(directory, "*.report.json")): + try: + with open(path) as f: + reports.append(json.load(f)) + except (json.JSONDecodeError, OSError): + continue + return reports + + def cmd_scan(args): """Scan a repository end-to-end.""" from core.scanner import scan_repository @@ -51,6 +68,8 @@ def cmd_scan(args): enhance=not args.no_enhance, enhance_mode=args.enhance_mode, dynamic_test=args.dynamic_test, + workers=args.workers, + backoff_seconds=args.backoff, ) _output_json(success(result.to_dict())) @@ -112,6 +131,9 @@ def cmd_enhance(args): from core.enhancer import enhance_dataset from core.schemas import success, error from core.step_report import step_context + from core import tracking + + tracking.reset_tracking() # Default output path: same dir as input, with _enhanced suffix if args.output: @@ -136,6 +158,8 @@ def cmd_enhance(args): repo_path=args.repo_path, mode=args.mode, checkpoint_path=args.checkpoint, + workers=args.workers, + backoff_seconds=args.backoff, ) ctx.summary = { @@ -144,6 +168,8 @@ def cmd_enhance(args): "classifications": result.classifications, "mode": args.mode, } + if result.error_summary: + ctx.summary["error_summary"] = result.error_summary ctx.outputs = { "enhanced_dataset_path": result.enhanced_dataset_path, } @@ -165,14 +191,19 @@ def cmd_analyze(args): from core.analyzer import run_analysis from core.schemas import success, error from core.step_report import step_context + from core import tracking + + tracking.reset_tracking() output_dir = args.output or tempfile.mkdtemp(prefix="open_ant_analyze_") + exploitable_filter = "all" if args.exploitable_all else ("strict" if args.exploitable_only else None) + try: with step_context("analyze", output_dir, inputs={ "dataset_path": os.path.abspath(args.dataset), "model": args.model, - "exploitable_only": args.exploitable_only, + "exploitable_filter": exploitable_filter, "limit": args.limit, }) as ctx: result = run_analysis( @@ -183,7 +214,10 @@ def cmd_analyze(args): repo_path=args.repo_path, limit=args.limit, model=args.model, - exploitable_only=args.exploitable_only, + exploitable_filter=exploitable_filter, + workers=args.workers, + checkpoint_path=getattr(args, "checkpoint", None), + backoff_seconds=args.backoff, ) ctx.summary = { @@ -219,6 +253,8 @@ def cmd_analyze(args): analyzer_output_path=args.analyzer_output, app_context_path=args.app_context, repo_path=args.repo_path, + workers=args.workers, + backoff_seconds=args.backoff, ) vctx.summary = { @@ -254,6 +290,9 @@ def cmd_verify(args): from core.verifier import run_verification from core.schemas import success, error from core.step_report import step_context + from core import tracking + + tracking.reset_tracking() output_dir = args.output or tempfile.mkdtemp(prefix="open_ant_verify_") @@ -270,6 +309,9 @@ def cmd_verify(args): analyzer_output_path=args.analyzer_output, app_context_path=args.app_context, repo_path=args.repo_path, + workers=args.workers, + checkpoint_path=getattr(args, "checkpoint", None), + backoff_seconds=args.backoff, ) ctx.summary = { @@ -303,11 +345,15 @@ def cmd_build_output(args): output_dir = os.path.dirname(os.path.abspath(args.output)) + # Load existing step reports for cost/duration data + results_dir = os.path.dirname(os.path.abspath(args.results)) + step_reports = _load_step_reports(results_dir) + try: with step_context("build-output", output_dir, inputs={ "results_path": os.path.abspath(args.results), }) as ctx: - path = build_pipeline_output( + path, findings_count = build_pipeline_output( results_path=args.results, output_path=args.output, repo_name=args.repo_name, @@ -316,11 +362,12 @@ def cmd_build_output(args): commit_sha=args.commit_sha, application_type=args.app_type or "web_app", processing_level=args.processing_level, + step_reports=step_reports, ) ctx.outputs = {"pipeline_output_path": path} - _output_json(success({"pipeline_output_path": path})) + _output_json(success({"pipeline_output_path": path, "findings_count": findings_count})) return 0 except Exception as e: @@ -333,6 +380,9 @@ def cmd_dynamic_test(args): from core.dynamic_tester import run_tests from core.schemas import success, error from core.step_report import step_context + from core import tracking + + tracking.reset_tracking() output_dir = args.output or tempfile.mkdtemp(prefix="openant_dyntest_") @@ -371,6 +421,18 @@ def cmd_dynamic_test(args): return 2 +def _default_report_output(results_path: str, fmt: str) -> str: + """Derive a sensible default output path based on format.""" + reports_dir = os.path.join(os.path.dirname(os.path.abspath(results_path)), "final-reports") + defaults = { + "html": os.path.join(reports_dir, "report.html"), + "csv": os.path.join(reports_dir, "report.csv"), + "summary": os.path.join(reports_dir, "report.md"), + "disclosure": os.path.join(reports_dir, "disclosures"), + } + return defaults.get(fmt, os.path.join(reports_dir, "report")) + + def cmd_report(args): """Generate reports from analysis results. @@ -381,7 +443,6 @@ def cmd_report(args): """ from core.reporter import ( build_pipeline_output, - generate_html_report, generate_csv_report, generate_summary_report, generate_disclosure_docs, @@ -389,32 +450,53 @@ def cmd_report(args): from core.schemas import success, error from core.step_report import step_context - output_path = args.output + fmt = args.format + output_path = args.output or _default_report_output(args.results, fmt) output_dir = os.path.dirname(os.path.abspath(output_path)) + # Check if dynamic tests have been run (for summary/disclosure formats) + if fmt in ("summary", "disclosure") and not getattr(args, "skip_dt_check", False): + results_dir = os.path.dirname(os.path.abspath(args.results)) + dt_results_path = os.path.join(results_dir, "dynamic_test_results.json") + if not os.path.exists(dt_results_path): + print( + "\nDynamic tests haven't been run yet.\n" + "If this is intentional, press Y to generate reports without dynamic test data.\n" + "Otherwise, run 'openant dynamic-test' first.\n", + file=sys.stderr, + ) + try: + answer = input("[Y/n] ").strip().lower() + except (EOFError, KeyboardInterrupt): + answer = "n" + if answer not in ("y", "yes", ""): + print("Aborted. Run 'openant dynamic-test' first.", file=sys.stderr) + return 0 + try: with step_context("report", output_dir, inputs={ "results_path": os.path.abspath(args.results), - "format": args.format, + "format": fmt, }) as ctx: - fmt = args.format - # For summary/disclosure, we need pipeline_output.json pipeline_output_path = args.pipeline_output if fmt in ("summary", "disclosure") and not pipeline_output_path: - # Auto-build pipeline_output from results + # Auto-build pipeline_output from results, with step report data + results_dir = os.path.dirname(os.path.abspath(args.results)) + step_reports = _load_step_reports(results_dir) pipeline_output_path = os.path.join(output_dir, "pipeline_output.json") build_pipeline_output( results_path=args.results, output_path=pipeline_output_path, repo_name=args.repo_name, + step_reports=step_reports, ) if fmt == "html": - if not args.dataset: - _output_json(error("--dataset is required for HTML reports")) - return 2 - result = generate_html_report(args.results, args.dataset, output_path) + # HTML reports are now rendered by the Go CLI via report-data. + # This code path should not be reached — Go handles html directly. + _output_json(error("HTML reports are generated by the Go CLI. Use 'openant report -f html' instead.")) + return 2 elif fmt == "csv": if not args.dataset: _output_json(error("--dataset is required for CSV reports")) @@ -439,6 +521,373 @@ def cmd_report(args): return 2 +def cmd_checkpoint_status(args): + """Report checkpoint status for a checkpoint directory. + + Internal subcommand — not user-facing. Called by the Go CLI to get + accurate completed/errored counts by reading actual checkpoint files. + """ + from core.checkpoint import StepCheckpoint + from core.schemas import success, error + + checkpoint_dir = args.checkpoint_dir + if not os.path.isdir(checkpoint_dir): + _output_json(error(f"Checkpoint directory not found: {checkpoint_dir}")) + return 2 + + try: + status = StepCheckpoint.status(checkpoint_dir) + _output_json(success(status)) + return 0 + except Exception as e: + _output_json(error(str(e))) + return 2 + + +def cmd_report_data(args): + """Prepare pre-computed report data as JSON for the Go HTML renderer. + + Internal subcommand — not user-facing. Called by the Go CLI to get + all data needed to render the HTML overview report. + + Outputs a JSON blob with stats, chart data, findings, remediation HTML, + and step reports — everything display-ready. + """ + import html as html_mod + import anthropic + from core.schemas import success, error + from core.step_report import step_context + from utilities.llm_client import get_global_tracker + + results_path = args.results + dataset_path = args.dataset + + if not dataset_path: + _output_json(error("--dataset is required for report-data")) + return 2 + + results_dir = os.path.dirname(os.path.abspath(results_path)) + + try: + with step_context("report-data", results_dir, inputs={ + "results_path": os.path.abspath(results_path), + "dataset_path": os.path.abspath(dataset_path), + }) as ctx: + # Load data + with open(results_path) as f: + experiment = json.load(f) + with open(dataset_path) as f: + dataset = json.load(f) + + # --- Load dynamic test results if available --- + # Dynamic tests use VULN-XXX IDs from pipeline_output.json, + # but report-data works with route_keys from results_verified.json. + # Bridge via pipeline_output's location.function (== route_key). + dt_by_route_key = {} + dt_path = os.path.join(results_dir, "dynamic_test_results.json") + po_path = os.path.join(results_dir, "pipeline_output.json") + if os.path.exists(dt_path) and os.path.exists(po_path): + with open(dt_path) as f: + dt_data = json.load(f) + with open(po_path) as f: + po_data = json.load(f) + + # Map VULN-ID → route_key from pipeline_output + vuln_id_to_route = {} + for finding in po_data.get("findings", []): + fid = finding.get("id") + route = finding.get("location", {}).get("function", "") + if fid and route: + vuln_id_to_route[fid] = route + + # Map route_key → dynamic test result + for dr in dt_data.get("results", []): + fid = dr.get("finding_id") + route = vuln_id_to_route.get(fid) + if route: + dt_by_route_key[route] = dr + + print(f"[Report] Loaded {len(dt_by_route_key)} dynamic test results", file=sys.stderr) + + # --- Prepare findings --- + units_by_id = {u["id"]: u for u in dataset.get("units", [])} + + verdict_order = ["vulnerable", "bypassable", "inconclusive", "protected", "safe"] + verdict_colors = { + "vulnerable": "#dc3545", + "bypassable": "#fd7e14", + "inconclusive": "#6c757d", + "protected": "#28a745", + "safe": "#20c997", + } + verdict_priority = {v: i for i, v in enumerate(verdict_order)} + dt_status_order = ["CONFIRMED", "INCONCLUSIVE", "ERROR", "", "BLOCKED", "NOT_REPRODUCED"] + dt_status_priority = {s: i for i, s in enumerate(dt_status_order)} + + verdict_counts = {} + file_verdicts = {} + findings = [] + + for result in experiment.get("results", []): + route_key = result.get("route_key", "") + verdict = result.get("finding", "") + file_path = route_key.rsplit(":", 1)[0] if ":" in route_key else route_key + unit = units_by_id.get(route_key, {}) + llm_context = unit.get("llm_context") or {} + verification = result.get("verification") or {} + + # Justification: prefer stage2, fallback to stage1 + justification = verification.get("explanation", "") + if not justification: + justification = result.get("reasoning", "") + justification = justification[:300] + + # Downgrade unverified findings to inconclusive + if justification.strip() == "Max iterations reached": + verdict = "inconclusive" + + verdict_counts[verdict] = verdict_counts.get(verdict, 0) + 1 + + # Track worst verdict per file + if file_path not in file_verdicts: + file_verdicts[file_path] = verdict + elif verdict_priority.get(verdict, 3) < verdict_priority.get(file_verdicts[file_path], 3): + file_verdicts[file_path] = verdict + + func_name = route_key.split(":")[-1] if ":" in route_key else route_key + + # Dynamic test result for this finding + dt_result = dt_by_route_key.get(route_key) + dt_status = "" + dt_details = "" + if dt_result: + dt_status = dt_result.get("status", "") + dt_details = dt_result.get("details", "") + + findings.append({ + "verdict": verdict, + "verdict_color": verdict_colors.get(verdict, "#6c757d"), + "file": file_path, + "function": func_name, + "attack_vector": result.get("attack_vector", "") or "", + "analysis": justification, + "dynamic_test_status": dt_status, + "dynamic_test_details": dt_details, + "number": 0, # assigned after sort + }) + + # Sort by verdict priority, then by dynamic test status within each group + findings.sort(key=lambda f: ( + verdict_priority.get(f["verdict"], 3), + dt_status_priority.get(f["dynamic_test_status"], 3), + )) + for i, f in enumerate(findings, 1): + f["number"] = i + + # --- Group findings by verdict, sub-grouped by dynamic test outcome --- + dt_subgroup_defs = [ + ("Confirmed", lambda s: s == "CONFIRMED"), + ("Not reproduced", lambda s: s in ("NOT_REPRODUCED", "BLOCKED")), + ("Test error", lambda s: s == "ERROR"), + ("Not tested", lambda s: s in ("", "INCONCLUSIVE")), + ] + + findings_by_verdict = [] + for v in verdict_order: + group = [f for f in findings if f["verdict"] == v] + if not group: + continue + + subgroups = [] + for label, predicate in dt_subgroup_defs: + sg_findings = [f for f in group if predicate(f.get("dynamic_test_status", ""))] + if sg_findings: + subgroups.append({"label": label, "findings": sg_findings}) + + findings_by_verdict.append({ + "verdict": v, + "verdict_color": verdict_colors[v], + "count": len(group), + "open_by_default": v in ("vulnerable", "bypassable"), + "findings": group, + "subgroups": subgroups, + "has_subgroups": len(subgroups) > 1, + }) + + # --- Chart data --- + unit_chart = { + "labels": [v for v in verdict_order if v in verdict_counts], + "data": [verdict_counts.get(v, 0) for v in verdict_order if v in verdict_counts], + "colors": [verdict_colors[v] for v in verdict_order if v in verdict_counts], + } + + file_verdict_counts = {} + for v in file_verdicts.values(): + file_verdict_counts[v] = file_verdict_counts.get(v, 0) + 1 + + file_chart = { + "labels": [v for v in verdict_order if v in file_verdict_counts], + "data": [file_verdict_counts.get(v, 0) for v in verdict_order if v in file_verdict_counts], + "colors": [verdict_colors[v] for v in verdict_order if v in file_verdict_counts], + } + + # --- Stats --- + total_units = len(experiment.get("results", [])) + total_files = len(file_verdicts) + + stats = { + "total_units": total_units, + "total_files": total_files, + "vulnerable": verdict_counts.get("vulnerable", 0), + "bypassable": verdict_counts.get("bypassable", 0), + "secure": verdict_counts.get("protected", 0) + verdict_counts.get("safe", 0), + } + + # --- Remediation guidance (LLM call) --- + actionable = [f for f in findings if f["verdict"] in ("vulnerable", "bypassable", "inconclusive")] + + if not actionable: + remediation_html = "

No vulnerabilities or security concerns found. All code units are either safe or properly protected.

" + else: + findings_text = "" + for f in actionable: + findings_text += f""" +### Finding #{f['number']}: {f['file']}:{f['function']} +- **Verdict**: {f['verdict']} +- **Attack Vector**: {f['attack_vector'] or 'Not specified'} +- **Analysis**: {f['analysis'][:500]} +""" + prompt = f"""Analyze these security findings and provide: + +1. **Executive Summary**: A brief overview of the security posture (2-3 sentences) + +2. **Prioritized Action Items**: Group remediation steps by priority: Critical Priority, High Priority, Medium Priority. + For each item: + - What to fix + - Why it's important + - How to fix it (concrete steps) + When referencing findings, use their exact numbers with # prefix (e.g. #4, #12, #13, #14). + Do NOT invent specific timeframes like "fix within 72 hours" — use only the priority labels above. + +3. **Quick Wins**: Any simple fixes that would immediately improve security + +Format your response as HTML (use

,

,

    ,
  • , tags). Do not include ```html markers. + +## Findings to Analyze: +{findings_text} +""" + print("[Report] Generating remediation guidance (LLM)...", file=sys.stderr) + client = anthropic.Anthropic() + response = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=4096, + messages=[{"role": "user", "content": prompt}], + ) + remediation_html = response.content[0].text + + # Post-process: linkify finding references like #4, #12-#14 + import re + def _linkify_finding(m): + num = m.group(1) + return f'#{num}' + remediation_html = re.sub(r'#(\d+)', _linkify_finding, remediation_html) + + # Track usage + usage = response.usage + tracker = get_global_tracker() + tracker.record_call( + model="claude-sonnet-4-20250514", + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + ) + print(f" Remediation cost: ${(usage.input_tokens / 1e6) * 3.0 + (usage.output_tokens / 1e6) * 15.0:.4f}", file=sys.stderr) + + # --- Step reports --- + step_reports_data = [] + for sr in _load_step_reports(results_dir): + duration = sr.get("duration_seconds", 0) + cost = sr.get("cost_usd", 0) + if duration >= 60: + dur_str = f"{duration / 60:.1f}m" + else: + dur_str = f"{duration:.1f}s" + cost_str = f"${cost:.2f}" if cost > 0 else "-" + + step_reports_data.append({ + "step": sr.get("step", "unknown"), + "duration": dur_str, + "cost": cost_str, + "status": sr.get("status", "unknown"), + "timestamp": sr.get("timestamp", ""), + }) + + # Sort by timestamp + step_reports_data.sort(key=lambda s: s.get("timestamp", "")) + + # --- Category descriptions (static) --- + categories = [ + {"verdict": "vulnerable", "color": "#dc3545", "description": "Code contains an exploitable security vulnerability with no effective protection. Immediate remediation required."}, + {"verdict": "bypassable", "color": "#fd7e14", "description": "Security controls exist but can be circumvented under certain conditions. Review and strengthen protections."}, + {"verdict": "inconclusive", "color": "#6c757d", "description": "Security posture could not be determined. Manual review recommended to assess risk."}, + {"verdict": "protected", "color": "#28a745", "description": "Code handles potentially dangerous operations but has effective security controls in place."}, + {"verdict": "safe", "color": "#20c997", "description": "Code does not involve security-sensitive operations or poses no security risk."}, + ] + + from datetime import datetime + + # --- Repo info from pipeline_output.json --- + repo_name = "" + commit_sha = "" + language = "" + repo_url = "" + if os.path.exists(po_path): + try: + with open(po_path) as f: + po = json.load(f) + repo_info = po.get("repository", {}) + repo_name = repo_info.get("name", "") + commit_sha = repo_info.get("commit_sha", "") + language = repo_info.get("language", "") + repo_url = repo_info.get("url", "") + except (json.JSONDecodeError, OSError): + pass + + # --- Totals from step reports --- + total_duration_seconds = 0.0 + total_cost_usd = 0.0 + for sr in _load_step_reports(results_dir): + total_duration_seconds += sr.get("duration_seconds", 0) + total_cost_usd += sr.get("cost_usd", 0) + + report_data = { + "title": "Security Analysis Report", + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "repo_name": repo_name, + "commit_sha": commit_sha, + "language": language, + "repo_url": repo_url, + "total_duration_seconds": total_duration_seconds, + "total_cost_usd": total_cost_usd, + "stats": stats, + "unit_chart": unit_chart, + "file_chart": file_chart, + "remediation_html": remediation_html, + "findings": findings, + "findings_by_verdict": findings_by_verdict, + "step_reports": step_reports_data, + "categories": categories, + } + + ctx.summary = {"findings": len(findings), "actionable": len(actionable)} + + _output_json(success(report_data)) + return 0 + + except Exception as e: + _output_json(error(str(e))) + return 2 + + def main(): parser = argparse.ArgumentParser( prog="openant", @@ -487,6 +936,10 @@ def main(): scan_p.add_argument("--no-skip-tests", action="store_true", help="Include test files in parsing (default: tests are skipped)") scan_p.add_argument("--limit", type=int, help="Max units to analyze") scan_p.add_argument("--model", choices=["opus", "sonnet"], default="opus", help="Model (default: opus)") + scan_p.add_argument("--workers", type=int, default=8, + help="Number of parallel workers for LLM steps (default: 8)") + scan_p.add_argument("--backoff", type=int, default=30, + help="Seconds to wait when rate-limited (default: 30)") scan_p.set_defaults(func=cmd_scan) # --------------------------------------------------------------- @@ -526,6 +979,10 @@ def main(): default="agentic", help="Enhancement mode (default: agentic — thorough but more expensive)", ) + enhance_p.add_argument("--workers", type=int, default=8, + help="Number of parallel workers for LLM calls (default: 8)") + enhance_p.add_argument("--backoff", type=int, default=30, + help="Seconds to wait when rate-limited (default: 30)") enhance_p.set_defaults(func=cmd_enhance) # --------------------------------------------------------------- @@ -539,9 +996,17 @@ def main(): analyze_p.add_argument("--app-context", help="Path to application_context.json") analyze_p.add_argument("--limit", type=int, help="Max units to analyze") analyze_p.add_argument("--repo-path", help="Path to the repository (for context correction)") - analyze_p.add_argument("--exploitable-only", action="store_true", - help="Only analyze units classified as exploitable/vulnerable by enhancer") + exploit_group = analyze_p.add_mutually_exclusive_group() + exploit_group.add_argument("--exploitable-all", action="store_true", + help="Analyze units classified as exploitable or vulnerable_internal (safer, compensates for parser gaps)") + exploit_group.add_argument("--exploitable-only", action="store_true", + help="Analyze only units classified as exploitable (strict, use after parser entry point fixes)") analyze_p.add_argument("--model", choices=["opus", "sonnet"], default="opus", help="Model (default: opus)") + analyze_p.add_argument("--workers", type=int, default=8, + help="Number of parallel workers for LLM calls (default: 8)") + analyze_p.add_argument("--checkpoint", help="Path to checkpoint directory for save/resume") + analyze_p.add_argument("--backoff", type=int, default=30, + help="Seconds to wait when rate-limited (default: 30)") analyze_p.set_defaults(func=cmd_analyze) # --------------------------------------------------------------- @@ -553,6 +1018,11 @@ def main(): verify_p.add_argument("--app-context", help="Path to application_context.json") verify_p.add_argument("--repo-path", help="Path to the repository") verify_p.add_argument("--output", "-o", help="Output directory (default: temp dir)") + verify_p.add_argument("--workers", type=int, default=8, + help="Number of parallel workers for LLM calls (default: 8)") + verify_p.add_argument("--checkpoint", help="Path to checkpoint directory for save/resume") + verify_p.add_argument("--backoff", type=int, default=30, + help="Seconds to wait when rate-limited (default: 30)") verify_p.set_defaults(func=cmd_verify) # --------------------------------------------------------------- @@ -587,15 +1057,31 @@ def main(): report_p.add_argument( "--format", "-f", choices=["html", "csv", "summary", "disclosure"], - default="html", - help="Report format (default: html)", + default="disclosure", + help="Report format (default: disclosure)", ) report_p.add_argument("--dataset", help="Path to dataset JSON (required for html/csv)") report_p.add_argument("--pipeline-output", help="Path to pipeline_output.json (for summary/disclosure; auto-built if absent)") report_p.add_argument("--repo-name", help="Repository name (used when auto-building pipeline_output)") - report_p.add_argument("--output", "-o", required=True, help="Output path") + report_p.add_argument("--output", "-o", help="Output path (default: derived from results path and format)") report_p.set_defaults(func=cmd_report) + # --------------------------------------------------------------- + # report-data — internal: prepare pre-computed report data as JSON + # --------------------------------------------------------------- + rd_p = subparsers.add_parser("report-data", help="(internal) Prepare report data for Go renderer") + rd_p.add_argument("results", help="Path to results/experiment JSON") + rd_p.add_argument("--dataset", required=True, help="Path to dataset JSON") + rd_p.set_defaults(func=cmd_report_data) + + # --------------------------------------------------------------- + # checkpoint-status — internal: report checkpoint status for Go CLI + # --------------------------------------------------------------- + cs_p = subparsers.add_parser("checkpoint-status", + help="(internal) Report checkpoint status for a directory") + cs_p.add_argument("checkpoint_dir", help="Path to checkpoint directory") + cs_p.set_defaults(func=cmd_checkpoint_status) + args = parser.parse_args() return args.func(args) diff --git a/libs/openant-core/parsers/c/unit_generator.py b/libs/openant-core/parsers/c/unit_generator.py index 0330950..a0391d7 100644 --- a/libs/openant-core/parsers/c/unit_generator.py +++ b/libs/openant-core/parsers/c/unit_generator.py @@ -193,7 +193,7 @@ def create_unit(self, func_id: str, func_data: Dict) -> Dict: # Assemble enhanced code enhanced_code = self.assemble_enhanced_code(func_data, upstream_deps, downstream_callers) files_included = self.collect_files_included(file_path, upstream_deps, downstream_callers) - is_enhanced = len(upstream_deps) > 0 or len(downstream_callers) > 0 + has_deps_inlined = len(upstream_deps) > 0 or len(downstream_callers) > 0 # Get direct calls/callers (depth 1 only) direct_calls = self.call_graph.get(func_id, []) @@ -211,7 +211,7 @@ def create_unit(self, func_id: str, func_data: Dict) -> Dict: 'end_line': func_data.get('end_line'), 'function_name': func_name, 'class_name': class_name, - 'enhanced': is_enhanced, + 'deps_inlined': has_deps_inlined, 'files_included': files_included, 'original_length': len(func_data.get('code', '')), 'enhanced_length': len(enhanced_code), @@ -258,7 +258,7 @@ def update_statistics(self, unit: Dict) -> None: self.statistics['units_with_upstream'] += 1 if dep_meta.get('total_downstream', 0) > 0: self.statistics['units_with_downstream'] += 1 - if unit.get('code', {}).get('primary_origin', {}).get('enhanced', False): + if unit.get('code', {}).get('primary_origin', {}).get('deps_inlined', False): self.statistics['units_enhanced'] += 1 def generate_units(self) -> Dict: diff --git a/libs/openant-core/parsers/go/go_parser/generator.go b/libs/openant-core/parsers/go/go_parser/generator.go index 1629113..f475ecd 100644 --- a/libs/openant-core/parsers/go/go_parser/generator.go +++ b/libs/openant-core/parsers/go/go_parser/generator.go @@ -53,7 +53,7 @@ func (g *Generator) Generate() *Dataset { unitsWithDownstream++ totalDownstream += downstream } - if unit.Code.PrimaryOrigin.Enhanced { + if unit.Code.PrimaryOrigin.DepsInlined { unitsEnhanced++ } } @@ -108,7 +108,7 @@ func (g *Generator) createUnit(funcID string, funcInfo FunctionInfo) Unit { primaryCode, filesIncluded := g.assembleEnhancedCode(funcInfo, upstream) originalLength := len(funcInfo.Code) enhancedLength := len(primaryCode) - enhanced := enhancedLength > originalLength + depsInlined := enhancedLength > originalLength return Unit{ ID: funcID, @@ -121,7 +121,7 @@ func (g *Generator) createUnit(funcID string, funcInfo FunctionInfo) Unit { EndLine: funcInfo.EndLine, FunctionName: funcInfo.Name, ClassName: funcInfo.ClassName, - Enhanced: enhanced, + DepsInlined: depsInlined, FilesIncluded: filesIncluded, OriginalLength: originalLength, EnhancedLength: enhancedLength, diff --git a/libs/openant-core/parsers/go/go_parser/go_parser b/libs/openant-core/parsers/go/go_parser/go_parser index 63462b6..198b846 100755 Binary files a/libs/openant-core/parsers/go/go_parser/go_parser and b/libs/openant-core/parsers/go/go_parser/go_parser differ diff --git a/libs/openant-core/parsers/go/go_parser/types.go b/libs/openant-core/parsers/go/go_parser/types.go index 32f3fae..4eab273 100644 --- a/libs/openant-core/parsers/go/go_parser/types.go +++ b/libs/openant-core/parsers/go/go_parser/types.go @@ -91,7 +91,7 @@ type PrimaryOrigin struct { EndLine int `json:"end_line"` FunctionName string `json:"function_name"` ClassName string `json:"class_name,omitempty"` - Enhanced bool `json:"enhanced"` + DepsInlined bool `json:"deps_inlined"` FilesIncluded []string `json:"files_included"` OriginalLength int `json:"original_length"` EnhancedLength int `json:"enhanced_length"` diff --git a/libs/openant-core/parsers/javascript/unit_generator.js b/libs/openant-core/parsers/javascript/unit_generator.js index 699eba5..3650792 100644 --- a/libs/openant-core/parsers/javascript/unit_generator.js +++ b/libs/openant-core/parsers/javascript/unit_generator.js @@ -277,9 +277,9 @@ class UnitGenerator { const directCalls = this.resolver.callGraph[functionId] || []; const directCallers = this.resolver.reverseCallGraph[functionId] || []; - // Assemble enhanced code with dependencies (Sastinel standard format) + // Assemble enhanced code with dependencies (OpenAnt standard format) const filesIncluded = this._collectFilesIncluded(filePath, upstreamDependencies, downstreamCallers); - const isEnhanced = upstreamDependencies.length > 0 || downstreamCallers.length > 0; + const hasDepsInlined = upstreamDependencies.length > 0 || downstreamCallers.length > 0; const assembledCode = this._assembleEnhancedCode(funcData, upstreamDependencies, downstreamCallers); // Build the unit @@ -294,7 +294,7 @@ class UnitGenerator { end_line: funcData.endLine || null, function_name: funcData.name, class_name: funcData.className || null, - enhanced: isEnhanced, + deps_inlined: hasDepsInlined, files_included: filesIncluded, original_length: funcData.code.length, enhanced_length: assembledCode.length diff --git a/libs/openant-core/parsers/php/PARSER_PIPELINE.md b/libs/openant-core/parsers/php/PARSER_PIPELINE.md index 43e4fea..1792936 100644 --- a/libs/openant-core/parsers/php/PARSER_PIPELINE.md +++ b/libs/openant-core/parsers/php/PARSER_PIPELINE.md @@ -86,7 +86,7 @@ python call_graph_builder.py functions.json --output call_graph.json Creates self-contained analysis units with full context. **Input:** Call graph data -**Output:** Dataset compatible with Sastinel +**Output:** OpenAnt dataset format **Each unit contains:** ```json diff --git a/libs/openant-core/parsers/php/PARSER_UPGRADE_PLAN.md b/libs/openant-core/parsers/php/PARSER_UPGRADE_PLAN.md index f23b5fe..d877201 100644 --- a/libs/openant-core/parsers/php/PARSER_UPGRADE_PLAN.md +++ b/libs/openant-core/parsers/php/PARSER_UPGRADE_PLAN.md @@ -4,7 +4,7 @@ ## Quick Context -This is the PHP code parser for Sastinel (a SAST tool). It mirrors the Ruby/C parser's tree-sitter approach and the Python parser's pipeline structure. +This is the PHP code parser for OpenAnt (a SAST tool). It mirrors the Ruby/C parser's tree-sitter approach and the Python parser's pipeline structure. **Goal:** Parse PHP repositories into self-contained analysis units for vulnerability detection. diff --git a/libs/openant-core/parsers/php/__init__.py b/libs/openant-core/parsers/php/__init__.py index 1f37e03..fd37c6d 100644 --- a/libs/openant-core/parsers/php/__init__.py +++ b/libs/openant-core/parsers/php/__init__.py @@ -1 +1 @@ -# PHP parser for Sastinel +# PHP parser for OpenAnt diff --git a/libs/openant-core/parsers/php/test_pipeline.py b/libs/openant-core/parsers/php/test_pipeline.py index 7a9c43f..fd10477 100644 --- a/libs/openant-core/parsers/php/test_pipeline.py +++ b/libs/openant-core/parsers/php/test_pipeline.py @@ -6,7 +6,7 @@ 1. RepositoryScanner - Enumerates .rb/.rake files 2. FunctionExtractor - Extracts functions via tree-sitter 3. CallGraphBuilder - Builds bidirectional call graphs -4. UnitGenerator - Creates Sastinel dataset format +4. UnitGenerator - Creates OpenAnt dataset format 5. CodeQL (optional) - Static analysis pre-filter 6. ContextEnhancer (optional) - LLM enhancement using Claude Sonnet diff --git a/libs/openant-core/parsers/php/unit_generator.py b/libs/openant-core/parsers/php/unit_generator.py index 92e7389..9b36684 100644 --- a/libs/openant-core/parsers/php/unit_generator.py +++ b/libs/openant-core/parsers/php/unit_generator.py @@ -195,7 +195,7 @@ def create_unit(self, func_id: str, func_data: Dict) -> Dict: # Assemble enhanced code enhanced_code = self.assemble_enhanced_code(func_data, upstream_deps, downstream_callers) files_included = self.collect_files_included(file_path, upstream_deps, downstream_callers) - is_enhanced = len(upstream_deps) > 0 or len(downstream_callers) > 0 + has_deps_inlined = len(upstream_deps) > 0 or len(downstream_callers) > 0 # Get direct calls/callers (depth 1 only) direct_calls = self.call_graph.get(func_id, []) @@ -213,7 +213,7 @@ def create_unit(self, func_id: str, func_data: Dict) -> Dict: 'end_line': func_data.get('end_line'), 'function_name': func_name, 'class_name': class_name, - 'enhanced': is_enhanced, + 'deps_inlined': has_deps_inlined, 'files_included': files_included, 'original_length': len(func_data.get('code', '')), 'enhanced_length': len(enhanced_code), @@ -259,7 +259,7 @@ def update_statistics(self, unit: Dict) -> None: self.statistics['units_with_upstream'] += 1 if dep_meta.get('total_downstream', 0) > 0: self.statistics['units_with_downstream'] += 1 - if unit.get('code', {}).get('primary_origin', {}).get('enhanced', False): + if unit.get('code', {}).get('primary_origin', {}).get('deps_inlined', False): self.statistics['units_enhanced'] += 1 def generate_units(self) -> Dict: diff --git a/libs/openant-core/parsers/python/unit_generator.py b/libs/openant-core/parsers/python/unit_generator.py index 9373ea1..a7d2680 100644 --- a/libs/openant-core/parsers/python/unit_generator.py +++ b/libs/openant-core/parsers/python/unit_generator.py @@ -30,7 +30,7 @@ "end_line": 25, "function_name": "function_name", "class_name": null, - "enhanced": true, + "deps_inlined": true, "files_included": ["file.py", "utils.py"] }, "dependencies": [...], @@ -276,7 +276,7 @@ def create_unit(self, func_id: str, func_data: Dict) -> Dict: # Assemble enhanced code enhanced_code = self.assemble_enhanced_code(func_data, upstream_deps, downstream_callers) files_included = self.collect_files_included(file_path, upstream_deps, downstream_callers) - is_enhanced = len(upstream_deps) > 0 or len(downstream_callers) > 0 + has_deps_inlined = len(upstream_deps) > 0 or len(downstream_callers) > 0 # Get direct calls/callers (depth 1 only) direct_calls = self.call_graph.get(func_id, []) @@ -294,7 +294,7 @@ def create_unit(self, func_id: str, func_data: Dict) -> Dict: 'end_line': func_data.get('end_line'), 'function_name': func_name, 'class_name': class_name, - 'enhanced': is_enhanced, + 'deps_inlined': has_deps_inlined, 'files_included': files_included, 'original_length': len(func_data.get('code', '')), 'enhanced_length': len(enhanced_code), @@ -341,7 +341,7 @@ def update_statistics(self, unit: Dict) -> None: self.statistics['units_with_upstream'] += 1 if dep_meta.get('total_downstream', 0) > 0: self.statistics['units_with_downstream'] += 1 - if unit.get('code', {}).get('primary_origin', {}).get('enhanced', False): + if unit.get('code', {}).get('primary_origin', {}).get('deps_inlined', False): self.statistics['units_enhanced'] += 1 def generate_units(self) -> Dict: diff --git a/libs/openant-core/parsers/ruby/PARSER_PIPELINE.md b/libs/openant-core/parsers/ruby/PARSER_PIPELINE.md index ae6748c..074e287 100644 --- a/libs/openant-core/parsers/ruby/PARSER_PIPELINE.md +++ b/libs/openant-core/parsers/ruby/PARSER_PIPELINE.md @@ -86,7 +86,7 @@ python call_graph_builder.py functions.json --output call_graph.json Creates self-contained analysis units with full context. **Input:** Call graph data -**Output:** Dataset compatible with Sastinel +**Output:** OpenAnt dataset format **Each unit contains:** ```json diff --git a/libs/openant-core/parsers/ruby/PARSER_UPGRADE_PLAN.md b/libs/openant-core/parsers/ruby/PARSER_UPGRADE_PLAN.md index 46928bd..2a5af9e 100644 --- a/libs/openant-core/parsers/ruby/PARSER_UPGRADE_PLAN.md +++ b/libs/openant-core/parsers/ruby/PARSER_UPGRADE_PLAN.md @@ -4,7 +4,7 @@ ## Quick Context -This is the Ruby code parser for Sastinel (a SAST tool). It mirrors the C parser's tree-sitter approach and the Python parser's pipeline structure. +This is the Ruby code parser for OpenAnt (a SAST tool). It mirrors the C parser's tree-sitter approach and the Python parser's pipeline structure. **Goal:** Parse Ruby repositories into self-contained analysis units for vulnerability detection. diff --git a/libs/openant-core/parsers/ruby/test_pipeline.py b/libs/openant-core/parsers/ruby/test_pipeline.py index 033679f..cffe880 100644 --- a/libs/openant-core/parsers/ruby/test_pipeline.py +++ b/libs/openant-core/parsers/ruby/test_pipeline.py @@ -6,7 +6,7 @@ 1. RepositoryScanner - Enumerates .rb/.rake files 2. FunctionExtractor - Extracts functions via tree-sitter 3. CallGraphBuilder - Builds bidirectional call graphs -4. UnitGenerator - Creates Sastinel dataset format +4. UnitGenerator - Creates OpenAnt dataset format 5. CodeQL (optional) - Static analysis pre-filter 6. ContextEnhancer (optional) - LLM enhancement using Claude Sonnet diff --git a/libs/openant-core/parsers/ruby/unit_generator.py b/libs/openant-core/parsers/ruby/unit_generator.py index 1e8f6bc..184a221 100644 --- a/libs/openant-core/parsers/ruby/unit_generator.py +++ b/libs/openant-core/parsers/ruby/unit_generator.py @@ -195,7 +195,7 @@ def create_unit(self, func_id: str, func_data: Dict) -> Dict: # Assemble enhanced code enhanced_code = self.assemble_enhanced_code(func_data, upstream_deps, downstream_callers) files_included = self.collect_files_included(file_path, upstream_deps, downstream_callers) - is_enhanced = len(upstream_deps) > 0 or len(downstream_callers) > 0 + has_deps_inlined = len(upstream_deps) > 0 or len(downstream_callers) > 0 # Get direct calls/callers (depth 1 only) direct_calls = self.call_graph.get(func_id, []) @@ -213,7 +213,7 @@ def create_unit(self, func_id: str, func_data: Dict) -> Dict: 'end_line': func_data.get('end_line'), 'function_name': func_name, 'class_name': class_name, - 'enhanced': is_enhanced, + 'deps_inlined': has_deps_inlined, 'files_included': files_included, 'original_length': len(func_data.get('code', '')), 'enhanced_length': len(enhanced_code), @@ -259,7 +259,7 @@ def update_statistics(self, unit: Dict) -> None: self.statistics['units_with_upstream'] += 1 if dep_meta.get('total_downstream', 0) > 0: self.statistics['units_with_downstream'] += 1 - if unit.get('code', {}).get('primary_origin', {}).get('enhanced', False): + if unit.get('code', {}).get('primary_origin', {}).get('deps_inlined', False): self.statistics['units_enhanced'] += 1 def generate_units(self) -> Dict: diff --git a/libs/openant-core/parsers/zig/__init__.py b/libs/openant-core/parsers/zig/__init__.py new file mode 100644 index 0000000..0f5eabc --- /dev/null +++ b/libs/openant-core/parsers/zig/__init__.py @@ -0,0 +1 @@ +# Zig parser for OpenAnt diff --git a/libs/openant-core/parsers/zig/call_graph_builder.py b/libs/openant-core/parsers/zig/call_graph_builder.py new file mode 100644 index 0000000..52f661d --- /dev/null +++ b/libs/openant-core/parsers/zig/call_graph_builder.py @@ -0,0 +1,325 @@ +""" +Stage 3: Call Graph Builder for Zig + +Builds bidirectional call graphs showing function dependencies. +""" + +import json +import re +from collections import defaultdict +from typing import Dict, Any, List, Set + +import tree_sitter_zig as ts_zig +from tree_sitter import Language, Parser, Node + + +class CallGraphBuilder: + """Builds call graphs from extracted Zig functions.""" + + ZIG_LANGUAGE = Language(ts_zig.language()) + + # Zig standard library and builtin functions to filter out + ZIG_BUILTINS = { + # Builtin functions + "@import", + "@as", + "@intCast", + "@floatCast", + "@ptrCast", + "@alignCast", + "@enumFromInt", + "@intFromEnum", + "@intFromPtr", + "@ptrFromInt", + "@errorName", + "@tagName", + "@typeName", + "@typeInfo", + "@Type", + "@sizeOf", + "@alignOf", + "@bitSizeOf", + "@offsetOf", + "@fieldParentPtr", + "@hasField", + "@hasDecl", + "@field", + "@call", + "@src", + "@This", + "@min", + "@max", + "@add", + "@sub", + "@mul", + "@div", + "@rem", + "@mod", + "@shl", + "@shr", + "@bitReverse", + "@byteSwap", + "@truncate", + "@reduce", + "@shuffle", + "@select", + "@splat", + "@memcpy", + "@memset", + "@ctz", + "@clz", + "@popCount", + "@abs", + "@sqrt", + "@sin", + "@cos", + "@tan", + "@exp", + "@exp2", + "@log", + "@log2", + "@log10", + "@floor", + "@ceil", + "@round", + "@mulAdd", + "@panic", + "@compileError", + "@compileLog", + "@breakpoint", + "@returnAddress", + "@frameAddress", + "@cmpxchgStrong", + "@cmpxchgWeak", + "@atomicLoad", + "@atomicStore", + "@atomicRmw", + "@fence", + "@prefetch", + "@setCold", + "@setRuntimeSafety", + "@setEvalBranchQuota", + "@setFloatMode", + "@setAlignStack", + "@errorReturnTrace", + "@asyncCall", + "@cDefine", + "@cInclude", + "@cUndef", + "@embedFile", + "@export", + "@extern", + "@unionInit", + "@wasmMemorySize", + "@wasmMemoryGrow", + # Common std functions + "print", + "println", + "debug", + "assert", + "expect", + "expectEqual", + "expectError", + "expectFmt", + "expectEqualSlices", + "expectEqualStrings", + "allocPrint", + "allocPrintZ", + "bufPrint", + "bufPrintZ", + "comptimePrint", + } + + def __init__(self, extractor_output: Dict[str, Any]): + self.functions = extractor_output.get("functions", {}) + self.classes = extractor_output.get("classes", {}) + self.imports = extractor_output.get("imports", {}) + self.repository = extractor_output.get("repository", "") + self.parser = Parser(self.ZIG_LANGUAGE) + + def build(self) -> Dict[str, Any]: + """ + Build the call graph. + + Returns call_graph.json structure with: + - functions (copied from extractor) + - classes (copied from extractor) + - imports (copied from extractor) + - call_graph: {caller_id: [callee_ids]} + - reverse_call_graph: {callee_id: [caller_ids]} + """ + call_graph: Dict[str, List[str]] = defaultdict(list) + reverse_call_graph: Dict[str, List[str]] = defaultdict(list) + + # Build an index of function names to IDs for resolution + name_to_ids = self._build_name_index() + + # For each function, find calls in its body + for func_id, func_info in self.functions.items(): + code = func_info.get("code", "") + file_path = func_info.get("file_path", "") + + # Parse the function code to find call sites + calls = self._find_calls_in_code(code) + + # Resolve each call to a function ID + for call_name in calls: + resolved_ids = self._resolve_call( + call_name, file_path, name_to_ids + ) + for resolved_id in resolved_ids: + if resolved_id != func_id: # No self-calls + if resolved_id not in call_graph[func_id]: + call_graph[func_id].append(resolved_id) + if func_id not in reverse_call_graph[resolved_id]: + reverse_call_graph[resolved_id].append(func_id) + + # Calculate statistics + total_edges = sum(len(callees) for callees in call_graph.values()) + out_degrees = [len(callees) for callees in call_graph.values()] + avg_out_degree = total_edges / len(self.functions) if self.functions else 0 + max_out_degree = max(out_degrees) if out_degrees else 0 + isolated = len( + [ + f + for f in self.functions + if f not in call_graph and f not in reverse_call_graph + ] + ) + + return { + "repository": self.repository, + "functions": self.functions, + "classes": self.classes, + "imports": self.imports, + "call_graph": dict(call_graph), + "reverse_call_graph": dict(reverse_call_graph), + "statistics": { + "total_functions": len(self.functions), + "total_edges": total_edges, + "avg_out_degree": round(avg_out_degree, 2), + "max_out_degree": max_out_degree, + "isolated_functions": isolated, + }, + } + + def _build_name_index(self) -> Dict[str, List[str]]: + """Build index from function names to function IDs.""" + name_to_ids: Dict[str, List[str]] = defaultdict(list) + + for func_id, func_info in self.functions.items(): + name = func_info.get("name", "") + qualified_name = func_info.get("qualified_name", "") + + if name: + name_to_ids[name].append(func_id) + if qualified_name and qualified_name != name: + name_to_ids[qualified_name].append(func_id) + + return name_to_ids + + def _find_calls_in_code(self, code: str) -> Set[str]: + """Find all function calls in a code snippet.""" + calls = set() + + try: + tree = self.parser.parse(code.encode("utf-8")) + self._extract_calls_from_node(tree.root_node, code.encode("utf-8"), calls) + except Exception: + # Fallback to regex-based extraction + calls = self._find_calls_with_regex(code) + + # Filter out builtins + calls = {c for c in calls if c not in self.ZIG_BUILTINS and not c.startswith("@")} + + return calls + + def _extract_calls_from_node( + self, node: Node, source: bytes, calls: Set[str] + ) -> None: + """Recursively extract call sites from AST nodes.""" + # Look for function call expressions + if node.type in ("call_expr", "call_expression", "CallExpr"): + # Get the function being called + for child in node.children: + if child.type in ("identifier", "IDENTIFIER", "field_access"): + call_name = self._get_node_text(child, source) + # Handle method calls (obj.method) + if "." in call_name: + parts = call_name.split(".") + calls.add(parts[-1]) # Add just the method name + calls.add(call_name) # Also add the full qualified name + else: + calls.add(call_name) + break + + # Recurse into children + for child in node.children: + self._extract_calls_from_node(child, source, calls) + + def _find_calls_with_regex(self, code: str) -> Set[str]: + """Fallback regex-based call detection.""" + calls = set() + + # Pattern for function calls: name(...) + # Matches: foo(), bar.baz(), self.method() + pattern = r"\b([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*\(" + + for match in re.finditer(pattern, code): + call_name = match.group(1) + if "." in call_name: + parts = call_name.split(".") + calls.add(parts[-1]) + calls.add(call_name) + else: + calls.add(call_name) + + return calls + + def _get_node_text(self, node: Node, source: bytes) -> str: + """Get the source text for a node.""" + return source[node.start_byte : node.end_byte].decode("utf-8", errors="replace") + + def _resolve_call( + self, + call_name: str, + caller_file: str, + name_to_ids: Dict[str, List[str]], + ) -> List[str]: + """ + Resolve a call name to function ID(s). + + Resolution order: + 1. Same file + 2. Imported files + 3. Unique name match + """ + candidates = name_to_ids.get(call_name, []) + + if not candidates: + return [] + + # 1. Prefer same file + same_file = [c for c in candidates if c.startswith(f"{caller_file}:")] + if same_file: + return same_file + + # 2. Check imported files + file_imports = self.imports.get(caller_file, []) + for candidate in candidates: + candidate_file = candidate.split(":")[0] + for imp in file_imports: + if imp in candidate_file or candidate_file.endswith(imp): + return [candidate] + + # 3. If unique match, use it + if len(candidates) == 1: + return candidates + + # Multiple matches, return all (conservative) + return candidates + + def save_results(self, output_path: str, results: Dict[str, Any]) -> None: + """Save call graph to a JSON file.""" + with open(output_path, "w") as f: + json.dump(results, f, indent=2) diff --git a/libs/openant-core/parsers/zig/function_extractor.py b/libs/openant-core/parsers/zig/function_extractor.py new file mode 100644 index 0000000..f3348a0 --- /dev/null +++ b/libs/openant-core/parsers/zig/function_extractor.py @@ -0,0 +1,280 @@ +""" +Stage 2: Function Extractor for Zig + +Extracts functions, methods, and structs from Zig source files using tree-sitter. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional, List + +import tree_sitter_zig as ts_zig +from tree_sitter import Language, Parser, Node + + +class FunctionExtractor: + """Extracts functions and structs from Zig source files using tree-sitter.""" + + ZIG_LANGUAGE = Language(ts_zig.language()) + + def __init__(self, repo_path: str, scan_results: Dict[str, Any]): + self.repo_path = Path(repo_path).resolve() + self.scan_results = scan_results + self.parser = Parser(self.ZIG_LANGUAGE) + + def extract(self) -> Dict[str, Any]: + """ + Extract all functions and structs from scanned files. + + Returns functions.json structure with functions, classes (structs), imports. + """ + functions = {} + classes = {} # Zig structs + imports = {} + files_processed = 0 + files_with_errors = 0 + + for file_info in self.scan_results.get("files", []): + file_path = file_info["path"] + full_path = self.repo_path / file_path + + try: + with open(full_path, "rb") as f: + source = f.read() + + tree = self.parser.parse(source) + file_functions, file_structs, file_imports = self._extract_from_tree( + tree.root_node, source, file_path + ) + + functions.update(file_functions) + classes.update(file_structs) + imports[file_path] = file_imports + files_processed += 1 + + except Exception as e: + print(f"Error processing {file_path}: {e}") + files_with_errors += 1 + + return { + "repository": str(self.repo_path), + "extraction_time": datetime.now().isoformat(), + "functions": functions, + "classes": classes, + "imports": imports, + "statistics": { + "total_functions": len(functions), + "total_classes": len(classes), + "files_processed": files_processed, + "files_with_errors": files_with_errors, + }, + } + + def _extract_from_tree( + self, root: Node, source: bytes, file_path: str + ) -> tuple[Dict[str, Any], Dict[str, Any], List[str]]: + """Extract functions, structs, and imports from a parse tree.""" + functions = {} + structs = {} + imports = [] + + # Walk the AST + self._walk_node(root, source, file_path, functions, structs, imports, None) + + return functions, structs, imports + + def _walk_node( + self, + node: Node, + source: bytes, + file_path: str, + functions: Dict[str, Any], + structs: Dict[str, Any], + imports: List[str], + current_struct: Optional[str], + ) -> None: + """Recursively walk the AST to extract definitions.""" + + if node.type == "function_declaration" or node.type == "FnProto": + func_info = self._extract_function(node, source, file_path, current_struct) + if func_info: + func_id = f"{file_path}:{func_info['qualified_name']}" + functions[func_id] = func_info + + elif node.type == "VarDecl": + # Check if this is a struct/enum definition + struct_info = self._extract_struct_from_var_decl(node, source, file_path) + if struct_info: + struct_id = f"{file_path}:{struct_info['name']}" + structs[struct_id] = struct_info + # Extract methods within the struct + self._extract_struct_methods( + node, source, file_path, struct_info["name"], functions + ) + + elif node.type == "container_decl" or node.type == "ContainerDecl": + # Direct struct/enum declarations + struct_info = self._extract_container(node, source, file_path) + if struct_info: + struct_id = f"{file_path}:{struct_info['name']}" + structs[struct_id] = struct_info + + elif node.type == "@import" or ( + node.type == "builtin_call_expr" + and self._get_node_text(node, source).startswith("@import") + ): + import_path = self._extract_import(node, source) + if import_path: + imports.append(import_path) + + # Recurse into children + for child in node.children: + self._walk_node( + child, source, file_path, functions, structs, imports, current_struct + ) + + def _extract_function( + self, node: Node, source: bytes, file_path: str, current_struct: Optional[str] + ) -> Optional[Dict[str, Any]]: + """Extract function information from a function declaration node.""" + # Find function name + name = None + parameters = [] + + for child in node.children: + if child.type == "identifier" or child.type == "IDENTIFIER": + name = self._get_node_text(child, source) + elif child.type == "parameters" or child.type == "ParamDeclList": + parameters = self._extract_parameters(child, source) + + if not name: + return None + + # Determine qualified name and unit type + if current_struct: + qualified_name = f"{current_struct}.{name}" + unit_type = "method" + else: + qualified_name = name + unit_type = self._classify_function(name, file_path) + + start_line = node.start_point[0] + 1 # 1-indexed + end_line = node.end_point[0] + 1 + + return { + "name": name, + "qualified_name": qualified_name, + "file_path": file_path, + "start_line": start_line, + "end_line": end_line, + "code": self._get_node_text(node, source), + "class_name": current_struct, + "module_name": None, + "parameters": parameters, + "unit_type": unit_type, + } + + def _extract_parameters(self, node: Node, source: bytes) -> List[str]: + """Extract parameter names from a parameter list node.""" + params = [] + for child in node.children: + if child.type == "parameter" or child.type == "ParamDecl": + for subchild in child.children: + if subchild.type == "identifier" or subchild.type == "IDENTIFIER": + params.append(self._get_node_text(subchild, source)) + break + return params + + def _extract_struct_from_var_decl( + self, node: Node, source: bytes, file_path: str + ) -> Optional[Dict[str, Any]]: + """Extract struct info from a variable declaration (const Foo = struct {...}).""" + name = None + is_struct = False + + for child in node.children: + if child.type == "identifier" or child.type == "IDENTIFIER": + name = self._get_node_text(child, source) + elif child.type == "container_decl" or child.type == "ContainerDecl": + is_struct = True + + if name and is_struct: + return { + "name": name, + "file_path": file_path, + "start_line": node.start_point[0] + 1, + "end_line": node.end_point[0] + 1, + "code": self._get_node_text(node, source), + } + return None + + def _extract_container( + self, node: Node, source: bytes, file_path: str + ) -> Optional[Dict[str, Any]]: + """Extract struct/enum from a container declaration.""" + # Anonymous container - try to find name from parent + return None + + def _extract_struct_methods( + self, + node: Node, + source: bytes, + file_path: str, + struct_name: str, + functions: Dict[str, Any], + ) -> None: + """Extract methods from within a struct definition.""" + for child in node.children: + if child.type == "container_decl" or child.type == "ContainerDecl": + for member in child.children: + if ( + member.type == "function_declaration" + or member.type == "FnProto" + or member.type == "container_field" + ): + # Check if it's a function field + func_info = self._extract_function( + member, source, file_path, struct_name + ) + if func_info: + func_id = f"{file_path}:{func_info['qualified_name']}" + functions[func_id] = func_info + + def _extract_import(self, node: Node, source: bytes) -> Optional[str]: + """Extract import path from an @import call.""" + text = self._get_node_text(node, source) + # Parse @import("path") + if "@import" in text: + start = text.find('"') + end = text.rfind('"') + if start != -1 and end != -1 and start < end: + return text[start + 1 : end] + return None + + def _get_node_text(self, node: Node, source: bytes) -> str: + """Get the source text for a node.""" + return source[node.start_byte : node.end_byte].decode("utf-8", errors="replace") + + def _classify_function(self, name: str, file_path: str) -> str: + """Classify the function type based on name and context.""" + name_lower = name.lower() + + # Test functions + if name_lower.startswith("test") or "_test" in name_lower: + return "test" + + # Init/constructor patterns + if name in ("init", "create", "new"): + return "constructor" + + # Main entry point + if name == "main": + return "function" + + return "function" + + def save_results(self, output_path: str, results: Dict[str, Any]) -> None: + """Save extraction results to a JSON file.""" + with open(output_path, "w") as f: + json.dump(results, f, indent=2) diff --git a/libs/openant-core/parsers/zig/repository_scanner.py b/libs/openant-core/parsers/zig/repository_scanner.py new file mode 100644 index 0000000..ae09564 --- /dev/null +++ b/libs/openant-core/parsers/zig/repository_scanner.py @@ -0,0 +1,135 @@ +""" +Stage 1: Repository Scanner for Zig + +Enumerates all Zig source files in a repository. +""" + +import os +import json +from datetime import datetime +from pathlib import Path +from typing import List, Dict, Any, Optional + + +class RepositoryScanner: + """Scans a repository for Zig source files.""" + + # Directories to exclude from scanning + EXCLUDE_DIRS = { + ".git", + "vendor", + "node_modules", + "zig-cache", + "zig-out", + ".zig-cache", + "__pycache__", + ".venv", + "venv", + "build", + "dist", + "target", + } + + # Test directory patterns to skip when skip_tests is True + TEST_PATTERNS = {"test", "tests", "spec", "specs", "_test", "test_"} + + def __init__( + self, + repo_path: str, + skip_tests: bool = True, + exclude_patterns: Optional[List[str]] = None, + ): + self.repo_path = Path(repo_path).resolve() + self.skip_tests = skip_tests + self.exclude_patterns = exclude_patterns or [] + + def scan(self) -> Dict[str, Any]: + """ + Scan the repository for Zig files. + + Returns scan_results.json structure: + { + "repository": "/path/to/repo", + "scan_time": "2025-01-15T10:30:00", + "files": [{"path": "src/main.zig", "size": 1234}, ...], + "statistics": {...} + } + """ + files = [] + directories_scanned = 0 + directories_excluded = 0 + + for root, dirs, filenames in os.walk(self.repo_path): + # Filter out excluded directories + original_dirs = dirs.copy() + dirs[:] = [ + d + for d in dirs + if d not in self.EXCLUDE_DIRS + and not self._matches_exclude_pattern(d) + and not (self.skip_tests and self._is_test_directory(d)) + ] + directories_excluded += len(original_dirs) - len(dirs) + directories_scanned += 1 + + for filename in filenames: + if not filename.endswith(".zig"): + continue + + file_path = Path(root) / filename + relative_path = file_path.relative_to(self.repo_path) + + # Skip test files if requested + if self.skip_tests and self._is_test_file(str(relative_path)): + continue + + try: + size = file_path.stat().st_size + except OSError: + size = 0 + + files.append({"path": str(relative_path), "size": size}) + + total_size = sum(f["size"] for f in files) + + return { + "repository": str(self.repo_path), + "scan_time": datetime.now().isoformat(), + "files": files, + "statistics": { + "total_files": len(files), + "total_size_bytes": total_size, + "directories_scanned": directories_scanned, + "directories_excluded": directories_excluded, + }, + } + + def _matches_exclude_pattern(self, name: str) -> bool: + """Check if a name matches any exclude pattern.""" + for pattern in self.exclude_patterns: + if pattern in name: + return True + return False + + def _is_test_directory(self, dirname: str) -> bool: + """Check if a directory name indicates test code.""" + dirname_lower = dirname.lower() + return any(pattern in dirname_lower for pattern in self.TEST_PATTERNS) + + def _is_test_file(self, filepath: str) -> bool: + """Check if a file path indicates test code.""" + filepath_lower = filepath.lower() + # Check for test in path components + parts = Path(filepath_lower).parts + for part in parts: + if any(pattern in part for pattern in self.TEST_PATTERNS): + return True + # Check for _test.zig suffix + if filepath_lower.endswith("_test.zig"): + return True + return False + + def save_results(self, output_path: str, results: Dict[str, Any]) -> None: + """Save scan results to a JSON file.""" + with open(output_path, "w") as f: + json.dump(results, f, indent=2) diff --git a/libs/openant-core/parsers/zig/test_pipeline.py b/libs/openant-core/parsers/zig/test_pipeline.py new file mode 100644 index 0000000..b4a9832 --- /dev/null +++ b/libs/openant-core/parsers/zig/test_pipeline.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +""" +Zig Parser Pipeline Orchestrator + +Entry point for parsing Zig repositories. Wires together the 4-stage pipeline: +1. Repository Scanner +2. Function Extractor +3. Call Graph Builder +4. Unit Generator + +Usage: + python test_pipeline.py \ + --output \ + --processing-level \ + --skip-tests \ + --name +""" + +import argparse +import json +import sys +from pathlib import Path + +# Add parent directories to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from parsers.zig.repository_scanner import RepositoryScanner +from parsers.zig.function_extractor import FunctionExtractor +from parsers.zig.call_graph_builder import CallGraphBuilder +from parsers.zig.unit_generator import UnitGenerator + + +def main(): + parser = argparse.ArgumentParser( + description="Parse Zig repositories for vulnerability analysis" + ) + parser.add_argument("repo_path", help="Path to the Zig repository") + parser.add_argument( + "--output", "-o", required=True, help="Output directory for results" + ) + parser.add_argument( + "--processing-level", + choices=["all", "reachable", "codeql", "exploitable"], + default="all", + help="Processing level for filtering functions", + ) + parser.add_argument( + "--skip-tests", action="store_true", help="Skip test files and functions" + ) + parser.add_argument("--name", help="Dataset name (defaults to repo directory name)") + parser.add_argument( + "--dependency-depth", + type=int, + default=3, + help="Maximum depth for dependency resolution", + ) + + args = parser.parse_args() + + repo_path = Path(args.repo_path).resolve() + output_dir = Path(args.output).resolve() + + if not repo_path.exists(): + print(f"Error: Repository path does not exist: {repo_path}", file=sys.stderr) + return 1 + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"[Zig Parser] Parsing repository: {repo_path}", file=sys.stderr) + print(f"[Zig Parser] Output directory: {output_dir}", file=sys.stderr) + print(f"[Zig Parser] Processing level: {args.processing_level}", file=sys.stderr) + print(f"[Zig Parser] Skip tests: {args.skip_tests}", file=sys.stderr) + + try: + # Stage 1: Repository Scanner + print("[Zig Parser] Stage 1: Scanning repository...", file=sys.stderr) + scanner = RepositoryScanner( + str(repo_path), + skip_tests=args.skip_tests, + ) + scan_results = scanner.scan() + scanner.save_results(str(output_dir / "scan_results.json"), scan_results) + print( + f" Found {scan_results['statistics']['total_files']} Zig files", + file=sys.stderr, + ) + + if scan_results["statistics"]["total_files"] == 0: + print("[Zig Parser] No Zig files found in repository", file=sys.stderr) + # Write empty dataset + empty_dataset = { + "name": args.name or repo_path.name, + "repository": str(repo_path), + "units": [], + "statistics": {"total_units": 0, "by_type": {}}, + "metadata": {"generator": "zig_unit_generator.py"}, + } + with open(output_dir / "dataset.json", "w") as f: + json.dump(empty_dataset, f, indent=2) + with open(output_dir / "analyzer_output.json", "w") as f: + json.dump({"repository": str(repo_path), "functions": {}}, f, indent=2) + return 0 + + # Stage 2: Function Extractor + print("[Zig Parser] Stage 2: Extracting functions...", file=sys.stderr) + extractor = FunctionExtractor(str(repo_path), scan_results) + extractor_output = extractor.extract() + print( + f" Extracted {extractor_output['statistics']['total_functions']} functions", + file=sys.stderr, + ) + print( + f" Extracted {extractor_output['statistics']['total_classes']} structs", + file=sys.stderr, + ) + + # Stage 3: Call Graph Builder + print("[Zig Parser] Stage 3: Building call graph...", file=sys.stderr) + call_graph_builder = CallGraphBuilder(extractor_output) + call_graph_output = call_graph_builder.build() + call_graph_builder.save_results( + str(output_dir / "call_graph.json"), call_graph_output + ) + print( + f" Built graph with {call_graph_output['statistics']['total_edges']} edges", + file=sys.stderr, + ) + + # Apply processing level filters + if args.processing_level != "all": + call_graph_output = apply_processing_filter( + call_graph_output, args.processing_level, str(repo_path) + ) + print( + f" After {args.processing_level} filter: {len(call_graph_output['functions'])} functions", + file=sys.stderr, + ) + + # Stage 4: Unit Generator + print("[Zig Parser] Stage 4: Generating analysis units...", file=sys.stderr) + generator = UnitGenerator( + call_graph_output, + str(repo_path), + dependency_depth=args.dependency_depth, + ) + dataset, analyzer_output = generator.generate(name=args.name) + generator.save_results(str(output_dir), dataset, analyzer_output) + print( + f" Generated {dataset['statistics']['total_units']} units", + file=sys.stderr, + ) + + print("[Zig Parser] Pipeline complete!", file=sys.stderr) + return 0 + + except Exception as e: + print(f"[Zig Parser] Error: {e}", file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + return 1 + + +def apply_processing_filter( + call_graph_output: dict, level: str, repo_path: str +) -> dict: + """ + Apply processing level filters to reduce the function set. + + Levels: + - all: No filtering (already handled) + - reachable: Filter to functions reachable from entry points + - codeql: Filter to reachable + CodeQL-flagged functions + - exploitable: Filter to reachable + CodeQL + LLM-classified exploitable + """ + if level == "reachable": + return apply_reachability_filter(call_graph_output, repo_path) + elif level == "codeql": + # First apply reachability, then would filter by CodeQL results + filtered = apply_reachability_filter(call_graph_output, repo_path) + # CodeQL filtering would be applied here if results exist + return filtered + elif level == "exploitable": + # Apply all filters + filtered = apply_reachability_filter(call_graph_output, repo_path) + # CodeQL + LLM filtering would be applied here + return filtered + return call_graph_output + + +def apply_reachability_filter(call_graph_output: dict, repo_path: str) -> dict: + """Filter to functions reachable from entry points.""" + try: + # Try to import the reachability analyzer + from utilities.agentic_enhancer.entry_point_detector import EntryPointDetector + from utilities.agentic_enhancer.reachability_analyzer import ReachabilityAnalyzer + + # Detect entry points + detector = EntryPointDetector(repo_path) + entry_points = detector.detect() + + # Analyze reachability + analyzer = ReachabilityAnalyzer(call_graph_output, entry_points) + reachable = analyzer.get_reachable_functions() + + # Filter functions to only reachable ones + filtered_functions = { + fid: finfo + for fid, finfo in call_graph_output["functions"].items() + if fid in reachable + } + + # Update the output with filtered functions + result = call_graph_output.copy() + result["functions"] = filtered_functions + + # Filter call graphs too + result["call_graph"] = { + k: [v for v in vs if v in reachable] + for k, vs in call_graph_output.get("call_graph", {}).items() + if k in reachable + } + result["reverse_call_graph"] = { + k: [v for v in vs if v in reachable] + for k, vs in call_graph_output.get("reverse_call_graph", {}).items() + if k in reachable + } + + return result + + except ImportError: + print( + " Warning: Reachability analyzer not available, skipping filter", + file=sys.stderr, + ) + return call_graph_output + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/libs/openant-core/parsers/zig/unit_generator.py b/libs/openant-core/parsers/zig/unit_generator.py new file mode 100644 index 0000000..de1ce1c --- /dev/null +++ b/libs/openant-core/parsers/zig/unit_generator.py @@ -0,0 +1,253 @@ +""" +Stage 4: Unit Generator for Zig + +Creates self-contained analysis units with dependency context. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, List, Optional, Set + + +class UnitGenerator: + """Generates analysis units from call graph data.""" + + # File boundary marker using Zig comment syntax + FILE_BOUNDARY = "\n\n// ========== File Boundary ==========\n\n" + + def __init__( + self, + call_graph_output: Dict[str, Any], + repo_path: str, + dependency_depth: int = 3, + ): + self.functions = call_graph_output.get("functions", {}) + self.classes = call_graph_output.get("classes", {}) + self.call_graph = call_graph_output.get("call_graph", {}) + self.reverse_call_graph = call_graph_output.get("reverse_call_graph", {}) + self.repository = repo_path + self.dependency_depth = dependency_depth + + def generate(self, name: Optional[str] = None) -> tuple[Dict[str, Any], Dict[str, Any]]: + """ + Generate analysis units. + + Returns: + (dataset.json, analyzer_output.json) + """ + units = [] + dataset_name = name or Path(self.repository).name + + for func_id, func_info in self.functions.items(): + unit = self._generate_unit(func_id, func_info) + units.append(unit) + + # Calculate statistics + by_type: Dict[str, int] = {} + units_with_upstream = 0 + units_with_downstream = 0 + total_upstream = 0 + total_downstream = 0 + + for unit in units: + unit_type = unit["unit_type"] + by_type[unit_type] = by_type.get(unit_type, 0) + 1 + + dep_meta = unit["code"]["dependency_metadata"] + if dep_meta["total_upstream"] > 0: + units_with_upstream += 1 + total_upstream += dep_meta["total_upstream"] + if dep_meta["total_downstream"] > 0: + units_with_downstream += 1 + total_downstream += dep_meta["total_downstream"] + + avg_upstream = total_upstream / len(units) if units else 0 + avg_downstream = total_downstream / len(units) if units else 0 + + dataset = { + "name": dataset_name, + "repository": self.repository, + "units": units, + "statistics": { + "total_units": len(units), + "by_type": by_type, + "units_with_upstream": units_with_upstream, + "units_with_downstream": units_with_downstream, + "units_enhanced": len([u for u in units if u["code"]["primary_origin"]["deps_inlined"]]), + "avg_upstream": round(avg_upstream, 2), + "avg_downstream": round(avg_downstream, 2), + }, + "metadata": { + "generator": "zig_unit_generator.py", + "generated_at": datetime.now().isoformat(), + "dependency_depth": self.dependency_depth, + }, + } + + # Generate analyzer_output.json (camelCase for historical reasons) + analyzer_output = { + "repository": self.repository, + "functions": { + func_id: { + "name": func_info["name"], + "unitType": func_info["unit_type"], + "code": func_info["code"], + "filePath": func_info["file_path"], + "startLine": func_info["start_line"], + "endLine": func_info["end_line"], + "isExported": self._is_exported(func_info), + "parameters": func_info.get("parameters", []), + "className": func_info.get("class_name"), + } + for func_id, func_info in self.functions.items() + }, + "call_graph": self.call_graph, + "reverse_call_graph": self.reverse_call_graph, + } + + return dataset, analyzer_output + + def _generate_unit(self, func_id: str, func_info: Dict[str, Any]) -> Dict[str, Any]: + """Generate a single analysis unit.""" + # Get dependencies (upstream - functions this calls) + upstream = self._get_dependencies(func_id, self.call_graph, self.dependency_depth) + # Get dependents (downstream - functions that call this) + downstream = self._get_dependencies(func_id, self.reverse_call_graph, self.dependency_depth) + + # Get direct callers and callees + direct_calls = self.call_graph.get(func_id, []) + direct_callers = self.reverse_call_graph.get(func_id, []) + + # Build enhanced code with dependencies + primary_code, files_included = self._build_enhanced_code(func_id, func_info, upstream) + + original_length = len(func_info.get("code", "")) + enhanced_length = len(primary_code) + + return { + "id": func_id, + "unit_type": func_info["unit_type"], + "code": { + "primary_code": primary_code, + "primary_origin": { + "file_path": func_info["file_path"], + "start_line": func_info["start_line"], + "end_line": func_info["end_line"], + "function_name": func_info["name"], + "class_name": func_info.get("class_name"), + "deps_inlined": len(upstream) > 0, + "files_included": files_included, + "original_length": original_length, + "enhanced_length": enhanced_length, + }, + "dependencies": [], + "dependency_metadata": { + "depth": self.dependency_depth, + "total_upstream": len(upstream), + "total_downstream": len(downstream), + "direct_calls": len(direct_calls), + "direct_callers": len(direct_callers), + }, + }, + "ground_truth": { + "status": "UNKNOWN", + "vulnerability_types": [], + "issues": [], + "annotation_source": None, + "annotation_key": None, + "notes": None, + }, + "metadata": { + "parameters": func_info.get("parameters", []), + "generator": "zig_unit_generator.py", + "direct_calls": direct_calls, + "direct_callers": direct_callers, + }, + } + + def _get_dependencies( + self, func_id: str, graph: Dict[str, List[str]], max_depth: int + ) -> Set[str]: + """Get all dependencies up to max_depth.""" + dependencies: Set[str] = set() + current_level = {func_id} + + for _ in range(max_depth): + next_level: Set[str] = set() + for fid in current_level: + for dep in graph.get(fid, []): + if dep not in dependencies and dep != func_id: + dependencies.add(dep) + next_level.add(dep) + current_level = next_level + if not current_level: + break + + return dependencies + + def _build_enhanced_code( + self, func_id: str, func_info: Dict[str, Any], upstream: Set[str] + ) -> tuple[str, List[str]]: + """Build enhanced code with dependency context.""" + # Start with the primary function's code + primary_code = func_info.get("code", "") + files_included = [func_info["file_path"]] + + if not upstream: + return primary_code, files_included + + # Group dependencies by file + deps_by_file: Dict[str, List[str]] = {} + for dep_id in upstream: + dep_info = self.functions.get(dep_id) + if dep_info: + file_path = dep_info["file_path"] + if file_path not in deps_by_file: + deps_by_file[file_path] = [] + deps_by_file[file_path].append(dep_id) + + # Build enhanced code + code_parts = [primary_code] + + for file_path, dep_ids in deps_by_file.items(): + if file_path == func_info["file_path"]: + # Same file - add dependencies without file boundary + for dep_id in dep_ids: + dep_info = self.functions.get(dep_id) + if dep_info: + code_parts.append(dep_info.get("code", "")) + else: + # Different file - add file boundary + if file_path not in files_included: + files_included.append(file_path) + file_code = [] + for dep_id in dep_ids: + dep_info = self.functions.get(dep_id) + if dep_info: + file_code.append(dep_info.get("code", "")) + if file_code: + code_parts.append(self.FILE_BOUNDARY + "\n".join(file_code)) + + return "\n\n".join(code_parts), files_included + + def _is_exported(self, func_info: Dict[str, Any]) -> bool: + """Check if a function is exported (pub in Zig).""" + code = func_info.get("code", "") + return code.strip().startswith("pub ") + + def save_results( + self, + output_dir: str, + dataset: Dict[str, Any], + analyzer_output: Dict[str, Any], + ) -> None: + """Save generated outputs to files.""" + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + with open(output_path / "dataset.json", "w") as f: + json.dump(dataset, f, indent=2) + + with open(output_path / "analyzer_output.json", "w") as f: + json.dump(analyzer_output, f, indent=2) diff --git a/libs/openant-core/report/__main__.py b/libs/openant-core/report/__main__.py index 91db04b..fbe6515 100644 --- a/libs/openant-core/report/__main__.py +++ b/libs/openant-core/report/__main__.py @@ -28,12 +28,13 @@ def cmd_summary(args): sys.exit(1) print("Generating summary report...") - report = generate_summary_report(pipeline_data) + report, usage = generate_summary_report(pipeline_data) output_path = Path(args.output) if args.output else Path("SUMMARY_REPORT.md") output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(report) print(f" -> {output_path}") + print(f" Cost: ${usage['cost_usd']:.4f} ({usage['total_tokens']:,} tokens)") def cmd_disclosures(args): @@ -57,7 +58,7 @@ def cmd_disclosures(args): continue print(f"Generating disclosure for {finding['short_name']}...") - disclosure = generate_disclosure(finding, product_name) + disclosure, _usage = generate_disclosure(finding, product_name) safe_name = finding["short_name"].replace(" ", "_").upper() filename = f"DISCLOSURE_{i:02d}_{safe_name}.md" diff --git a/libs/openant-core/report/generator.py b/libs/openant-core/report/generator.py index 7889e2c..25a55e8 100644 --- a/libs/openant-core/report/generator.py +++ b/libs/openant-core/report/generator.py @@ -1,5 +1,7 @@ """ Report Generator - generates security reports and disclosure documents from pipeline output. + +Returns (text, usage_dict) tuples from LLM functions so callers can track costs. """ import json @@ -16,6 +18,39 @@ PROMPTS_DIR = Path(__file__).parent / "prompts" MODEL = "claude-opus-4-6" +# Pricing per million tokens +_PRICING = { + "claude-opus-4-6": {"input": 15.00, "output": 75.00}, + "claude-opus-4-20250514": {"input": 15.00, "output": 75.00}, + "claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00}, +} +_DEFAULT_PRICING = {"input": 3.00, "output": 15.00} + + +def _extract_usage(response, model: str = MODEL) -> dict: + """Extract usage info from an Anthropic API response.""" + usage = response.usage + pricing = _PRICING.get(model, _DEFAULT_PRICING) + input_cost = (usage.input_tokens / 1_000_000) * pricing["input"] + output_cost = (usage.output_tokens / 1_000_000) * pricing["output"] + return { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "total_tokens": usage.input_tokens + usage.output_tokens, + "cost_usd": round(input_cost + output_cost, 6), + } + + +def _merge_usage(usages: list[dict]) -> dict: + """Merge multiple usage dicts into one.""" + merged = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0, "cost_usd": 0.0} + for u in usages: + merged["input_tokens"] += u["input_tokens"] + merged["output_tokens"] += u["output_tokens"] + merged["total_tokens"] += u["total_tokens"] + merged["cost_usd"] = round(merged["cost_usd"] + u["cost_usd"], 6) + return merged + def _check_api_key(): """Check that ANTHROPIC_API_KEY is set.""" @@ -30,6 +65,44 @@ def load_prompt(name: str) -> str: return (PROMPTS_DIR / f"{name}.txt").read_text() +def merge_dynamic_results(pipeline_data: dict, pipeline_path: str) -> dict: + """Merge dynamic test results into pipeline findings if available. + + Looks for dynamic_test_results.json next to the pipeline_output.json file + and adds a 'dynamic_testing' key to each matching finding. + """ + dynamic_path = Path(pipeline_path).parent / "dynamic_test_results.json" + if not dynamic_path.exists(): + return pipeline_data + + dynamic_data = json.loads(dynamic_path.read_text()) + results_by_id = {} + for result in dynamic_data.get("results", []): + fid = result.get("finding_id") + if fid: + results_by_id[fid] = result + + if not results_by_id: + return pipeline_data + + from datetime import datetime + date_str = datetime.fromtimestamp(dynamic_path.stat().st_mtime).strftime("%B %Y") + + for finding in pipeline_data.get("findings", []): + fid = finding.get("id") + if fid and fid in results_by_id: + r = results_by_id[fid] + finding["dynamic_testing"] = { + "status": r.get("status"), + "details": r.get("details"), + "evidence": r.get("evidence", []), + "tested": f"Docker container, {date_str}", + } + + print(f" Merged {len(results_by_id)} dynamic test results from {dynamic_path.name}", file=sys.stderr) + return pipeline_data + + def _compact_for_summary(pipeline_data: dict) -> dict: """Create a compact copy of pipeline_data for the summary prompt. @@ -48,13 +121,19 @@ def _compact_for_summary(pipeline_data: dict) -> dict: "cwe_name": f.get("cwe_name"), "stage1_verdict": f.get("stage1_verdict"), "stage2_verdict": f.get("stage2_verdict"), + "dynamic_testing": f.get("dynamic_testing"), "impact": f.get("impact"), }) return compact -def generate_summary_report(pipeline_data: dict) -> str: - """Generate a summary report from pipeline data.""" +def generate_summary_report(pipeline_data: dict) -> tuple[str, dict]: + """Generate a summary report from pipeline data. + + Returns: + (report_text, usage_dict) where usage_dict has input_tokens, + output_tokens, total_tokens, cost_usd. + """ _check_api_key() client = anthropic.Anthropic() @@ -69,11 +148,15 @@ def generate_summary_report(pipeline_data: dict) -> str: messages=[{"role": "user", "content": user_prompt}] ) - return response.content[0].text + return response.content[0].text, _extract_usage(response) -def generate_disclosure(vulnerability_data: dict, product_name: str) -> str: - """Generate a disclosure document for a single vulnerability.""" +def generate_disclosure(vulnerability_data: dict, product_name: str) -> tuple[str, dict]: + """Generate a disclosure document for a single vulnerability. + + Returns: + (disclosure_text, usage_dict) + """ _check_api_key() client = anthropic.Anthropic() @@ -92,7 +175,7 @@ def generate_disclosure(vulnerability_data: dict, product_name: str) -> str: messages=[{"role": "user", "content": user_prompt}] ) - return response.content[0].text + return response.content[0].text, _extract_usage(response) def generate_all(pipeline_path: str, output_dir: str) -> None: @@ -110,7 +193,7 @@ def generate_all(pipeline_path: str, output_dir: str) -> None: # Generate summary report print("Generating summary report...") - summary = generate_summary_report(pipeline_data) + summary, _usage = generate_summary_report(pipeline_data) (output_path / "SUMMARY_REPORT.md").write_text(summary) print(f" -> {output_path / 'SUMMARY_REPORT.md'}") @@ -125,7 +208,7 @@ def generate_all(pipeline_path: str, output_dir: str) -> None: continue print(f"Generating disclosure for {finding['short_name']}...") - disclosure = generate_disclosure(finding, product_name) + disclosure, _usage = generate_disclosure(finding, product_name) safe_name = finding["short_name"].replace(" ", "_").upper() filename = f"DISCLOSURE_{i:02d}_{safe_name}.md" diff --git a/libs/openant-core/report/schema.py b/libs/openant-core/report/schema.py index af12f40..e11e625 100644 --- a/libs/openant-core/report/schema.py +++ b/libs/openant-core/report/schema.py @@ -17,7 +17,7 @@ class Finding: cwe_name: str stage1_verdict: str stage2_verdict: str - dynamic_testing: bool = False + dynamic_testing: dict | bool = False description: Optional[str] = None vulnerable_code: Optional[str] = None impact: Optional[list] = None diff --git a/libs/openant-core/utilities/agentic_enhancer/agent.py b/libs/openant-core/utilities/agentic_enhancer/agent.py index a811b31..62061b7 100644 --- a/libs/openant-core/utilities/agentic_enhancer/agent.py +++ b/libs/openant-core/utilities/agentic_enhancer/agent.py @@ -17,6 +17,7 @@ import anthropic from ..llm_client import TokenTracker, get_global_tracker +from ..rate_limiter import get_rate_limiter from .repository_index import RepositoryIndex from .tools import TOOL_DEFINITIONS, ToolExecutor from .prompts import SYSTEM_PROMPT, get_user_prompt @@ -46,7 +47,10 @@ def __init__( total_tokens: int, is_entry_point: bool = False, reachable_from_entry: Optional[bool] = None, - entry_point_path: Optional[List[str]] = None + entry_point_path: Optional[List[str]] = None, + input_tokens: int = 0, + output_tokens: int = 0, + cost_usd: float = 0.0, ): self.include_functions = include_functions self.usage_context = usage_context @@ -58,6 +62,9 @@ def __init__( self.is_entry_point = is_entry_point self.reachable_from_entry = reachable_from_entry self.entry_point_path = entry_point_path + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.cost_usd = cost_usd def to_dict(self) -> dict: """Convert to dictionary for JSON serialization.""" @@ -69,7 +76,10 @@ def to_dict(self) -> dict: "confidence": self.confidence, "agent_metadata": { "iterations": self.iterations, - "total_tokens": self.total_tokens + "total_tokens": self.total_tokens, + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "cost_usd": self.cost_usd, }, "reachability": { "is_entry_point": self.is_entry_point, @@ -95,7 +105,8 @@ def __init__( tracker: TokenTracker = None, verbose: bool = False, entry_points: Optional[Set[str]] = None, - reachability: Optional[ReachabilityAnalyzer] = None + reachability: Optional[ReachabilityAnalyzer] = None, + client: Optional[anthropic.Anthropic] = None, ): """ Initialize the agent. @@ -106,6 +117,8 @@ def __init__( verbose: If True, print debug information entry_points: Set of func_ids that are entry points (optional) reachability: ReachabilityAnalyzer for checking user input paths (optional) + client: Shared Anthropic client (reuse across workers to avoid FD exhaustion). + If not provided, creates a new one (only for standalone/test use). """ self.index = index self.tracker = tracker or get_global_tracker() @@ -113,9 +126,7 @@ def __init__( self.tool_executor = ToolExecutor(index) self.entry_points = entry_points or set() self.reachability = reachability - - # Initialize Anthropic client - self.client = anthropic.Anthropic() + self.client = client or anthropic.Anthropic(max_retries=5) def analyze_unit( self, @@ -176,14 +187,42 @@ def analyze_unit( if self.verbose: print(f" Iteration {iterations}...") - # Call Claude - response = self.client.messages.create( - model=AGENT_MODEL, - max_tokens=MAX_TOKENS_PER_RESPONSE, - system=SYSTEM_PROMPT, - tools=TOOL_DEFINITIONS, - messages=messages - ) + # Call Claude with rate limiting + try: + # Wait if we're in a global backoff period + rate_limiter = get_rate_limiter() + rate_limiter.wait_if_needed() + + response = self.client.messages.create( + model=AGENT_MODEL, + max_tokens=MAX_TOKENS_PER_RESPONSE, + system=SYSTEM_PROMPT, + tools=TOOL_DEFINITIONS, + messages=messages + ) + except anthropic.RateLimitError as exc: + # Report to global rate limiter so all workers back off + retry_after = float(exc.response.headers.get("retry-after", 0)) + get_rate_limiter().report_rate_limit(retry_after) + # Attach agent state so the caller knows how far we got + exc.agent_state = { + "iteration": iterations, + "max_iterations": MAX_ITERATIONS, + "tokens_used": total_input_tokens + total_output_tokens, + "input_tokens": total_input_tokens, + "output_tokens": total_output_tokens, + } + raise + except Exception as exc: + # Attach agent state so the caller knows how far we got + exc.agent_state = { + "iteration": iterations, + "max_iterations": MAX_ITERATIONS, + "tokens_used": total_input_tokens + total_output_tokens, + "input_tokens": total_input_tokens, + "output_tokens": total_output_tokens, + } + raise # Track tokens total_input_tokens += response.usage.input_tokens @@ -259,7 +298,7 @@ def analyze_unit( # If finish was called, return result if finish_result: # Record token usage - self.tracker.record_call( + call_record = self.tracker.record_call( model=AGENT_MODEL, input_tokens=total_input_tokens, output_tokens=total_output_tokens @@ -275,19 +314,42 @@ def analyze_unit( total_tokens=total_input_tokens + total_output_tokens, is_entry_point=is_entry_point, reachable_from_entry=reachable_from_entry, - entry_point_path=entry_point_path + entry_point_path=entry_point_path, + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + cost_usd=call_record.get("cost_usd", 0.0), ) # Add assistant message and tool results to conversation messages.append({"role": "assistant", "content": assistant_content}) - messages.append({"role": "user", "content": tool_results}) + + # Only add user message with tool results if there are results + # (empty content triggers API error: "user messages must have non-empty content") + if tool_results: + messages.append({"role": "user", "content": tool_results}) + else: + # No tool calls but model didn't end — treat as incomplete + if self.verbose: + print(" No tool calls in response, treating as incomplete") + return AgentResult( + include_functions=[], + usage_context="Agent response had no tool calls", + security_classification="neutral", + classification_reasoning="Analysis incomplete - no tool calls", + confidence=0.3, + iterations=iterations, + total_tokens=total_input_tokens + total_output_tokens, + is_entry_point=is_entry_point, + reachable_from_entry=reachable_from_entry, + entry_point_path=entry_point_path + ) # Max iterations reached if self.verbose: print(f" Max iterations ({MAX_ITERATIONS}) reached") # Record token usage - self.tracker.record_call( + call_record = self.tracker.record_call( model=AGENT_MODEL, input_tokens=total_input_tokens, output_tokens=total_output_tokens @@ -303,7 +365,10 @@ def analyze_unit( total_tokens=total_input_tokens + total_output_tokens, is_entry_point=is_entry_point, reachable_from_entry=reachable_from_entry, - entry_point_path=entry_point_path + entry_point_path=entry_point_path, + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + cost_usd=call_record.get("cost_usd", 0.0), ) @@ -313,7 +378,8 @@ def enhance_unit_with_agent( tracker: TokenTracker = None, verbose: bool = False, entry_points: Optional[Set[str]] = None, - reachability: Optional[ReachabilityAnalyzer] = None + reachability: Optional[ReachabilityAnalyzer] = None, + client: Optional[anthropic.Anthropic] = None, ) -> dict: """ Enhance a single unit using the agentic approach. @@ -325,6 +391,7 @@ def enhance_unit_with_agent( verbose: Print debug info entry_points: Set of func_ids that are entry points (optional) reachability: ReachabilityAnalyzer for checking user input paths (optional) + client: Shared Anthropic client (reuse across workers to avoid FD exhaustion). Returns: Enhanced unit with agent_context field including reachability info @@ -334,7 +401,8 @@ def enhance_unit_with_agent( tracker=tracker, verbose=verbose, entry_points=entry_points, - reachability=reachability + reachability=reachability, + client=client, ) # Extract unit info @@ -385,7 +453,7 @@ def enhance_unit_with_agent( origin = unit["code"].get("primary_origin", {}) current_files = set(origin.get("files_included", [])) origin["files_included"] = list(current_files | additional_files) - origin["enhanced"] = True + origin["deps_inlined"] = True origin["enhanced_length"] = len(assembled) unit["code"]["primary_origin"] = origin diff --git a/libs/openant-core/utilities/context_enhancer.py b/libs/openant-core/utilities/context_enhancer.py index 452a53e..df1a8ac 100644 --- a/libs/openant-core/utilities/context_enhancer.py +++ b/libs/openant-core/utilities/context_enhancer.py @@ -15,13 +15,28 @@ import json import argparse import logging +import os import sys +import threading import time +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Callable, Optional +import anthropic + from .llm_client import AnthropicClient, TokenTracker, get_global_tracker, reset_global_tracker from .agentic_enhancer import RepositoryIndex, enhance_unit_with_agent, load_index_from_file +from .rate_limiter import get_rate_limiter, is_rate_limit_error, is_retryable_error + +# Avoid circular import — import checkpoint at usage site +_StepCheckpoint = None +def _get_step_checkpoint(): + global _StepCheckpoint + if _StepCheckpoint is None: + from core.checkpoint import StepCheckpoint + _StepCheckpoint = StepCheckpoint + return _StepCheckpoint # Null logger that discards all messages (used when no logger provided) @@ -33,6 +48,45 @@ CONTEXT_ENHANCEMENT_MODEL = "claude-sonnet-4-20250514" +def _build_error_info(exc: Exception) -> dict: + """Build a structured error dict from an exception. + + Captures exception type, message, HTTP status, request ID, and + any agent iteration state attached by agent.py. + """ + info = { + "type": "unknown", + "exception_class": type(exc).__name__, + "message": str(exc), + } + + # Anthropic SDK specific exceptions + if isinstance(exc, anthropic.APIConnectionError): + info["type"] = "connection" + elif isinstance(exc, anthropic.APITimeoutError): + info["type"] = "timeout" + elif isinstance(exc, anthropic.RateLimitError): + info["type"] = "rate_limit" + info["status_code"] = exc.status_code + if hasattr(exc, "response") and exc.response is not None: + info["request_id"] = exc.response.headers.get("request-id") + retry_after = exc.response.headers.get("retry-after") + if retry_after: + info["retry_after"] = retry_after + elif isinstance(exc, anthropic.APIStatusError): + info["type"] = "api_status" + info["status_code"] = exc.status_code + if hasattr(exc, "response") and exc.response is not None: + info["request_id"] = exc.response.headers.get("request-id") + + # Agent iteration state (attached by agent.py) + agent_state = getattr(exc, "agent_state", None) + if agent_state: + info["agent_state"] = agent_state + + return info + + def get_context_enhancement_prompt( function_id: str, function_name: str, @@ -270,15 +324,19 @@ def enhance_dataset( dataset: dict, batch_size: int = 10, progress_callback: Optional[Callable] = None, + workers: int = 10, ) -> dict: """ Enhance all units in a dataset (single-shot mode). + Uses ThreadPoolExecutor for parallel processing when workers > 1. + Args: dataset: The dataset from unit_generator.js batch_size: Number of units to process before printing progress progress_callback: Optional callback(unit_id, classification, unit_elapsed) called after each unit completes. + workers: Number of parallel workers (default: 10). Returns: Enhanced dataset @@ -288,22 +346,55 @@ def enhance_dataset( self._log("info", f"Enhancing {total} units with LLM context (single-shot mode)", units=total) self._log("info", f"Model: {CONTEXT_ENHANCEMENT_MODEL}") + mode = "sequential" if workers <= 1 else f"parallel ({workers} workers)" + self._log("info", f"Mode: {mode}") # Build lookup dict for context gathering units_by_id = {u.get("id"): u for u in units} - for i, unit in enumerate(units): - if (i + 1) % batch_size == 0 or i == 0: - self._log("info", f"Processing unit {i + 1}/{total}", unit_id=unit.get("id")) - + def _process_one(unit): + """Process a single unit. Mutates unit in-place.""" unit_start = time.monotonic() self.enhance_unit(unit, units_by_id) unit_elapsed = time.monotonic() - unit_start - - if progress_callback: - ctx = unit.get("llm_context", {}) - classification = ctx.get("confidence", "unknown") - progress_callback(unit.get("id", "?"), str(classification), unit_elapsed) + ctx = unit.get("llm_context", {}) + classification = ctx.get("confidence", "unknown") + worker = threading.current_thread().name + return unit.get("id", "?"), str(classification), unit_elapsed, worker + + if workers <= 1: + for unit in units: + uid, classification, elapsed, _worker = _process_one(unit) + if progress_callback: + progress_callback(uid, classification, elapsed) + else: + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = {executor.submit(_process_one, unit): unit for unit in units} + for future in as_completed(futures): + uid, classification, elapsed, worker = future.result() + if progress_callback: + progress_callback(uid, f"{classification} [{worker}]", elapsed) + + # Recompute stats from unit results (thread-safe) + self.stats = { + "units_processed": 0, + "units_enhanced": 0, + "dependencies_added": 0, + "callers_added": 0, + "data_flows_extracted": 0, + "errors": 0, + } + for unit in units: + ctx = unit.get("llm_context", {}) + self.stats["units_processed"] += 1 + if ctx.get("reasoning") != "LLM analysis failed, using static analysis only": + self.stats["units_enhanced"] += 1 + self.stats["dependencies_added"] += len(ctx.get("missing_dependencies", [])) + self.stats["callers_added"] += len(ctx.get("additional_callers", [])) + if ctx.get("data_flow", {}).get("security_relevant_flows"): + self.stats["data_flows_extracted"] += 1 + if ctx.get("confidence", 1.0) <= 0.3 and ctx.get("reasoning", "").startswith("LLM analysis failed"): + self.stats["errors"] += 1 # Get token usage stats token_stats = self.tracker.get_totals() @@ -341,6 +432,8 @@ def enhance_dataset_agentic( verbose: bool = False, checkpoint_path: str = None, progress_callback: Optional[Callable] = None, + restored_callback: Optional[Callable] = None, + workers: int = 10, ) -> dict: """ Enhance all units using agentic approach with tool use. @@ -348,8 +441,11 @@ def enhance_dataset_agentic( This mode traces call paths iteratively to understand code intent. More accurate but slower and more expensive than single-shot mode. - Supports checkpoint/resume: if checkpoint_path is provided, saves progress - after each unit and skips already-processed units on resume. + Uses ThreadPoolExecutor for parallel processing when workers > 1. + + Supports checkpoint/resume: if checkpoint_path is provided, each completed + unit is saved to a separate file under a checkpoints directory. On resume, + completed units are loaded from their individual checkpoint files. Args: dataset: The dataset from unit_generator.js @@ -357,9 +453,13 @@ def enhance_dataset_agentic( repo_path: Repository root path (for file reading) batch_size: Number of units to process before printing progress verbose: Print debug information - checkpoint_path: Path to save/load checkpoint file (enables resume) + checkpoint_path: Path to checkpoint directory (enables resume). + If provided, per-unit results are saved under this directory. progress_callback: Optional callback(unit_id, classification, unit_elapsed) called after each unit completes. + restored_callback: Optional callback(count) called after checkpoint + loading with the number of restored units. + workers: Number of parallel workers (default: 10). Returns: Enhanced dataset with agent_context field @@ -367,38 +467,96 @@ def enhance_dataset_agentic( units = dataset.get("units", []) total = len(units) - # Check for existing checkpoint - checkpoint_data = None + # Checkpoint directory setup + checkpoint_dir = None processed_ids = set() if checkpoint_path: - checkpoint_file = Path(checkpoint_path) - if checkpoint_file.exists(): - self._log("info", f"Found checkpoint at {checkpoint_path}, resuming...") - with open(checkpoint_file, 'r') as f: - checkpoint_data = json.load(f) - - # Build set of already-processed unit IDs - for cp_unit in checkpoint_data.get("units", []): - if cp_unit.get("agent_context") and not cp_unit["agent_context"].get("error"): - processed_ids.add(cp_unit.get("id")) - - # Restore units from checkpoint - cp_units_by_id = {u.get("id"): u for u in checkpoint_data.get("units", [])} - for unit in units: - unit_id = unit.get("id") - if unit_id in cp_units_by_id and cp_units_by_id[unit_id].get("agent_context"): - unit["agent_context"] = cp_units_by_id[unit_id]["agent_context"] - if "code" in cp_units_by_id[unit_id]: - unit["code"] = cp_units_by_id[unit_id]["code"] - - self._log("info", f"Restored {len(processed_ids)} already-processed units", units=len(processed_ids)) + # Use checkpoint_path as a directory for per-unit files + checkpoint_dir = checkpoint_path if os.path.isdir(checkpoint_path) or not checkpoint_path.endswith(".json") else os.path.splitext(checkpoint_path)[0] + "_checkpoints" + os.makedirs(checkpoint_dir, exist_ok=True) + + # Check for legacy single-file checkpoint and migrate + if os.path.isfile(checkpoint_path) and checkpoint_path.endswith(".json"): + self._migrate_legacy_checkpoint(checkpoint_path, checkpoint_dir, units) + + # Load completed unit IDs from per-unit checkpoint files + processed_ids = self._load_completed_units(checkpoint_dir) + + # Restore agent_context from checkpoint files into units + for unit in units: + unit_id = unit.get("id") + if unit_id in processed_ids: + cp_file = os.path.join(checkpoint_dir, f"{self._safe_filename(unit_id)}.json") + if os.path.exists(cp_file): + with open(cp_file, 'r') as f: + cp_data = json.load(f) + unit["agent_context"] = cp_data.get("agent_context", {}) + if "code" in cp_data: + unit["code"] = cp_data["code"] + + if processed_ids: + self._log("info", f"Restored {len(processed_ids)} already-processed units from checkpoints", units=len(processed_ids)) + if restored_callback: + restored_callback(len(processed_ids)) + + # Initialize summary tracking for _summary.json + # Counts are updated in the main thread (as_completed loop) — no lock needed. + _summary_cp = None + _summary_completed = len(processed_ids) + _summary_errors = 0 + _summary_error_breakdown = {} + _summary_input_tokens = 0 + _summary_output_tokens = 0 + _summary_cost_usd = 0.0 + + if checkpoint_dir: + SC = _get_step_checkpoint() + _summary_cp = SC.__new__(SC) + _summary_cp.step_name = "enhance" + _summary_cp.dir = checkpoint_dir + + # Count errors and sum usage from already-loaded checkpoints + for unit in units: + uid = unit.get("id", "") + cp_file = os.path.join(checkpoint_dir, f"{self._safe_filename(uid)}.json") + if not os.path.exists(cp_file): + continue + try: + with open(cp_file, 'r') as f: + cp_data = json.load(f) + # Sum usage from all existing checkpoints (completed + errored) + cp_usage = cp_data.get("usage", {}) + _summary_input_tokens += cp_usage.get("input_tokens", 0) + _summary_output_tokens += cp_usage.get("output_tokens", 0) + _summary_cost_usd += cp_usage.get("cost_usd", 0.0) + # Count errors for non-completed units + if uid not in processed_ids and cp_data.get("agent_context", {}).get("error"): + _summary_errors += 1 + err = cp_data["agent_context"]["error"] + err_type = err.get("type", "unknown") if isinstance(err, dict) else "unknown" + _summary_error_breakdown[err_type] = _summary_error_breakdown.get(err_type, 0) + 1 + except (json.JSONDecodeError, OSError): + pass + + _summary_cp.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="in_progress", + usage={"input_tokens": _summary_input_tokens, + "output_tokens": _summary_output_tokens, + "cost_usd": round(_summary_cost_usd, 6)}) + + # Inject prior usage into tracker so step_report captures the total + if _summary_input_tokens or _summary_output_tokens: + self.tracker.add_prior_usage( + _summary_input_tokens, _summary_output_tokens, _summary_cost_usd) remaining = total - len(processed_ids) self._log("info", f"Enhancing {remaining} units with agentic analysis ({len(processed_ids)} already done)", units=remaining) self._log("info", "Mode: Iterative tool use (traces call paths)") self._log("info", "Model: claude-sonnet-4-20250514") - if checkpoint_path: - self._log("info", f"Checkpoint: {checkpoint_path}") + mode = "sequential" if workers <= 1 else f"parallel ({workers} workers)" + self._log("info", f"Workers: {mode}") + if checkpoint_dir: + self._log("info", f"Checkpoint dir: {checkpoint_dir}") # Load repository index self._log("info", f"Loading repository index from {analyzer_output_path}") @@ -406,83 +564,160 @@ def enhance_dataset_agentic( stats = index.get_statistics() self._log("info", f"Indexed {stats['total_functions']} functions from {stats['total_files']} files") - # Track stats - agentic_stats = { - "units_processed": len(processed_ids), # Start from checkpoint count - "units_with_context": 0, - "total_iterations": 0, - "functions_added": 0, - "security_controls_found": 0, - "vulnerable_found": 0, - "neutral_found": 0, - "errors": 0 - } + # Create a single shared Anthropic client for all workers. + # Each ContextAgent previously created its own anthropic.Anthropic() instance, + # which spawns a new httpx connection pool. With 1000+ units and 8 workers, + # this exhausted file descriptors (macOS limit ~256). The httpx.Client + # underlying anthropic.Anthropic is thread-safe, so sharing is correct. + shared_client = anthropic.Anthropic(max_retries=5) - # Count stats from restored units - for unit in units: - agent_ctx = unit.get("agent_context", {}) - if agent_ctx and unit.get("id") in processed_ids: - if agent_ctx.get("include_functions"): - agentic_stats["units_with_context"] += 1 - agentic_stats["functions_added"] += len(agent_ctx["include_functions"]) - classification = agent_ctx.get("security_classification", "neutral") - if classification == "security_control": - agentic_stats["security_controls_found"] += 1 - elif classification == "vulnerable": - agentic_stats["vulnerable_found"] += 1 - else: - agentic_stats["neutral_found"] += 1 - agentic_stats["total_iterations"] += agent_ctx.get("agent_metadata", {}).get("iterations", 0) + # Filter to unprocessed units + units_to_process = [(i, unit) for i, unit in enumerate(units) if unit.get("id") not in processed_ids] - processed_this_run = 0 - for i, unit in enumerate(units): + def _enhance_one(unit): + """Enhance a single unit. Mutates unit in-place, returns metadata.""" unit_id = unit.get("id") - - # Skip already-processed units - if unit_id in processed_ids: - continue - - processed_this_run += 1 - if processed_this_run % batch_size == 1 or processed_this_run == 1: - self._log("info", f"Processing unit {agentic_stats['units_processed'] + 1}/{total}", unit_id=unit_id) - unit_start = time.monotonic() + classification = "neutral" try: - enhance_unit_with_agent(unit, index, self.tracker, verbose) - agentic_stats["units_processed"] += 1 + enhance_unit_with_agent(unit, index, self.tracker, verbose, client=shared_client) agent_ctx = unit.get("agent_context", {}) - if agent_ctx.get("include_functions"): - agentic_stats["units_with_context"] += 1 - agentic_stats["functions_added"] += len(agent_ctx["include_functions"]) - classification = agent_ctx.get("security_classification", "neutral") - if classification == "security_control": - agentic_stats["security_controls_found"] += 1 - elif classification == "vulnerable": - agentic_stats["vulnerable_found"] += 1 - else: - agentic_stats["neutral_found"] += 1 - - agentic_stats["total_iterations"] += agent_ctx.get("agent_metadata", {}).get("iterations", 0) except Exception as e: classification = "error" - agentic_stats["errors"] += 1 - self._log("error", f"Error processing unit", unit_id=unit_id, error=str(e)) + error_info = _build_error_info(e) + self._log("error", f"Error processing unit", + unit_id=unit_id, + error=error_info.get("message", str(e)), + error_type=error_info.get("type", "unknown")) unit["agent_context"] = { - "error": str(e), + "error": error_info, "security_classification": "neutral", "confidence": 0.0 } unit_elapsed = time.monotonic() - unit_start - if progress_callback: - progress_callback(unit_id or "?", classification, unit_elapsed) - - # Save checkpoint after each unit - if checkpoint_path: - self._save_checkpoint(dataset, checkpoint_path, agentic_stats) + worker = threading.current_thread().name + + # Save per-unit checkpoint (no lock — each file is unique) + if checkpoint_dir: + self._save_unit_checkpoint(unit, checkpoint_dir) + + return unit_id or "?", classification, unit_elapsed, worker + + def _update_summary(classification, unit): + """Update summary counters after a unit completes. Called from main thread.""" + nonlocal _summary_completed, _summary_errors, _summary_error_breakdown + nonlocal _summary_input_tokens, _summary_output_tokens, _summary_cost_usd + if _summary_cp is None: + return + if classification == "error": + _summary_errors += 1 + err = unit.get("agent_context", {}).get("error", {}) + err_type = err.get("type", "unknown") if isinstance(err, dict) else "unknown" + _summary_error_breakdown[err_type] = _summary_error_breakdown.get(err_type, 0) + 1 + else: + _summary_completed += 1 + # Accumulate per-unit usage + meta = unit.get("agent_context", {}).get("agent_metadata", {}) + _summary_input_tokens += meta.get("input_tokens", 0) + _summary_output_tokens += meta.get("output_tokens", 0) + _summary_cost_usd += meta.get("cost_usd", 0.0) + _summary_cp.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="in_progress", + usage={"input_tokens": _summary_input_tokens, + "output_tokens": _summary_output_tokens, + "cost_usd": round(_summary_cost_usd, 6)}) + + if workers <= 1: + # Sequential mode + try: + for _, unit in units_to_process: + uid, classification, elapsed, _worker = _enhance_one(unit) + _update_summary(classification, unit) + if progress_callback: + progress_callback(uid, classification, elapsed) + except KeyboardInterrupt: + self._log("warning", "Interrupted — progress saved to checkpoints") + return dataset + else: + # Parallel mode + executor = ThreadPoolExecutor(max_workers=workers) + futures = {executor.submit(_enhance_one, unit): unit for _, unit in units_to_process} + try: + for future in as_completed(futures): + unit = futures[future] + uid, classification, elapsed, worker = future.result() + _update_summary(classification, unit) + if progress_callback: + progress_callback(uid, f"{classification} [{worker}]", elapsed) + except KeyboardInterrupt: + self._log("warning", "Interrupted — cancelling pending work...") + executor.shutdown(wait=False, cancel_futures=True) + self._log("info", "Progress saved to checkpoints") + return dataset + executor.shutdown(wait=False) + + # Auto-retry failed units with transient errors (rate limit, connection, timeout, 5xx) + retryable_units = [ + (i, unit) for i, unit in enumerate(units) + if is_retryable_error(unit.get("agent_context", {}).get("error")) + ] + if retryable_units: + rate_limiter = get_rate_limiter() + backoff = rate_limiter.time_until_ready() + if backoff > 0: + self._log("info", + f"Retrying {len(retryable_units)} failed units " + f"(waiting {backoff:.0f}s for rate limit to clear)...") + rate_limiter.wait_if_needed() + else: + self._log("info", + f"Retrying {len(retryable_units)} failed units (transient errors)...") + + # Retry sequentially to avoid re-triggering rate limit + for i, unit in retryable_units: + # Clear previous error + unit["agent_context"] = {} + uid, classification, elapsed, _ = _enhance_one(unit) + + # Update summary: retry succeeded → flip error to completed + if classification != "error": + _summary_errors = max(0, _summary_errors - 1) + _summary_completed += 1 + # Decrement the old error type count (best effort) + # The error was already counted in _update_summary during initial pass + # Accumulate retry usage + meta = unit.get("agent_context", {}).get("agent_metadata", {}) + _summary_input_tokens += meta.get("input_tokens", 0) + _summary_output_tokens += meta.get("output_tokens", 0) + _summary_cost_usd += meta.get("cost_usd", 0.0) + if _summary_cp is not None: + _summary_cp.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="in_progress", + usage={"input_tokens": _summary_input_tokens, + "output_tokens": _summary_output_tokens, + "cost_usd": round(_summary_cost_usd, 6)}) + + # Save checkpoint (overwrite error with result) + if checkpoint_dir: + self._save_unit_checkpoint(unit, checkpoint_dir) + + if progress_callback: + progress_callback(uid, f"{classification} (retry)", elapsed) + + # Write final summary with phase="done" + if _summary_cp is not None: + _summary_cp.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="done", + usage={"input_tokens": _summary_input_tokens, + "output_tokens": _summary_output_tokens, + "cost_usd": round(_summary_cost_usd, 6)}) + + # Compute stats from all units (including previously checkpointed ones) + agentic_stats = self._compute_agentic_stats(units) # Get token usage stats token_stats = self.tracker.get_totals() @@ -503,7 +738,8 @@ def enhance_dataset_agentic( "units_with_context": agentic_stats['units_with_context'], "avg_iterations_per_unit": round(avg_iterations, 1), "security_controls": agentic_stats['security_controls_found'], - "vulnerable": agentic_stats['vulnerable_found'], + "exploitable": agentic_stats['exploitable_found'], + "vulnerable_internal": agentic_stats['vulnerable_found'], "neutral": agentic_stats['neutral_found'], "errors": agentic_stats['errors'] }) @@ -515,16 +751,111 @@ def enhance_dataset_agentic( return dataset - def _save_checkpoint(self, dataset: dict, checkpoint_path: str, agentic_stats: dict): - """Save checkpoint to disk after each unit is processed.""" - # Update metadata before saving - dataset["metadata"] = dataset.get("metadata", {}) - dataset["metadata"]["checkpoint"] = True - dataset["metadata"]["agentic_stats"] = agentic_stats - dataset["metadata"]["token_usage"] = self.tracker.get_totals() + @staticmethod + def _safe_filename(unit_id: str) -> str: + from utilities.safe_filename import safe_filename + return safe_filename(unit_id) + + def _save_unit_checkpoint(self, unit: dict, checkpoint_dir: str): + """Save a single unit's result to its own checkpoint file.""" + unit_id = unit.get("id", "unknown") + filename = self._safe_filename(unit_id) + ".json" + filepath = os.path.join(checkpoint_dir, filename) + cp_data = { + "id": unit_id, + "agent_context": unit.get("agent_context", {}), + } + # Include code if it was modified by the agent + if "code" in unit: + cp_data["code"] = unit["code"] + # Include per-unit usage from agent_metadata + meta = cp_data["agent_context"].get("agent_metadata", {}) + if meta.get("input_tokens") or meta.get("output_tokens"): + cp_data["usage"] = { + "input_tokens": meta.get("input_tokens", 0), + "output_tokens": meta.get("output_tokens", 0), + "cost_usd": meta.get("cost_usd", 0.0), + } + with open(filepath, 'w') as f: + json.dump(cp_data, f, indent=2) + + def _load_completed_units(self, checkpoint_dir: str) -> set: + """Load the set of completed unit IDs from per-unit checkpoint files.""" + completed = set() + if not os.path.isdir(checkpoint_dir): + return completed + for filename in os.listdir(checkpoint_dir): + if not filename.endswith(".json"): + continue + filepath = os.path.join(checkpoint_dir, filename) + try: + with open(filepath, 'r') as f: + cp_data = json.load(f) + unit_id = cp_data.get("id") + agent_ctx = cp_data.get("agent_context", {}) + if unit_id and agent_ctx and not agent_ctx.get("error"): + completed.add(unit_id) + except (json.JSONDecodeError, OSError): + continue + return completed + + def _migrate_legacy_checkpoint(self, checkpoint_path: str, checkpoint_dir: str, units: list): + """Migrate a legacy single-file checkpoint to per-unit checkpoint files.""" + try: + with open(checkpoint_path, 'r') as f: + checkpoint_data = json.load(f) + for cp_unit in checkpoint_data.get("units", []): + if cp_unit.get("agent_context") and not cp_unit["agent_context"].get("error"): + self._save_unit_checkpoint(cp_unit, checkpoint_dir) + self._log("info", f"Migrated legacy checkpoint to per-unit files in {checkpoint_dir}") + except Exception as e: + self._log("warning", f"Could not migrate legacy checkpoint: {e}") - with open(checkpoint_path, 'w') as f: - json.dump(dataset, f, indent=2) + @staticmethod + def _compute_agentic_stats(units: list) -> dict: + """Compute agentic stats from all units.""" + stats = { + "units_processed": 0, + "units_with_context": 0, + "total_iterations": 0, + "functions_added": 0, + "security_controls_found": 0, + "exploitable_found": 0, + "vulnerable_found": 0, + "neutral_found": 0, + "errors": 0, + "error_summary": {}, + } + for unit in units: + agent_ctx = unit.get("agent_context") + if not agent_ctx: + continue + if agent_ctx.get("error"): + stats["errors"] += 1 + # Tally errors by type + err = agent_ctx["error"] + if isinstance(err, dict): + err_type = err.get("type", "unknown") + else: + # Legacy string errors (from older runs) + err_type = "legacy_string" + stats["error_summary"][err_type] = stats["error_summary"].get(err_type, 0) + 1 + continue + stats["units_processed"] += 1 + if agent_ctx.get("include_functions"): + stats["units_with_context"] += 1 + stats["functions_added"] += len(agent_ctx["include_functions"]) + classification = agent_ctx.get("security_classification", "neutral") + if classification == "security_control": + stats["security_controls_found"] += 1 + elif classification == "exploitable": + stats["exploitable_found"] += 1 + elif classification == "vulnerable_internal": + stats["vulnerable_found"] += 1 + else: + stats["neutral_found"] += 1 + stats["total_iterations"] += agent_ctx.get("agent_metadata", {}).get("iterations", 0) + return stats def get_token_stats(self) -> dict: """ diff --git a/libs/openant-core/utilities/dynamic_tester/__init__.py b/libs/openant-core/utilities/dynamic_tester/__init__.py index 630f068..450d327 100644 --- a/libs/openant-core/utilities/dynamic_tester/__init__.py +++ b/libs/openant-core/utilities/dynamic_tester/__init__.py @@ -3,6 +3,9 @@ Takes pipeline_output.json from the static analysis pipeline and dynamically tests all detected vulnerabilities using Docker containers. +Supports checkpoint/resume: each completed finding is saved to a per-unit +checkpoint file so interrupted runs can resume automatically. + Public API: run_dynamic_tests(pipeline_output_path, output_dir) -> list[DynamicTestResult] """ @@ -16,13 +19,14 @@ from utilities.dynamic_tester.docker_executor import run_single_container from utilities.dynamic_tester.result_collector import collect_result from utilities.dynamic_tester.reporter import generate_report -from utilities.llm_client import TokenTracker +from utilities.llm_client import get_global_tracker def run_dynamic_tests( pipeline_output_path: str, output_dir: str | None = None, max_retries: int = 3, + checkpoint_path: str | None = None, ) -> list[DynamicTestResult]: """Run dynamic tests for all findings in a pipeline output file. @@ -30,6 +34,8 @@ def run_dynamic_tests( pipeline_output_path: Path to pipeline_output.json output_dir: Directory for output files. Defaults to same directory as pipeline_output_path. + max_retries: Max retries per finding on error (default 3). + checkpoint_path: Path to checkpoint directory for resume support. Returns: List of DynamicTestResult objects @@ -53,27 +59,111 @@ def run_dynamic_tests( output_dir = os.path.dirname(os.path.abspath(pipeline_output_path)) os.makedirs(output_dir, exist_ok=True) - tracker = TokenTracker() + # Set up checkpoint support + checkpoint = None + checkpointed = {} + if checkpoint_path is None: + checkpoint_path = os.path.join(output_dir, "dynamic_test_checkpoints") + + from core.checkpoint import StepCheckpoint + checkpoint = StepCheckpoint("dynamic_test", output_dir) + checkpoint.dir = checkpoint_path + if checkpoint.exists: + checkpointed = checkpoint.load() + + # Count successful vs errored checkpoints. Errored ones are NOT "already + # done" — they'll be retried with fresh test generation on resume. + successful_ids = {fid for fid, cp in checkpointed.items() + if cp.get("status") != "ERROR"} + errored_ids = {fid for fid in checkpointed.keys() if fid not in successful_ids} + + if successful_ids: + print(f"Restored {len(successful_ids)} already-tested findings from checkpoints", + file=sys.stderr, flush=True) + if errored_ids: + print(f"Retrying {len(errored_ids)} previously errored findings", + file=sys.stderr, flush=True) + + # Use the global tracker so step_context captures dynamic-test cost in + # dynamic-test.report.json (same as enhance/analyze/verify). + tracker = get_global_tracker() + + # Inject prior usage from ALL existing checkpoints (both successful and + # errored) so the report shows total cost across runs. The errored + # entries will be retried — their initial attempt cost is preserved, + # and the retry API calls get added on top. + _prior_input = 0 + _prior_output = 0 + _prior_cost = 0.0 + for _cp in checkpointed.values(): + _prior_cost += _cp.get("generation_cost_usd", 0) or 0 + _prior_input += _cp.get("generation_input_tokens", 0) or 0 + _prior_output += _cp.get("generation_output_tokens", 0) or 0 + if _prior_cost > 0 or _prior_input > 0 or _prior_output > 0: + tracker.add_prior_usage(_prior_input, _prior_output, _prior_cost) + results: list[DynamicTestResult] = [] - print(f"Dynamic testing {len(findings)} findings from {repo_info['name']}", + total = len(findings) + restored = len(successful_ids) + remaining = total - restored + _completed = restored + _errors = 0 + + # Write initial summary so Go CLI can show accurate counts + checkpoint.ensure_dir() + checkpoint.write_summary(total, _completed, _errors, {}, phase="in_progress") + + print(f"Dynamic testing {total} findings from {repo_info['name']} " + f"({restored} already done, {remaining} remaining)", file=sys.stderr) - for i, finding in enumerate(findings): + try: + for i, finding in enumerate(findings): finding_id = finding.get("id", f"FINDING-{i+1}") - print(f"\n[{i+1}/{len(findings)}] Testing {finding_id}: " + + # Skip already-checkpointed findings, but ONLY if they succeeded. + # Errored findings fall through to fresh test generation + Docker run, + # so code/prompt fixes take effect on resume. + cp_data = checkpointed.get(finding_id) + if cp_data and cp_data.get("status") != "ERROR": + result = DynamicTestResult( + finding_id=finding_id, + status=cp_data.get("status", "ERROR"), + details=cp_data.get("details", ""), + elapsed_seconds=cp_data.get("elapsed_seconds", 0), + generation_cost_usd=cp_data.get("generation_cost_usd", 0), + generation_input_tokens=cp_data.get("generation_input_tokens", 0), + generation_output_tokens=cp_data.get("generation_output_tokens", 0), + retry_count=cp_data.get("retry_count", 0), + test_code=cp_data.get("test_code", ""), + dockerfile=cp_data.get("dockerfile", ""), + docker_compose=cp_data.get("docker_compose", ""), + ) + results.append(result) + continue + + print(f"\n[{i+1}/{total}] Testing {finding_id}: " f"{finding.get('name', 'unknown')}...", file=sys.stderr) + # Begin per-unit tracking so we can capture token counts for this + # finding in addition to cost. + tracker.start_unit_tracking() + # Step 1: Generate test - cost_before = tracker.total_cost_usd print(" Generating test...", file=sys.stderr) generation = generate_test(finding, repo_info, tracker) - generation_cost = tracker.total_cost_usd - cost_before + unit_usage = tracker.get_unit_usage() + generation_cost = unit_usage["cost_usd"] if generation is None: print(" Test generation failed.", file=sys.stderr) result = collect_result(finding, None, None, generation_cost) + result.generation_input_tokens = unit_usage["input_tokens"] + result.generation_output_tokens = unit_usage["output_tokens"] results.append(result) + if checkpoint: + checkpoint.save(finding_id, result.to_dict()) continue print(f" Generated (${generation_cost:.4f}). Running in Docker...", @@ -104,12 +194,14 @@ def run_dynamic_tests( print(f" {error_type} error. Retry {retry_count}/{max_retries} " f"with error feedback...", file=sys.stderr) - retry_cost_before = tracker.total_cost_usd retry_gen = regenerate_test( finding, repo_info, generation, error_msg, tracker, ) - generation_cost += tracker.total_cost_usd - retry_cost_before + # Refresh unit usage after retry (tracker accumulates across calls + # on the same thread). + unit_usage = tracker.get_unit_usage() + generation_cost = unit_usage["cost_usd"] if retry_gen is None: print(f" Retry generation failed.", file=sys.stderr) @@ -122,10 +214,24 @@ def run_dynamic_tests( f"(${generation_cost:.4f})", file=sys.stderr) result.retry_count = retry_count + result.generation_input_tokens = unit_usage["input_tokens"] + result.generation_output_tokens = unit_usage["output_tokens"] results.append(result) + # Save checkpoint and update summary after each finding + if checkpoint: + checkpoint.save(finding_id, result.to_dict()) + _completed += 1 + if result.status == "ERROR": + _errors += 1 + checkpoint.write_summary(total, _completed, _errors, {}, phase="in_progress") + print(f" Result: {result.status} ({result.elapsed_seconds:.1f}s)", file=sys.stderr) + except KeyboardInterrupt: + print("\n[Dynamic Test] Interrupted — progress saved to checkpoints", + file=sys.stderr, flush=True) + return results # Generate report total_cost = tracker.total_cost_usd @@ -147,4 +253,9 @@ def run_dynamic_tests( }, f, indent=2, ensure_ascii=False) print(f"Results JSON written to {results_path}", file=sys.stderr) + # Mark done. Checkpoints are preserved as a permanent artifact alongside + # results — allows retroactive retry of errored findings after fixes. + if checkpoint: + checkpoint.write_summary(total, _completed, _errors, {}, phase="done") + return results diff --git a/libs/openant-core/utilities/dynamic_tester/docker_executor.py b/libs/openant-core/utilities/dynamic_tester/docker_executor.py index 4908297..864ef91 100644 --- a/libs/openant-core/utilities/dynamic_tester/docker_executor.py +++ b/libs/openant-core/utilities/dynamic_tester/docker_executor.py @@ -66,13 +66,17 @@ def _write_test_files(work_dir: str, generation: dict) -> None: # Write test script test_filename = generation.get("test_filename", "test_exploit.py") - with open(os.path.join(work_dir, test_filename), "w") as f: + test_path = os.path.join(work_dir, test_filename) + os.makedirs(os.path.dirname(test_path), exist_ok=True) + with open(test_path, "w") as f: f.write(generation["test_script"]) # Write requirements/dependencies file if generation.get("requirements"): req_filename = generation.get("requirements_filename", "requirements.txt") - with open(os.path.join(work_dir, req_filename), "w") as f: + req_path = os.path.join(work_dir, req_filename) + os.makedirs(os.path.dirname(req_path), exist_ok=True) + with open(req_path, "w") as f: f.write(generation["requirements"]) # Copy attacker server if needed (before docker-compose so it's available) @@ -127,18 +131,28 @@ def run_single_container( result = DockerExecutionResult() start_time = time.time() - # Sanitize finding_id for use as Docker image tag - image_tag = f"openant-test-{finding_id.lower().replace(' ', '-')}" - network_name = f"openant-net-{finding_id.lower().replace(' ', '-')}" - - work_dir = tempfile.mkdtemp(prefix=f"openant-test-{finding_id}-") + # Sanitize finding_id for use as Docker image tag. + # Docker tags must match [a-z0-9][a-z0-9._-]*, so strip anything else. + safe_id = re.sub(r"[^a-z0-9-]", "-", finding_id.lower()).strip("-_.") + image_tag = f"openant-test-{safe_id}" + network_name = f"openant-net-{safe_id}" + + # Use a deterministic, sanitized work_dir name so docker compose project + # names (derived from the dir name) are always valid Docker references. + # We still use mkdtemp for uniqueness but strip any non-alphanumeric chars. + raw_work_dir = tempfile.mkdtemp(prefix=f"openant-test-{safe_id}-") + parent = os.path.dirname(raw_work_dir) + safe_basename = re.sub(r"[^a-z0-9-]", "", os.path.basename(raw_work_dir).lower()).strip("-") + work_dir = os.path.join(parent, safe_basename) + if work_dir != raw_work_dir: + os.rename(raw_work_dir, work_dir) try: _write_test_files(work_dir, generation) if generation.get("docker_compose") and generation.get("needs_attacker_server"): - # Multi-service: use docker compose - result = _run_compose(work_dir, container_timeout, build_timeout) + # Multi-service: use docker compose with explicit project name + result = _run_compose(work_dir, safe_id, container_timeout, build_timeout) else: # Single container: docker build + run result = _run_single(work_dir, image_tag, network_name, @@ -178,7 +192,9 @@ def _run_single( # Create isolated network _run_command(["docker", "network", "create", network_name], timeout=10) - # Run with timeout, no host mounts, no privileged mode + # Run with timeout, no host mounts, no privileged mode. + # tmpfs for /tmp (writable workspace) and /root (for build caches like + # ~/.cache/go-build that some test runners write to even at runtime). stdout, stderr, code, timed_out = _run_command( [ "docker", "run", @@ -187,7 +203,8 @@ def _run_single( "--memory", "512m", "--cpus", "1", "--read-only", - "--tmpfs", "/tmp", + "--tmpfs", "/tmp:size=256m", + "--tmpfs", "/root:size=128m", "--security-opt", "no-new-privileges", image_tag, ], @@ -205,15 +222,22 @@ def _run_single( def _run_compose( work_dir: str, + project_name: str, container_timeout: int, build_timeout: int, ) -> DockerExecutionResult: - """Build and run multi-service test via docker compose.""" + """Build and run multi-service test via docker compose. + + Uses an explicit project name to ensure image tags are always valid + Docker references, independent of the temp dir name. + """ result = DockerExecutionResult() + compose_base = ["docker", "compose", "-p", project_name] + # Build all services stdout, stderr, code, timed_out = _run_command( - ["docker", "compose", "build"], + compose_base + ["build"], timeout=build_timeout, cwd=work_dir, ) @@ -225,7 +249,7 @@ def _run_compose( # Start services _run_command( - ["docker", "compose", "up", "-d"], + compose_base + ["up", "-d"], timeout=60, cwd=work_dir, ) @@ -233,7 +257,7 @@ def _run_compose( try: # Wait for the test container to exit (it should be the main service) stdout, stderr, code, timed_out = _run_command( - ["docker", "compose", "logs", "--no-log-prefix", "-f", "test"], + compose_base + ["logs", "--no-log-prefix", "-f", "test"], timeout=container_timeout, cwd=work_dir, ) @@ -244,7 +268,7 @@ def _run_compose( finally: # Always tear down _run_command( - ["docker", "compose", "down", "--volumes", "--remove-orphans"], + compose_base + ["down", "--volumes", "--remove-orphans"], timeout=30, cwd=work_dir, ) diff --git a/libs/openant-core/utilities/dynamic_tester/docker_templates/go.Dockerfile b/libs/openant-core/utilities/dynamic_tester/docker_templates/go.Dockerfile index 2b74c3b..97c3601 100644 --- a/libs/openant-core/utilities/dynamic_tester/docker_templates/go.Dockerfile +++ b/libs/openant-core/utilities/dynamic_tester/docker_templates/go.Dockerfile @@ -1,7 +1,10 @@ -FROM golang:1.22-alpine +FROM golang:1.25-alpine WORKDIR /test -COPY go.mod . -RUN go mod download COPY test_exploit.go . +# Initialize a fresh module and resolve dependencies in the container. +# This avoids needing go.sum/go.mod from the host, which is brittle +# when the LLM-generated test imports third-party packages. +RUN go mod init openant-test 2>/dev/null || true +RUN go mod tidy RUN go build -o test_exploit test_exploit.go CMD ["./test_exploit"] diff --git a/libs/openant-core/utilities/dynamic_tester/models.py b/libs/openant-core/utilities/dynamic_tester/models.py index 01dc4a9..55d7149 100644 --- a/libs/openant-core/utilities/dynamic_tester/models.py +++ b/libs/openant-core/utilities/dynamic_tester/models.py @@ -29,6 +29,8 @@ class DynamicTestResult: docker_compose: str = "" # Generated docker-compose.yml (if multi-service) elapsed_seconds: float = 0.0 generation_cost_usd: float = 0.0 + generation_input_tokens: int = 0 + generation_output_tokens: int = 0 retry_count: int = 0 def to_dict(self) -> dict: @@ -42,5 +44,7 @@ def to_dict(self) -> dict: "docker_compose": self.docker_compose, "elapsed_seconds": round(self.elapsed_seconds, 2), "generation_cost_usd": round(self.generation_cost_usd, 6), + "generation_input_tokens": self.generation_input_tokens, + "generation_output_tokens": self.generation_output_tokens, "retry_count": self.retry_count, } diff --git a/libs/openant-core/utilities/dynamic_tester/result_collector.py b/libs/openant-core/utilities/dynamic_tester/result_collector.py index 51430ed..4586f07 100644 --- a/libs/openant-core/utilities/dynamic_tester/result_collector.py +++ b/libs/openant-core/utilities/dynamic_tester/result_collector.py @@ -104,7 +104,7 @@ def collect_result( status = "INCONCLUSIVE" evidence = [] - for e in parsed.get("evidence", []): + for e in (parsed.get("evidence") or []): if isinstance(e, dict) and "type" in e and "content" in e: evidence.append(TestEvidence(type=e["type"], content=str(e["content"])[:5000])) diff --git a/libs/openant-core/utilities/dynamic_tester/test_generator.py b/libs/openant-core/utilities/dynamic_tester/test_generator.py index fd5cd13..422c5fa 100644 --- a/libs/openant-core/utilities/dynamic_tester/test_generator.py +++ b/libs/openant-core/utilities/dynamic_tester/test_generator.py @@ -9,6 +9,10 @@ import json import os import re +import sys +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed from utilities.llm_client import AnthropicClient, TokenTracker @@ -43,10 +47,27 @@ - Do NOT pin exact versions unless the vulnerability is version-specific. Use >= or no version pin. - For Python: put ALL dependencies in requirements.txt, use `pip install --no-cache-dir -r requirements.txt`. - For Node.js: put ALL dependencies in package.json. +- For Go: do NOT write go.mod or go.sum yourself. Instead, the Dockerfile MUST initialize + the module inside the container using `RUN go mod init && go mod tidy`. This works + whether the test uses stdlib only or third-party packages. Use golang:1.25-alpine as the + base image to support modern k8s and cloud-native packages. Example Dockerfile for Go: + FROM golang:1.25-alpine + WORKDIR /test + COPY test_exploit.go . + RUN go mod init openant-test && go mod tidy + RUN go build -o test_exploit test_exploit.go + CMD ["./test_exploit"] - The Dockerfile MUST install dependencies from the requirements/package file, NOT inline in RUN commands. - If a package has many transitive dependencies, only install the specific sub-package you need (e.g., `langchain-core` instead of `langchain`). +CONTAINER FILESYSTEM: +- The container runs with a read-only root filesystem. Only /tmp is writable. +- Do NOT write files to $HOME, /root, /app/data, or any other location outside /tmp. +- If the test needs a writable cache (e.g., Go build cache), set env vars to redirect + to /tmp: `ENV GOCACHE=/tmp/.gocache GOMODCACHE=/tmp/.gomodcache`. +- For Python, use `PYTHONDONTWRITEBYTECODE=1` to avoid writing .pyc files. + ATTACKER CAPTURE SERVER (for SSRF/callback/exfiltration tests): - The attacker server is provided locally and listens on port 9999. - Endpoints: GET /health (health check), GET/POST /capture (logs full request), @@ -264,29 +285,58 @@ def regenerate_test( return parsed +def _generate_one(finding, repo_info, tracker): + """Generate a test for a single finding, tracking cost.""" + cost_before = tracker.total_cost_usd + result = generate_test(finding, repo_info, tracker) + cost_after = tracker.total_cost_usd + cost = cost_after - cost_before + worker = threading.current_thread().name + return finding, result, cost, worker + + def generate_tests_batch( findings: list[dict], repo_info: dict, tracker: TokenTracker = None, + workers: int = 10, ) -> list[tuple[dict, dict | None, float]]: """Generate tests for multiple findings. + Uses ThreadPoolExecutor for parallel generation when workers > 1. + Args: findings: List of finding dicts repo_info: Repository info tracker: Optional TokenTracker + workers: Number of parallel workers (default: 10). Returns: List of (finding, generation_result_or_None, cost_usd) tuples """ tracker = tracker or TokenTracker() - results = [] + total = len(findings) + + mode = "sequential" if workers <= 1 else f"parallel ({workers} workers)" + print(f"[DynamicTest] Generating tests for {total} findings, mode: {mode}", file=sys.stderr, flush=True) - for finding in findings: - cost_before = tracker.total_cost_usd - result = generate_test(finding, repo_info, tracker) - cost_after = tracker.total_cost_usd - cost = cost_after - cost_before - results.append((finding, result, cost)) + if workers <= 1: + results = [] + for i, finding in enumerate(findings): + _finding, result, cost, _worker = _generate_one(finding, repo_info, tracker) + print(f"[DynamicTest] {i+1}/{total} ${cost:.2f}", file=sys.stderr, flush=True) + results.append((_finding, result, cost)) + return results + + # Parallel mode + results = [] + completed = 0 + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [executor.submit(_generate_one, finding, repo_info, tracker) for finding in findings] + for future in as_completed(futures): + _finding, result, cost, worker = future.result() + completed += 1 + print(f"[DynamicTest] {completed}/{total} ${cost:.2f} [{worker}]", file=sys.stderr, flush=True) + results.append((_finding, result, cost)) return results diff --git a/libs/openant-core/utilities/finding_verifier.py b/libs/openant-core/utilities/finding_verifier.py index 101e90f..2e66b7c 100644 --- a/libs/openant-core/utilities/finding_verifier.py +++ b/libs/openant-core/utilities/finding_verifier.py @@ -31,13 +31,17 @@ import json import logging import re +import sys +import threading import time +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from typing import Callable, Optional import anthropic from .llm_client import TokenTracker, get_global_tracker +from .rate_limiter import get_rate_limiter # Null logger that discards all messages (used when no logger provided) _null_logger = logging.getLogger("null_verifier") @@ -259,14 +263,15 @@ def __init__( tracker: TokenTracker = None, verbose: bool = False, app_context: "ApplicationContext" = None, - logger: logging.Logger = None + logger: logging.Logger = None, + client: "anthropic.Anthropic | None" = None, ): self.index = index self.tracker = tracker or get_global_tracker() self.verbose = verbose self.app_context = app_context self.tool_executor = ToolExecutor(index) - self.client = anthropic.Anthropic() + self.client = client or anthropic.Anthropic(max_retries=5) self.logger = logger or _null_logger self._use_logger = logger is not None @@ -323,13 +328,23 @@ def verify_result( self._log("debug", f"Iteration {iterations}", iterations=iterations) - response = self.client.messages.create( - model=VERIFIER_MODEL, - max_tokens=MAX_TOKENS_PER_RESPONSE, - system=system_prompt, - tools=VERIFICATION_TOOLS, - messages=messages - ) + # Wait if we're in a global backoff period + rate_limiter = get_rate_limiter() + rate_limiter.wait_if_needed() + + try: + response = self.client.messages.create( + model=VERIFIER_MODEL, + max_tokens=MAX_TOKENS_PER_RESPONSE, + system=system_prompt, + tools=VERIFICATION_TOOLS, + messages=messages + ) + except anthropic.RateLimitError as exc: + # Report to global rate limiter so all workers back off + retry_after = float(exc.response.headers.get("retry-after", 0)) + get_rate_limiter().report_rate_limit(retry_after) + raise total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens @@ -416,66 +431,258 @@ def verify_batch( results: list, code_by_route: dict, progress_callback: Optional[Callable] = None, + workers: int = 10, + checkpoint=None, + restored_callback: Optional[Callable] = None, ) -> list: """ Verify a batch of results with consistency cross-check. + Uses ThreadPoolExecutor for parallel verification when workers > 1. + Supports checkpoint/resume via the checkpoint parameter. + Args: results: List of Stage 1 results to verify code_by_route: Dict mapping route_key to code progress_callback: Optional callback(unit_id, detail, unit_elapsed) called after each finding is verified. + workers: Number of parallel workers (default: 10). + checkpoint: Optional StepCheckpoint instance for resume support. + restored_callback: Optional callback(count) called after checkpoint + loading with the number of restored units. Returns: Updated results with verification and consistency check """ - # Step 1: Individual verification - for i, result in enumerate(results): - route_key = result.get("route_key", "unknown") - stage1_finding = result.get("finding", "inconclusive") + total = len(results) + + # Load checkpoint state + checkpointed = {} + if checkpoint is not None: + checkpointed = checkpoint.load() + + def _cp_is_error(cp_data): + """A verify checkpoint is errored if verification is missing/empty + or correct_finding == 'error'.""" + if not cp_data: + return True + v = cp_data.get("verification", {}) + if not v: + return True + return v.get("correct_finding") == "error" + + # Separate already-done (successful) from to-do (new + errored) + results_to_verify = [] + _restored_ok = 0 + for r in results: + key = r.get("unit_id") or r.get("route_key", "unknown") + cp_data = checkpointed.get(key) + if cp_data and not _cp_is_error(cp_data): + # Restore verification data from checkpoint + if "verification" in cp_data: + r["verification"] = cp_data["verification"] + if "finding" in cp_data: + r["finding"] = cp_data["finding"] + if "verification_note" in cp_data: + r["verification_note"] = cp_data["verification_note"] + _restored_ok += 1 + else: + # Either no checkpoint, or an errored one — re-verify + results_to_verify.append(r) + + if _restored_ok: + print(f"[Verify] Restored {_restored_ok} findings from checkpoints", + file=sys.stderr, flush=True) + if restored_callback: + restored_callback(_restored_ok) + errored_retries = len(checkpointed) - _restored_ok + if errored_retries: + print(f"[Verify] Retrying {errored_retries} previously errored findings", + file=sys.stderr, flush=True) + + # Initialize summary tracking for _summary.json + _summary_completed = _restored_ok + _summary_errors = 0 + _summary_error_breakdown = {} + _summary_input_tokens = 0 + _summary_output_tokens = 0 + _summary_cost_usd = 0.0 + + # Sum usage from ALL existing checkpoints (including errored ones + # — their cost was already spent in a prior run) + for _key, _cp in checkpointed.items(): + _cp_usage = _cp.get("usage", {}) + _summary_input_tokens += _cp_usage.get("input_tokens", 0) + _summary_output_tokens += _cp_usage.get("output_tokens", 0) + _summary_cost_usd += _cp_usage.get("cost_usd", 0.0) + + def _usage_dict(): + return {"input_tokens": _summary_input_tokens, + "output_tokens": _summary_output_tokens, + "cost_usd": round(_summary_cost_usd, 6)} + + # Inject prior usage into tracker so step_report captures the total + if _summary_input_tokens or _summary_output_tokens: + self.tracker.add_prior_usage( + _summary_input_tokens, _summary_output_tokens, _summary_cost_usd) + + if checkpoint is not None: + checkpoint.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="in_progress", + usage=_usage_dict()) + + def _summary_callback(detail, usage=None): + """Update summary counters after each unit. Called from main thread.""" + nonlocal _summary_completed, _summary_errors, _summary_error_breakdown + nonlocal _summary_input_tokens, _summary_output_tokens, _summary_cost_usd + if detail == "error": + _summary_errors += 1 + _summary_error_breakdown["api"] = _summary_error_breakdown.get("api", 0) + 1 + else: + _summary_completed += 1 + if usage: + _summary_input_tokens += usage.get("input_tokens", 0) + _summary_output_tokens += usage.get("output_tokens", 0) + _summary_cost_usd += usage.get("cost_usd", 0.0) + if checkpoint is not None: + checkpoint.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="in_progress", + usage=_usage_dict()) + + remaining = len(results_to_verify) + mode = "sequential" if workers <= 1 else f"parallel ({workers} workers)" + print(f"[Verify] Mode: {mode}, {remaining} findings to verify " + f"({len(checkpointed)} already done)", file=sys.stderr, flush=True) + + if workers <= 1: + self._verify_batch_sequential( + results_to_verify, code_by_route, progress_callback, checkpoint, + summary_callback=_summary_callback) + else: + self._verify_batch_parallel( + results_to_verify, code_by_route, progress_callback, workers, checkpoint, + summary_callback=_summary_callback) + + # Write final summary with phase="done" + if checkpoint is not None: + checkpoint.write_summary(total, _summary_completed, _summary_errors, + _summary_error_breakdown, phase="done", + usage=_usage_dict()) + + # Step 2: Consistency cross-check (barrier — needs all results) + results = self._check_consistency(results, code_by_route) - self._log("info", f"Verifying finding {i+1}/{len(results)}", - unit_id=route_key, classification=stage1_finding) + return results - unit_start = time.monotonic() - detail = "" - try: - code = code_by_route.get(route_key, "") - verification = self.verify_result( - code=code, - finding=stage1_finding, - attack_vector=result.get("attack_vector"), - reasoning=result.get("reasoning", ""), - files_included=result.get("files_included", []) - ) + def _verify_one(self, result, code_by_route): + """Verify a single result. Returns (route_key, detail, elapsed, worker, usage). - result["verification"] = verification.to_dict() - - if verification.agree: - detail = f"agreed:{verification.correct_finding}" - self._log("info", f"Verification agreed: {verification.correct_finding}", - unit_id=route_key, total_tokens=verification.total_tokens, - iterations=verification.iterations) - else: - detail = f"disagreed:{stage1_finding}->{verification.correct_finding}" - result["finding"] = verification.correct_finding - result["verification_note"] = f"Changed from {stage1_finding} to {verification.correct_finding}" - self._log("info", f"Verification disagreed: {stage1_finding} -> {verification.correct_finding}", - unit_id=route_key, total_tokens=verification.total_tokens, - iterations=verification.iterations) - - except Exception as e: - detail = "error" - self._log("error", f"Verification failed", unit_id=route_key, error=str(e)) - - unit_elapsed = time.monotonic() - unit_start - if progress_callback: - progress_callback(route_key, detail, unit_elapsed) - - # Step 2: Consistency cross-check - results = self._check_consistency(results, code_by_route) + Mutates the result dict in-place (each result is unique, no contention). + """ + route_key = result.get("route_key", "unknown") + stage1_finding = result.get("finding", "inconclusive") + worker = threading.current_thread().name - return results + self.tracker.start_unit_tracking() + unit_start = time.monotonic() + detail = "" + try: + code = code_by_route.get(route_key, "") + verification = self.verify_result( + code=code, + finding=stage1_finding, + attack_vector=result.get("attack_vector"), + reasoning=result.get("reasoning", ""), + files_included=result.get("files_included", []) + ) + + result["verification"] = verification.to_dict() + + if verification.agree: + detail = f"agreed:{verification.correct_finding}" + self._log("info", f"Verification agreed: {verification.correct_finding}", + unit_id=route_key, total_tokens=verification.total_tokens, + iterations=verification.iterations) + else: + detail = f"disagreed:{stage1_finding}->{verification.correct_finding}" + result["finding"] = verification.correct_finding + result["verification_note"] = f"Changed from {stage1_finding} to {verification.correct_finding}" + self._log("info", f"Verification disagreed: {stage1_finding} -> {verification.correct_finding}", + unit_id=route_key, total_tokens=verification.total_tokens, + iterations=verification.iterations) + + except Exception as e: + detail = "error" + print(f"[Verify] ERROR {route_key}: {type(e).__name__}: {e}", file=sys.stderr, flush=True) + + unit_elapsed = time.monotonic() - unit_start + usage = self.tracker.get_unit_usage() + return route_key, detail, unit_elapsed, worker, usage + + def _verify_batch_sequential(self, results, code_by_route, progress_callback, + checkpoint=None, summary_callback=None): + """Verify all results sequentially.""" + try: + for i, result in enumerate(results): + route_key = result.get("route_key", "unknown") + stage1_finding = result.get("finding", "inconclusive") + self._log("info", f"Verifying finding {i+1}/{len(results)}", + unit_id=route_key, classification=stage1_finding) + + route_key, detail, unit_elapsed, _worker, usage = self._verify_one(result, code_by_route) + if checkpoint is not None: + key = result.get("unit_id") or route_key + cp_data = { + "verification": result.get("verification", {}), + "finding": result.get("finding", ""), + "verification_note": result.get("verification_note", ""), + } + if usage: + cp_data["usage"] = usage + checkpoint.save(key, cp_data) + if summary_callback: + summary_callback(detail, usage=usage) + if progress_callback: + progress_callback(route_key, detail, unit_elapsed) + except KeyboardInterrupt: + print("[Verify] Interrupted — progress saved to checkpoints", + file=sys.stderr, flush=True) + + def _verify_batch_parallel(self, results, code_by_route, progress_callback, workers, + checkpoint=None, summary_callback=None): + """Verify all results in parallel using ThreadPoolExecutor.""" + executor = ThreadPoolExecutor(max_workers=workers) + future_to_result = {} + for result in results: + future = executor.submit(self._verify_one, result, code_by_route) + future_to_result[future] = result + + try: + for future in as_completed(future_to_result): + result = future_to_result[future] + route_key, detail, unit_elapsed, worker, usage = future.result() + if checkpoint is not None: + key = result.get("unit_id") or route_key + cp_data = { + "verification": result.get("verification", {}), + "finding": result.get("finding", ""), + "verification_note": result.get("verification_note", ""), + } + if usage: + cp_data["usage"] = usage + checkpoint.save(key, cp_data) + if summary_callback: + summary_callback(detail, usage=usage) + if progress_callback: + progress_callback(route_key, f"{detail} [{worker}]", unit_elapsed) + except KeyboardInterrupt: + print("[Verify] Interrupted — cancelling pending work...", + file=sys.stderr, flush=True) + executor.shutdown(wait=False, cancel_futures=True) + print("[Verify] Progress saved to checkpoints", + file=sys.stderr, flush=True) + return + executor.shutdown(wait=False) def _check_consistency( self, @@ -623,6 +830,10 @@ def _resolve_inconsistency( prompt = get_consistency_check_prompt(group, code_by_route) try: + # Wait if we're in a global backoff period + rate_limiter = get_rate_limiter() + rate_limiter.wait_if_needed() + response = self.client.messages.create( model=VERIFIER_MODEL, max_tokens=MAX_TOKENS_PER_RESPONSE, @@ -648,6 +859,11 @@ def _resolve_inconsistency( explanation=result.get("explanation", "") ) + except anthropic.RateLimitError as e: + # Report to global rate limiter so all workers back off + retry_after = float(e.response.headers.get("retry-after", 0)) + get_rate_limiter().report_rate_limit(retry_after) + self._log("error", f"Consistency resolution rate limited", error=str(e)) except Exception as e: self._log("error", f"Consistency resolution failed", error=str(e)) diff --git a/libs/openant-core/utilities/llm_client.py b/libs/openant-core/utilities/llm_client.py index d8d32c8..ea356bf 100644 --- a/libs/openant-core/utilities/llm_client.py +++ b/libs/openant-core/utilities/llm_client.py @@ -18,10 +18,13 @@ """ import os +import threading from typing import Optional import anthropic from dotenv import load_dotenv +from .rate_limiter import get_rate_limiter + # Pricing per million tokens (as of December 2024) MODEL_PRICING = { @@ -38,14 +41,17 @@ class TokenTracker: """ def __init__(self): + self._lock = threading.Lock() + self._thread_local = threading.local() self.reset() def reset(self): """Reset all counters.""" - self.calls = [] - self.total_input_tokens = 0 - self.total_output_tokens = 0 - self.total_cost_usd = 0.0 + with self._lock: + self.calls = [] + self.total_input_tokens = 0 + self.total_output_tokens = 0 + self.total_cost_usd = 0.0 @property def total_tokens(self) -> int: @@ -79,14 +85,54 @@ def record_call(self, model: str, input_tokens: int, output_tokens: int) -> dict "cost_usd": round(total_cost, 6) } - # Update totals - self.calls.append(call_record) - self.total_input_tokens += input_tokens - self.total_output_tokens += output_tokens - self.total_cost_usd += total_cost + # Update totals (thread-safe) + with self._lock: + self.calls.append(call_record) + self.total_input_tokens += input_tokens + self.total_output_tokens += output_tokens + self.total_cost_usd += total_cost + + # Accumulate to thread-local unit tracking if active + tl = self._thread_local + if hasattr(tl, "unit_input"): + tl.unit_input += input_tokens + tl.unit_output += output_tokens + tl.unit_cost += total_cost return call_record + def add_prior_usage(self, input_tokens: int, output_tokens: int, cost_usd: float): + """Inject usage from a prior run (e.g. restored checkpoints). + + This ensures step reports capture the total cost across all runs, + not just the current run's API calls. + """ + with self._lock: + self.total_input_tokens += input_tokens + self.total_output_tokens += output_tokens + self.total_cost_usd += cost_usd + + def start_unit_tracking(self): + """Start tracking usage for the current unit on this thread. + + Call before processing a unit, then call ``get_unit_usage()`` + after to get the accumulated usage for just that unit. Thread-safe + because each thread has its own ``threading.local()`` storage. + """ + tl = self._thread_local + tl.unit_input = 0 + tl.unit_output = 0 + tl.unit_cost = 0.0 + + def get_unit_usage(self) -> dict: + """Return usage accumulated since ``start_unit_tracking()`` on this thread.""" + tl = self._thread_local + return { + "input_tokens": getattr(tl, "unit_input", 0), + "output_tokens": getattr(tl, "unit_output", 0), + "cost_usd": round(getattr(tl, "unit_cost", 0.0), 6), + } + def get_summary(self) -> dict: """ Get summary of all tracked calls. @@ -94,14 +140,15 @@ def get_summary(self) -> dict: Returns: Dict with totals and per-call breakdown """ - return { - "total_calls": len(self.calls), - "total_input_tokens": self.total_input_tokens, - "total_output_tokens": self.total_output_tokens, - "total_tokens": self.total_input_tokens + self.total_output_tokens, - "total_cost_usd": round(self.total_cost_usd, 6), - "calls": self.calls - } + with self._lock: + return { + "total_calls": len(self.calls), + "total_input_tokens": self.total_input_tokens, + "total_output_tokens": self.total_output_tokens, + "total_tokens": self.total_input_tokens + self.total_output_tokens, + "total_cost_usd": round(self.total_cost_usd, 6), + "calls": list(self.calls), + } def get_totals(self) -> dict: """ @@ -110,13 +157,14 @@ def get_totals(self) -> dict: Returns: Dict with totals only """ - return { - "total_calls": len(self.calls), - "total_input_tokens": self.total_input_tokens, - "total_output_tokens": self.total_output_tokens, - "total_tokens": self.total_input_tokens + self.total_output_tokens, - "total_cost_usd": round(self.total_cost_usd, 6) - } + with self._lock: + return { + "total_calls": len(self.calls), + "total_input_tokens": self.total_input_tokens, + "total_output_tokens": self.total_output_tokens, + "total_tokens": self.total_input_tokens + self.total_output_tokens, + "total_cost_usd": round(self.total_cost_usd, 6), + } # Global tracker instance for session-wide tracking @@ -156,7 +204,7 @@ def __init__(self, model: str = "claude-opus-4-20250514", tracker: TokenTracker if not api_key: raise ValueError("ANTHROPIC_API_KEY not found in environment") - self.client = anthropic.Anthropic(api_key=api_key) + self.client = anthropic.Anthropic(api_key=api_key, max_retries=5) self.model = model self.tracker = tracker or _global_tracker self.last_call = None # Store last call details @@ -172,13 +220,23 @@ async def analyze(self, prompt: str, max_tokens: int = 8192) -> str: Returns: Response text from Claude """ - message = self.client.messages.create( - model=self.model, - max_tokens=max_tokens, - messages=[ - {"role": "user", "content": prompt} - ] - ) + # Wait if we're in a global backoff period + rate_limiter = get_rate_limiter() + rate_limiter.wait_if_needed() + + try: + message = self.client.messages.create( + model=self.model, + max_tokens=max_tokens, + messages=[ + {"role": "user", "content": prompt} + ] + ) + except anthropic.RateLimitError as exc: + # Report to global rate limiter so all workers back off + retry_after = float(exc.response.headers.get("retry-after", 0)) + get_rate_limiter().report_rate_limit(retry_after) + raise # Track token usage self.last_call = self.tracker.record_call( @@ -214,7 +272,17 @@ def analyze_sync(self, prompt: str, max_tokens: int = 8192, model: str = None, s if system: kwargs["system"] = system - message = self.client.messages.create(**kwargs) + # Wait if we're in a global backoff period + rate_limiter = get_rate_limiter() + rate_limiter.wait_if_needed() + + try: + message = self.client.messages.create(**kwargs) + except anthropic.RateLimitError as exc: + # Report to global rate limiter so all workers back off + retry_after = float(exc.response.headers.get("retry-after", 0)) + get_rate_limiter().report_rate_limit(retry_after) + raise # Track token usage self.last_call = self.tracker.record_call( diff --git a/libs/openant-core/utilities/rate_limiter.py b/libs/openant-core/utilities/rate_limiter.py new file mode 100644 index 0000000..3416f1b --- /dev/null +++ b/libs/openant-core/utilities/rate_limiter.py @@ -0,0 +1,243 @@ +""" +Process-level rate limiter with coordinated backoff. + +When any worker hits a 429 rate limit error, ALL workers pause for a +configurable backoff period (default 30s). This prevents thundering herd +and ensures the rate limit window has time to reset. + +Usage: + from utilities.rate_limiter import get_rate_limiter, configure_rate_limiter + + # At startup (once) + configure_rate_limiter(backoff_seconds=30) + + # Before every API call + rate_limiter = get_rate_limiter() + rate_limiter.wait_if_needed() + + # When catching RateLimitError + except anthropic.RateLimitError as e: + retry_after = float(e.response.headers.get("retry-after", 0)) + rate_limiter.report_rate_limit(retry_after) + raise +""" + +import random +import sys +import threading +import time + + +class GlobalRateLimiter: + """ + Singleton rate limiter with coordinated backoff across all threads. + + When any thread reports a rate limit error, all threads pause until + the backoff period expires. This ensures the organization-wide rate + limit window has time to reset. + """ + + _instance = None + _init_lock = threading.Lock() + + def __new__(cls, backoff_seconds: float = 30.0): + if cls._instance is None: + with cls._init_lock: + if cls._instance is None: + instance = super().__new__(cls) + instance._lock = threading.Lock() + instance._backoff_until = 0.0 + instance._backoff_seconds = backoff_seconds + instance._total_waits = 0 + instance._total_wait_time = 0.0 + cls._instance = instance + return cls._instance + + @property + def backoff_seconds(self) -> float: + return self._backoff_seconds + + @backoff_seconds.setter + def backoff_seconds(self, value: float): + self._backoff_seconds = value + + def wait_if_needed(self) -> float: + """ + Block if currently in a backoff period. + + Call this before every API request. Returns the time waited (0 if none). + """ + with self._lock: + now = time.monotonic() + if now >= self._backoff_until: + return 0.0 + + wait_time = self._backoff_until - now + # Add jitter (0-2s) to prevent thundering herd when backoff expires + jitter = random.uniform(0, 2.0) + total_wait = wait_time + jitter + + # Sleep outside the lock so other threads can also read backoff_until + time.sleep(total_wait) + + with self._lock: + self._total_waits += 1 + self._total_wait_time += total_wait + + return total_wait + + def report_rate_limit(self, retry_after: float | None = None): + """ + Report a rate limit error and trigger global backoff. + + Call this when any worker receives a 429 error. All workers will + pause until the backoff period expires. + + Args: + retry_after: The retry-after header value from the API response. + If provided, uses max(retry_after, backoff_seconds). + """ + with self._lock: + # Use the larger of retry_after and our configured backoff + backoff = max(retry_after or 0.0, self._backoff_seconds) + new_backoff_until = time.monotonic() + backoff + + # Only extend if this is later than current backoff + if new_backoff_until > self._backoff_until: + self._backoff_until = new_backoff_until + print( + f"[RateLimiter] Global backoff triggered: {backoff:.0f}s", + file=sys.stderr, + flush=True, + ) + + def is_in_backoff(self) -> bool: + """Check if currently in a backoff period (for diagnostics).""" + with self._lock: + return time.monotonic() < self._backoff_until + + def time_until_ready(self) -> float: + """Seconds until backoff expires (0 if not in backoff).""" + with self._lock: + remaining = self._backoff_until - time.monotonic() + return max(0.0, remaining) + + def get_stats(self) -> dict: + """Get statistics about rate limiting (for diagnostics).""" + with self._lock: + return { + "total_waits": self._total_waits, + "total_wait_time": round(self._total_wait_time, 2), + "backoff_seconds": self._backoff_seconds, + "currently_in_backoff": time.monotonic() < self._backoff_until, + } + + def reset(self): + """Reset backoff state. For testing.""" + with self._lock: + self._backoff_until = 0.0 + self._total_waits = 0 + self._total_wait_time = 0.0 + + +# Module-level singleton access +_rate_limiter: GlobalRateLimiter | None = None +_config_lock = threading.Lock() + + +def configure_rate_limiter(backoff_seconds: float = 30.0) -> GlobalRateLimiter: + """ + Configure the global rate limiter. Call once at startup. + + Args: + backoff_seconds: How long to pause all workers on rate limit (default: 30s). + + Returns: + The configured GlobalRateLimiter singleton. + """ + global _rate_limiter + with _config_lock: + if _rate_limiter is None: + _rate_limiter = GlobalRateLimiter(backoff_seconds) + else: + _rate_limiter.backoff_seconds = backoff_seconds + return _rate_limiter + + +def get_rate_limiter() -> GlobalRateLimiter: + """ + Get the global rate limiter singleton. + + If not configured, creates one with default settings (30s backoff). + """ + global _rate_limiter + if _rate_limiter is None: + with _config_lock: + if _rate_limiter is None: + _rate_limiter = GlobalRateLimiter(30.0) + return _rate_limiter + + +def reset_rate_limiter(): + """Reset the rate limiter singleton. For testing.""" + global _rate_limiter + with _config_lock: + if _rate_limiter is not None: + _rate_limiter.reset() + + +def is_rate_limit_error(error_info: dict | str | None) -> bool: + """ + Check if an error dict/string represents a rate limit error. + + Args: + error_info: The error field from agent_context or similar. + + Returns: + True if this is a rate limit error that should be retried. + """ + if not error_info: + return False + if isinstance(error_info, dict): + return error_info.get("type") == "rate_limit" + return "rate_limit" in str(error_info).lower() + + +def is_retryable_error(error_info: dict | str | None) -> bool: + """ + Check if an error is retryable (transient network/server issues). + + Retryable errors include: + - rate_limit: API rate limiting (429) + - connection: Network connectivity issues + - timeout: Request timeout + - api_status with 500+: Server errors (not client errors like 400) + + Args: + error_info: The error field from agent_context or similar. + + Returns: + True if this error should be retried. + """ + if not error_info: + return False + + if isinstance(error_info, dict): + error_type = error_info.get("type", "") + + # Always retry these transient error types + if error_type in ("rate_limit", "connection", "timeout"): + return True + + # Retry server errors (5xx), but not client errors (4xx) + if error_type == "api_status": + status_code = error_info.get("status_code", 0) + return status_code >= 500 + + return False + + # String-based error checking + error_str = str(error_info).lower() + return any(term in error_str for term in ( + "rate_limit", "connection", "timeout", "500", "502", "503", "504" + )) diff --git a/libs/openant-core/utilities/safe_filename.py b/libs/openant-core/utilities/safe_filename.py new file mode 100644 index 0000000..387d49c --- /dev/null +++ b/libs/openant-core/utilities/safe_filename.py @@ -0,0 +1,25 @@ +"""Shared filename sanitizer for checkpoint files.""" + +import hashlib + + +def safe_filename(unit_id: str) -> str: + """Convert a unit ID to a safe filename. + + Handles long filenames by truncating and appending a hash for uniqueness. + macOS has a 255 character limit for filenames. + """ + safe = (unit_id + .replace("/", "__") + .replace("\\", "__") + .replace(":", "_") + .replace(" ", "_")) + + # Leave room for .json extension (5 chars) and hash suffix (17 chars: _ + 16 hex) + max_len = 255 - 5 - 17 # = 233 + + if len(safe) > max_len: + h = hashlib.sha256(unit_id.encode()).hexdigest()[:16] + safe = safe[:max_len] + "_" + h + + return safe diff --git a/libs/openant-core/validate_dataset_schema.py b/libs/openant-core/validate_dataset_schema.py index 31867b5..1312bce 100755 --- a/libs/openant-core/validate_dataset_schema.py +++ b/libs/openant-core/validate_dataset_schema.py @@ -34,11 +34,13 @@ def validate_unit(unit, index): errors.append(f"Unit {index}: 'code.primary_origin' must be dict") return errors - # 5. CRITICAL: Check enhanced flag (experiment.py line 191) - if "enhanced" not in primary_origin: - errors.append(f"Unit {index}: MISSING 'code.primary_origin.enhanced'") - elif not isinstance(primary_origin.get("enhanced"), bool): - errors.append(f"Unit {index}: 'code.primary_origin.enhanced' must be bool") + # 5. CRITICAL: Check deps_inlined flag (formerly "enhanced") + # Accept either "deps_inlined" (new) or "enhanced" (legacy) for backward compat + deps_inlined_key = "deps_inlined" if "deps_inlined" in primary_origin else "enhanced" + if deps_inlined_key not in primary_origin: + errors.append(f"Unit {index}: MISSING 'code.primary_origin.deps_inlined'") + elif not isinstance(primary_origin.get(deps_inlined_key), bool): + errors.append(f"Unit {index}: 'code.primary_origin.deps_inlined' must be bool") # 6. CRITICAL: Check files_included (experiment.py line 192) if "files_included" not in primary_origin: @@ -46,12 +48,12 @@ def validate_unit(unit, index): elif not isinstance(primary_origin.get("files_included"), list): errors.append(f"Unit {index}: 'code.primary_origin.files_included' must be list") - # 7. If enhanced=true, files_included must have entries - if primary_origin.get("enhanced") and not primary_origin.get("files_included"): - errors.append(f"Unit {index}: enhanced=true but files_included is empty") + # 7. If deps_inlined=true, files_included must have entries + if primary_origin.get(deps_inlined_key) and not primary_origin.get("files_included"): + errors.append(f"Unit {index}: deps_inlined=true but files_included is empty") - # 8. Check file boundaries in primary_code when enhanced with multiple files - if primary_origin.get("enhanced") and len(primary_origin.get("files_included", [])) > 1: + # 8. Check file boundaries in primary_code when deps_inlined with multiple files + if primary_origin.get(deps_inlined_key) and len(primary_origin.get("files_included", [])) > 1: if "// ========== File Boundary ==========" not in primary_code: errors.append(f"Unit {index}: enhanced with multiple files but no file boundaries") @@ -65,19 +67,19 @@ def validate_dataset(path): all_errors = [] units = data.get("units", []) - enhanced_count = 0 + deps_inlined_count = 0 for i, unit in enumerate(units): errors = validate_unit(unit, i) all_errors.extend(errors) - # Count enhanced units + # Count units with dependencies inlined code_field = unit.get("code", {}) if isinstance(code_field, dict): primary_origin = code_field.get("primary_origin", {}) - if primary_origin.get("enhanced"): - enhanced_count += 1 + if primary_origin.get("deps_inlined", primary_origin.get("enhanced")): + deps_inlined_count += 1 - return all_errors, len(units), enhanced_count + return all_errors, len(units), deps_inlined_count if __name__ == "__main__": @@ -85,11 +87,11 @@ def validate_dataset(path): print("Usage: python validate_dataset_schema.py ") sys.exit(1) - errors, total, enhanced = validate_dataset(sys.argv[1]) + errors, total, deps_inlined = validate_dataset(sys.argv[1]) print(f"Dataset: {sys.argv[1]}") print(f"Total units: {total}") - print(f"Enhanced units: {enhanced}") + print(f"Units with deps inlined: {deps_inlined}") print() if errors: