diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..8eb3281 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,30 @@ +name: CI + +on: + push: + branches: ["**"] + pull_request: + +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true + +jobs: + linux: + name: Build & Test (Linux / Swift) + runs-on: ubuntu-latest + # Official Swift toolchain image. Must be >= the package's + # swift-tools-version (6.1); bump this tag when raising the manifest. + container: swift:6.1 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Swift version + run: swift --version + + - name: Build + run: swift build --build-tests + + - name: Test + run: swift test --skip-build diff --git a/Package.swift b/Package.swift index 3bc01d3..6a05933 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version: 6.3 +// swift-tools-version: 6.1 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription diff --git a/README.md b/README.md new file mode 100644 index 0000000..a073aa7 --- /dev/null +++ b/README.md @@ -0,0 +1,100 @@ +# GraphRAG (Swift) + +[![CI](https://github.com/PicoMLX/GraphRAG/actions/workflows/ci.yml/badge.svg)](https://github.com/PicoMLX/GraphRAG/actions/workflows/ci.yml) + +A Swift port of the Rust crate [`graphrag-rs`](https://github.com/automataIA/graphrag-rs): +Graph-based Retrieval Augmented Generation. It builds a knowledge graph from +documents and answers natural-language questions using graph-based context +retrieval. + +This package ports the **core library** (`graphrag-core`) — the parts that make +GraphRAG work end to end — into idiomatic, Swift 6, dependency-free code. It runs +fully offline out of the box, and can optionally talk to a local +[Ollama](https://ollama.com) server for LLM-backed extraction and answer +generation. + +## Installation + +Add the package to your `Package.swift`: + +```swift +.package(url: "https://github.com/picomlx/graphrag.git", branch: "main") +``` + +and depend on the `GraphRAG` product. + +## Quick start + +```swift +import GraphRAG + +// Offline pipeline: hash embeddings + pattern-based entity extraction. +let rag = try GraphRAGBuilder() + .withChunkSize(800) + .withChunkOverlap(100) + .withTopK(5) + .build() + +await rag.addDocument(text: """ + Ada Lovelace collaborated with Charles Babbage on the Analytical Engine, + an early mechanical general-purpose computer. + """) + +try await rag.build() // chunk → extract → embed → index +let answer = try await rag.ask("Who worked on the Analytical Engine?") +print(answer.text) +print(answer.sources) // grounding chunk ids +``` + +### Using a local LLM (Ollama) + +```swift +let rag = try GraphRAGBuilder() + .withLocalDefaults() // Ollama chat + embeddings + .build() +``` + +With Ollama enabled, entity/relationship extraction uses the LLM extraction +prompt, and `ask` synthesizes a natural-language answer from the retrieved +context. Without it, extraction is pattern-based and `ask` returns an extractive +summary of the top chunks. + +## What's included + +| Area | Types | +| --- | --- | +| Core model | `Document`, `TextChunk`, `Entity`, `Relationship`, `EntityMention`, typed IDs, `GraphRAGError` | +| Abstractions | `LanguageModel`, `EmbeddingModel`, `EntityExtracting`, `ChunkingStrategy` | +| Text | `HierarchicalChunker`, `TextProcessor`, `TfIdfKeywordExtractor` | +| Graph | `KnowledgeGraph`, `PageRank`, `GraphTraversal` (BFS/DFS/ego/paths), `GraphAnalytics` (degree/closeness/betweenness/components) | +| Retrieval | `BM25Retriever`, `InMemoryVectorStore` (cosine), `HybridRetriever` (RRF / weighted / CombSUM / MaxScore fusion) | +| Extraction | `PatternEntityExtractor`, `LLMEntityExtractor`, `Prompts` | +| Embeddings | `HashEmbedder` (offline, deterministic), `OllamaEmbedder` | +| LLM | `OllamaClient` | +| Orchestration | `GraphRAG` (actor), `GraphRAGBuilder`, `Config` | + +## Design notes / port fidelity + +- **Defaults match the Rust crate**: PageRank damping `0.85` / tolerance `1e-6`, + BM25 `k1 = 1.2`, `b = 0.75`, hybrid `RRF k = 60`, semantic/keyword weights + `0.7 / 0.3`, traversal `maxDepth = 3`, min relationship strength `0.5`, etc. +- **Concurrency**: `GraphRAG` is an `actor`; backends are `Sendable` existentials + (`any EmbeddingModel`, `any LanguageModel`, `any EntityExtracting`). Builds + cleanly under Swift 6 strict concurrency. +- **Unicode safety**: the Rust chunker works on UTF-8 byte offsets guarded by + `is_char_boundary`. This port operates on `Character` (grapheme) arrays, which + are always valid boundaries; sizes and offsets are measured in characters. +- **Scope**: this is the portable core pipeline. The Rust workspace's + server/WASM/CLI crates and heavier optional subsystems (LightRAG, ROGRAG, + Leiden communities, distributed caching, persistence backends) are out of + scope for this port. + +## Testing + +```bash +swift test +``` + +The suite covers chunking, keyword extraction, BM25 ranking, cosine/vector +search, the knowledge graph, PageRank, traversal, analytics, pattern extraction, +and the end-to-end offline build/ask pipeline. diff --git a/Sources/GraphRAG/Core/Error.swift b/Sources/GraphRAG/Core/Error.swift new file mode 100644 index 0000000..a03944c --- /dev/null +++ b/Sources/GraphRAG/Core/Error.swift @@ -0,0 +1,85 @@ +// Error.swift +// Ported from graphrag-rs `core::error::GraphRAGError`. + +import Foundation + +/// The unified error type for every fallible GraphRAG operation. +/// +/// Mirrors the variants of the Rust `GraphRAGError` enum. Each case carries a +/// human-readable message (and, where relevant, structured fields) so callers +/// can pattern-match or surface a description. +public enum GraphRAGError: Error, Sendable, CustomStringConvertible { + case config(message: String) + case notInitialized + case noDocuments + case io(message: String) + case http(message: String) + case json(message: String) + case textProcessing(message: String) + case graphConstruction(message: String) + case vectorSearch(message: String) + case entityExtraction(message: String) + case retrieval(message: String) + case generation(message: String) + case functionCall(message: String) + case storage(message: String) + case embedding(message: String) + case languageModel(message: String) + case parallel(message: String) + case serialization(message: String) + case validation(message: String) + case network(message: String) + case auth(message: String) + case notFound(resource: String, id: String) + case alreadyExists(resource: String, id: String) + case timeout(operation: String, seconds: Double) + case resourceLimit(resource: String, limit: Int) + case dataCorruption(message: String) + case unsupported(operation: String, reason: String) + case rateLimit(message: String) + case conflictResolution(message: String) + case incrementalUpdate(message: String) + + public var description: String { + switch self { + case .config(let m): return "Configuration error: \(m)" + case .notInitialized: return "GraphRAG system is not initialized" + case .noDocuments: return "No documents have been added" + case .io(let m): return "I/O error: \(m)" + case .http(let m): return "HTTP error: \(m)" + case .json(let m): return "JSON error: \(m)" + case .textProcessing(let m): return "Text processing error: \(m)" + case .graphConstruction(let m): return "Graph construction error: \(m)" + case .vectorSearch(let m): return "Vector search error: \(m)" + case .entityExtraction(let m): return "Entity extraction error: \(m)" + case .retrieval(let m): return "Retrieval error: \(m)" + case .generation(let m): return "Generation error: \(m)" + case .functionCall(let m): return "Function call error: \(m)" + case .storage(let m): return "Storage error: \(m)" + case .embedding(let m): return "Embedding error: \(m)" + case .languageModel(let m): return "Language model error: \(m)" + case .parallel(let m): return "Parallel processing error: \(m)" + case .serialization(let m): return "Serialization error: \(m)" + case .validation(let m): return "Validation error: \(m)" + case .network(let m): return "Network error: \(m)" + case .auth(let m): return "Authentication error: \(m)" + case .notFound(let resource, let id): + return "\(resource) not found: \(id)" + case .alreadyExists(let resource, let id): + return "\(resource) already exists: \(id)" + case .timeout(let operation, let seconds): + return "Operation '\(operation)' timed out after \(seconds)s" + case .resourceLimit(let resource, let limit): + return "Resource limit exceeded for \(resource): \(limit)" + case .dataCorruption(let m): return "Data corruption: \(m)" + case .unsupported(let operation, let reason): + return "Unsupported operation '\(operation)': \(reason)" + case .rateLimit(let m): return "Rate limit exceeded: \(m)" + case .conflictResolution(let m): return "Conflict resolution error: \(m)" + case .incrementalUpdate(let m): return "Incremental update error: \(m)" + } + } +} + +/// Convenience matching the Rust `pub type Result = ...` alias. +public typealias GraphRAGResult = Swift.Result diff --git a/Sources/GraphRAG/Core/Identifiers.swift b/Sources/GraphRAG/Core/Identifiers.swift new file mode 100644 index 0000000..03671c0 --- /dev/null +++ b/Sources/GraphRAG/Core/Identifiers.swift @@ -0,0 +1,39 @@ +// Identifiers.swift +// Strongly-typed identifier wrappers, ported from graphrag-rs `core::DocumentId`, +// `core::EntityId` and `core::ChunkId`. + +/// Stable identifier for a `Document`. +public struct DocumentID: Hashable, Codable, Sendable, CustomStringConvertible, + ExpressibleByStringLiteral +{ + public var raw: String + + public init(_ raw: String) { self.raw = raw } + public init(stringLiteral value: String) { self.raw = value } + + public var description: String { raw } +} + +/// Stable identifier for an `Entity`. +public struct EntityID: Hashable, Codable, Sendable, CustomStringConvertible, + ExpressibleByStringLiteral +{ + public var raw: String + + public init(_ raw: String) { self.raw = raw } + public init(stringLiteral value: String) { self.raw = value } + + public var description: String { raw } +} + +/// Stable identifier for a `TextChunk`. +public struct ChunkID: Hashable, Codable, Sendable, CustomStringConvertible, + ExpressibleByStringLiteral +{ + public var raw: String + + public init(_ raw: String) { self.raw = raw } + public init(stringLiteral value: String) { self.raw = value } + + public var description: String { raw } +} diff --git a/Sources/GraphRAG/Core/Models.swift b/Sources/GraphRAG/Core/Models.swift new file mode 100644 index 0000000..dd08b44 --- /dev/null +++ b/Sources/GraphRAG/Core/Models.swift @@ -0,0 +1,160 @@ +// Models.swift +// Core domain model, ported from graphrag-rs `core::mod`. + +import Foundation + +/// Optional metadata attached to a chunk during enrichment. +/// +/// In the Rust source this is a dedicated `ChunkMetadata` struct; here it keeps +/// the most useful fields plus an open key/value bag for extensions. +public struct ChunkMetadata: Codable, Sendable, Equatable { + /// Zero-based index of the chunk within its source document. + public var index: Int + /// Approximate token / word count of the chunk content. + public var wordCount: Int + /// Keywords extracted from the chunk, if any. + public var keywords: [String] + /// Arbitrary extra fields. + public var extra: [String: String] + + public init( + index: Int = 0, + wordCount: Int = 0, + keywords: [String] = [], + extra: [String: String] = [:] + ) { + self.index = index + self.wordCount = wordCount + self.keywords = keywords + self.extra = extra + } +} + +/// A contiguous span of a document produced by the chunking stage. +public struct TextChunk: Codable, Sendable, Identifiable, Equatable { + public var id: ChunkID + public var documentID: DocumentID + public var content: String + /// Character (grapheme) offset of the chunk start within the original + /// document content — not a UTF-8 byte offset. + public var startOffset: Int + /// Character (grapheme) offset of the chunk end within the original document + /// content — not a UTF-8 byte offset. + public var endOffset: Int + /// Optional dense embedding for semantic search. + public var embedding: [Float]? + /// Entities mentioned within this chunk. + public var entities: [EntityID] + public var metadata: ChunkMetadata + + public init( + id: ChunkID, + documentID: DocumentID, + content: String, + startOffset: Int, + endOffset: Int, + embedding: [Float]? = nil, + entities: [EntityID] = [], + metadata: ChunkMetadata = ChunkMetadata() + ) { + self.id = id + self.documentID = documentID + self.content = content + self.startOffset = startOffset + self.endOffset = endOffset + self.embedding = embedding + self.entities = entities + self.metadata = metadata + } +} + +/// A source document and its derived chunks. +public struct Document: Codable, Sendable, Identifiable, Equatable { + public var id: DocumentID + public var title: String + public var content: String + public var metadata: [String: String] + public var chunks: [TextChunk] + + public init( + id: DocumentID, + title: String, + content: String, + metadata: [String: String] = [:], + chunks: [TextChunk] = [] + ) { + self.id = id + self.title = title + self.content = content + self.metadata = metadata + self.chunks = chunks + } +} + +/// A single mention (occurrence) of an entity inside a chunk. +public struct EntityMention: Codable, Sendable, Equatable { + public var chunkID: ChunkID + public var startOffset: Int + public var endOffset: Int + public var confidence: Float + + public init(chunkID: ChunkID, startOffset: Int, endOffset: Int, confidence: Float) { + self.chunkID = chunkID + self.startOffset = startOffset + self.endOffset = endOffset + self.confidence = confidence + } +} + +/// A node in the knowledge graph. +public struct Entity: Codable, Sendable, Identifiable, Equatable { + public var id: EntityID + public var name: String + public var entityType: String + public var confidence: Float + public var mentions: [EntityMention] + public var embedding: [Float]? + + public init( + id: EntityID, + name: String, + entityType: String, + confidence: Float = 1.0, + mentions: [EntityMention] = [], + embedding: [Float]? = nil + ) { + self.id = id + self.name = name + self.entityType = entityType + self.confidence = confidence + self.mentions = mentions + self.embedding = embedding + } +} + +/// A directed, typed edge between two entities. +public struct Relationship: Codable, Sendable, Equatable { + public var source: EntityID + public var target: EntityID + public var relationType: String + public var confidence: Float + /// Chunks that provide evidence for this relationship. + public var context: [ChunkID] + public var embedding: [Float]? + + public init( + source: EntityID, + target: EntityID, + relationType: String, + confidence: Float = 1.0, + context: [ChunkID] = [], + embedding: [Float]? = nil + ) { + self.source = source + self.target = target + self.relationType = relationType + self.confidence = confidence + self.context = context + self.embedding = embedding + } +} diff --git a/Sources/GraphRAG/Core/Protocols.swift b/Sources/GraphRAG/Core/Protocols.swift new file mode 100644 index 0000000..155ddd4 --- /dev/null +++ b/Sources/GraphRAG/Core/Protocols.swift @@ -0,0 +1,59 @@ +// Protocols.swift +// Pluggable abstractions, ported from graphrag-rs `core::traits`. +// +// The Rust crate exposes both synchronous and async variants of each trait. +// In Swift we model the async variants (the ones the pipeline actually uses) +// with `async` requirements and require `Sendable` so implementations can cross +// concurrency domains. + +/// A text-generation backend (the "LLM"). +public protocol LanguageModel: Sendable { + /// Complete `prompt` with default parameters. + func complete(_ prompt: String) async throws -> String + /// Complete `prompt` with explicit generation parameters. + func complete(_ prompt: String, params: GenerationParams) async throws -> String + /// Whether the backend is reachable / configured. + func isAvailable() async -> Bool + /// Static model description. + var modelInfo: ModelInfo { get } +} + +extension LanguageModel { + public func complete(_ prompt: String) async throws -> String { + try await complete(prompt, params: .default) + } +} + +/// An embedding backend that turns text into dense vectors. +public protocol EmbeddingModel: Sendable { + /// Embed a single string. + func embed(_ text: String) async throws -> [Float] + /// Embed a batch of strings (default: sequential `embed`). + func embedBatch(_ texts: [String]) async throws -> [[Float]] + /// Dimensionality of produced vectors. + var dimension: Int { get } + /// Whether the backend is ready. + func isAvailable() async -> Bool +} + +extension EmbeddingModel { + public func embedBatch(_ texts: [String]) async throws -> [[Float]] { + var out: [[Float]] = [] + out.reserveCapacity(texts.count) + for text in texts { + out.append(try await embed(text)) + } + return out + } +} + +/// A strategy that splits raw text into chunks. +public protocol ChunkingStrategy: Sendable { + /// Split `text` belonging to `documentID` into ordered chunks. + func chunk(_ text: String, documentID: DocumentID) -> [TextChunk] +} + +/// Extracts entities (and optionally relationships) from text. +public protocol EntityExtracting: Sendable { + func extract(from chunk: TextChunk) async throws -> (entities: [Entity], relationships: [Relationship]) +} diff --git a/Sources/GraphRAG/Core/Types.swift b/Sources/GraphRAG/Core/Types.swift new file mode 100644 index 0000000..1f6d737 --- /dev/null +++ b/Sources/GraphRAG/Core/Types.swift @@ -0,0 +1,113 @@ +// Types.swift +// Shared supporting value types, ported from graphrag-rs `core::traits` helpers. + +import Foundation + +/// A single hit returned by a vector store search. +public struct SearchResult: Sendable, Equatable { + public var id: String + /// Distance (lower is closer) — for cosine stores this is `1 - similarity`. + public var distance: Float + public var metadata: [String: String]? + + public init(id: String, distance: Float, metadata: [String: String]? = nil) { + self.id = id + self.distance = distance + self.metadata = metadata + } + + /// Convenience similarity score for cosine-based stores. + public var similarity: Float { 1.0 - distance } +} + +/// Knobs passed to a `LanguageModel` completion call. +public struct GenerationParams: Sendable, Equatable { + public var maxTokens: Int? + public var temperature: Float? + public var topP: Float? + public var stopSequences: [String]? + + public init( + maxTokens: Int? = nil, + temperature: Float? = nil, + topP: Float? = nil, + stopSequences: [String]? = nil + ) { + self.maxTokens = maxTokens + self.temperature = temperature + self.topP = topP + self.stopSequences = stopSequences + } + + public static let `default` = GenerationParams() +} + +/// Static description of a language model. +public struct ModelInfo: Sendable, Equatable { + public var name: String + public var version: String? + public var maxContextLength: Int? + public var supportsStreaming: Bool + + public init( + name: String, + version: String? = nil, + maxContextLength: Int? = nil, + supportsStreaming: Bool = false + ) { + self.name = name + self.version = version + self.maxContextLength = maxContextLength + self.supportsStreaming = supportsStreaming + } +} + +/// Aggregate counts describing a graph. +public struct GraphStats: Sendable, Equatable { + public var nodeCount: Int + public var edgeCount: Int + public var averageDegree: Float + public var maxDepth: Int + + public init(nodeCount: Int, edgeCount: Int, averageDegree: Float, maxDepth: Int) { + self.nodeCount = nodeCount + self.edgeCount = edgeCount + self.averageDegree = averageDegree + self.maxDepth = maxDepth + } +} + +/// Counts produced by `GraphRAG.stats()`. +public struct Stats: Sendable, Equatable { + public var documentCount: Int + public var chunkCount: Int + public var entityCount: Int + public var relationshipCount: Int + + public init( + documentCount: Int = 0, + chunkCount: Int = 0, + entityCount: Int = 0, + relationshipCount: Int = 0 + ) { + self.documentCount = documentCount + self.chunkCount = chunkCount + self.entityCount = entityCount + self.relationshipCount = relationshipCount + } +} + +/// The result of an `ask` query. +public struct Answer: Sendable, Equatable { + public var text: String + /// Confidence in `[0, 1]`, when available. + public var confidence: Float + /// Chunk identifiers used to ground the answer. + public var sources: [ChunkID] + + public init(text: String, confidence: Float = 0.0, sources: [ChunkID] = []) { + self.text = text + self.confidence = confidence + self.sources = sources + } +} diff --git a/Sources/GraphRAG/Embeddings/HashEmbedder.swift b/Sources/GraphRAG/Embeddings/HashEmbedder.swift new file mode 100644 index 0000000..ca7c9fc --- /dev/null +++ b/Sources/GraphRAG/Embeddings/HashEmbedder.swift @@ -0,0 +1,75 @@ +// HashEmbedder.swift +// Offline, deterministic embedding backend (the default in graphrag-rs when no +// neural/remote provider is configured). +// +// Each token is hashed (FNV-1a, stable across runs) into a bucket with a signed +// contribution; the accumulated vector is L2-normalized. Texts sharing tokens +// land near each other under cosine similarity, which is enough to drive the +// retrieval pipeline without any model download or network call. + +import Foundation + +public struct HashEmbedder: EmbeddingModel { + public let dimension: Int + + public init(dimension: Int = 384) { + self.dimension = max(1, dimension) + } + + public func isAvailable() async -> Bool { true } + + public func embed(_ text: String) async throws -> [Float] { + embedSync(text) + } + + /// Synchronous variant (the hashing is pure and cheap). + public func embedSync(_ text: String) -> [Float] { + var vector = [Float](repeating: 0, count: dimension) + let tokens = tokenize(text) + guard !tokens.isEmpty else { return vector } + + for token in tokens { + let hash = HashEmbedder.fnv1a(token) + let bucket = Int(hash % UInt64(dimension)) + let sign: Float = (hash & 0x1) == 0 ? 1 : -1 + vector[bucket] += sign + } + + // L2 normalize. + var norm: Float = 0 + for value in vector { norm += value * value } + norm = norm.squareRoot() + if norm > 0 { + for i in 0.. [String] { + // Split on any non-alphanumeric so "graph-based" hashes as the same two + // tokens as the query "graph based" (preserves semantic overlap). + var tokens: [String] = [] + var current = "" + for ch in text { + if ch.isLetter || ch.isNumber { + current.append(contentsOf: ch.lowercased()) + } else if !current.isEmpty { + tokens.append(current) + current = "" + } + } + if !current.isEmpty { tokens.append(current) } + return tokens + } + + /// 64-bit FNV-1a hash — stable across processes (unlike Swift's `Hasher`). + static func fnv1a(_ string: String) -> UInt64 { + var hash: UInt64 = 0xcbf2_9ce4_8422_2325 + let prime: UInt64 = 0x0000_0100_0000_01b3 + for byte in string.utf8 { + hash ^= UInt64(byte) + hash = hash &* prime + } + return hash + } +} diff --git a/Sources/GraphRAG/Embeddings/Ollama.swift b/Sources/GraphRAG/Embeddings/Ollama.swift new file mode 100644 index 0000000..124a439 --- /dev/null +++ b/Sources/GraphRAG/Embeddings/Ollama.swift @@ -0,0 +1,203 @@ +// Ollama.swift +// Ported from graphrag-rs `ollama` and `embeddings::ollama`. +// +// Talks to a local Ollama daemon over HTTP: `/api/generate` for completions and +// `/api/embeddings` for embeddings. Network calls go through URLSession. + +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// Connection + generation settings for a local Ollama server. +public struct OllamaConfig: Sendable { + public var host: String + public var port: Int + public var chatModel: String + public var embeddingModel: String + public var embeddingDimension: Int + public var temperature: Float + public var maxTokens: Int + public var timeoutSeconds: Double + public var keepAlive: String? + public var numCtx: Int? + + public init( + host: String = "http://localhost", + port: Int = 11434, + chatModel: String = "llama3.2:3b", + embeddingModel: String = "nomic-embed-text", + embeddingDimension: Int = 1024, + temperature: Float = 0.7, + maxTokens: Int = 2000, + timeoutSeconds: Double = 30, + keepAlive: String? = nil, + numCtx: Int? = nil + ) { + self.host = host + self.port = port + self.chatModel = chatModel + self.embeddingModel = embeddingModel + self.embeddingDimension = embeddingDimension + self.temperature = temperature + self.maxTokens = maxTokens + self.timeoutSeconds = timeoutSeconds + self.keepAlive = keepAlive + self.numCtx = numCtx + } + + var baseURL: String { + // Accept bare hosts ("localhost", "127.0.0.1"): without a scheme, URL + // parses the host as the scheme and the request fails. + let normalizedHost = host.contains("://") ? host : "http://\(host)" + // If the host already includes a port (e.g. "http://localhost:11434"), + // don't append another. + if let schemeRange = normalizedHost.range(of: "://"), + normalizedHost[schemeRange.upperBound...].contains(":") + { + return normalizedHost + } + return "\(normalizedHost):\(port)" + } +} + +/// Shared low-level HTTP helpers for the Ollama REST API. +enum OllamaHTTP { + /// Serialize a JSON object to `Data` (synchronous; nothing crosses an await). + static func encode(_ body: [String: Any]) throws -> Data { + do { + return try JSONSerialization.data(withJSONObject: body) + } catch { + throw GraphRAGError.serialization(message: error.localizedDescription) + } + } + + static func post( + urlString: String, jsonBody: Data, timeout: Double + ) async throws -> Data { + guard let url = URL(string: urlString) else { + throw GraphRAGError.network(message: "Invalid Ollama URL: \(urlString)") + } + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.timeoutInterval = timeout + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.httpBody = jsonBody + return try await perform(request) + } + + static func get(urlString: String, timeout: Double) async throws -> Data { + guard let url = URL(string: urlString) else { + throw GraphRAGError.network(message: "Invalid Ollama URL: \(urlString)") + } + var request = URLRequest(url: url) + request.httpMethod = "GET" + request.timeoutInterval = timeout + return try await perform(request) + } + + private static func perform(_ request: URLRequest) async throws -> Data { + // Async-native URLSession supports task cancellation, unlike the legacy + // callback API wrapped in a continuation. + let data: Data + let response: URLResponse + do { + (data, response) = try await URLSession.shared.data(for: request) + } catch { + throw GraphRAGError.network(message: error.localizedDescription) + } + if let http = response as? HTTPURLResponse, !(200..<300).contains(http.statusCode) { + throw GraphRAGError.http(message: "HTTP \(http.statusCode)") + } + return data + } +} + +/// `LanguageModel` backed by Ollama's `/api/generate`. +public struct OllamaClient: LanguageModel { + public let config: OllamaConfig + + public init(config: OllamaConfig = OllamaConfig()) { + self.config = config + } + + public var modelInfo: ModelInfo { + ModelInfo( + name: config.chatModel, maxContextLength: config.numCtx, supportsStreaming: true) + } + + public func isAvailable() async -> Bool { + do { + _ = try await OllamaHTTP.get( + urlString: "\(config.baseURL)/api/tags", timeout: config.timeoutSeconds) + return true + } catch { + return false + } + } + + public func complete(_ prompt: String, params: GenerationParams) async throws -> String { + var options: [String: Any] = [ + "temperature": Double(params.temperature ?? config.temperature), + "num_predict": params.maxTokens ?? config.maxTokens, + ] + if let topP = params.topP { options["top_p"] = Double(topP) } + if let numCtx = config.numCtx { options["num_ctx"] = numCtx } + if let stop = params.stopSequences { options["stop"] = stop } + + var body: [String: Any] = [ + "model": config.chatModel, + "prompt": prompt, + "stream": false, + "options": options, + ] + if let keepAlive = config.keepAlive { body["keep_alive"] = keepAlive } + + let jsonBody = try OllamaHTTP.encode(body) + let data = try await OllamaHTTP.post( + urlString: "\(config.baseURL)/api/generate", jsonBody: jsonBody, + timeout: config.timeoutSeconds) + struct GenerateResponse: Codable { let response: String } + do { + return try JSONDecoder().decode(GenerateResponse.self, from: data).response + } catch { + throw GraphRAGError.generation(message: "Failed to decode Ollama response") + } + } +} + +/// `EmbeddingModel` backed by Ollama's `/api/embeddings`. +public struct OllamaEmbedder: EmbeddingModel { + public let config: OllamaConfig + + public init(config: OllamaConfig = OllamaConfig()) { + self.config = config + } + + public var dimension: Int { config.embeddingDimension } + + public func isAvailable() async -> Bool { + do { + _ = try await OllamaHTTP.get( + urlString: "\(config.baseURL)/api/tags", timeout: config.timeoutSeconds) + return true + } catch { + return false + } + } + + public func embed(_ text: String) async throws -> [Float] { + let body: [String: Any] = ["model": config.embeddingModel, "prompt": text] + let jsonBody = try OllamaHTTP.encode(body) + let data = try await OllamaHTTP.post( + urlString: "\(config.baseURL)/api/embeddings", jsonBody: jsonBody, + timeout: config.timeoutSeconds) + struct EmbeddingResponse: Codable { let embedding: [Float] } + do { + return try JSONDecoder().decode(EmbeddingResponse.self, from: data).embedding + } catch { + throw GraphRAGError.embedding(message: "Failed to decode Ollama embedding") + } + } +} diff --git a/Sources/GraphRAG/Entity/LLMExtractor.swift b/Sources/GraphRAG/Entity/LLMExtractor.swift new file mode 100644 index 0000000..9c81faa --- /dev/null +++ b/Sources/GraphRAG/Entity/LLMExtractor.swift @@ -0,0 +1,284 @@ +// LLMExtractor.swift +// Ported from graphrag-rs `entity::llm_extractor`. + +import Foundation + +/// LLM-driven entity & relationship extractor. +/// +/// Builds the extraction prompt, calls a `LanguageModel`, and parses the JSON +/// response with the same staged fallbacks as the Rust version (direct decode → +/// fenced code block → first/last brace slice). +public struct LLMEntityExtractor: EntityExtracting { + public let model: Model + public var entityTypes: [String] + public var temperature: Float + public var maxTokens: Int + /// Extra gleaning passes to recover missed items (0 = single pass). + public var gleaningRounds: Int + + public init( + model: Model, + entityTypes: [String] = Prompts.defaultEntityTypes, + temperature: Float = 0.0, + maxTokens: Int = 1500, + gleaningRounds: Int = 0 + ) { + self.model = model + self.entityTypes = entityTypes + self.temperature = temperature + self.maxTokens = maxTokens + self.gleaningRounds = gleaningRounds + } + + public func extract(from chunk: TextChunk) async throws + -> (entities: [Entity], relationships: [Relationship]) + { + let typesList = entityTypes.joined(separator: ", ") + let prompt = Prompts.fill( + Prompts.entityExtraction, + ["entity_types": typesList, "input_text": chunk.content]) + let params = GenerationParams(maxTokens: maxTokens, temperature: temperature) + let response = try await model.complete(prompt, params: params) + + var output = LLMEntityExtractor.parse(response) ?? ExtractionOutput() + + // Optional gleaning passes. + var round = 0 + while round < gleaningRounds { + let prevEntities = output.entities.map { "- \($0.name) (\($0.type))" } + .joined(separator: "\n") + let prevRelationships = output.relationships + .map { "- \($0.source) -> \($0.target)" }.joined(separator: "\n") + let gleanPrompt = Prompts.fill( + Prompts.gleaningContinuation, + [ + "entity_types": typesList, + "input_text": chunk.content, + "previous_entities": prevEntities.isEmpty ? "(none)" : prevEntities, + "previous_relationships": prevRelationships.isEmpty ? "(none)" : prevRelationships, + ]) + let gleanResponse = try await model.complete(gleanPrompt, params: params) + if let extra = LLMEntityExtractor.parse(gleanResponse) { + if extra.entities.isEmpty && extra.relationships.isEmpty { break } + output.entities.append(contentsOf: extra.entities) + output.relationships.append(contentsOf: extra.relationships) + } else { + break + } + round += 1 + } + + return convert(output, chunk: chunk) + } + + // MARK: - Conversion + + private func convert(_ output: ExtractionOutput, chunk: TextChunk) + -> (entities: [Entity], relationships: [Relationship]) + { + var entities: [Entity] = [] + var indexByID: [EntityID: Int] = [:] + var idByName: [String: EntityID] = [:] + let lowerContent = chunk.content.lowercased() + + for data in output.entities { + let name = data.name.trimmingCharacters(in: .whitespacesAndNewlines) + guard !name.isEmpty else { continue } + let type = data.type.isEmpty ? "CONCEPT" : data.type.uppercased() + let id = PatternEntityExtractor.makeEntityID(type: type, name: name) + + var mentions: [EntityMention] = [] + if let range = LLMEntityExtractor.tokenBoundaryRange( + of: name.lowercased(), in: lowerContent) + { + // Derive both offsets from the matched range; case folding can + // change grapheme counts, so `start + name.count` is unreliable. + let start = lowerContent.distance(from: lowerContent.startIndex, to: range.lowerBound) + let end = lowerContent.distance(from: lowerContent.startIndex, to: range.upperBound) + mentions.append( + EntityMention( + chunkID: chunk.id, startOffset: start, + endOffset: end, confidence: 0.9)) + } + + // Deduplicate by id (e.g. a gleaning pass repeating an entity) so a + // duplicate doesn't later consume the per-chunk cap. + if let idx = indexByID[id] { + entities[idx].mentions.append(contentsOf: mentions) + } else { + indexByID[id] = entities.count + entities.append( + Entity(id: id, name: name, entityType: type, confidence: 0.9, mentions: mentions)) + } + idByName[name.lowercased()] = id + } + + var relationships: [Relationship] = [] + for data in output.relationships { + let src = data.source.lowercased().trimmingCharacters(in: .whitespacesAndNewlines) + let tgt = data.target.lowercased().trimmingCharacters(in: .whitespacesAndNewlines) + guard let sourceID = idByName[src], let targetID = idByName[tgt] else { continue } + let relType = LLMEntityExtractor.relationTypeLabel(from: data.description) + // Clamp model-provided strength to a valid confidence; out-of-range + // values would distort traversal filtering and PageRank weights. + let confidence = min(max(data.strength ?? 0.7, 0), 1) + relationships.append( + Relationship( + source: sourceID, target: targetID, relationType: relType, + confidence: confidence, context: [chunk.id])) + } + + return (entities, relationships) + } + + private static func relationTypeLabel(from description: String) -> String { + let trimmed = description.trimmingCharacters(in: .whitespacesAndNewlines) + if trimmed.isEmpty { return "RELATED_TO" } + // Use the first few words, upper-snake-cased, as a coarse relation label. + let words = trimmed.split(whereSeparator: { $0.isWhitespace }).prefix(3) + let label = words.map { word in + String(word.filter { $0.isLetter || $0.isNumber }).uppercased() + }.filter { !$0.isEmpty }.joined(separator: "_") + return label.isEmpty ? "RELATED_TO" : label + } + + /// First occurrence of `needle` in `haystack` that sits on token boundaries + /// (not embedded inside a larger word), so "Ann" won't match "Annabelle". + static func tokenBoundaryRange(of needle: String, in haystack: String) -> Range? { + guard !needle.isEmpty else { return nil } + func isWordChar(_ c: Character) -> Bool { c.isLetter || c.isNumber } + var searchStart = haystack.startIndex + while let range = haystack.range(of: needle, range: searchStart.. ExtractionOutput? { + let decoder = JSONDecoder() + + // 1. Direct decode. + if let data = response.data(using: .utf8), + let parsed = try? decoder.decode(ExtractionOutput.self, from: data) + { + return parsed + } + // 2. Fenced code block. + if let fenced = extractFencedJSON(response), + let data = fenced.data(using: .utf8), + let parsed = try? decoder.decode(ExtractionOutput.self, from: data) + { + return parsed + } + // 3. First '{' to last '}'. + if let first = response.firstIndex(of: "{"), + let last = response.lastIndex(of: "}"), first < last + { + let slice = String(response[first...last]) + if let data = slice.data(using: .utf8), + let parsed = try? decoder.decode(ExtractionOutput.self, from: data) + { + return parsed + } + } + return nil + } + + private static func extractFencedJSON(_ text: String) -> String? { + guard let fenceStart = text.range(of: "```") else { return nil } + var afterFence = text[fenceStart.upperBound...] + // Skip an optional language tag line ("json"). + if let newline = afterFence.firstIndex(of: "\n") { + let firstLine = afterFence[afterFence.startIndex.. (entities: [Entity], relationships: [Relationship]) + { + let candidates = capitalizedSpans(in: chunk.content) + + var byName: [String: Entity] = [:] + var orderedNames: [String] = [] + for candidate in candidates { + guard let (type, confidence) = classify(candidate.text) else { continue } + guard confidence >= minConfidence else { continue } + let name = candidate.text + let mention = EntityMention( + chunkID: chunk.id, + startOffset: candidate.start, + endOffset: candidate.end, + confidence: confidence) + if var existing = byName[name] { + existing.mentions.append(mention) + existing.confidence = max(existing.confidence, confidence) + byName[name] = existing + } else { + let entity = Entity( + id: PatternEntityExtractor.makeEntityID(type: type, name: name), + name: name, + entityType: type, + confidence: confidence, + mentions: [mention]) + byName[name] = entity + orderedNames.append(name) + } + } + + let entities = orderedNames.compactMap { byName[$0] } + let relationships = inferRelationships(entities: entities, chunk: chunk) + return (entities, relationships) + } + + /// Stable `"TYPE_normalized_name"` identifier. + public static func makeEntityID(type: String, name: String) -> EntityID { + let normalized = name.lowercased().map { ch -> Character in + (ch.isLetter || ch.isNumber) ? ch : "_" + } + var collapsed = "" + var lastUnderscore = false + for ch in normalized { + if ch == "_" { + if !lastUnderscore { collapsed.append(ch) } + lastUnderscore = true + } else { + collapsed.append(ch) + lastUnderscore = false + } + } + let trimmed = collapsed.trimmingCharacters(in: CharacterSet(charactersIn: "_")) + return EntityID("\(type.lowercased())_\(trimmed)") + } + + // MARK: - Span detection + + private struct Span { var text: String; var start: Int; var end: Int } + + /// Maximal runs of Title-Case words (allowing a leading title like "Dr."). + private func capitalizedSpans(in text: String) -> [Span] { + let chars = Array(text) + let n = chars.count + var spans: [Span] = [] + var i = 0 + while i < n { + if isWordStart(chars, i) && chars[i].isUppercase { + let runStart = i + var j = i + // Consume consecutive capitalized words (optionally separated by a + // single space and an optional connector like "of"/"the"). + while true { + // advance to end of current word + while j < n && !chars[j].isWhitespace { j += 1 } + // Clause punctuation (comma/semicolon/colon) ends the run so + // "Alice, Bob" stays two entities. Sentence punctuation + // (./!/?) also ends it ("Acme. Bob") — unless the word is a + // known abbreviation/title like "Dr." so "Dr. Smith" merges. + if j > runStart, let last = chars[j - 1].unicodeScalars.first { + if CharacterSet(charactersIn: ",;:").contains(last) { + // Keep "Acme, Inc." together: a comma immediately + // followed by an org suffix doesn't end the run. + if !(last == "," && nextWordIsOrgSuffix(chars, from: j)) { break } + } else if CharacterSet(charactersIn: ".!?").contains(last) { + var ws = j - 1 + while ws > runStart && !chars[ws - 1].isWhitespace { ws -= 1 } + let word = String(chars[ws.. spanStart, + let scalar = chars[spanEnd - 1].unicodeScalars.first, + trimSet.contains(scalar) + { + spanEnd -= 1 + } + if spanEnd - spanStart >= 2 { + let cleaned = String(chars[spanStart.. Bool { + if i == 0 { return true } + let prev = chars[i - 1] + // Start a run after whitespace or opening punctuation, so quoted or + // parenthesized names (e.g. "Ada Lovelace" or (Paris)) aren't skipped. + return prev.isWhitespace || "\"'([{".contains(prev) + } + + /// Whether the next word starting at/after `from` is an organization suffix + /// (e.g. "Inc"/"Inc."), used to keep "Acme, Inc." as one span. + private func nextWordIsOrgSuffix(_ chars: [Character], from: Int) -> Bool { + var t = from + while t < chars.count && (chars[t] == " " || chars[t] == "\t") { t += 1 } + var e = t + while e < chars.count && chars[e].isLetter { e += 1 } + guard e > t else { return false } + let word = String(chars[t.. (type: String, confidence: Float)? { + var words = text.split(separator: " ").map(String.init) + guard !words.isEmpty else { return nil } + // Drop a leading article so "The United States" / "The University of + // California" classify as location/organization instead of person. + if words.count > 1, ["the", "a", "an"].contains(words[0].lowercased()) { + words.removeFirst() + } + guard !words.isEmpty else { return nil } + + // Blocklist single sentence-initial words that are common/structural. + if words.count == 1, PatternEntityExtractor.blocklist.contains(words[0].lowercased()) { + return nil + } + + // Organizations by suffix. + if let last = words.last, + PatternEntityExtractor.orgSuffixes.contains(last.lowercased()) + { + return ("ORGANIZATION", 0.9) + } + // Organizations by prefix ("University of ...", etc.). + let lower = words.joined(separator: " ").lowercased() + for prefix in PatternEntityExtractor.orgPrefixes where lower.hasPrefix(prefix) { + return ("ORGANIZATION", 0.9) + } + // Known locations. + if PatternEntityExtractor.knownLocations.contains(lower) { + return ("LOCATION", 0.9) + } + // Titled persons. + if let first = words.first, + PatternEntityExtractor.personTitles.contains( + first.trimmingCharacters(in: CharacterSet(charactersIn: ".")).lowercased()) + { + return ("PERSON", 0.9) + } + // Multi-word Title Case -> likely a person/proper noun. + if words.count >= 2 { + return ("PERSON", 0.8) + } + // Single capitalized word -> generic concept. + if words[0].count >= 3 { + return ("CONCEPT", 0.6) + } + return nil + } + + // MARK: - Relationship inference + + private func inferRelationships(entities: [Entity], chunk: TextChunk) -> [Relationship] { + guard entities.count >= 2 else { return [] } + let chars = Array(chunk.content) + + // Assign a sentence id to every character offset (incremented after + // ./!/?), so relationships are only inferred between entities that + // co-occur in the SAME sentence — otherwise one "works for" phrase would + // wrongly link every person/org pair sharing a chunk. A period that ends + // a person title ("Dr.") is an abbreviation, not a sentence boundary, so + // "Dr. Smith works for Acme Inc." stays one sentence. + func periodIsSentenceEnd(_ periodIndex: Int) -> Bool { + var s = periodIndex + while s > 0 && chars[s - 1].isLetter { s -= 1 } + let word = String(chars[s..= chars.count || chars[t].isNewline { return true } + return chars[t].isUppercase + } + return true + } + var sentenceID = [Int](repeating: 0, count: chars.count + 1) + var sid = 0 + for k in 0.. Int { sentenceID[max(0, min(offset, chars.count))] } + + var relationships: [Relationship] = [] + var seen: Set = [] + + for i in 0.. = [ + "WORKS_FOR", "LEADS", "BORN_IN", "LOCATED_IN", "HEADQUARTERED_IN", + "MARRIED_TO", "COLLEAGUE_OF", + ] + + /// Whether an entity other than `a`/`b` has a mention after `offset` within + /// the same sentence (used to decide if a trailing cue window is safe). + private func entityMentionFollows( + after offset: Int, sentence: (Int) -> Int, entities: [Entity], a: Entity, b: Entity + ) -> Bool { + let s = sentence(offset) + for c in entities where c.id != a.id && c.id != b.id { + for m in c.mentions where m.startOffset > offset && sentence(m.startOffset) == s { + return true + } + } + return false + } + + /// Whether an entity (other than `a`/`b`) of the same type as one endpoint + /// has a mention strictly between offsets `lo` and `hi`. + private func hasInterveningSameType( + _ a: Entity, _ b: Entity, lo: Int, hi: Int, among entities: [Entity] + ) -> Bool { + for c in entities where c.id != a.id && c.id != b.id { + guard c.entityType == a.entityType || c.entityType == b.entityType else { continue } + for m in c.mentions where m.startOffset > lo && m.startOffset < hi { + return true + } + } + return false + } + + /// Order (source, target) for a typed relation by the entities' roles. + /// Symmetric relations keep their text order. + private func orient(_ relType: String, _ a: Entity, _ b: Entity) -> (Entity, Entity) { + func pick(_ source: String, _ target: String) -> (Entity, Entity)? { + if a.entityType == source && b.entityType == target { return (a, b) } + if b.entityType == source && a.entityType == target { return (b, a) } + return nil + } + switch relType { + case "WORKS_FOR", "LEADS": + return pick("PERSON", "ORGANIZATION") ?? (a, b) + case "BORN_IN": + return pick("PERSON", "LOCATION") ?? (a, b) + case "HEADQUARTERED_IN": + return pick("ORGANIZATION", "LOCATION") ?? (a, b) + case "LOCATED_IN": + // Whichever endpoint is the location is the target. + if a.entityType == "LOCATION" { return (b, a) } + if b.entityType == "LOCATION" { return (a, b) } + return (a, b) + default: + return (a, b) // ASSOCIATED_WITH / KNOWS / MARRIED_TO / RELATED_TO ... + } + } + + private func relationType(for a: String, _ b: String, context: String) -> String { + func has(_ s: String) -> Bool { context.contains(s) } + switch (a, b) { + case ("PERSON", "ORGANIZATION"), ("ORGANIZATION", "PERSON"): + if has("works for") || has("employed by") { return "WORKS_FOR" } + if has("founded") || has("ceo") { return "LEADS" } + return "ASSOCIATED_WITH" + case ("PERSON", "LOCATION"), ("LOCATION", "PERSON"): + if has("born in") || has(" from ") { return "BORN_IN" } + if has("lives in") || has("based in") { return "LOCATED_IN" } + return "ASSOCIATED_WITH" + case ("ORGANIZATION", "LOCATION"), ("LOCATION", "ORGANIZATION"): + if has("headquartered") || has("based in") { return "HEADQUARTERED_IN" } + return "LOCATED_IN" + case ("PERSON", "PERSON"): + if has("married") || has("spouse") { return "MARRIED_TO" } + if has("colleague") || has("partner") { return "COLLEAGUE_OF" } + return "KNOWS" + default: + return "RELATED_TO" + } + } + + // MARK: - Lexicons + + static let orgSuffixes: Set = [ + "inc", "inc.", "corp", "corp.", "llc", "ltd", "ltd.", "company", + "corporation", "group", "solutions", "technologies", + ] + static let orgPrefixes: [String] = ["university of", "institute of", "department of"] + static let knownLocations: Set = [ + "united states", "new york", "california", "london", "paris", "tokyo", + "berlin", "washington", "boston", "chicago", + ] + static let personTitles: Set = ["dr", "prof", "mr", "mrs", "ms"] + /// Words whose trailing period is an abbreviation rather than a sentence end, + /// used only for sentence segmentation in relationship inference (so + /// "Acme Inc. was founded by Sam Altman" stays one sentence). Entity-span + /// splitting still uses the narrower `personTitles`. + static let sentenceAbbreviations: Set = [ + "dr", "prof", "mr", "mrs", "ms", "jr", "sr", "st", + "inc", "corp", "ltd", "llc", "co", "etc", "vs", + ] + static let blocklist: Set = [ + "the", "and", "but", "or", "chapter", "section", "however", "therefore", + "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", + "january", "february", "march", "april", "may", "june", "july", "august", + "september", "october", "november", "december", "this", "that", "these", + "those", "there", "here", "when", "where", "what", "who", "why", "how", + ] +} diff --git a/Sources/GraphRAG/Entity/Prompts.swift b/Sources/GraphRAG/Entity/Prompts.swift new file mode 100644 index 0000000..1aa1fbb --- /dev/null +++ b/Sources/GraphRAG/Entity/Prompts.swift @@ -0,0 +1,148 @@ +// Prompts.swift +// Ported from graphrag-rs `entity::prompts` and the answer-generation template +// in `graphrag::ask`. Templates use `{placeholder}` markers filled by callers. + +import Foundation + +public enum Prompts { + /// Default entity types requested from the LLM. + public static let defaultEntityTypes: [String] = [ + "PERSON", "ORGANIZATION", "LOCATION", "EVENT", "CONCEPT", "OBJECT", + ] + + /// Single-pass entity + relationship extraction prompt. + public static let entityExtraction = """ + -Goal- + Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + + -Steps- + 1. Identify all entities. For each identified entity, extract the following information: + - entity_name: Name of the entity, capitalized + - entity_type: One of the following types: [{entity_types}] + - entity_description: Comprehensive description of the entity's attributes and activities + + 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. + For each pair of related entities, extract the following information: + - source_entity: name of the source entity, as identified in step 1 + - target_entity: name of the target entity, as identified in step 1 + - relationship_description: explanation as to why you think the source entity and the target entity are related to each other + - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity + + 3. Return output in JSON format with the following structure: + { + "entities": [ + { "name": "entity name", "type": "entity type", "description": "entity description" } + ], + "relationships": [ + { "source": "source entity name", "target": "target entity name", "description": "relationship description", "strength": 0.8 } + ] + } + + -Real Data- + ###################### + Entity Types: {entity_types} + Text: {input_text} + ###################### + Output: + """ + + /// Gleaning continuation prompt to catch entities/relationships missed on + /// the first pass. + public static let gleaningContinuation = """ + -Goal- + You previously extracted entities and relationships from a text document. Review your previous extraction and the original text to identify any additional entities or relationships you may have missed in the first pass. + + -Steps- + 1. Review the entities you previously identified: + {previous_entities} + + 2. Review the relationships you previously identified: + {previous_relationships} + + 3. Carefully review the original text again and identify any entities or relationships you may have missed. + + 4. Return ONLY the NEW entities and relationships you discovered in this pass, using the same JSON format: + { + "entities": [ + { "name": "entity name", "type": "entity type", "description": "entity description" } + ], + "relationships": [ + { "source": "source entity name", "target": "target entity name", "description": "relationship description", "strength": 0.8 } + ] + } + + If you found no additional entities or relationships, return empty arrays. + + -Real Data- + ###################### + Entity Types: {entity_types} + Text: {input_text} + ###################### + Output: + """ + + /// Completion-check prompt; the model answers only YES or NO. + public static let completionCheck = """ + Based on the text below and the entities/relationships already extracted, are there any significant entities or relationships that have been missed? + + Text: + {input_text} + + Current Entities ({entity_count}): + {entities_summary} + + Current Relationships ({relationship_count}): + {relationships_summary} + + Respond with ONLY "YES" if the extraction is complete and thorough, or "NO" if there are still significant entities or relationships missing. + + Answer (YES or NO): + """ + + /// Answer-generation prompt used by `GraphRAG.ask`. + public static let answerGeneration = """ + You are a knowledgeable assistant specialized in answering questions based on a knowledge graph. + + IMPORTANT INSTRUCTIONS: + - Answer ONLY using information from the provided context below + - Synthesize information from ALL context sections to give a comprehensive answer + - Provide direct, conversational, and natural responses + - Do NOT show your reasoning process or use tags + - If the context lacks sufficient information, clearly state: "I don't have enough information to answer this question." + - Aim for a complete answer (3-6 sentences) that covers different aspects found across the context + - Use a natural, helpful tone as if speaking to a person + + CONTEXT: + {context} + + QUESTION: {query} + + ANSWER (direct response only, no reasoning): + """ + + /// Fill `{key}` placeholders in `template` with `values`. + /// + /// Scans the template once: a `{key}` is replaced only when `key` is a known + /// value, and inserted values are never re-scanned. This keeps literal braces + /// in the template (e.g. JSON examples) intact and prevents a value that + /// itself contains `{query}`-style text from being substituted further. + public static func fill(_ template: String, _ values: [String: String]) -> String { + var result = "" + var i = template.startIndex + while i < template.endIndex { + if template[i] == "{", + let close = template[template.index(after: i)...].firstIndex(of: "}") + { + let key = String(template[template.index(after: i)..] = [:] + for id in nodeList { adj[id] = [] } + for rel in graph.relationships { + // Only connect endpoints that are actual graph nodes. + guard nodeSet.contains(rel.source), nodeSet.contains(rel.target) else { continue } + adj[rel.source, default: []].insert(rel.target) + adj[rel.target, default: []].insert(rel.source) + } + self.adjacency = adj.mapValues { Array($0) } + } + + private func neighbors(_ id: EntityID) -> [EntityID] { adjacency[id] ?? [] } + + // MARK: - Degree + + /// Degree centrality: `degree / (n - 1)`, in `[0, 1]`. + public func degreeCentrality(_ id: EntityID) -> Float { + let n = nodes.count + guard n > 1 else { return 0 } + return Float(neighbors(id).count) / Float(n - 1) + } + + // MARK: - Closeness + + /// Closeness centrality: reachable node count divided by total distance. + public func closenessCentrality(_ id: EntityID) -> Float { + let distances = bfsDistances(from: id) + var total = 0 + var reachable = 0 + for (node, dist) in distances where node != id { + total += dist + reachable += 1 + } + guard total > 0 else { return 0 } + return Float(reachable) / Float(total) + } + + // MARK: - Betweenness (Brandes, unweighted) + + /// Normalized betweenness centrality for every node, via Brandes' algorithm. + public func betweennessCentrality() -> [EntityID: Float] { + var betweenness: [EntityID: Double] = [:] + for id in nodes { betweenness[id] = 0 } + let n = nodes.count + guard n > 2 else { return betweenness.mapValues { Float($0) } } + + for source in nodes { + var stack: [EntityID] = [] + var predecessors: [EntityID: [EntityID]] = [:] + var sigma: [EntityID: Double] = [:] + var dist: [EntityID: Int] = [:] + for id in nodes { sigma[id] = 0; dist[id] = -1; predecessors[id] = [] } + sigma[source] = 1 + dist[source] = 0 + + var queue: [EntityID] = [source] + var head = 0 + while head < queue.count { + let v = queue[head]; head += 1 + stack.append(v) + for w in neighbors(v) { + if dist[w]! < 0 { + dist[w] = dist[v]! + 1 + queue.append(w) + } + if dist[w]! == dist[v]! + 1 { + sigma[w]! += sigma[v]! + predecessors[w]!.append(v) + } + } + } + + var delta: [EntityID: Double] = [:] + for id in nodes { delta[id] = 0 } + while let w = stack.popLast() { + for v in predecessors[w]! { + delta[v]! += (sigma[v]! / sigma[w]!) * (1 + delta[w]!) + } + if w != source { betweenness[w]! += delta[w]! } + } + } + + // Undirected: each pair counted twice; normalize to [0, 1]. + let norm = Double((n - 1) * (n - 2)) + var result: [EntityID: Float] = [:] + for (id, value) in betweenness { + result[id] = norm > 0 ? Float(value / norm) : 0 + } + return result + } + + /// Combined centrality scores for a single node. + public func centrality(_ id: EntityID) -> CentralityScores { + CentralityScores( + degree: degreeCentrality(id), + betweenness: betweennessCentrality()[id] ?? 0, + closeness: closenessCentrality(id) + ) + } + + // MARK: - Components + + /// The connected component (undirected) containing `start`. + public func connectedComponent(containing start: EntityID) -> [EntityID] { + guard graph.contains(start) else { return [] } + var visited: Set = [start] + var queue: [EntityID] = [start] + var component: [EntityID] = [] + var head = 0 + while head < queue.count { + let current = queue[head]; head += 1 + component.append(current) + for neighbor in neighbors(current) where !visited.contains(neighbor) { + visited.insert(neighbor) + queue.append(neighbor) + } + } + return component + } + + /// All connected components of the graph. + public func connectedComponents() -> [[EntityID]] { + var visited: Set = [] + var components: [[EntityID]] = [] + for node in nodes where !visited.contains(node) { + let component = connectedComponent(containing: node) + for c in component { visited.insert(c) } + components.append(component) + } + return components + } + + // MARK: - Global + + /// Graph density: `2E / (n(n-1))`, where `E` is the number of unique + /// undirected pairs. Counting unique pairs (rather than raw stored edges) + /// keeps density in `[0, 1]` even with reciprocal or multiple typed edges + /// between the same two entities. + public func density() -> Float { + let n = nodes.count + guard n > 1 else { return 0 } + let uniqueEdges = adjacency.values.reduce(0) { $0 + $1.count } / 2 + return Float(2 * uniqueEdges) / Float(n * (n - 1)) + } + + /// Local clustering coefficient: fraction of a node's neighbour pairs that + /// are themselves connected. + public func clusteringCoefficient(_ id: EntityID) -> Float { + let ns = neighbors(id) + let k = ns.count + guard k > 1 else { return 0 } + var links = 0 + for i in 0.. 0 ? Float(links) / Float(possible) : 0 + } + + private func bfsDistances(from source: EntityID) -> [EntityID: Int] { + var dist: [EntityID: Int] = [source: 0] + var queue: [EntityID] = [source] + var head = 0 + while head < queue.count { + let current = queue[head]; head += 1 + let d = dist[current]! + for neighbor in neighbors(current) where dist[neighbor] == nil { + dist[neighbor] = d + 1 + queue.append(neighbor) + } + } + return dist + } +} diff --git a/Sources/GraphRAG/Graph/KnowledgeGraph.swift b/Sources/GraphRAG/Graph/KnowledgeGraph.swift new file mode 100644 index 0000000..5ab1d83 --- /dev/null +++ b/Sources/GraphRAG/Graph/KnowledgeGraph.swift @@ -0,0 +1,365 @@ +// KnowledgeGraph.swift +// Ported from graphrag-rs `core::KnowledgeGraph`. +// +// The Rust version is backed by petgraph plus side indexes. This port uses a +// value-type adjacency representation: entities/relationships are stored in +// insertion order with `[ID: Int]` indexes for O(1) lookup, mirroring the +// `entity_index` HashMap and IndexMap behaviour. + +import Foundation + +public struct KnowledgeGraph: Sendable, Codable { + // Entities, in insertion order. + private var entitiesByID: [EntityID: Entity] + private var entityOrder: [EntityID] + + // Relationships, in insertion order, with adjacency indexes into the array. + public private(set) var relationships: [Relationship] + private var outgoing: [EntityID: [Int]] + private var incoming: [EntityID: [Int]] + + // Documents and chunks, in insertion order. + private var documentsByID: [DocumentID: Document] + private var documentOrder: [DocumentID] + private var chunksByID: [ChunkID: TextChunk] + private var chunkOrder: [ChunkID] + + public init() { + entitiesByID = [:] + entityOrder = [] + relationships = [] + outgoing = [:] + incoming = [:] + documentsByID = [:] + documentOrder = [] + chunksByID = [:] + chunkOrder = [] + } + + // MARK: - Mutation + + /// Insert an entity. If one with the same id already exists, mentions are + /// merged and the higher confidence / any available embedding is kept. + public mutating func addEntity(_ entity: Entity) { + if var existing = entitiesByID[entity.id] { + existing.mentions.append(contentsOf: entity.mentions) + existing.confidence = max(existing.confidence, entity.confidence) + if existing.embedding == nil { existing.embedding = entity.embedding } + if existing.entityType.isEmpty { existing.entityType = entity.entityType } + entitiesByID[entity.id] = existing + } else { + entitiesByID[entity.id] = entity + entityOrder.append(entity.id) + } + } + + /// Insert a directed relationship. Duplicate (source, target, type) edges are + /// merged: their evidence context is unioned and the max confidence kept. + public mutating func addRelationship(_ relationship: Relationship) { + // Ignore dangling edges: both endpoints must be nodes, otherwise + // `neighbors(of:)`/traversals could surface an EntityID with no node. + guard entitiesByID[relationship.source] != nil, + entitiesByID[relationship.target] != nil + else { return } + // Merge duplicates. + if let existingIndices = outgoing[relationship.source] { + for idx in existingIndices + where relationships[idx].target == relationship.target + && relationships[idx].relationType == relationship.relationType + { + relationships[idx].confidence = max( + relationships[idx].confidence, relationship.confidence) + for ctx in relationship.context where !relationships[idx].context.contains(ctx) { + relationships[idx].context.append(ctx) + } + return + } + } + let index = relationships.count + relationships.append(relationship) + outgoing[relationship.source, default: []].append(index) + incoming[relationship.target, default: []].append(index) + } + + public mutating func addDocument(_ document: Document) { + if documentsByID[document.id] == nil { + documentOrder.append(document.id) + } else { + // Replacing an existing id: purge the previous version's chunks so + // direct KnowledgeGraph callers don't retain stale chunk text. + removeChunks(forDocument: document.id) + } + documentsByID[document.id] = document + } + + public mutating func addChunk(_ chunk: TextChunk) { + if chunksByID[chunk.id] == nil { chunkOrder.append(chunk.id) } + chunksByID[chunk.id] = chunk + // Keep the copy embedded in its document in sync, so + // `document(id)?.chunks` and saved JSON reflect enrichment too. + if var doc = documentsByID[chunk.documentID] { + if let idx = doc.chunks.firstIndex(where: { $0.id == chunk.id }) { + doc.chunks[idx] = chunk + } else { + doc.chunks.append(chunk) + } + documentsByID[chunk.documentID] = doc + } + } + + /// Remove all chunks belonging to a document (used when a document is + /// replaced so stale chunks don't survive). + public mutating func removeChunks(forDocument documentID: DocumentID) { + let removed = Set(chunkOrder.filter { chunksByID[$0]?.documentID == documentID }) + guard !removed.isEmpty else { return } + chunkOrder.removeAll { removed.contains($0) } + for id in removed { chunksByID.removeValue(forKey: id) } + if var doc = documentsByID[documentID] { + doc.chunks.removeAll { removed.contains($0.id) } + documentsByID[documentID] = doc + } + // Scrub evidence pointing at the removed chunks. Drop mentions that + // reference them; if an entity's mentions are entirely exhausted (no + // independent evidence remains), remove the entity too so it doesn't + // linger in stats/traversal/JSON from a document version that's gone. + var removedEntities: Set = [] + for eid in entityOrder { + guard var entity = entitiesByID[eid], !entity.mentions.isEmpty else { continue } + let kept = entity.mentions.filter { !removed.contains($0.chunkID) } + if kept.isEmpty { + removedEntities.insert(eid) + } else if kept.count != entity.mentions.count { + entity.mentions = kept + entitiesByID[eid] = entity + } + } + if !removedEntities.isEmpty { + for eid in removedEntities { entitiesByID.removeValue(forKey: eid) } + entityOrder.removeAll { removedEntities.contains($0) } + } + // Scrub relationship context; drop a relationship whose only evidence was + // a removed chunk, or whose endpoint entity was just removed (a leftover + // would expose a stale fact). Edges that never had context are kept. + var survived: [Relationship] = [] + survived.reserveCapacity(relationships.count) + for var rel in relationships { + if removedEntities.contains(rel.source) || removedEntities.contains(rel.target) { + continue + } + let hadContext = !rel.context.isEmpty + rel.context.removeAll { removed.contains($0) } + if hadContext && rel.context.isEmpty { continue } + survived.append(rel) + } + relationships = survived + rebuildAdjacency() + } + + /// Rebuild the outgoing/incoming index after the `relationships` array changes. + private mutating func rebuildAdjacency() { + outgoing.removeAll() + incoming.removeAll() + for (index, rel) in relationships.enumerated() { + outgoing[rel.source, default: []].append(index) + incoming[rel.target, default: []].append(index) + } + } + + /// Drop all entities and relationships, preserving documents and chunks. + /// Chunk entity references are cleared too (in both `chunksByID` and the + /// document copies) so no chunk points at an entity id that no longer exists. + public mutating func clearEntitiesAndRelationships() { + entitiesByID.removeAll() + entityOrder.removeAll() + relationships.removeAll() + outgoing.removeAll() + incoming.removeAll() + for id in chunkOrder where !(chunksByID[id]?.entities.isEmpty ?? true) { + chunksByID[id]?.entities = [] + } + for did in documentOrder { + guard var doc = documentsByID[did] else { continue } + for i in doc.chunks.indices where !doc.chunks[i].entities.isEmpty { + doc.chunks[i].entities = [] + } + documentsByID[did] = doc + } + } + + // MARK: - Lookup + + public func entity(_ id: EntityID) -> Entity? { entitiesByID[id] } + public func document(_ id: DocumentID) -> Document? { documentsByID[id] } + public func chunk(_ id: ChunkID) -> TextChunk? { chunksByID[id] } + public func contains(_ id: EntityID) -> Bool { entitiesByID[id] != nil } + + public var entities: [Entity] { entityOrder.compactMap { entitiesByID[$0] } } + public var documents: [Document] { documentOrder.compactMap { documentsByID[$0] } } + public var chunks: [TextChunk] { chunkOrder.compactMap { chunksByID[$0] } } + + public var entityCount: Int { entitiesByID.count } + public var relationshipCount: Int { relationships.count } + public var documentCount: Int { documentsByID.count } + public var chunkCount: Int { chunksByID.count } + + /// Bidirectional neighbors: for every incident edge, the other endpoint and + /// the relationship. Deduplicated per (neighbor, relationType). + public func neighbors(of id: EntityID) -> [(neighbor: EntityID, relationship: Relationship)] { + // Keep the highest-confidence edge per (neighbor, relationType) so a weak + // A->B can't hide a stronger reciprocal B->A from strength-filtered + // traversals. + var bestIndexByKey: [String: Int] = [:] + var order: [String] = [] + func consider(_ index: Int, neighbor: EntityID) { + let key = "\(neighbor.raw)|\(relationships[index].relationType)" + if let existing = bestIndexByKey[key] { + if relationships[index].confidence > relationships[existing].confidence { + bestIndexByKey[key] = index + } + } else { + bestIndexByKey[key] = index + order.append(key) + } + } + for idx in outgoing[id] ?? [] { consider(idx, neighbor: relationships[idx].target) } + for idx in incoming[id] ?? [] { consider(idx, neighbor: relationships[idx].source) } + return order.map { key in + let rel = relationships[bestIndexByKey[key]!] + let neighbor = rel.source == id ? rel.target : rel.source + return (neighbor, rel) + } + } + + /// All relationships where `id` is the source or target. + public func entityRelationships(_ id: EntityID) -> [Relationship] { + var out: [Relationship] = [] + for idx in outgoing[id] ?? [] { out.append(relationships[idx]) } + for idx in incoming[id] ?? [] { out.append(relationships[idx]) } + return out + } + + public func outDegree(_ id: EntityID) -> Int { (outgoing[id] ?? []).count } + public func inDegree(_ id: EntityID) -> Int { (incoming[id] ?? []).count } + public func degree(_ id: EntityID) -> Int { outDegree(id) + inDegree(id) } + + /// Case-insensitive substring match against entity names. + public func findEntitiesByName(_ name: String) -> [Entity] { + let needle = name.lowercased() + return entities.filter { $0.name.lowercased().contains(needle) } + } + + /// Shortest path (by hop count) between two entities via BFS, inclusive of + /// endpoints, or nil if unreachable within `maxDepth`. + public func findRelationshipPath( + from source: EntityID, to target: EntityID, maxDepth: Int = 5 + ) -> [EntityID]? { + // Endpoints must exist; otherwise even the self-path is meaningless. + guard contains(source), contains(target) else { return nil } + if source == target { return [source] } + var visited: Set = [source] + var queue: [(EntityID, [EntityID])] = [(source, [source])] + while !queue.isEmpty { + let (current, path) = queue.removeFirst() + if path.count > maxDepth { continue } + for (neighbor, _) in neighbors(of: current) where !visited.contains(neighbor) { + let newPath = path + [neighbor] + if neighbor == target { return newPath } + visited.insert(neighbor) + queue.append((neighbor, newPath)) + } + } + return nil + } + + public func stats() -> GraphStats { + let n = entityCount + let avgDegree = n > 0 ? Float(2 * relationshipCount) / Float(n) : 0 + return GraphStats( + nodeCount: n, + edgeCount: relationshipCount, + averageDegree: avgDegree, + maxDepth: diameter() + ) + } + + /// Longest shortest-path (in hops) over the undirected graph — i.e. the + /// graph's diameter. O(V·(V+E)); intended for occasional stats calls. + private func diameter() -> Int { + guard entityOrder.count > 1 else { return 0 } + var adjacency: [EntityID: [EntityID]] = [:] + for rel in relationships { + adjacency[rel.source, default: []].append(rel.target) + adjacency[rel.target, default: []].append(rel.source) + } + var maxDist = 0 + for start in entityOrder { + var dist: [EntityID: Int] = [start: 0] + var queue: [EntityID] = [start] + var head = 0 + while head < queue.count { + let current = queue[head] + head += 1 + let d = dist[current]! + if d > maxDist { maxDist = d } + for neighbor in adjacency[current] ?? [] where dist[neighbor] == nil { + dist[neighbor] = d + 1 + queue.append(neighbor) + } + } + } + return maxDist + } + + // MARK: - Codable + + private enum CodingKeys: String, CodingKey { + case entities, relationships, documents, chunks + } + + public init(from decoder: Decoder) throws { + self.init() + let container = try decoder.container(keyedBy: CodingKeys.self) + let decodedEntities = try container.decode([Entity].self, forKey: .entities) + let decodedDocuments = try container.decode([Document].self, forKey: .documents) + let decodedChunks = try container.decode([TextChunk].self, forKey: .chunks) + let decodedRelationships = try container.decode([Relationship].self, forKey: .relationships) + for e in decodedEntities { addEntity(e) } + for d in decodedDocuments { addDocument(d) } + for c in decodedChunks { addChunk(c) } + for r in decodedRelationships { addRelationship(r) } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(entities, forKey: .entities) + try container.encode(relationships, forKey: .relationships) + try container.encode(documents, forKey: .documents) + try container.encode(chunks, forKey: .chunks) + } + + /// Serialize the graph to a JSON file. + public func save(toJSON path: String) throws { + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + do { + let data = try encoder.encode(self) + try data.write(to: URL(fileURLWithPath: path)) + } catch let error as GraphRAGError { + throw error + } catch { + throw GraphRAGError.io(message: error.localizedDescription) + } + } + + /// Load a graph from a JSON file. + public static func load(fromJSON path: String) throws -> KnowledgeGraph { + do { + let data = try Data(contentsOf: URL(fileURLWithPath: path)) + return try JSONDecoder().decode(KnowledgeGraph.self, from: data) + } catch let error as GraphRAGError { + throw error + } catch { + throw GraphRAGError.io(message: error.localizedDescription) + } + } +} diff --git a/Sources/GraphRAG/Graph/PageRank.swift b/Sources/GraphRAG/Graph/PageRank.swift new file mode 100644 index 0000000..a6f8284 --- /dev/null +++ b/Sources/GraphRAG/Graph/PageRank.swift @@ -0,0 +1,95 @@ +// PageRank.swift +// Ported from graphrag-rs `graph::pagerank`. + +import Foundation + +/// Weighted PageRank over the knowledge graph's directed relationships. +public struct PageRank: Sendable { + /// Probability of following a link vs. teleporting (default 0.85). + public var dampingFactor: Double + /// Maximum power iterations (default 100). + public var maxIterations: Int + /// L-infinity convergence threshold (default 1e-6). + public var tolerance: Double + + public init(dampingFactor: Double = 0.85, maxIterations: Int = 100, tolerance: Double = 1e-6) { + self.dampingFactor = dampingFactor + self.maxIterations = maxIterations + self.tolerance = tolerance + } + + /// Compute a PageRank score in `[0, 1]` for each entity. Scores sum to 1. + public func compute(_ graph: KnowledgeGraph) -> [EntityID: Double] { + let nodes = graph.entities.map(\.id) + let n = nodes.count + guard n > 0 else { return [:] } + if n == 1 { return [nodes[0]: 1.0] } + + var indexOf: [EntityID: Int] = [:] + for (i, id) in nodes.enumerated() { indexOf[id] = i } + + // Incoming contributions: for each target i, list of (source j, weight). + var incomingEdges: [[(source: Int, weight: Double)]] = Array(repeating: [], count: n) + var outWeight = [Double](repeating: 0, count: n) + for rel in graph.relationships { + guard let s = indexOf[rel.source], let t = indexOf[rel.target] else { continue } + // Skip non-positive-confidence edges entirely; otherwise a single + // zero-confidence edge would receive all of a node's PageRank mass. + // Nodes left with no positive out-edge are handled as dangling. + let w = Double(rel.confidence) + guard w > 0 else { continue } + incomingEdges[t].append((s, w)) + outWeight[s] += w + } + + // Clamp to a valid probability so a misconfigured factor can't produce a + // negative teleport term (and negative scores). + let d = min(max(dampingFactor, 0), 1) + let teleport = (1.0 - d) / Double(n) + var scores = [Double](repeating: 1.0 / Double(n), count: n) + + for _ in 0.. 0 { + for i in 0.. [(id: EntityID, score: Double)] { + guard k > 0 else { return [] } + let scores = compute(graph) + return scores.sorted { lhs, rhs in + if lhs.value == rhs.value { return lhs.key.raw < rhs.key.raw } + return lhs.value > rhs.value + } + .prefix(k) + .map { (id: $0.key, score: $0.value) } + } +} diff --git a/Sources/GraphRAG/Graph/Traversal.swift b/Sources/GraphRAG/Graph/Traversal.swift new file mode 100644 index 0000000..2712fff --- /dev/null +++ b/Sources/GraphRAG/Graph/Traversal.swift @@ -0,0 +1,208 @@ +// Traversal.swift +// Ported from graphrag-rs `graph::traversal`. + +import Foundation + +/// Tunables that govern graph traversal. +public struct TraversalConfig: Sendable { + public var maxDepth: Int + public var maxPaths: Int + public var useEdgeWeights: Bool + public var minRelationshipStrength: Float + + public init( + maxDepth: Int = 3, + maxPaths: Int = 100, + useEdgeWeights: Bool = true, + minRelationshipStrength: Float = 0.5 + ) { + self.maxDepth = maxDepth + self.maxPaths = maxPaths + self.useEdgeWeights = useEdgeWeights + self.minRelationshipStrength = minRelationshipStrength + } +} + +/// The product of a traversal: discovered entities, the edges walked, and the +/// depth/distance of each entity from the source(s). +public struct TraversalResult: Sendable { + public var entities: [EntityID] + public var relationships: [Relationship] + public var distances: [EntityID: Int] + + public init( + entities: [EntityID] = [], + relationships: [Relationship] = [], + distances: [EntityID: Int] = [:] + ) { + self.entities = entities + self.relationships = relationships + self.distances = distances + } +} + +/// Breadth-/depth-first traversal of the knowledge graph with edge-strength +/// filtering. +public struct GraphTraversal: Sendable { + public var config: TraversalConfig + + public init(config: TraversalConfig = TraversalConfig()) { + self.config = config + } + + private func passesFilter(_ relationship: Relationship) -> Bool { + !config.useEdgeWeights || relationship.confidence >= config.minRelationshipStrength + } + + /// Breadth-first search from a single source. + public func bfs(_ graph: KnowledgeGraph, from source: EntityID) -> TraversalResult { + multiSourceBFS(graph, from: [source]) + } + + /// Breadth-first search from multiple sources simultaneously. + public func multiSourceBFS(_ graph: KnowledgeGraph, from sources: [EntityID]) -> TraversalResult { + var result = TraversalResult() + var visited: Set = [] + var queue: [EntityID] = [] + for source in sources where graph.contains(source) && !visited.contains(source) { + visited.insert(source) + result.distances[source] = 0 + result.entities.append(source) + queue.append(source) + } + + var head = 0 + while head < queue.count { + let current = queue[head] + head += 1 + let depth = result.distances[current] ?? 0 + if depth >= config.maxDepth { continue } + for (neighbor, relationship) in graph.neighbors(of: current) { + guard passesFilter(relationship) else { continue } + if !visited.contains(neighbor) { + visited.insert(neighbor) + result.distances[neighbor] = depth + 1 + result.entities.append(neighbor) + result.relationships.append(relationship) + queue.append(neighbor) + } + } + } + return result + } + + /// Depth-first search from a single source. + public func dfs(_ graph: KnowledgeGraph, from source: EntityID) -> TraversalResult { + var result = TraversalResult() + guard graph.contains(source) else { return result } + var visited: Set = [] + dfsVisit(graph, current: source, depth: 0, visited: &visited, result: &result) + return result + } + + private func dfsVisit( + _ graph: KnowledgeGraph, + current: EntityID, + depth: Int, + visited: inout Set, + result: inout TraversalResult + ) { + if depth > config.maxDepth || visited.contains(current) { return } + visited.insert(current) + result.distances[current] = depth + result.entities.append(current) + // Stop expanding at the depth limit so we never record an edge to a node + // that won't itself be visited (matches BFS and the documented limit). + guard depth < config.maxDepth else { return } + for (neighbor, relationship) in graph.neighbors(of: current) { + guard passesFilter(relationship) else { continue } + if !visited.contains(neighbor) { + result.relationships.append(relationship) + dfsVisit(graph, current: neighbor, depth: depth + 1, visited: &visited, result: &result) + } + } + } + + /// k-hop ego network expanding layer by layer around `center`. + public func egoNetwork(_ graph: KnowledgeGraph, center: EntityID, hops: Int? = nil) -> TraversalResult { + let k = hops ?? config.maxDepth + var result = TraversalResult() + guard graph.contains(center) else { return result } + var visited: Set = [center] + result.distances[center] = 0 + result.entities.append(center) + var currentLayer = [center] + // De-duplicate emitted edges: neighbors of adjacent layers can revisit + // the same edge, which would otherwise overcount evidence/degrees. + var emittedEdges: Set = [] + + var hop = 1 + while hop <= k && !currentLayer.isEmpty { + var nextLayer: [EntityID] = [] + for entity in currentLayer { + for (neighbor, relationship) in graph.neighbors(of: entity) { + guard passesFilter(relationship) else { continue } + let edgeKey = + "\(relationship.source.raw)|\(relationship.target.raw)|\(relationship.relationType)" + if emittedEdges.insert(edgeKey).inserted { + result.relationships.append(relationship) + } + if !visited.contains(neighbor) { + visited.insert(neighbor) + result.distances[neighbor] = hop + result.entities.append(neighbor) + nextLayer.append(neighbor) + } + } + } + currentLayer = nextLayer + hop += 1 + } + return result + } + + /// Enumerate simple paths from `source` to `target` up to `maxDepth` hops, + /// capped at `maxPaths`. + public func findAllPaths(_ graph: KnowledgeGraph, from source: EntityID, to target: EntityID) -> [[EntityID]] { + var paths: [[EntityID]] = [] + guard graph.contains(source), graph.contains(target) else { return paths } + var visited: Set = [] + var current: [EntityID] = [source] + pathDFS(graph, current: source, target: target, remaining: config.maxDepth, + path: ¤t, visited: &visited, paths: &paths) + return paths + } + + private func pathDFS( + _ graph: KnowledgeGraph, + current: EntityID, + target: EntityID, + remaining: Int, + path: inout [EntityID], + visited: inout Set, + paths: inout [[EntityID]] + ) { + if paths.count >= config.maxPaths { return } + if current == target { + paths.append(path) + return + } + // `<= 0` (not `== 0`) so a negative configured maxDepth yields no expansion. + if remaining <= 0 { return } + visited.insert(current) + // Paths are entity-only, so collapse parallel edges (multiple relation + // types between the same nodes) to one neighbor to avoid duplicate paths. + var uniqueNeighbors: [EntityID] = [] + var seenNeighbors: Set = [] + for (neighbor, relationship) in graph.neighbors(of: current) where passesFilter(relationship) { + if seenNeighbors.insert(neighbor).inserted { uniqueNeighbors.append(neighbor) } + } + for neighbor in uniqueNeighbors where !visited.contains(neighbor) { + path.append(neighbor) + pathDFS(graph, current: neighbor, target: target, remaining: remaining - 1, + path: &path, visited: &visited, paths: &paths) + path.removeLast() + } + visited.remove(current) + } +} diff --git a/Sources/GraphRAG/GraphRAG.swift b/Sources/GraphRAG/GraphRAG.swift index 08b22b8..87791e7 100644 --- a/Sources/GraphRAG/GraphRAG.swift +++ b/Sources/GraphRAG/GraphRAG.swift @@ -1,2 +1,40 @@ -// The Swift Programming Language -// https://docs.swift.org/swift-book +// GraphRAG.swift +// Umbrella documentation for the GraphRAG Swift package — a port of the Rust +// crate graphrag-rs (https://github.com/automataIA/graphrag-rs). +// +// GraphRAG builds a knowledge graph from documents and answers natural-language +// questions using graph-based context retrieval. +// +// Quick start: +// ```swift +// import GraphRAG +// +// let rag = try GraphRAGBuilder() +// .withChunkSize(800) +// .withChunkOverlap(100) +// .withTopK(5) +// .build() +// +// await rag.addDocument(text: "Ada Lovelace worked with Charles Babbage ...") +// try await rag.build() +// let answer = try await rag.ask("Who did Ada Lovelace work with?") +// print(answer.text) +// ``` +// +// Everything in this package is `public`. The principal entry points are: +// - `GraphRAG` the orchestrating actor (ingest → build → ask) +// - `GraphRAGBuilder` fluent configuration +// - `Config` tunable defaults +// - `KnowledgeGraph` the entity/relationship graph + documents/chunks +// - `HybridRetriever` BM25 + vector fusion retrieval +// - `PageRank`, `GraphTraversal`, `GraphAnalytics` graph algorithms +// +// Pluggable backends conform to `EmbeddingModel`, `LanguageModel`, and +// `EntityExtracting`. Offline defaults (`HashEmbedder`, `PatternEntityExtractor`) +// require no network or model download; `OllamaClient` / `OllamaEmbedder` enable +// local LLM-backed extraction and generation. + +/// The semantic version of this GraphRAG port. +public enum GraphRAGVersion { + public static let current = "0.2.0" +} diff --git a/Sources/GraphRAG/GraphRAG/Builder.swift b/Sources/GraphRAG/GraphRAG/Builder.swift new file mode 100644 index 0000000..532c3ee --- /dev/null +++ b/Sources/GraphRAG/GraphRAG/Builder.swift @@ -0,0 +1,139 @@ +// Builder.swift +// Ported from graphrag-rs `builder::mod` (the fluent GraphRAGBuilder). + +import Foundation + +/// Fluent builder for assembling a configured `GraphRAG` instance. +/// +/// ```swift +/// let rag = try GraphRAGBuilder() +/// .withChunkSize(800) +/// .withTopK(5) +/// .build() +/// ``` +public struct GraphRAGBuilder: Sendable { + private var config: Config + private var ollamaConfig: OllamaConfig + private var useOllamaChat: Bool = false + + public init(config: Config = .default) { + self.config = config + self.ollamaConfig = OllamaConfig() + } + + // MARK: - General config + + public func withOutputDir(_ dir: String) -> Self { + var copy = self + copy.config.outputDir = dir + return copy + } + + public func withChunkSize(_ size: Int) -> Self { + var copy = self + copy.config.chunkSize = size + return copy + } + + public func withChunkOverlap(_ overlap: Int) -> Self { + var copy = self + copy.config.chunkOverlap = overlap + return copy + } + + public func withTopK(_ k: Int) -> Self { + var copy = self + copy.config.topKResults = k + return copy + } + + public func withSimilarityThreshold(_ threshold: Float) -> Self { + var copy = self + copy.config.similarityThreshold = threshold + return copy + } + + public func withApproach(_ approach: String) -> Self { + var copy = self + copy.config.approach = approach + return copy + } + + public func withEmbeddingDimension(_ dimension: Int) -> Self { + var copy = self + copy.config.embedding.dimension = dimension + return copy + } + + // MARK: - Backend selection + + /// Use the offline, deterministic hash embedder (the default). + public func withHashEmbeddings() -> Self { + var copy = self + copy.config.embedding.backend = "hash" + return copy + } + + /// Enable a local Ollama chat model (also used for LLM-based extraction). + public func withOllama( + host: String = "http://localhost", port: Int = 11434, chatModel: String = "llama3.2:3b" + ) -> Self { + var copy = self + copy.ollamaConfig.host = host + copy.ollamaConfig.port = port + copy.ollamaConfig.chatModel = chatModel + copy.useOllamaChat = true + return copy + } + + /// Use Ollama for embeddings instead of the hash embedder. + public func withOllamaEmbeddings(model: String = "nomic-embed-text", dimension: Int = 1024) -> Self { + var copy = self + copy.ollamaConfig.embeddingModel = model + copy.ollamaConfig.embeddingDimension = dimension + copy.config.embedding.backend = "ollama" + copy.config.embedding.dimension = dimension + return copy + } + + /// Preconfigure for a fully local Ollama setup (chat + embeddings). + public func withLocalDefaults() -> Self { + self.withOllama().withOllamaEmbeddings() + } + + public func withConfig(_ config: Config) -> Self { + var copy = self + copy.config = config + return copy + } + + // MARK: - Build + + /// Construct the configured `GraphRAG` engine. + public func build() throws -> GraphRAG { + // The embedding backend is driven solely by `config.embedding.backend`, + // so a later `withConfig(...)` can switch it back to hash (no sticky + // flag). Sync the Ollama embedder's dimension from the config. + let embedder: any EmbeddingModel + if config.embedding.backend.lowercased() == "ollama" { + var oc = ollamaConfig + oc.embeddingDimension = config.embedding.dimension + embedder = OllamaEmbedder(config: oc) + } else { + embedder = HashEmbedder(dimension: config.embedding.dimension) + } + + let languageModel: (any LanguageModel)? = + useOllamaChat ? OllamaClient(config: ollamaConfig) : nil + + let extractor: any EntityExtracting + if useOllamaChat { + extractor = LLMEntityExtractor(model: OllamaClient(config: ollamaConfig)) + } else { + extractor = PatternEntityExtractor(minConfidence: config.entity.minConfidence) + } + + return try GraphRAG( + config: config, embedder: embedder, languageModel: languageModel, extractor: extractor) + } +} diff --git a/Sources/GraphRAG/GraphRAG/Config.swift b/Sources/GraphRAG/GraphRAG/Config.swift new file mode 100644 index 0000000..f940465 --- /dev/null +++ b/Sources/GraphRAG/GraphRAG/Config.swift @@ -0,0 +1,75 @@ +// Config.swift +// Ported from graphrag-rs `config::mod`. Defaults mirror the Rust crate. + +import Foundation + +public struct EmbeddingConfig: Sendable { + public var dimension: Int + /// "hash" (offline, deterministic) or "ollama". + public var backend: String + + public init(dimension: Int = 384, backend: String = "hash") { + self.dimension = dimension + self.backend = backend + } +} + +public struct TextConfig: Sendable { + public var languages: [String] + + public init(languages: [String] = ["en"]) { + self.languages = languages + } +} + +public struct EntityConfig: Sendable { + public var minConfidence: Float + public var extractRelationships: Bool + + public init(minConfidence: Float = 0.7, extractRelationships: Bool = true) { + self.minConfidence = minConfidence + self.extractRelationships = extractRelationships + } +} + +/// Top-level GraphRAG configuration. +public struct Config: Sendable { + public var outputDir: String + public var chunkSize: Int + public var chunkOverlap: Int + public var maxEntitiesPerChunk: Int + public var topKResults: Int + public var similarityThreshold: Float + /// "semantic", "keyword", or "hybrid". + public var approach: String + + public var embedding: EmbeddingConfig + public var text: TextConfig + public var entity: EntityConfig + + public init( + outputDir: String = "./output", + chunkSize: Int = 1000, + chunkOverlap: Int = 200, + maxEntitiesPerChunk: Int = 10, + topKResults: Int = 10, + similarityThreshold: Float = 0.8, + approach: String = "hybrid", + embedding: EmbeddingConfig = EmbeddingConfig(), + text: TextConfig = TextConfig(), + entity: EntityConfig = EntityConfig() + ) { + self.outputDir = outputDir + self.chunkSize = chunkSize + self.chunkOverlap = chunkOverlap + self.maxEntitiesPerChunk = maxEntitiesPerChunk + self.topKResults = topKResults + self.similarityThreshold = similarityThreshold + self.approach = approach + self.embedding = embedding + self.text = text + self.entity = entity + } + + public static let `default` = Config() +} diff --git a/Sources/GraphRAG/GraphRAG/Engine.swift b/Sources/GraphRAG/GraphRAG/Engine.swift new file mode 100644 index 0000000..7f312eb --- /dev/null +++ b/Sources/GraphRAG/GraphRAG/Engine.swift @@ -0,0 +1,264 @@ +// Engine.swift +// Ported from graphrag-rs `graphrag::mod` / `build` / `ask`. +// +// `GraphRAG` is the high-level orchestrator. It is an `actor` so its mutable +// graph/index state is safe to share across tasks. Pluggable backends (embedder, +// optional LLM, entity extractor) are injected as existentials. + +import Foundation + +public actor GraphRAG { + public let config: Config + + private var graph: KnowledgeGraph + private let embedder: any EmbeddingModel + private let languageModel: (any LanguageModel)? + private let extractor: any EntityExtracting + private let textProcessor: TextProcessor + private var retriever: HybridRetriever + private var isBuilt: Bool = false + private var isBuilding: Bool = false + /// Bumped on every ingestion so a `build()` can detect documents added while + /// it was suspended at an `await` (actors are reentrant). + private var ingestionVersion: Int = 0 + + /// Designated initializer. + public init( + config: Config = .default, + embedder: (any EmbeddingModel)? = nil, + languageModel: (any LanguageModel)? = nil, + extractor: (any EntityExtracting)? = nil + ) throws { + self.config = config + self.graph = KnowledgeGraph() + self.embedder = embedder ?? GraphRAG.defaultEmbedder(for: config) + self.languageModel = languageModel + self.extractor = extractor ?? PatternEntityExtractor(minConfidence: config.entity.minConfidence) + self.textProcessor = try TextProcessor( + chunkSize: config.chunkSize, chunkOverlap: config.chunkOverlap) + self.retriever = HybridRetriever( + config: HybridConfig(maxCandidates: max(100, config.topKResults * 10))) + } + + /// Pick the default embedder honoring `config.embedding.backend` when no + /// embedder was injected. + private static func defaultEmbedder(for config: Config) -> any EmbeddingModel { + if config.embedding.backend.lowercased() == "ollama" { + return OllamaEmbedder( + config: OllamaConfig(embeddingDimension: config.embedding.dimension)) + } + return HashEmbedder(dimension: config.embedding.dimension) + } + + // MARK: - Ingestion + + /// Add raw text as a new document (auto-titled, UUID id) and chunk it. + @discardableResult + public func addDocument(text: String, title: String? = nil) -> DocumentID { + let id = DocumentID(UUID().uuidString) + let document = Document( + id: id, title: title ?? "Document \(graph.documentCount + 1)", content: text) + addDocument(document) + return id + } + + /// Add a pre-built document, chunking it if it has no chunks yet. Replacing a + /// document with the same id drops the previous version's chunks first, so + /// stale text can't linger in the index. + public func addDocument(_ document: Document) { + var doc = document + if doc.chunks.isEmpty { + doc.chunks = textProcessor.chunk(doc) + } + graph.removeChunks(forDocument: doc.id) + graph.addDocument(doc) + for chunk in doc.chunks { graph.addChunk(chunk) } + isBuilt = false + ingestionVersion += 1 + } + + // MARK: - Build + + /// Run the full indexing pipeline: extract entities/relationships, embed + /// chunks, and build the retrieval index. + public func build() async throws { + guard graph.documentCount > 0 else { throw GraphRAGError.noDocuments } + // Actors are reentrant at `await`, so refuse overlapping builds. + guard !isBuilding else { + throw GraphRAGError.validation(message: "A build is already in progress") + } + isBuilding = true + // Any failure below leaves the system unbuilt: ask() must require a fresh, + // successful build rather than querying half-rebuilt state. + isBuilt = false + defer { isBuilding = false } + + let startVersion = ingestionVersion + graph.clearEntitiesAndRelationships() + + // Operate on a fixed snapshot of chunk ids so documents ingested mid-build + // (which bump ingestionVersion) don't get half-processed this round. + let chunkIDs = graph.chunks.map(\.id) + + // Stage 1: entity & relationship extraction per chunk. + for id in chunkIDs { + guard let chunk = graph.chunk(id) else { continue } + let extracted = try await extractor.extract(from: chunk) + // Apply the configured confidence threshold uniformly — injected/LLM + // extractors don't self-filter the way the pattern extractor does. + var entities = extracted.entities.filter { + $0.confidence >= config.entity.minConfidence + } + let relationships = extracted.relationships + + // Honor the per-chunk entity cap, keeping the highest-confidence ones. + if config.maxEntitiesPerChunk > 0, entities.count > config.maxEntitiesPerChunk { + entities = Array( + entities.sorted { $0.confidence > $1.confidence } + .prefix(config.maxEntitiesPerChunk)) + } + + let keptIDs = Set(entities.map(\.id)) + for entity in entities { graph.addEntity(entity) } + // `extractRelationships == false` gates insertion here. We don't try + // to also suppress the extractor's own relationship work: the + // EntityExtracting protocol returns entities and relationships from a + // single call, so skipping the LLM's relationship prompting would + // require a protocol/prompt change. Deferred deliberately. + if config.entity.extractRelationships { + // Scope to entities that survived THIS chunk's cap — not global + // graph state — so the cap can't be defeated by an id that + // already exists from an earlier chunk. + for relationship in relationships + where keptIDs.contains(relationship.source) && keptIDs.contains(relationship.target) { + graph.addRelationship(relationship) + } + } + + // Only annotate the chunk if it wasn't replaced during the await + // (content still matches what we extracted from). A replacement bumps + // ingestionVersion, so the build is already marked unbuilt and will + // redo this next round rather than tagging new text with stale ids. + if var current = graph.chunk(id), current.content == chunk.content { + current.entities = entities.map(\.id) + graph.addChunk(current) + } + } + + // Stage 2: embed chunks — skipped entirely for keyword-only retrieval, + // which never uses embeddings (avoids embedder latency/failure, e.g. a + // remote Ollama embedder, when only BM25 is used). + if config.approach.lowercased() != "keyword" { + for id in chunkIDs { + guard let chunk = graph.chunk(id) else { continue } + let embedding = try await embedder.embed(chunk.content) + // Skip if the chunk was replaced during the embedding await + // (content changed), so we never attach an old-content embedding. + if var current = graph.chunk(id), current.content == chunk.content { + current.embedding = embedding + graph.addChunk(current) + } + } + } + + // Stage 3: build the hybrid retrieval index (index(graph:) clears first). + retriever.index(graph: graph) + + // Only declare success if no new documents arrived during the build; + // otherwise the index is already stale and a rebuild is required. + isBuilt = (ingestionVersion == startVersion) + } + + // MARK: - Query + + /// Answer a natural-language question over the indexed corpus. + public func ask(_ query: String) async throws -> Answer { + guard isBuilt else { throw GraphRAGError.notInitialized } + + let results = try await runRetrieval(query, limit: config.topKResults) + + guard !results.isEmpty else { + return Answer( + text: "I don't have enough information to answer this question.", + confidence: 0) + } + + let context = assembleContext(results) + let sources = results.map { ChunkID($0.id) } + let confidence = min(1.0, Float(results.count) / Float(max(1, config.topKResults))) + + // If an LLM is configured, synthesize a natural-language answer. + if let languageModel, await languageModel.isAvailable() { + let prompt = Prompts.fill( + Prompts.answerGeneration, ["context": context, "query": query]) + let raw = try await languageModel.complete(prompt) + return Answer( + text: GraphRAG.stripThinkingTags(raw), confidence: confidence, sources: sources) + } + + // Otherwise return an extractive summary of the top chunks. + let extractive = results.prefix(3).map(\.content).joined(separator: "\n\n") + return Answer( + text: "Based on the retrieved context:\n\n\(extractive)", + confidence: confidence, sources: sources) + } + + /// Hybrid search without answer synthesis. + public func search(_ query: String, limit: Int? = nil) async throws -> [HybridSearchResult] { + guard isBuilt else { throw GraphRAGError.notInitialized } + return try await runRetrieval(query, limit: limit ?? config.topKResults) + } + + /// Run retrieval honoring the configured `approach` (hybrid / keyword / + /// semantic) and the top-level `similarityThreshold`. + private func runRetrieval(_ query: String, limit: Int) async throws -> [HybridSearchResult] { + let approach = config.approach.lowercased() + let includeKeyword = approach != "semantic" + let includeSemantic = approach != "keyword" + let queryEmbedding = includeSemantic ? try await embedder.embed(query) : nil + return retriever.search( + query: query, + queryEmbedding: queryEmbedding, + limit: limit, + semanticThreshold: config.similarityThreshold, + includeKeyword: includeKeyword, + includeSemantic: includeSemantic) + } + + // MARK: - Introspection + + public func stats() -> Stats { + Stats( + documentCount: graph.documentCount, + chunkCount: graph.chunkCount, + entityCount: graph.entityCount, + relationshipCount: graph.relationshipCount) + } + + /// Direct access to the underlying knowledge graph (a value-type snapshot). + public func knowledgeGraph() -> KnowledgeGraph { graph } + + /// Persist the knowledge graph to JSON. + public func save(toJSON path: String) throws { try graph.save(toJSON: path) } + + // MARK: - Helpers + + private func assembleContext(_ results: [HybridSearchResult]) -> String { + results.map { result in + let score = String(format: "%.3f", result.score) + return "[Chunk | Relevance: \(score)]\n\(result.content)" + }.joined(separator: "\n\n---\n\n") + } + + /// Remove `...` blocks emitted by some reasoning models. + static func stripThinkingTags(_ text: String) -> String { + var result = text + while let open = result.range(of: ""), + let close = result.range(of: ""), + open.lowerBound < close.lowerBound + { + result.removeSubrange(open.lowerBound.. Bool { + guard let entry = entries.removeValue(forKey: id) else { return false } + order.removeAll { $0 == id } + totalLength -= entry.length + for term in entry.termCounts.keys { + if let df = documentFrequency[term] { + if df <= 1 { documentFrequency.removeValue(forKey: term) } + else { documentFrequency[term] = df - 1 } + } + } + return true + } + + public mutating func clear() { + entries.removeAll() + order.removeAll() + documentFrequency.removeAll() + totalLength = 0 + } + + public func content(for id: String) -> String? { entries[id]?.content } + + /// Score and rank documents against `query`, returning the top `limit`. + public func search(_ query: String, limit: Int) -> [BM25Result] { + guard !entries.isEmpty, limit > 0 else { return [] } + let queryTerms = Set(BM25Retriever.tokenize(query)) + guard !queryTerms.isEmpty else { return [] } + + let n = Float(entries.count) + let avgdl = averageDocumentLength + + var results: [BM25Result] = [] + for id in order { + guard let entry = entries[id] else { continue } + var score: Float = 0 + for term in queryTerms { + guard let rawCount = entry.termCounts[term], rawCount > 0 else { continue } + let df = Float(documentFrequency[term] ?? 1) + let idf = log(n / df) + 1.0 + // Standard BM25 uses the raw term count; document-length + // normalization is handled solely by the `|D| / avgdl` factor in + // the denominator (normalizing tf here as well would penalize + // length twice). + let tf = Float(rawCount) + let denom = tf + k1 * (1 - b + b * (Float(entry.length) / max(avgdl, 1))) + score += idf * (tf * (k1 + 1)) / max(denom, 0.0001) + } + if score > 0 { + results.append(BM25Result(id: id, score: score, content: entry.content)) + } + } + + results.sort { lhs, rhs in + if lhs.score == rhs.score { return lhs.id < rhs.id } + return lhs.score > rhs.score + } + return Array(results.prefix(limit)) + } + + // MARK: - Tokenization + + static func tokenize(_ text: String) -> [String] { + var tokens: [String] = [] + var current = "" + func flush() { + // Keep 2-letter acronyms (AI, ML, EU); only single chars are dropped. + // Common short words are already removed by the stopword filter. + if current.count >= 2, !TfIdfKeywordExtractor.defaultStopwords.contains(current) { + tokens.append(current) + } + current = "" + } + // Split on any non-alphanumeric so punctuation separates terms + // ("graph-based" -> "graph", "based") instead of concatenating them. + for ch in text { + if ch.isLetter || ch.isNumber { + current.append(contentsOf: ch.lowercased()) + } else { + flush() + } + } + flush() + return tokens + } +} diff --git a/Sources/GraphRAG/Retrieval/Hybrid.swift b/Sources/GraphRAG/Retrieval/Hybrid.swift new file mode 100644 index 0000000..9aac413 --- /dev/null +++ b/Sources/GraphRAG/Retrieval/Hybrid.swift @@ -0,0 +1,211 @@ +// Hybrid.swift +// Ported from graphrag-rs `retrieval::hybrid`. + +import Foundation + +/// Strategy used to merge ranked lists from different retrievers. +public enum FusionMethod: Sendable, Equatable { + /// Reciprocal Rank Fusion (default). + case rrf + /// Weighted sum of max-normalized scores. + case weighted + /// Raw sum of scores. + case combSum + /// Maximum of the per-method scores. + case maxScore +} + +/// Configuration for `HybridRetriever`. +public struct HybridConfig: Sendable { + public var semanticWeight: Float + public var keywordWeight: Float + public var fusionMethod: FusionMethod + public var rrfK: Float + public var maxCandidates: Int + public var minScoreThreshold: Float + + public init( + semanticWeight: Float = 0.7, + keywordWeight: Float = 0.3, + fusionMethod: FusionMethod = .rrf, + rrfK: Float = 60.0, + maxCandidates: Int = 100, + minScoreThreshold: Float = 0.1 + ) { + self.semanticWeight = semanticWeight + self.keywordWeight = keywordWeight + self.fusionMethod = fusionMethod + self.rrfK = rrfK + self.maxCandidates = maxCandidates + self.minScoreThreshold = minScoreThreshold + } +} + +/// A fused search hit combining keyword and semantic signals. +public struct HybridSearchResult: Sendable, Equatable { + public var id: String + public var content: String + public var score: Float + public var semanticScore: Float + public var keywordScore: Float + + public init(id: String, content: String, score: Float, semanticScore: Float, keywordScore: Float) { + self.id = id + self.content = content + self.score = score + self.semanticScore = semanticScore + self.keywordScore = keywordScore + } +} + +/// Combines BM25 keyword search with cosine vector search over a chunk corpus. +public struct HybridRetriever: Sendable { + public var config: HybridConfig + private var bm25: BM25Retriever + private var vectors: InMemoryVectorStore + private var contents: [String: String] = [:] + + public init(config: HybridConfig = HybridConfig()) { + self.config = config + self.bm25 = BM25Retriever() + self.vectors = InMemoryVectorStore() + } + + public var isInitialized: Bool { !contents.isEmpty } + public var documentCount: Int { contents.count } + + /// Index a chunk for keyword search, and for semantic search if it carries + /// an embedding. + public mutating func index(id: String, content: String, embedding: [Float]?) { + contents[id] = content + bm25.index(id: id, content: content) + if let embedding { + vectors.add(id: id, vector: embedding) + } else { + // Drop any vector from a previous version so semantic search can't + // return this id using a stale embedding. + vectors.remove(id: id) + } + } + + /// Index all chunks of a knowledge graph as a full (re)index. Clears any + /// previously indexed content first, so ids removed since the last index + /// can't linger in `contents`, BM25, or the vector store. + public mutating func index(graph: KnowledgeGraph) { + clear() + for chunk in graph.chunks { + index(id: chunk.id.raw, content: chunk.content, embedding: chunk.embedding) + } + } + + public mutating func clear() { + bm25.clear() + vectors.clear() + contents.removeAll() + } + + /// Run both retrievers and fuse the results. + /// + /// - Parameters: + /// - query: The raw query text (for BM25). + /// - queryEmbedding: Optional query vector (for semantic search). + /// - limit: Number of fused results to return. + /// - semanticThreshold: Minimum cosine similarity for a semantic hit. + /// - includeKeyword: Include BM25 results (false for a semantic-only approach). + /// - includeSemantic: Include vector results (false for a keyword-only approach). + public func search( + query: String, + queryEmbedding: [Float]?, + limit: Int, + semanticThreshold: Float = 0, + includeKeyword: Bool = true, + includeSemantic: Bool = true + ) -> [HybridSearchResult] { + // A negative limit would trap in `prefix`; treat anything <= 0 as empty. + guard limit > 0 else { return [] } + let keyword: [(id: String, score: Float)] = + includeKeyword + ? bm25.search(query, limit: config.maxCandidates).map { (id: $0.id, score: $0.score) } + : [] + // Drop non-positive cosine hits (off-topic protection: the vector store + // always returns its nearest `maxCandidates`) and anything below the + // caller's similarity threshold. + let semantic: [(id: String, score: Float)] = + (includeSemantic ? queryEmbedding : nil).map { + vectors.search($0, k: config.maxCandidates) + .filter { $0.score > 0 && $0.score >= semanticThreshold } + } ?? [] + + let fused = fuse(semantic: semantic, keyword: keyword) + // RRF scores are rank-based and inherently small (≈ 1/(k+rank)); the + // absolute `minScoreThreshold` only makes sense for magnitude-based + // fusion (weighted / CombSUM / MaxScore). + let applyThreshold = config.fusionMethod != .rrf + return Array( + fused + .filter { !applyThreshold || $0.score >= config.minScoreThreshold } + .prefix(limit) + ) + } + + // MARK: - Fusion + + private func fuse( + semantic: [(id: String, score: Float)], + keyword: [(id: String, score: Float)] + ) -> [HybridSearchResult] { + var semScore: [String: Float] = [:] + var kwScore: [String: Float] = [:] + var semRank: [String: Int] = [:] + var kwRank: [String: Int] = [:] + for (rank, item) in semantic.enumerated() { + semScore[item.id] = item.score + semRank[item.id] = rank + } + for (rank, item) in keyword.enumerated() { + kwScore[item.id] = item.score + kwRank[item.id] = rank + } + + let maxSem = semantic.map(\.score).max() ?? 0 + let maxKw = keyword.map(\.score).max() ?? 0 + let allIDs = Set(semScore.keys).union(kwScore.keys) + + var results: [HybridSearchResult] = [] + for id in allIDs { + let sem = semScore[id] ?? 0 + let kw = kwScore[id] ?? 0 + let combined: Float + switch config.fusionMethod { + case .rrf: + var s: Float = 0 + if let r = semRank[id] { + s += (1.0 / (config.rrfK + Float(r) + 1.0)) * config.semanticWeight + } + if let r = kwRank[id] { + s += (1.0 / (config.rrfK + Float(r) + 1.0)) * config.keywordWeight + } + combined = s + case .weighted: + let nSem = maxSem > 0 ? sem / maxSem : 0 + let nKw = maxKw > 0 ? kw / maxKw : 0 + combined = nSem * config.semanticWeight + nKw * config.keywordWeight + case .combSum: + combined = sem + kw + case .maxScore: + combined = max(sem, kw) + } + results.append( + HybridSearchResult( + id: id, content: contents[id] ?? "", + score: combined, semanticScore: sem, keywordScore: kw)) + } + + results.sort { lhs, rhs in + if lhs.score == rhs.score { return lhs.id < rhs.id } + return lhs.score > rhs.score + } + return results + } +} + diff --git a/Sources/GraphRAG/Retrieval/VectorStore.swift b/Sources/GraphRAG/Retrieval/VectorStore.swift new file mode 100644 index 0000000..639f24f --- /dev/null +++ b/Sources/GraphRAG/Retrieval/VectorStore.swift @@ -0,0 +1,84 @@ +// VectorStore.swift +// Ported from graphrag-rs `storage` in-memory vector store. + +import Foundation + +/// Cosine similarity between two equal-length vectors. Returns 0 if either is +/// zero-length or dimensions mismatch. +public func cosineSimilarity(_ a: [Float], _ b: [Float]) -> Float { + guard a.count == b.count, !a.isEmpty else { return 0 } + var dot: Float = 0 + var normA: Float = 0 + var normB: Float = 0 + for i in 0.. 0 ? dot / denom : 0 +} + +/// A brute-force, cosine-similarity in-memory vector store. +public struct InMemoryVectorStore: Sendable { + private var vectors: [String: [Float]] = [:] + private var order: [String] = [] + + public init() {} + + public var count: Int { vectors.count } + public var isEmpty: Bool { vectors.isEmpty } + public var ids: [String] { order } + public var dimension: Int? { order.first.flatMap { vectors[$0]?.count } } + + public func contains(_ id: String) -> Bool { vectors[id] != nil } + public func embedding(for id: String) -> [Float]? { vectors[id] } + + /// Insert or replace a vector. + public mutating func add(id: String, vector: [Float]) { + if vectors[id] == nil { order.append(id) } + vectors[id] = vector + } + + public mutating func addBatch(_ items: [(id: String, vector: [Float])]) { + for item in items { add(id: item.id, vector: item.vector) } + } + + @discardableResult + public mutating func remove(id: String) -> Bool { + guard vectors.removeValue(forKey: id) != nil else { return false } + order.removeAll { $0 == id } + return true + } + + public mutating func clear() { + vectors.removeAll() + order.removeAll() + } + + /// Top-`k` ids by descending cosine similarity to `query`. + public func search(_ query: [Float], k: Int) -> [(id: String, score: Float)] { + guard !vectors.isEmpty, k > 0 else { return [] } + var scored: [(id: String, score: Float)] = [] + scored.reserveCapacity(order.count) + for id in order { + guard let v = vectors[id] else { continue } + scored.append((id, cosineSimilarity(query, v))) + } + scored.sort { lhs, rhs in + if lhs.score == rhs.score { return lhs.id < rhs.id } + return lhs.score > rhs.score + } + return Array(scored.prefix(k)) + } + + /// Like `search`, but discards results below `threshold`. + public func search(_ query: [Float], k: Int, threshold: Float) -> [(id: String, score: Float)] { + search(query, k: k).filter { $0.score >= threshold } + } + + /// All vectors whose similarity to `query` is at least `threshold`. + public func findSimilar(_ query: [Float], threshold: Float) -> [(id: String, score: Float)] { + search(query, k: order.count, threshold: threshold) + } +} diff --git a/Sources/GraphRAG/Text/Chunking.swift b/Sources/GraphRAG/Text/Chunking.swift new file mode 100644 index 0000000..22c3f00 --- /dev/null +++ b/Sources/GraphRAG/Text/Chunking.swift @@ -0,0 +1,243 @@ +// Chunking.swift +// Ported from graphrag-rs `text::chunking` (HierarchicalChunker) and the +// `TextProcessor` API in `text::mod`. +// +// The Rust implementation works on UTF-8 byte indices and guards every slice +// with `is_char_boundary`. Swift's `Character` (extended grapheme cluster) is +// always a valid boundary, so this port operates over a `[Character]` array and +// measures sizes/offsets in characters. For typical text this matches the byte +// behaviour while remaining Unicode-safe by construction. + +import Foundation + +/// A chunk's content together with its character offsets in the source text. +public struct ChunkSpan: Sendable, Equatable { + public var content: String + public var startOffset: Int + public var endOffset: Int + + public init(content: String, startOffset: Int, endOffset: Int) { + self.content = content + self.startOffset = startOffset + self.endOffset = endOffset + } +} + +/// Recursive, separator-aware chunker. +/// +/// Splits on a hierarchy of separators (paragraph → line → sentence → clause → +/// word), preferring the "highest" separator that yields a boundary past the +/// first quarter of the window. +public struct HierarchicalChunker: Sendable { + /// Ordered, most-significant-first list of separators. + public var separators: [String] + /// Chunks whose trimmed length is below this are discarded. + public var minChunkSize: Int + + public static let defaultSeparators: [String] = [ + "\n\n", "\n", ". ", "! ", "? ", "; ", ": ", " ", "", + ] + + public init(separators: [String] = HierarchicalChunker.defaultSeparators, minChunkSize: Int = 50) { + self.separators = separators + self.minChunkSize = minChunkSize + } + + public func withSeparators(_ separators: [String]) -> HierarchicalChunker { + HierarchicalChunker(separators: separators, minChunkSize: minChunkSize) + } + + public func withMinSize(_ size: Int) -> HierarchicalChunker { + HierarchicalChunker(separators: separators, minChunkSize: size) + } + + /// Split `text` into chunk strings of approximately `chunkSize` characters, + /// overlapping consecutive chunks by `overlap` characters. + public func chunkText(_ text: String, chunkSize: Int, overlap: Int) -> [String] { + chunkSpans(text, chunkSize: chunkSize, overlap: overlap).map(\.content) + } + + /// Like `chunkText` but also returns character offsets for each chunk. + public func chunkSpans(_ text: String, chunkSize: Int, overlap: Int) -> [ChunkSpan] { + let chars = Array(text) + let n = chars.count + guard n > 0, chunkSize > 0 else { return [] } + // Clamp: a negative overlap would advance past the chunk end and skip + // text. (TextProcessor rejects it, but this public API must be safe too.) + let overlap = max(0, overlap) + + var spans: [ChunkSpan] = [] + var start = 0 + + while start < n { + var end = min(start + chunkSize, n) + + // Final chunk: take the remainder. + if end >= n { + let slice = chars[start..= minChunkSize || spans.isEmpty { + spans.append(makeSpan(slice, start: start, end: n)) + } + break + } + + let optimalEnd = findOptimalBoundary(chars, start: start, maxEnd: end) + if optimalEnd > start { end = optimalEnd } + + let slice = chars[start..= minChunkSize { + spans.append(makeSpan(slice, start: start, end: end)) + } + + // Advance with overlap, snapped back to a word boundary. + var nextStart = max(0, end - overlap) + nextStart = findWordBoundaryBackward(chars, pos: nextStart) + // Guarantee forward progress. + if nextStart <= start { nextStart = end } + start = nextStart + } + + return spans + } + + // MARK: - Boundary helpers + + private func makeSpan(_ slice: ArraySlice, start: Int, end: Int) -> ChunkSpan { + ChunkSpan(content: String(slice), startOffset: start, endOffset: end) + } + + private func trimmedCount(_ slice: ArraySlice) -> Int { + String(slice).trimmingCharacters(in: .whitespacesAndNewlines).count + } + + /// Find the best split point in `chars[start.. Int { + let rangeLen = maxEnd - start + guard rangeLen > 0 else { return maxEnd } + let quarter = rangeLen / 4 + + for separator in separators where !separator.isEmpty { + let sep = Array(separator) + if let matchStart = lastRange(of: sep, in: chars, start: start, end: maxEnd) { + let boundary = matchStart + sep.count + if boundary > start + quarter { + return boundary + } + } + } + return findWordBoundaryBackward(chars, pos: maxEnd) + } + + /// Largest index `p <= pos` such that the character before `p` is whitespace. + func findWordBoundaryBackward(_ chars: [Character], pos: Int) -> Int { + var p = min(pos, chars.count) + while p > 0 { + if chars[p - 1].isWhitespace { return p } + p -= 1 + } + return 0 + } + + /// Last start-index of `needle` within `chars[start.. Int? { + guard !needle.isEmpty, end - start >= needle.count else { return nil } + var i = end - needle.count + while i >= start { + var matched = true + for j in 0.. 0 else { + throw GraphRAGError.config(message: "chunk_size must be > 0") + } + guard chunkOverlap >= 0 else { + throw GraphRAGError.config(message: "chunk_overlap must be >= 0") + } + guard chunkOverlap < chunkSize else { + throw GraphRAGError.config(message: "chunk_overlap must be < chunk_size") + } + self.chunkSize = chunkSize + self.chunkOverlap = chunkOverlap + // For shorter documents the 50-char minimum can drop everything; scale + // the floor down for small chunk sizes. + let minSize = min(50, max(1, chunkSize / 4)) + self.chunker = HierarchicalChunker(minChunkSize: minSize) + self.keywordExtractor = TfIdfKeywordExtractor() + } + + /// Hierarchically chunk a document into `TextChunk`s with offsets and metadata. + public func chunk(_ document: Document) -> [TextChunk] { + let spans = chunker.chunkSpans(document.content, chunkSize: chunkSize, overlap: chunkOverlap) + var chunks: [TextChunk] = [] + chunks.reserveCapacity(spans.count) + for (index, span) in spans.enumerated() { + let id = ChunkID("\(document.id.raw)_\(index)") + let metadata = ChunkMetadata( + index: index, + wordCount: wordCount(span.content), + keywords: extractKeywords(span.content, maxKeywords: 5) + ) + chunks.append( + TextChunk( + id: id, + documentID: document.id, + content: span.content, + startOffset: span.startOffset, + endOffset: span.endOffset, + metadata: metadata + ) + ) + } + return chunks + } + + /// Extract up to `maxKeywords` keywords from `text`. + public func extractKeywords(_ text: String, maxKeywords: Int) -> [String] { + keywordExtractor.extractKeywordStrings(text, topK: maxKeywords) + } + + /// Naive sentence splitter on `.`, `!`, `?`, and newlines. + public func extractSentences(_ text: String) -> [String] { + var sentences: [String] = [] + var current = "" + for ch in text { + current.append(ch) + if ch == "." || ch == "!" || ch == "?" || ch == "\n" { + let trimmed = current.trimmingCharacters(in: .whitespacesAndNewlines) + if !trimmed.isEmpty { sentences.append(trimmed) } + current = "" + } + } + let tail = current.trimmingCharacters(in: .whitespacesAndNewlines) + if !tail.isEmpty { sentences.append(tail) } + return sentences + } + + /// Collapse runs of whitespace and trim. + public func cleanText(_ text: String) -> String { + let collapsed = text.split(whereSeparator: { $0.isWhitespace }).joined(separator: " ") + return collapsed.trimmingCharacters(in: .whitespacesAndNewlines) + } + + public func wordCount(_ text: String) -> Int { + text.split(whereSeparator: { $0.isWhitespace }).count + } +} diff --git a/Sources/GraphRAG/Text/KeywordExtraction.swift b/Sources/GraphRAG/Text/KeywordExtraction.swift new file mode 100644 index 0000000..727bccc --- /dev/null +++ b/Sources/GraphRAG/Text/KeywordExtraction.swift @@ -0,0 +1,105 @@ +// KeywordExtraction.swift +// Ported from graphrag-rs `text::keyword_extraction` (TfIdfKeywordExtractor). + +import Foundation + +/// TF-IDF keyword extractor. +/// +/// Maintains corpus document frequencies so IDF can be computed across a growing +/// collection. With an empty corpus every term has an assumed document frequency +/// of 1 (treated as rare), so scoring degrades gracefully to plain TF weighting. +public struct TfIdfKeywordExtractor: Sendable { + public private(set) var documentFrequencies: [String: Int] + public private(set) var totalDocuments: Int + public let stopwords: Set + + public init(documentFrequencies: [String: Int] = [:], totalDocuments: Int = 0) { + self.documentFrequencies = documentFrequencies + // Start at the true count (0 for a fresh corpus). The smoothed IDF below + // handles an empty corpus without a phantom document. + self.totalDocuments = max(0, totalDocuments) + self.stopwords = TfIdfKeywordExtractor.defaultStopwords + } + + /// Extract the top-`topK` `(term, score)` pairs, sorted by descending score. + public func extractKeywords(_ text: String, topK: Int) -> [(term: String, score: Float)] { + let tokens = tokenize(text) + guard !tokens.isEmpty, topK > 0 else { return [] } + + // Term frequency (normalized by document length). + var counts: [String: Int] = [:] + for token in tokens { counts[token, default: 0] += 1 } + let totalTerms = Float(tokens.count) + + var scored: [(term: String, score: Float)] = [] + scored.reserveCapacity(counts.count) + for (term, count) in counts { + let tf = Float(count) / totalTerms + let idf = inverseDocumentFrequency(term) + scored.append((term, tf * idf)) + } + + scored.sort { lhs, rhs in + if lhs.score == rhs.score { return lhs.term < rhs.term } + return lhs.score > rhs.score + } + return Array(scored.prefix(topK)) + } + + /// Extract just the top-`topK` keyword strings. + public func extractKeywordStrings(_ text: String, topK: Int) -> [String] { + extractKeywords(text, topK: topK).map(\.term) + } + + /// Add a document's terms to the corpus statistics (for IDF). + public mutating func addDocumentToCorpus(_ text: String) { + let unique = Set(tokenize(text)) + for term in unique { documentFrequencies[term, default: 0] += 1 } + totalDocuments += 1 + } + + public func corpusStats() -> (totalDocuments: Int, uniqueTerms: Int) { + (totalDocuments, documentFrequencies.count) + } + + // MARK: - Internals + + private func inverseDocumentFrequency(_ term: String) -> Float { + let df = documentFrequencies[term] ?? 1 + // Smoothed IDF: stays strictly positive even for an empty corpus + // (N = 1, df = 1 -> 1.0), so ranking falls back to term frequency rather + // than collapsing every score to zero. + let idf = log(Float(totalDocuments + 1) / Float(df + 1)) + 1.0 + return max(idf, 0.0) + } + + /// Lowercase, keep alphanumerics/`-`/`_`, drop short / numeric / stopword tokens. + func tokenize(_ text: String) -> [String] { + var tokens: [String] = [] + for rawWord in text.split(whereSeparator: { $0.isWhitespace }) { + var cleaned = "" + for ch in rawWord where ch.isLetter || ch.isNumber || ch == "-" || ch == "_" { + cleaned.append(contentsOf: ch.lowercased()) + } + if cleaned.count <= 2 { continue } + if stopwords.contains(cleaned) { continue } + if cleaned.allSatisfy({ $0.isNumber }) { continue } + tokens.append(cleaned) + } + return tokens + } + + public static let defaultStopwords: Set = [ + "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", + "for", "not", "on", "with", "he", "as", "you", "do", "at", "this", "but", + "his", "by", "from", "they", "we", "say", "her", "she", "or", "an", "will", + "my", "one", "all", "would", "there", "their", "what", "so", "up", "out", + "if", "about", "who", "get", "which", "go", "me", "when", "make", "can", + "like", "time", "no", "just", "him", "know", "take", "people", "into", + "year", "your", "good", "some", "could", "them", "see", "other", "than", + "then", "now", "look", "only", "come", "its", "over", "think", "also", + "back", "after", "use", "two", "how", "our", "work", "first", "well", + "way", "even", "new", "want", "because", "any", "these", "give", "day", + "most", "us", "is", "was", "are", "been", "has", "had", "were", "said", "did", + ] +} diff --git a/Tests/GraphRAGTests/GraphRAGTests.swift b/Tests/GraphRAGTests/GraphRAGTests.swift index 5616821..ddd73af 100644 --- a/Tests/GraphRAGTests/GraphRAGTests.swift +++ b/Tests/GraphRAGTests/GraphRAGTests.swift @@ -1,8 +1,274 @@ import Testing + @testable import GraphRAG -@Test func example() async throws { - // Write your test here and use APIs like `#expect(...)` to check expected conditions. - // Swift Testing Documentation - // https://developer.apple.com/documentation/testing +// MARK: - Text chunking + +@Test func chunkerProducesOverlappingChunks() throws { + let text = String( + repeating: + "The quick brown fox jumps over the lazy dog. Knowledge graphs connect entities. ", + count: 20) + let chunker = HierarchicalChunker(minChunkSize: 10) + let spans = chunker.chunkSpans(text, chunkSize: 200, overlap: 50) + + #expect(spans.count > 1) + // Offsets are ordered and within bounds. + for span in spans { + #expect(span.startOffset >= 0) + #expect(span.endOffset <= text.count) + #expect(span.startOffset < span.endOffset) + } + // Consecutive chunks overlap (next start is before previous end). + for i in 1.. 0) +} + +// MARK: - Keyword extraction + +@Test func tfidfExtractsContentKeywords() { + let extractor = TfIdfKeywordExtractor() + let keywords = extractor.extractKeywordStrings( + "Knowledge graphs represent entities and relationships between entities.", topK: 3) + #expect(!keywords.isEmpty) + // Stopwords like "and"/"between" must be filtered out. + #expect(!keywords.contains("and")) +} + +// MARK: - BM25 + +@Test func bm25RanksRelevantDocumentFirst() { + var bm25 = BM25Retriever() + bm25.index(id: "a", content: "Graph databases store nodes and edges efficiently.") + bm25.index(id: "b", content: "Cooking recipes for delicious pasta dishes.") + bm25.index(id: "c", content: "Knowledge graphs use nodes edges and graph traversal.") + + let results = bm25.search("graph nodes edges", limit: 3) + #expect(!results.isEmpty) + // A graph-related doc should outrank the cooking doc. + #expect(results.first?.id == "a" || results.first?.id == "c") + #expect(!results.contains { $0.id == "b" && $0.score > (results.first?.score ?? 0) }) +} + +// MARK: - Vector store & embeddings + +@Test func cosineSimilarityBasics() { + #expect(abs(cosineSimilarity([1, 0], [1, 0]) - 1.0) < 1e-6) + #expect(abs(cosineSimilarity([1, 0], [0, 1])) < 1e-6) +} + +@Test func hashEmbedderIsDeterministicAndDimensioned() { + let embedder = HashEmbedder(dimension: 64) + let a = embedder.embedSync("knowledge graph retrieval") + let b = embedder.embedSync("knowledge graph retrieval") + #expect(a == b) + #expect(a.count == 64) +} + +@Test func vectorStoreReturnsNearestNeighbor() { + let embedder = HashEmbedder(dimension: 128) + var store = InMemoryVectorStore() + store.add(id: "graphs", vector: embedder.embedSync("graphs nodes edges entities")) + store.add(id: "cooking", vector: embedder.embedSync("cooking pasta tomato recipe")) + + let query = embedder.embedSync("entities and nodes in graphs") + let results = store.search(query, k: 2) + #expect(results.first?.id == "graphs") +} + +// MARK: - Knowledge graph + +@Test func knowledgeGraphStoresEntitiesAndNeighbors() { + var graph = KnowledgeGraph() + let ada = Entity(id: "person_ada", name: "Ada Lovelace", entityType: "PERSON") + let babbage = Entity(id: "person_babbage", name: "Charles Babbage", entityType: "PERSON") + graph.addEntity(ada) + graph.addEntity(babbage) + graph.addRelationship( + Relationship(source: ada.id, target: babbage.id, relationType: "COLLEAGUE_OF")) + + #expect(graph.entityCount == 2) + #expect(graph.relationshipCount == 1) + let neighbors = graph.neighbors(of: ada.id) + #expect(neighbors.contains { $0.neighbor == babbage.id }) + // Bidirectional lookup. + #expect(graph.neighbors(of: babbage.id).contains { $0.neighbor == ada.id }) +} + +@Test func knowledgeGraphMergesDuplicateRelationships() { + var graph = KnowledgeGraph() + graph.addEntity(Entity(id: "a", name: "A", entityType: "X")) + graph.addEntity(Entity(id: "b", name: "B", entityType: "X")) + graph.addRelationship(Relationship(source: "a", target: "b", relationType: "R", confidence: 0.5)) + graph.addRelationship(Relationship(source: "a", target: "b", relationType: "R", confidence: 0.9)) + #expect(graph.relationshipCount == 1) + #expect(graph.relationships[0].confidence == 0.9) +} + +// MARK: - Graph algorithms + +@Test func pageRankScoresSumToOneAndRankHub() { + var graph = KnowledgeGraph() + for name in ["a", "b", "c", "hub"] { + graph.addEntity(Entity(id: EntityID(name), name: name, entityType: "X")) + } + // Everyone points to the hub. + graph.addRelationship(Relationship(source: "a", target: "hub", relationType: "R")) + graph.addRelationship(Relationship(source: "b", target: "hub", relationType: "R")) + graph.addRelationship(Relationship(source: "c", target: "hub", relationType: "R")) + + let scores = PageRank().compute(graph) + let total = scores.values.reduce(0, +) + #expect(abs(total - 1.0) < 1e-6) + let hub = scores[EntityID("hub")] ?? 0 + #expect(hub > (scores[EntityID("a")] ?? 0)) +} + +@Test func bfsTraversalRespectsDepth() { + var graph = KnowledgeGraph() + for name in ["a", "b", "c", "d"] { + graph.addEntity(Entity(id: EntityID(name), name: name, entityType: "X")) + } + graph.addRelationship(Relationship(source: "a", target: "b", relationType: "R", confidence: 1)) + graph.addRelationship(Relationship(source: "b", target: "c", relationType: "R", confidence: 1)) + graph.addRelationship(Relationship(source: "c", target: "d", relationType: "R", confidence: 1)) + + let traversal = GraphTraversal(config: TraversalConfig(maxDepth: 2, minRelationshipStrength: 0.5)) + let result = traversal.bfs(graph, from: "a") + #expect(result.distances[EntityID("a")] == 0) + #expect(result.distances[EntityID("b")] == 1) + #expect(result.distances[EntityID("c")] == 2) + // 'd' is at depth 3, beyond maxDepth. + #expect(result.distances[EntityID("d")] == nil) +} + +@Test func analyticsDegreeAndComponents() { + var graph = KnowledgeGraph() + for name in ["a", "b", "c"] { + graph.addEntity(Entity(id: EntityID(name), name: name, entityType: "X")) + } + graph.addRelationship(Relationship(source: "a", target: "b", relationType: "R")) + let analytics = GraphAnalytics(graph) + // 'a' connects to 'b' out of 2 possible -> 0.5. + #expect(abs(analytics.degreeCentrality("a") - 0.5) < 1e-6) + // 'a'+'b' connected, 'c' isolated -> 2 components. + #expect(analytics.connectedComponents().count == 2) +} + +// MARK: - Pattern extraction + +@Test func patternExtractorFindsPeople() async throws { + let extractor = PatternEntityExtractor(minConfidence: 0.5) + let chunk = TextChunk( + id: "c0", documentID: "d0", + content: "Ada Lovelace worked with Charles Babbage in London.", + startOffset: 0, endOffset: 0) + let (entities, _) = try await extractor.extract(from: chunk) + let names = Set(entities.map(\.name)) + #expect(names.contains("Ada Lovelace")) + #expect(names.contains("Charles Babbage")) +} + +// MARK: - End-to-end pipeline + +@Test func endToEndBuildAndAskWithoutLLM() async throws { + let rag = try GraphRAGBuilder() + .withChunkSize(400) + .withChunkOverlap(50) + .withTopK(3) + .build() + + await rag.addDocument( + text: """ + Ada Lovelace was an English mathematician. She collaborated with Charles Babbage + on the Analytical Engine, an early mechanical general-purpose computer. Ada is + often regarded as the first computer programmer. + """) + await rag.addDocument( + text: "Pasta is cooked in boiling water with salt. Tomato sauce is a common topping.") + + try await rag.build() + + let stats = await rag.stats() + #expect(stats.documentCount == 2) + #expect(stats.chunkCount >= 2) + #expect(stats.entityCount > 0) + + let answer = try await rag.ask("Who worked on the Analytical Engine?") + #expect(!answer.text.isEmpty) + #expect(!answer.sources.isEmpty) + // The relevant (computing) chunk should be retrieved over the pasta chunk. + #expect(answer.text.lowercased().contains("babbage") || answer.text.lowercased().contains("ada")) +} + +@Test func askBeforeBuildThrows() async throws { + let rag = try GraphRAGBuilder().build() + await rag.addDocument(text: "Some content about graphs and entities.") + await #expect(throws: GraphRAGError.self) { + _ = try await rag.ask("anything") + } +} + +// MARK: - Review regressions + +@Test func negativeChunkOverlapRejected() { + #expect(throws: GraphRAGError.self) { + _ = try TextProcessor(chunkSize: 100, chunkOverlap: -10) + } +} + +@Test func replacingDocumentRemovesStaleChunks() async throws { + let rag = try GraphRAGBuilder().withChunkSize(500).withChunkOverlap(50).build() + let id = DocumentID("fixed-id") + await rag.addDocument(Document(id: id, title: "v1", content: "First version about apples.")) + await rag.addDocument(Document(id: id, title: "v2", content: "Second version about oranges.")) + try await rag.build() + let stats = await rag.stats() + // Only the replacement's chunk(s) should remain, not both versions'. + #expect(stats.documentCount == 1) + #expect(stats.chunkCount == 1) + let answer = try await rag.ask("oranges") + #expect(!answer.text.lowercased().contains("apples")) +} + +@Test func dfsDoesNotRecordEdgesBeyondMaxDepth() { + var graph = KnowledgeGraph() + for name in ["a", "b", "c", "d"] { + graph.addEntity(Entity(id: EntityID(name), name: name, entityType: "X")) + } + graph.addRelationship(Relationship(source: "a", target: "b", relationType: "R", confidence: 1)) + graph.addRelationship(Relationship(source: "b", target: "c", relationType: "R", confidence: 1)) + graph.addRelationship(Relationship(source: "c", target: "d", relationType: "R", confidence: 1)) + + let traversal = GraphTraversal(config: TraversalConfig(maxDepth: 2, minRelationshipStrength: 0.5)) + let result = traversal.dfs(graph, from: "a") + let visited = Set(result.entities) + #expect(!visited.contains(EntityID("d"))) + // Every recorded edge must connect two visited nodes. + for rel in result.relationships { + #expect(visited.contains(rel.source)) + #expect(visited.contains(rel.target)) + } +} + +@Test func negativeTopKDoesNotCrashSearch() async throws { + let config = Config(topKResults: -5) + let rag = try GraphRAGBuilder().withConfig(config).build() + await rag.addDocument(text: "Graphs connect entities and relationships.") + try await rag.build() + let answer = try await rag.ask("graphs") + // Should degrade to the no-results answer rather than trapping. + #expect(answer.sources.isEmpty) }