diff --git a/Documentation/TTS/PocketTTS.md b/Documentation/TTS/PocketTTS.md index 27659910c..bfd7dad32 100644 --- a/Documentation/TTS/PocketTTS.md +++ b/Documentation/TTS/PocketTTS.md @@ -145,6 +145,29 @@ fluidaudio tts "Hello world" --backend pocket --voice-file my_voice.bin - The `mimi_encoder.mlmodelc` model is downloaded automatically on first use - Supports any audio format that AVFoundation can read +### Cloning Across Languages + +The Mimi encoder is language-agnostic — voice cloning produces a generic +acoustic embedding that any language pack's `cond_step` model can consume. +You can: + +- Clone a voice once and reuse the same `PocketTtsVoiceData` across managers + configured with different languages. +- Clone a voice with a Spanish-only manager without pulling in the English + language pack — only the encoder subtree is downloaded. + +```swift +// Clone with a Spanish manager +let esManager = PocketTtsManager(language: .spanish) +try await esManager.initialize() +let voiceData = try await esManager.cloneVoice(from: speakerAudioURL) + +// Use the same cloned voice with a French manager +let frManager = PocketTtsManager(language: .french24L) +try await frManager.initialize() +let frAudio = try await frManager.synthesize(text: "Bonjour", voiceData: voiceData) +``` + ## Pipeline and Pronunciation Control ``` @@ -214,6 +237,57 @@ for try await frame in session.frames { | Streaming playback | `synthesizeStreaming()` | | Streaming text or custom chunking | `makeSession()` | +## Languages + +PocketTTS ships with multiple language packs converted from +[kyutai/pocket-tts](https://huggingface.co/kyutai/pocket-tts). Pick the one +that matches your input text — there is no automatic language detection. + +| ID | Layers | HF Path | +|----|--------|---------| +| `english` | 6 | repo root (legacy layout) | +| `german` | 6 | `v2/german/` | +| `german_24l` | 24 | `v2/german_24l/` | +| `italian` | 6 | `v2/italian/` | +| `italian_24l` | 24 | `v2/italian_24l/` | +| `portuguese` | 6 | `v2/portuguese/` | +| `portuguese_24l` | 24 | `v2/portuguese_24l/` | +| `spanish` | 6 | `v2/spanish/` | +| `spanish_24l` | 24 | `v2/spanish_24l/` | +| `french_24l` | 24 | `v2/french_24l/` | + +Notes: +- French only ships a 24-layer pack upstream (no 6-layer variant). +- 24-layer packs are higher quality but slower and larger. +- The 21 voice names (alba, anna, eve, michael, …) are shared across + languages, but the underlying acoustic embeddings are per-language. +- Mimi encoder weights (used for voice cloning) are language-agnostic and + always live at the repo root. + +### Swift API + +```swift +let manager = PocketTtsManager(language: .spanish) +try await manager.initialize() +let audio = try await manager.synthesize(text: "Hola mundo") +``` + +`PocketTtsManager.language` is immutable per instance. To support multiple +languages in one app, instantiate one manager per language. + +### CLI Usage + +```bash +# Default (English) +fluidaudio tts "Hello world" --backend pocket --output en.wav + +# Spanish (6L) +fluidaudio tts "Hola mundo" --backend pocket --language spanish --output es.wav + +# French (24L only) +fluidaudio tts "Bonjour" --backend pocket --language french_24l --output fr.wav +``` + ## Usage PocketTTS is part of core `FluidAudio` - no GPL dependencies required. diff --git a/README.md b/README.md index d49573303..002c046c5 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Want to convert your own model? Check [möbius](https://github.com/FluidInferenc - **Automatic Speech Recognition (ASR)**: [Parakeet TDT v3](Documentation/Models.md#batch-transcription-near-real-time) (0.6b) and other TDT/CTC models for batch transcription supporting 25 European languages, Japanese, and Chinese; [Parakeet EOU](Documentation/Models.md#streaming-transcription-true-real-time) (120m) for streaming ASR with end-of-utterance detection (English only). See all [ASR models](Documentation/Models.md#asr-models). - **Inverse Text Normalization (ITN)**: Post-process ASR output to convert spoken-form to written-form ("two hundred" → "200"). See [text-processing-rs](https://github.com/FluidInference/text-processing-rs) -- **Text-to-Speech (TTS)**: Kokoro (82m) for parallel synthesis with SSML and pronunciation control across 9 languages (EN, ES, FR, HI, IT, JA, PT, ZH); PocketTTS for streaming TTS with voice cloning support (English only) +- **Text-to-Speech (TTS)**: Kokoro (82m) for parallel synthesis with SSML and pronunciation control across 9 languages (EN, ES, FR, HI, IT, JA, PT, ZH); PocketTTS for streaming TTS with voice cloning support (EN, DE, ES, FR, IT, PT — 6L and 24L variants) - **Speaker Diarization (Online + Offline)**: Speaker separation and identification across audio streams. Streaming pipeline for real-time processing and offline batch pipeline with advanced clustering. - **Speaker Embedding Extraction**: Generate speaker embeddings for voice comparison and clustering, you can use this for speaker identification - **Voice Activity Detection (VAD)**: Voice activity detection with Silero models @@ -556,25 +556,36 @@ FluidAudio ships two TTS backends: ### PocketTTS Streaming-friendly TTS with voice cloning support from short audio samples. +Available language packs: `english` (default), `german`, `german_24l`, +`italian`, `italian_24l`, `portuguese`, `portuguese_24l`, `spanish`, +`spanish_24l`, `french_24l` (24-layer only — no 6-layer French upstream). ```swift import FluidAudio Task { - let manager = try await PocketTtsManager() - let audioData = try await manager.synthesize("Hello from FluidAudio.") + let manager = PocketTtsManager(language: .spanish) + try await manager.initialize() + let audioData = try await manager.synthesize(text: "Hola, mundo.") try audioData.write(to: URL(fileURLWithPath: "out.wav")) } ``` ```bash -# Synthesize with default voice +# English (default) swift run fluidaudiocli tts "Hello from FluidAudio." --output out.wav --backend pocket -# Clone a voice from an audio sample +# Other languages +swift run fluidaudiocli tts "Hola mundo" --backend pocket --language spanish --output es.wav +swift run fluidaudiocli tts "Bonjour" --backend pocket --language french_24l --output fr.wav + +# Clone a voice from an audio sample (works with any language pack) swift run fluidaudiocli tts "Hello world." --output out.wav --backend pocket --clone-voice speaker.wav ``` +See [Documentation/TTS/PocketTTS.md](Documentation/TTS/PocketTTS.md#languages) +for the full language table. + ### Kokoro High-quality parallel TTS with SSML and phoneme-level pronunciation control. Uses a CoreML G2P (grapheme-to-phoneme) model for out-of-vocabulary words — no external dependencies required. diff --git a/Sources/FluidAudio/DownloadUtils.swift b/Sources/FluidAudio/DownloadUtils.swift index 5191280bb..15b6b9250 100644 --- a/Sources/FluidAudio/DownloadUtils.swift +++ b/Sources/FluidAudio/DownloadUtils.swift @@ -575,8 +575,10 @@ public class DownloadUtils { public static func downloadSubdirectory( _ repo: Repo, subdirectory: String, - to repoDirectory: URL + to repoDirectory: URL, + progressHandler: ProgressHandler? = nil ) async throws { + progressHandler?(DownloadProgress(fractionCompleted: 0.0, phase: .listing)) var filesToDownload: [(path: String, size: Int)] = [] func listFiles(at path: String) async throws { @@ -611,12 +613,22 @@ public class DownloadUtils { } try await listFiles(at: subdirectory) - logger.info("Found \(filesToDownload.count) files in \(subdirectory)") + let totalFiles = filesToDownload.count + logger.info("Found \(totalFiles) files in \(subdirectory)") + progressHandler?( + DownloadProgress( + fractionCompleted: totalFiles == 0 ? 1.0 : 0.0, + phase: .downloading(completedFiles: 0, totalFiles: totalFiles))) for (index, file) in filesToDownload.enumerated() { let destPath = repoDirectory.appendingPathComponent(file.path) if FileManager.default.fileExists(atPath: destPath.path) { + progressHandler?( + DownloadProgress( + fractionCompleted: Double(index + 1) / Double(totalFiles), + phase: .downloading( + completedFiles: index + 1, totalFiles: totalFiles))) continue } @@ -658,8 +670,14 @@ public class DownloadUtils { } try FileManager.default.moveItem(at: tempURL, to: destPath) - if (index + 1) % 5 == 0 || index == filesToDownload.count - 1 { - logger.info("Downloaded \(index + 1)/\(filesToDownload.count) \(subdirectory) files") + progressHandler?( + DownloadProgress( + fractionCompleted: Double(index + 1) / Double(totalFiles), + phase: .downloading( + completedFiles: index + 1, totalFiles: totalFiles))) + + if (index + 1) % 5 == 0 || index == totalFiles - 1 { + logger.info("Downloaded \(index + 1)/\(totalFiles) \(subdirectory) files") } } diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index 69264524c..50018f671 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -565,25 +565,41 @@ public enum ModelNames { public static let condStep = "cond_step" public static let flowlmStep = "flowlm_step" public static let flowDecoder = "flow_decoder" - public static let mimiDecoder = "mimi_decoder_v2" + /// Legacy English (root-of-repo) Mimi decoder file basename. + public static let mimiDecoderLegacy = "mimi_decoder_v2" + /// New per-language `v2//` Mimi decoder file basename. + public static let mimiDecoderV2 = "mimi_decoder" public static let mimiEncoder = "mimi_encoder" public static let condStepFile = condStep + ".mlmodelc" public static let flowlmStepFile = flowlmStep + ".mlmodelc" public static let flowDecoderFile = flowDecoder + ".mlmodelc" - public static let mimiDecoderFile = mimiDecoder + ".mlmodelc" + public static let mimiDecoderLegacyFile = mimiDecoderLegacy + ".mlmodelc" + public static let mimiDecoderV2File = mimiDecoderV2 + ".mlmodelc" public static let mimiEncoderFile = mimiEncoder + ".mlmodelc" /// Directory containing binary constants, tokenizer, and voice data. public static let constantsBinDir = "constants_bin" - public static let requiredModels: Set = [ - condStepFile, - flowlmStepFile, - flowDecoderFile, - mimiDecoderFile, - constantsBinDir, - ] + /// Returns the Mimi decoder filename used inside this language's pack. + public static func mimiDecoderFile(for language: PocketTtsLanguage) -> String { + language == .english ? mimiDecoderLegacyFile : mimiDecoderV2File + } + + /// Required models inside the language root for the given language. + /// + /// English (legacy root) and other languages use different Mimi + /// decoder filenames, but all four model directories plus the + /// `constants_bin/` directory must be present. + public static func requiredModels(for language: PocketTtsLanguage) -> Set { + [ + condStepFile, + flowlmStepFile, + flowDecoderFile, + mimiDecoderFile(for: language), + constantsBinDir, + ] + } /// Models required for voice cloning (optional feature). public static let voiceCloningModels: Set = [ @@ -743,7 +759,7 @@ public enum ModelNames { return ttsModels.union(ModelNames.G2P.requiredModels) .union(ModelNames.MultilingualG2P.requiredModels) case .pocketTts: - return ModelNames.PocketTTS.requiredModels + return ModelNames.PocketTTS.requiredModels(for: .english) case .sortformer: if let variant = variant { return [variant] diff --git a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsConstantsLoader.swift b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsConstantsLoader.swift index a5d98495c..cdcff0e4e 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsConstantsLoader.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsConstantsLoader.swift @@ -9,11 +9,55 @@ public struct PocketTtsConstantsBundle: Sendable { } /// Pre-loaded voice conditioning data. +/// +/// Two formats are supported: +/// - **Flat audio prompt** (legacy English `_audio_prompt.bin`): +/// a `[1, promptLength, 1024]` Float32 tensor that the runtime feeds +/// through `cond_step` token-by-token to populate the LM transformer +/// KV cache. +/// - **Pre-baked KV cache snapshot** (v2 packs `.safetensors`): +/// per-layer K/V tensors already shaped to drop directly into the +/// LM transformer's `cache{i}` slots. Skips `cond_step` voice prefill. +/// +/// At most one of the two is populated. The synthesizer's `prefillKVCache` +/// branches on `cacheSnapshot != nil` to choose the path. public struct PocketTtsVoiceData: Sendable { - /// Flattened audio prompt: [1, promptLength, 1024] + /// Flattened audio prompt: [1, promptLength, 1024]. Empty when + /// `cacheSnapshot` is non-nil. public let audioPrompt: [Float] - /// Number of voice conditioning tokens (typically 125). + /// Number of voice conditioning tokens (typically 125). Zero when + /// `cacheSnapshot` is non-nil. public let promptLength: Int + /// Pre-baked LM transformer KV cache (v2 packs only). + public let cacheSnapshot: PocketTtsVoiceCacheSnapshot? + + public init( + audioPrompt: [Float], + promptLength: Int, + cacheSnapshot: PocketTtsVoiceCacheSnapshot? = nil + ) { + self.audioPrompt = audioPrompt + self.promptLength = promptLength + self.cacheSnapshot = cacheSnapshot + } +} + +/// Pre-baked KV cache snapshot for v2 voice packs. +/// +/// One `LayerCache` per LM transformer layer, in layer order. Each +/// `cache` is a flat Float32 array shaped `[2, 1, seqLen, 16, 64]` +/// (K and V interleaved at outer dim) matching the model's `cache{i}` +/// input layout, and `offset` is the number of tokens already filled +/// (= seqLen for a fully-prebaked snapshot). +public struct PocketTtsVoiceCacheSnapshot: Sendable { + public struct LayerCache: Sendable { + public let cache: [Float] + public let offset: Int + } + + public let layers: [LayerCache] + /// Sequence length per layer in the source snapshot (e.g. 126). + public let cacheSeqLen: Int } /// Loads PocketTTS constants from raw `.bin` Float32 files on disk. @@ -36,9 +80,13 @@ public enum PocketTtsConstantsLoader { expectedCount: PocketTtsConstants.latentDim, name: "bos_emb" ) - let embedTable = try loadFloatArray( + // text_embed_table vocab size is language-dependent — English ships + // 4001 rows, other languages may differ. Validate only that the file + // is a clean multiple of embeddingDim; the synthesizer indexes into + // this table by token ID at runtime. + let embedTable = try loadFloatArrayMultipleOf( from: constantsDir.appendingPathComponent("text_embed_table.bin"), - expectedCount: PocketTtsConstants.vocabSize * PocketTtsConstants.embeddingDim, + rowSize: PocketTtsConstants.embeddingDim, name: "text_embed_table" ) @@ -65,10 +113,15 @@ public enum PocketTtsConstantsLoader { /// Load voice conditioning data from the given directory. /// - /// Supports variable-length voice prompts — the prompt length is derived - /// from the file size (`floatCount / embeddingDim`). + /// Resolution order: + /// 1. `.safetensors` — v2 packs ship pre-baked LM KV cache snapshots + /// (per-layer `[2,1,seqLen,16,64]` F32 + `[1]` I64 offset). + /// 2. `_audio_prompt.bin` — legacy English flat `[1,promptLen,1024]` F32. /// - /// HuggingFace layout: `constants_bin/_audio_prompt.bin` + /// Both legacy English (root-level pack) and v2 language packs are supported + /// without code branching at the call site — the loader picks the right + /// format and the synthesizer's `prefillKVCache` dispatches on + /// `cacheSnapshot`. public static func loadVoice( _ voice: String, from directory: URL ) throws -> PocketTtsVoiceData { @@ -79,13 +132,26 @@ public enum PocketTtsConstantsLoader { } let constantsDir = directory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir) - let voiceURL = constantsDir.appendingPathComponent("\(sanitized)_audio_prompt.bin") + let safetensorsURL = constantsDir.appendingPathComponent("\(sanitized).safetensors") + let binURL = constantsDir.appendingPathComponent("\(sanitized)_audio_prompt.bin") + + if FileManager.default.fileExists(atPath: safetensorsURL.path) { + let snapshot = try loadVoiceSnapshot(from: safetensorsURL, voiceName: sanitized) + logger.info( + "Loaded PocketTTS voice '\(sanitized)' from safetensors (\(snapshot.layers.count) layers, seq=\(snapshot.cacheSeqLen))" + ) + return PocketTtsVoiceData( + audioPrompt: [], + promptLength: 0, + cacheSnapshot: snapshot + ) + } - guard FileManager.default.fileExists(atPath: voiceURL.path) else { - throw LoadError.fileNotFound("\(sanitized)_audio_prompt") + guard FileManager.default.fileExists(atPath: binURL.path) else { + throw LoadError.fileNotFound("\(sanitized) (no .safetensors or _audio_prompt.bin)") } - let data = try Data(contentsOf: voiceURL) + let data = try Data(contentsOf: binURL) let embDim = PocketTtsConstants.embeddingDim let floatCount = data.count / MemoryLayout.size @@ -115,6 +181,130 @@ public enum PocketTtsConstantsLoader { return PocketTtsVoiceData(audioPrompt: audioPrompt, promptLength: promptLength) } + /// Parse a v2 voice safetensors file into a `PocketTtsVoiceCacheSnapshot`. + /// + /// Expected schema (kyutai pocket-tts v2): + /// - `transformer.layers.{N}.self_attn/cache` F32 `[2, 1, seqLen, 16, 64]` + /// - `transformer.layers.{N}.self_attn/offset` I64 `[1]` + /// + /// All layers share the same `seqLen`. Layer count is auto-detected (6 + /// for non-24L packs, 24 for `*_24l` packs). + private static func loadVoiceSnapshot( + from url: URL, voiceName: String + ) throws -> PocketTtsVoiceCacheSnapshot { + let raw = try Data(contentsOf: url) + let parsed = try parseSafetensors(raw, fileLabel: voiceName) + + // Collect cache + offset entries by layer index. + var caches: [Int: SafetensorsTensor] = [:] + var offsets: [Int: SafetensorsTensor] = [:] + let cachePrefix = "transformer.layers." + let cacheSuffix = ".self_attn/cache" + let offsetSuffix = ".self_attn/offset" + + for (key, tensor) in parsed.tensors { + guard key.hasPrefix(cachePrefix) else { continue } + let stripped = String(key.dropFirst(cachePrefix.count)) + if let dotRange = stripped.firstIndex(of: ".") { + let idxStr = String(stripped[...size + let cacheFloats: [Float] = raw.withUnsafeBytes { rawBuf -> [Float] in + let base = rawBuf.baseAddress!.advanced(by: cacheTensor.byteOffset) + let typed = base.assumingMemoryBound(to: Float.self) + return Array(UnsafeBufferPointer(start: typed, count: floatCount)) + } + + // offset is I64 [1] + guard offsetTensor.dtype == "I64", offsetTensor.shape == [1] else { + throw LoadError.invalidSize( + "\(voiceName).safetensors: layer \(layerIdx) offset shape/dtype unexpected", + expected: 0, + actual: 0 + ) + } + let offsetVal: Int = raw.withUnsafeBytes { rawBuf -> Int in + let base = rawBuf.baseAddress!.advanced(by: offsetTensor.byteOffset) + let typed = base.assumingMemoryBound(to: Int64.self) + return Int(typed.pointee) + } + + layers.append(.init(cache: cacheFloats, offset: offsetVal)) + } + + return PocketTtsVoiceCacheSnapshot(layers: layers, cacheSeqLen: seqLen) + } + // MARK: - Private /// Load a raw Float32 binary file into a [Float] array. @@ -137,4 +327,150 @@ public enum PocketTtsConstantsLoader { return Array(floatBuffer) } } + + // MARK: - safetensors + + /// One tensor entry from a safetensors file. + fileprivate struct SafetensorsTensor { + let dtype: String + let shape: [Int] + /// Absolute byte offset into the original file (already includes + /// the 8-byte size prefix and JSON header length). + let byteOffset: Int + let byteCount: Int + } + + fileprivate struct SafetensorsParsed { + let tensors: [String: SafetensorsTensor] + } + + /// Minimal safetensors reader: 8-byte LE u64 header length, then JSON + /// header, then raw tensor bytes. Only used for v2 voice prebakes — no + /// need to support arbitrary safetensors features. + fileprivate static func parseSafetensors( + _ data: Data, fileLabel: String + ) throws -> SafetensorsParsed { + guard data.count >= 8 else { + throw LoadError.invalidSize( + "\(fileLabel).safetensors: too small", + expected: 8, + actual: data.count + ) + } + // Header length: little-endian u64 at byte 0. + let headerLen: UInt64 = data.withUnsafeBytes { rawBuf -> UInt64 in + let typed = rawBuf.baseAddress!.assumingMemoryBound(to: UInt64.self) + return UInt64(littleEndian: typed.pointee) + } + let headerStart = 8 + let headerEnd = headerStart + Int(headerLen) + guard headerEnd <= data.count else { + throw LoadError.invalidSize( + "\(fileLabel).safetensors: header overflow", + expected: headerEnd, + actual: data.count + ) + } + let headerBytes = data.subdata(in: headerStart..= start else { + throw LoadError.invalidSize( + "\(fileLabel).safetensors: negative span for '\(key)'", + expected: start, + actual: end + ) + } + let absStart = dataStart + start + let span = end - start + guard absStart + span <= data.count else { + throw LoadError.invalidSize( + "\(fileLabel).safetensors: tensor '\(key)' overflows file", + expected: data.count, + actual: absStart + span + ) + } + tensors[key] = SafetensorsTensor( + dtype: dtype, + shape: shape, + byteOffset: absStart, + byteCount: span + ) + } + return SafetensorsParsed(tensors: tensors) + } + + /// Load a raw Float32 binary file whose length must be a non-zero multiple + /// of `rowSize`. Used for tensors whose leading dimension varies per + /// language (e.g. `text_embed_table` vocab size). + private static func loadFloatArrayMultipleOf( + from url: URL, rowSize: Int, name: String + ) throws -> [Float] { + guard FileManager.default.fileExists(atPath: url.path) else { + throw LoadError.fileNotFound(name) + } + + let data = try Data(contentsOf: url) + let actualCount = data.count / MemoryLayout.size + + guard actualCount > 0, actualCount % rowSize == 0 else { + throw LoadError.invalidSize(name, expected: rowSize, actual: actualCount) + } + + return data.withUnsafeBytes { rawBuffer in + let floatBuffer = rawBuffer.bindMemory(to: Float.self) + return Array(floatBuffer) + } + } } diff --git a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift index db44911e4..5ef4a79fe 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift @@ -6,13 +6,20 @@ public enum PocketTtsResourceDownloader { private static let logger = AppLogger(category: "PocketTtsResourceDownloader") - /// Ensure all PocketTTS models are downloaded and return the cache directory. + /// Ensure all PocketTTS models for the given language are downloaded and + /// return the **language root** directory. /// /// - Parameters: + /// - language: Which upstream language pack to fetch. /// - directory: Optional override for the base cache directory. /// When `nil`, uses the default platform cache location. /// - progressHandler: Optional callback for download progress updates. + /// - Returns: The directory that contains the four `.mlmodelc` packages + /// plus `constants_bin/` for the requested language. For English this + /// is the legacy repo root; for other languages it's + /// `/v2//`. public static func ensureModels( + language: PocketTtsLanguage = .english, directory: URL? = nil, progressHandler: DownloadUtils.ProgressHandler? = nil ) async throws -> URL { @@ -22,31 +29,58 @@ public enum PocketTtsResourceDownloader { let repoDir = modelsDirectory.appendingPathComponent(Repo.pocketTts.folderName) - // Check that all required directories exist (models + constants_bin) - let requiredModels = ModelNames.PocketTTS.requiredModels + let languageRoot: URL + if let subdir = language.repoSubdirectory { + languageRoot = repoDir.appendingPathComponent(subdir) + } else { + languageRoot = repoDir + } + + let requiredModels = ModelNames.PocketTTS.requiredModels(for: language) let allPresent = requiredModels.allSatisfy { model in FileManager.default.fileExists( - atPath: repoDir.appendingPathComponent(model).path) + atPath: languageRoot.appendingPathComponent(model).path) } if !allPresent { - logger.info("Downloading PocketTTS models from HuggingFace...") - try await DownloadUtils.downloadRepo(.pocketTts, to: modelsDirectory, progressHandler: progressHandler) + if let subdir = language.repoSubdirectory { + logger.info( + "Downloading PocketTTS \(language.rawValue) language pack from HuggingFace (\(subdir))..." + ) + try await DownloadUtils.downloadSubdirectory( + .pocketTts, + subdirectory: subdir, + to: repoDir, + progressHandler: progressHandler + ) + } else { + logger.info("Downloading PocketTTS English models from HuggingFace...") + try await DownloadUtils.downloadRepo( + .pocketTts, to: modelsDirectory, progressHandler: progressHandler) + } } else { - logger.info("PocketTTS models found in cache") + logger.info( + "PocketTTS \(language.rawValue) models found in cache") } - return repoDir + return languageRoot } /// Ensure the Mimi encoder model is downloaded for voice cloning. /// /// This is an optional model that's only needed for voice cloning functionality. /// It's downloaded separately from the main models to reduce initial download size. + /// The encoder is shared across all language packs and lives at the legacy + /// repo root regardless of which language is currently loaded — so a Spanish + /// (or any non-English) user can clone a voice without pulling in the + /// English language pack. /// - Parameter directory: Optional override for the base cache directory. /// When `nil`, uses the default platform cache location. public static func ensureMimiEncoder(directory: URL? = nil) async throws -> URL { - let repoDir = try await ensureModels(directory: directory) + let targetDir = try directory ?? cacheDirectory() + let modelsDirectory = targetDir.appendingPathComponent( + PocketTtsConstants.defaultModelsSubdirectory) + let repoDir = modelsDirectory.appendingPathComponent(Repo.pocketTts.folderName) let encoderPath = repoDir.appendingPathComponent(ModelNames.PocketTTS.mimiEncoderFile) if FileManager.default.fileExists(atPath: encoderPath.path) { @@ -54,6 +88,11 @@ public enum PocketTtsResourceDownloader { return encoderPath } + // Make sure the parent directory exists — the user may not have + // downloaded any language pack yet. + try FileManager.default.createDirectory( + at: repoDir, withIntermediateDirectories: true) + logger.info("Downloading Mimi encoder for voice cloning...") try await downloadMimiEncoder(to: repoDir) @@ -74,36 +113,70 @@ public enum PocketTtsResourceDownloader { } /// Ensure constants (binary blobs + tokenizer) are available. - public static func ensureConstants(repoDirectory: URL) throws -> PocketTtsConstantsBundle { - try PocketTtsConstantsLoader.load(from: repoDirectory) + /// + /// - Parameter languageRoot: The directory returned by `ensureModels(...)`, + /// which contains the language-specific `constants_bin/`. + public static func ensureConstants(languageRoot: URL) throws -> PocketTtsConstantsBundle { + try PocketTtsConstantsLoader.load(from: languageRoot) } - /// Ensure voice conditioning data is available, downloading from HuggingFace if missing. + /// Ensure voice conditioning data for the given language is available, + /// downloading from HuggingFace if missing. + /// + /// - Parameters: + /// - voice: Voice name (e.g. `"alba"`, `"michael"`). + /// - language: Language pack the voice belongs to. Voice files are + /// per-language (same names, different acoustic embeddings). + /// - languageRoot: The directory returned by `ensureModels(language:)`. public static func ensureVoice( - _ voice: String, repoDirectory: URL + _ voice: String, + language: PocketTtsLanguage = .english, + languageRoot: URL ) async throws -> PocketTtsVoiceData { let sanitized = voice.filter { $0.isLetter || $0.isNumber || $0 == "_" } guard !sanitized.isEmpty else { throw PocketTTSError.processingFailed("Invalid voice name: \(voice)") } - let constantsDir = repoDirectory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir) - let voiceFile = "\(sanitized)_audio_prompt.bin" - let voiceURL = constantsDir.appendingPathComponent(voiceFile) - - if !FileManager.default.fileExists(atPath: voiceURL.path) { - logger.info("Downloading voice '\(sanitized)' from HuggingFace...") - let remotePath = "constants_bin/\(voiceFile)" + let constantsDir = languageRoot.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir) + let safetensorsFile = "\(sanitized).safetensors" + let binFile = "\(sanitized)_audio_prompt.bin" + let safetensorsURL = constantsDir.appendingPathComponent(safetensorsFile) + let binURL = constantsDir.appendingPathComponent(binFile) + + let safetensorsExists = FileManager.default.fileExists(atPath: safetensorsURL.path) + let binExists = FileManager.default.fileExists(atPath: binURL.path) + + if !safetensorsExists && !binExists { + // For non-English (v2) packs, voices ship as `.safetensors`. For + // legacy English (root pack), they ship as flat `.bin`. Pick the + // expected format based on language and download just that file. + let remotePrefix: String + if let subdir = language.repoSubdirectory { + remotePrefix = "\(subdir)/" + } else { + remotePrefix = "" + } + let preferredFile = (language == .english) ? binFile : safetensorsFile + let preferredLocalURL = (language == .english) ? binURL : safetensorsURL + let remotePath = "\(remotePrefix)constants_bin/\(preferredFile)" let remoteURL = try ModelRegistry.resolveModel(Repo.pocketTts.remotePath, remotePath) + logger.info( + "Downloading voice '\(sanitized)' for \(language.rawValue) from HuggingFace (\(preferredFile))..." + ) let data = try await AssetDownloader.fetchData( from: remoteURL, - description: "\(sanitized) voice prompt", + description: "\(sanitized) voice prompt (\(language.rawValue))", logger: logger ) - try data.write(to: voiceURL, options: [.atomic]) + // Make sure the parent directory exists in case this is a fresh + // language pack that hasn't materialized constants_bin/ yet. + try FileManager.default.createDirectory( + at: constantsDir, withIntermediateDirectories: true) + try data.write(to: preferredLocalURL, options: [.atomic]) logger.info("Downloaded voice '\(sanitized)' (\(data.count / 1024) KB)") } - return try PocketTtsConstantsLoader.loadVoice(voice, from: repoDirectory) + return try PocketTtsConstantsLoader.loadVoice(voice, from: languageRoot) } // MARK: - Private diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsLayerKeys.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsLayerKeys.swift new file mode 100644 index 000000000..dc37702a7 --- /dev/null +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsLayerKeys.swift @@ -0,0 +1,187 @@ +@preconcurrency import CoreML +import Foundation + +/// Discovered CoreML output names for one transformer model (cond_step or +/// flowlm_step). +/// +/// CoreML auto-generates output names during tracing (`new_cache_N_internal_tensor_assign_2`, +/// `var_NNN`) and the exact numeric suffixes differ between 6L and 24L packs. +/// Rather than hardcoding the names per pack, we scan the model's output +/// description at load time and group outputs by tensor shape: +/// +/// - `[2, 1, kvCacheMaxLen, 16, 64]` → KV cache (one per layer) +/// - `[1]` → position scalar (one per layer) +/// - `[1, 1, transformerDim]` → transformer hidden state (flowlm_step only) +/// - `[1, 1, 1]` → EOS logit (flowlm_step only) +/// +/// Within each group we order by the numeric suffix in the name. Cache names +/// follow the closed form `new_cache_{2*i+1}_internal_tensor_assign_2` for +/// layers 0..N-2 with the last layer being `new_cache_internal_tensor_assign_2` +/// (no number — sorted last). Position names use `var_NNN` with irregular +/// strides that nevertheless increase monotonically per layer. +struct PocketTtsLayerKeys: Sendable { + /// One cache output name per transformer layer, ordered by layer index. + let cacheKeys: [String] + /// One position output name per transformer layer, ordered by layer index. + let positionKeys: [String] + /// Hidden-state output name (flowlm_step only). `nil` for cond_step. + let transformerOut: String? + /// EOS logit output name (flowlm_step only). `nil` for cond_step. + let eosLogit: String? + + var layerCount: Int { cacheKeys.count } + + enum DiscoveryError: Error, LocalizedError { + case shapeMismatch(modelName: String, expectedLayers: Int, actualCaches: Int) + case missingFlowLMOutputs(modelName: String, hasTransformer: Bool, hasEos: Bool) + + var errorDescription: String? { + switch self { + case .shapeMismatch(let modelName, let expected, let actual): + return + "PocketTTS layer-key discovery on \(modelName): expected \(expected) cache outputs, found \(actual)" + case .missingFlowLMOutputs(let modelName, let hasTransformer, let hasEos): + return + "PocketTTS \(modelName) missing flowlm outputs (transformer=\(hasTransformer), eos=\(hasEos))" + } + } + } + + /// Discover the output keys for a `cond_step` or `flowlm_step` CoreML model. + /// + /// - Parameters: + /// - model: The compiled CoreML model. + /// - kind: Which model this is — affects whether transformer/eos + /// outputs are required. + /// - expectedLayers: Optional sanity check for the layer count. + static func discover( + from model: MLModel, + kind: ModelKind, + expectedLayers: Int? = nil, + modelName: String + ) throws -> PocketTtsLayerKeys { + let outputs = model.modelDescription.outputDescriptionsByName + + // Bucket outputs by shape. + var cacheCandidates: [String] = [] + var positionCandidates: [String] = [] + var transformerCandidate: String? + var eosCandidate: String? + + let cacheShape = [ + 2, 1, PocketTtsConstants.kvCacheMaxLen, 16, 64, + ] + let transformerShape = [1, 1, PocketTtsConstants.transformerDim] + let eosShape = [1, 1, 1] + let positionShape = [1] + + for (name, desc) in outputs { + guard let constraint = desc.multiArrayConstraint else { continue } + let shape = constraint.shape.map { $0.intValue } + + if shape == cacheShape { + cacheCandidates.append(name) + } else if shape == positionShape { + positionCandidates.append(name) + } else if shape == transformerShape { + transformerCandidate = name + } else if shape == eosShape { + eosCandidate = name + } + } + + // Sort caches by extracted numeric suffix; "new_cache_internal_..." + // (no number) sorts as "last" (largest layer index). + cacheCandidates.sort { lhs, rhs in + let li = cacheLayerIndex(from: lhs) ?? Int.max + let ri = cacheLayerIndex(from: rhs) ?? Int.max + if li != ri { return li < ri } + return lhs < rhs + } + + // Sort positions by trailing numeric suffix. + positionCandidates.sort { lhs, rhs in + let li = trailingNumber(in: lhs) ?? Int.max + let ri = trailingNumber(in: rhs) ?? Int.max + if li != ri { return li < ri } + return lhs < rhs + } + + if let expected = expectedLayers, cacheCandidates.count != expected { + throw DiscoveryError.shapeMismatch( + modelName: modelName, + expectedLayers: expected, + actualCaches: cacheCandidates.count + ) + } + + if positionCandidates.count != cacheCandidates.count { + throw DiscoveryError.shapeMismatch( + modelName: modelName, + expectedLayers: cacheCandidates.count, + actualCaches: positionCandidates.count + ) + } + + switch kind { + case .condStep: + return PocketTtsLayerKeys( + cacheKeys: cacheCandidates, + positionKeys: positionCandidates, + transformerOut: nil, + eosLogit: nil + ) + case .flowlmStep: + guard let transformerOut = transformerCandidate, let eosLogit = eosCandidate else { + throw DiscoveryError.missingFlowLMOutputs( + modelName: modelName, + hasTransformer: transformerCandidate != nil, + hasEos: eosCandidate != nil + ) + } + return PocketTtsLayerKeys( + cacheKeys: cacheCandidates, + positionKeys: positionCandidates, + transformerOut: transformerOut, + eosLogit: eosLogit + ) + } + } + + enum ModelKind { + case condStep + case flowlmStep + } + + // MARK: - Name parsing + + /// Extract the layer index from a cache output name. + /// + /// Pattern: + /// - `new_cache__internal_tensor_assign_2` → returns `(N - 1) / 2` + /// - `new_cache_internal_tensor_assign_2` → returns `nil` (sorts last) + private static func cacheLayerIndex(from name: String) -> Int? { + // Strip the "new_cache_" prefix, then take everything up to the next "_". + guard name.hasPrefix("new_cache_") else { return nil } + let after = name.dropFirst("new_cache_".count) + guard let underscore = after.firstIndex(of: "_") else { return nil } + let head = after[.. Int? { + var digits = "" + for char in name.reversed() { + if char.isNumber { + digits.append(char) + } else { + break + } + } + guard !digits.isEmpty else { return nil } + return Int(String(digits.reversed())) + } +} diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsMimiKeys.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsMimiKeys.swift new file mode 100644 index 000000000..f38bc67da --- /dev/null +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsMimiKeys.swift @@ -0,0 +1,218 @@ +@preconcurrency import CoreML +import Foundation + +/// Discovered CoreML input/output schema for the Mimi audio decoder. +/// +/// The Mimi decoder ships in two upstream variants that differ in attention +/// cache layout and (auto-generated) output names: +/// - **Legacy English** — `attn{0,1}_cache` shape `[2, 1, 8, 256, 64]` +/// (heads-first) plus `attn{0,1}_end_offset` inputs. +/// - **v2 multi-language packs** — `attn{0,1}_cache` shape +/// `[2, 1, 256, 8, 64]` (seq-first), no `_end_offset` inputs. +/// +/// CoreML auto-generates non-passthrough output names (`var_NNN`) at +/// conversion time so they differ between packs. To keep one Swift runtime +/// path we discover both the streaming-state mapping and the audio output +/// name from the loaded model. +struct PocketTtsMimiKeys: Sendable { + + /// Output name for the `[1, 1, 1920]` audio frame. + let audioOutput: String + + /// Ordered streaming-state input → output mapping. Only contains state + /// inputs the loaded model actually accepts (so packs missing + /// `attn*_end_offset` simply omit those entries). + let stateMapping: [(input: String, output: String)] + + /// Declared shape per state input (as integers). Used by + /// `loadMimiInitialState` to allocate tensors matching the model. + let stateShapes: [String: [Int]] + + enum DiscoveryError: Error, LocalizedError { + case missingAudioOutput + case unmatchedStateInput(name: String, shape: [Int]) + case ambiguousMatch(name: String) + + var errorDescription: String? { + switch self { + case .missingAudioOutput: + return "PocketTTS Mimi decoder is missing a [1, 1, 1920] audio output" + case .unmatchedStateInput(let name, let shape): + return "PocketTTS Mimi decoder: no output of shape \(shape) for state input '\(name)'" + case .ambiguousMatch(let name): + return "PocketTTS Mimi decoder: could not deterministically pair state input '\(name)'" + } + } + } + + /// Canonical streaming-state input order. Used to disambiguate + /// shape-bucket pairing (e.g. `attn0_cache` before `attn1_cache`). + /// Inputs absent in a given pack (e.g. `attn{0,1}_end_offset` for v2) + /// are simply skipped. + private static let canonicalStateOrder: [String] = [ + "upsample_partial", + "attn0_cache", "attn0_offset", "attn0_end_offset", + "attn1_cache", "attn1_offset", "attn1_end_offset", + "conv0_prev", "conv0_first", + "convtr0_partial", + "res0_conv0_prev", "res0_conv0_first", + "res0_conv1_prev", "res0_conv1_first", + "convtr1_partial", + "res1_conv0_prev", "res1_conv0_first", + "res1_conv1_prev", "res1_conv1_first", + "convtr2_partial", + "res2_conv0_prev", "res2_conv0_first", + "res2_conv1_prev", "res2_conv1_first", + "conv_final_prev", "conv_final_first", + ] + + /// Discover the Mimi schema from a loaded `MLModel`. + static func discover(from model: MLModel) throws -> PocketTtsMimiKeys { + let inputs = model.modelDescription.inputDescriptionsByName + let outputs = model.modelDescription.outputDescriptionsByName + + // 1. Audio output is the only `[1, 1, 1920]` tensor. + let audioShape = [1, 1, PocketTtsConstants.samplesPerFrame] + let audioOutput = outputs.first { _, desc in + guard let constraint = desc.multiArrayConstraint else { return false } + return constraint.shape.map { $0.intValue } == audioShape + }?.key + + guard let audio = audioOutput else { + throw DiscoveryError.missingAudioOutput + } + + // 2. Build state input set + shapes (everything except `latent`). + var stateShapes: [String: [Int]] = [:] + for (name, desc) in inputs where name != "latent" { + guard let constraint = desc.multiArrayConstraint else { continue } + stateShapes[name] = constraint.shape.map { $0.intValue } + } + + // 3. Pair inputs to outputs. + // - Pass-through: output name equals input name (e.g. `conv0_first`, + // `res*_conv1_prev` zero-shape carry-throughs). + // - Semantic-named: outputs containing `end_offset` are reserved + // for inputs containing `end_offset` (legacy English mimi has + // `new_end_offset` / `new_end_offset_1` outputs that share shape + // `[1]` with `var_NNN` offset outputs). + // - Otherwise: match by shape, then disambiguate within a shape + // bucket by sorting outputs by trailing `var_NNN` and inputs + // in canonical order. + var availableOutputs = outputs + availableOutputs.removeValue(forKey: audio) + + // Remove pass-throughs first (cheap to identify). + var passThroughMap: [String: String] = [:] + for inputName in stateShapes.keys { + if availableOutputs[inputName] != nil { + passThroughMap[inputName] = inputName + availableOutputs.removeValue(forKey: inputName) + } + } + + // Reserve `*end_offset*` outputs exclusively for `*end_offset*` inputs. + // Pair them in canonical order, sorted by trailing digits ascending so + // `attn0_end_offset → new_end_offset_1`, `attn1_end_offset → new_end_offset`. + let endOffsetInputs = canonicalStateOrder.filter { + $0.contains("end_offset") && stateShapes[$0] != nil && passThroughMap[$0] == nil + } + var endOffsetOutputs = availableOutputs.keys.filter { $0.contains("end_offset") }.sorted { + let li = trailingNumber(in: $0) ?? Int.max + let ri = trailingNumber(in: $1) ?? Int.max + if li != ri { return li < ri } + return $0 < $1 + } + var endOffsetMap: [String: String] = [:] + for inputName in endOffsetInputs { + guard !endOffsetOutputs.isEmpty else { + throw DiscoveryError.unmatchedStateInput( + name: inputName, shape: stateShapes[inputName] ?? []) + } + let chosen = endOffsetOutputs.removeFirst() + endOffsetMap[inputName] = chosen + availableOutputs.removeValue(forKey: chosen) + } + + // Bucket remaining outputs by shape, sorted by var-number ascending. + var outputsByShape: [[Int]: [String]] = [:] + for (name, desc) in availableOutputs { + guard let constraint = desc.multiArrayConstraint else { continue } + let shape = constraint.shape.map { $0.intValue } + outputsByShape[shape, default: []].append(name) + } + for key in outputsByShape.keys { + outputsByShape[key]?.sort { lhs, rhs in + let li = trailingNumber(in: lhs) ?? Int.max + let ri = trailingNumber(in: rhs) ?? Int.max + if li != ri { return li < ri } + return lhs < rhs + } + } + + // Walk canonical order, taking outputs from each shape bucket. Skip + // inputs already resolved via pass-through or end-offset reservation. + var nonPassThroughInputs: [String] = [] + for name in canonicalStateOrder + where stateShapes[name] != nil + && passThroughMap[name] == nil + && endOffsetMap[name] == nil + { + nonPassThroughInputs.append(name) + } + // Any inputs not in canonical list (defensive) appended in name order. + for name in stateShapes.keys.sorted() + where !canonicalStateOrder.contains(name) + && passThroughMap[name] == nil + && endOffsetMap[name] == nil + { + nonPassThroughInputs.append(name) + } + + var resolvedMapping: [String: String] = passThroughMap + for (k, v) in endOffsetMap { resolvedMapping[k] = v } + for inputName in nonPassThroughInputs { + guard let shape = stateShapes[inputName] else { continue } + guard var bucket = outputsByShape[shape], !bucket.isEmpty else { + throw DiscoveryError.unmatchedStateInput(name: inputName, shape: shape) + } + let chosen = bucket.removeFirst() + outputsByShape[shape] = bucket + resolvedMapping[inputName] = chosen + } + + // Emit mapping in canonical order so iteration is deterministic. + var orderedMapping: [(input: String, output: String)] = [] + for name in canonicalStateOrder { + if let out = resolvedMapping[name] { + orderedMapping.append((input: name, output: out)) + } + } + // Append any non-canonical inputs at the end (defensive). + for name in stateShapes.keys.sorted() where !canonicalStateOrder.contains(name) { + if let out = resolvedMapping[name] { + orderedMapping.append((input: name, output: out)) + } + } + + return PocketTtsMimiKeys( + audioOutput: audio, + stateMapping: orderedMapping, + stateShapes: stateShapes + ) + } + + /// Extract the trailing run of digits from a name like `var_445`. + private static func trailingNumber(in name: String) -> Int? { + var digits = "" + for char in name.reversed() { + if char.isNumber { + digits.append(char) + } else { + break + } + } + guard !digits.isEmpty else { return nil } + return Int(String(digits.reversed())) + } +} diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsModelStore.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsModelStore.swift index 3f93f10a3..efbeb823f 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsModelStore.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsModelStore.swift @@ -7,6 +7,9 @@ import OSLog /// Manages loading and storing of the four CoreML models /// (cond_step, flowlm_step, flow_decoder, mimi_decoder), /// the binary constants bundle, and voice conditioning data. +/// +/// A store is bound to a single `PocketTtsLanguage` for its lifetime; switch +/// languages by creating a new store/manager. public actor PocketTtsModelStore { private let logger = AppLogger(subsystem: "com.fluidaudio.tts", category: "PocketTtsModelStore") @@ -18,12 +21,20 @@ public actor PocketTtsModelStore { private var mimiEncoderModel: MLModel? private var constantsBundle: PocketTtsConstantsBundle? private var voiceCache: [String: PocketTtsVoiceData] = [:] - private var repoDirectory: URL? + private var languageRootDirectory: URL? + private var condLayerKeys: PocketTtsLayerKeys? + private var flowlmLayerKeys: PocketTtsLayerKeys? + private var mimiDecoderKeysCache: PocketTtsMimiKeys? private let directory: URL? + public let language: PocketTtsLanguage - /// - Parameter directory: Optional override for the base cache directory. - /// When `nil`, uses the default platform cache location. - public init(directory: URL? = nil) { + /// - Parameters: + /// - language: Which upstream language pack to load. Defaults to + /// `.english` for backward compatibility. + /// - directory: Optional override for the base cache directory. When + /// `nil`, uses the default platform cache location. + public init(language: PocketTtsLanguage = .english, directory: URL? = nil) { + self.language = language self.directory = directory } @@ -31,10 +42,15 @@ public actor PocketTtsModelStore { public func loadIfNeeded() async throws { guard condStepModel == nil else { return } - let repoDir = try await PocketTtsResourceDownloader.ensureModels(directory: directory) - self.repoDirectory = repoDir + let languageRoot = try await PocketTtsResourceDownloader.ensureModels( + language: language, + directory: directory + ) + self.languageRootDirectory = languageRoot - logger.info("Loading PocketTTS CoreML models...") + logger.info( + "Loading PocketTTS CoreML models (language=\(self.language.rawValue))..." + ) // Use CPU+GPU for all models to avoid ANE float16 precision loss. // The ANE processes in native float16, which causes audible artifacts @@ -46,16 +62,16 @@ public actor PocketTtsModelStore { let loadStart = Date() - let modelFiles = [ + let modelFiles: [String] = [ ModelNames.PocketTTS.condStepFile, ModelNames.PocketTTS.flowlmStepFile, ModelNames.PocketTTS.flowDecoderFile, - ModelNames.PocketTTS.mimiDecoderFile, + ModelNames.PocketTTS.mimiDecoderFile(for: language), ] var loadedModels: [MLModel] = [] for file in modelFiles { - let modelURL = repoDir.appendingPathComponent(file) + let modelURL = languageRoot.appendingPathComponent(file) let model = try MLModel(contentsOf: modelURL, configuration: config) loadedModels.append(model) logger.info("Loaded \(file)") @@ -66,12 +82,34 @@ public actor PocketTtsModelStore { flowDecoderModel = loadedModels[2] mimiDecoderModel = loadedModels[3] + // Discover per-model output names. Names differ between 6L and 24L + // packs because CoreML auto-generates them during tracing. + let expectedLayers = language.transformerLayers + condLayerKeys = try PocketTtsLayerKeys.discover( + from: loadedModels[0], + kind: .condStep, + expectedLayers: expectedLayers, + modelName: "cond_step" + ) + flowlmLayerKeys = try PocketTtsLayerKeys.discover( + from: loadedModels[1], + kind: .flowlmStep, + expectedLayers: expectedLayers, + modelName: "flowlm_step" + ) + + // Discover Mimi decoder schema. Legacy English and v2 packs differ in + // attention cache layout, presence of `attn*_end_offset` inputs, and + // auto-generated `var_NNN` output names. Discovery makes both work + // through one runtime path. + mimiDecoderKeysCache = try PocketTtsMimiKeys.discover(from: loadedModels[3]) + let elapsed = Date().timeIntervalSince(loadStart) logger.info("All PocketTTS models loaded in \(String(format: "%.2f", elapsed))s") // Load constants constantsBundle = try PocketTtsResourceDownloader.ensureConstants( - repoDirectory: repoDir) + languageRoot: languageRoot) logger.info("PocketTTS constants loaded") } @@ -115,9 +153,36 @@ public actor PocketTtsModelStore { return bundle } - /// The repository directory containing models and constants. + /// Discovered output names for the cond_step transformer model. + func condStepLayerKeys() throws -> PocketTtsLayerKeys { + guard let keys = condLayerKeys else { + throw PocketTTSError.modelNotFound("PocketTTS cond_step layer keys not discovered") + } + return keys + } + + /// Discovered output names for the flowlm_step transformer model. + func flowLMStepLayerKeys() throws -> PocketTtsLayerKeys { + guard let keys = flowlmLayerKeys else { + throw PocketTTSError.modelNotFound("PocketTTS flowlm_step layer keys not discovered") + } + return keys + } + + /// Discovered I/O schema for the Mimi audio decoder model (state mapping, + /// audio output name, declared state shapes). + func mimiDecoderKeys() throws -> PocketTtsMimiKeys { + guard let keys = mimiDecoderKeysCache else { + throw PocketTTSError.modelNotFound("PocketTTS mimi_decoder keys not discovered") + } + return keys + } + + /// The language root directory (legacy repo root for English, or + /// `/v2/` otherwise) — contains the four model files, + /// `constants_bin/`, and is the right base for `loadMimiInitialState`. public func repoDir() throws -> URL { - guard let dir = repoDirectory else { + guard let dir = languageRootDirectory else { throw PocketTTSError.modelNotFound("PocketTTS repository not loaded") } return dir @@ -128,10 +193,14 @@ public actor PocketTtsModelStore { if let cached = voiceCache[voice] { return cached } - guard let repoDir = repoDirectory else { + guard let languageRoot = languageRootDirectory else { throw PocketTTSError.modelNotFound("PocketTTS repository not loaded") } - let data = try await PocketTtsResourceDownloader.ensureVoice(voice, repoDirectory: repoDir) + let data = try await PocketTtsResourceDownloader.ensureVoice( + voice, + language: language, + languageRoot: languageRoot + ) voiceCache[voice] = data return data } @@ -140,18 +209,15 @@ public actor PocketTtsModelStore { /// Load the Mimi encoder model for voice cloning (lazy, on-demand). /// - /// Downloads the model from HuggingFace if not already cached. + /// Downloads the model from HuggingFace if not already cached. The Mimi + /// encoder is shared across all language packs and lives at the legacy + /// repo root. public func loadMimiEncoderIfNeeded() async throws { guard mimiEncoderModel == nil else { return } // Ensure the mimi_encoder is downloaded (downloads if needed) let modelURL = try await PocketTtsResourceDownloader.ensureMimiEncoder(directory: directory) - // Update repoDirectory if not set - if repoDirectory == nil { - repoDirectory = modelURL.deletingLastPathComponent() - } - let config = MLModelConfiguration() config.computeUnits = .cpuAndGPU @@ -174,8 +240,18 @@ public actor PocketTtsModelStore { /// Check if the Mimi encoder model is available. public func isMimiEncoderAvailable() -> Bool { - guard let repoDir = repoDirectory else { return false } - let modelURL = repoDir.appendingPathComponent(ModelNames.PocketTTS.mimiEncoderFile) + // The Mimi encoder always lives at the repo root regardless of the + // currently selected language pack. + let repoRoot: URL + if let langRoot = languageRootDirectory { + repoRoot = + (language.repoSubdirectory == nil) + ? langRoot + : langRoot.deletingLastPathComponent().deletingLastPathComponent() + } else { + return false + } + let modelURL = repoRoot.appendingPathComponent(ModelNames.PocketTTS.mimiEncoderFile) return FileManager.default.fileExists(atPath: modelURL.path) } diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSession.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSession.swift index b96f12dc7..6f9c81032 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSession.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSession.swift @@ -58,6 +58,9 @@ public actor PocketTtsSession { private let stepModel: MLModel private let flowModel: MLModel private let mimiModel: MLModel + private let condLayerKeys: PocketTtsLayerKeys + private let flowlmLayerKeys: PocketTtsLayerKeys + private let mimiKeys: PocketTtsMimiKeys // Persistent state private let voiceKVSnapshot: PocketTtsSynthesizer.KVCacheState @@ -80,6 +83,9 @@ public actor PocketTtsSession { stepModel: MLModel, flowModel: MLModel, mimiModel: MLModel, + condLayerKeys: PocketTtsLayerKeys, + flowlmLayerKeys: PocketTtsLayerKeys, + mimiKeys: PocketTtsMimiKeys, bosEmb: MLMultiArray, temperature: Float, seed: UInt64 @@ -91,6 +97,9 @@ public actor PocketTtsSession { self.stepModel = stepModel self.flowModel = flowModel self.mimiModel = mimiModel + self.condLayerKeys = condLayerKeys + self.flowlmLayerKeys = flowlmLayerKeys + self.mimiKeys = mimiKeys self.bosEmb = bosEmb self.temperature = temperature self.rng = SeededRNG(seed: seed) @@ -172,7 +181,8 @@ public actor PocketTtsSession { // Clone voice KV snapshot and prefill text tokens only var kvState = try PocketTtsSynthesizer.cloneKVCacheState(voiceKVSnapshot) kvState = try await PocketTtsSynthesizer.prefillKVCacheText( - state: kvState, textEmbeddings: textEmbeddings, model: condModel + state: kvState, textEmbeddings: textEmbeddings, model: condModel, + layerKeys: condLayerKeys ) // Generation loop @@ -190,7 +200,8 @@ public actor PocketTtsSession { sequence: sequence, bosEmb: bosEmb, state: &localKV, - model: stepModel + model: stepModel, + layerKeys: flowlmLayerKeys ) kvState = localKV @@ -219,7 +230,8 @@ public actor PocketTtsSession { let frameSamples = try await PocketTtsSynthesizer.runMimiDecoder( latent: latent, state: &localMimi, - model: mimiModel + model: mimiModel, + mimiKeys: mimiKeys ) mimiState = localMimi diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift index f41d55347..3f16ba365 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift @@ -9,20 +9,24 @@ extension PocketTtsSynthesizer { /// for every processed token. This avoids recomputing K/V for past tokens — /// each new step only computes its own K/V, then reads all cached K/V via attention. struct KVCacheState { - /// 6 KV cache arrays, each shaped `[2, 1, kvCacheMaxLen, 16, 64]`: + /// `N` KV cache arrays (one per transformer layer), each shaped + /// `[2, 1, kvCacheMaxLen, 16, 64]`: /// - `2`: K and V tensors (index 0 = keys, index 1 = values) /// - `1`: batch size /// - `kvCacheMaxLen` (512): pre-allocated position slots /// - `16`: attention heads /// - `64`: dims per head (16 × 64 = 1024 total) + /// + /// `N` is 6 for the legacy English / 6L packs, and 24 for `*_24l` + /// packs. var caches: [MLMultiArray] - /// 6 position counters (one per layer) tracking the next write slot in the cache. + /// `N` position counters (one per layer) tracking the next write slot + /// in each cache. var positions: [MLMultiArray] } /// Create an empty KV cache state (all zeros, positions at 0). - static func emptyKVCacheState() throws -> KVCacheState { - let layers = PocketTtsConstants.kvCacheLayers + static func emptyKVCacheState(layers: Int) throws -> KVCacheState { let shape: [NSNumber] = [ 2, 1, NSNumber(value: PocketTtsConstants.kvCacheMaxLen), 16, 64, ] @@ -96,13 +100,15 @@ extension PocketTtsSynthesizer { static func runCondStep( conditioning: MLMultiArray, state: inout KVCacheState, - model: MLModel + model: MLModel, + layerKeys: PocketTtsLayerKeys ) async throws { + let layers = layerKeys.layerCount var inputDict: [String: Any] = [ "conditioning": conditioning ] - for i in 0.. KVCacheState { var state = state let dim = PocketTtsConstants.embeddingDim @@ -145,7 +152,8 @@ extension PocketTtsSynthesizer { offset: tokenIdx * dim, dim: dim ) - try await runCondStep(conditioning: token, state: &state, model: model) + try await runCondStep( + conditioning: token, state: &state, model: model, layerKeys: layerKeys) } return state @@ -158,36 +166,126 @@ extension PocketTtsSynthesizer { static func prefillKVCacheText( state: KVCacheState, textEmbeddings: [[Float]], - model: MLModel + model: MLModel, + layerKeys: PocketTtsLayerKeys ) async throws -> KVCacheState { var state = state let dim = PocketTtsConstants.embeddingDim for embedding in textEmbeddings { let token = try createConditioningToken(from: embedding, offset: 0, dim: dim) - try await runCondStep(conditioning: token, state: &state, model: model) + try await runCondStep( + conditioning: token, state: &state, model: model, layerKeys: layerKeys) } return state } + /// Build a `KVCacheState` from a pre-baked v2 voice snapshot. + /// + /// Each layer's source cache `[2, 1, seqLen, 16, 64]` is copied into the + /// first `seqLen` positions of a fresh `[2, 1, kvCacheMaxLen, 16, 64]` + /// allocation. The K block (outer dim 0) and V block (outer dim 1) are + /// copied independently because the dest seq capacity is larger than + /// the source — they don't lie at adjacent offsets in the dest. + /// `position{i}` is initialized from the snapshot's per-layer offset + /// (typically equal to `seqLen`). + static func kvCacheStateFromSnapshot( + _ snapshot: PocketTtsVoiceCacheSnapshot, + layers: Int + ) throws -> KVCacheState { + guard snapshot.layers.count == layers else { + throw PocketTTSError.processingFailed( + "voice snapshot layer count \(snapshot.layers.count) != model layer count \(layers)" + ) + } + let destSeq = PocketTtsConstants.kvCacheMaxLen + let srcSeq = snapshot.cacheSeqLen + guard srcSeq <= destSeq else { + throw PocketTTSError.processingFailed( + "voice snapshot seqLen \(srcSeq) exceeds model capacity \(destSeq)" + ) + } + + // For shape [2, 1, seq, 16, 64] row-major: + // K block size = 1 * seq * 16 * 64 floats + // V block size = same + let perKVFloats = 1 * srcSeq * 16 * 64 + let destPerKVFloats = 1 * destSeq * 16 * 64 + let copyBytes = perKVFloats * MemoryLayout.size + + let shape: [NSNumber] = [ + 2, 1, NSNumber(value: destSeq), 16, 64, + ] + + var caches: [MLMultiArray] = [] + var positions: [MLMultiArray] = [] + caches.reserveCapacity(layers) + positions.reserveCapacity(layers) + + for layerIdx in 0.. KVCacheState { - let emptyState = try emptyKVCacheState() - var state = try await prefillKVCacheVoice( - state: emptyState, voiceData: voiceData, model: model - ) + var state: KVCacheState + if let snapshot = voiceData.cacheSnapshot { + state = try kvCacheStateFromSnapshot(snapshot, layers: layerKeys.layerCount) + } else { + let emptyState = try emptyKVCacheState(layers: layerKeys.layerCount) + state = try await prefillKVCacheVoice( + state: emptyState, voiceData: voiceData, model: model, layerKeys: layerKeys + ) + } state = try await prefillKVCacheText( - state: state, textEmbeddings: textEmbeddings, model: model + state: state, textEmbeddings: textEmbeddings, model: model, layerKeys: layerKeys ) let finalPos = state.positions[0][0].floatValue @@ -223,14 +321,22 @@ extension PocketTtsSynthesizer { sequence: MLMultiArray, bosEmb: MLMultiArray, state: inout KVCacheState, - model: MLModel + model: MLModel, + layerKeys: PocketTtsLayerKeys ) async throws -> (transformerOut: MLMultiArray, eosLogit: Float) { + guard let transformerKey = layerKeys.transformerOut, let eosKey = layerKeys.eosLogit + else { + throw PocketTTSError.processingFailed( + "flowlm_step layer keys missing transformer/eos outputs") + } + + let layers = layerKeys.layerCount var inputDict: [String: Any] = [ "sequence": sequence, "bos_emb": bosEmb, ] - for i in 0.. MimiState { + /// Tensor shapes come from the loaded model's input descriptions, not a + /// manifest, so legacy English (`attn*_cache: [2,1,8,256,64]`) and v2 + /// multi-language packs (`attn*_cache: [2,1,256,8,64]`, no + /// `attn*_end_offset` inputs) both load through one path. `.bin` files + /// must be Float32 with element count matching the model's declared + /// shape; missing files mean a zero-initialized tensor. + static func loadMimiInitialState( + from repoDirectory: URL, + mimiKeys: PocketTtsMimiKeys + ) throws -> MimiState { let constantsDir = repoDirectory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir) let stateDir = constantsDir.appendingPathComponent("mimi_init_state") - let manifestURL = constantsDir.appendingPathComponent("manifest.json") - - // Parse manifest for mimi_init_state shapes - let manifestData = try Data(contentsOf: manifestURL) - guard let manifest = try JSONSerialization.jsonObject(with: manifestData) as? [String: Any], - let mimiManifest = manifest["mimi_init_state"] as? [String: Any] - else { - throw PocketTTSError.processingFailed("Failed to parse mimi_init_state from manifest.json") - } var tensors: [String: MLMultiArray] = [:] - for (name, info) in mimiManifest { - guard let infoDict = info as? [String: Any], - let shapeArray = infoDict["shape"] as? [Int], - let byteCount = infoDict["bytes"] as? Int - else { + for (name, shapeInts) in mimiKeys.stateShapes { + let shape = shapeInts.map { NSNumber(value: $0) } + let array = try MLMultiArray(shape: shape, dataType: .float32) + + // Zero-length tensors (e.g. `res*_conv1_prev: [1, 128, 0]`) are + // empty pass-throughs — nothing to load. + if shapeInts.contains(0) { + tensors[name] = array continue } - let shape = shapeArray.map { NSNumber(value: $0) } - let array = try MLMultiArray(shape: shape, dataType: .float32) + let elementCount = shapeInts.reduce(1, *) + let dstPtr = array.dataPointer.bindMemory(to: Float.self, capacity: elementCount) + // Default to zero in case the .bin file is absent (offset scalars, + // attention caches in zero-init packs, etc.). + dstPtr.initialize(repeating: 0, count: elementCount) - // Some tensors (e.g., res{0,1,2}_conv1_prev) have zero-length shapes - // and are empty pass-throughs — skip loading binary data for those. - if byteCount > 0 && !shapeArray.contains(0) { - let binURL = stateDir.appendingPathComponent("\(name).bin") + let binURL = stateDir.appendingPathComponent("\(name).bin") + if FileManager.default.fileExists(atPath: binURL.path) { let data = try Data(contentsOf: binURL) - let floatCount = byteCount / MemoryLayout.size - let dstPtr = array.dataPointer.bindMemory(to: Float.self, capacity: floatCount) - data.withUnsafeBytes { rawBuffer in - let srcPtr = rawBuffer.bindMemory(to: Float.self) - dstPtr.update(from: srcPtr.baseAddress!, count: floatCount) + let expectedBytes = elementCount * MemoryLayout.size + if data.count == expectedBytes { + data.withUnsafeBytes { rawBuffer in + let srcPtr = rawBuffer.bindMemory(to: Float.self) + dstPtr.update(from: srcPtr.baseAddress!, count: elementCount) + } } + // Mismatched bin sizes (e.g. English-packed `attn*_cache.bin` + // for a v2 model with the same byte count but different shape) + // fall back to zero-init, which is the correct empty-cache + // initial value anyway. } tensors[name] = array } - // Ensure offset scalars exist - for key in ["attn0_offset", "attn0_end_offset", "attn1_offset", "attn1_end_offset"] { - if tensors[key] == nil { - let scalar = try MLMultiArray(shape: [1], dataType: .float32) - scalar[0] = NSNumber(value: Float(0)) - tensors[key] = scalar - } - } - return MimiState(tensors: tensors) } @@ -104,7 +102,8 @@ extension PocketTtsSynthesizer { static func runMimiDecoder( latent: [Float], state: inout MimiState, - model: MLModel + model: MLModel, + mimiKeys: PocketTtsMimiKeys ) async throws -> [Float] { // Create latent input: [1, 32] let latentDim = PocketTtsConstants.latentDim @@ -116,17 +115,23 @@ extension PocketTtsSynthesizer { latentPtr.update(from: base, count: latentDim) } - // Build input dictionary + // Build input dictionary — only include keys the model actually accepts + // so that legacy-English-only tensors (e.g. `attn*_end_offset`) don't + // surface to v2 packs that omit those inputs. var inputDict: [String: Any] = ["latent": latentArray] - for (key, array) in state.tensors { - inputDict[key] = array + for (inputName, _) in mimiKeys.stateMapping { + guard let array = state.tensors[inputName] else { + throw PocketTTSError.processingFailed( + "Mimi state missing tensor '\(inputName)'") + } + inputDict[inputName] = array } let input = try MLDictionaryFeatureProvider(dictionary: inputDict) let output = try await model.compatPrediction(from: input, options: MLPredictionOptions()) // Extract audio output [1, 1, 1920] - guard let audioArray = output.featureValue(for: MimiKeys.audioOutput)?.multiArrayValue else { + guard let audioArray = output.featureValue(for: mimiKeys.audioOutput)?.multiArrayValue else { throw PocketTTSError.processingFailed("Missing Mimi audio output") } @@ -134,7 +139,7 @@ extension PocketTtsSynthesizer { let samples = readFloatArray(from: audioArray, count: sampleCount) // Update streaming state - for (inputName, outputName) in mimiStateMapping { + for (inputName, outputName) in mimiKeys.stateMapping { guard let updated = output.featureValue(for: outputName)?.multiArrayValue else { throw PocketTTSError.processingFailed( "Missing Mimi state output: \(outputName) (for \(inputName))") diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+Types.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+Types.swift index 9032e8d1d..4112e055f 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+Types.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+Types.swift @@ -14,87 +14,9 @@ extension PocketTtsSynthesizer { public let eosStep: Int? } - /// CoreML output key names for the conditioning step model. - /// - /// These names are auto-generated during CoreML model tracing and must match - /// the compiled `.mlmodelc` exactly. They only change when models are re-converted. - enum CondStepKeys { - static let cacheKeys: [String] = [ - "new_cache_1_internal_tensor_assign_2", - "new_cache_3_internal_tensor_assign_2", - "new_cache_5_internal_tensor_assign_2", - "new_cache_7_internal_tensor_assign_2", - "new_cache_9_internal_tensor_assign_2", - "new_cache_internal_tensor_assign_2", - ] - static let positionKeys: [String] = [ - "var_445", "var_864", "var_1283", "var_1702", "var_2121", "var_2365", - ] - } - - /// CoreML output key names for the generation step model. - /// - /// Auto-generated during CoreML model tracing. Must match the compiled model. - enum FlowLMStepKeys { - /// CoreML assigned this output the name "input" during model tracing — - /// it is the transformer hidden state output, not an input tensor. - static let transformerOut = "input" - static let eosLogit = "var_2582" - static let cacheKeys: [String] = [ - "new_cache_1_internal_tensor_assign_2", - "new_cache_3_internal_tensor_assign_2", - "new_cache_5_internal_tensor_assign_2", - "new_cache_7_internal_tensor_assign_2", - "new_cache_9_internal_tensor_assign_2", - "new_cache_internal_tensor_assign_2", - ] - static let positionKeys: [String] = [ - "var_458", "var_877", "var_1296", "var_1715", "var_2134", "var_2553", - ] - } - - /// CoreML output key names for the Mimi decoder model. - enum MimiKeys { - static let audioOutput = "var_821" - } - - /// Mimi decoder streaming state key mappings (input name → output name). - /// - /// 26 state tensors that carry the decoder's streaming context across frames: - /// - Upsampling: `upsample_partial` — partial output buffer for upsampling layers - /// - Attention: `attn{0,1}_cache/offset/end_offset` — causal attention KV caches - /// - Convolutions: `conv*_prev/first` — causal conv padding buffers - /// - Residual blocks: `res{0,1,2}_conv{0,1}_prev/first` — residual conv state - /// - Transposed convs: `convtr{0,1,2}_partial` — transposed conv overlap buffers - /// - /// 3 zero-length tensors (`res{0,1,2}_conv1_prev`) are pass-throughs where - /// input and output names are identical. - static let mimiStateMapping: [(input: String, output: String)] = [ - ("upsample_partial", "var_82"), - ("attn0_cache", "var_262"), - ("attn0_offset", "var_840"), - ("attn0_end_offset", "new_end_offset_1"), - ("attn1_cache", "var_479"), - ("attn1_offset", "var_843"), - ("attn1_end_offset", "new_end_offset"), - ("conv0_prev", "var_607"), - ("conv0_first", "conv0_first"), - ("convtr0_partial", "var_634"), - ("res0_conv0_prev", "var_660"), - ("res0_conv0_first", "res0_conv0_first"), - ("res0_conv1_prev", "res0_conv1_prev"), - ("res0_conv1_first", "res0_conv1_first"), - ("convtr1_partial", "var_700"), - ("res1_conv0_prev", "var_726"), - ("res1_conv0_first", "res1_conv0_first"), - ("res1_conv1_prev", "res1_conv1_prev"), - ("res1_conv1_first", "res1_conv1_first"), - ("convtr2_partial", "var_766"), - ("res2_conv0_prev", "var_792"), - ("res2_conv0_first", "res2_conv0_first"), - ("res2_conv1_prev", "res2_conv1_prev"), - ("res2_conv1_first", "res2_conv1_first"), - ("conv_final_prev", "var_824"), - ("conv_final_first", "conv_final_first"), - ] + /// CoreML output key names for the conditioning and generation step models + /// are discovered at model-load time via `PocketTtsLayerKeys.discover(...)`. + /// Mimi decoder I/O is discovered at model-load via `PocketTtsMimiKeys.discover(...)`. + /// Both use discovery because CoreML auto-generates `var_NNN` names that + /// differ between English/v2 packs and 6L/24L variants. } diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer.swift index 94537e523..ef221caff 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer.swift @@ -74,10 +74,13 @@ public struct PocketTtsSynthesizer { let stepModel = try await store.flowlmStep() let flowModel = try await store.flowDecoder() let mimiModel = try await store.mimiDecoder() + let condLayerKeys = try await store.condStepLayerKeys() + let flowlmLayerKeys = try await store.flowLMStepLayerKeys() + let mimiKeys = try await store.mimiDecoderKeys() // 5. Load Mimi initial state (continuous across chunks) let repoDir = try await store.repoDir() - var mimiState = try loadMimiInitialState(from: repoDir) + var mimiState = try loadMimiInitialState(from: repoDir, mimiKeys: mimiKeys) // 6. Create BOS embedding let bosEmb = try createBosEmbedding(constants.bosEmbedding) @@ -101,7 +104,8 @@ public struct PocketTtsSynthesizer { var kvState = try await prefillKVCache( voiceData: voiceData, textEmbeddings: textEmbeddings, - model: condModel + model: condModel, + layerKeys: condLayerKeys ) let prefillElapsed = Date().timeIntervalSince(prefillStart) logger.info( @@ -120,7 +124,8 @@ public struct PocketTtsSynthesizer { sequence: sequence, bosEmb: bosEmb, state: &kvState, - model: stepModel + model: stepModel, + layerKeys: flowlmLayerKeys ) if eosLogit > PocketTtsConstants.eosThreshold && eosStep == nil { @@ -144,7 +149,8 @@ public struct PocketTtsSynthesizer { let frameSamples = try await runMimiDecoder( latent: latent, state: &mimiState, - model: mimiModel + model: mimiModel, + mimiKeys: mimiKeys ) audioChunks.append(frameSamples) @@ -229,10 +235,13 @@ public struct PocketTtsSynthesizer { let stepModel = try await store.flowlmStep() let flowModel = try await store.flowDecoder() let mimiModel = try await store.mimiDecoder() + let condLayerKeys = try await store.condStepLayerKeys() + let flowlmLayerKeys = try await store.flowLMStepLayerKeys() + let mimiKeys = try await store.mimiDecoderKeys() // 5. Load Mimi initial state (continuous across chunks) let repoDir = try await store.repoDir() - var mimiState = try loadMimiInitialState(from: repoDir) + var mimiState = try loadMimiInitialState(from: repoDir, mimiKeys: mimiKeys) // 6. Create BOS embedding let bosEmb = try createBosEmbedding(constants.bosEmbedding) @@ -256,7 +265,8 @@ public struct PocketTtsSynthesizer { var kvState = try await prefillKVCache( voiceData: voiceData, textEmbeddings: textEmbeddings, - model: condModel + model: condModel, + layerKeys: condLayerKeys ) let prefillElapsed = Date().timeIntervalSince(prefillStart) logger.info( @@ -275,7 +285,8 @@ public struct PocketTtsSynthesizer { sequence: sequence, bosEmb: bosEmb, state: &kvState, - model: stepModel + model: stepModel, + layerKeys: flowlmLayerKeys ) if eosLogit > PocketTtsConstants.eosThreshold && eosStep == nil { @@ -300,7 +311,8 @@ public struct PocketTtsSynthesizer { let frameSamples = try await runMimiDecoder( latent: latent, state: &mimiState, - model: mimiModel + model: mimiModel, + mimiKeys: mimiKeys ) audioChunks.append(frameSamples) @@ -405,8 +417,11 @@ public struct PocketTtsSynthesizer { let stepModel = try await store.flowlmStep() let flowModel = try await store.flowDecoder() let mimiModel = try await store.mimiDecoder() + let condLayerKeys = try await store.condStepLayerKeys() + let flowlmLayerKeys = try await store.flowLMStepLayerKeys() + let mimiKeys = try await store.mimiDecoderKeys() let repoDir = try await store.repoDir() - let mimiInitialState = try loadMimiInitialState(from: repoDir) + let mimiInitialState = try loadMimiInitialState(from: repoDir, mimiKeys: mimiKeys) let bosEmb = try createBosEmbedding(constants.bosEmbedding) let seedValue = seed ?? UInt64.random(in: 0...UInt64.max) let chunkCount = chunks.count @@ -421,6 +436,9 @@ public struct PocketTtsSynthesizer { stepModel: stepModel, flowModel: flowModel, mimiModel: mimiModel, + condLayerKeys: condLayerKeys, + flowlmLayerKeys: flowlmLayerKeys, + mimiKeys: mimiKeys, mimiInitialState: mimiInitialState, bosEmb: bosEmb, seedValue: seedValue, @@ -458,8 +476,11 @@ public struct PocketTtsSynthesizer { let stepModel = try await store.flowlmStep() let flowModel = try await store.flowDecoder() let mimiModel = try await store.mimiDecoder() + let condLayerKeys = try await store.condStepLayerKeys() + let flowlmLayerKeys = try await store.flowLMStepLayerKeys() + let mimiKeys = try await store.mimiDecoderKeys() let repoDir = try await store.repoDir() - let mimiInitialState = try loadMimiInitialState(from: repoDir) + let mimiInitialState = try loadMimiInitialState(from: repoDir, mimiKeys: mimiKeys) let bosEmb = try createBosEmbedding(constants.bosEmbedding) let seedValue = seed ?? UInt64.random(in: 0...UInt64.max) let chunkCount = chunks.count @@ -472,6 +493,9 @@ public struct PocketTtsSynthesizer { stepModel: stepModel, flowModel: flowModel, mimiModel: mimiModel, + condLayerKeys: condLayerKeys, + flowlmLayerKeys: flowlmLayerKeys, + mimiKeys: mimiKeys, mimiInitialState: mimiInitialState, bosEmb: bosEmb, seedValue: seedValue, @@ -502,16 +526,31 @@ public struct PocketTtsSynthesizer { let stepModel = try await store.flowlmStep() let flowModel = try await store.flowDecoder() let mimiModel = try await store.mimiDecoder() + let condLayerKeys = try await store.condStepLayerKeys() + let flowlmLayerKeys = try await store.flowLMStepLayerKeys() + let mimiKeys = try await store.mimiDecoderKeys() let repoDir = try await store.repoDir() - let mimiState = try loadMimiInitialState(from: repoDir) + let mimiState = try loadMimiInitialState(from: repoDir, mimiKeys: mimiKeys) let bosEmb = try createBosEmbedding(constants.bosEmbedding) let seedValue = seed ?? UInt64.random(in: 0...UInt64.max) - // One-time voice prefill - let emptyState = try emptyKVCacheState() - let voiceKVSnapshot = try await prefillKVCacheVoice( - state: emptyState, voiceData: voiceData, model: condModel - ) + // One-time voice prefill. Two paths matching `prefillKVCache`: + // - v2 packs (cacheSnapshot != nil): drop pre-baked K/V into cache, + // skip cond_step entirely (`promptLength == 0`, so the loop in + // `prefillKVCacheVoice` would be a no-op anyway). + // - Flat audio prompt (legacy English): feed every voice token + // through cond_step. + let voiceKVSnapshot: KVCacheState + if let snapshot = voiceData.cacheSnapshot { + voiceKVSnapshot = try kvCacheStateFromSnapshot( + snapshot, layers: condLayerKeys.layerCount) + } else { + let emptyState = try emptyKVCacheState(layers: condLayerKeys.layerCount) + voiceKVSnapshot = try await prefillKVCacheVoice( + state: emptyState, voiceData: voiceData, model: condModel, + layerKeys: condLayerKeys + ) + } logger.info( "Session voice prefill at position \(Int(voiceKVSnapshot.positions[0][0].floatValue))" @@ -525,6 +564,9 @@ public struct PocketTtsSynthesizer { stepModel: stepModel, flowModel: flowModel, mimiModel: mimiModel, + condLayerKeys: condLayerKeys, + flowlmLayerKeys: flowlmLayerKeys, + mimiKeys: mimiKeys, bosEmb: bosEmb, temperature: temperature, seed: seedValue @@ -548,6 +590,9 @@ public struct PocketTtsSynthesizer { let stepModel: MLModel let flowModel: MLModel let mimiModel: MLModel + let condLayerKeys: PocketTtsLayerKeys + let flowlmLayerKeys: PocketTtsLayerKeys + let mimiKeys: PocketTtsMimiKeys var mimiState: MimiState let bosEmb: MLMultiArray var rng: SeededRNG @@ -562,6 +607,9 @@ public struct PocketTtsSynthesizer { stepModel: MLModel, flowModel: MLModel, mimiModel: MLModel, + condLayerKeys: PocketTtsLayerKeys, + flowlmLayerKeys: PocketTtsLayerKeys, + mimiKeys: PocketTtsMimiKeys, mimiInitialState: MimiState, bosEmb: MLMultiArray, seedValue: UInt64, @@ -575,6 +623,9 @@ public struct PocketTtsSynthesizer { self.stepModel = stepModel self.flowModel = flowModel self.mimiModel = mimiModel + self.condLayerKeys = condLayerKeys + self.flowlmLayerKeys = flowlmLayerKeys + self.mimiKeys = mimiKeys self.mimiState = mimiInitialState self.bosEmb = bosEmb self.rng = SeededRNG(seed: seedValue) @@ -610,7 +661,8 @@ public struct PocketTtsSynthesizer { let result = try await PocketTtsSynthesizer.runMimiDecoder( latent: latent, state: &localState, - model: mimiModel + model: mimiModel, + mimiKeys: mimiKeys ) mimiState = localState return result @@ -626,7 +678,8 @@ public struct PocketTtsSynthesizer { sequence: sequence, bosEmb: bosEmb, state: &localState, - model: stepModel + model: stepModel, + layerKeys: flowlmLayerKeys ) kvState = localState return result @@ -650,7 +703,8 @@ public struct PocketTtsSynthesizer { var kvState = try await PocketTtsSynthesizer.prefillKVCache( voiceData: voiceData, textEmbeddings: textEmbeddings, - model: condModel + model: condModel, + layerKeys: condLayerKeys ) let maxGenLen = PocketTtsSynthesizer.estimateMaxFrames(text: chunkText) @@ -996,11 +1050,16 @@ public struct PocketTtsSynthesizer { // MARK: - Embedding /// Look up text token embeddings from the embedding table. + /// + /// Vocab size is derived from the actual loaded table because v2 language + /// packs ship per-language `text_embed_table` with potentially different + /// row counts (`PocketTtsConstants.vocabSize` only matches the legacy + /// English pack). static func embedTokens( _ tokenIds: [Int], constants: PocketTtsConstantsBundle ) -> [[Float]] { let dim = PocketTtsConstants.embeddingDim - let vocabSize = PocketTtsConstants.vocabSize + let vocabSize = constants.textEmbedTable.count / dim return tokenIds.map { id in guard id >= 0, id < vocabSize else { logger.warning("Token ID \(id) out of range [0, \(vocabSize)), clamping") diff --git a/Sources/FluidAudio/TTS/PocketTTS/PocketTtsConstants.swift b/Sources/FluidAudio/TTS/PocketTTS/PocketTtsConstants.swift index abb4ef0e7..ec625c2d9 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/PocketTtsConstants.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/PocketTtsConstants.swift @@ -37,8 +37,6 @@ public enum PocketTtsConstants { // MARK: - KV cache - /// Number of transformer layers, each with its own KV cache. - public static let kvCacheLayers: Int = 6 /// Max KV cache positions: voice (~125) + text (≤50) + generated frames. public static let kvCacheMaxLen: Int = 512 @@ -52,3 +50,39 @@ public enum PocketTtsConstants { public static let defaultModelsSubdirectory: String = "Models" } + +/// Supported PocketTTS language packs (matches upstream +/// `kyutai/pocket-tts/languages//` folder names exactly). +/// +/// File layout on `FluidInference/pocket-tts-coreml`: +/// - `english`: legacy root layout (`mimi_decoder_v2.mlmodelc` etc.) +/// - other languages: `v2//` subtree with `mimi_decoder.mlmodelc` +public enum PocketTtsLanguage: String, Sendable, CaseIterable { + case english + case french24L = "french_24l" + case german + case german24L = "german_24l" + case italian + case italian24L = "italian_24l" + case portuguese + case portuguese24L = "portuguese_24l" + case spanish + case spanish24L = "spanish_24l" + + /// Number of transformer layers in this language pack (6 or 24). + public var transformerLayers: Int { + switch self { + case .english, .german, .italian, .portuguese, .spanish: + return 6 + case .french24L, .german24L, .italian24L, .portuguese24L, .spanish24L: + return 24 + } + } + + /// HF subdirectory under the pocket-tts-coreml repo root. + /// English returns `nil` to preserve the legacy root-level layout + /// (avoids forcing existing caches to re-download). + public var repoSubdirectory: String? { + self == .english ? nil : "v2/\(rawValue)" + } +} diff --git a/Sources/FluidAudio/TTS/PocketTTS/PocketTtsManager.swift b/Sources/FluidAudio/TTS/PocketTTS/PocketTtsManager.swift index 0fcda7239..b0184a0b1 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/PocketTtsManager.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/PocketTtsManager.swift @@ -20,18 +20,26 @@ public actor PocketTtsManager { private var defaultVoice: String private var isInitialized = false + /// The language pack this manager loads. Immutable for the lifetime of the + /// manager — to switch languages, create a new `PocketTtsManager`. + public nonisolated let language: PocketTtsLanguage + /// Creates a new PocketTTS manager. /// /// - Parameters: /// - defaultVoice: Default voice identifier (default: "alba"). + /// - language: Which upstream language pack to load. Defaults to + /// `.english` for backward compatibility. /// - directory: Optional override for the base cache directory. /// When `nil`, uses the default platform cache location. public init( defaultVoice: String = PocketTtsConstants.defaultVoice, + language: PocketTtsLanguage = .english, directory: URL? = nil ) { - self.modelStore = PocketTtsModelStore(directory: directory) + self.modelStore = PocketTtsModelStore(language: language, directory: directory) self.defaultVoice = defaultVoice + self.language = language } public var isAvailable: Bool { diff --git a/Sources/FluidAudioCLI/Commands/TTSCommand.swift b/Sources/FluidAudioCLI/Commands/TTSCommand.swift index 2037a94d2..5e6845f90 100644 --- a/Sources/FluidAudioCLI/Commands/TTSCommand.swift +++ b/Sources/FluidAudioCLI/Commands/TTSCommand.swift @@ -146,6 +146,7 @@ public struct TTS { var cloneVoicePath: String? = nil var voiceFilePath: String? = nil var saveVoicePath: String? = nil + var pocketLanguage: PocketTtsLanguage = .english var i = 0 while i < arguments.count { @@ -227,6 +228,21 @@ public struct TTS { saveVoicePath = arguments[i + 1] i += 1 } + case "--language": + if i + 1 < arguments.count { + let raw = arguments[i + 1] + if let parsed = PocketTtsLanguage(rawValue: raw) { + pocketLanguage = parsed + } else { + let supported = PocketTtsLanguage.allCases + .map { $0.rawValue } + .joined(separator: ", ") + logger.warning( + "Unknown PocketTTS language '\(raw)'. Supported: \(supported). Falling back to english." + ) + } + i += 1 + } default: if text == nil { text = argument @@ -258,7 +274,8 @@ public struct TTS { await runPocketTts( text: text, output: output, voice: voice, deEss: deEss, metricsPath: metricsPath, cloneVoicePath: cloneVoicePath, - voiceFilePath: voiceFilePath, saveVoicePath: saveVoicePath) + voiceFilePath: voiceFilePath, saveVoicePath: saveVoicePath, + language: pocketLanguage) return } @@ -501,14 +518,17 @@ public struct TTS { private static func runPocketTts( text: String, output: String, voice: String, deEss: Bool, metricsPath: String?, cloneVoicePath: String?, - voiceFilePath: String?, saveVoicePath: String? + voiceFilePath: String?, saveVoicePath: String?, + language: PocketTtsLanguage ) async { do { let tStart = Date() let pocketVoice = voice == TtsConstants.recommendedVoice ? PocketTtsConstants.defaultVoice : voice - let manager = PocketTtsManager(defaultVoice: pocketVoice) + let manager = PocketTtsManager( + defaultVoice: pocketVoice, language: language) + logger.info("PocketTTS language: \(language.rawValue)") let tLoad0 = Date() try await manager.initialize() @@ -665,6 +685,13 @@ public struct TTS { --voice-file FILE Load previously saved voice .bin file --save-voice FILE Save cloned voice to .bin file for later use + PocketTTS Language Packs: + --language ID Language pack (default: english) + Supported: english, french_24l, + german, german_24l, italian, italian_24l, + portuguese, portuguese_24l, spanish, spanish_24l + Note: French is 24-layer only (no 6-layer pack upstream) + Lexicon file format: # Comments start with # kokoro=kəkˈɔɹO diff --git a/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsLanguageTests.swift b/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsLanguageTests.swift new file mode 100644 index 000000000..141e0abec --- /dev/null +++ b/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsLanguageTests.swift @@ -0,0 +1,116 @@ +import Foundation +import XCTest + +@testable import FluidAudio + +/// Pure-logic unit tests for PocketTTS multi-language plumbing. +/// +/// These tests exercise the path/filename/layer-count derivation that drives +/// HuggingFace downloads and CoreML model selection. They do not require any +/// model files or network access. +final class PocketTtsLanguageTests: XCTestCase { + + // MARK: - PocketTtsLanguage.repoSubdirectory + + func testEnglishHasNoRepoSubdirectory() { + // English keeps the legacy root-level layout so existing caches stay + // valid without re-downloading into a `v2/english/` folder. + XCTAssertNil(PocketTtsLanguage.english.repoSubdirectory) + } + + func testNonEnglishLanguagesUseV2Subdirectory() { + let expected: [(PocketTtsLanguage, String)] = [ + (.french24L, "v2/french_24l"), + (.german, "v2/german"), + (.german24L, "v2/german_24l"), + (.italian, "v2/italian"), + (.italian24L, "v2/italian_24l"), + (.portuguese, "v2/portuguese"), + (.portuguese24L, "v2/portuguese_24l"), + (.spanish, "v2/spanish"), + (.spanish24L, "v2/spanish_24l"), + ] + for (lang, path) in expected { + XCTAssertEqual( + lang.repoSubdirectory, path, + "Unexpected repoSubdirectory for \(lang.rawValue)") + } + } + + func testAllNonEnglishLanguagesAreCovered() { + // Guard against silent additions to the enum that forget to update + // the v2// mapping above. + let nonEnglish = PocketTtsLanguage.allCases.filter { $0 != .english } + XCTAssertEqual(nonEnglish.count, 9) + for lang in nonEnglish { + XCTAssertEqual( + lang.repoSubdirectory, "v2/\(lang.rawValue)", + "Language \(lang.rawValue) does not follow v2/ convention") + } + } + + // MARK: - PocketTtsLanguage.transformerLayers + + func testTransformerLayerCounts() { + // 6L variants (English plus 4 base non-English packs) + let sixLayer: [PocketTtsLanguage] = [ + .english, .german, .italian, .portuguese, .spanish, + ] + for lang in sixLayer { + XCTAssertEqual( + lang.transformerLayers, 6, + "\(lang.rawValue) should be a 6-layer pack") + } + + // 24L variants (note: French ships only the 24L variant upstream) + let twentyFourLayer: [PocketTtsLanguage] = [ + .french24L, .german24L, .italian24L, .portuguese24L, .spanish24L, + ] + for lang in twentyFourLayer { + XCTAssertEqual( + lang.transformerLayers, 24, + "\(lang.rawValue) should be a 24-layer pack") + } + } + + // MARK: - ModelNames.PocketTTS.requiredModels(for:) + + func testEnglishRequiredModelsUsesLegacyMimi() { + // English keeps `mimi_decoder_v2.mlmodelc` for backward-compat with + // the original repo layout on HuggingFace. + let models = ModelNames.PocketTTS.requiredModels(for: .english) + XCTAssertTrue(models.contains(ModelNames.PocketTTS.mimiDecoderLegacyFile)) + XCTAssertFalse(models.contains(ModelNames.PocketTTS.mimiDecoderV2File)) + XCTAssertTrue(models.contains(ModelNames.PocketTTS.condStepFile)) + XCTAssertTrue(models.contains(ModelNames.PocketTTS.flowlmStepFile)) + XCTAssertTrue(models.contains(ModelNames.PocketTTS.flowDecoderFile)) + XCTAssertTrue(models.contains(ModelNames.PocketTTS.constantsBinDir)) + } + + func testNonEnglishRequiredModelsUsesNewMimi() { + // Other languages ship `mimi_decoder.mlmodelc` (not the legacy `_v2`). + for lang in PocketTtsLanguage.allCases where lang != .english { + let models = ModelNames.PocketTTS.requiredModels(for: lang) + XCTAssertTrue( + models.contains(ModelNames.PocketTTS.mimiDecoderV2File), + "\(lang.rawValue) should require mimi_decoder.mlmodelc") + XCTAssertFalse( + models.contains(ModelNames.PocketTTS.mimiDecoderLegacyFile), + "\(lang.rawValue) should NOT require legacy mimi_decoder_v2.mlmodelc") + } + } + + // MARK: - ModelNames.PocketTTS.mimiDecoderFile(for:) + + func testMimiDecoderFilenameDispatch() { + XCTAssertEqual( + ModelNames.PocketTTS.mimiDecoderFile(for: .english), + ModelNames.PocketTTS.mimiDecoderLegacyFile) + for lang in PocketTtsLanguage.allCases where lang != .english { + XCTAssertEqual( + ModelNames.PocketTTS.mimiDecoderFile(for: lang), + ModelNames.PocketTTS.mimiDecoderV2File, + "\(lang.rawValue) should use mimi_decoder.mlmodelc") + } + } +} diff --git a/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsSessionTests.swift b/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsSessionTests.swift index 91a208351..78625bbc6 100644 --- a/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsSessionTests.swift +++ b/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsSessionTests.swift @@ -45,7 +45,8 @@ final class PocketTtsSessionTests: XCTestCase { // MARK: - KV Cache Clone Tests func testCloneKVCacheStateProducesIndependentCopy() throws { - let original = try PocketTtsSynthesizer.emptyKVCacheState() + let original = try PocketTtsSynthesizer.emptyKVCacheState( + layers: PocketTtsLanguage.english.transformerLayers) // Write a known value into the original let ptr = original.caches[0].dataPointer.bindMemory(to: Float.self, capacity: 1) @@ -70,7 +71,8 @@ final class PocketTtsSessionTests: XCTestCase { } func testCloneKVCacheStatePreservesShape() throws { - let original = try PocketTtsSynthesizer.emptyKVCacheState() + let original = try PocketTtsSynthesizer.emptyKVCacheState( + layers: PocketTtsLanguage.english.transformerLayers) let clone = try PocketTtsSynthesizer.cloneKVCacheState(original) XCTAssertEqual(clone.caches.count, original.caches.count)