diff --git a/Package.swift b/Package.swift index 38737110f..a34912889 100644 --- a/Package.swift +++ b/Package.swift @@ -55,6 +55,7 @@ let package = Package( name: "FluidAudioTests", dependencies: [ "FluidAudio", + "FluidAudioCLI", ] ), ], diff --git a/Sources/FluidAudio/Diarizer/DiarizationDER.swift b/Sources/FluidAudio/Diarizer/DiarizationDER.swift new file mode 100644 index 000000000..89c471081 --- /dev/null +++ b/Sources/FluidAudio/Diarizer/DiarizationDER.swift @@ -0,0 +1,292 @@ +// +// DiarizationDER.swift +// LS-EEND-Test +// +// Frame-wise Diarization Error Rate with optimal (Hungarian) speaker +// mapping. Matches the NIST md-eval / pyannote DER definition: +// +// DER = (miss + false_alarm + confusion) / total_ref_speech +// +// where per 10 ms frame t with Nref speakers in the reference and Nsys +// speakers in the hypothesis: +// miss = max(0, Nref − Nsys) +// false_alarm= max(0, Nsys − Nref) +// confusion = min(Nref, Nsys) − Ncorrect_t +// and Ncorrect_t counts hyp speakers whose globally-mapped label is also +// active in the reference at t. The global mapping is the one-to-one +// hyp→ref assignment that maximises total overlap time (Hungarian). +// +// Collar support: `collar` is the full width around each reference +// speaker-change boundary that is excluded from scoring. A frame whose +// midpoint lies within ±collar/2 of any boundary in the reference is +// dropped from every accumulator (matches pyannote.metrics). + +import Foundation + +public struct DERSpeakerSegment: Sendable, Hashable { + public let speaker: String + public let start: Double + public let end: Double + public init(speaker: String, start: Double, end: Double) { + self.speaker = speaker + self.start = start + self.end = end + } +} + +public struct DERResult: Sendable { + public let der: Double + public let confusion: Double + public let falseAlarm: Double + public let miss: Double + public let totalRefSpeech: Double + /// Flat `hypLabel → refLabel` mapping found by Hungarian. Hyp labels + /// that drew no ref partner are omitted. + public let mapping: [String: String] +} + +public enum DiarizationDER { + + /// Compute frame-wise DER. `frameStep` is the analysis grid; segments + /// are discretised by midpoint test. + public static func compute( + ref: [DERSpeakerSegment], + hyp: [DERSpeakerSegment], + frameStep: Double = 0.01, + collar: Double = 0 + ) -> DERResult { + precondition(frameStep > 0) + precondition(collar >= 0) + + // Discover label sets + total duration. + var refLabels: [String] = [] + var hypLabels: [String] = [] + var refIdx: [String: Int] = [:] + var hypIdx: [String: Int] = [:] + var maxEnd: Double = 0 + for s in ref { + if refIdx[s.speaker] == nil { + refIdx[s.speaker] = refLabels.count + refLabels.append(s.speaker) + } + maxEnd = max(maxEnd, s.end) + } + for s in hyp { + if hypIdx[s.speaker] == nil { + hypIdx[s.speaker] = hypLabels.count + hypLabels.append(s.speaker) + } + maxEnd = max(maxEnd, s.end) + } + let numFrames = Int(ceil(maxEnd / frameStep)) + 1 + if numFrames <= 0 || (refLabels.isEmpty && hypLabels.isEmpty) { + return DERResult( + der: 0, confusion: 0, falseAlarm: 0, + miss: 0, totalRefSpeech: 0, mapping: [:]) + } + + // Rasterise each label to a BitSet per frame. + let refMask = rasterise( + ref, labelIdx: refIdx, numLabels: refLabels.count, + numFrames: numFrames, frameStep: frameStep) + let hypMask = rasterise( + hyp, labelIdx: hypIdx, numLabels: hypLabels.count, + numFrames: numFrames, frameStep: frameStep) + + // Overlap matrix O[h][r] = #frames both active. + let H = hypLabels.count + let R = refLabels.count + var overlap = [Int](repeating: 0, count: H * R) + if H > 0 && R > 0 { + for t in 0.. 0 { + let maxO = overlap.max() ?? 0 + var cost = [Int](repeating: maxO, count: n * n) + for h in 0.. 0 { + mapping[h] = r + } + } + } + + // Build collar mask — frame is scorable iff its midpoint is not + // within collar/2 of any reference speaker-change boundary. + let scorable = collarMask( + ref: ref, numFrames: numFrames, frameStep: frameStep, collar: collar) + + // Frame-wise error accumulation under the global mapping. + var sumMiss = 0 + var sumFA = 0 + var sumConf = 0 + var sumRef = 0 + for t in 0..= 0 && refMask[rRow + rm] { nCorrect += 1 } + } + sumMiss += max(0, nRef - nSys) + sumFA += max(0, nSys - nRef) + sumConf += min(nRef, nSys) - nCorrect + sumRef += nRef + } + let missS = Double(sumMiss) * frameStep + let faS = Double(sumFA) * frameStep + let confS = Double(sumConf) * frameStep + let refS = Double(sumRef) * frameStep + let der = refS > 0 ? (missS + faS + confS) / refS : 0 + + var mapOut: [String: String] = [:] + for h in 0..= 0 { mapOut[hypLabels[h]] = refLabels[r] } + } + return DERResult( + der: der, confusion: confS, falseAlarm: faS, miss: missS, + totalRefSpeech: refS, mapping: mapOut + ) + } + + /// Scorability mask: `scorable[t] == false` ⇒ frame is inside a + /// collar around some reference boundary and must be excluded from + /// every accumulator (incl. the `totalRefSpeech` denominator). + /// `collar == 0` returns an all-true mask. + private static func collarMask( + ref: [DERSpeakerSegment], + numFrames: Int, + frameStep: Double, + collar: Double + ) -> [Bool] { + var mask = [Bool](repeating: true, count: numFrames) + if collar <= 0 { return mask } + let half = collar / 2.0 + // Every segment endpoint is a boundary; start + end contribute. + var boundaries: [Double] = [] + boundaries.reserveCapacity(ref.count * 2) + for s in ref where s.end > s.start { + boundaries.append(s.start) + boundaries.append(s.end) + } + for b in boundaries { + let lo = max(0, Int(floor((b - half) / frameStep))) + let hi = min(numFrames, Int(ceil((b + half) / frameStep))) + if hi <= lo { continue } + for t in lo.. [Bool] { + var mask = [Bool](repeating: false, count: numFrames * numLabels) + if numLabels == 0 { return mask } + for seg in segs { + guard let li = labelIdx[seg.speaker], seg.end > seg.start else { continue } + // midpoint-test range: smallest t with (t+0.5)*step >= start, + // largest with (t+0.5)*step < end. + let tStart = max(0, Int(ceil(seg.start / frameStep - 0.5))) + let tEndEx = min(numFrames, Int(ceil(seg.end / frameStep - 0.5))) + if tEndEx <= tStart { continue } + for t in tStart.. [Int] { + if n == 0 { return [] } + // Classic Jonker-Volgenant style implementation adapted for + // square n×n. 1-based indexing in arrays of size n+1 for the + // canonical algorithm — cost is read as `cost[(i-1)*n + (j-1)]`. + let INF = Int.max / 4 + var u = [Int](repeating: 0, count: n + 1) + var v = [Int](repeating: 0, count: n + 1) + var p = [Int](repeating: 0, count: n + 1) + var way = [Int](repeating: 0, count: n + 1) + + for i in 1...n { + p[0] = i + var j0 = 0 + var minv = [Int](repeating: INF, count: n + 1) + var used = [Bool](repeating: false, count: n + 1) + repeat { + used[j0] = true + let i0 = p[j0] + var delta = INF + var j1 = 0 + for j in 1...n where !used[j] { + let cur = cost[(i0 - 1) * n + (j - 1)] - u[i0] - v[j] + if cur < minv[j] { + minv[j] = cur + way[j] = j0 + } + if minv[j] < delta { + delta = minv[j] + j1 = j + } + } + for j in 0...n { + if used[j] { + u[p[j]] += delta + v[j] -= delta + } else { + minv[j] -= delta + } + } + j0 = j1 + } while p[j0] != 0 + repeat { + let j1 = way[j0] + p[j0] = p[j1] + j0 = j1 + } while j0 != 0 + } + var assign = [Int](repeating: -1, count: n) + for j in 1...n { + if p[j] != 0 { assign[p[j] - 1] = j - 1 } + } + return assign + } +} diff --git a/Sources/FluidAudio/Diarizer/DiarizerProtocol.swift b/Sources/FluidAudio/Diarizer/DiarizerProtocol.swift new file mode 100644 index 000000000..e4176e110 --- /dev/null +++ b/Sources/FluidAudio/Diarizer/DiarizerProtocol.swift @@ -0,0 +1,112 @@ +import Foundation + +// MARK: - Diarizer Protocol + +/// Protocol for frame-based end-to-end neural diarization pipelines. +public protocol Diarizer: AnyObject { + /// Whether the processor is initialized and ready + var isAvailable: Bool { get } + + /// Number of confirmed frames processed so far + var numFramesProcessed: Int { get } + + /// Model's target sample rate in Hz + var targetSampleRate: Int? { get } + + /// Output frame rate in Hz + var modelFrameHz: Double? { get } + + /// Number of real speaker output tracks + var numSpeakers: Int? { get } + + /// Diarization timeline + var timeline: DiarizerTimeline { get } + + // MARK: Streaming + + /// Add audio samples to the processing buffer. + /// + /// Implementations may resample the input when `sourceSampleRate` differs from + /// the model's target sample rate. + /// + /// - Parameters: + /// - samples: Mono audio samples to enqueue for diarization. + /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the target rate. + func addAudio(_ samples: C, sourceSampleRate: Double?) throws + where C.Element == Float + + /// Process buffered audio and return any newly available diarization output. + func process() throws -> DiarizerTimelineUpdate? + + /// Add audio and process it in one call. + /// + /// - Parameters: + /// - samples: Mono audio samples to process. + /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the target rate. + /// - Returns: A timeline update containing finalized and tentative output, or `nil` + /// if not enough buffered audio was available to emit frames. + func process(samples: C, sourceSampleRate: Double?) throws -> DiarizerTimelineUpdate? + where C.Element == Float + + // MARK: Offline + + /// Process a complete audio buffer and return the resulting timeline. + /// + /// - Parameters: + /// - samples: Complete mono audio buffer to diarize. + /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the target rate. + /// - keepSpeakers: Whether to keep pre-enrolled speakers. If `nil`, it will keep the speakers if no audio has been added. + /// - finalizeOnCompletion: Whether to finalize the timeline before returning it. + /// - progressCallback: Optional callback receiving `(processedSamples, totalSamples, chunksProcessed)`. + /// - Returns: The diarization timeline for the provided audio. + func processComplete( + _ samples: C, + sourceSampleRate: Double?, + keepingEnrolledSpeakers keepSpeakers: Bool?, + finalizeOnCompletion: Bool, + progressCallback: ((Int, Int, Int) -> Void)? + ) throws -> DiarizerTimeline + where C.Element == Float + + /// Process a complete audio file from a URL. + /// + /// Reads and resamples the file to ``targetSampleRate``, then delegates to + /// ``processComplete(_:finalizeOnCompletion:progressCallback:)``. + /// + /// - Parameters: + /// - audioFileURL: Path to a WAV, CAF, or other audio file. + /// - keepSpeakers: Whether to keep pre-enrolled speakers. + /// - finalizeOnCompletion: Whether to finalize the timeline after processing + /// - progressCallback: Optional callback (processedSamples, totalSamples, chunksProcessed). + /// - Returns: Finalized timeline with segments. + func processComplete( + audioFileURL: URL, + keepingEnrolledSpeakers keepSpeakers: Bool?, + finalizeOnCompletion: Bool, + progressCallback: ((Int, Int, Int) -> Void)? + ) throws -> DiarizerTimeline + + // MARK: Lifecycle + + /// Reset streaming state while keeping model loaded + func reset() + + /// Clean up all resources + func cleanup() + + /// Pre-enroll a speaker before running the diarizer. + /// + /// - Paramters: + /// - samples: Enrollment audio samples. + /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the target rate. + /// - name: The speaker's name. + /// - overwriteAssignedSpeakerName: Whether enrollment may overwrite the name of an already-named slot + /// if the diarizer assigns the audio to that speaker. + /// - Returns: The enrolled speaker. + func enrollSpeaker( + withAudio samples: C, + sourceSampleRate: Double?, + named name: String?, + overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool + ) throws -> DiarizerSpeaker? where C.Element == Float +} diff --git a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift index b52fa938e..320adbe6c 100644 --- a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift +++ b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift @@ -1,120 +1,6 @@ import Foundation -// MARK: - Diarizer Protocol - -/// Protocol for frame-based speaker diarization processors. -/// -/// Both SortformerDiarizer and LS-EEND processors conform to this protocol, -/// providing a unified streaming and offline diarization API. -public protocol Diarizer: AnyObject { - /// Whether the processor is initialized and ready - var isAvailable: Bool { get } - - /// Number of confirmed frames processed so far - var numFramesProcessed: Int { get } - - /// Model's target sample rate in Hz - var targetSampleRate: Int? { get } - - /// Output frame rate in Hz - var modelFrameHz: Double? { get } - - /// Number of real speaker output tracks - var numSpeakers: Int? { get } - - /// Diarization timeline - var timeline: DiarizerTimeline { get } - - // MARK: Streaming - - /// Add audio samples to the processing buffer. - /// - /// Implementations may resample the input when `sourceSampleRate` differs from - /// the model's target sample rate. - /// - /// - Parameters: - /// - samples: Mono audio samples to enqueue for diarization. - /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the target rate. - func addAudio(_ samples: C, sourceSampleRate: Double?) throws - where C.Element == Float - - /// Process buffered audio and return any newly available diarization output. - func process() throws -> DiarizerTimelineUpdate? - - /// Add audio and process it in one call. - /// - /// - Parameters: - /// - samples: Mono audio samples to process. - /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the target rate. - /// - Returns: A timeline update containing finalized and tentative output, or `nil` - /// if not enough buffered audio was available to emit frames. - func process(samples: C, sourceSampleRate: Double?) throws -> DiarizerTimelineUpdate? - where C.Element == Float - - // MARK: Offline - - /// Process a complete audio buffer and return the resulting timeline. - /// - /// - Parameters: - /// - samples: Complete mono audio buffer to diarize. - /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the target rate. - /// - keepSpeakers: Whether to keep pre-enrolled speakers. If `nil`, it will keep the speakers if no audio has been added. - /// - finalizeOnCompletion: Whether to finalize the timeline before returning it. - /// - progressCallback: Optional callback receiving `(processedSamples, totalSamples, chunksProcessed)`. - /// - Returns: The diarization timeline for the provided audio. - func processComplete( - _ samples: C, - sourceSampleRate: Double?, - keepingEnrolledSpeakers keepSpeakers: Bool?, - finalizeOnCompletion: Bool, - progressCallback: ((Int, Int, Int) -> Void)? - ) throws -> DiarizerTimeline - where C.Element == Float - - /// Process a complete audio file from a URL. - /// - /// Reads and resamples the file to ``targetSampleRate``, then delegates to - /// ``processComplete(_:finalizeOnCompletion:progressCallback:)``. - /// - /// - Parameters: - /// - audioFileURL: Path to a WAV, CAF, or other audio file. - /// - keepSpeakers: Whether to keep pre-enrolled speakers. - /// - finalizeOnCompletion: Whether to finalize the timeline after processing - /// - progressCallback: Optional callback (processedSamples, totalSamples, chunksProcessed). - /// - Returns: Finalized timeline with segments. - func processComplete( - audioFileURL: URL, - keepingEnrolledSpeakers keepSpeakers: Bool?, - finalizeOnCompletion: Bool, - progressCallback: ((Int, Int, Int) -> Void)? - ) throws -> DiarizerTimeline - - // MARK: Lifecycle - - /// Reset streaming state while keeping model loaded - func reset() - - /// Clean up all resources - func cleanup() - - /// Pre-enroll a speaker before running the diarizer. - /// - /// - Parameters: - /// - samples: Enrollment audio samples. - /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the target rate. - /// - name: The speaker's name. - /// - overwriteAssignedSpeakerName: Whether enrollment may overwrite the name of an already-named slot - /// if the diarizer assigns the audio to that speaker. - /// - Returns: The enrolled speaker. - func enrollSpeaker( - withAudio samples: C, - sourceSampleRate: Double?, - named name: String?, - overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool - ) throws -> DiarizerSpeaker? where C.Element == Float -} - -// MARK: - Post-Processing Configuration +// MARK: - Timeline Configuration /// Configuration for post-processing diarizer predictions into segments. /// @@ -268,116 +154,171 @@ public struct DiarizerTimelineConfig: Sendable { } } -// MARK: - Speaker +// MARK: - Chunk Result -public final class DiarizerSpeaker: Identifiable, CustomStringConvertible { - /// Speaker ID - public let id: UUID +/// Result from a single streaming diarization step (works with any diarizer). +public struct DiarizerChunkResult: Sendable { + /// Speaker probabilities for finalized frames. + /// Flat array of shape [frameCount, numSpeakers]. + public let finalizedPredictions: [Float] - /// Speaker's string representation - public var description: String { - queue.sync { _name ?? "Speaker \(_index)" } + /// Number of finalized frames in this result + public let finalizedFrameCount: Int + + /// Frame index of the first confirmed frame + public let startFrame: Int + + /// Tentative/preview predictions (may change with future data). + /// Flat array of shape [tentativeFrameCount, numSpeakers]. + public let tentativePredictions: [Float] + + /// Number of tentative frames + public let tentativeFrameCount: Int + + /// Frame index of first tentative frame + public var tentativeStartFrame: Int { startFrame + finalizedFrameCount } + + public init( + startFrame: Int = 0, + finalizedPredictions: [Float], + finalizedFrameCount: Int, + tentativePredictions: [Float] = [], + tentativeFrameCount: Int = 0 + ) { + self.startFrame = startFrame + self.finalizedPredictions = finalizedPredictions + self.finalizedFrameCount = finalizedFrameCount + self.tentativePredictions = tentativePredictions + self.tentativeFrameCount = tentativeFrameCount } - /// Display name - public var name: String? { - get { queue.sync { _name } } - set { queue.sync(flags: .barrier) { _name = newValue } } + /// Get probability for a specific speaker at a confirmed frame + public func probability(speaker: Int, frame: Int, numSpeakers: Int) -> Float { + guard frame < finalizedFrameCount, speaker < numSpeakers else { return 0 } + return finalizedPredictions[frame * numSpeakers + speaker] } - /// Slot in the diarizer predictions - public var index: Int { - get { queue.sync { _index } } - set { queue.sync(flags: .barrier) { _index = newValue } } + /// Get probability for a specific speaker at a tentative frame + public func tentativeProbability(speaker: Int, frame: Int, numSpeakers: Int) -> Float { + guard frame < tentativeFrameCount, speaker < numSpeakers else { return 0 } + return tentativePredictions[frame * numSpeakers + speaker] } +} + +// MARK: - Speaker - /// Confirmed/finalized speech segments that belong to this speaker - public var finalizedSegments: [DiarizerSegment] { - get { queue.sync { _finalizedSegments } } - set { queue.sync(flags: .barrier) { _finalizedSegments = newValue } } +public class DiarizerSpeaker: Identifiable { + public struct Snapshot { + public let name: String? + public let index: Int + public let finalizedSegments: [DiarizerSegment] + public let tentativeSegments: [DiarizerSegment] } - /// Tentative speech segments that belong to this speaker - public var tentativeSegments: [DiarizerSegment] { - get { queue.sync { _tentativeSegments } } - set { queue.sync(flags: .barrier) { _tentativeSegments = newValue } } + /// Serializes mutation of this speaker's segment arrays, index, and name. + private let lock = NSLock() + + /// Speaker ID + public let id: UUID + + /// Display name + public var name: String? + + /// Diarizer output slot + public var index: Int + + /// Finalized speech segments + public var finalizedSegments: [DiarizerSegment] = [] + + /// Preview speech segments + public var tentativeSegments: [DiarizerSegment] = [] + + /// Speaker's string representation + public var description: String { + lock.lock() + defer { lock.unlock() } + return name ?? "Speaker \(index)" } /// Whether this speaker has any segments public var hasSegments: Bool { - queue.sync { !(_finalizedSegments.isEmpty && _tentativeSegments.isEmpty) } + lock.lock() + defer { lock.unlock() } + return !(finalizedSegments.isEmpty && tentativeSegments.isEmpty) } /// Number of segments (finalized + tentative) public var segmentCount: Int { - queue.sync { _finalizedSegments.count + _tentativeSegments.count } + lock.lock() + defer { lock.unlock() } + return finalizedSegments.count + tentativeSegments.count } /// Number of confirmed segments public var finalizedSegmentCount: Int { - queue.sync { _finalizedSegments.count } + lock.lock() + defer { lock.unlock() } + return finalizedSegments.count } /// Number of tentative segments public var tentativeSegmentCount: Int { - queue.sync { _tentativeSegments.count } + lock.lock() + defer { lock.unlock() } + return tentativeSegments.count } /// Last segment (tentative or finalized). Checks tentative segments first, falls back to finalized if none found. public var lastSegment: DiarizerSegment? { - queue.sync { _tentativeSegments.last ?? _finalizedSegments.last } + lock.lock() + defer { lock.unlock() } + return tentativeSegments.last ?? finalizedSegments.last } /// Total duration of segments in seconds (finalized + tentative) public var speechDuration: Float { - queue.sync { - return - (_finalizedSegments.reduce(0.0) { $0 + $1.duration } - + _tentativeSegments.reduce(0.0) { $0 + $1.duration }) - } + lock.lock() + defer { lock.unlock() } + return finalizedSegments.reduce(0.0) { $0 + $1.duration } + + tentativeSegments.reduce(0.0) { $0 + $1.duration } } /// Duration of all finalized segments in seconds public var finalizedSpeechDuration: Float { - queue.sync { - _finalizedSegments.reduce(0.0) { $0 + $1.duration } - } + lock.lock() + defer { lock.unlock() } + return finalizedSegments.reduce(0.0) { $0 + $1.duration } } /// Duration of all tentative segments in seconds public var tentativeSpeechDuration: Float { - queue.sync { - _tentativeSegments.reduce(0.0) { $0 + $1.duration } - } + lock.lock() + defer { lock.unlock() } + return tentativeSegments.reduce(0.0) { $0 + $1.duration } } /// Total number of frames spanned by all segments (finalized + tentative) public var numSpeechFrames: Int { - queue.sync { - return (_finalizedSegments.reduce(0) { $0 + $1.length } + _tentativeSegments.reduce(0) { $0 + $1.length }) - } + lock.lock() + defer { lock.unlock() } + return finalizedSegments.reduce(0) { $0 + $1.length } + + tentativeSegments.reduce(0) { $0 + $1.length } } /// Number of frames in all finalized segments public var numFinalizedSpeechFrames: Int { - queue.sync { - _finalizedSegments.reduce(0) { $0 + $1.length } - } + lock.lock() + defer { lock.unlock() } + return finalizedSegments.reduce(0) { $0 + $1.length } } /// Number of frames in all tentative segments public var numTentativeSpeechFrames: Int { - queue.sync { - _tentativeSegments.reduce(0) { $0 + $1.length } - } + lock.lock() + defer { lock.unlock() } + return tentativeSegments.reduce(0) { $0 + $1.length } } - private var _name: String? - private var _index: Int - private var _finalizedSegments: [DiarizerSegment] = [] - private var _tentativeSegments: [DiarizerSegment] = [] - private let queue = DispatchQueue(label: "FluidAudio.Diarization.DiarizerSpeaker") - /// - Parameters: /// - id: Speaker UUID /// - index: Index in diarizer output @@ -388,63 +329,103 @@ public final class DiarizerSpeaker: Identifiable, CustomStringConvertible { name: String? = nil ) { self.id = id - self._index = index - self._name = name + self.index = index + self.name = name + } + + /// Initialize from a snapshot of a diarizer speaker + public init(from snapshot: consuming Snapshot) { + self.id = UUID() + self.index = snapshot.index + self.name = snapshot.name + self.finalizedSegments = snapshot.finalizedSegments + self.tentativeSegments = snapshot.tentativeSegments + } + + /// Rename the speaker + public func rename(to name: String?) { + lock.lock() + defer { lock.unlock() } + self.name = name + } + + /// Reassign diarizer output slot + public func reassign(toSlot slot: Int) { + lock.lock() + defer { lock.unlock() } + self.index = slot } /// Finalize all segments /// - Parameter minFramesOn: Minimum segment length - public func finalize(enforcingMinFramesOn minFramesOn: Int? = nil) { - queue.sync(flags: .barrier) { - if let minFramesOn { - _tentativeSegments.removeAll { $0.length < minFramesOn } - } - _finalizedSegments.append(contentsOf: _tentativeSegments) - _tentativeSegments.removeAll() - } + public func finalize() { + lock.lock() + defer { lock.unlock() } + finalizedSegments.append(contentsOf: tentativeSegments) + tentativeSegments.removeAll() } /// Clear segments public func reset() { - queue.sync(flags: .barrier) { - _tentativeSegments.removeAll() - _finalizedSegments.removeAll() - } + lock.lock() + defer { lock.unlock() } + tentativeSegments.removeAll() + finalizedSegments.removeAll() + } + + public func rollback(to snapshot: consuming Snapshot, keepingName: Bool = false) { + lock.lock() + defer { lock.unlock() } + if !keepingName { self.name = snapshot.name } + self.index = snapshot.index + self.finalizedSegments = snapshot.finalizedSegments + self.tentativeSegments = snapshot.tentativeSegments + } + + public func takeSnapshot() -> Snapshot { + lock.lock() + defer { lock.unlock() } + return Snapshot( + name: name, + index: index, + finalizedSegments: finalizedSegments, + tentativeSegments: tentativeSegments + ) } /// Clear all tentative segments /// - Parameter keepingCapacity: Whether to keep the reserved capacity in the tentative segments list. - public func removeAllTentative(keepingCapacity: Bool = false) { - queue.sync(flags: .barrier) { - _tentativeSegments.removeAll(keepingCapacity: keepingCapacity) - } + public func clearTentative(keepingCapacity: Bool = false) { + lock.lock() + defer { lock.unlock() } + tentativeSegments.removeAll(keepingCapacity: keepingCapacity) } /// Append a tentative segment /// - Parameter segment: The segment to append public func appendTentative(_ segment: DiarizerSegment) { - queue.sync(flags: .barrier) { - _tentativeSegments.append(segment) - } + lock.lock() + defer { lock.unlock() } + tentativeSegments.append(segment) } /// Append a finalized segment /// - Parameter segment: The segment to append public func appendFinalized(_ segment: DiarizerSegment) { - queue.sync(flags: .barrier) { - _finalizedSegments.append(segment) - } + lock.lock() + defer { lock.unlock() } + finalizedSegments.append(segment) } /// Append a segment, automatically detecting if it's finalized or tentative /// - Parameter segment: The segment to append public func append(_ segment: DiarizerSegment) { - queue.sync(flags: .barrier) { - if segment.isFinalized { - _finalizedSegments.append(segment) - } else { - _tentativeSegments.append(segment) - } + lock.lock() + defer { lock.unlock() } + if segment.isFinalized { + finalizedSegments.append(segment) + } else { + tentativeSegments.append(segment) } } @@ -452,18 +433,18 @@ public final class DiarizerSpeaker: Identifiable, CustomStringConvertible { /// - Returns: The popped segment @discardableResult public func popLastTentative() -> DiarizerSegment? { - queue.sync(flags: .barrier) { - _tentativeSegments.popLast() - } + lock.lock() + defer { lock.unlock() } + return tentativeSegments.popLast() } /// Pop last finalized segment /// - Returns: The popped segment @discardableResult public func popLastFinalized() -> DiarizerSegment? { - queue.sync(flags: .barrier) { - return _finalizedSegments.popLast() - } + lock.lock() + defer { lock.unlock() } + return finalizedSegments.popLast() } /// Pop last tentative or finalized segment @@ -471,12 +452,12 @@ public final class DiarizerSpeaker: Identifiable, CustomStringConvertible { /// - Returns: The popped segment @discardableResult public func popLast(fromFinalized: Bool) -> DiarizerSegment? { - queue.sync(flags: .barrier) { - return - (fromFinalized - ? _finalizedSegments.popLast() - : _tentativeSegments.popLast()) - } + lock.lock() + defer { lock.unlock() } + return + (fromFinalized + ? finalizedSegments.popLast() + : tentativeSegments.popLast()) } /// Pop last segment. Pops the last tentative segment first. Falls back to the last finalized segment if no @@ -484,9 +465,25 @@ public final class DiarizerSpeaker: Identifiable, CustomStringConvertible { /// - Returns: The popped segment @discardableResult public func popLast() -> DiarizerSegment? { - queue.sync(flags: .barrier) { - return _tentativeSegments.popLast() ?? _finalizedSegments.popLast() + lock.lock() + defer { lock.unlock() } + return tentativeSegments.popLast() ?? finalizedSegments.popLast() + } + + /// Pop last segment. Pops the last tentative segment first. Falls back to the last finalized segment if no + /// tentative segments are found. + /// - Returns: The popped segment + @discardableResult + public func popLast( + if predicate: @Sendable (DiarizerSegment) throws -> Bool + ) rethrows -> DiarizerSegment? { + lock.lock() + defer { lock.unlock() } + let last = tentativeSegments.last ?? finalizedSegments.last + guard let last, try predicate(last) else { + return nil } + return tentativeSegments.popLast() ?? finalizedSegments.popLast() } } @@ -600,60 +597,6 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { } } -// MARK: - Chunk Result - -/// Result from a single streaming diarization step (works with any diarizer). -/// -/// Maps directly to `SortformerChunkResult` for Sortformer, -/// and wraps `LSEENDStreamingUpdate` for LS-EEND. -public struct DiarizerChunkResult: Sendable { - /// Speaker probabilities for finalized frames. - /// Flat array of shape [frameCount, numSpeakers]. - public let finalizedPredictions: [Float] - - /// Number of finalized frames in this result - public let finalizedFrameCount: Int - - /// Frame index of the first confirmed frame - public let startFrame: Int - - /// Tentative/preview predictions (may change with future data). - /// Flat array of shape [tentativeFrameCount, numSpeakers]. - public let tentativePredictions: [Float] - - /// Number of tentative frames - public let tentativeFrameCount: Int - - /// Frame index of first tentative frame - public var tentativeStartFrame: Int { startFrame + finalizedFrameCount } - - public init( - startFrame: Int, - finalizedPredictions: [Float], - finalizedFrameCount: Int, - tentativePredictions: [Float] = [], - tentativeFrameCount: Int = 0 - ) { - self.startFrame = startFrame - self.finalizedPredictions = finalizedPredictions - self.finalizedFrameCount = finalizedFrameCount - self.tentativePredictions = tentativePredictions - self.tentativeFrameCount = tentativeFrameCount - } - - /// Get probability for a specific speaker at a confirmed frame - public func probability(speaker: Int, frame: Int, numSpeakers: Int) -> Float { - guard frame < finalizedFrameCount, speaker < numSpeakers else { return 0 } - return finalizedPredictions[frame * numSpeakers + speaker] - } - - /// Get probability for a specific speaker at a tentative frame - public func tentativeProbability(speaker: Int, frame: Int, numSpeakers: Int) -> Float { - guard frame < tentativeFrameCount, speaker < numSpeakers else { return 0 } - return tentativePredictions[frame * numSpeakers + speaker] - } -} - // MARK: - Activity Type /// Methods to measure speech activity for segment activity @@ -682,12 +625,18 @@ public enum DiarizerActivityType: Sendable { /// /// Generalizes `SortformerTimeline` for any frame-based diarizer. Works with /// both Sortformer (fixed 4 speakers) and LS-EEND (variable speaker count). -public final class DiarizerTimeline { - private struct ClosedSegmentStats { - var start: Int - var end: Int - var activitySum: Float - var activeFrameCount: Int +public class DiarizerTimeline { + public struct ConfiguredSnapshot { + let config: DiarizerTimelineConfig + let snapshot: Snapshot + } + + public struct Snapshot { + public let speakers: [Int: DiarizerSpeaker.Snapshot] + public let finalizedPredictions: [Float] + public let tentativePredictions: [Float] + public let numFinalizedFrames: Int + internal let scratches: [SegmentScratch] } public enum KeptOnReset { @@ -698,94 +647,83 @@ public final class DiarizerTimeline { case speakersWithSegments } - private struct StreamingState { - var startFrame: Int - var isSpeaking: Bool - var activitySum: Float - var activeFrameCount: Int - var lastSegment: ClosedSegmentStats? - - init( - startFrame: Int = 0, - isSpeaking: Bool = false, - activitySum: Float = 0, - activeFrameCount: Int = 0, - lastSegment: ClosedSegmentStats? = nil - ) { - self.startFrame = startFrame - self.isSpeaking = isSpeaking - self.activitySum = activitySum - self.activeFrameCount = activeFrameCount - self.lastSegment = lastSegment - } + internal struct SegmentScratch { + var speaking: Bool = false + var hasSegment: Bool = false + var startFrame: Int = .min + var endFrame: Int = .min + var activitySum: Float = 0 + var activeFrameCount: Int = 0 } + /// Serializes mutation of `speakers`, `scratches`, prediction buffers, and + /// finalized cursor across threads. NSLock is non-recursive, so public + /// mutating entry points acquire the lock once and delegate to private + /// `_unlocked` helpers when they need to call other mutating logic. + private let lock = NSLock() + /// Post-processing configuration public let config: DiarizerTimelineConfig /// Finalized frame-wise speaker predictions. /// Flat array of shape [numFrames, numSpeakers]. - public var finalizedPredictions: [Float] { - queue.sync { _finalizedPredictions } - } + public var finalizedPredictions: [Float] = [] /// Tentative predictions. /// Flat array of shape [numTentative, numSpeakers]. - public var tentativePredictions: [Float] { - queue.sync { _tentativePredictions } - } + public var tentativePredictions: [Float] = [] /// Total number of finalized frames public var numFinalizedFrames: Int { - queue.sync { _numFinalizedFrames } + lock.lock() + defer { lock.unlock() } + return finalizedCursorFrame } /// Number of tentative frames public var numTentativeFrames: Int { - queue.sync { _tentativePredictions.count / speakerCapacity } + lock.lock() + defer { lock.unlock() } + return tentativePredictions.count / speakerCapacity } /// Total number of frames (finalized + tentative) public var numFrames: Int { - queue.sync { _numFinalizedFrames + _tentativePredictions.count / speakerCapacity } + lock.lock() + defer { lock.unlock() } + return finalizedCursorFrame + tentativePredictions.count / speakerCapacity } /// Speakers in the timeline - public var speakers: [Int: DiarizerSpeaker] { - get { queue.sync { _speakers } } - set { - queue.sync(flags: .barrier) { - let maxSpeakers = speakerCapacity - - _speakers = newValue.filter { key, _ in - key >= 0 && key < maxSpeakers - } - - for (index, speaker) in _speakers { - speaker.index = index - } - } - } - } + public private(set) var speakers: [Int: DiarizerSpeaker] /// Whether the timeline has any segments public var hasSegments: Bool { - speakers.values.contains(where: \.hasSegments) + lock.lock() + defer { lock.unlock() } + return speakers.values.contains(where: \.hasSegments) } /// Duration of finalized predictions in seconds public var finalizedDuration: Float { - Float(numFinalizedFrames) * config.frameDurationSeconds + lock.lock() + defer { lock.unlock() } + return Float(finalizedCursorFrame) * config.frameDurationSeconds } /// Duration of tentative predictions in seconds public var tentativeDuration: Float { - Float(numTentativeFrames) * config.frameDurationSeconds + lock.lock() + defer { lock.unlock() } + return Float(tentativePredictions.count / speakerCapacity) * config.frameDurationSeconds } /// Duration of all predictions (finalized + tentative) in seconds public var duration: Float { - Float(numFrames) * config.frameDurationSeconds + lock.lock() + defer { lock.unlock() } + return Float(finalizedCursorFrame + tentativePredictions.count / speakerCapacity) + * config.frameDurationSeconds } /// Maximum number of speakers @@ -793,16 +731,8 @@ public final class DiarizerTimeline { config.numSpeakers } - private var _finalizedPredictions: [Float] = [] - private var _tentativePredictions: [Float] = [] - private var _speakers: [Int: DiarizerSpeaker] = [:] - private var _numFinalizedFrames: Int = 0 - - // Segment builder state - private var states: [StreamingState] - - private let queue = DispatchQueue(label: "FluidAudio.Diarizer.DiarizerTimeline") - + private var finalizedCursorFrame: Int = 0 + private var scratches: [SegmentScratch] private static let logger = AppLogger(category: "DiarizerTimeline") // MARK: - Init @@ -810,8 +740,8 @@ public final class DiarizerTimeline { /// Initialize for streaming usage public init(config: DiarizerTimelineConfig) { self.config = config - states = Array(repeating: .init(), count: config.numSpeakers) - _speakers = [:] + scratches = Array(repeating: .init(), count: config.numSpeakers) + speakers = [:] } /// Initialize with existing probabilities (batch processing or restored state) @@ -844,6 +774,26 @@ public final class DiarizerTimeline { ) } + /// Initialize from a snapshot + public init(from snapshot: consuming Snapshot, withConfig config: DiarizerTimelineConfig) { + self.config = config + self.finalizedPredictions = snapshot.finalizedPredictions + self.tentativePredictions = snapshot.tentativePredictions + self.finalizedCursorFrame = snapshot.numFinalizedFrames + self.scratches = snapshot.scratches + self.speakers = [:] + self.speakers.reserveCapacity(snapshot.speakers.count) + + for (slot, speakerSnapshot) in snapshot.speakers { + self.speakers[slot] = DiarizerSpeaker(from: speakerSnapshot) + } + } + + /// Initialize from a snapshot + public convenience init(from snapshot: consuming ConfiguredSnapshot) { + self.init(from: snapshot.snapshot, withConfig: snapshot.config) + } + // MARK: - Streaming API /// Add new predictions from the diarizer @@ -852,133 +802,130 @@ public final class DiarizerTimeline { finalizedPredictions: [Float], tentativePredictions: [Float] ) throws -> DiarizerTimelineUpdate { - let numFinalized = finalizedPredictions.count / speakerCapacity - let numTentative = tentativePredictions.count / speakerCapacity - + lock.lock() + defer { lock.unlock() } + let finalizedCount = finalizedPredictions.count / speakerCapacity + let tentativeCount = tentativePredictions.count / speakerCapacity let chunk = DiarizerChunkResult( - startFrame: self.numFinalizedFrames, + startFrame: finalizedCursorFrame, finalizedPredictions: finalizedPredictions, - finalizedFrameCount: numFinalized, + finalizedFrameCount: finalizedCount, tentativePredictions: tentativePredictions, - tentativeFrameCount: numTentative + tentativeFrameCount: tentativeCount ) - - return try addChunk(chunk) + return try _addChunkUnlocked(consume chunk) } /// Add a new chunk of predictions from the diarizer @discardableResult public func addChunk(_ chunk: DiarizerChunkResult) throws -> DiarizerTimelineUpdate { - try queue.sync(flags: .barrier) { - try verifyPredictionCounts( - finalized: chunk.finalizedPredictions, - tentative: chunk.tentativePredictions - ) - - _finalizedPredictions.append(contentsOf: chunk.finalizedPredictions) - _tentativePredictions = chunk.tentativePredictions - - for speaker in _speakers.values { - speaker.removeAllTentative(keepingCapacity: true) - } + lock.lock() + defer { lock.unlock() } + return try _addChunkUnlocked(chunk) + } - var confirmedCounts = [Int](repeating: 0, count: speakerCapacity) - for (index, speaker) in _speakers { - confirmedCounts[index] = speaker.finalizedSegmentCount - } + private func _addChunkUnlocked(_ chunk: DiarizerChunkResult) throws -> DiarizerTimelineUpdate { + try verifyPredictionCounts( + finalized: chunk.finalizedPredictions, + tentative: chunk.tentativePredictions + ) - updateSegments( - predictions: chunk.finalizedPredictions, - numFrames: chunk.finalizedFrameCount, - isFinalized: true, - addTrailingTentative: false - ) + // Update predictions + if config.maxStoredFrames != 0 { + finalizedPredictions.append(contentsOf: chunk.finalizedPredictions) + trimPredictions() + } + tentativePredictions = chunk.tentativePredictions - _numFinalizedFrames += chunk.finalizedFrameCount + // Clear tentative segments + for speaker in speakers.values { + speaker.clearTentative(keepingCapacity: true) + } - updateSegments( - predictions: chunk.tentativePredictions, - numFrames: chunk.tentativeFrameCount, - isFinalized: false, - addTrailingTentative: true - ) + // Extract new segments + var newFinalized: [DiarizerSegment] = [] + var newTentative: [DiarizerSegment] = [] - trimPredictions() + updateSegments( + predictions: chunk.finalizedPredictions, + isFinalized: true, + addTrailingTentative: false, + emittingFinalizedTo: &newFinalized, + emittingTentativeTo: &newTentative + ) - let newConfirmed = _speakers.flatMap { (index, speaker) in - let startIndex = confirmedCounts[index] - guard startIndex < speaker.finalizedSegmentCount else { - return ArraySlice() - } - return speaker.finalizedSegments.suffix(from: startIndex) - } + finalizedCursorFrame += chunk.finalizedFrameCount - let newTentative = _speakers.values.flatMap(\.tentativeSegments) + updateSegments( + predictions: chunk.tentativePredictions, + isFinalized: false, + addTrailingTentative: true, + emittingFinalizedTo: &newFinalized, + emittingTentativeTo: &newTentative + ) - return DiarizerTimelineUpdate( - finalizedSegments: newConfirmed, - tentativeSegments: newTentative, - chunkResult: chunk - ) - } + return DiarizerTimelineUpdate( + finalizedSegments: consume newFinalized, + tentativeSegments: consume newTentative, + chunkResult: consume chunk + ) } /// Finalize all tentative data at end of recording public func finalize() { - queue.sync(flags: .barrier) { finalizeLocked() } + lock.lock() + defer { lock.unlock() } + _finalizeUnlocked() } - private func finalizeLocked() { - _finalizedPredictions.append(contentsOf: _tentativePredictions) - _numFinalizedFrames += _tentativePredictions.count / speakerCapacity - _tentativePredictions.removeAll() - for speaker in _speakers.values { - speaker.finalize(enforcingMinFramesOn: config.minFramesOn) + private func _finalizeUnlocked() { + finalizedPredictions.append(contentsOf: tentativePredictions) + finalizedCursorFrame += tentativePredictions.count / speakerCapacity + tentativePredictions.removeAll() + for speaker in speakers.values { + speaker.finalize() } trimPredictions() } /// Reset to initial state /// - Parameter condition: Condition for when to keep a speaker. All speakers still have their segments reset. - public func reset(keepingSpeakersWhere condition: (DiarizerSpeaker) -> Bool) { - queue.sync(flags: .barrier) { resetLocked(keepingSpeakersWhere: condition) } + public func reset( + keepingSpeakersWhere condition: (DiarizerSpeaker) -> Bool + ) { + lock.lock() + defer { lock.unlock() } + finalizedPredictions.removeAll() + tentativePredictions.removeAll() + finalizedCursorFrame = 0 + scratches = Array(repeating: .init(), count: speakerCapacity) + + speakers = speakers.filter { _, speaker in condition(speaker) } + for speaker in speakers.values { + speaker.reset() + } } /// Reset to initial state /// - Parameter keepingSpeakers: Whether to keep existing speakers enrolled. Their segments are still reset. public func reset(keepingSpeakers: Bool = false) { - queue.sync(flags: .barrier) { - resetLocked(keepingSpeakers: keepingSpeakers) - } + lock.lock() + defer { lock.unlock() } + _resetUnlocked(keepingSpeakers: keepingSpeakers) } - private func resetLocked(keepingSpeakersWhere condition: (DiarizerSpeaker) -> Bool) { - _finalizedPredictions.removeAll() - _tentativePredictions.removeAll() - _numFinalizedFrames = 0 - states = Array(repeating: .init(), count: speakerCapacity) - - _speakers = _speakers.filter { - condition($0.value) - } - - for speaker in _speakers.values { - speaker.reset() - } - } - - private func resetLocked(keepingSpeakers: Bool) { - _finalizedPredictions.removeAll() - _tentativePredictions.removeAll() - _numFinalizedFrames = 0 - states = Array(repeating: .init(), count: speakerCapacity) + private func _resetUnlocked(keepingSpeakers: Bool) { + finalizedPredictions.removeAll() + tentativePredictions.removeAll() + finalizedCursorFrame = 0 + scratches = Array(repeating: .init(), count: speakerCapacity) if keepingSpeakers { - for speaker in _speakers.values { + for speaker in speakers.values { speaker.reset() } } else { - _speakers = [:] + speakers.removeAll(keepingCapacity: true) } } @@ -988,43 +935,98 @@ public final class DiarizerTimeline { /// - tentativePredictions: Tentative prediction matrix `[numFrames, numSpeakers]` flattened /// - keepingSpeakers: Whether to keep the old speaker names and slots. /// - isComplete: Whether to finalize the timeline afterward. + @discardableResult public func rebuild( finalizedPredictions: [Float], tentativePredictions: [Float], keepingSpeakers: Bool = false, isComplete: Bool = true - ) throws { - try verifyPredictionCounts(finalized: finalizedPredictions, tentative: tentativePredictions) + ) throws -> DiarizerTimelineUpdate { + lock.lock() + defer { lock.unlock() } - queue.sync(flags: .barrier) { - resetLocked(keepingSpeakers: keepingSpeakers) - _finalizedPredictions = finalizedPredictions - _tentativePredictions = tentativePredictions + try verifyPredictionCounts( + finalized: finalizedPredictions, + tentative: tentativePredictions + ) - let numFinalizedFrames = finalizedPredictions.count / speakerCapacity - let numTentativeFrames = tentativePredictions.count / speakerCapacity + var newFinalized: [DiarizerSegment] = [] + var newTentative: [DiarizerSegment] = [] - updateSegments( - predictions: finalizedPredictions, - numFrames: numFinalizedFrames, - isFinalized: true, - addTrailingTentative: false - ) + let chunk = DiarizerChunkResult( + startFrame: 0, + finalizedPredictions: finalizedPredictions, + finalizedFrameCount: finalizedPredictions.count / speakerCapacity, + tentativePredictions: tentativePredictions, + tentativeFrameCount: tentativePredictions.count / speakerCapacity + ) - _numFinalizedFrames = numFinalizedFrames + _resetUnlocked(keepingSpeakers: keepingSpeakers) + self.finalizedPredictions = finalizedPredictions + self.tentativePredictions = tentativePredictions - updateSegments( - predictions: tentativePredictions, - numFrames: numTentativeFrames, - isFinalized: false, - addTrailingTentative: true - ) - if isComplete { - finalizeLocked() - } else { - trimPredictions() - } + updateSegments( + predictions: finalizedPredictions, + isFinalized: true, + addTrailingTentative: false, + emittingFinalizedTo: &newFinalized, + emittingTentativeTo: &newTentative + ) + + finalizedCursorFrame = finalizedPredictions.count / speakerCapacity + + updateSegments( + predictions: tentativePredictions, + isFinalized: false, + addTrailingTentative: true, + emittingFinalizedTo: &newFinalized, + emittingTentativeTo: &newTentative + ) + + if isComplete { + _finalizeUnlocked() + } else { + trimPredictions() + } + + return DiarizerTimelineUpdate( + finalizedSegments: consume newFinalized, + tentativeSegments: consume newTentative, + chunkResult: consume chunk + ) + } + + public func rollback(to snapshot: consuming Snapshot, keepingSpeakers: Bool = false) { + lock.lock() + defer { lock.unlock() } + self.finalizedPredictions = snapshot.finalizedPredictions + self.tentativePredictions = snapshot.tentativePredictions + self.finalizedCursorFrame = snapshot.numFinalizedFrames + self.scratches = snapshot.scratches + + for (slot, speakerSnapshot) in snapshot.speakers { + speakers[slot]?.rollback(to: speakerSnapshot, keepingName: keepingSpeakers) } + + guard !keepingSpeakers else { return } + speakers = speakers.filter { slot, _ in snapshot.speakers[slot] != nil } + } + + public func takeSnapshot() -> Snapshot { + lock.lock() + defer { lock.unlock() } + var speakersSnapshots: [Int: DiarizerSpeaker.Snapshot] = [:] + for (slot, speaker) in speakers { + speakersSnapshots[slot] = speaker.takeSnapshot() + } + + return Snapshot( + speakers: speakersSnapshots, + finalizedPredictions: finalizedPredictions, + tentativePredictions: tentativePredictions, + numFinalizedFrames: finalizedCursorFrame, + scratches: scratches + ) } // MARK: Speaker Management @@ -1039,23 +1041,23 @@ public final class DiarizerTimeline { named name: String? = nil, atIndex index: Int? = nil ) -> DiarizerSpeaker? { - queue.sync(flags: .barrier) { - let index = index ?? (0..= 0, index < speakerCapacity else { return nil } + // Ensure index is within bounds + guard let index, index >= 0, index < speakerCapacity else { return nil } - if let speaker = _speakers[index] { - // Update old speaker - speaker.name = name - return speaker - } - - // New speaker - let speaker = DiarizerSpeaker(index: index, name: name) - _speakers[index] = speaker + if let speaker = speakers[index] { + // Update old speaker + speaker.rename(to: name) return speaker } + + // New speaker + let speaker = DiarizerSpeaker(index: index, name: name) + speakers[index] = speaker + return speaker } /// Add a speaker to the timeline at a given slot, or replace the old one if it's already occupied @@ -1070,33 +1072,35 @@ public final class DiarizerTimeline { atIndex index: Int? = nil, transferCurrentSegment: Bool = true ) -> DiarizerSpeaker? { - queue.sync(flags: .barrier) { - // Ensure index is within bounds - let index = index ?? (0..= 0, index < speakerCapacity else { - return nil - } + guard let index, index >= 0, index < speakerCapacity else { + return nil + } - if transferCurrentSegment, - states[index].isSpeaking, - let oldSpeaker = _speakers[index], - let oldStartFrame = oldSpeaker.lastSegment?.startFrame, - oldStartFrame >= states[index].startFrame, - let segment = oldSpeaker.popLast() - { - speaker.append(segment) - } + if transferCurrentSegment, + scratches[index].speaking, + let oldSpeaker = speakers[index], + let segment = oldSpeaker.popLast( + if: { [startFrame = scratches[index].startFrame] in + $0.startFrame >= startFrame + }) + { + speaker.append(segment) + } - if transferCurrentSegment { - states[index] = StreamingState() - } + // Clear current segment if we don't want to transfer it + if !transferCurrentSegment { + scratches[index] = SegmentScratch() + } - _speakers[index] = speaker - speaker.index = index + speakers[index] = speaker + speaker.reassign(toSlot: index) - return speaker - } + return speaker } /// Remove speaker at a given index @@ -1109,177 +1113,194 @@ public final class DiarizerTimeline { atIndex index: Int, clearCurrentSegment: Bool = false ) -> DiarizerSpeaker? { + lock.lock() + defer { lock.unlock() } guard index >= 0, index < speakerCapacity else { return nil } - - return queue.sync(flags: .barrier) { - if clearCurrentSegment { - states[index] = StreamingState() - } - - return _speakers.removeValue(forKey: index) + if clearCurrentSegment { + scratches[index] = SegmentScratch() } + + return speakers.removeValue(forKey: index) } // MARK: - Query /// Get probability for a specific speaker at a finalized frame public func probability(speaker: Int, frame: Int) -> Float { - queue.sync { - let frameOffset = (frame - _numFinalizedFrames) * speakerCapacity + _finalizedPredictions.count - guard frameOffset >= 0, - frameOffset < _finalizedPredictions.count, - speaker < speakerCapacity - else { return .nan } - return _finalizedPredictions[frameOffset + speaker] - } + lock.lock() + defer { lock.unlock() } + let frameOffset = (frame - finalizedCursorFrame) * speakerCapacity + finalizedPredictions.count + guard frameOffset >= 0, + frameOffset < finalizedPredictions.count, + speaker < speakerCapacity + else { return .nan } + return finalizedPredictions[frameOffset + speaker] } /// Get probability for a specific speaker at a tentative frame public func tentativeProbability(speaker: Int, frame: Int) -> Float { - queue.sync { - let frameOffset = (frame - _numFinalizedFrames) * speakerCapacity - guard frameOffset >= 0, - frameOffset < _tentativePredictions.count, - speaker < speakerCapacity - else { return .nan } - return _tentativePredictions[frameOffset + speaker] - } + lock.lock() + defer { lock.unlock() } + let frameOffset = (frame - finalizedCursorFrame) * speakerCapacity + guard frameOffset >= 0, + frameOffset < tentativePredictions.count, + speaker < speakerCapacity + else { return .nan } + return tentativePredictions[frameOffset + speaker] } // MARK: - Segment Detection private func updateSegments( - predictions: [Float], - numFrames: Int, + predictions: borrowing [Float], isFinalized: Bool, - addTrailingTentative: Bool + addTrailingTentative: Bool, + emittingFinalizedTo finalizedResult: inout [DiarizerSegment], + emittingTentativeTo tentativeResult: inout [DiarizerSegment] ) { - guard numFrames > 0 || addTrailingTentative else { return } + guard !predictions.isEmpty || addTrailingTentative else { return } - let frameOffset = _numFinalizedFrames - let numSpeakers = speakerCapacity + let frameOffset = finalizedCursorFrame let onset = config.onsetThreshold let offset = config.offsetThreshold let padOnset = config.onsetPadFrames let padOffset = config.offsetPadFrames let minFramesOn = config.minFramesOn let minFramesOff = config.minFramesOff - let frameDuration = config.frameDurationSeconds - let tentativeBuffer = padOnset + padOffset + minFramesOff - let tentativeStartFrame = isFinalized ? (frameOffset + numFrames) - tentativeBuffer : 0 + let numNewFrames = predictions.count / speakerCapacity + let endFrame = frameOffset + numNewFrames + let pad = padOnset + padOffset + let minSegmentLength = pad + minFramesOn + let finalizedEndFrame = isFinalized ? endFrame - minFramesOff - pad : .min let activityFunc = config.activityType.evaluationFunction - for speakerIndex in 0..= offset { - activitySum += activityFunc(activity) - activeFrameCount += 1 + aux.activitySum += activityFunc(activity) + aux.activeFrameCount += 1 continue } - speaking = false - let end = frameOffset + i + padOffset + aux.speaking = false + let end = frame + padOffset - guard end - start > minFramesOn else { - activitySum = 0 - activeFrameCount = 0 + guard end - aux.startFrame >= minSegmentLength else { continue } - wasLastSegmentFinal = isFinalized && (end < tentativeStartFrame) - let meanActivity = activeFrameCount > 0 ? (activitySum / Float(activeFrameCount)) : 0 - - let newSegment = DiarizerSegment( - speakerIndex: speakerIndex, - startFrame: start, - endFrame: end, - finalized: wasLastSegmentFinal, - frameDurationSeconds: frameDuration, - activity: meanActivity - ) + aux.endFrame = end + aux.hasSegment = true + } else if activity > onset { + let start = frame - padOnset + aux.speaking = true - provideSpeaker(forSlot: speakerIndex).append(newSegment) + guard !aux.hasSegment || start - aux.endFrame > minFramesOff else { + aux.activitySum += activityFunc(activity) + aux.activeFrameCount += 1 + aux.hasSegment = false + continue + } - lastSegment = ClosedSegmentStats( - start: start, - end: end, - activitySum: activitySum, - activeFrameCount: activeFrameCount + commitSegment( + from: &aux, + toSlot: speakerIndex, + isFinalized: aux.endFrame < finalizedEndFrame, + emittingIfFinalizedTo: &finalizedResult, + emittingIfTentativeTo: &tentativeResult ) - activitySum = 0 - activeFrameCount = 0 - } else if activity > onset { - start = max(0, frameOffset + i - padOnset) - speaking = true - activitySum = activityFunc(activity) - activeFrameCount = 1 - - if let lastSegment, start - lastSegment.end <= minFramesOff { - start = lastSegment.start - activitySum += lastSegment.activitySum - activeFrameCount += lastSegment.activeFrameCount - _speakers[speakerIndex]?.popLast(fromFinalized: wasLastSegmentFinal) - } + aux.startFrame = start + aux.activitySum = activityFunc(activity) + aux.activeFrameCount = 1 } } + // Commit final pending segment + commitSegment( + from: &aux, + toSlot: speakerIndex, + isFinalized: aux.endFrame < finalizedEndFrame, + emittingIfFinalizedTo: &finalizedResult, + emittingIfTentativeTo: &tentativeResult + ) + if isFinalized { - states[speakerIndex].startFrame = start - states[speakerIndex].isSpeaking = speaking - states[speakerIndex].activitySum = activitySum - states[speakerIndex].activeFrameCount = activeFrameCount - states[speakerIndex].lastSegment = lastSegment + scratches[speakerIndex] = aux + continue } - if addTrailingTentative { - let end = frameOffset + numFrames + padOffset - if speaking && (end > start) { - let meanActivity = activeFrameCount > 0 ? (activitySum / Float(activeFrameCount)) : 0 - let newSegment = DiarizerSegment( - speakerIndex: speakerIndex, - startFrame: start, - endFrame: end, - finalized: false, - frameDurationSeconds: frameDuration, - activity: meanActivity - ) - provideSpeaker(forSlot: speakerIndex).appendTentative(newSegment) - } - } + // Add trailing segment (tentative-path only) + guard addTrailingTentative, aux.speaking else { continue } + aux.endFrame = endFrame + padOffset + guard aux.endFrame - aux.startFrame >= minSegmentLength else { continue } + aux.hasSegment = true + + commitSegment( + from: &aux, + toSlot: speakerIndex, + isFinalized: false, + emittingIfFinalizedTo: &finalizedResult, + emittingIfTentativeTo: &tentativeResult + ) } } - private func provideSpeaker(forSlot speakerIndex: Int) -> DiarizerSpeaker { - if let speaker = _speakers[speakerIndex] { return speaker } + @inline(__always) + private func commitSegment( + from aux: inout SegmentScratch, + toSlot slot: Int, + isFinalized: Bool, + emittingIfFinalizedTo finalizedResult: inout [DiarizerSegment], + emittingIfTentativeTo tentativeResult: inout [DiarizerSegment] + ) { + guard aux.hasSegment else { return } + + let segment = DiarizerSegment( + speakerIndex: slot, + startFrame: aux.startFrame, + endFrame: aux.endFrame, + finalized: isFinalized, + frameDurationSeconds: config.frameDurationSeconds, + activity: aux.activeFrameCount > 0 ? aux.activitySum / Float(aux.activeFrameCount) : 0 + ) + + let speaker: DiarizerSpeaker + if let spk = speakers[slot] { + speaker = consume spk + } else { + let spk = DiarizerSpeaker(index: slot) + speakers[slot] = spk + speaker = consume spk + } + + speaker.append(segment) + + if isFinalized { + finalizedResult.append(consume segment) + } else { + tentativeResult.append(consume segment) + } - let newSpeaker = DiarizerSpeaker(index: speakerIndex) - _speakers[speakerIndex] = newSpeaker - return newSpeaker + aux.hasSegment = false } private func trimPredictions() { guard let maxStoredFrames = config.maxStoredFrames else { return } - let numToRemove = _finalizedPredictions.count - maxStoredFrames * speakerCapacity + let numToRemove = finalizedPredictions.count - maxStoredFrames * speakerCapacity if numToRemove > 0 { - _finalizedPredictions.removeFirst(numToRemove) + finalizedPredictions.removeFirst(numToRemove) } } diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDatatypes.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDatatypes.swift deleted file mode 100644 index cb8457340..000000000 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDatatypes.swift +++ /dev/null @@ -1,533 +0,0 @@ -import CoreML -import Foundation - -/// Errors thrown by the LS-EEND inference pipeline. -public enum LSEENDError: Error, LocalizedError { - /// The model metadata JSON is malformed or contains invalid values. - case invalidMetadata(String) - /// A matrix operation received dimensions that don't match. - case invalidMatrixShape(String) - /// The audio input format is unsupported (e.g. wrong sample rate or empty buffer). - case unsupportedAudio(String) - /// CoreML prediction failed during a model forward pass. - case modelPredictionFailed(String) - /// A required output feature is missing from the CoreML prediction result. - case missingFeature(String) - /// A file path could not be resolved. - case invalidPath(String) - /// The CoreML model could not be loaded or compiled. - case modelLoadFailed(String) - - public var errorDescription: String? { - switch self { - case .invalidMetadata(let message): - return "Invalid LS-EEND metadata: \(message)" - case .invalidMatrixShape(let message): - return "Invalid LS-EEND matrix shape: \(message)" - case .unsupportedAudio(let message): - return "Unsupported LS-EEND audio input: \(message)" - case .modelPredictionFailed(let message): - return "LS-EEND CoreML prediction failed: \(message)" - case .missingFeature(let message): - return "Missing CoreML feature: \(message)" - case .invalidPath(let message): - return "Invalid LS-EEND path: \(message)" - case .modelLoadFailed(let message): - return "Failed to load LS-EEND model: \(message)" - } - } -} - -/// A row-major 2D matrix of `Float` values used throughout the LS-EEND pipeline. -/// -/// Rows typically represent time frames and columns represent speaker channels or feature dimensions. -/// All operations return new matrices (value semantics); the underlying `values` array is stored flat -/// in row-major order. -public struct LSEENDMatrix: Sendable, Equatable { - /// The number of rows (typically time frames). - public let rows: Int - /// The number of columns (typically speakers or feature dimensions). - public let columns: Int - /// Flat row-major storage. Element at `(row, col)` is at index `row * columns + col`. - public var values: [Float] - - /// Creates a matrix with validated dimensions. - /// - /// - Parameters: - /// - rows: Number of rows (must be non-negative). - /// - columns: Number of columns (must be non-negative). - /// - values: Flat row-major values. Count must equal `rows * columns`. - /// - Throws: ``LSEENDError/invalidMatrixShape(_:)`` if dimensions are negative or values count doesn't match. - public init(rows: Int, columns: Int, values: [Float]) throws { - guard rows >= 0, columns >= 0 else { - throw LSEENDError.invalidMatrixShape("Negative dimensions are not supported.") - } - guard values.count == rows * columns else { - throw LSEENDError.invalidMatrixShape( - "Expected \(rows * columns) values, received \(values.count)." - ) - } - self.rows = rows - self.columns = columns - self.values = values - } - - /// Creates a matrix without validating that `values.count == rows * columns`. - /// - /// Use this initializer only when the caller has already guaranteed dimensional consistency. - public init(validatingRows rows: Int, columns: Int, values: [Float]) { - self.rows = rows - self.columns = columns - self.values = values - } - - /// Creates a zero-filled matrix with the given dimensions. - public static func zeros(rows: Int, columns: Int) -> LSEENDMatrix { - LSEENDMatrix( - validatingRows: rows, columns: columns, values: [Float](repeating: 0, count: max(0, rows * columns))) - } - - /// Creates an empty matrix (zero rows) with the given column count. - public static func empty(columns: Int) -> LSEENDMatrix { - zeros(rows: 0, columns: columns) - } - - /// Whether the matrix contains no data (zero rows, zero columns, or empty values). - public var isEmpty: Bool { - rows == 0 || columns == 0 || values.isEmpty - } - - /// Accesses the element at the given row and column. - public subscript(row: Int, column: Int) -> Float { - get { - values[(row * columns) + column] - } - set { - values[(row * columns) + column] = newValue - } - } - - /// Returns the values of a single row as an `ArraySlice`. - public func row(_ index: Int) -> ArraySlice { - let start = index * columns - return values[start..<(start + columns)] - } - - /// Returns a new matrix containing only the first `count` columns of each row. - public func prefixingColumns(_ count: Int) -> LSEENDMatrix { - let clipped = max(0, min(count, columns)) - guard clipped < columns else { return self } - guard rows > 0 else { return .empty(columns: clipped) } - var out = [Float](repeating: 0, count: rows * clipped) - for rowIndex in 0.. [[Float]] { - guard rows > 0, columns > 0 else { return [] } - return (0.. LSEENDMatrix { - if isEmpty { return other } - if other.isEmpty { return self } - precondition(columns == other.columns, "Column count mismatch") - return LSEENDMatrix(validatingRows: rows + other.rows, columns: columns, values: values + other.values) - } - - /// Returns a new matrix with the first `count` rows removed. - public func droppingFirstRows(_ count: Int) -> LSEENDMatrix { - let clipped = max(0, min(count, rows)) - guard clipped > 0 else { return self } - let start = clipped * columns - return LSEENDMatrix( - validatingRows: rows - clipped, columns: columns, values: Array(values[start.. LSEENDMatrix { - let lower = max(0, min(start, rows)) - let upper = max(lower, min(end, rows)) - guard lower < upper else { return .empty(columns: columns) } - let slice = Array(values[(lower * columns)..<(upper * columns)]) - return LSEENDMatrix(validatingRows: upper - lower, columns: columns, values: slice) - } - - /// Returns a new matrix with the element-wise sigmoid function applied to all values. - /// - /// Converts logits to probabilities: `σ(x) = 1 / (1 + exp(-x))`. - public func applyingSigmoid() -> LSEENDMatrix { - guard !values.isEmpty else { return self } - var output = values - for index in output.indices { - output[index] = 1.0 / (1.0 + expf(-values[index])) - } - return LSEENDMatrix(validatingRows: rows, columns: columns, values: output) - } -} - -/// The result of a complete (offline) LS-EEND inference pass. -/// -/// Contains both "real" outputs (speaker tracks only) and "full" outputs -/// (including the two boundary tracks the model uses internally). -public struct LSEENDInferenceResult: Sendable { - /// Speaker logits with boundary tracks removed. Shape: `[frames, realOutputDim]`. - public let logits: LSEENDMatrix - /// Speaker probabilities (sigmoid of ``logits``). Shape: `[frames, realOutputDim]`. - public let probabilities: LSEENDMatrix - /// Raw model logits including boundary tracks. Shape: `[frames, fullOutputDim]`. - public let fullLogits: LSEENDMatrix - /// Probabilities including boundary tracks (sigmoid of ``fullLogits``). - public let fullProbabilities: LSEENDMatrix - /// Output frame rate in Hz (e.g. 10.0 means one frame per 100 ms). - public let frameHz: Double - /// Duration of the input audio in seconds. - public let durationSeconds: Double -} - -/// An incremental update from a streaming LS-EEND session. -/// -/// Each update contains two regions: -/// - **Committed** (`logits` / `probabilities`): frames that have passed through the -/// full encoder and are final. -/// - **Preview** (`previewLogits` / `previewProbabilities`): speculative frames decoded -/// by flushing pending state with zero-padded input. These will be refined by future audio. -public struct LSEENDStreamingUpdate: Sendable { - /// Frame index where the committed region begins. - public var startFrame: Int - /// Committed speaker logits (boundary tracks removed). - public var logits: LSEENDMatrix - /// Committed speaker probabilities (sigmoid of ``logits``). - public var probabilities: LSEENDMatrix - /// Frame index where the preview region begins (equal to ``totalEmittedFrames``). - public var previewStartFrame: Int - /// Speculative speaker logits for frames not yet fully committed. - public var previewLogits: LSEENDMatrix - /// Speculative speaker probabilities (sigmoid of ``previewLogits``). - public var previewProbabilities: LSEENDMatrix - /// Output frame rate in Hz. - public var frameHz: Double - /// Total audio duration processed so far, in seconds. - public var durationSeconds: Double - /// Running total of committed frames emitted across all updates. - public var totalEmittedFrames: Int -} - -/// Progress information for a single chunk in a streaming simulation. -/// -/// Used by ``LSEENDInferenceEngine/simulateStreaming(audioFileURL:chunkSeconds:)`` -/// to report per-chunk statistics. -public struct LSEENDStreamingProgress: Sendable, Codable { - /// One-based index of the chunk being processed. - public let chunkIndex: Int - /// Cumulative audio duration fed to the session, in seconds. - public let bufferSeconds: Double - /// Number of new committed frames emitted by this chunk. - public let numFramesEmitted: Int - /// Running total of committed frames across all chunks so far. - public let totalFramesEmitted: Int - /// Whether this entry represents the final flush (finalization). - public let flush: Bool -} - -/// Combined result of a streaming simulation, pairing the final inference output -/// with per-chunk progress entries. -public struct LSEENDStreamingSimulationResult: Sendable { - /// The complete inference result after all chunks have been processed and finalized. - public let result: LSEENDInferenceResult - /// Per-chunk progress entries logged during the simulation. - public let updates: [LSEENDStreamingProgress] -} - -/// The LS-EEND model variant (dataset the model was trained on). -/// -/// Maps directly to ``ModelNames/LSEEND/Variant``. Each variant corresponds -/// to a different training dataset and produces slightly different diarization behavior. -public typealias LSEENDVariant = ModelNames.LSEEND.Variant - -extension LSEENDVariant: Identifiable { - /// The dataset name used as a stable identifier (e.g. `"DIHARD III"`). - public var id: String { rawValue } -} - -/// Locates the CoreML model and metadata files for a specific LS-EEND variant. -/// -/// Pass the descriptor to ``LSEENDInferenceEngine/init(descriptor:computeUnits:)`` -/// or ``LSEENDDiarizer/initialize(descriptor:)`` to load the model. -public struct LSEENDModelDescriptor: Sendable { - /// The model variant (training dataset). - public let variant: LSEENDVariant - /// URL of the compiled CoreML model (`.mlmodelc`) or model package (`.mlpackage`). - public let modelURL: URL - /// URL of the JSON metadata file describing model dimensions and audio parameters. - public let metadataURL: URL - - private static let logger = AppLogger(category: "LSEENDModelDescriptor") - - /// Creates a descriptor from explicit file paths. - /// - /// - Parameters: - /// - variant: The model variant. - /// - modelURL: Path to the `.mlmodelc` or `.mlpackage` file. - /// - metadataURL: Path to the JSON metadata file. - public init( - variant: LSEENDVariant, - modelURL: URL, - metadataURL: URL - ) { - self.variant = variant - self.modelURL = modelURL - self.metadataURL = metadataURL - } - - /// Download LS-EEND models from HuggingFace and construct a descriptor. - /// - /// Downloads all variant files on first call; subsequent calls use the cache. - /// The returned descriptor points at the cached `.mlmodelc` and `.json` files. - /// - /// - Parameters: - /// - variant: The model variant to load (default: `.dihard3`). - /// - cacheDirectory: Directory to cache downloaded models (defaults to app support) - /// - computeUnits: Model compute units (.cpuOnly seems to be fastest for this model) - /// - Returns: A descriptor ready for ``LSEENDInferenceEngine/init(descriptor:computeUnits:)``. - public static func loadFromHuggingFace( - variant: LSEENDVariant = .dihard3, - cacheDirectory: URL? = nil, - computeUnits: MLComputeUnits = .cpuOnly, - progressHandler: DownloadUtils.ProgressHandler? = nil - ) async throws -> LSEENDModelDescriptor { - await SystemInfo.logOnce(using: logger) - - let directory = - cacheDirectory - ?? FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] - .appendingPathComponent("FluidAudio/Models") - - try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true) - - let repo = Repo.lseend - let repoPath = directory.appendingPathComponent(repo.folderName) - let requiredModels = ModelNames.getRequiredModelNames(for: repo, variant: variant.stem) - - let allModelsExist = requiredModels.allSatisfy { model in - let modelPath = repoPath.appendingPathComponent(model) - return FileManager.default.fileExists(atPath: modelPath.path) - } - - if !allModelsExist { - logger.info("Models not found in cache at \(repoPath.path)") - try await DownloadUtils.downloadRepo( - repo, - to: directory, - variant: variant.stem, - progressHandler: progressHandler - ) - } - - let modelURL = repoPath.appendingPathComponent(variant.modelFile) - let metadataURL = repoPath.appendingPathComponent(variant.configFile) - - return LSEENDModelDescriptor( - variant: variant, - modelURL: modelURL, - metadataURL: metadataURL - ) - } -} - -/// Tensor shapes for the six recurrent state buffers carried between LS-EEND inference steps. -/// -/// Each property is an array of dimension sizes (e.g. `[layers, heads, keyDim, bufferLen]`). -/// These shapes are read from the model metadata JSON and used to allocate -/// zero-initialized `MLMultiArray` tensors at session start. -public struct LSEENDStateShapes: Decodable, Sendable { - /// Encoder retention key-value cache shape. - public let encRetKv: [Int] - /// Encoder retention scale buffer shape. - public let encRetScale: [Int] - /// Encoder convolutional cache shape. - public let encConvCache: [Int] - /// Decoder retention key-value cache shape. - public let decRetKv: [Int] - /// Decoder retention scale buffer shape. - public let decRetScale: [Int] - /// Top-level buffer shape (used for cross-attention between encoder and decoder). - public let topBuffer: [Int] - - enum CodingKeys: String, CodingKey { - case encRetKv = "enc_ret_kv" - case encRetScale = "enc_ret_scale" - case encConvCache = "enc_conv_cache" - case decRetKv = "dec_ret_kv" - case decRetScale = "dec_ret_scale" - case topBuffer = "top_buffer" - } -} - -/// Model configuration decoded from the LS-EEND metadata JSON file. -/// -/// Contains all architectural parameters (layer counts, dimensions, state shapes) -/// and audio processing parameters (sample rate, FFT settings, mel bands). -/// Optional audio fields (`sampleRate`, `winLength`, etc.) fall back to defaults -/// via the `resolved*` computed properties. -public struct LSEENDModelMetadata: Decodable, Sendable { - /// Input feature dimension per frame (nMels × splice window width). - public let inputDim: Int - /// Total output dimension including boundary tracks. - public let fullOutputDim: Int - /// Number of real speaker output tracks (excludes boundary tracks). - public let realOutputDim: Int - /// Number of encoder transformer layers. - public let encoderLayers: Int - /// Number of decoder transformer layers. - public let decoderLayers: Int - /// Hidden dimension of the encoder. - public let encoderDim: Int - /// Number of attention heads. - public let numHeads: Int - /// Key dimension per attention head. - public let keyDim: Int - /// Value dimension per attention head. - public let headDim: Int - /// Length of the encoder convolutional cache (number of frames buffered). - public let encoderConvCacheLen: Int - /// Length of the top-level cross-attention buffer. - public let topBufferLen: Int - /// Number of initial frames consumed before the decoder begins producing output. - public let convDelay: Int - /// Maximum number of speaker slots in the model output. - public let maxNspks: Int - /// Output frame rate in Hz (frames per second). - public let frameHz: Double - /// Target audio sample rate the model expects. - public let targetSampleRate: Int - /// Compute precision used during export (informational, e.g. `"float32"`). - public let computePrecision: String? - /// Tensor shapes for the six recurrent state buffers. - public let stateShapes: LSEENDStateShapes - /// Explicit sample rate override (falls back to ``targetSampleRate`` if nil). - public let sampleRate: Int? - /// STFT window length in samples (defaults to 200 if nil). - public let winLength: Int? - /// STFT hop length in samples (defaults to 80 if nil). - public let hopLength: Int? - /// FFT size (defaults to next power of 2 ≥ ``resolvedWinLength`` if nil). - public let nFFT: Int? - /// Number of mel filterbank channels (inferred from ``inputDim`` if nil). - public let nMels: Int? - /// Context receptive field half-width for splice-and-subsample (inferred if nil). - public let contextRecp: Int? - /// Subsampling factor for feature frames (inferred from frame rate if nil). - public let subsampling: Int? - /// Feature type identifier (informational, e.g. `"logmel_cmvn"`). - public let featType: String? - - enum CodingKeys: String, CodingKey { - case inputDim = "input_dim" - case fullOutputDim = "full_output_dim" - case realOutputDim = "real_output_dim" - case encoderLayers = "encoder_layers" - case decoderLayers = "decoder_layers" - case encoderDim = "encoder_dim" - case numHeads = "num_heads" - case keyDim = "key_dim" - case headDim = "head_dim" - case encoderConvCacheLen = "encoder_conv_cache_len" - case topBufferLen = "top_buffer_len" - case convDelay = "conv_delay" - case maxNspks = "max_nspks" - case frameHz = "frame_hz" - case targetSampleRate = "target_sample_rate" - case computePrecision = "compute_precision" - case stateShapes = "state_shapes" - case sampleRate = "sample_rate" - case winLength = "win_length" - case hopLength = "hop_length" - case nFFT = "n_fft" - case nMels = "n_mels" - case contextRecp = "context_recp" - case subsampling - case featType = "feat_type" - } - - /// Effective sample rate: uses ``sampleRate`` if present, otherwise ``targetSampleRate``. - public var resolvedSampleRate: Int { - sampleRate ?? targetSampleRate - } - - /// Effective STFT window length in samples (defaults to 200). - public var resolvedWinLength: Int { - winLength ?? 200 - } - - /// Effective STFT hop length in samples (defaults to 80). - public var resolvedHopLength: Int { - hopLength ?? 80 - } - - /// Effective FFT size. Uses ``nFFT`` if present, otherwise the smallest power of 2 - /// that is ≥ ``resolvedWinLength``. - public var resolvedFFTSize: Int { - if let nFFT { - return nFFT - } - var fft = 1 - while fft < resolvedWinLength { - fft <<= 1 - } - return fft - } - - /// Effective number of mel filterbank channels. Uses ``nMels`` if present, - /// otherwise inferred from ``inputDim`` and ``resolvedContextRecp``. - public var resolvedMelCount: Int { - if let nMels { - return nMels - } - let inferred = inputDim / max(1, (2 * resolvedContextRecp) + 1) - return max(1, inferred) - } - - /// Effective context receptive field half-width for the splice-and-subsample step. - /// Uses ``contextRecp`` if present, otherwise inferred from ``inputDim`` and mel count. - public var resolvedContextRecp: Int { - if let contextRecp { - return contextRecp - } - let melCount = max(1, nMels ?? 23) - return max(0, ((inputDim / melCount) - 1) / 2) - } - - /// Effective subsampling factor (how many STFT frames map to one model frame). - /// Uses ``subsampling`` if present, otherwise derived from ``frameHz`` and hop length. - public var resolvedSubsampling: Int { - if let subsampling { - return subsampling - } - let denominator = Int(round(frameHz * Double(resolvedHopLength))) - return max(1, resolvedSampleRate / max(1, denominator)) - } - - /// Minimum streaming latency in seconds before the model can produce its first output frame. - /// - /// Accounts for the FFT center padding, context receptive field, and convolutional delay. - public var streamingLatencySeconds: Double { - let fftSize = resolvedFFTSize - return Double( - (fftSize / 2) + (resolvedContextRecp * resolvedHopLength) - + (convDelay * resolvedSubsampling * resolvedHopLength)) - / Double(max(resolvedSampleRate, 1)) - } -} diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizer.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizer.swift index b74affa83..21fb85521 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizer.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizer.swift @@ -1,737 +1,383 @@ +// +// LSEENDDiarizer.swift +// LS-EEND-Test +// +// Streaming LS-EEND (Long-form Streaming End-to-End Neural Diarization) +// implementation. Mirrors the Python CoreMLPipelineDiarizer's per-frame +// semantics (STFT → log10-mel → CMN → subsample+context → T-block CoreML +// call → finalize with silence flush). CPU-optimized: preallocated scratch, +// vDSP for CMN, MLMultiArray reference swapping for state updates. +// + import AVFoundation +import Accelerate import CoreML import Foundation -/// Speaker diarization using LS-EEND (Linear Streaming End-to-End Neural Diarization). -/// -/// Supports both streaming and offline processing, matching the `SortformerDiarizer` API -/// - Important: This class is **not** thread-safe. public final class LSEENDDiarizer: Diarizer { - private let lock = NSLock() - private let logger = AppLogger(category: "LSEENDDiarizer") - - // MARK: - Diarizer Protocol Properties - - /// Accumulated results - public var timeline: DiarizerTimeline { - lock.withLock { return _timeline } - } - - /// Whether the processor is ready for processing - public var isAvailable: Bool { - lock.withLock { return _engine != nil } - } - - /// Number of confirmed frames processed so far - public var numFramesProcessed: Int { - lock.withLock { return _numFramesProcessed } - } - - /// Model's target sample rate in Hz (e.g., 8000) - public var targetSampleRate: Int? { - lock.withLock { return _engine?.targetSampleRate } - } - - /// Output frame rate in Hz (e.g., 10.0) - public var modelFrameHz: Double? { - lock.withLock { return _engine?.modelFrameHz } - } - - /// Number of real speaker tracks (excluding boundary tracks) - public var numSpeakers: Int? { - lock.withLock { return _engine?.metadata.realOutputDim } - } - // MARK: - Additional Properties + // MARK: - Dependencies + private var model: LSEENDModel? = nil + private var session: LSEENDFeatureProvider? = nil - /// Compute units for CoreML inference - public let computeUnits: MLComputeUnits + public var timeline: DiarizerTimeline - /// Post-processing configuration - public var timelineConfig: DiarizerTimelineConfig { - lock.withLock { return _timeline.config } - } + // MARK: - Protocol properties - /// Streaming latency in seconds - public var streamingLatencySeconds: Double? { - lock.withLock { return _engine?.streamingLatencySeconds } - } + public private(set) var isAvailable: Bool = false + /// Number of finalized output frames emitted to the timeline. + /// Tracks `timeline.numFinalizedFrames` (warmup frames stripped by + /// the model are excluded), matching the `Diarizer` protocol contract. + public var numFramesProcessed: Int { timeline.numFinalizedFrames } + /// Input frames fed to the model (including warmup). Used internally + /// to drive the per-chunk warmup calculation. + private var framesFedToModel: Int = 0 + public private(set) var targetSampleRate: Int? + public private(set) var modelFrameHz: Double? + public private(set) var numSpeakers: Int? - /// Total speaker slots in model output (including boundary tracks) - public var decodeMaxSpeakers: Int? { - lock.withLock { return _engine?.decodeMaxSpeakers } - } - - /// Whether a streaming session is currently active. - var hasActiveSession: Bool { - lock.withLock { return _session != nil } - } + private var finalized: Bool = false - // MARK: - Private State - - private var _engine: LSEENDInferenceHelper? - private var _session: LSEENDStreamingSession? - private var _melSpectrogram: AudioMelSpectrogram? - private var _timeline: DiarizerTimeline - private var _numFramesProcessed: Int = 0 - private var _timelineConfig: DiarizerTimelineConfig - private var _visibleStartFrameOffset: Int = 0 - - // Audio buffering - private var pendingAudio: [Float] = [] + private let logger = AppLogger(category: "LSEENDDiarizer") // MARK: - Init - /// Create a processor with default settings. - /// - /// Call `initialize(descriptor:)` before processing audio. - /// - /// - Parameters: - /// - computeUnits: CoreML compute units (default: `.cpuOnly`) - /// - onsetThreshold: Onset threshold for segment detection - /// - offsetThreshold: Offset threshold for segment detection - /// - onsetPadFrames: Padding frames added before each speech segment - /// - offsetPadFrames: Padding frames added after each speech segment - /// - minFramesOn: Minimum segment length in frames (shorter segments are discarded) - /// - minFramesOff: Minimum gap length in frames (shorter gaps are closed) - /// - maxStoredFrames: Maximum number of finalized prediction frames to retain (`nil` = unlimited) - public init( - computeUnits: MLComputeUnits = .cpuOnly, - onsetThreshold: Float = 0.5, - offsetThreshold: Float = 0.5, - onsetPadFrames: Int = 0, - offsetPadFrames: Int = 0, - minFramesOn: Int = 0, - minFramesOff: Int = 0, - maxStoredFrames: Int? = nil - ) { - self.computeUnits = computeUnits - // Placeholder timeline until model is loaded and numSpeakers/frameHz are known - self._timelineConfig = .init( - numSpeakers: 1, - frameDurationSeconds: 0.1, - onsetThreshold: onsetThreshold, - offsetThreshold: offsetThreshold, - onsetPadFrames: onsetPadFrames, - offsetPadFrames: offsetPadFrames, - minFramesOn: minFramesOn, - minFramesOff: minFramesOff, - maxStoredFrames: maxStoredFrames + public init(model: LSEENDModel) throws { + let metadata = model.metadata + self.model = model + self.session = try LSEENDFeatureProvider(from: metadata) + self.timeline = DiarizerTimeline( + config: .default( + numSpeakers: metadata.maxSpeakers, + frameDurationSeconds: metadata.frameDurationSeconds + ) ) - self._timeline = DiarizerTimeline(config: _timelineConfig) - } - - /// Create a processor with default settings. - /// - /// Call `initialize(descriptor:)` before processing audio. - /// - /// - Parameters: - /// - computeUnits: CoreML compute units (default: `.cpuOnly`) - /// - onsetThreshold: Onset threshold for segment detection - /// - offsetThreshold: Offset threshold for segment detection - public init( - computeUnits: MLComputeUnits = .cpuOnly, - timelineConfig: DiarizerTimelineConfig - ) { - self.computeUnits = computeUnits - self._timelineConfig = timelineConfig - // Placeholder timeline until model is loaded and numSpeakers/frameHz are known - self._timeline = DiarizerTimeline(config: timelineConfig) - } - - // MARK: - Initialization - - /// Initialize with a model descriptor. Loads the CoreML model. - /// - /// - Parameter descriptor: Model descriptor specifying variant and file paths - public func initialize(variant: LSEENDVariant = .dihard3) async throws { - let descriptor = try await LSEENDModelDescriptor.loadFromHuggingFace(variant: variant) - try initialize(descriptor: descriptor) - } - - /// Initialize with a model descriptor. Loads the CoreML model. - /// - /// - Parameter descriptor: Model descriptor specifying variant and file paths - public func initialize(descriptor: LSEENDModelDescriptor) throws { - let engine = try LSEENDInferenceHelper(descriptor: descriptor, computeUnits: computeUnits) - let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig) - - lock.withLock { - updateTimelineConfig(engine: engine) - _engine = engine - _melSpectrogram = melSpectrogram - _timeline = DiarizerTimeline(config: _timelineConfig) - _session = nil - resetBuffersLocked() - - logger.info( - "Initialized LS-EEND \(descriptor.variant.rawValue): " - + "\(engine.metadata.realOutputDim) speakers, " - + "\(String(format: "%.1f", engine.modelFrameHz)) Hz, " - + "\(String(format: "%.2f", engine.streamingLatencySeconds))s latency" + self.targetSampleRate = metadata.sampleRate + self.modelFrameHz = Double(metadata.sampleRate) / Double(metadata.hopLength * metadata.subsampling) + self.numSpeakers = metadata.maxSpeakers + self.isAvailable = true + } + + /// Replace model + timeline + derived metadata. Used by init and hot-swap + /// paths so every metadata-derived property stays in lockstep with the + /// currently loaded model. + private func adopt(model: LSEENDModel) throws { + let metadata = model.metadata + self.model = model + self.session = try LSEENDFeatureProvider(from: metadata) + self.timeline = DiarizerTimeline( + config: .default( + numSpeakers: metadata.maxSpeakers, + frameDurationSeconds: metadata.frameDurationSeconds ) - } - } - - /// Initialize with a pre-loaded engine. - public func initialize(engine: LSEENDInferenceHelper) { - let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig) - - lock.withLock { - updateTimelineConfig(engine: engine) - _engine = engine - _melSpectrogram = melSpectrogram - _timeline = DiarizerTimeline(config: _timelineConfig) - _session = nil - resetBuffersLocked() - - logger.info("Initialized LS-EEND with pre-loaded engine") - } - } - - // MARK: - Speaker Priming - - /// Prime the diarizer with enrollment audio to warm the streaming state. - /// - /// This feeds audio through the active streaming session, discards any emitted - /// predictions, and resets the visible timeline so subsequent calls to - /// `process()` start again from frame 0 while keeping the warmed model state. - /// - /// - Parameters: - /// - samples: Audio samples to use for priming. - /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the model rate. - /// - name: The speaker's name. - /// - overwriteAssignedSpeakerName: Whether enrollment may overwrite the name on an already-named slot - /// if the diarizer assigns the audio to that speaker. - /// - Throws: ``LSEENDError/modelPredictionFailed(_:)`` if the diarizer is not initialized. - public func enrollSpeaker( - withSamples samples: [Float], - sourceSampleRate: Double? = nil, - named name: String? = nil, - overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool = true - ) throws -> DiarizerSpeaker? { - try enrollSpeakerInternal( - withAudio: samples, - sourceSampleRate: sourceSampleRate, - named: name, - overwritingAssignedSpeakerName: overwriteAssignedSpeakerName ) + self.targetSampleRate = metadata.sampleRate + self.modelFrameHz = Double(metadata.sampleRate) / Double(metadata.hopLength * metadata.subsampling) + self.numSpeakers = metadata.maxSpeakers + self.isAvailable = true } - /// Prime the diarizer with enrollment audio to warm the streaming state. - /// - /// - Parameters: - /// - samples: Audio samples to use for priming. - /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the model rate. - /// - name: The speaker's name. - /// - overwriteAssignedSpeakerName: Whether enrollment may overwrite the name on an already-named slot - /// if the diarizer assigns the audio to that speaker. - public func enrollSpeaker( - withAudio samples: C, - sourceSampleRate: Double? = nil, - named name: String? = nil, - overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool = true - ) throws -> DiarizerSpeaker? where C.Element == Float { - try enrollSpeakerInternal( - withAudio: Array(samples), - sourceSampleRate: sourceSampleRate, - named: name, - overwritingAssignedSpeakerName: overwriteAssignedSpeakerName + public func loadFromHuggingFace( + variant: LSEENDVariant = .dihard3, + stepSize: LSEENDStepSize = .step100ms, + cacheDirectory: URL? = nil, + computeUnits: MLComputeUnits = .cpuOnly, + progressHandler: DownloadUtils.ProgressHandler? = nil + ) async throws { + let model = try await LSEENDModel.loadFromHuggingFace( + variant: variant, + stepSize: stepSize, + cacheDirectory: cacheDirectory, + computeUnits: computeUnits, + progressHandler: progressHandler ) - } - - private func enrollSpeakerInternal( - withAudio samples: [Float], - sourceSampleRate: Double?, - named name: String?, - overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool - ) throws -> DiarizerSpeaker? { - try lock.withLock { - let description: String = name.map { "named '\($0)'" } ?? "(no name)" - guard let engine = _engine else { - throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") - } - - let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) ?? samples - guard !normalized.isEmpty else { - logger.warning("Failed to enroll speaker \(description) because no speech detected") - return nil - } - - if _timeline.hasSegments { - logger.warning("Trying to enroll a speaker while timeline has segments; timeline will be reset") + // New model may have different convDelay / maxSpeakers / sampleRate, + // so rebuild session + timeline + derived metadata and clear any + // prior streaming state before the next chunk runs. + try adopt(model: model) + resetStreamingState() + } + + // MARK: - Debug helpers (parity tests) + + #if DEBUG + /// Drive `samples` through session → STFT → log10-mel → CMN → + /// subsample+context stack, and return the flat `[N × featDim]` + /// stacked features that would be fed to CoreML. Used by + /// `testFeat345Parity` to byte-compare against the Python fixture + /// without running inference. + internal func debugExtractFeatures( + _ samples: C, sourceSampleRate: Double? + ) throws -> [Float] where C.Element == Float { + guard let session else { throw LSEENDError.notInitialized } + session.reset() + try session.enqueueAudio(samples, withSampleRate: sourceSampleRate) + try session.drainRightContextWithSilence() + + var out: [Float] = [] + while let input = try session.emitNextChunk() { + // `input.melFeatures` is preallocated + reused — copy out each + // pass. Caller-allocated input arrays have tight strides, so a + // flat read is safe (unlike model *output* arrays, which get + // tile-padded strides; see CLAUDE.md gotcha #2). + input.melFeatures.withUnsafeBufferPointer(ofType: Float.self) { buf in + out.append(contentsOf: buf) } - - _timeline.reset(keepingSpeakers: true) - var occupiedIndices = Set(_timeline.speakers.keys) - _numFramesProcessed = 0 - _visibleStartFrameOffset = 0 - pendingAudio.removeAll(keepingCapacity: true) - - if _session == nil { - _session = try engine.createSession( - inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!) - } - guard let session = _session else { - return nil - } - - let update = try session.pushAudio(normalized) - let didProcess = update.map { !$0.probabilities.isEmpty || !$0.previewProbabilities.isEmpty } ?? false - - guard didProcess else { - let minimumSeconds = engine.streamingLatencySeconds - logger.warning( - "Failed to enroll speaker \(description): not enough audio was provided. " - + "Please provide at least \(String(format: "%.2f", minimumSeconds)) seconds of speech." - ) - return nil - } - - if let update { - let numSpeakers = engine.metadata.realOutputDim - let result = DiarizerChunkResult( - startFrame: max(0, update.startFrame - _visibleStartFrameOffset), - finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: update.probabilities.rows, - tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), - tentativeFrameCount: update.previewProbabilities.rows - ) - _numFramesProcessed += result.finalizedFrameCount - _ = try _timeline.addChunk(result) - } - - let speaker = _timeline.speakers.values.max { $0.numSpeechFrames < $1.numSpeechFrames } - let enrolledSpeaker: DiarizerSpeaker? - if let speaker, speaker.hasSegments { - if let oldName = speaker.name { - guard overwriteAssignedSpeakerName else { - logger.warning( - "Failed to enroll speaker \(description): diarizer matched existing speaker '\(oldName)' " - + "at index \(speaker.index) and overwritingAssignedSpeakerName=false" - ) - _visibleStartFrameOffset = session.snapshot().probabilities.rows - _numFramesProcessed = 0 - _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) - pendingAudio.removeAll(keepingCapacity: true) - return nil - } - logger.warning( - "Newly-enrolled speaker \(description) will overwrite the old one named \(oldName) at index \(speaker.index)" - ) - } - speaker.name = name - occupiedIndices.insert(speaker.index) - enrolledSpeaker = speaker - } else { - logger.warning("Failed to enroll speaker \(description) because no speech detected") - enrolledSpeaker = nil - } - - _visibleStartFrameOffset = session.snapshot().probabilities.rows - _numFramesProcessed = 0 - _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) - pendingAudio.removeAll(keepingCapacity: true) - - logger.info( - "Enrolled speaker \(description) with \(normalized.count) samples " - + "(\(String(format: "%.1f", Float(normalized.count) / Float(engine.targetSampleRate)))s), " - + "visible offset=\(_visibleStartFrameOffset)" - ) - - return enrolledSpeaker } + return out } + #endif - // MARK: - Streaming (Diarizer Protocol) + // MARK: - Streaming API - /// Add audio samples to the processing buffer. - /// - /// Audio must be at the model's target sample rate (typically 8000 Hz). - /// Call `process()` after adding audio to run inference. - public func addAudio(_ samples: [Float]) { - try? addAudio(samples, sourceSampleRate: nil) - } - - /// Add audio samples to the processing buffer, resampling when needed. - /// - /// - Parameters: - /// - samples: Mono audio samples to enqueue. - /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the model rate. - /// Add audio samples from any `Collection` of `Float` to the processing buffer. - public func addAudio( - _ samples: C, - sourceSampleRate: Double? = nil - ) throws where C.Element == Float { - try lock.withLock { - if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { - pendingAudio.append(contentsOf: normalized) - } else { - pendingAudio.append(contentsOf: samples) - } + public func addAudio(_ samples: C, sourceSampleRate: Double?) throws + where C.Element == Float { + guard !samples.isEmpty else { return } + guard let session else { + throw LSEENDError.notInitialized } + try session.enqueueAudio(samples, withSampleRate: sourceSampleRate) } - /// Process buffered audio and return any new results. - /// - /// - Returns: New chunk result if inference produced frames, nil otherwise public func process() throws -> DiarizerTimelineUpdate? { - try lock.withLock { return try processLocked() } + try flush(progressCallback: nil) } - /// Add and process a chunk of audio in one call. - /// - Parameters: - /// - samples: Audio samples to process. - /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the model rate. - /// - Returns: New chunk result if inference produced frames, nil otherwise. public func process( - samples: C, - sourceSampleRate: Double? = nil + samples: C, sourceSampleRate: Double? ) throws -> DiarizerTimelineUpdate? where C.Element == Float { - try lock.withLock { - if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { - pendingAudio.append(contentsOf: normalized) - } else { - pendingAudio.append(contentsOf: samples) - } - - return try processLocked() - } - } - - /// Internal process — caller must hold lock. - private func processLocked() throws -> DiarizerTimelineUpdate? { - guard let engine = _engine else { - throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") - } - - guard !pendingAudio.isEmpty else { return nil } - - // Lazily create session on first process call - if _session == nil { - _session = try engine.createSession( - inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!) - } - guard let session = _session else { return nil } - - // Clear unconditionally (even on throw) so failed audio isn't re-fed. - // Using defer + direct pass avoids a CoW copy — pushAudio receives a - // temporary reference, and removeAll runs after it returns (refcount == 1). - defer { pendingAudio.removeAll(keepingCapacity: true) } - - guard let update = try session.pushAudio(pendingAudio) else { - return nil - } - - let numSpeakers = engine.metadata.realOutputDim - let result = DiarizerChunkResult( - startFrame: max(0, update.startFrame - _visibleStartFrameOffset), - finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: update.probabilities.rows, - tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), - tentativeFrameCount: update.previewProbabilities.rows - ) - - _numFramesProcessed += result.finalizedFrameCount - return try _timeline.addChunk(result) + try addAudio(samples, sourceSampleRate: sourceSampleRate) + return try process() } - // MARK: - Offline (Diarizer Protocol) - - /// Progress callback: (processedSamples, totalSamples, chunksProcessed) - public typealias ProgressCallback = (Int, Int, Int) -> Void - - /// Process a complete audio buffer. - /// - /// Resets state (unless pre-enrolled speakers are kept) and pushes all audio at once, then finalizes. - /// - /// - Parameters: - /// - samples: Complete audio samples at the model's target sample rate. - /// - sourceSampleRate: Source audio sample rate (if nil, assumes that it matches the engine's sample rate). - /// - keepSpeakers: Whether to keep pre-enrolled speakers. If `nil`, it keeps the speakers if no more audio was added. - /// - finalizeOnCompletion: Whether to finalize the timeline after processing. - /// - progressCallback: Optional callback (processedSamples, totalSamples, chunksProcessed). - /// - Returns: Finalized timeline with segments. - public func processComplete( - _ samples: [Float], - sourceSampleRate: Double? = nil, - keepingEnrolledSpeakers keepSpeakers: Bool? = nil, - finalizeOnCompletion: Bool = true, - progressCallback: ((Int, Int, Int) -> Void)? = nil - ) throws -> DiarizerTimeline { - try lock.withLock { - try processCompleteLocked( - samples, - sourceSampleRate: sourceSampleRate, - keepingEnrolledSpeakers: keepSpeakers, - finalizeOnCompletion: finalizeOnCompletion, - progressCallback: progressCallback - ) - } - } + // MARK: - Offline API - /// Process a complete audio buffer. - /// - /// Resets state (unless pre-enrolled speakers are kept) and pushes all audio at once, then finalizes. - /// - /// - Parameters: - /// - samples: Complete audio samples. - /// - sourceSampleRate: Source audio sample rate (if `nil`, assumes the model rate). - /// - keepSpeakers: Whether to keep pre-enrolled speakers. If `nil`, it keeps the speakers if no more audio was added. - /// - finalizeOnCompletion: Whether to finalize the timeline after processing. - /// - progressCallback: Optional callback `(processedSamples, totalSamples, chunksProcessed)`. - /// - Returns: Finalized timeline with segments. public func processComplete( _ samples: C, - sourceSampleRate: Double? = nil, - keepingEnrolledSpeakers keepSpeakers: Bool? = nil, + sourceSampleRate: Double?, + keepingEnrolledSpeakers keepSpeakers: Bool?, finalizeOnCompletion: Bool, progressCallback: ((Int, Int, Int) -> Void)? ) throws -> DiarizerTimeline where C.Element == Float { - try lock.withLock { - try processCompleteLocked( - Array(samples), - sourceSampleRate: sourceSampleRate, - keepingEnrolledSpeakers: keepSpeakers, - finalizeOnCompletion: finalizeOnCompletion, - progressCallback: progressCallback - ) + guard session != nil, model != nil else { + throw LSEENDError.notInitialized } + let keep = keepSpeakers ?? !timeline.hasSegments + resetStreamingState() + timeline.reset(keepingSpeakers: keep) + + try addAudio(samples, sourceSampleRate: sourceSampleRate) + try flush( + finalizeOnCompletion: finalizeOnCompletion, + progressCallback: progressCallback + ) + return timeline } - /// Process a complete audio file from a URL. - /// - /// Reads and resamples the file to ``targetSampleRate``, then delegates to - /// ``processComplete(_:finalizeOnCompletion:progressCallback:)``. - /// - /// - Parameters: - /// - audioFileURL: Path to a WAV, CAF, or other audio file. - /// - keepSpeakers: Whether to keep pre-enrolled speakers. If `nil`, it keeps the speakers if no more audio was added. - /// - finalizeOnCompletion: Whether to finalize the timeline after processing - /// - progressCallback: Optional callback (processedSamples, totalSamples, chunksProcessed). - /// - Returns: Finalized timeline with segments. public func processComplete( audioFileURL: URL, - keepingEnrolledSpeakers keepSpeakers: Bool? = nil, - finalizeOnCompletion: Bool = true, - progressCallback: ((Int, Int, Int) -> Void)? = nil + keepingEnrolledSpeakers keepSpeakers: Bool?, + finalizeOnCompletion: Bool, + progressCallback: ((Int, Int, Int) -> Void)? ) throws -> DiarizerTimeline { - try lock.withLock { - guard let engine = _engine else { - throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") - } - - let converter = AudioConverter(sampleRate: Double(engine.targetSampleRate)) - let audio = try converter.resampleAudioFile(audioFileURL) - - return try processCompleteLocked( - audio, - sourceSampleRate: nil, - keepingEnrolledSpeakers: keepSpeakers, - finalizeOnCompletion: finalizeOnCompletion, - progressCallback: progressCallback - ) + guard let session, model != nil else { + throw LSEENDError.notInitialized } + let keep = keepSpeakers ?? !timeline.hasSegments + resetStreamingState() + timeline.reset(keepingSpeakers: keep) + + try session.enqueueAudioFile(at: audioFileURL) + try flush( + finalizeOnCompletion: finalizeOnCompletion, + progressCallback: progressCallback + ) + return timeline } - private func processCompleteLocked( - _ samples: [Float], - sourceSampleRate: Double?, - keepingEnrolledSpeakers keepSpeakers: Bool? = nil, + /// Shared drain path for both `processComplete` overloads. Runs + /// session → model → timeline, optionally finalizing the stream. + private func flush( + recordFrames: Bool = true, finalizeOnCompletion: Bool, progressCallback: ((Int, Int, Int) -> Void)? - ) throws -> DiarizerTimeline { - let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) ?? samples + ) throws { + guard let session else { + throw LSEENDError.notInitialized + } - guard let engine = _engine else { - throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") + if finalizeOnCompletion { + try session.drainRightContextWithSilence() } - let keepSpeakers = keepSpeakers ?? (_numFramesProcessed == 0 && pendingAudio.isEmpty) + _ = try flush(recordFrames: recordFrames, progressCallback: progressCallback) - _timeline.reset(keepingSpeakers: keepSpeakers) - _numFramesProcessed = 0 - pendingAudio.removeAll(keepingCapacity: true) - let useRetainedSession = keepSpeakers && _session != nil - if !keepSpeakers { - _visibleStartFrameOffset = 0 - _session = nil + if finalizeOnCompletion { + timeline.finalize() + finalized = true } + } - let session = - if let retainedSession = _session, useRetainedSession { - retainedSession - } else { - try engine.createSession( - inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!) - } - let numSpeakers = engine.metadata.realOutputDim - - // Push all audio at once - if let update = try session.pushAudio(normalized) { - let chunk = DiarizerChunkResult( - startFrame: max(0, update.startFrame - _visibleStartFrameOffset), - finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: update.probabilities.rows, - tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), - tentativeFrameCount: update.previewProbabilities.rows - ) - _numFramesProcessed += chunk.finalizedFrameCount - try _timeline.addChunk(chunk) + /// Drain all ready chunks through `model.predict` → timeline. Returns the + /// timeline update, or nil if the drain produced no frames. Warmup rows + /// are stripped per-chunk inside `model.predict` (`input.warmupFrames`), + /// so the accumulated stream is already 1:1 with real audio time. + private func flush( + recordFrames: Bool = true, + progressCallback: ((Int, Int, Int) -> Void)? + ) throws -> DiarizerTimelineUpdate? { + guard let session, let model else { + throw LSEENDError.notInitialized } - progressCallback?(normalized.count, normalized.count, 1) + let chunkSize = model.metadata.chunkSize + let numSpeakers = model.metadata.maxSpeakers + let rightContext = model.metadata.convDelay + let totalChunks = session.readyChunks - // Finalize remaining frames - if let finalUpdate = try session.finalize() { - let chunk = DiarizerChunkResult( - startFrame: max(0, finalUpdate.startFrame - _visibleStartFrameOffset), - finalizedPredictions: flattenRowMajor(finalUpdate.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: finalUpdate.probabilities.rows, - tentativePredictions: [], - tentativeFrameCount: 0 - ) - _numFramesProcessed += chunk.finalizedFrameCount - try _timeline.addChunk(chunk) - } - if useRetainedSession { - _session = nil - } + var processed = 0 + var newPreds: [Float] = [] + newPreds.reserveCapacity(totalChunks * numSpeakers * chunkSize) - if finalizeOnCompletion { - _timeline.finalize() + while let input = try session.emitNextChunk() { + if recordFrames { + input.warmupFrames = max(min(rightContext - framesFedToModel, chunkSize), 0) + framesFedToModel += chunkSize + } + newPreds.append(contentsOf: try model.predict(from: input)) + processed += 1 + progressCallback?(processed, totalChunks, 1) } - return _timeline + + guard !newPreds.isEmpty else { return nil } + + return try timeline.addPredictions( + finalizedPredictions: newPreds, + tentativePredictions: [] + ) } - // MARK: - Lifecycle (Diarizer Protocol) + // MARK: - Lifecycle - /// Reset all streaming state for a new audio stream. - /// - /// Preserves the loaded model. Call `initialize()` again to change models. public func reset() { - lock.withLock { - _session = nil - _timeline.reset() - resetBuffersLocked() - logger.debug("LS-EEND state reset") - } + resetStreamingState() + timeline.reset(keepingSpeakers: false) } - /// Clean up all resources including the loaded model. public func cleanup() { - lock.withLock { - _engine = nil - _session = nil - _melSpectrogram = nil - _timeline.reset() - resetBuffersLocked() - logger.info("LS-EEND resources cleaned up") - } + resetStreamingState() + self.model = nil + self.session = nil + isAvailable = false } - // MARK: - LS-EEND Specific + public func enrollSpeaker( + withAudio samples: C, + sourceSampleRate: Double?, + named name: String?, + overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool + ) throws -> DiarizerSpeaker? where C.Element == Float { + guard let session else { + throw LSEENDError.notInitialized + } + + let sessionSnapshot = try session.takeSnapshot() + let timelineSnapshot = timeline.takeSnapshot() + let isNamed = name != nil - /// Finalize the current streaming session. - /// - /// Flushes any remaining frames and finalizes the timeline. - /// After calling this, `process()` will no longer produce results - /// until `reset()` is called. - /// - /// - Returns: Final chunk result if any remaining frames were flushed, nil otherwise - @discardableResult - public func finalizeSession() throws -> DiarizerChunkResult? { - lock.lock() - defer { lock.unlock() } - - guard let engine = _engine, let session = _session else { return nil } - let numSpeakers = engine.metadata.realOutputDim - var lastResult: DiarizerChunkResult? - - // Flush pending audio first — clear unconditionally so failed audio isn't retained. - // Using defer + direct pass avoids a CoW copy. - if !pendingAudio.isEmpty { - defer { pendingAudio.removeAll(keepingCapacity: true) } - let pushedUpdate = try session.pushAudio(pendingAudio) - if let update = pushedUpdate { - let flushedResult = DiarizerChunkResult( - startFrame: _numFramesProcessed, - finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: update.probabilities.rows, - tentativePredictions: [], - tentativeFrameCount: 0 - ) - _numFramesProcessed += flushedResult.finalizedFrameCount - try _timeline.addChunk(flushedResult) - lastResult = flushedResult - } + let requireNewSpeaker = isNamed && !overwriteAssignedSpeakerName + + if timeline.hasSegments { + logger.warning("Enrolling speaker mid session. The timeline will be reset if successful.") } - if let finalUpdate = try session.finalize() { - let finalResult = DiarizerChunkResult( - startFrame: _numFramesProcessed, - finalizedPredictions: flattenRowMajor(finalUpdate.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: finalUpdate.probabilities.rows, - tentativePredictions: [], - tentativeFrameCount: 0 - ) - _numFramesProcessed += finalResult.finalizedFrameCount - try _timeline.addChunk(finalResult) - lastResult = finalResult + // Flush queued audio, including right context. + try session.drainRightContextWithSilence() + _ = try flush(progressCallback: nil) + + // Snapshot old speakers starting here after old audio has been flushed + let oldSlots: Set + + if isNamed { + oldSlots = Set(timeline.speakers.filter { $0.value.name != nil }.keys) + } else { + oldSlots = Set(timeline.speakers.keys) } - _timeline.finalize() - _session = nil - return lastResult - } + try session.enqueueAudio( + samples, + withSampleRate: sourceSampleRate, + eagerPreprocessing: false + ) - // MARK: - Private + // Flush enrollment audio queued in the right context + try session.drainRightContextWithSilence() - private func resetBuffersLocked() { - pendingAudio.removeAll(keepingCapacity: true) - _numFramesProcessed = 0 - _visibleStartFrameOffset = 0 - } + // Process enrollment audio. The new speaker will be extracted from this timeline update. + guard let update = try flush(recordFrames: false, progressCallback: nil), + !update.finalizedSegments.isEmpty + else { + session.rollback(to: sessionSnapshot) + timeline.rollback(to: timelineSnapshot) + return nil + } - private func normalizeSamplesLocked( - _ samples: C, - sourceSampleRate: Double? - ) throws -> [Float]? where C.Element == Float { - guard let engine = _engine, - let sourceSampleRate, - sourceSampleRate != Double(engine.targetSampleRate) + // Get the new/unnamed speaker with the most speech if any exist. + // Fallback to old speaker with the most speech if overwrites are allowed. + var speechActivities: [Int: Float] = [:] + for segment in update.finalizedSegments { + speechActivities[segment.speakerIndex, default: 0] += segment.activity * Float(segment.length) + } + + // Prioritized unnamed speakers; speech activity is secondary + let bestSlot = speechActivities.max { + let isFirstOld = oldSlots.contains($0.key) + let isSecondOld = oldSlots.contains($1.key) + if isFirstOld == isSecondOld { + return $0.value < $1.value + } + return isFirstOld + }?.key + + guard let bestSlot, + let enrolledSpeaker = timeline.speakers[bestSlot], + !requireNewSpeaker || !oldSlots.contains(bestSlot) else { + session.rollback(to: sessionSnapshot) + timeline.rollback(to: timelineSnapshot) return nil } - return try AudioConverter(sampleRate: Double(engine.targetSampleRate)) - .resample(Array(samples), from: sourceSampleRate) - } + // Rename speaker and report success + enrolledSpeaker.name = name + timeline.reset(keepingSpeakers: true) - /// Create a new mel spectrogram instance owned by this diarizer. - private static func createMelSpectrogram(featureConfig: LSEENDFeatureConfig) -> AudioMelSpectrogram { - AudioMelSpectrogram( - sampleRate: featureConfig.sampleRate, - nMels: featureConfig.nMels, - nFFT: featureConfig.nFFT, - hopLength: featureConfig.hopLength, - winLength: featureConfig.winLength, - preemph: 0, - padTo: 1, - logFloor: 1e-10, - logFloorMode: .clamped, - windowPeriodic: true - ) + return enrolledSpeaker } - private func updateTimelineConfig(engine: LSEENDInferenceHelper) { - self._timelineConfig.numSpeakers = engine.metadata.realOutputDim - self._timelineConfig.frameDurationSeconds = Float(1.0 / engine.modelFrameHz) + // MARK: - Private: state + + private func resetStreamingState() { + session?.reset() + framesFedToModel = 0 + finalized = false } - /// Convert an LSEENDMatrix to a flat [Float] in row-major layout. - private func flattenRowMajor(_ matrix: LSEENDMatrix, numSpeakers: Int) -> [Float] { - guard matrix.rows > 0, matrix.columns > 0 else { return [] } - return matrix.values + // MARK: - Private: finalize + + @discardableResult + public func finalize() throws -> DiarizerTimelineUpdate? { + guard !finalized else { return nil } + guard let session else { + throw LSEENDError.notInitialized + } + + // Drain pending real audio, capture real-frame target. + try session.drainRightContextWithSilence() + let update = try process() + + timeline.finalize() + finalized = true + return update } } diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDInference.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDInference.swift new file mode 100644 index 000000000..b353162a0 --- /dev/null +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDInference.swift @@ -0,0 +1,214 @@ +import Foundation +import CoreML +import Accelerate + +public class LSEENDModel { + public let metadata: LSEENDMetadata + + private let model: MLModel + + private let lock = NSLock() + + private static let logger = AppLogger(category: "LS-EEND Model") + + // MARK: - Init + + public init(modelURL: URL, computeUnits: MLComputeUnits = .cpuOnly) throws { + // Load the model from the URL + let modelConfig = MLModelConfiguration() + modelConfig.computeUnits = computeUnits + self.model = try MLModel(contentsOf: modelURL, configuration: modelConfig) + + // Load the config from metadata + guard let userMetadata = self.model.modelDescription.metadata[.creatorDefinedKey] as? [String: Any], + let json = userMetadata["config"] as? String + else { + throw LSEENDError.initializationFailed("No `config` found in model metadata") + } + + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + self.metadata = try decoder.decode(LSEENDMetadata.self, from: Data(json.utf8)) + } + + /// Download LS-EEND models from HuggingFace. + /// + /// - Parameters: + /// - variant: The model variant to load (default: `.dihard3`). + /// - stepSize: The model step size to load (default: `.step100ms`). + /// - cacheDirectory: Directory to cache downloaded models (defaults to app support) + /// - computeUnits: Model compute units (`.cpuOnly` seems to be fastest for this model) + /// - Returns: LS-EEND Model Wrapper + public static func loadFromHuggingFace( + variant: LSEENDVariant = .dihard3, + stepSize: LSEENDStepSize = .step100ms, + cacheDirectory: URL? = nil, + computeUnits: MLComputeUnits = .cpuOnly, + progressHandler: DownloadUtils.ProgressHandler? = nil + ) async throws -> LSEENDModel { + let directory = + cacheDirectory + ?? FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] + .appendingPathComponent("FluidAudio/Models") + + try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true) + + let repo = variant.repo + let repoPath = directory.appendingPathComponent(repo.folderName) + let modelRelPath = variant.fileName(forStep: stepSize) + let fullRelPath = repo.subPath.map { "\($0)/\(modelRelPath)" } ?? modelRelPath + let modelURL = repoPath.appendingPathComponent(fullRelPath) + + let modelExists = FileManager.default.fileExists(atPath: modelURL.path) + + if !modelExists { + // Narrow to just the one mlmodelc — listing the whole step dir + // is fine here since each step dir contains only its own mlmodelc. + logger.info("Models not found in cache at \(modelURL.path); downloading \(fullRelPath)…") + try await DownloadUtils.downloadSubdirectory( + repo, subdirectory: fullRelPath, to: repoPath + ) + } + + guard FileManager.default.fileExists(atPath: modelURL.path) else { + throw LSEENDError.initializationFailed( + "HF download completed but mlmodelc missing at \(modelURL.path). " + + "Expected HF path: \(modelRelPath)" + ) + } + + return try LSEENDModel(modelURL: modelURL, computeUnits: computeUnits) + } + + // MARK: - Inference + + public func predict(from input: LSEENDInput) throws -> [Float] { + try autoreleasepool { + lock.lock() + defer { lock.unlock() } + + let prediction = try model.prediction(from: input) + + guard let probsMA = prediction.featureValue(for: "probs")?.multiArrayValue, + let encKvMA = prediction.featureValue(for: "enc_kv_new")?.multiArrayValue, + let encScaleMA = prediction.featureValue(for: "enc_scale_new")?.multiArrayValue, + let encConvCacheMA = prediction.featureValue(for: "enc_conv_cache_new")?.multiArrayValue, + let cnnWindowMA = prediction.featureValue(for: "cnn_window_new")?.multiArrayValue, + let decKvMA = prediction.featureValue(for: "dec_kv_new")?.multiArrayValue, + let decScaleMA = prediction.featureValue(for: "dec_scale_new")?.multiArrayValue + else { + throw LSEENDError.inferenceFailed("Failed to extract predictions from CoreML model.") + } + + // Update state + input.state.encRetKv = encKvMA + input.state.encRetScale = encScaleMA + input.state.encConvCache = encConvCacheMA + input.state.cnnWindow = cnnWindowMA + input.state.decRetKv = decKvMA + input.state.decRetScale = decScaleMA + + // Copy speaker sigmoids and skip warmup frames + let warmup = input.warmupFrames + let outputFrames = metadata.chunkSize - warmup + let outputSpeakers = metadata.maxSpeakers + guard outputFrames > 0, outputSpeakers > 0 else { return [] } + + guard probsMA.strides.last?.intValue == 1 else { + throw LSEENDError.inferenceFailed( + "Probs innermost stride must be 1. CoreML model produced strides: \(probsMA.strides).") + } + let frameStride = probsMA.strides[1].intValue + + var probsOut = [Float](repeating: 0, count: outputFrames * outputSpeakers) + let maBase = probsMA.dataPointer.assumingMemoryBound(to: Float.self) + + probsOut.withUnsafeMutableBufferPointer { flatPtr in + vDSP_mmov( + maBase + warmup * frameStride, + flatPtr.baseAddress!, + vDSP_Length(outputSpeakers), + vDSP_Length(outputFrames), + vDSP_Length(frameStride), + vDSP_Length(outputSpeakers) + ) + } + + return probsOut + } + } +} + +public class LSEENDInput: MLFeatureProvider { + public var state: LSEENDState + public let melFeatures: MLMultiArray + public let decoderMask: MLMultiArray + public var warmupFrames: Int = 0 + + public var featureNames: Set { + [ + "features", + "enc_kv", "enc_scale", + "enc_conv_cache", "cnn_window", + "dec_kv", "dec_scale", + "valid_mask", + ] + } + + public init(from metadata: LSEENDMetadata, state: consuming LSEENDState? = nil) throws { + self.state = try state ?? LSEENDState(from: metadata) + let T = NSNumber(value: metadata.chunkSize) + let M = NSNumber(value: metadata.melFrames) + let N = NSNumber(value: metadata.nMels) + self.melFeatures = try MLMultiArray(shape: [1, M, N], dataType: .float32) + self.decoderMask = try MLMultiArray(shape: [T], dataType: .float32) + } + + /// Reset state + @inline(__always) + public func resetState() { + state.reset() + } + + @inline(__always) + public func loadInputs( + melFeatures newMelFeatures: C, + decoderMask newDecoderMask: C, + warmupFrames: Int? = nil + ) throws where C.Element == Float { + try Self.load(decoderMask, from: newDecoderMask) + try Self.load(melFeatures, from: newMelFeatures) + self.warmupFrames = warmupFrames ?? newDecoderMask.withUnsafeBufferPointer { $0.count(where: \.isZero) } + } + + public func featureValue(for featureName: String) -> MLFeatureValue? { + switch featureName { + case "features": return MLFeatureValue(multiArray: melFeatures) + case "enc_kv": return MLFeatureValue(multiArray: state.encRetKv) + case "enc_scale": return MLFeatureValue(multiArray: state.encRetScale) + case "enc_conv_cache": return MLFeatureValue(multiArray: state.encConvCache) + case "cnn_window": return MLFeatureValue(multiArray: state.cnnWindow) + case "dec_kv": return MLFeatureValue(multiArray: state.decRetKv) + case "dec_scale": return MLFeatureValue(multiArray: state.decRetScale) + case "valid_mask": return MLFeatureValue(multiArray: decoderMask) + default: return nil + } + } + + @inline(__always) + private static func load( + _ multiArray: MLMultiArray, + from buffer: C + ) throws { + guard buffer.count == multiArray.count else { + throw LSEENDError.invalidInputSize( + "Input size mismatch: new=\(buffer.count) expected=\(multiArray.count)") + } + + _ = buffer.withUnsafeBufferPointer { buf in + memcpy( + multiArray.dataPointer, buf.baseAddress, + buf.count * MemoryLayout.stride) + } + } +} diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift deleted file mode 100644 index f978b9006..000000000 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift +++ /dev/null @@ -1,708 +0,0 @@ -import AVFoundation -import CoreML -import CryptoKit -import Foundation - -private final class LSEENDModelState { - var encRetKv: MLMultiArray - var encRetScale: MLMultiArray - var encConvCache: MLMultiArray - var decRetKv: MLMultiArray - var decRetScale: MLMultiArray - var topBuffer: MLMultiArray - - init( - encRetKv: MLMultiArray, - encRetScale: MLMultiArray, - encConvCache: MLMultiArray, - decRetKv: MLMultiArray, - decRetScale: MLMultiArray, - topBuffer: MLMultiArray - ) { - self.encRetKv = encRetKv - self.encRetScale = encRetScale - self.encConvCache = encConvCache - self.decRetKv = decRetKv - self.decRetScale = decRetScale - self.topBuffer = topBuffer - } - - func copy() throws -> LSEENDModelState { - try LSEENDModelState( - encRetKv: cloneAlignedMultiArray(encRetKv), - encRetScale: cloneAlignedMultiArray(encRetScale), - encConvCache: cloneAlignedMultiArray(encConvCache), - decRetKv: cloneAlignedMultiArray(decRetKv), - decRetScale: cloneAlignedMultiArray(decRetScale), - topBuffer: cloneAlignedMultiArray(topBuffer) - ) - } -} - -private struct LSEENDStepOutput { - let fullLogits: [Float] - let nextState: LSEENDModelState -} - -private final class LSEENDInferenceSharedResources { - let descriptor: LSEENDModelDescriptor - let computeUnits: MLComputeUnits - let metadata: LSEENDModelMetadata - let featureConfig: LSEENDFeatureConfig - let model: MLModel - let targetSampleRate: Int - let modelFrameHz: Double - let streamingLatencySeconds: Double - let decodeMaxSpeakers: Int - let melSpectrogram: AudioMelSpectrogram - let offlineFeatureExtractor: LSEENDOfflineFeatureExtractor - - // Preallocated ANE-aligned input arrays reused across predictStep calls - let memoryOptimizer: ANEMemoryOptimizer - let frameArray: MLMultiArray // [1, 1, inputDim] - let ingestArray: MLMultiArray // [1] - let decodeArray: MLMultiArray // [1] - - init( - descriptor: LSEENDModelDescriptor, - computeUnits: MLComputeUnits - ) throws { - self.descriptor = descriptor - self.computeUnits = computeUnits - - let metadataData = try Data(contentsOf: descriptor.metadataURL) - metadata = try JSONDecoder().decode(LSEENDModelMetadata.self, from: metadataData) - featureConfig = LSEENDFeatureConfig(metadata: metadata) - targetSampleRate = metadata.resolvedSampleRate - modelFrameHz = metadata.frameHz - streamingLatencySeconds = metadata.streamingLatencySeconds - decodeMaxSpeakers = metadata.maxNspks - melSpectrogram = AudioMelSpectrogram( - sampleRate: featureConfig.sampleRate, - nMels: featureConfig.nMels, - nFFT: featureConfig.nFFT, - hopLength: featureConfig.hopLength, - winLength: featureConfig.winLength, - preemph: 0, - padTo: 1, - logFloor: 1e-10, - logFloorMode: .clamped, - windowPeriodic: true - ) - - let configuration = MLModelConfiguration() - configuration.computeUnits = computeUnits - configuration.allowLowPrecisionAccumulationOnGPU = true - model = try MLModel( - contentsOf: try LSEENDInferenceHelper.compiledModelURL(for: descriptor.modelURL), - configuration: configuration - ) - offlineFeatureExtractor = LSEENDOfflineFeatureExtractor(metadata: metadata, spectrogram: melSpectrogram) - - // Preallocate ANE-aligned input arrays - memoryOptimizer = ANEMemoryOptimizer() - frameArray = try memoryOptimizer.createAlignedArray( - shape: [1, 1, NSNumber(value: metadata.inputDim)], - dataType: .float32 - ) - ingestArray = try memoryOptimizer.createAlignedArray( - shape: [1], - dataType: .float32 - ) - decodeArray = try memoryOptimizer.createAlignedArray( - shape: [1], - dataType: .float32 - ) - } -} - -/// Low level CoreML inference engine for LS-EEND speaker diarization. -/// -/// Each engine instance owns its own compiled model, mel spectrogram, and feature extractor. -/// There are no shared singletons — multiple engines can run concurrently without interference. -/// -/// The engine supports three usage modes: -/// - **Offline**: ``infer(samples:sampleRate:)`` or ``infer(audioFileURL:)`` for batch processing. -/// - **Streaming**: ``createSession(inputSampleRate:)`` to get an ``LSEENDStreamingSession`` for -/// incremental audio processing. -/// - **Simulation**: ``simulateStreaming(audioFileURL:chunkSeconds:)`` to replay a file through the -/// streaming pipeline with fixed-size chunks. -public final class LSEENDInferenceHelper { - private let logger = AppLogger(category: "LSEENDInference") - private let sharedResources: LSEENDInferenceSharedResources - - /// The descriptor used to create this engine. - public var descriptor: LSEENDModelDescriptor { sharedResources.descriptor } - /// The CoreML compute units this engine was configured with. - public var computeUnits: MLComputeUnits { sharedResources.computeUnits } - /// Model metadata decoded from the JSON configuration file. - public var metadata: LSEENDModelMetadata { sharedResources.metadata } - /// Derived feature extraction parameters. - public var featureConfig: LSEENDFeatureConfig { sharedResources.featureConfig } - /// The loaded CoreML model. - public var model: MLModel { sharedResources.model } - /// Audio sample rate the model expects (e.g. 8000 Hz). - public var targetSampleRate: Int { sharedResources.targetSampleRate } - /// Output frame rate in Hz (e.g. 10.0). - public var modelFrameHz: Double { sharedResources.modelFrameHz } - /// Minimum latency in seconds before the first output frame can be produced. - public var streamingLatencySeconds: Double { sharedResources.streamingLatencySeconds } - /// Maximum number of speaker slots in the model output (including boundary tracks). - public var decodeMaxSpeakers: Int { sharedResources.decodeMaxSpeakers } - - fileprivate var melSpectrogram: AudioMelSpectrogram { sharedResources.melSpectrogram } - private var offlineFeatureExtractor: LSEENDOfflineFeatureExtractor { sharedResources.offlineFeatureExtractor } - - private let lock = NSLock() - - /// Creates an inference engine by loading and compiling the CoreML model. - /// - /// - Parameters: - /// - descriptor: Locates the model and metadata files. - /// - computeUnits: CoreML compute units to use (default: `.cpuOnly`, - /// which is typically fastest for this model's architecture). - /// - Throws: ``LSEENDError`` if the model or metadata cannot be loaded. - public init( - descriptor: LSEENDModelDescriptor, - computeUnits: MLComputeUnits = .cpuOnly - ) throws { - sharedResources = try LSEENDInferenceSharedResources( - descriptor: descriptor, - computeUnits: computeUnits - ) - logger.info("Loaded LS-EEND variant \(descriptor.variant.rawValue) @ \(descriptor.modelURL.path)") - } - - /// Creates a new streaming session for incremental audio processing. - /// - /// - Parameter inputSampleRate: Must match ``targetSampleRate``. - /// - Returns: A session that accepts audio via ``LSEENDStreamingSession/pushAudio(_:)``. - /// - Throws: ``LSEENDError/unsupportedAudio(_:)`` if the sample rate doesn't match. - public func createSession(inputSampleRate: Int) throws -> LSEENDStreamingSession { - try LSEENDStreamingSession(engine: self, inputSampleRate: inputSampleRate) - } - - /// Creates a streaming session with a caller-owned mel spectrogram instance. - /// - /// Use this overload when thread-safety requires the session to have its own - /// isolated spectrogram rather than sharing the engine's instance. - /// - /// - Parameters: - /// - inputSampleRate: Must match ``targetSampleRate``. - /// - melSpectrogram: A mel spectrogram instance owned by the caller. - /// - Returns: A session that accepts audio via ``LSEENDStreamingSession/pushAudio(_:)``. - public func createSession( - inputSampleRate: Int, melSpectrogram: AudioMelSpectrogram - ) throws -> LSEENDStreamingSession { - try LSEENDStreamingSession(engine: self, inputSampleRate: inputSampleRate, melSpectrogram: melSpectrogram) - } - - /// Runs offline inference on raw audio samples. - /// - /// Resamples to ``targetSampleRate`` if needed, extracts features, runs the full - /// model, and returns speaker probabilities for every frame. - /// - /// - Parameters: - /// - samples: Mono audio samples. - /// - sampleRate: Sample rate of the input audio. - /// - Returns: Complete inference result with logits and probabilities. - public func infer(samples: [Float], sampleRate: Int) throws -> LSEENDInferenceResult { - let normalizedAudio = try resampleIfNeeded(samples: samples, sampleRate: sampleRate) - let session = try createSession(inputSampleRate: targetSampleRate) - if !normalizedAudio.isEmpty { - _ = try session.pushAudio(normalizedAudio) - } - _ = try session.finalize() - return session.snapshot() - } - - /// Runs offline inference on an audio file. - /// - /// Reads the file, resamples to ``targetSampleRate``, and runs full inference. - /// - /// - Parameter audioFileURL: Path to a WAV, CAF, or other audio file. - /// - Returns: Complete inference result with logits and probabilities. - public func infer(audioFileURL: URL) throws -> LSEENDInferenceResult { - let converter = AudioConverter( - targetFormat: AVAudioFormat( - commonFormat: .pcmFormatFloat32, - sampleRate: Double(targetSampleRate), - channels: 1, - interleaved: false - )! - ) - let audio = try converter.resampleAudioFile(audioFileURL) - return try infer(samples: audio, sampleRate: targetSampleRate) - } - - /// Simulates streaming inference by processing an audio file in fixed-size chunks. - /// - /// Useful for testing and benchmarking the streaming pipeline against offline results. - /// - /// - Parameters: - /// - audioFileURL: Path to the audio file. - /// - chunkSeconds: Duration of each simulated audio chunk in seconds. - /// - Returns: The final inference result along with per-chunk progress entries. - public func simulateStreaming(audioFileURL: URL, chunkSeconds: Double) throws -> LSEENDStreamingSimulationResult { - let converter = AudioConverter( - targetFormat: AVAudioFormat( - commonFormat: .pcmFormatFloat32, - sampleRate: Double(targetSampleRate), - channels: 1, - interleaved: false - )! - ) - let audio = try converter.resampleAudioFile(audioFileURL) - let chunkSize = max(1, Int(round(chunkSeconds * Double(targetSampleRate)))) - let session = try createSession(inputSampleRate: targetSampleRate) - var updates: [LSEENDStreamingProgress] = [] - var chunkIndex = 1 - var start = 0 - while start < audio.count { - let stop = min(audio.count, start + chunkSize) - let update = try session.pushAudio(Array(audio[start.. LSEENDStepOutput { - try lock.withLock { - // Write into preallocated ANE-aligned arrays instead of allocating new ones - sharedResources.memoryOptimizer.optimizedCopy(from: frame, to: sharedResources.frameArray) - sharedResources.ingestArray[0] = NSNumber(value: ingest) - sharedResources.decodeArray[0] = NSNumber(value: decode) - - let provider = try MLDictionaryFeatureProvider(dictionary: [ - "frame": MLFeatureValue(multiArray: sharedResources.frameArray), - "enc_ret_kv": MLFeatureValue(multiArray: state.encRetKv), - "enc_ret_scale": MLFeatureValue(multiArray: state.encRetScale), - "enc_conv_cache": MLFeatureValue(multiArray: state.encConvCache), - "dec_ret_kv": MLFeatureValue(multiArray: state.decRetKv), - "dec_ret_scale": MLFeatureValue(multiArray: state.decRetScale), - "top_buffer": MLFeatureValue(multiArray: state.topBuffer), - "ingest": MLFeatureValue(multiArray: sharedResources.ingestArray), - "decode": MLFeatureValue(multiArray: sharedResources.decodeArray), - ]) - let prediction = try model.prediction(from: provider) - let fullLogitsArray = try feature(named: "full_logits", from: prediction) - let nextState = LSEENDModelState( - encRetKv: try cloneAligned(feature(named: "enc_ret_kv_out", from: prediction)), - encRetScale: try cloneAligned(feature(named: "enc_ret_scale_out", from: prediction)), - encConvCache: try cloneAligned(feature(named: "enc_conv_cache_out", from: prediction)), - decRetKv: try cloneAligned(feature(named: "dec_ret_kv_out", from: prediction)), - decRetScale: try cloneAligned(feature(named: "dec_ret_scale_out", from: prediction)), - topBuffer: try cloneAligned(feature(named: "top_buffer_out", from: prediction)) - ) - return LSEENDStepOutput( - fullLogits: floatValues(from: fullLogitsArray, count: metadata.fullOutputDim), - nextState: nextState - ) - } - } - - fileprivate func initialState() throws -> LSEENDModelState { - let optimizer = sharedResources.memoryOptimizer - - return try LSEENDModelState( - encRetKv: optimizer.createAlignedArray( - shape: metadata.stateShapes.encRetKv.map(NSNumber.init(value:)), dataType: .float32), - encRetScale: optimizer.createAlignedArray( - shape: metadata.stateShapes.encRetScale.map(NSNumber.init(value:)), dataType: .float32), - encConvCache: optimizer.createAlignedArray( - shape: metadata.stateShapes.encConvCache.map(NSNumber.init(value:)), dataType: .float32), - decRetKv: optimizer.createAlignedArray( - shape: metadata.stateShapes.decRetKv.map(NSNumber.init(value:)), dataType: .float32), - decRetScale: optimizer.createAlignedArray( - shape: metadata.stateShapes.decRetScale.map(NSNumber.init(value:)), dataType: .float32), - topBuffer: optimizer.createAlignedArray( - shape: metadata.stateShapes.topBuffer.map(NSNumber.init(value:)), dataType: .float32) - ) - } - - private func resampleIfNeeded(samples: [Float], sampleRate: Int) throws -> [Float] { - guard sampleRate > 0 else { - throw LSEENDError.unsupportedAudio("Invalid sample rate \(sampleRate).") - } - guard sampleRate != targetSampleRate else { - return samples - } - let converter = AudioConverter( - targetFormat: AVAudioFormat( - commonFormat: .pcmFormatFloat32, - sampleRate: Double(targetSampleRate), - channels: 1, - interleaved: false - )! - ) - return try converter.resample(samples, from: Double(sampleRate)) - } - - fileprivate static func compiledModelURL(for modelURL: URL) throws -> URL { - if modelURL.pathExtension == "mlmodelc" { - return modelURL - } - let caches = - FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first - ?? URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true) - let compiledRoot = caches.appendingPathComponent("LS-EENDCompiledModels", isDirectory: true) - try FileManager.default.createDirectory(at: compiledRoot, withIntermediateDirectories: true) - let fingerprint = try cacheFingerprint(for: modelURL) - let compiledName = - modelURL.deletingPathExtension().lastPathComponent - + "-" - + fingerprint - + ".mlmodelc" - let destination = compiledRoot.appendingPathComponent(compiledName, isDirectory: true) - if FileManager.default.fileExists(atPath: destination.path) { - return destination - } - let compiled = try MLModel.compileModel(at: modelURL) - try FileManager.default.copyItem(at: compiled, to: destination) - return destination - } - - private func feature(named name: String, from provider: MLFeatureProvider) throws -> MLMultiArray { - guard let value = provider.featureValue(for: name)?.multiArrayValue else { - throw LSEENDError.missingFeature(name) - } - return value - } - - /// Clone an MLMultiArray into a new ANE-aligned allocation using stride-aware copy. - private func cloneAligned(_ source: MLMultiArray) throws -> MLMultiArray { - let copy = try sharedResources.memoryOptimizer.createAlignedArray( - shape: source.shape, - dataType: .float32 - ) - ANEMemoryUtils.strideAwareCopy(from: source, to: copy) - return copy - } - - private static func cacheFingerprint(for url: URL) throws -> String { - let fileManager = FileManager.default - let standardizedURL = url.standardizedFileURL - var hasher = SHA256() - - func update(for fileURL: URL, relativePath: String) throws { - let attributes = try fileManager.attributesOfItem(atPath: fileURL.path) - let fileType = attributes[.type] as? FileAttributeType ?? .typeUnknown - let size = (attributes[.size] as? NSNumber)?.uint64Value ?? 0 - let modification = (attributes[.modificationDate] as? Date)?.timeIntervalSinceReferenceDate ?? 0 - let record = "\(relativePath)|\(fileType.rawValue)|\(size)|\(modification)\n" - hasher.update(data: Data(record.utf8)) - } - - let rootAttributes = try fileManager.attributesOfItem(atPath: standardizedURL.path) - let rootType = rootAttributes[.type] as? FileAttributeType ?? .typeUnknown - - if rootType == .typeDirectory { - let enumerator = fileManager.enumerator( - at: standardizedURL, - includingPropertiesForKeys: [.isRegularFileKey], - options: [.skipsHiddenFiles] - ) - var paths: [String] = [] - while let childURL = enumerator?.nextObject() as? URL { - let relativePath = childURL.path.replacingOccurrences(of: standardizedURL.path + "/", with: "") - paths.append(relativePath) - } - for relativePath in paths.sorted() { - try update(for: standardizedURL.appendingPathComponent(relativePath), relativePath: relativePath) - } - } else { - try update(for: standardizedURL, relativePath: standardizedURL.lastPathComponent) - } - - let digest = hasher.finalize() - return digest.prefix(8).map { String(format: "%02x", $0) }.joined() - } -} - -/// A stateful streaming session that incrementally processes audio and emits diarization frames. -/// -/// Created via ``LSEENDInferenceEngine/createSession(inputSampleRate:)``. -/// The session maintains internal RNN state across calls to ``pushAudio(_:)``. -/// -/// - Important: This class is **not** thread-safe. All calls must be serialized externally. -public final class LSEENDStreamingSession { - fileprivate let engine: LSEENDInferenceHelper - /// The sample rate of audio being fed to this session. - public let inputSampleRate: Int - fileprivate let featureExtractor: LSEENDStreamingFeatureExtractor - fileprivate var state: LSEENDModelState - fileprivate let zeroFrame: [Float] - fileprivate var fullLogitChunks: [LSEENDMatrix] = [] - fileprivate var finalized = false - - fileprivate var totalInputSamples: Int = 0 - fileprivate var totalFeatureFrames = 0 - fileprivate var emittedFrames = 0 - - fileprivate init( - engine: LSEENDInferenceHelper, inputSampleRate: Int, melSpectrogram: AudioMelSpectrogram? = nil - ) throws { - guard inputSampleRate == engine.targetSampleRate else { - throw LSEENDError.unsupportedAudio( - "Stateful LS-EEND streaming expects \(engine.targetSampleRate) Hz audio, received \(inputSampleRate) Hz." - ) - } - self.engine = engine - self.inputSampleRate = inputSampleRate - featureExtractor = LSEENDStreamingFeatureExtractor( - metadata: engine.metadata, spectrogram: melSpectrogram ?? engine.melSpectrogram) - state = try engine.initialState() - zeroFrame = [Float](repeating: 0, count: engine.metadata.inputDim) - } - - /// Feeds audio samples into the session and returns any newly committed frames. - /// - /// The returned update contains both committed (final) frames and a speculative preview - /// of pending frames decoded by zero-padding the remaining state. - /// - /// - Parameter chunk: Mono audio samples at ``inputSampleRate``. - /// - Returns: An update with new committed and preview frames, or `nil` if no frames were produced. - /// - Throws: ``LSEENDError/unsupportedAudio(_:)`` if the session has already been finalized. - public func pushAudio(_ chunk: [Float]) throws -> LSEENDStreamingUpdate? { - guard !finalized else { - throw LSEENDError.unsupportedAudio("Streaming session already finalized.") - } - guard !chunk.isEmpty else { - return nil - } - totalInputSamples += chunk.count - let features = try featureExtractor.pushAudio(chunk) - let committed = try ingestFeatures(features) - return try buildUpdate(committedFullLogits: committed, includePreview: true) - } - - /// Flushes remaining buffered features and marks the session as complete. - /// - /// After finalization, ``pushAudio(_:)`` will throw. Calling `finalize()` again returns `nil`. - /// - /// - Returns: A final update with any remaining frames, or `nil` if no frames were pending. - public func finalize() throws -> LSEENDStreamingUpdate? { - guard !finalized else { - return nil - } - - var committedFullLogits = LSEENDMatrix.empty(columns: engine.decodeMaxSpeakers) - let targetEndFrame = Int( - round(Double(totalInputSamples) / Double(max(inputSampleRate, 1)) * engine.modelFrameHz)) - let exactPaddingSamples = try exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) - if exactPaddingSamples > 0 { - let features = try featureExtractor.pushAudio([Float](repeating: 0, count: exactPaddingSamples)) - let committed = try ingestFeatures(features) - if committed.rows > 0 { - committedFullLogits = committedFullLogits.appendingRows(committed) - } - } - - let finalFeatures = try featureExtractor.finalize() - let finalCommitted = try ingestFeatures(finalFeatures) - if finalCommitted.rows > 0 { - committedFullLogits = committedFullLogits.appendingRows(finalCommitted) - } - - let pending = totalFeatureFrames - emittedFrames - let tail = - try pending > 0 ? flushTail(from: state, pendingFrames: pending) : .empty(columns: engine.decodeMaxSpeakers) - emittedFrames += tail.rows - finalized = true - return try buildUpdate(committedFullLogits: committedFullLogits.appendingRows(tail), includePreview: false) - } - - private func exactFinalizationPaddingSamples(targetEndFrame: Int) throws -> Int { - guard targetEndFrame > 0 else { - return 0 - } - let stableBlockSize = engine.metadata.resolvedHopLength * engine.metadata.resolvedSubsampling - let (requiredTotalSamples, overflow) = targetEndFrame.multipliedReportingOverflow(by: stableBlockSize) - guard !overflow else { - throw LSEENDError.unsupportedAudio( - "Finalization padding overflowed for \(targetEndFrame) frames at block size \(stableBlockSize)." - ) - } - return max(0, requiredTotalSamples - totalInputSamples) - } - - /// Assembles the full inference result from all committed frames emitted so far. - /// - /// Can be called at any time (before or after finalization) to get a complete - /// ``LSEENDInferenceResult`` covering all frames produced up to this point. - public func snapshot() -> LSEENDInferenceResult { - let fullLogits = fullLogitChunks.reduce(LSEENDMatrix.empty(columns: engine.decodeMaxSpeakers)) { - partial, matrix in - partial.appendingRows(matrix) - } - let fullProbabilities = fullLogits.applyingSigmoid() - let logits = cropRealTracks(from: fullLogits) - let probabilities = cropRealTracks(from: fullProbabilities) - return LSEENDInferenceResult( - logits: logits, - probabilities: probabilities, - fullLogits: fullLogits, - fullProbabilities: fullProbabilities, - frameHz: engine.modelFrameHz, - durationSeconds: Double(totalInputSamples) / Double(max(inputSampleRate, 1)) - ) - } - - fileprivate func ingestFeatures(_ features: LSEENDMatrix) throws -> LSEENDMatrix { - guard features.rows > 0 else { - return .empty(columns: engine.decodeMaxSpeakers) - } - var output: [Float] = [] - for rowIndex in 0..= engine.metadata.convDelay ? Float(1) : Float(0) - let step = try autoreleasepool { - try engine.predictStep( - frame: Array(features.row(rowIndex)), - state: state, - ingest: 1, - decode: decode - ) - } - state = step.nextState - totalFeatureFrames += 1 - if decode > 0 { - output.append(contentsOf: step.fullLogits) - emittedFrames += 1 - } - } - return LSEENDMatrix( - validatingRows: output.isEmpty ? 0 : output.count / engine.decodeMaxSpeakers, - columns: engine.decodeMaxSpeakers, - values: output - ) - } - - fileprivate func flushTail(from state: LSEENDModelState, pendingFrames: Int) throws -> LSEENDMatrix { - guard pendingFrames > 0 else { - return .empty(columns: engine.decodeMaxSpeakers) - } - var previewState = state - var output: [Float] = [] - for _ in 0.. LSEENDStreamingUpdate? { - let startFrame = emittedFrames - committedFullLogits.rows - let committedProbabilities = committedFullLogits.applyingSigmoid() - let committedLogits = cropRealTracks(from: committedFullLogits) - let committedRealProbabilities = cropRealTracks(from: committedProbabilities) - - if committedFullLogits.rows > 0 { - fullLogitChunks.append(committedFullLogits) - } - - let previewFullLogits: LSEENDMatrix - if includePreview { - let previewState = try state.copy() - let pending = totalFeatureFrames - emittedFrames - previewFullLogits = try flushTail(from: previewState, pendingFrames: pending) - } else { - previewFullLogits = .empty(columns: engine.decodeMaxSpeakers) - } - let previewLogits = cropRealTracks(from: previewFullLogits) - let previewProbabilities = cropRealTracks(from: previewFullLogits.applyingSigmoid()) - - if committedLogits.isEmpty && previewLogits.isEmpty { - return nil - } - - return LSEENDStreamingUpdate( - startFrame: startFrame, - logits: committedLogits, - probabilities: committedRealProbabilities, - previewStartFrame: emittedFrames, - previewLogits: previewLogits, - previewProbabilities: previewProbabilities, - frameHz: engine.modelFrameHz, - durationSeconds: Double(totalInputSamples) / Double(max(inputSampleRate, 1)), - totalEmittedFrames: emittedFrames - ) - } - - private func cropRealTracks(from matrix: LSEENDMatrix) -> LSEENDMatrix { - guard matrix.columns > 2 else { - return .empty(columns: 0) - } - let realColumns = matrix.columns - 2 - var output = [Float](repeating: 0, count: matrix.rows * realColumns) - for rowIndex in 0.. Double { - (value * 1000).rounded() / 1000 -} - -/// Clone an MLMultiArray into a new ANE-aligned allocation. -private func cloneAlignedMultiArray(_ source: MLMultiArray) throws -> MLMultiArray { - let copy = try ANEMemoryUtils.createAlignedArray( - shape: source.shape, - dataType: .float32, - zeroClear: false - ) - ANEMemoryUtils.strideAwareCopy(from: source, to: copy) - return copy -} - -private func floatValues(from array: MLMultiArray, count: Int) -> [Float] { - let pointer = array.dataPointer.bindMemory(to: Float.self, capacity: max(count, array.count)) - return Array(UnsafeBufferPointer(start: pointer, count: count)) -} diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDPreprocessor.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDPreprocessor.swift index 24c1ce09e..04f16b5e2 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDPreprocessor.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDPreprocessor.swift @@ -1,414 +1,384 @@ import Foundation +import CoreML +import Accelerate -private let lseendLogConversionFactor = Float(1.0 / Foundation.log(10.0)) - -/// Resolved feature extraction parameters derived from ``LSEENDModelMetadata``. -/// -/// Captures the concrete STFT and splice-and-subsample settings needed by the -/// feature extractors, resolving any optional fields in the metadata to their defaults. -public struct LSEENDFeatureConfig: Sendable, Hashable { - /// Audio sample rate in Hz (e.g. 8000). - public let sampleRate: Int - /// STFT window length in samples. - public let winLength: Int - /// STFT hop length in samples. - public let hopLength: Int - /// FFT size (a power of 2 ≥ ``winLength``). - public let nFFT: Int - /// Number of mel filterbank channels. - public let nMels: Int - /// Context receptive field half-width for the splice step. - public let contextRecp: Int - /// Subsampling factor (how many STFT frames per model frame). - public let subsampling: Int - /// Total input feature dimension per model frame (`nMels × (2 × contextRecp + 1)`). - public let inputDim: Int - - /// Creates a feature config by resolving all parameters from the given metadata. - public init(metadata: LSEENDModelMetadata) { - sampleRate = metadata.resolvedSampleRate - winLength = metadata.resolvedWinLength - hopLength = metadata.resolvedHopLength - nFFT = metadata.resolvedFFTSize - nMels = metadata.resolvedMelCount - contextRecp = metadata.resolvedContextRecp - subsampling = metadata.resolvedSubsampling - inputDim = metadata.inputDim +// MARK: - Feature Provider +public class LSEENDFeatureProvider { + public struct Snapshot: ~Copyable { + let state: LSEENDState + let melQueue: StreamingChunkQueue + let audioQueue: StreamingChunkQueue + let cmnMean: [Float] + let cmnCount: Int + let decoderMaskEnd: Int } - /// Minimum audio chunk size in samples that produces an integer number of model frames. - /// - /// Equal to `hopLength × subsampling`. Audio buffers should be multiples of this - /// size for consistent streaming behavior. - public var stableBlockSize: Int { - hopLength * subsampling - } -} + /// Number of mel chunks currently ready for `emitNextChunk()`. + public var readyChunks: Int { lock.withLock { melQueue.readyChunks } } -private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> AudioMelSpectrogram { - AudioMelSpectrogram( - sampleRate: config.sampleRate, - nMels: config.nMels, - nFFT: config.nFFT, - hopLength: config.hopLength, - winLength: config.winLength, - preemph: 0, - padTo: 1, - logFloor: 1e-10, - logFloorMode: .clamped, - windowPeriodic: true - ) -} + // MARK: Private Attributes -/// Batch feature extractor for offline LS-EEND inference. -/// -/// Converts a complete audio buffer into model input features in one pass: -/// 1. STFT → mel spectrogram -/// 2. Log-mel with cumulative mean normalization -/// 3. Splice-and-subsample context windowing -/// -/// For incremental processing, use ``LSEENDStreamingFeatureExtractor`` instead. -public final class LSEENDOfflineFeatureExtractor { - private let config: LSEENDFeatureConfig - private let spectrogram: AudioMelSpectrogram - - /// Creates an offline feature extractor. - /// - /// - Parameters: - /// - metadata: Model metadata from which feature parameters are derived. - /// - spectrogram: Optional pre-configured mel spectrogram; one is created if `nil`. - public init(metadata: LSEENDModelMetadata, spectrogram: AudioMelSpectrogram? = nil) { - let featureConfig = LSEENDFeatureConfig(metadata: metadata) - config = featureConfig - self.spectrogram = spectrogram ?? createMelSpectrogram(for: featureConfig) - } + private let melSpectrogram: AudioMelSpectrogram + private let converter: AudioConverter + private let input: LSEENDInput - /// Extracts model input features from a complete audio buffer. - /// - /// - Parameter audio: Mono audio samples at the model's target sample rate. - /// - Returns: Feature matrix with shape `[frames, inputDim]`, or an empty matrix - /// if the audio is too short to produce any frames. - public func extractFeatures(audio: [Float]) throws -> LSEENDMatrix { - let usableSamples = (audio.count / config.stableBlockSize) * config.stableBlockSize - guard usableSamples > 0 else { - return .empty(columns: config.inputDim) - } - let trimmedAudio = Array(audio.prefix(usableSamples)) - let stftFrameCount = max(0, usableSamples / config.hopLength - 1) - guard stftFrameCount > 0 else { - return .empty(columns: config.inputDim) - } - let mel = spectrogram.computeFlatTransposed( - audio: trimmedAudio, - lastAudioSample: 0, - paddingMode: .center, - expectedFrameCount: stftFrameCount - ).mel - let normalized = Self.applyLogMelCumMeanNormalization( - mel, - rowCount: stftFrameCount, - nMels: config.nMels, - frameStart: 0, - cumulativeFeatureSum: nil - ) - let base = LSEENDMatrix(validatingRows: stftFrameCount, columns: config.nMels, values: normalized.values) - return Self.spliceAndSubsample( - baseFeatures: base, - contextSize: config.contextRecp, - subsampling: config.subsampling + private var melQueue: StreamingChunkQueue + private var audioQueue: StreamingChunkQueue + + private var cmnMean: [Float] + private var cmnCount: Int + + private var isRightContextEmpty: Bool = true + + private var decoderMaskEnd: Int + + private let lock = NSLock() + private let log10Scale: Float = 1.0 / log(10.0) + private let decoderMask: [Float] + + /// Audio samples required past the last real sample to flush every + /// buffered real frame through STFT + mel ±context + CNN right-lookahead. + private let flushSampleCount: Int + private let chunkFrames: Int + private let nMels: Int + + // MARK: - Init + public init( + from metadata: borrowing LSEENDMetadata, + restoringFrom snapshot: consuming Snapshot? = nil + ) throws { + self.nMels = metadata.nMels + + let contextMels = metadata.contextSize + let contextSamples = metadata.nFFT / 2 + let chunkMels = metadata.subsampling * metadata.chunkSize + let chunkSamples = metadata.hopLength * chunkMels + let rightSamples = metadata.nFFT / 2 - metadata.hopLength + + // (mel ±context + CNN right-lookahead) mels × hop + STFT last-window halfNfft + self.flushSampleCount = + (contextMels + metadata.convDelay * metadata.subsampling) * metadata.hopLength + + contextSamples + + self.chunkFrames = metadata.chunkSize + + var decoderMaskTemp = [Float](repeating: 1, count: metadata.convDelay + metadata.chunkSize) + vDSP_vclr(&decoderMaskTemp, 1, vDSP_Length(metadata.convDelay)) + self.decoderMask = decoderMaskTemp + + // Initialize processors + self.melSpectrogram = AudioMelSpectrogram( + sampleRate: metadata.sampleRate, + nMels: metadata.nMels, + nFFT: metadata.nFFT, + hopLength: metadata.hopLength, + winLength: metadata.winLength, + preemph: 0, + padTo: 0, + logFloor: 1e-10, + logFloorMode: .clamped, + windowPeriodic: true ) - } - fileprivate static func applyLogMelCumMeanNormalization( - _ mel: [Float], - rowCount: Int, - nMels: Int, - frameStart: Int, - cumulativeFeatureSum: [Double]? - ) -> (values: [Float], cumulativeFeatureSum: [Double]) { - var cumulative = cumulativeFeatureSum ?? [Double](repeating: 0, count: nMels) - var output = [Float](repeating: 0, count: mel.count) - for rowIndex in 0.. LSEENDMatrix { - guard baseFeatures.rows > 0 else { - return .empty(columns: baseFeatures.columns * ((2 * contextSize) + 1)) + // MARK: - Push Audio + + /// Add audio to the processing queue + /// - Parameters: + /// - samples: Audio samples to enqueue + /// - sourceSampleRate: Sample rate of audio input + /// - eagerPreprocessing: Whether to eagerly feed audio chunks to the mel spectrogram + public func enqueueAudio( + _ samples: C, + withSampleRate sourceSampleRate: Double? = nil, + eagerPreprocessing: Bool = true + ) throws where C.Element == Float { + lock.lock() + defer { lock.unlock() } + + if let sourceSampleRate { + let array = (samples as? [Float]) ?? Array(samples) + try audioQueue.append(converter.resample(array, from: sourceSampleRate)) + } else { + audioQueue.append(samples) } - let outputRows = (baseFeatures.rows + subsampling - 1) / subsampling - let outputColumns = baseFeatures.columns * ((2 * contextSize) + 1) - var output = [Float](repeating: 0, count: outputRows * outputColumns) - for outputRow in 0..= 0, sourceRow < baseFeatures.rows else { - continue - } - let destinationOffset = destinationRowOffset + (contextOffset + contextSize) * baseFeatures.columns - let sourceOffset = sourceRow * baseFeatures.columns - for featureIndex in 0.. Int { + let samples = try converter.resampleAudioFile(url) + lock.lock() + defer { lock.unlock() } + audioQueue.append(samples) + processAudioQueue() + return samples.count } - /// Feeds audio samples and returns any new model input frames. - /// - /// - Parameter chunk: Mono audio samples at the model's target sample rate. - /// - Returns: Feature matrix with shape `[newFrames, inputDim]`, or an empty matrix - /// if no new frames could be produced from the available audio. - public func pushAudio(_ chunk: [Float]) throws -> LSEENDMatrix { - guard !chunk.isEmpty else { - return .empty(columns: config.inputDim) + /// Add silence to push all queued audio out of the right context + /// - Parameter flush Whether to flush the queued audio into the mel spectrogram preprocessor + public func drainRightContextWithSilence(flush: Bool = true) throws { + lock.lock() + defer { lock.unlock() } + + // 1. Trailing silence covering STFT + mel ±context + CNN right-lookahead. + audioQueue.append(repeatElement(0, count: flushSampleCount)) + + // 2. Round up to the next audio-chunk boundary so popAllChunks consumes + // every real sample plus the silence we just pushed. + let unread = audioQueue.unreadFloats + let chunk = audioQueue.chunkFloats + let ctx = audioQueue.contextFloats + let overCtx = max(0, unread - ctx) + let shortfall = (chunk - overCtx % chunk) % chunk + if shortfall > 0 { + audioQueue.append(repeatElement(0, count: shortfall)) + } + + // 3. Drain audioQueue → STFT → log10 → CMN → melQueue. + if flush { + processAudioQueue() } - audioBuffer.append(contentsOf: chunk) - totalSamples += chunk.count - try appendSTFTFrames(targetFrameCount: stableSTFTFrameCount(), allowRightPad: false, effectiveTotalSamples: nil) - return try emitModelFrames(final: false, totalSTFTFrames: nil) } - /// Flushes remaining buffered audio and returns any final model input frames. - /// - /// Should be called exactly once after the last ``pushAudio(_:)`` call. - /// Applies right-padding to extract any remaining STFT frames that couldn't - /// be emitted during streaming. - /// - /// - Returns: Feature matrix with any remaining frames, or an empty matrix. - public func finalize() throws -> LSEENDMatrix { - let usableSamples = usableSampleCount(totalSamples) - let totalSTFTFrames = offlineSTFTFrameCount(usableSamples) - try appendSTFTFrames( - targetFrameCount: totalSTFTFrames, - allowRightPad: true, - effectiveTotalSamples: usableSamples + // MARK: - Read Chunk + + /// Read the next chunk from the mel + public func emitNextChunk() throws -> LSEENDInput? { + lock.lock() + defer { lock.unlock() } + + processAudioQueue() + guard let rawChunk = melQueue.popNextChunk() else { return nil } + + // Advance decoder mask + decoderMaskEnd = min(decoderMaskEnd + chunkFrames, decoderMask.count) + + try input.loadInputs( + melFeatures: rawChunk, + decoderMask: decoderMask[decoderMaskEnd - chunkFrames.. Int { - (sampleCount / config.stableBlockSize) * config.stableBlockSize + return input } - private func stableSTFTFrameCount() -> Int { - let leftPad = config.nFFT / 2 - guard totalSamples > leftPad else { - return 0 - } - return max(0, ((totalSamples - leftPad) / config.hopLength) + 1) + // MARK: - Snapshot and Rollback + + public func takeSnapshot() throws -> Snapshot { + lock.lock() + defer { lock.unlock() } + let result = Snapshot( + state: try input.state.copy(), + melQueue: melQueue, + audioQueue: audioQueue, + cmnMean: cmnMean, + cmnCount: cmnCount, + decoderMaskEnd: decoderMaskEnd + ) + return result } - private func offlineSTFTFrameCount(_ usableSamples: Int) -> Int { - guard usableSamples > 0 else { - return 0 - } - return max(0, usableSamples / config.hopLength - 1) + /// Rollback to a previous snapshot. + /// - Parameters: + /// - snapshot Snapshot to revert to + /// - keepingState Whether to preserve the current recurrent state + public func rollback(to snapshot: consuming Snapshot, keepingState: Bool = false) { + lock.lock() + defer { lock.unlock() } + if !keepingState { self.input.state = snapshot.state } + self.melQueue = snapshot.melQueue + self.audioQueue = snapshot.audioQueue + self.cmnMean = snapshot.cmnMean + self.cmnCount = snapshot.cmnCount + self.decoderMaskEnd = snapshot.decoderMaskEnd } - private func totalModelFrameCount(_ totalSTFTFrames: Int) -> Int { - guard totalSTFTFrames > 0 else { - return 0 - } - return (totalSTFTFrames + config.subsampling - 1) / config.subsampling + /// Clear preprocessor buffers + model recurrence state + frame counter. + public func reset() { + lock.lock() + defer { lock.unlock() } + vDSP.fill(&cmnMean, with: 0) + cmnCount = 0 + decoderMaskEnd = 0 + audioQueue.reset() + melQueue.reset() + input.resetState() } - private func appendSTFTFrames( - targetFrameCount: Int, - allowRightPad: Bool, - effectiveTotalSamples: Int? - ) throws { - guard targetFrameCount > nextSTFTFrame else { - return - } - let frameStart = nextSTFTFrame - let frameStop = targetFrameCount - let expectedFrames = frameStop - frameStart - let segment = try stftSegment( - frameStart: frameStart, - frameStop: frameStop, - allowRightPad: allowRightPad, - effectiveTotalSamples: effectiveTotalSamples - ) - let mel = spectrogram.computeFlatTransposed( - audio: segment, + // MARK: - Helpers + + private func processAudioQueue() { + guard let audioChunk = audioQueue.popAllChunks() else { return } + + var (melFeats, melFrames, _) = melSpectrogram.computeFlatTransposed( + audio: audioChunk, lastAudioSample: 0, paddingMode: .prePadded, - expectedFrameCount: expectedFrames - ).mel - let normalized = LSEENDOfflineFeatureExtractor.applyLogMelCumMeanNormalization( - mel, - rowCount: expectedFrames, - nMels: config.nMels, - frameStart: frameStart, - cumulativeFeatureSum: cumulativeFeatureSum + expectedFrameCount: nil ) - cumulativeFeatureSum = normalized.cumulativeFeatureSum - baseFeatureBuffer.append(contentsOf: normalized.values) - baseFeatureRows += expectedFrames - nextSTFTFrame = frameStop - dropConsumedAudio() - } - private func stftSegment( - frameStart: Int, - frameStop: Int, - allowRightPad: Bool, - effectiveTotalSamples: Int? - ) throws -> [Float] { - guard frameStop > frameStart else { - return [] - } - let leftPad = config.nFFT / 2 - let total = effectiveTotalSamples ?? totalSamples - let globalStart = frameStart * config.hopLength - leftPad - let globalStop = (frameStop - 1) * config.hopLength - leftPad + config.nFFT - - let prefixCount = max(0, -globalStart) - let suffixCount = allowRightPad ? max(0, globalStop - total) : 0 - let rawStart = max(0, globalStart) - let rawStop = min(total, globalStop) - guard rawStart >= audioStartSample else { - throw LSEENDError.unsupportedAudio( - "Audio buffer underflow. Need sample \(rawStart) but buffer starts at \(audioStartSample)." - ) - } - let localStart = rawStart - audioStartSample - let localStop = rawStop - audioStartSample - var segment = [Float](repeating: 0, count: prefixCount + (localStop - localStart) + suffixCount) - let coreCount = max(0, localStop - localStart) - if coreCount > 0 { - for index in 0.. LSEENDMatrix { - var output = [Float]() - let latestFrame = nextSTFTFrame - 1 - let totalModelFrames = final ? self.totalModelFrameCount(totalSTFTFrames ?? 0) : nil - - while true { - let centerIndex = nextModelFrame * config.subsampling - let maxIndex: Int - if final { - guard let totalModelFrames, let totalSTFTFrames else { break } - if nextModelFrame >= totalModelFrames { - break - } - maxIndex = totalSTFTFrames - 1 - } else { - if centerIndex + config.contextRecp > latestFrame { - break - } - maxIndex = latestFrame - } - output.append(contentsOf: try spliceFrame(centerIndex: centerIndex, maxIndex: maxIndex)) - nextModelFrame += 1 - dropConsumedBaseFeatures() - } +// MARK: - Streaming Chunk Queue + +public struct StreamingChunkQueue { + /// Stride between frames if features are n-dimensional arrays + public let stride: Int + + /// Total context size in floats (`leftContextFloats + rightContextFloats`). + public let contextFloats: Int + + /// Unpadded chunk size + public let chunkFloats: Int + + /// Padded chunk size — width of a `popNextChunk` / `popAllChunks` slice. + public let paddedChunkFloats: Int - let outputRows = output.isEmpty ? 0 : output.count / config.inputDim - return LSEENDMatrix(validatingRows: outputRows, columns: config.inputDim, values: output) + /// Whether the buffer is empty + public var isEmpty: Bool { buffer.isEmpty } + + /// Number of unread floats + public var unreadFloats: Int { buffer.count - head } + + /// Number of full chunks currently poppable via `popNextChunk` / `popAllChunks`. + public var readyChunks: Int { max(0, (unreadFloats - contextFloats) / chunkFloats) } + + // MARK: - Private attributes + + /// Pre-pad width in floats — how many leading zeros are seeded at init + private let leftContextFloats: Int + + /// Next index at which to start processing + private var head: Int + + /// Data buffer + private var buffer: [Float] = [] + + /// Whether a chunk is ready + public var hasChunk: Bool { + buffer.count - head >= paddedChunkFloats } - private func spliceFrame(centerIndex: Int, maxIndex: Int) throws -> [Float] { - var frame = [Float](repeating: 0, count: config.inputDim) - for frameIndex in (centerIndex - config.contextRecp)...(centerIndex + config.contextRecp) { - let destinationBase = (frameIndex - (centerIndex - config.contextRecp)) * config.nMels - guard frameIndex >= 0, frameIndex <= maxIndex else { - continue - } - let localIndex = frameIndex - baseFeatureStart - guard localIndex >= 0, localIndex < baseFeatureRows else { - throw LSEENDError.unsupportedAudio( - "Feature buffer underflow. Need frame \(frameIndex), buffer covers [\(baseFeatureStart), \(baseFeatureStart + baseFeatureRows - 1)]." - ) - } - let sourceBase = localIndex * config.nMels - for melIndex in 0.. 0 else { - return + // MARK: - Append and Pop + + public mutating func append(_ newElements: C) + where C.Element == Float { + // Lazy trimming + if buffer.count + newElements.count > buffer.capacity { + buffer.removeFirst(head) + head = 0 } - audioBuffer.removeFirst(dropCount) - audioStartSample += dropCount + + // Allow Swift to reserve more memory if needed after the trimming + buffer.append(contentsOf: newElements) } - private func dropConsumedBaseFeatures() { - let keepFrom = max(0, nextModelFrame * config.subsampling - config.contextRecp) - let dropRows = keepFrom - baseFeatureStart - guard dropRows > 0 else { - return - } - let dropCount = dropRows * config.nMels - baseFeatureBuffer.removeFirst(dropCount) - baseFeatureRows -= dropRows - baseFeatureStart += dropRows + /// Pop the last chunk + public mutating func popNextChunk() -> ArraySlice? { + guard hasChunk else { return nil } + let result = buffer[head.. ArraySlice? { + guard hasChunk else { return nil } + let newHead = head + (buffer.count - head - contextFloats) / chunkFloats * chunkFloats + let result = buffer[head.. prediction subsampling + public let subsampling: Int + + /// Number of right context output frames + public let convDelay: Int + + // Model structure + public let nUnits: Int + public let nHeads: Int + public let encNLayers: Int + public let decNLayers: Int + public let convKernelSize: Int + public var headDim: Int { nUnits / nHeads } + + /// Mel frames per chunk + public var melFrames: Int { (chunkSize - 1) * subsampling + 2 * contextSize + 1 } + + public var nFFT: Int { + 1 << (Int.bitWidth - (winLength - 1).leadingZeroBitCount) + } +} + +// MARK: - Recurrent State + +public struct LSEENDState: ~Copyable { + public var encRetKv: MLMultiArray + public var encRetScale: MLMultiArray + public var encConvCache: MLMultiArray + public var cnnWindow: MLMultiArray + public var decRetKv: MLMultiArray + public var decRetScale: MLMultiArray + + public init( + encRetKv: MLMultiArray, + encRetScale: MLMultiArray, + encConvCache: MLMultiArray, + cnnWindow: MLMultiArray, + decRetKv: MLMultiArray, + decRetScale: MLMultiArray + ) { + self.encRetKv = encRetKv + self.encRetScale = encRetScale + self.encConvCache = encConvCache + self.cnnWindow = cnnWindow + self.decRetKv = decRetKv + self.decRetScale = decRetScale + } + + public init(from metadata: borrowing LSEENDMetadata) throws { + let Lenc = NSNumber(value: metadata.encNLayers) + let Ldec = NSNumber(value: metadata.decNLayers) + let H = NSNumber(value: metadata.nHeads) + let hd = NSNumber(value: metadata.headDim) + let D = NSNumber(value: metadata.nUnits) + let K = NSNumber(value: metadata.convKernelSize) + let Kcnn = NSNumber(value: 2 * metadata.convDelay) + let nSpk = NSNumber(value: metadata.maxNspks) + + func makeArray(shape: [NSNumber]) throws -> MLMultiArray { + try ANEMemoryUtils.createAlignedArray(shape: shape, dataType: .float32) + } + + self.init( + encRetKv: try makeArray(shape: [Lenc, 1, H, hd, hd]), + encRetScale: try makeArray(shape: [Lenc, 1]), + encConvCache: try makeArray(shape: [Lenc, 1, K, D]), + cnnWindow: try makeArray(shape: [1, D, Kcnn]), + decRetKv: try makeArray(shape: [Ldec, nSpk, H, hd, hd]), + decRetScale: try makeArray(shape: [Ldec, 1]) + ) + + self.reset() + } + + public func copy() throws -> LSEENDState { + func clone(_ src: MLMultiArray) throws -> MLMultiArray { + let dst = try ANEMemoryUtils.createAlignedArray( + shape: src.shape, dataType: src.dataType + ) + ANEMemoryUtils.strideAwareCopy(from: src, to: dst) + return dst + } + return LSEENDState( + encRetKv: try clone(encRetKv), + encRetScale: try clone(encRetScale), + encConvCache: try clone(encConvCache), + cnnWindow: try clone(cnnWindow), + decRetKv: try clone(decRetKv), + decRetScale: try clone(decRetScale) + ) + } + + public func copy(to dst: borrowing LSEENDState) { + ANEMemoryUtils.strideAwareCopy(from: encRetKv, to: dst.encRetKv) + ANEMemoryUtils.strideAwareCopy(from: encRetScale, to: dst.encRetScale) + ANEMemoryUtils.strideAwareCopy(from: encConvCache, to: dst.encConvCache) + ANEMemoryUtils.strideAwareCopy(from: cnnWindow, to: dst.cnnWindow) + ANEMemoryUtils.strideAwareCopy(from: decRetKv, to: dst.decRetKv) + ANEMemoryUtils.strideAwareCopy(from: decRetScale, to: dst.decRetScale) + } + + public func reset() { + func clear(_ buffer: MLMultiArray) { + buffer.withUnsafeMutableBufferPointer(ofType: Float.self) { buf, _ in + guard let base = buf.baseAddress else { return } + memset(base, 0, buf.count * MemoryLayout.stride) + } + } + + clear(encRetKv) + clear(encRetScale) + clear(encConvCache) + clear(cnnWindow) + clear(decRetKv) + clear(decRetScale) + } +} + +// MARK: - Errors + +public enum LSEENDError: Error, LocalizedError { + case initializationFailed(String) + case inferenceFailed(String) + case invalidInputSize(String) + case notInitialized +} diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index 69264524c..9fc9bac7c 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -22,7 +22,10 @@ public enum Repo: String, CaseIterable, Sendable { case diarizer = "FluidInference/speaker-diarization-coreml" case kokoro = "FluidInference/kokoro-82m-coreml" case sortformer = "FluidInference/diar-streaming-sortformer-coreml" - case lseend = "FluidInference/ls-eend-coreml" + case lseendAmi = "FluidInference/ls-eend-coreml/optimized/ami" + case lseendCallHome = "FluidInference/ls-eend-coreml/optimized/ch" + case lseendDihard2 = "FluidInference/ls-eend-coreml/optimized/dih2" + case lseendDihard3 = "FluidInference/ls-eend-coreml/optimized/dih3" case pocketTts = "FluidInference/pocket-tts-coreml" case qwen3Asr = "FluidInference/qwen3-asr-0.6b-coreml/f32" case qwen3AsrInt8 = "FluidInference/qwen3-asr-0.6b-coreml/int8" @@ -67,8 +70,14 @@ public enum Repo: String, CaseIterable, Sendable { return "kokoro-82m-coreml" case .sortformer: return "diar-streaming-sortformer-coreml" - case .lseend: - return "ls-eend-coreml" + case .lseendAmi: + return "ls-eend-coreml/optimized/ami" + case .lseendCallHome: + return "ls-eend-coreml/optimized/ch" + case .lseendDihard2: + return "ls-eend-coreml/optimized/dih2" + case .lseendDihard3: + return "ls-eend-coreml/optimized/dih3" case .pocketTts: return "pocket-tts-coreml" case .qwen3Asr: @@ -97,7 +106,7 @@ public enum Repo: String, CaseIterable, Sendable { return "FluidInference/nemotron-speech-streaming-en-0.6b-coreml" case .sortformer: return "FluidInference/diar-streaming-sortformer-coreml" - case .lseend: + case .lseendAmi, .lseendCallHome, .lseendDihard2, .lseendDihard3: return "FluidInference/ls-eend-coreml" case .qwen3Asr, .qwen3AsrInt8: return "FluidInference/qwen3-asr-0.6b-coreml" @@ -131,6 +140,14 @@ public enum Repo: String, CaseIterable, Sendable { return "nemotron_coreml_160ms" case .nemotronStreaming80: return "nemotron_coreml_80ms" + case .lseendAmi: + return "optimized/ami" + case .lseendCallHome: + return "optimized/ch" + case .lseendDihard2: + return "optimized/dih2" + case .lseendDihard3: + return "optimized/dih3" case .cohereTranscribeCoreml: return "q8" default: @@ -169,6 +186,14 @@ public enum Repo: String, CaseIterable, Sendable { return "parakeet-ja" case .parakeetTdtCtc110m: return "parakeet-tdt-ctc-110m" + case .lseendAmi: + return "ls-eend/ami" + case .lseendCallHome: + return "ls-eend/ch" + case .lseendDihard2: + return "ls-eend/dih2" + case .lseendDihard3: + return "ls-eend/dih3" case .cohereTranscribeCoreml: return "cohere-transcribe/q8" default: @@ -483,52 +508,80 @@ public enum ModelNames { /// LS-EEND streaming diarization model names public enum LSEEND { - public enum Variant: String, CaseIterable, Sendable, CustomStringConvertible { - case ami = "AMI" - case callhome = "CALLHOME" - case dihard2 = "DIHARD II" - case dihard3 = "DIHARD III" + public enum Variant: CaseIterable, Sendable, CustomStringConvertible { + case ami + case callhome + case dihard2 + case dihard3 + + public var repo: Repo { + switch self { + case .ami: return .lseendAmi + case .callhome: return .lseendCallHome + case .dihard2: return .lseendDihard2 + case .dihard3: return .lseendDihard3 + } + } public var name: String { switch self { case .ami: - return "ls_eend_ami_step" + return "ls_eend_ami" case .callhome: - return "ls_eend_callhome_step" + return "ls_eend_ch" case .dihard2: - return "ls_eend_dih2_step" + return "ls_eend_dih2" case .dihard3: - return "ls_eend_dih3_step" + return "ls_eend_dih3" } } - public var description: String { rawValue } + public var description: String { name } - public var stem: String { "\(rawValue)/\(name)" } + public func name(forStep step: StepSize) -> String { + "\(name)_\(step)" + } - public var modelFile: String { "\(stem).mlmodelc" } + public func fileName(forStep step: StepSize) -> String { + "\(step)/\(name)_\(step).mlmodelc" + } + } - public var configFile: String { "\(stem).json" } + public enum StepSize: Int, CaseIterable, Sendable, CustomStringConvertible { + case step100ms = 1 + case step200ms = 2 + case step300ms = 3 + case step400ms = 4 + case step500ms = 5 - public var fileNames: [String] { [modelFile, configFile] } + public var description: String { + switch self { + case .step100ms: return "100ms" + case .step200ms: return "200ms" + case .step300ms: return "300ms" + case .step400ms: return "400ms" + case .step500ms: return "500ms" + } + } } /// Lowest latency for streaming public static let defaultVariant: Variant = .dihard3 + public static let defaultStep: StepSize = .step100ms /// Bundle name for a specific variant - public static func bundle(for variant: Variant) -> [String] { - return variant.fileNames + public static func bundle(for variant: Variant, withStep step: StepSize) -> [String] { + return [variant.fileName(forStep: step)] } /// Default bundle name public static var defaultBundle: [String] { - return defaultVariant.fileNames + return [defaultVariant.fileName(forStep: defaultStep)] } - /// All Sortformer bundle models required by the downloader + /// All LS-EEND bundle models required by the downloader public static var requiredModels: Set { - Set(Variant.allCases.flatMap(\.fileNames)) + Set(Variant.allCases.flatMap { StepSize.allCases.map($0.fileName) }) } } @@ -749,9 +802,9 @@ public enum ModelNames { return [variant] } return ModelNames.Sortformer.requiredModels - case .lseend: + case .lseendAmi, .lseendCallHome, .lseendDihard2, .lseendDihard3: if let variant = variant { - return [variant + ".mlmodelc", variant + ".json"] + return [variant + ".mlmodelc"] } return ModelNames.LSEEND.requiredModels case .qwen3Asr, .qwen3AsrInt8: diff --git a/Sources/FluidAudioCLI/Commands/DiarizationBenchmarkUtils.swift b/Sources/FluidAudioCLI/Commands/DiarizationBenchmarkUtils.swift index e41f500d7..7bab553b5 100644 --- a/Sources/FluidAudioCLI/Commands/DiarizationBenchmarkUtils.swift +++ b/Sources/FluidAudioCLI/Commands/DiarizationBenchmarkUtils.swift @@ -5,6 +5,12 @@ import Foundation /// Shared utilities for diarization benchmark commands (LS-EEND and Sortformer). enum DiarizationBenchmarkUtils { + enum AMISplit: String { + case train + case dev + case test + } + /// Dataset corpora supported by diarization benchmarks. enum Dataset: String { case ami = "ami" @@ -30,7 +36,9 @@ enum DiarizationBenchmarkUtils { // MARK: - File Paths - static func getAMIFiles(maxFiles: Int?) -> [String] { + static func getAMIFiles(split: AMISplit = .test, maxFiles: Int?) -> [String] { + let allMeetings = getAMIMeetings(split: split) + var availableMeetings: [String] = [] for meeting in DatasetDownloader.officialAMITestSet { let path = getAudioPath(for: meeting, dataset: .ami) @@ -45,6 +53,48 @@ enum DiarizationBenchmarkUtils { return availableMeetings } + static func getAMIMeetings(split: AMISplit) -> [String] { + switch split { + case .train: + return [ + "EN2001a", "EN2001d", "EN2001e", "EN2002a", "EN2002b", "EN2002c", "EN2002d", + "EN2003a", "EN2004a", "EN2005a", "EN2006a", "EN2006b", "EN2009b", "EN2009c", + "EN2009d", "ES2002a", "ES2002b", "ES2002c", "ES2002d", "ES2003a", "ES2003b", + "ES2003c", "ES2003d", "ES2005a", "ES2005b", "ES2005c", "ES2005d", "ES2006a", + "ES2006b", "ES2006c", "ES2006d", "ES2007a", "ES2007b", "ES2007c", "ES2007d", + "ES2008a", "ES2008b", "ES2008c", "ES2008d", "ES2009a", "ES2009b", "ES2009c", + "ES2009d", "ES2010a", "ES2010b", "ES2010c", "ES2010d", "ES2012a", "ES2012b", + "ES2012c", "ES2012d", "ES2013a", "ES2013b", "ES2013c", "ES2013d", "ES2014a", + "ES2014b", "ES2014c", "ES2014d", "ES2015a", "ES2015b", "ES2015c", "ES2015d", + "ES2016a", "ES2016b", "ES2016c", "ES2016d", "IB4005", "IN1001", "IN1002", + "IN1005", "IN1007", "IN1008", "IN1009", "IN1012", "IN1013", "IN1014", "IN1016", + "IS1000a", "IS1000b", "IS1000c", "IS1000d", "IS1001a", "IS1001b", "IS1001c", + "IS1001d", "IS1002b", "IS1002c", "IS1002d", "IS1003a", "IS1003b", "IS1003c", + "IS1003d", "IS1004a", "IS1004b", "IS1004c", "IS1004d", "IS1005a", "IS1005b", + "IS1005c", "IS1006a", "IS1006b", "IS1006c", "IS1006d", "IS1007a", "IS1007b", + "IS1007c", "IS1007d", "TS3005a", "TS3005b", "TS3005c", "TS3005d", "TS3006a", + "TS3006b", "TS3006c", "TS3006d", "TS3007a", "TS3007b", "TS3007c", "TS3007d", + "TS3008a", "TS3008b", "TS3008c", "TS3008d", "TS3009a", "TS3009b", "TS3009c", + "TS3009d", "TS3010a", "TS3010b", "TS3010c", "TS3010d", "TS3011a", "TS3011b", + "TS3011c", "TS3011d", "TS3012a", "TS3012b", "TS3012c", "TS3012d", + ] + case .dev: + return [ + "ES2011a", "ES2011b", "ES2011c", "ES2011d", + "IB4001", "IB4002", "IB4003", "IB4004", "IB4010", "IB4011", + "IS1008a", "IS1008b", "IS1008c", "IS1008d", + "TS3004a", "TS3004b", "TS3004c", "TS3004d", + ] + case .test: + return [ + "EN2002a", "EN2002b", "EN2002c", "EN2002d", + "ES2004a", "ES2004b", "ES2004c", "ES2004d", + "IS1009a", "IS1009b", "IS1009c", "IS1009d", + "TS3003a", "TS3003b", "TS3003c", "TS3003d", + ] + } + } + static func getAudioPath(for meeting: String, dataset: Dataset) -> String { let homeDir = FileManager.default.homeDirectoryForCurrentUser switch dataset { @@ -64,23 +114,54 @@ enum DiarizationBenchmarkUtils { } static func getRTTMURL(for meeting: String, dataset: Dataset) -> URL? { - let homeDir = FileManager.default.homeDirectoryForCurrentUser switch dataset { case .ami: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/ami_official/rttm/\(meeting).rttm" - ) + return getAMIRTTMURL(for: meeting) case .voxconverse: + let homeDir = FileManager.default.homeDirectoryForCurrentUser return homeDir.appendingPathComponent( "FluidAudioDatasets/voxconverse/rttm_repo/test/\(meeting).rttm" ) case .callhome: + let homeDir = FileManager.default.homeDirectoryForCurrentUser return homeDir.appendingPathComponent( "FluidAudioDatasets/callhome_eng/rttm/\(meeting).rttm" ) } } + static func getAMIRTTMURL(for meeting: String) -> URL? { + let fileManager = FileManager.default + let homeDir = fileManager.homeDirectoryForCurrentUser + let workingDir = URL(fileURLWithPath: fileManager.currentDirectoryPath) + return getAMIRTTMURL( + for: meeting, + workingDir: workingDir, + homeDir: homeDir, + fileManager: fileManager + ) + } + + static func getAMIRTTMURL( + for meeting: String, + workingDir: URL, + homeDir: URL, + fileManager: FileManager = .default + ) -> URL? { + let candidateURLs = [ + homeDir.appendingPathComponent("FluidAudioDatasets/ami_official/rttm/\(meeting).rttm"), + workingDir.appendingPathComponent("Datasets/diar-forced-alignment/AMI/test/\(meeting).rttm"), + workingDir.appendingPathComponent("Datasets/diar-forced-alignment/AMI/dev/\(meeting).rttm"), + workingDir.appendingPathComponent("Datasets/diar-forced-alignment/AMI/train/\(meeting).rttm"), + ] + + for candidateURL in candidateURLs where fileManager.fileExists(atPath: candidateURL.path) { + return candidateURL + } + + return candidateURLs.first + } + static func getVoxConverseFiles(maxFiles: Int?) -> [String] { let homeDir = FileManager.default.homeDirectoryForCurrentUser let voxDir = homeDir.appendingPathComponent( diff --git a/Sources/FluidAudioCLI/Commands/LSEENDBenchmark.swift b/Sources/FluidAudioCLI/Commands/LSEENDBenchmark.swift index a49847835..3bbbd4ac2 100644 --- a/Sources/FluidAudioCLI/Commands/LSEENDBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/LSEENDBenchmark.swift @@ -5,6 +5,7 @@ import Foundation /// LS-EEND diarization benchmark for evaluating performance on standard corpora enum LSEENDBenchmark { private static let logger = AppLogger(category: "LSEENDBench") + private static let frameStepForDER: Double = 0.01 typealias Dataset = DiarizationBenchmarkUtils.Dataset typealias BenchmarkResult = DiarizationBenchmarkUtils.BenchmarkResult @@ -20,12 +21,14 @@ enum LSEENDBenchmark { Options: --dataset Dataset to use: ami, voxconverse, callhome (default: ami) - --variant Model variant: ami, callhome, dihard2, dihard3 (default: dihard3) + --ami-split AMI split: dev, test, train (default: test) + --variant Model variant: ami, callhome, dihard2, dihard3 (default: ami) + --step-size Model step size: 100ms, 200ms, 300ms, 400ms, 500ms (default: 500ms) --single-file Process a specific meeting (e.g., ES2004a) --max-files Maximum number of files to process --threshold Speaker activity threshold (default: 0.5) --median-width Median filter width for post-processing (default: 1) - --collar Collar duration in seconds (default: 0.25) + --collar Collar duration in seconds (default: 0.0 for AMI, 0.25 otherwise) --onset Onset threshold for speech detection (default: 0.5) --offset Offset threshold for speech detection (default: 0.5) --pad-onset Padding before speech segments in seconds @@ -46,8 +49,8 @@ enum LSEENDBenchmark { # Full AMI benchmark with auto-download fluidaudio lseend-benchmark --auto-download --output results.json - # Benchmark with CALLHOME variant on CALLHOME dataset - fluidaudio lseend-benchmark --dataset callhome --variant callhome + # Benchmark with AMI 500ms model + fluidaudio lseend-benchmark --variant ami --step-size 500ms """) } @@ -58,6 +61,7 @@ enum LSEENDBenchmark { var threshold: Float = 0.5 var medianWidth: Int = 1 var collarSeconds: Double = 0.25 + var collarWasProvided = false var outputFile: String? var verbose = false var autoDownload = false @@ -72,7 +76,9 @@ enum LSEENDBenchmark { var progressFile: String = ".lseend_progress.json" var resumeFromProgress = false var dataset: Dataset = .ami - var variant: LSEENDVariant = .dihard3 + var amiSplit: DiarizationBenchmarkUtils.AMISplit = .test + var variant: LSEENDVariant = .ami + var stepSize: LSEENDStepSize = .step500ms var i = 0 while i < arguments.count { @@ -86,6 +92,15 @@ enum LSEENDBenchmark { } i += 1 } + case "--ami-split": + if i + 1 < arguments.count { + if let split = DiarizationBenchmarkUtils.AMISplit(rawValue: arguments[i + 1].lowercased()) { + amiSplit = split + } else { + print("Unknown AMI split: \(arguments[i + 1]). Using test.") + } + i += 1 + } case "--variant": if i + 1 < arguments.count { let v = arguments[i + 1].lowercased() @@ -103,6 +118,24 @@ enum LSEENDBenchmark { } i += 1 } + case "--step-size": + if i + 1 < arguments.count { + switch arguments[i + 1].lowercased() { + case "100", "100ms": + stepSize = .step100ms + case "200", "200ms": + stepSize = .step200ms + case "300", "300ms": + stepSize = .step300ms + case "400", "400ms": + stepSize = .step400ms + case "500", "500ms": + stepSize = .step500ms + default: + print("Unknown step size: \(arguments[i + 1]). Using 500ms.") + } + i += 1 + } case "--single-file": if i + 1 < arguments.count { singleFile = arguments[i + 1] @@ -126,6 +159,7 @@ enum LSEENDBenchmark { case "--collar": if i + 1 < arguments.count { collarSeconds = Double(arguments[i + 1]) ?? 0.25 + collarWasProvided = true i += 1 } case "--output": @@ -183,10 +217,18 @@ enum LSEENDBenchmark { i += 1 } + if dataset == .ami && !collarWasProvided { + collarSeconds = 0.0 + } + print("Starting LS-EEND Benchmark") fflush(stdout) print(" Dataset: \(dataset.rawValue)") - print(" Variant: \(variant.rawValue)") + if dataset == .ami { + print(" AMI split: \(amiSplit.rawValue)") + } + print(" Variant: \(variant.description)") + print(" Step size: \(stepSize.description)") print(" Threshold: \(threshold)") print(" Median width: \(medianWidth)") print(" Collar: \(collarSeconds)s") @@ -194,18 +236,57 @@ enum LSEENDBenchmark { // Download dataset if needed if autoDownload && dataset == .ami { print("Downloading AMI dataset if needed...") + let meetingsToDownload = + singleFile.map { [$0] } ?? DiarizationBenchmarkUtils.getAMIMeetings(split: amiSplit) await DatasetDownloader.downloadAMIDataset( variant: .sdm, force: false, - singleFile: singleFile + singleFile: singleFile, + meetingIds: meetingsToDownload ) await DatasetDownloader.downloadAMIAnnotations(force: false) } + let amiSplitDirectory: URL? + if dataset == .ami { + let splitDirectory = AMIKaldiData.splitDirectory(split: amiSplit) + + do { + if autoDownload { + try AMIKaldiData.ensureSplitExists(split: amiSplit) + } else if !AMIKaldiData.splitExists(split: amiSplit) { + print("AMI Kaldi split not found at \(splitDirectory.path)") + print( + "Run `fluidaudio lseend-benchmark --auto-download` to build Datasets/ami/mhs/data/\(amiSplit.rawValue)." + ) + return + } + } catch { + print("Failed to prepare AMI Kaldi data: \(error)") + return + } + + amiSplitDirectory = splitDirectory + } else { + amiSplitDirectory = nil + } + // Get list of files to process let filesToProcess: [String] if let meeting = singleFile { filesToProcess = [meeting] + } else if dataset == .ami { + guard let amiSplitDirectory else { + print("AMI Kaldi split directory was not initialized.") + return + } + + do { + filesToProcess = try AMIKaldiData.recordingIDs(in: amiSplitDirectory, maxFiles: maxFiles) + } catch { + print("Failed to enumerate AMI Kaldi recordings: \(error)") + return + } } else { filesToProcess = DiarizationBenchmarkUtils.getFiles(for: dataset, maxFiles: maxFiles) } @@ -251,10 +332,21 @@ enum LSEENDBenchmark { if let v = minDurationOn { timelineConfig.minDurationOn = v } if let v = minDurationOff { timelineConfig.minDurationOff = v } - let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly, timelineConfig: timelineConfig) + let diarizer: LSEENDDiarizer do { - try await diarizer.initialize(variant: variant) + let model = try await LSEENDModel.loadFromHuggingFace( + variant: variant, + stepSize: stepSize, + computeUnits: .cpuOnly + ) + diarizer = try LSEENDDiarizer(model: model) + diarizer.timeline = DiarizerTimeline( + config: configuredTimelineConfig( + base: timelineConfig, + diarizer: diarizer + ) + ) } catch { print("Failed to initialize LS-EEND: \(error)") return @@ -292,6 +384,7 @@ enum LSEENDBenchmark { let result = await processMeeting( meetingName: meetingName, dataset: dataset, + amiSplitDirectory: amiSplitDirectory, diarizer: diarizer, modelLoadTime: modelLoadTime, threshold: threshold, @@ -336,6 +429,7 @@ enum LSEENDBenchmark { private static func processMeeting( meetingName: String, dataset: Dataset, + amiSplitDirectory: URL?, diarizer: LSEENDDiarizer, modelLoadTime: Double, threshold: Float, @@ -345,18 +439,37 @@ enum LSEENDBenchmark { numSpeakers: Int, verbose: Bool ) async -> BenchmarkResult? { - let audioPath = DiarizationBenchmarkUtils.getAudioPath(for: meetingName, dataset: dataset) - guard FileManager.default.fileExists(atPath: audioPath) else { - print("Audio file not found: \(audioPath)") - fflush(stdout) - return nil - } - do { + let audioPath: String + if dataset == .ami { + guard let amiSplitDirectory else { + print("AMI Kaldi split directory was not initialized.") + return nil + } + guard let path = try AMIKaldiData.audioPath(for: meetingName, in: amiSplitDirectory) else { + print("AMI Kaldi wav.scp has no entry for \(meetingName)") + return nil + } + audioPath = path + } else { + audioPath = DiarizationBenchmarkUtils.getAudioPath(for: meetingName, dataset: dataset) + } + + guard FileManager.default.fileExists(atPath: audioPath) else { + print("Audio file not found: \(audioPath)") + fflush(stdout) + return nil + } + // Load and process audio let audioURL = URL(fileURLWithPath: audioPath) let startTime = Date() - let timeline = try diarizer.processComplete(audioFileURL: audioURL) + let timeline = try diarizer.processComplete( + audioFileURL: audioURL, + keepingEnrolledSpeakers: nil, + finalizeOnCompletion: true, + progressCallback: nil + ) let processingTime = Date().timeIntervalSince(startTime) let duration = timeline.finalizedDuration @@ -369,89 +482,56 @@ enum LSEENDBenchmark { print(" Total frames: \(numFrames)") } - // Load ground truth RTTM (or fall back to AMI XML annotations) - let rttmEntries: [LSEENDRTTMEntry] - let rttmSpeakers: [String] - - let rttmURL = DiarizationBenchmarkUtils.getRTTMURL(for: meetingName, dataset: dataset) - if let rttmURL = rttmURL, FileManager.default.fileExists(atPath: rttmURL.path) { - let parsed = try LSEENDEvaluation.parseRTTM(url: rttmURL) - rttmEntries = parsed.entries - rttmSpeakers = parsed.speakers - } else if dataset == .ami { - // Fall back to AMI XML annotations (same as SortformerBenchmark) - print(" [RTTM] No RTTM file, falling back to AMI annotations") - let groundTruth = await AMIParser.loadAMIGroundTruth( + let referenceSegments: [DERSpeakerSegment] + let groundTruthSpeakers: Int + + if dataset == .ami { + print(" [REF] Using AMI word-aligned annotations") + referenceSegments = await AMIParser.loadWordAlignedDERReference( for: meetingName, duration: duration ) - guard !groundTruth.isEmpty else { - print("No ground truth found for \(meetingName)") - return nil - } - // Convert TimedSpeakerSegment to LSEENDRTTMEntry - var speakers: [String] = [] - var entries: [LSEENDRTTMEntry] = [] - for segment in groundTruth { - if !speakers.contains(segment.speakerId) { - speakers.append(segment.speakerId) - } - entries.append( - LSEENDRTTMEntry( - recordingID: meetingName, - start: Double(segment.startTimeSeconds), - duration: Double(segment.endTimeSeconds - segment.startTimeSeconds), - speaker: segment.speakerId - ) + groundTruthSpeakers = Set(referenceSegments.map(\.speaker)).count + } else if let rttmURL = DiarizationBenchmarkUtils.getRTTMURL(for: meetingName, dataset: dataset), + FileManager.default.fileExists(atPath: rttmURL.path) + { + let groundTruth = try RTTMParser.loadSegments(from: rttmURL.path) + referenceSegments = groundTruth.map { + DERSpeakerSegment( + speaker: $0.speakerId, + start: Double($0.startTimeSeconds), + end: Double($0.endTimeSeconds) ) } - rttmEntries = entries - rttmSpeakers = speakers + groundTruthSpeakers = Set(groundTruth.map(\.speakerId)).count } else { print("No RTTM ground truth found for \(meetingName)") return nil } - let referenceBinary = LSEENDEvaluation.rttmToFrameMatrix( - entries: rttmEntries, - speakers: rttmSpeakers, - numFrames: numFrames, - frameRate: frameHz - ) - - print(" [RTTM] Loaded \(rttmEntries.count) segments, speakers: \(rttmSpeakers)") - - // Build probability matrix from timeline predictions - let predictions = timeline.finalizedPredictions - let probMatrix = LSEENDMatrix( - validatingRows: numFrames, - columns: numSpeakers, - values: predictions + print( + " [REF] Loaded \(referenceSegments.count) segments, speakers: \(groundTruthSpeakers)" ) - // Compute DER using the built-in evaluation - let settings = LSEENDEvaluationSettings( + let hypothesisSegments = timelineToDERSegments( + timeline, + numSpeakers: numSpeakers, threshold: threshold, - medianWidth: medianWidth, - collarSeconds: collarSeconds, - frameRate: frameHz + medianWidth: medianWidth ) - let evalResult = LSEENDEvaluation.computeDER( - probabilities: probMatrix, - referenceBinary: referenceBinary, - settings: settings + + let evalResult = DiarizationDER.compute( + ref: referenceSegments, + hyp: hypothesisSegments, + frameStep: frameStepForDER, + collar: collarSeconds ) + let totalRefSpeech = max(evalResult.totalRefSpeech, .leastNonzeroMagnitude) let derPercent = Float(evalResult.der * 100) - let missPercent = - evalResult.speakerScored > 0 - ? Float(evalResult.speakerMiss / evalResult.speakerScored * 100) : 0 - let faPercent = - evalResult.speakerScored > 0 - ? Float(evalResult.speakerFalseAlarm / evalResult.speakerScored * 100) : 0 - let sePercent = - evalResult.speakerScored > 0 - ? Float(evalResult.speakerError / evalResult.speakerScored * 100) : 0 + let missPercent = Float(evalResult.miss / totalRefSpeech * 100) + let faPercent = Float(evalResult.falseAlarm / totalRefSpeech * 100) + let sePercent = Float(evalResult.confusion / totalRefSpeech * 100) print( " DER breakdown: miss=\(String(format: "%.1f", missPercent))%, " @@ -478,7 +558,7 @@ enum LSEENDBenchmark { processingTime: processingTime, totalFrames: numFrames, detectedSpeakers: detectedSpeakerIndices.count, - groundTruthSpeakers: rttmSpeakers.count, + groundTruthSpeakers: groundTruthSpeakers, modelLoadTime: modelLoadTime, audioLoadTime: nil ) @@ -489,5 +569,108 @@ enum LSEENDBenchmark { } } + private static func configuredTimelineConfig( + base: DiarizerTimelineConfig, + diarizer: LSEENDDiarizer + ) -> DiarizerTimelineConfig { + var config = base + config.numSpeakers = diarizer.numSpeakers ?? config.numSpeakers + config.frameDurationSeconds = Float(1.0 / (diarizer.modelFrameHz ?? Double(config.frameDurationSeconds))) + return config + } + + private static func timelineToDERSegments( + _ timeline: DiarizerTimeline, + numSpeakers: Int, + threshold: Float, + medianWidth: Int + ) -> [DERSpeakerSegment] { + let binary = probabilitiesToBinary( + timeline.finalizedPredictions, + numFrames: timeline.numFinalizedFrames, + numSpeakers: numSpeakers, + threshold: threshold, + medianWidth: medianWidth + ) + return binaryToSegments( + binary, + numFrames: timeline.numFinalizedFrames, + numSpeakers: numSpeakers, + frameStep: Double(timeline.config.frameDurationSeconds) + ) + } + + private static func probabilitiesToBinary( + _ predictions: [Float], + numFrames: Int, + numSpeakers: Int, + threshold: Float, + medianWidth: Int + ) -> [Bool] { + var out = [Bool](repeating: false, count: numFrames * numSpeakers) + for frame in 0.. threshold + } + } + + guard medianWidth > 1 else { return out } + var filtered = out + let halfWindow = medianWidth / 2 + for speaker in 0..= (end - start) + } + } + return filtered + } + + private static func binaryToSegments( + _ binary: [Bool], + numFrames: Int, + numSpeakers: Int, + frameStep: Double + ) -> [DERSpeakerSegment] { + var segments: [DERSpeakerSegment] = [] + for speaker in 0.. [options] Options: - --variant Model variant: ami, callhome, dihard2, dihard3 (default: dihard3) + --variant Model variant: ami, callhome, dihard2, dihard3 (default: ami) + --step-size Model step size: 100ms, 200ms, 300ms, 400ms, 500ms (default: 500ms) --threshold Speaker activity threshold (default: 0.5) --onset Onset threshold for speech detection (default: 0.5) --offset Offset threshold for speech detection (default: 0.5) @@ -256,9 +293,22 @@ enum LSEENDCommand { # With specific variant fluidaudio lseend audio.wav --variant ami + # With explicit step size + fluidaudio lseend audio.wav --variant ami --step-size 500ms + # Save results to file fluidaudio lseend audio.wav --output results.json """) } + + private static func configuredTimelineConfig( + base: DiarizerTimelineConfig, + diarizer: LSEENDDiarizer + ) -> DiarizerTimelineConfig { + var config = base + config.numSpeakers = diarizer.numSpeakers ?? config.numSpeakers + config.frameDurationSeconds = Float(1.0 / (diarizer.modelFrameHz ?? Double(config.frameDurationSeconds))) + return config + } } #endif diff --git a/Sources/FluidAudioCLI/Commands/LSEENDEvaluation.swift b/Sources/FluidAudioCLI/Commands/LSEENDEvaluation.swift deleted file mode 100644 index 7d1079926..000000000 --- a/Sources/FluidAudioCLI/Commands/LSEENDEvaluation.swift +++ /dev/null @@ -1,513 +0,0 @@ -import FluidAudio -import Foundation - -/// A single speaker turn entry from an RTTM (Rich Transcription Time Marked) file. -public struct LSEENDRTTMEntry: Sendable, Codable { - /// The recording or file identifier. - public let recordingID: String - /// Start time of the speaker turn in seconds. - public let start: Double - /// Duration of the speaker turn in seconds. - public let duration: Double - /// Speaker label (e.g. `"spk0"`, `"speaker_A"`). - public let speaker: String - - public init(recordingID: String, start: Double, duration: Double, speaker: String) { - self.recordingID = recordingID - self.start = start - self.duration = duration - self.speaker = speaker - } -} - -/// Configuration for DER (Diarization Error Rate) evaluation. -public struct LSEENDEvaluationSettings: Sendable, Codable { - /// Probability threshold for binarizing speaker predictions (e.g. 0.5). - public let threshold: Float - /// Median filter kernel width applied after thresholding (0 or 1 to disable). - public let medianWidth: Int - /// Collar duration in seconds around reference speaker transitions to exclude from scoring. - public let collarSeconds: Double - /// Frame rate in Hz used to convert between time and frame indices. - public let frameRate: Double - - /// Creates evaluation settings. - public init(threshold: Float, medianWidth: Int, collarSeconds: Double, frameRate: Double) { - self.threshold = threshold - self.medianWidth = medianWidth - self.collarSeconds = collarSeconds - self.frameRate = frameRate - } -} - -/// Detailed results of a DER evaluation, including error breakdown and speaker mapping. -public struct LSEENDEvaluationResult: Sendable { - /// Overall Diarization Error Rate: `(miss + falseAlarm + speakerError) / speakerScored`. - public let der: Double - /// Total number of reference speaker-active frames scored (after collar exclusion). - public let speakerScored: Double - /// Missed speech: reference-active frames with no corresponding prediction. - public let speakerMiss: Double - /// False alarm: predicted-active frames with no corresponding reference. - public let speakerFalseAlarm: Double - /// Speaker confusion: frames where both reference and prediction are active but mapped to different speakers. - public let speakerError: Double - /// The probability threshold used for binarization. - public let threshold: Float - /// The median filter width applied after thresholding. - public let medianWidth: Int - /// The collar duration in seconds used during scoring. - public let collarSeconds: Double - /// Binary predictions remapped to reference speaker order via optimal assignment. - public let mappedBinary: LSEENDMatrix - /// Continuous probabilities remapped to reference speaker order. - public let mappedProbabilities: LSEENDMatrix - /// Per-frame mask: `true` for frames included in scoring, `false` for collar-excluded frames. - public let validMask: [Bool] - /// Optimal speaker assignment mapping: `[referenceIndex: predictionIndex]`. - public let assignment: [Int: Int] - /// Prediction column indices that were not matched to any reference speaker. - public let unmatchedPredictionIndices: [Int] -} - -/// Utilities for evaluating LS-EEND diarization output against reference annotations. -/// -/// Provides RTTM parsing/writing, post-processing (threshold + median filter), -/// and DER computation with collar masking and optimal speaker assignment. -public enum LSEENDEvaluation { - /// Parses an RTTM file into speaker turn entries. - /// - /// - Parameter url: Path to the RTTM file. - /// - Returns: A tuple of parsed entries and an ordered list of unique speaker labels. - public static func parseRTTM(url: URL) throws -> (entries: [LSEENDRTTMEntry], speakers: [String]) { - let text = try String(contentsOf: url, encoding: .utf8) - var entries: [LSEENDRTTMEntry] = [] - var speakers: [String] = [] - for line in text.split(whereSeparator: \.isNewline) { - let parts = line.split(separator: " ") - guard parts.count >= 8, parts[0] == "SPEAKER" else { continue } - let speaker = String(parts[7]) - if !speakers.contains(speaker) { - speakers.append(speaker) - } - entries.append( - LSEENDRTTMEntry( - recordingID: String(parts[1]), - start: Double(parts[3]) ?? 0, - duration: Double(parts[4]) ?? 0, - speaker: speaker - ) - ) - } - return (entries, speakers) - } - - /// Converts RTTM entries into a binary frame-level matrix. - /// - /// - Parameters: - /// - entries: Speaker turn entries from ``parseRTTM(url:)``. - /// - speakers: Ordered speaker labels defining column order. - /// - numFrames: Total number of output frames. - /// - frameRate: Frame rate in Hz for time-to-frame conversion. - /// - Returns: A binary matrix of shape `[numFrames, speakers.count]` where 1 indicates active speech. - public static func rttmToFrameMatrix( - entries: [LSEENDRTTMEntry], - speakers: [String], - numFrames: Int, - frameRate: Double - ) -> LSEENDMatrix { - var matrix = LSEENDMatrix.zeros(rows: numFrames, columns: speakers.count) - let speakerToIndex = Dictionary(uniqueKeysWithValues: speakers.enumerated().map { ($1, $0) }) - for entry in entries { - guard let speakerIndex = speakerToIndex[entry.speaker] else { continue } - let start = pythonRoundedInt(entry.start * frameRate) - let stop = pythonRoundedInt((entry.start + entry.duration) * frameRate) - guard stop > start else { continue } - for rowIndex in max(0, start).. 0 { - startIndex = rowIndex - } else if previous > 0, value == 0, let activeStart = startIndex { - let startSeconds = Double(activeStart) / frameRate - let durationSeconds = Double(rowIndex - activeStart) / frameRate - lines.append( - String( - format: "SPEAKER %@ 1 %.3f %.3f %@ ", - recordingID, - startSeconds, - durationSeconds, - labels[speakerIndex] - ) - ) - startIndex = nil - } - previous = value - } - if previous > 0, let activeStart = startIndex { - let startSeconds = Double(activeStart) / frameRate - let durationSeconds = Double(binaryPrediction.rows - activeStart) / frameRate - lines.append( - String( - format: "SPEAKER %@ 1 %.3f %.3f %@ ", - recordingID, - startSeconds, - durationSeconds, - labels[speakerIndex] - ) - ) - } - } - try lines.joined(separator: "\n").appending("\n").write(to: outputURL, atomically: true, encoding: .utf8) - } - - /// Computes a validity mask that excludes frames near speaker transitions. - /// - /// Frames within `collarFrames` of any speaker onset or offset in the reference - /// are marked `false` (excluded from DER scoring). - /// - /// - Parameters: - /// - reference: Binary reference matrix of shape `[frames, speakers]`. - /// - collarFrames: Number of frames on each side of a transition to exclude. - /// - Returns: Boolean mask of length `reference.rows`. - public static func collarMask(reference: LSEENDMatrix, collarFrames: Int) -> [Bool] { - guard collarFrames > 0 else { - return [Bool](repeating: true, count: reference.rows) - } - var mask = [Bool](repeating: true, count: reference.rows) - for columnIndex in 0.. 0 { - let start = max(0, reference.rows - collarFrames) - for maskedIndex in start.. LSEENDMatrix { - var binary = probabilities - for index in binary.values.indices { - binary.values[index] = binary.values[index] > value ? 1 : 0 - } - return binary - } - - /// Applies a 1D median filter along the time axis of each speaker column. - /// - /// Smooths binary predictions to remove brief spurious activations or gaps. - /// Even widths are rounded up to the next odd number. - /// - /// - Parameters: - /// - binary: Binary matrix to filter. - /// - width: Kernel width in frames (1 or 0 to skip filtering). - /// - Returns: Filtered binary matrix with the same shape. - public static func medianFilter(binary: LSEENDMatrix, width: Int) -> LSEENDMatrix { - guard width > 1, binary.rows > 0, binary.columns > 0 else { - return binary - } - let kernel = width % 2 == 0 ? width + 1 : width - let radius = kernel / 2 - var output = binary - for columnIndex in 0.. 0 { - ones += 1 - } - } - output[rowIndex, columnIndex] = ones * 2 >= count ? 1 : 0 - } - } - return output - } - - /// Computes the Diarization Error Rate (DER) between predictions and a reference. - /// - /// Applies thresholding, median filtering, collar masking, and optimal speaker - /// assignment (Hungarian-style) before computing miss, false alarm, and speaker error rates. - /// - /// - Parameters: - /// - probabilities: Continuous prediction matrix of shape `[frames, predSpeakers]`. - /// - referenceBinary: Binary reference matrix of shape `[frames, refSpeakers]`. - /// - settings: Evaluation parameters (threshold, median width, collar, frame rate). - /// - Returns: Detailed evaluation result including DER, error breakdown, and speaker mapping. - public static func computeDER( - probabilities: LSEENDMatrix, - referenceBinary: LSEENDMatrix, - settings: LSEENDEvaluationSettings - ) -> LSEENDEvaluationResult { - let predictionBinary = medianFilter( - binary: threshold(probabilities: probabilities, value: settings.threshold), - width: settings.medianWidth - ) - let validMask = collarMask( - reference: referenceBinary, - collarFrames: pythonRoundedInt(settings.collarSeconds * settings.frameRate) - ) - let mapping = mapPredictions( - predictionBinary: predictionBinary, - referenceBinary: referenceBinary, - validMask: validMask - ) - var mappedProbabilities = LSEENDMatrix.zeros(rows: probabilities.rows, columns: referenceBinary.columns) - for (referenceIndex, predictionIndex) in mapping.assignment { - for rowIndex in 0.. 0 - let predValue = scoredPrediction[rowIndex, columnIndex] > 0 - if refValue { referenceActive += 1 } - if predValue { predictionActive += 1 } - if refValue && predValue { mappedOverlap += 1 } - } - miss += Double(max(referenceActive - predictionActive, 0)) - falseAlarm += Double(max(predictionActive - referenceActive, 0)) - speakerError += Double(min(referenceActive, predictionActive) - mappedOverlap) - speakerScored += Double(referenceActive) - } - let der = speakerScored > 0 ? (miss + falseAlarm + speakerError) / speakerScored : 0 - return LSEENDEvaluationResult( - der: der, - speakerScored: speakerScored, - speakerMiss: miss, - speakerFalseAlarm: falseAlarm, - speakerError: speakerError, - threshold: settings.threshold, - medianWidth: settings.medianWidth, - collarSeconds: settings.collarSeconds, - mappedBinary: mapping.mappedBinary, - mappedProbabilities: mappedProbabilities, - validMask: validMask, - assignment: mapping.assignment, - unmatchedPredictionIndices: mapping.unmatchedPredictionIndices - ) - } - - private static func mapPredictions( - predictionBinary: LSEENDMatrix, - referenceBinary: LSEENDMatrix, - validMask: [Bool] - ) -> (mappedBinary: LSEENDMatrix, assignment: [Int: Int], unmatchedPredictionIndices: [Int]) { - let numPred = predictionBinary.columns - let numRef = referenceBinary.columns - var mapped = LSEENDMatrix.zeros(rows: predictionBinary.rows, columns: numRef) - guard numPred > 0, numRef > 0 else { - return (mapped, [:], Array(0..() - for (predIndex, refIndex) in assignment { - matchedPredictions.insert(predIndex) - mappedAssignment[refIndex] = predIndex - for rowIndex in 0.. Float { - var refCount = 0 - var predCount = 0 - var overlap = 0 - for rowIndex in 0.. 0 - let ref = referenceBinary[rowIndex, referenceIndex] > 0 - if ref { refCount += 1 } - if pred { predCount += 1 } - if ref && pred { overlap += 1 } - } - let miss = max(refCount - predCount, 0) - let falseAlarm = max(predCount - refCount, 0) - let speakerError = min(refCount, predCount) - overlap - return Float(miss + falseAlarm + speakerError) - } - - private static func solveRectangularAssignment(cost: [Float], rows: Int, columns: Int) -> [(Int, Int)] { - if rows <= columns { - let solution = solveAssignmentRowsToColumns(cost: cost, rows: rows, columns: columns) - return solution.enumerated().map { ($0.offset, $0.element) } - } - let transposed = transpose(cost: cost, rows: rows, columns: columns) - let solution = solveAssignmentRowsToColumns(cost: transposed, rows: columns, columns: rows) - return solution.enumerated().map { ($0.element, $0.offset) } - } - - private static func solveAssignmentRowsToColumns(cost: [Float], rows: Int, columns: Int) -> [Int] { - precondition(columns <= 20, "Assignment solver is O(2^columns); columns=\(columns) is too large") - let stateCount = 1 << columns - var dp = [Float](repeating: .greatestFiniteMagnitude, count: stateCount) - var parent = [Int](repeating: -1, count: stateCount) - var parentColumn = [Int](repeating: -1, count: stateCount) - dp[0] = 0 - - for mask in 0.. [Float] { - var output = [Float](repeating: 0, count: cost.count) - for rowIndex in 0.. LSEENDMatrix { - guard !indices.isEmpty else { return .empty(columns: 0) } - var output = [Float](repeating: 0, count: matrix.rows * indices.count) - for rowIndex in 0.. Int { - Int(value.rounded(.toNearestOrEven)) - } -} diff --git a/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift b/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift index 50e59ba46..f9a083a40 100644 --- a/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift @@ -5,6 +5,7 @@ import Foundation /// Sortformer streaming diarization benchmark for evaluating real-time performance enum SortformerBenchmark { private static let logger = AppLogger(category: "SortformerBench") + private static let derFrameStepSeconds: Double = 0.01 typealias Dataset = DiarizationBenchmarkUtils.Dataset typealias BenchmarkResult = DiarizationBenchmarkUtils.BenchmarkResult @@ -26,9 +27,9 @@ enum SortformerBenchmark { --model Path to Sortformer.mlpackage --nvidia-low-latency Use NVIDIA 1.04s latency config (20.57% DER target) --nvidia-high-latency Use NVIDIA 30.4s latency config (20.57% DER target) - --gradient-descent Use Gradient Descent config (downloads from HuggingFace by default) - --hf Download models from HuggingFace (clears cache first) - --local Use local models instead of HuggingFace (for --gradient-descent) + --gradient-descent Use Gradient Descent config + --hf Use HuggingFace/cache-backed model loading + --local Use local mlpackage loading instead of HuggingFace/cache-backed loading --output Output JSON file for results --progress Progress file for resuming (default: .sortformer_progress.json) --resume Resume from previous progress file @@ -67,8 +68,7 @@ enum SortformerBenchmark { var autoDownload = false var useNvidiaLowLatency = false var useNvidiaHighLatency = false - var useGradientDescent = false - var useHuggingFace = false + var useHuggingFace = true var useLocalModels = false var progressFile: String = ".sortformer_progress.json" var resumeFromProgress = false @@ -129,7 +129,7 @@ enum SortformerBenchmark { case "--nvidia-low-latency": useNvidiaLowLatency = true case "--gradient-descent": - useGradientDescent = true + break case "--hf": useHuggingFace = true case "--local": @@ -143,9 +143,10 @@ enum SortformerBenchmark { i += 1 } - // Gradient descent uses HuggingFace by default unless --local is specified - if useGradientDescent && !useLocalModels { - useHuggingFace = true + // Benchmarks should prefer the cache-backed Hugging Face loader by default. + // `--local` explicitly opts out to local mlpackage loading. + if useLocalModels { + useHuggingFace = false } print("Starting Sortformer Benchmark") @@ -159,7 +160,8 @@ enum SortformerBenchmark { let modeDesc = useHuggingFace - ? "HuggingFace models" : "Combined Pipeline" + ? "HuggingFace/cache-backed models" + : "Local mlpackage" print(" Mode: \(modeDesc)") print(" Preprocessing: Native Swift mel spectrogram") @@ -179,12 +181,6 @@ enum SortformerBenchmark { print(" Pipeline: \(pipelineURL.path)") - // Check models exist - guard useHuggingFace || FileManager.default.fileExists(atPath: pipelineURL.path) else { - print("ERROR: Pipeline model not found: \(pipelineURL.path)") - return - } - // Download dataset if needed if autoDownload && dataset == .ami { print("Downloading AMI dataset if needed...") @@ -253,6 +249,10 @@ enum SortformerBenchmark { let models = try await SortformerModels.loadFromHuggingFace(config: config) diarizer.initialize(models: models) } else { + guard FileManager.default.fileExists(atPath: pipelineURL.path) else { + print("ERROR: Local pipeline model not found: \(pipelineURL.path)") + return + } try await diarizer.initialize( mainModelPath: pipelineURL ) @@ -406,10 +406,10 @@ enum SortformerBenchmark { // Load ground truth from RTTM file (matches Python's approach) var groundTruth = loadRTTMGroundTruth(for: meetingName, dataset: dataset) - // Fall back to AMI XML annotations if no RTTM available (AMI only) + // Fall back to AMI word-aligned annotations if no RTTM available (AMI only) if groundTruth.isEmpty && dataset == .ami { - print(" [RTTM] No RTTM file, falling back to AMI annotations") - groundTruth = await AMIParser.loadAMIGroundTruth( + print(" [RTTM] No RTTM file, falling back to AMI word-aligned annotations") + groundTruth = await AMIParser.loadWordAlignedGroundTruth( for: meetingName, duration: duration ) @@ -420,19 +420,25 @@ enum SortformerBenchmark { return nil } - // Get filtered predictions for simple DER calculation (matches Python/NeMo) - let filteredPredictions = result.finalizedPredictions - - // Calculate DER using simple frame-level approach (matches NeMo evaluation) - // Frame shift is 0.08s (80ms) to match NeMo's subsampling_factor * window_stride - let simpleMetrics = calculateSimpleDER( - predictions: filteredPredictions, - numFrames: result.numFinalizedFrames, - numSpeakers: result.config.numSpeakers, - groundTruth: groundTruth, - threshold: threshold, - frameShift: 0.08 // 80ms frames like NeMo + let referenceSegments = groundTruth.map { + DERSpeakerSegment( + speaker: $0.speakerId, + start: Double($0.startTimeSeconds), + end: Double($0.endTimeSeconds) + ) + } + let hypothesisSegments = segmentsToDERSegments(segments) + let derResult = DiarizationDER.compute( + ref: referenceSegments, + hyp: hypothesisSegments, + frameStep: derFrameStepSeconds, + collar: 0 ) + let totalRefSpeech = max(derResult.totalRefSpeech, .leastNonzeroMagnitude) + let derPercent = Float(derResult.der * 100) + let missPercent = Float(derResult.miss / totalRefSpeech * 100) + let faPercent = Float(derResult.falseAlarm / totalRefSpeech * 100) + let sePercent = Float(derResult.confusion / totalRefSpeech * 100) // Count detected speakers let detectedSpeakers = segments.reduce(into: Set()) { @@ -451,10 +457,10 @@ enum SortformerBenchmark { return BenchmarkResult( meetingName: meetingName, - der: simpleMetrics.der, - missRate: simpleMetrics.miss, - falseAlarmRate: simpleMetrics.fa, - speakerErrorRate: simpleMetrics.se, + der: derPercent, + missRate: missPercent, + falseAlarmRate: faPercent, + speakerErrorRate: sePercent, rtfx: rtfx, processingTime: processingTime, totalFrames: result.numFinalizedFrames, @@ -527,135 +533,18 @@ enum SortformerBenchmark { return segments } - // MARK: - Simple Frame-Level DER (matches Python's calculation) - - /// Calculate DER using simple frame-level binary comparison like Python - /// This matches the NeMo evaluation approach without collar or complex segment overlap - private static func calculateSimpleDER( - predictions: [Float], - numFrames: Int, - numSpeakers: Int, - groundTruth: [TimedSpeakerSegment], - threshold: Float, - frameShift: Float // 0.08 for 80ms frames - ) -> (der: Float, miss: Float, fa: Float, se: Float) { - // Create reference binary matrix [numFrames, numSpeakers] - var refBinary = [[Float]](repeating: [Float](repeating: 0.0, count: numSpeakers), count: numFrames) - - // Map ground truth speakers to indices - let speakerLabels = Array(Set(groundTruth.map { $0.speakerId })).sorted() - var speakerMap = [String: Int]() - for (idx, label) in speakerLabels.enumerated() { - if idx < numSpeakers { - speakerMap[label] = idx - } - } - - // Fill reference binary from ground truth segments - for segment in groundTruth { - guard let spkIdx = speakerMap[segment.speakerId] else { continue } - let startFrame = max(0, min(Int(segment.startTimeSeconds / frameShift), numFrames)) - let endFrame = max(0, min(Int(segment.endTimeSeconds / frameShift), numFrames)) - for frame in startFrame.. threshold ? 1.0 : 0.0 - } - } - } - - // Try all permutations to find best DER - let permutations = generatePermutations(numSpeakers) - var bestDER: Float = .infinity - var bestMiss: Float = 0 - var bestFA: Float = 0 - var bestSE: Float = 0 - - for perm in permutations { - var missFrames: Float = 0 - var faFrames: Float = 0 - var seFrames: Float = 0 - var totalRefSpeech: Float = 0 - - for frame in 0.. 0 }) - var predSpeechPermuted = false - for spk in 0.. 0 { - predSpeechPermuted = true - break - } - } - - if refSpeech { - totalRefSpeech += 1 - } - - if refSpeech && !predSpeechPermuted { - missFrames += 1 - } else if !refSpeech && predSpeechPermuted { - faFrames += 1 - } else if refSpeech && predSpeechPermuted { - // Calculate speaker error - var refSpks = Set() - var predSpks = Set() - for spk in 0.. 0 { - refSpks.insert(spk) - } - if predBinary[frame][perm[spk]] > 0 { - predSpks.insert(spk) - } - } - let symDiff = refSpks.symmetricDifference(predSpks) - seFrames += Float(symDiff.count) / 2.0 - } - } - - if totalRefSpeech > 0 { - let der = (missFrames + faFrames + seFrames) / totalRefSpeech * 100 - if der < bestDER { - bestDER = der - bestMiss = missFrames / totalRefSpeech * 100 - bestFA = faFrames / totalRefSpeech * 100 - bestSE = seFrames / totalRefSpeech * 100 - } - } - } - - return (bestDER, bestMiss, bestFA, bestSE) - } - - /// Generate all permutations of 0.. [[Int]] { - if n == 0 { return [[]] } - if n == 1 { return [[0]] } - - var result: [[Int]] = [] - var arr = Array(0.. [DERSpeakerSegment] { + segments.flatMap { speakerSegments in + speakerSegments.map { segment in + DERSpeakerSegment( + speaker: segment.speakerLabel, + start: Double(segment.startTime), + end: Double(segment.endTime) + ) } } - - permute(0) - return result } } #endif diff --git a/Sources/FluidAudioCLI/DatasetParsers/AMIKaldiData.swift b/Sources/FluidAudioCLI/DatasetParsers/AMIKaldiData.swift new file mode 100644 index 000000000..f61f3d636 --- /dev/null +++ b/Sources/FluidAudioCLI/DatasetParsers/AMIKaldiData.swift @@ -0,0 +1,459 @@ +#if os(macOS) +import AVFoundation +import FluidAudio +import Foundation + +enum AMIKaldiData { + private static let logger = AppLogger(category: "AMIKaldiData") + private static let requiredKaldiFiles = [ + "wav.scp", "segments", "utt2spk", "spk2utt", "reco2dur", "reco2num_spk", "utt2timestamp", + ] + private static let sampleRate = 8_000.0 + private static let frameShiftSamples = 80.0 + private static let defaultFrameStep = frameShiftSamples / sampleRate + + struct SegmentEntry { + let utteranceId: String + let recordingId: String + let speakerId: String + let startTime: Double + let endTime: Double + } + + enum Error: LocalizedError { + case annotationsNotFound + case missingAudio(String) + case missingReference(String) + case invalidKaldiData(String) + + var errorDescription: String? { + switch self { + case .annotationsNotFound: + return "AMI annotations were not found. Expected Datasets/ami_public_1.6.2." + case .missingAudio(let meetingId): + return "AMI Kaldi data has no audio entry for \(meetingId)." + case .missingReference(let meetingId): + return "AMI Kaldi data has no reference segments for \(meetingId)." + case .invalidKaldiData(let message): + return message + } + } + } + + static func splitDirectory( + split: DiarizationBenchmarkUtils.AMISplit, + workingDirectory: URL = URL(fileURLWithPath: FileManager.default.currentDirectoryPath) + ) -> URL { + datasetsRoot(workingDirectory: workingDirectory) + .appendingPathComponent("ami/mhs/data/\(split.rawValue)", isDirectory: true) + } + + static func splitExists( + split: DiarizationBenchmarkUtils.AMISplit, + workingDirectory: URL = URL(fileURLWithPath: FileManager.default.currentDirectoryPath), + fileManager: FileManager = .default + ) -> Bool { + splitExists( + splitDirectory: splitDirectory(split: split, workingDirectory: workingDirectory), fileManager: fileManager) + } + + static func ensureSplitExists( + split: DiarizationBenchmarkUtils.AMISplit, + force: Bool = false, + workingDirectory: URL = URL(fileURLWithPath: FileManager.default.currentDirectoryPath), + homeDirectory: URL = FileManager.default.homeDirectoryForCurrentUser, + fileManager: FileManager = .default + ) throws { + let outputDirectory = splitDirectory(split: split, workingDirectory: workingDirectory) + if !force && splitExists(splitDirectory: outputDirectory, fileManager: fileManager) { + return + } + + guard + let annotationsRoot = findAnnotationsRoot( + workingDirectory: workingDirectory, + fileManager: fileManager + ) + else { + throw Error.annotationsNotFound + } + + try buildSplit( + split: split, + annotationsRoot: annotationsRoot, + audioRoot: homeDirectory.appendingPathComponent("FluidAudioDatasets/ami_official/sdm", isDirectory: true), + outputDirectory: outputDirectory, + fileManager: fileManager + ) + } + + static func buildSplit( + split: DiarizationBenchmarkUtils.AMISplit, + annotationsRoot: URL, + audioRoot: URL, + outputDirectory: URL, + fileManager: FileManager = .default + ) throws { + try buildSplit( + meetingIds: DiarizationBenchmarkUtils.getAMIMeetings(split: split), + annotationsRoot: annotationsRoot, + audioRoot: audioRoot, + outputDirectory: outputDirectory, + fileManager: fileManager + ) + } + + static func buildSplit( + meetingIds: [String], + annotationsRoot: URL, + audioRoot: URL, + outputDirectory: URL, + fileManager: FileManager = .default + ) throws { + let parser = AMIAnnotationParser() + let meetingsFile = annotationsRoot.appendingPathComponent("corpusResources/meetings.xml") + let segmentsDirectory = annotationsRoot.appendingPathComponent("segments", isDirectory: true) + + try fileManager.createDirectory(at: outputDirectory, withIntermediateDirectories: true) + + var wavLines: [String] = [] + var segmentLines: [String] = [] + var utt2spkLines: [String] = [] + var utt2timestampLines: [String] = [] + var reco2durLines: [String] = [] + var reco2numSpkLines: [String] = [] + var spkToUtterances: [String: [String]] = [:] + var generatedMeetings = 0 + + for meetingId in meetingIds.sorted() { + let audioURL = audioRoot.appendingPathComponent("\(meetingId).Mix-Headset.wav") + guard fileManager.fileExists(atPath: audioURL.path) else { + logger.warning("Skipping \(meetingId): audio not found at \(audioURL.path)") + continue + } + + guard + let mapping = try parser.parseSpeakerMapping( + for: meetingId, + from: meetingsFile + ) + else { + logger.warning("Skipping \(meetingId): no AMI speaker mapping found") + continue + } + + let segments = try loadSegments( + for: meetingId, + mapping: mapping, + parser: parser, + segmentsDirectory: segmentsDirectory, + fileManager: fileManager + ) + + guard !segments.isEmpty else { + logger.warning("Skipping \(meetingId): no AMI segments found") + continue + } + + let duration = try audioDuration(for: audioURL) + let speakers = Array(Set(segments.map(\.speakerId))).sorted() + + wavLines.append("\(meetingId) \(audioURL.path)") + reco2durLines.append("\(meetingId) \(formatSeconds(duration))") + reco2numSpkLines.append("\(meetingId) \(speakers.count)") + + for segment in segments { + segmentLines.append( + "\(segment.utteranceId) \(segment.recordingId) \(formatSeconds(segment.startTime)) \(formatSeconds(segment.endTime))" + ) + utt2spkLines.append("\(segment.utteranceId) \(segment.speakerId)") + utt2timestampLines.append( + "\(segment.utteranceId) \(formatSeconds(segment.startTime)) \(formatSeconds(segment.endTime))" + ) + spkToUtterances[segment.speakerId, default: []].append(segment.utteranceId) + } + + generatedMeetings += 1 + } + + guard generatedMeetings > 0 else { + throw Error.invalidKaldiData("Failed to build AMI Kaldi data: no meetings had both audio and annotations.") + } + + let spk2uttLines = spkToUtterances.keys.sorted().map { speakerId in + let utterances = spkToUtterances[speakerId, default: []].sorted() + return ([speakerId] + utterances).joined(separator: " ") + } + + try write(lines: wavLines.sorted(), to: outputDirectory.appendingPathComponent("wav.scp")) + try write(lines: segmentLines.sorted(), to: outputDirectory.appendingPathComponent("segments")) + try write(lines: utt2spkLines.sorted(), to: outputDirectory.appendingPathComponent("utt2spk")) + try write(lines: spk2uttLines, to: outputDirectory.appendingPathComponent("spk2utt")) + try write(lines: reco2durLines.sorted(), to: outputDirectory.appendingPathComponent("reco2dur")) + try write(lines: reco2numSpkLines.sorted(), to: outputDirectory.appendingPathComponent("reco2num_spk")) + try write(lines: utt2timestampLines.sorted(), to: outputDirectory.appendingPathComponent("utt2timestamp")) + } + + static func recordingIDs( + in splitDirectory: URL, + maxFiles: Int? = nil + ) throws -> [String] { + let recordings = try wavEntries(in: splitDirectory).keys.sorted() + guard let maxFiles else { return recordings } + return Array(recordings.prefix(maxFiles)) + } + + static func audioPath(for meetingId: String, in splitDirectory: URL) throws -> String? { + try wavEntries(in: splitDirectory)[meetingId] + } + + static func recordingDuration(for meetingId: String, in splitDirectory: URL) throws -> Double? { + try durationEntries(in: splitDirectory)[meetingId] + } + + static func loadDERReference( + for meetingId: String, + in splitDirectory: URL, + frameStep: Double = defaultFrameStep + ) throws -> [DERSpeakerSegment] { + let segments = try segmentEntries(in: splitDirectory) + .filter { $0.recordingId == meetingId } + + guard !segments.isEmpty else { + throw Error.missingReference(meetingId) + } + + var intervalsBySpeaker: [String: [(startFrame: Int, endFrame: Int)]] = [:] + for segment in segments { + let startFrame = Int((segment.startTime / frameStep).rounded(.toNearestOrEven)) + let endFrame = Int((segment.endTime / frameStep).rounded(.toNearestOrEven)) + guard endFrame > startFrame else { continue } + intervalsBySpeaker[segment.speakerId, default: []].append((startFrame, endFrame)) + } + + var references: [DERSpeakerSegment] = [] + for (speaker, intervals) in intervalsBySpeaker { + let sortedIntervals = intervals.sorted { + if $0.startFrame == $1.startFrame { + return $0.endFrame < $1.endFrame + } + return $0.startFrame < $1.startFrame + } + guard var current = sortedIntervals.first else { continue } + + for next in sortedIntervals.dropFirst() { + if next.startFrame <= current.endFrame { + current.endFrame = max(current.endFrame, next.endFrame) + continue + } + + references.append( + DERSpeakerSegment( + speaker: speaker, + start: Double(current.startFrame) * frameStep, + end: Double(current.endFrame) * frameStep + ) + ) + current = next + } + + references.append( + DERSpeakerSegment( + speaker: speaker, + start: Double(current.startFrame) * frameStep, + end: Double(current.endFrame) * frameStep + ) + ) + } + + return references.sorted { + if $0.start == $1.start { + if $0.end == $1.end { + return $0.speaker < $1.speaker + } + return $0.end < $1.end + } + return $0.start < $1.start + } + } + + private static func splitExists( + splitDirectory: URL, + fileManager: FileManager + ) -> Bool { + requiredKaldiFiles.allSatisfy { + fileManager.fileExists(atPath: splitDirectory.appendingPathComponent($0).path) + } + } + + private static func repositoryRoot() -> URL { + URL(fileURLWithPath: #filePath) + .deletingLastPathComponent() + .deletingLastPathComponent() + .deletingLastPathComponent() + .deletingLastPathComponent() + } + + private static func datasetsRoot( + workingDirectory: URL, + fileManager: FileManager = .default + ) -> URL { + let workingDatasets = workingDirectory.appendingPathComponent("Datasets", isDirectory: true) + if fileManager.fileExists(atPath: workingDatasets.path) { + return workingDatasets + } + return repositoryRoot().appendingPathComponent("Datasets", isDirectory: true) + } + + private static func findAnnotationsRoot( + workingDirectory: URL, + fileManager: FileManager + ) -> URL? { + let candidates = [ + datasetsRoot(workingDirectory: workingDirectory, fileManager: fileManager) + .appendingPathComponent("ami_public_1.6.2", isDirectory: true), + repositoryRoot().appendingPathComponent("Datasets/ami_public_1.6.2", isDirectory: true), + ] + + for candidate in candidates { + let segmentsDir = candidate.appendingPathComponent("segments", isDirectory: true) + let meetingsFile = candidate.appendingPathComponent("corpusResources/meetings.xml") + guard fileManager.fileExists(atPath: segmentsDir.path) else { continue } + guard fileManager.fileExists(atPath: meetingsFile.path) else { continue } + return candidate + } + + return nil + } + + private static func loadSegments( + for meetingId: String, + mapping: AMISpeakerMapping, + parser: AMIAnnotationParser, + segmentsDirectory: URL, + fileManager: FileManager + ) throws -> [SegmentEntry] { + var segments: [SegmentEntry] = [] + + for speakerCode in ["A", "B", "C", "D"] { + let fileURL = segmentsDirectory.appendingPathComponent("\(meetingId).\(speakerCode).segments.xml") + guard fileManager.fileExists(atPath: fileURL.path) else { continue } + guard let participantId = mapping.participantId(for: speakerCode) else { continue } + + let parsedSegments = try parser.parseSegmentsFile(fileURL) + for (index, segment) in parsedSegments.enumerated() where segment.duration > 0 { + segments.append( + SegmentEntry( + utteranceId: utteranceId(meetingId: meetingId, speakerCode: speakerCode, ordinal: index + 1), + recordingId: meetingId, + speakerId: participantId, + startTime: segment.startTime, + endTime: segment.endTime + ) + ) + } + } + + return segments.sorted { + if $0.recordingId == $1.recordingId { + if $0.startTime == $1.startTime { + if $0.endTime == $1.endTime { + return $0.utteranceId < $1.utteranceId + } + return $0.endTime < $1.endTime + } + return $0.startTime < $1.startTime + } + return $0.recordingId < $1.recordingId + } + } + + private static func utteranceId(meetingId: String, speakerCode: String, ordinal: Int) -> String { + "\(meetingId)_\(speakerCode.lowercased())_\(String(format: "%05d", ordinal))" + } + + private static func audioDuration(for audioURL: URL) throws -> Double { + let audioFile = try AVAudioFile(forReading: audioURL) + return Double(audioFile.length) / audioFile.processingFormat.sampleRate + } + + private static func write(lines: [String], to fileURL: URL) throws { + let contents = lines.joined(separator: "\n") + "\n" + try contents.write(to: fileURL, atomically: true, encoding: .utf8) + } + + private static func wavEntries(in splitDirectory: URL) throws -> [String: String] { + try parseKeyValueFile(splitDirectory.appendingPathComponent("wav.scp")) + } + + private static func durationEntries(in splitDirectory: URL) throws -> [String: Double] { + let lines = try String(contentsOf: splitDirectory.appendingPathComponent("reco2dur"), encoding: .utf8) + .split(whereSeparator: \.isNewline) + var result: [String: Double] = [:] + for line in lines { + let parts = line.split(maxSplits: 1, whereSeparator: \.isWhitespace) + guard parts.count == 2, let value = Double(parts[1]) else { + throw Error.invalidKaldiData("Invalid reco2dur line: \(line)") + } + result[String(parts[0])] = value + } + return result + } + + private static func segmentEntries(in splitDirectory: URL) throws -> [SegmentEntry] { + let utt2spk = try parseKeyValueFile(splitDirectory.appendingPathComponent("utt2spk")) + let lines = try String(contentsOf: splitDirectory.appendingPathComponent("segments"), encoding: .utf8) + .split(whereSeparator: \.isNewline) + + var entries: [SegmentEntry] = [] + entries.reserveCapacity(lines.count) + + for line in lines { + let parts = line.split(whereSeparator: \.isWhitespace) + guard parts.count == 4 else { + throw Error.invalidKaldiData("Invalid segments line: \(line)") + } + + let utteranceId = String(parts[0]) + guard let speakerId = utt2spk[utteranceId] else { + throw Error.invalidKaldiData("utt2spk missing entry for \(utteranceId)") + } + guard let startTime = Double(parts[2]), let endTime = Double(parts[3]) else { + throw Error.invalidKaldiData("Invalid segment timestamps for \(utteranceId)") + } + + entries.append( + SegmentEntry( + utteranceId: utteranceId, + recordingId: String(parts[1]), + speakerId: speakerId, + startTime: startTime, + endTime: endTime + ) + ) + } + + return entries + } + + private static func parseKeyValueFile(_ fileURL: URL) throws -> [String: String] { + let lines = try String(contentsOf: fileURL, encoding: .utf8) + .split(whereSeparator: \.isNewline) + var result: [String: String] = [:] + + for line in lines { + let parts = line.split(maxSplits: 1, whereSeparator: \.isWhitespace) + guard parts.count == 2 else { + throw Error.invalidKaldiData("Invalid key-value line in \(fileURL.lastPathComponent): \(line)") + } + result[String(parts[0])] = String(parts[1]) + } + + return result + } + + private static func formatSeconds(_ value: Double) -> String { + String(format: "%.6f", value) + } +} +#endif diff --git a/Sources/FluidAudioCLI/DatasetParsers/AMIParser.swift b/Sources/FluidAudioCLI/DatasetParsers/AMIParser.swift index e8fa8ebd2..74adb1f21 100644 --- a/Sources/FluidAudioCLI/DatasetParsers/AMIParser.swift +++ b/Sources/FluidAudioCLI/DatasetParsers/AMIParser.swift @@ -5,32 +5,12 @@ import Foundation /// AMI annotation parser and ground truth handling struct AMIParser { private static let logger = AppLogger(category: "AMIParser") + private static let defaultMergeGapSeconds = 0.5 + private static let defaultReferenceFrameStepSeconds = 0.01 /// Get ground truth speaker count from AMI meetings.xml static func getGroundTruthSpeakerCount(for meetingId: String) -> Int { - // Use the same path resolution logic as loadAMIGroundTruth for consistency - let possiblePaths = [ - // Current working directory - NEW Datasets location (after PR #19) - URL(fileURLWithPath: FileManager.default.currentDirectoryPath) - .appendingPathComponent( - "Datasets/ami_public_1.6.2"), - // Relative to source file - NEW Datasets location - URL(fileURLWithPath: #file).deletingLastPathComponent().deletingLastPathComponent() - .deletingLastPathComponent().appendingPathComponent( - "Datasets/ami_public_1.6.2"), - // OLD: Current working directory - Tests location (backward compatibility) - URL(fileURLWithPath: FileManager.default.currentDirectoryPath) - .appendingPathComponent( - "Tests/ami_public_1.6.2"), - // OLD: Relative to source file - Tests location (backward compatibility) - URL(fileURLWithPath: #file).deletingLastPathComponent().deletingLastPathComponent() - .deletingLastPathComponent().appendingPathComponent("Tests/ami_public_1.6.2"), - // OLD: Home directory - Tests location (backward compatibility) - FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent( - "code/FluidAudio/Tests/ami_public_1.6.2"), - ] - - for location in possiblePaths { + for location in possibleAnnotationRoots() { let meetingsFile = location.appendingPathComponent("corpusResources/meetings.xml") if FileManager.default.fileExists(atPath: meetingsFile.path) { do { @@ -66,45 +46,7 @@ struct AMIParser { ) async -> [TimedSpeakerSegment] { - // Try to find the AMI annotations directory in several possible locations - let possiblePaths = [ - // Current working directory - NEW Datasets location (after PR #19) - URL(fileURLWithPath: FileManager.default.currentDirectoryPath) - .appendingPathComponent( - "Datasets/ami_public_1.6.2"), - // Relative to source file - NEW Datasets location - URL(fileURLWithPath: #file).deletingLastPathComponent().deletingLastPathComponent() - .deletingLastPathComponent().appendingPathComponent( - "Datasets/ami_public_1.6.2"), - // OLD: Current working directory - Tests location (backward compatibility) - URL(fileURLWithPath: FileManager.default.currentDirectoryPath) - .appendingPathComponent( - "Tests/ami_public_1.6.2"), - // OLD: Relative to source file - Tests location (backward compatibility) - URL(fileURLWithPath: #file).deletingLastPathComponent().deletingLastPathComponent() - .deletingLastPathComponent().appendingPathComponent("Tests/ami_public_1.6.2"), - // OLD: Home directory - Tests location (backward compatibility) - FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent( - "code/FluidAudio/Tests/ami_public_1.6.2"), - ] - - // Add comprehensive debug logging for path resolution - var amiDir: URL? - for (_, path) in possiblePaths.enumerated() { - let segmentsDir = path.appendingPathComponent("segments") - let meetingsFile = path.appendingPathComponent("corpusResources/meetings.xml") - - let segmentsExists = FileManager.default.fileExists(atPath: segmentsDir.path) - let meetingsExists = FileManager.default.fileExists(atPath: meetingsFile.path) - - if segmentsExists && meetingsExists { - logger.info(" - 🎯 SELECTED: This path will be used") - amiDir = path - break - } - } - - guard let validAmiDir = amiDir else { + guard let validAmiDir = findAnnotationRoot(requiringSubdirectory: "segments") else { logger.warning(" AMI annotations not found in any expected location") logger.warning( " 📁 Expected structure: [path]/segments/ AND [path]/corpusResources/meetings.xml" @@ -118,76 +60,273 @@ struct AMIParser { return generateSimplifiedGroundTruth(duration: duration, speakerCount: 4) } - let segmentsDir = validAmiDir.appendingPathComponent("segments") - let meetingsFile = validAmiDir.appendingPathComponent("corpusResources/meetings.xml") - logger.info(" 📖 Loading AMI annotations for meeting: \(meetingId)") do { - let parser = AMIAnnotationParser() + let allSegments = try loadAMIGroundTruth( + for: meetingId, + in: validAmiDir, + duration: duration + ) + logger.info(" Total segments loaded: \(allSegments.count)") + return allSegments + } catch { + logger.warning(" Failed to parse AMI annotations: \(error)") + logger.warning(" Using simplified placeholder instead") + return generateSimplifiedGroundTruth(duration: duration, speakerCount: 4) + } + } - // Get speaker mapping for this meeting - guard - let speakerMapping = try parser.parseSpeakerMapping( - for: meetingId, from: meetingsFile) - else { - logger.warning( - " ⚠️ No speaker mapping found for meeting: \(meetingId), using placeholder" + /// Internal hook for tests and benchmark helpers that need deterministic parsing + /// from a specific AMI annotation root. + static func loadAMIGroundTruth( + for meetingId: String, + in amiDirectory: URL, + duration: Float + ) throws -> [TimedSpeakerSegment] { + _ = duration + return try loadOfficialGroundTruth( + for: meetingId, + in: amiDirectory, + filterShortSegments: true + ) + } + + private static func loadOfficialGroundTruth( + for meetingId: String, + in amiDirectory: URL, + filterShortSegments: Bool + ) throws -> [TimedSpeakerSegment] { + let segmentsDir = amiDirectory.appendingPathComponent("segments") + let meetingsFile = amiDirectory.appendingPathComponent("corpusResources/meetings.xml") + let parser = AMIAnnotationParser() + + guard + let speakerMapping = try parser.parseSpeakerMapping( + for: meetingId, + from: meetingsFile + ) + else { + throw NSError( + domain: "AMIParser", + code: 5, + userInfo: [NSLocalizedDescriptionKey: "No speaker mapping found for \(meetingId)"] + ) + } + + logger.info( + " Speaker mapping: A=\(speakerMapping.speakerA), B=\(speakerMapping.speakerB), C=\(speakerMapping.speakerC), D=\(speakerMapping.speakerD)" + ) + + var allSegments: [TimedSpeakerSegment] = [] + + for speakerCode in ["A", "B", "C", "D"] { + let segmentFile = segmentsDir.appendingPathComponent("\(meetingId).\(speakerCode).segments.xml") + guard FileManager.default.fileExists(atPath: segmentFile.path) else { continue } + guard let participantId = speakerMapping.participantId(for: speakerCode) else { continue } + + let segments = try parser.parseSegmentsFile(segmentFile) + for segment in segments where segment.duration > 0 { + if filterShortSegments && segment.duration < 0.5 { + continue + } + allSegments.append( + TimedSpeakerSegment( + speakerId: participantId, + embedding: generatePlaceholderEmbedding(for: participantId), + startTimeSeconds: Float(segment.startTime), + endTimeSeconds: Float(segment.endTime), + qualityScore: 1.0 + ) ) - return generateSimplifiedGroundTruth(duration: duration, speakerCount: 4) } logger.info( - " Speaker mapping: A=\(speakerMapping.speakerA), B=\(speakerMapping.speakerB), C=\(speakerMapping.speakerC), D=\(speakerMapping.speakerD)" + " Loaded \(segments.count) segments for speaker \(speakerCode) (\(participantId))" ) + } - var allSegments: [TimedSpeakerSegment] = [] + allSegments.sort { + if $0.startTimeSeconds == $1.startTimeSeconds { + if $0.endTimeSeconds == $1.endTimeSeconds { + return $0.speakerId < $1.speakerId + } + return $0.endTimeSeconds < $1.endTimeSeconds + } + return $0.startTimeSeconds < $1.startTimeSeconds + } + return allSegments + } - // Parse segments for each speaker (A, B, C, D) - for speakerCode in ["A", "B", "C", "D"] { - let segmentFile = segmentsDir.appendingPathComponent( - "\(meetingId).\(speakerCode).segments.xml") + /// Load AMI annotations as a 10 ms frame-quantized DER reference, matching the + /// original Kaldi-style label construction used by the LS-EEND repo. + static func loadFrameAlignedDERReference( + for meetingId: String, + duration: Float, + frameStep: Double = defaultReferenceFrameStepSeconds + ) async -> [DERSpeakerSegment] { + guard let validAmiDir = findAnnotationRoot(requiringSubdirectory: "segments") else { + logger.warning(" AMI annotations not found in any expected location") + logger.warning( + " 📁 Expected structure: [path]/segments/ AND [path]/corpusResources/meetings.xml" + ) + logger.warning(" 📋 Falling back to simplified placeholder ground truth") + return frameAlignedDERReference( + from: generateSimplifiedGroundTruth(duration: duration, speakerCount: 4), + frameStep: frameStep + ) + } - if FileManager.default.fileExists(atPath: segmentFile.path) { - let segments = try parser.parseSegmentsFile(segmentFile) + do { + return try loadFrameAlignedDERReference( + for: meetingId, + in: validAmiDir, + duration: duration, + frameStep: frameStep + ) + } catch { + logger.warning(" Failed to parse AMI annotations: \(error)") + logger.warning(" Falling back to simplified placeholder ground truth") + return frameAlignedDERReference( + from: generateSimplifiedGroundTruth(duration: duration, speakerCount: 4), + frameStep: frameStep + ) + } + } - // Map to TimedSpeakerSegment with real participant ID - guard let participantId = speakerMapping.participantId(for: speakerCode) - else { - continue - } + static func loadFrameAlignedDERReference( + for meetingId: String, + in amiDirectory: URL, + duration: Float, + frameStep: Double = defaultReferenceFrameStepSeconds + ) throws -> [DERSpeakerSegment] { + _ = duration + let segments = try loadOfficialGroundTruth( + for: meetingId, + in: amiDirectory, + filterShortSegments: false + ) + return frameAlignedDERReference(from: segments, frameStep: frameStep) + } - for segment in segments { - // Filter out very short segments (< 0.5 seconds) as done in research - guard segment.duration >= 0.5 else { continue } + /// Load AMI word-aligned ground truth annotations for a specific meeting. + /// + /// Uses forced-alignment `{meeting}.{A|B|C|D}.words.xml` files and merges + /// adjacent same-speaker words with gaps up to `mergeGap`. + static func loadWordAlignedGroundTruth( + for meetingId: String, + duration: Float, + mergeGap: Double = defaultMergeGapSeconds + ) async -> [TimedSpeakerSegment] { + guard let validAmiDir = findAnnotationRoot(requiringSubdirectory: "words") else { + logger.warning(" AMI word annotations not found in any expected location") + logger.warning( + " 📁 Expected structure: [path]/words/ AND [path]/corpusResources/meetings.xml" + ) + logger.warning(" 📋 Falling back to simplified placeholder ground truth") + return generateSimplifiedGroundTruth(duration: duration, speakerCount: 4) + } - let timedSegment = TimedSpeakerSegment( - speakerId: participantId, // Use real AMI participant ID - embedding: generatePlaceholderEmbedding(for: participantId), - startTimeSeconds: Float(segment.startTime), - endTimeSeconds: Float(segment.endTime), - qualityScore: 1.0 - ) + do { + return try loadWordAlignedGroundTruth( + for: meetingId, + in: validAmiDir, + duration: duration, + mergeGap: mergeGap + ) + } catch { + logger.warning(" Failed to parse AMI word annotations: \(error)") + logger.warning(" Falling back to simplified placeholder ground truth") + return generateSimplifiedGroundTruth(duration: duration, speakerCount: 4) + } + } - allSegments.append(timedSegment) - } + /// Internal hook for tests and benchmark helpers that need deterministic parsing + /// from a specific AMI annotation root. + static func loadWordAlignedGroundTruth( + for meetingId: String, + in amiDirectory: URL, + duration: Float, + mergeGap: Double = defaultMergeGapSeconds + ) throws -> [TimedSpeakerSegment] { + let wordsDir = amiDirectory.appendingPathComponent("words") + let meetingsFile = amiDirectory.appendingPathComponent("corpusResources/meetings.xml") + + let parser = AMIAnnotationParser() + guard + let speakerMapping = try parser.parseSpeakerMapping( + for: meetingId, + from: meetingsFile + ) + else { + throw NSError( + domain: "AMIParser", + code: 3, + userInfo: [NSLocalizedDescriptionKey: "No speaker mapping found for \(meetingId)"] + ) + } - logger.info( - " Loaded \(segments.count) segments for speaker \(speakerCode) (\(participantId))" + var allSegments: [TimedSpeakerSegment] = [] + for speakerCode in ["A", "B", "C", "D"] { + let wordsFile = wordsDir.appendingPathComponent("\(meetingId).\(speakerCode).words.xml") + guard FileManager.default.fileExists(atPath: wordsFile.path) else { continue } + guard let participantId = speakerMapping.participantId(for: speakerCode) else { continue } + + let words = try parser.parseWordsFile(wordsFile) + for segment in mergeSegments(words, mergeGap: mergeGap) { + allSegments.append( + TimedSpeakerSegment( + speakerId: participantId, + embedding: generatePlaceholderEmbedding(for: participantId), + startTimeSeconds: Float(segment.startTime), + endTimeSeconds: Float(segment.endTime), + qualityScore: 1.0 ) - } + ) } + } - // Sort by start time - allSegments.sort { $0.startTimeSeconds < $1.startTimeSeconds } + allSegments.sort { $0.startTimeSeconds < $1.startTimeSeconds } + return allSegments + } - logger.info(" Total segments loaded: \(allSegments.count)") - return allSegments + static func loadWordAlignedDERReference( + for meetingId: String, + duration: Float, + mergeGap: Double = defaultMergeGapSeconds + ) async -> [DERSpeakerSegment] { + let segments = await loadWordAlignedGroundTruth( + for: meetingId, + duration: duration, + mergeGap: mergeGap + ) + return segments.map { + DERSpeakerSegment( + speaker: $0.speakerId, + start: Double($0.startTimeSeconds), + end: Double($0.endTimeSeconds) + ) + } + } - } catch { - logger.warning(" Failed to parse AMI annotations: \(error)") - logger.warning(" Using simplified placeholder instead") - return generateSimplifiedGroundTruth(duration: duration, speakerCount: 4) + static func loadWordAlignedDERReference( + for meetingId: String, + in amiDirectory: URL, + duration: Float, + mergeGap: Double = defaultMergeGapSeconds + ) throws -> [DERSpeakerSegment] { + let segments = try loadWordAlignedGroundTruth( + for: meetingId, + in: amiDirectory, + duration: duration, + mergeGap: mergeGap + ) + return segments.map { + DERSpeakerSegment( + speaker: $0.speakerId, + start: Double($0.startTimeSeconds), + end: Double($0.endTimeSeconds) + ) } } @@ -232,6 +371,127 @@ struct AMIParser { } return embedding } + + private static func possibleAnnotationRoots() -> [URL] { + [ + URL(fileURLWithPath: FileManager.default.currentDirectoryPath) + .appendingPathComponent("Datasets/ami_public_1.6.2"), + URL(fileURLWithPath: #file).deletingLastPathComponent().deletingLastPathComponent() + .deletingLastPathComponent().appendingPathComponent("Datasets/ami_public_1.6.2"), + URL(fileURLWithPath: FileManager.default.currentDirectoryPath) + .appendingPathComponent("Tests/ami_public_1.6.2"), + URL(fileURLWithPath: #file).deletingLastPathComponent().deletingLastPathComponent() + .deletingLastPathComponent().appendingPathComponent("Tests/ami_public_1.6.2"), + FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent( + "code/FluidAudio/Tests/ami_public_1.6.2" + ), + ] + } + + private static func findAnnotationRoot(requiringSubdirectory subdirectory: String) -> URL? { + for path in possibleAnnotationRoots() { + let requiredDir = path.appendingPathComponent(subdirectory) + let meetingsFile = path.appendingPathComponent("corpusResources/meetings.xml") + let hasRequiredDir = FileManager.default.fileExists(atPath: requiredDir.path) + let hasMeetings = FileManager.default.fileExists(atPath: meetingsFile.path) + if hasRequiredDir && hasMeetings { + logger.info(" - 🎯 SELECTED: \(path.path)") + return path + } + } + return nil + } + + private static func mergeSegments( + _ segments: [AMISpeakerSegment], + mergeGap: Double + ) -> [AMISpeakerSegment] { + let sorted = segments.sorted { $0.startTime < $1.startTime } + guard var current = sorted.first else { return [] } + + var merged: [AMISpeakerSegment] = [] + for next in sorted.dropFirst() { + if next.startTime - current.endTime <= mergeGap { + current = AMISpeakerSegment( + segmentId: current.segmentId, + participantId: current.participantId, + startTime: current.startTime, + endTime: max(current.endTime, next.endTime) + ) + continue + } + merged.append(current) + current = next + } + + merged.append(current) + return merged + } + + private static func frameAlignedDERReference( + from segments: [TimedSpeakerSegment], + frameStep: Double + ) -> [DERSpeakerSegment] { + precondition(frameStep > 0) + + var intervalsBySpeaker: [String: [(startFrame: Int, endFrame: Int)]] = [:] + for segment in segments { + let startFrame = Int( + (Double(segment.startTimeSeconds) / frameStep).rounded(.toNearestOrEven) + ) + let endFrame = Int( + (Double(segment.endTimeSeconds) / frameStep).rounded(.toNearestOrEven) + ) + guard endFrame > startFrame else { continue } + intervalsBySpeaker[segment.speakerId, default: []].append((startFrame, endFrame)) + } + + var derSegments: [DERSpeakerSegment] = [] + for (speaker, intervals) in intervalsBySpeaker { + let sortedIntervals = intervals.sorted { + if $0.startFrame == $1.startFrame { + return $0.endFrame < $1.endFrame + } + return $0.startFrame < $1.startFrame + } + guard var current = sortedIntervals.first else { continue } + + for next in sortedIntervals.dropFirst() { + guard next.startFrame > current.endFrame else { + current.endFrame = max(current.endFrame, next.endFrame) + continue + } + + derSegments.append( + DERSpeakerSegment( + speaker: speaker, + start: Double(current.startFrame) * frameStep, + end: Double(current.endFrame) * frameStep + ) + ) + current = next + } + + derSegments.append( + DERSpeakerSegment( + speaker: speaker, + start: Double(current.startFrame) * frameStep, + end: Double(current.endFrame) * frameStep + ) + ) + } + + derSegments.sort { + if $0.start == $1.start { + if $0.end == $1.end { + return $0.speaker < $1.speaker + } + return $0.end < $1.end + } + return $0.start < $1.start + } + return derSegments + } } // MARK: - AMI Annotation Data Structures @@ -297,6 +557,33 @@ class AMIAnnotationParser: NSObject { return delegate.segments } + /// Parse words.xml file and return word-level speaker segments. + func parseWordsFile(_ xmlFile: URL) throws -> [AMISpeakerSegment] { + let data = try Data(contentsOf: xmlFile) + let speakerCode = extractSpeakerCodeFromFilename(xmlFile.lastPathComponent) + + let parser = XMLParser(data: data) + let delegate = AMIWordsXMLDelegate(speakerCode: speakerCode) + parser.delegate = delegate + + guard parser.parse() else { + throw NSError( + domain: "AMIParser", + code: 4, + userInfo: [ + NSLocalizedDescriptionKey: + "Failed to parse XML file: \(xmlFile.lastPathComponent)" + ] + ) + } + + if let error = delegate.parsingError { + throw error + } + + return delegate.segments + } + /// Extract speaker code from AMI filename private func extractSpeakerCodeFromFilename(_ filename: String) -> String { // Filename format: "EN2001a.A.segments.xml" -> extract "A" @@ -333,6 +620,48 @@ class AMIAnnotationParser: NSObject { } } +/// XML parser delegate for AMI words files +private class AMIWordsXMLDelegate: NSObject, XMLParserDelegate { + var segments: [AMISpeakerSegment] = [] + var parsingError: Error? + + private let speakerCode: String + + init(speakerCode: String) { + self.speakerCode = speakerCode + } + + func parser( + _ parser: XMLParser, didStartElement elementName: String, namespaceURI: String?, + qualifiedName qName: String?, attributes attributeDict: [String: String] = [:] + ) { + let tag = elementName.split(separator: ":").last.map(String.init) ?? elementName + guard tag == "w", + attributeDict["punc"] != "true", + let startTimeString = attributeDict["starttime"], + let endTimeString = attributeDict["endtime"], + let startTime = Double(startTimeString), + let endTime = Double(endTimeString), + endTime > startTime + else { + return + } + + segments.append( + AMISpeakerSegment( + segmentId: attributeDict["nite:id"] ?? UUID().uuidString, + participantId: speakerCode, + startTime: startTime, + endTime: endTime + ) + ) + } + + func parser(_ parser: XMLParser, parseErrorOccurred parseError: Error) { + parsingError = parseError + } +} + /// XML parser delegate for AMI segments files private class AMISegmentsXMLDelegate: NSObject, XMLParserDelegate { var segments: [AMISpeakerSegment] = [] diff --git a/Sources/FluidAudioCLI/DatasetParsers/DatasetDownloader.swift b/Sources/FluidAudioCLI/DatasetParsers/DatasetDownloader.swift index 9095adbb5..a5e3f19e2 100644 --- a/Sources/FluidAudioCLI/DatasetParsers/DatasetDownloader.swift +++ b/Sources/FluidAudioCLI/DatasetParsers/DatasetDownloader.swift @@ -37,7 +37,7 @@ struct DatasetDownloader { } static func downloadAMIDataset( - variant: AMIVariant, force: Bool, singleFile: String? = nil + variant: AMIVariant, force: Bool, singleFile: String? = nil, meetingIds: [String]? = nil ) async { @@ -59,11 +59,14 @@ struct DatasetDownloader { // Download AMI annotations first (required for proper benchmarking) await downloadAMIAnnotations(force: force) + await downloadAMIRTTMs(force: force, singleFile: singleFile, meetingIds: meetingIds) // Official AMI SDM test set (16 meetings) - matches NeMo evaluation let commonMeetings: [String] if let singleFile = singleFile { commonMeetings = [singleFile] + } else if let meetingIds { + commonMeetings = meetingIds } else { commonMeetings = Self.officialAMITestSet logger.info("📋 Downloading official AMI SDM test set (16 meetings)") @@ -259,6 +262,106 @@ struct DatasetDownloader { return false } + /// Sync AMI forced-alignment RTTMs from the local diar-forced-alignment repo into the standard cache path. + static func downloadAMIRTTMs( + force: Bool = false, + singleFile: String? = nil, + meetingIds: [String]? = nil + ) async { + let fileManager = FileManager.default + let homeDir = fileManager.homeDirectoryForCurrentUser + let workingDir = URL(fileURLWithPath: fileManager.currentDirectoryPath) + let sourceRoot = workingDir.appendingPathComponent("Datasets/diar-forced-alignment/AMI") + let destinationDir = homeDir.appendingPathComponent("FluidAudioDatasets/ami_official/rttm") + await downloadAMIRTTMs( + force: force, + singleFile: singleFile, + meetingIds: meetingIds, + sourceRoot: sourceRoot, + destinationDir: destinationDir, + fileManager: fileManager + ) + } + + static func downloadAMIRTTMs( + force: Bool = false, + singleFile: String? = nil, + meetingIds: [String]? = nil, + sourceRoot: URL, + destinationDir: URL, + fileManager: FileManager = .default + ) async { + guard fileManager.fileExists(atPath: sourceRoot.path) else { + logger.warning("AMI forced-alignment RTTM repo not found at \(sourceRoot.path)") + return + } + + do { + try fileManager.createDirectory(at: destinationDir, withIntermediateDirectories: true) + } catch { + logger.error("Failed to create AMI RTTM directory: \(error)") + return + } + + let selectedMeetingIds: [String] + if let singleFile { + selectedMeetingIds = [singleFile] + } else if let meetingIds { + selectedMeetingIds = meetingIds + } else { + selectedMeetingIds = [ + "EN2002a", "EN2002b", "EN2002c", "EN2002d", + "ES2004a", "ES2004b", "ES2004c", "ES2004d", + "IS1009a", "IS1009b", "IS1009c", "IS1009d", + "TS3003a", "TS3003b", "TS3003c", "TS3003d", + ] + } + + var copiedFiles = 0 + var skippedFiles = 0 + var missingFiles: [String] = [] + + for meetingId in selectedMeetingIds { + let destinationURL = destinationDir.appendingPathComponent("\(meetingId).rttm") + if !force && fileManager.fileExists(atPath: destinationURL.path) { + skippedFiles += 1 + continue + } + + guard let sourceURL = findAMIRTTMSource(meetingId: meetingId, sourceRoot: sourceRoot) else { + missingFiles.append(meetingId) + continue + } + + if fileManager.fileExists(atPath: destinationURL.path) { + try? fileManager.removeItem(at: destinationURL) + } + + do { + try fileManager.copyItem(at: sourceURL, to: destinationURL) + copiedFiles += 1 + } catch { + logger.error("Failed to copy RTTM for \(meetingId): \(error)") + } + } + + logger.info("AMI RTTMs: \(copiedFiles) copied, \(skippedFiles) skipped") + if !missingFiles.isEmpty { + logger.warning("Missing AMI RTTMs for: \(missingFiles.sorted().joined(separator: ", "))") + } + } + + private static func findAMIRTTMSource(meetingId: String, sourceRoot: URL) -> URL? { + let fileManager = FileManager.default + let candidateURLs = [ + sourceRoot.appendingPathComponent("test/\(meetingId).rttm"), + sourceRoot.appendingPathComponent("dev/\(meetingId).rttm"), + sourceRoot.appendingPathComponent("train/\(meetingId).rttm"), + ] + + return candidateURLs.first { fileManager.fileExists(atPath: $0.path) } + } + /// Extract ZIP file using system unzip command static func extractZipFile(_ zipFile: URL, to targetDir: URL) async -> Bool { let process = Process() diff --git a/Tests/FluidAudioTests/CLI/AMIKaldiDataTests.swift b/Tests/FluidAudioTests/CLI/AMIKaldiDataTests.swift new file mode 100644 index 000000000..ee6d38037 --- /dev/null +++ b/Tests/FluidAudioTests/CLI/AMIKaldiDataTests.swift @@ -0,0 +1,185 @@ +#if os(macOS) +import AVFoundation +import Foundation +import XCTest + +@testable import FluidAudioCLI + +final class AMIKaldiDataTests: XCTestCase { + + func testBuildSplitWritesExpectedKaldiFiles() throws { + let fixture = try makeFixture(meetingId: "ES2004a") + + try AMIKaldiData.buildSplit( + meetingIds: ["ES2004a"], + annotationsRoot: fixture.annotationsRoot, + audioRoot: fixture.audioRoot, + outputDirectory: fixture.outputDirectory + ) + + for fileName in ["wav.scp", "segments", "utt2spk", "spk2utt", "reco2dur", "reco2num_spk", "utt2timestamp"] { + let fileURL = fixture.outputDirectory.appendingPathComponent(fileName) + XCTAssertTrue(FileManager.default.fileExists(atPath: fileURL.path), "\(fileName) should exist") + } + + let segments = try String(contentsOf: fixture.outputDirectory.appendingPathComponent("segments")) + XCTAssertTrue(segments.contains("ES2004a_a_00001 ES2004a 0.004000 0.126000")) + XCTAssertTrue(segments.contains("ES2004a_b_00001 ES2004a 1.001000 1.019000")) + + let utt2spk = try String(contentsOf: fixture.outputDirectory.appendingPathComponent("utt2spk")) + XCTAssertTrue(utt2spk.contains("ES2004a_a_00001 SpeakerA")) + XCTAssertTrue(utt2spk.contains("ES2004a_b_00001 SpeakerB")) + + let spk2utt = try String(contentsOf: fixture.outputDirectory.appendingPathComponent("spk2utt")) + XCTAssertTrue(spk2utt.contains("SpeakerA ES2004a_a_00001 ES2004a_a_00002 ES2004a_a_00003")) + XCTAssertTrue(spk2utt.contains("SpeakerB ES2004a_b_00001")) + + let reco2dur = try String(contentsOf: fixture.outputDirectory.appendingPathComponent("reco2dur")) + XCTAssertTrue(reco2dur.contains("ES2004a 2.000000")) + + let reco2numSpk = try String(contentsOf: fixture.outputDirectory.appendingPathComponent("reco2num_spk")) + XCTAssertTrue(reco2numSpk.contains("ES2004a 2")) + + let utt2timestamp = try String(contentsOf: fixture.outputDirectory.appendingPathComponent("utt2timestamp")) + XCTAssertTrue(utt2timestamp.contains("ES2004a_a_00003 0.601000 0.799000")) + XCTAssertTrue(utt2timestamp.contains("ES2004a_b_00001 1.001000 1.019000")) + } + + func testLoadDERReferenceMatchesOriginalKaldiQuantization() throws { + let fixture = try makeFixture(meetingId: "ZZ0001") + + try AMIKaldiData.buildSplit( + meetingIds: ["ZZ0001"], + annotationsRoot: fixture.annotationsRoot, + audioRoot: fixture.audioRoot, + outputDirectory: fixture.outputDirectory + ) + + let meetingIds = try AMIKaldiData.recordingIDs(in: fixture.outputDirectory) + XCTAssertEqual(meetingIds, ["ZZ0001"]) + XCTAssertEqual( + try AMIKaldiData.audioPath(for: "ZZ0001", in: fixture.outputDirectory), + fixture.audioRoot.appendingPathComponent("ZZ0001.Mix-Headset.wav").path + ) + XCTAssertEqual( + try XCTUnwrap(AMIKaldiData.recordingDuration(for: "ZZ0001", in: fixture.outputDirectory)), + 2.0, + accuracy: 0.0001 + ) + + let segments = try AMIKaldiData.loadDERReference(for: "ZZ0001", in: fixture.outputDirectory) + XCTAssertEqual(segments.count, 3) + + XCTAssertEqual(segments[0].speaker, "SpeakerA") + XCTAssertEqual(segments[0].start, 0.00, accuracy: 0.0001) + XCTAssertEqual(segments[0].end, 0.25, accuracy: 0.0001) + + XCTAssertEqual(segments[1].speaker, "SpeakerA") + XCTAssertEqual(segments[1].start, 0.60, accuracy: 0.0001) + XCTAssertEqual(segments[1].end, 0.80, accuracy: 0.0001) + + XCTAssertEqual(segments[2].speaker, "SpeakerB") + XCTAssertEqual(segments[2].start, 1.00, accuracy: 0.0001) + XCTAssertEqual(segments[2].end, 1.02, accuracy: 0.0001) + } + + private func makeFixture( + meetingId: String + ) throws -> ( + root: URL, annotationsRoot: URL, audioRoot: URL, outputDirectory: URL + ) { + let root = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + let annotationsRoot = root.appendingPathComponent("ami_public_1.6.2", isDirectory: true) + let segmentsRoot = annotationsRoot.appendingPathComponent("segments", isDirectory: true) + let corpusRoot = annotationsRoot.appendingPathComponent("corpusResources", isDirectory: true) + let audioRoot = root.appendingPathComponent("audio", isDirectory: true) + let outputDirectory = root.appendingPathComponent("ami/mhs/data/test", isDirectory: true) + + try FileManager.default.createDirectory(at: segmentsRoot, withIntermediateDirectories: true) + try FileManager.default.createDirectory(at: corpusRoot, withIntermediateDirectories: true) + try FileManager.default.createDirectory(at: audioRoot, withIntermediateDirectories: true) + try FileManager.default.createDirectory(at: outputDirectory, withIntermediateDirectories: true) + + let meetingsXML = """ + + + + + + + + + """ + try meetingsXML.write( + to: corpusRoot.appendingPathComponent("meetings.xml"), + atomically: true, + encoding: .utf8 + ) + + let speakerASegments = """ + + + + + + """ + try speakerASegments.write( + to: segmentsRoot.appendingPathComponent("\(meetingId).A.segments.xml"), + atomically: true, + encoding: .utf8 + ) + + let speakerBSegments = """ + + + + """ + try speakerBSegments.write( + to: segmentsRoot.appendingPathComponent("\(meetingId).B.segments.xml"), + atomically: true, + encoding: .utf8 + ) + + let emptySegments = "" + try emptySegments.write( + to: segmentsRoot.appendingPathComponent("\(meetingId).C.segments.xml"), + atomically: true, + encoding: .utf8 + ) + try emptySegments.write( + to: segmentsRoot.appendingPathComponent("\(meetingId).D.segments.xml"), + atomically: true, + encoding: .utf8 + ) + + try writeAudio( + to: audioRoot.appendingPathComponent("\(meetingId).Mix-Headset.wav"), + durationSeconds: 2.0 + ) + + return (root, annotationsRoot, audioRoot, outputDirectory) + } + + private func writeAudio(to url: URL, durationSeconds: Double) throws { + let format = AVAudioFormat( + commonFormat: .pcmFormatFloat32, + sampleRate: 8_000, + channels: 1, + interleaved: false + ) + let resolvedFormat = try XCTUnwrap(format) + let totalFrames = AVAudioFrameCount(durationSeconds * resolvedFormat.sampleRate) + let buffer = try XCTUnwrap(AVAudioPCMBuffer(pcmFormat: resolvedFormat, frameCapacity: totalFrames)) + buffer.frameLength = totalFrames + + let channelData = try XCTUnwrap(buffer.floatChannelData?[0]) + for frame in 0.. URL { + let baseURL = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + let wordsURL = baseURL.appendingPathComponent("words", isDirectory: true) + let segmentsURL = baseURL.appendingPathComponent("segments", isDirectory: true) + let corpusURL = baseURL.appendingPathComponent("corpusResources", isDirectory: true) + + try FileManager.default.createDirectory(at: wordsURL, withIntermediateDirectories: true) + try FileManager.default.createDirectory(at: segmentsURL, withIntermediateDirectories: true) + try FileManager.default.createDirectory(at: corpusURL, withIntermediateDirectories: true) + + let meetingsXML = """ + + + + + + + + + """ + try meetingsXML.write( + to: corpusURL.appendingPathComponent("meetings.xml"), + atomically: true, + encoding: .utf8 + ) + + let speakerAWords = """ + + hello + world + . + + + """ + try speakerAWords.write( + to: wordsURL.appendingPathComponent("ES2004a.A.words.xml"), + atomically: true, + encoding: .utf8 + ) + + let speakerBWords = """ + + second + speaker + + """ + try speakerBWords.write( + to: wordsURL.appendingPathComponent("ES2004a.B.words.xml"), + atomically: true, + encoding: .utf8 + ) + + let emptyWords = "" + try emptyWords.write( + to: wordsURL.appendingPathComponent("ES2004a.C.words.xml"), + atomically: true, + encoding: .utf8 + ) + try emptyWords.write( + to: wordsURL.appendingPathComponent("ES2004a.D.words.xml"), + atomically: true, + encoding: .utf8 + ) + + let speakerASegments = """ + + + + + + """ + try speakerASegments.write( + to: segmentsURL.appendingPathComponent("ES2004a.A.segments.xml"), + atomically: true, + encoding: .utf8 + ) + + let speakerBSegments = """ + + + + """ + try speakerBSegments.write( + to: segmentsURL.appendingPathComponent("ES2004a.B.segments.xml"), + atomically: true, + encoding: .utf8 + ) + + let emptySegments = "" + try emptySegments.write( + to: segmentsURL.appendingPathComponent("ES2004a.C.segments.xml"), + atomically: true, + encoding: .utf8 + ) + try emptySegments.write( + to: segmentsURL.appendingPathComponent("ES2004a.D.segments.xml"), + atomically: true, + encoding: .utf8 + ) + + return baseURL + } +} +#endif diff --git a/Tests/FluidAudioTests/CLI/AMIRTTMTests.swift b/Tests/FluidAudioTests/CLI/AMIRTTMTests.swift new file mode 100644 index 000000000..a72890e65 --- /dev/null +++ b/Tests/FluidAudioTests/CLI/AMIRTTMTests.swift @@ -0,0 +1,62 @@ +#if os(macOS) +import Foundation +import XCTest + +@testable import FluidAudioCLI + +final class AMIRTTMTests: XCTestCase { + + func testAMIRTTMLookupPrefersCachedHomePath() throws { + let root = try makeFixtureRoot() + let homeDir = root.appendingPathComponent("home", isDirectory: true) + let workingDir = root.appendingPathComponent("workspace", isDirectory: true) + let cachedRTTM = homeDir.appendingPathComponent("FluidAudioDatasets/ami_official/rttm/ES2004a.rttm") + + try FileManager.default.createDirectory( + at: cachedRTTM.deletingLastPathComponent(), withIntermediateDirectories: true) + try "SPEAKER ES2004a 1 0.00 1.00 speaker0 \n".write( + to: cachedRTTM, + atomically: true, + encoding: .utf8 + ) + + let resolvedURL = DiarizationBenchmarkUtils.getAMIRTTMURL( + for: "ES2004a", + workingDir: workingDir, + homeDir: homeDir + ) + + XCTAssertEqual(resolvedURL, cachedRTTM) + } + + func testDownloadAMIRTTMsCopiesFromForcedAlignmentRepo() async throws { + let root = try makeFixtureRoot() + let sourceRoot = root.appendingPathComponent("Datasets/diar-forced-alignment/AMI", isDirectory: true) + let destinationDir = root.appendingPathComponent("cache/rttm", isDirectory: true) + let sourceRTTM = sourceRoot.appendingPathComponent("test/ES2004a.rttm") + + try FileManager.default.createDirectory( + at: sourceRTTM.deletingLastPathComponent(), withIntermediateDirectories: true) + let expectedContents = "SPEAKER ES2004a 1 0.00 1.00 speaker0 \n" + try expectedContents.write(to: sourceRTTM, atomically: true, encoding: .utf8) + + await DatasetDownloader.downloadAMIRTTMs( + force: false, + singleFile: "ES2004a", + sourceRoot: sourceRoot, + destinationDir: destinationDir + ) + + let copiedRTTM = destinationDir.appendingPathComponent("ES2004a.rttm") + XCTAssertTrue(FileManager.default.fileExists(atPath: copiedRTTM.path)) + XCTAssertEqual(try String(contentsOf: copiedRTTM), expectedContents) + } + + private func makeFixtureRoot() throws -> URL { + let root = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + try FileManager.default.createDirectory(at: root, withIntermediateDirectories: true) + return root + } +} +#endif diff --git a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDFeatureProvider.swift b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDFeatureProvider.swift new file mode 100644 index 000000000..d9364b553 --- /dev/null +++ b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDFeatureProvider.swift @@ -0,0 +1,183 @@ +import Accelerate +import CoreML +import Foundation +import XCTest + +@testable import FluidAudio + +final class LSEENDFeatureProviderTests: XCTestCase { + + func testFeatureProviderRequiresExactMinimumAudioForFirstChunk() throws { + let metadata = makeMetadata() + let provider = try LSEENDFeatureProvider(from: metadata) + let minimumSamples = minimumSamplesForFirstChunk(metadata: metadata) + + try provider.enqueueAudio(makeAudio(count: minimumSamples - 1), withSampleRate: nil) + XCTAssertEqual(provider.readyChunks, 0) + XCTAssertNil(try provider.emitNextChunk()) + + try provider.enqueueAudio(makeAudio(count: 1), withSampleRate: nil) + + let firstChunk = try provider.emitNextChunk() + XCTAssertNotNil(firstChunk) + XCTAssertEqual(firstChunk?.melFeatures.count, metadata.melFrames * metadata.nMels) + XCTAssertEqual(firstChunk?.warmupFrames, metadata.convDelay) + } + + func testChunkedMelMatchesNonChunkedPipelineExactly() throws { + let metadata = makeMetadata() + let provider = try LSEENDFeatureProvider(from: metadata) + let audio = makeAudio(count: minimumSamplesForFirstChunk(metadata: metadata) * 3 + 37) + + try provider.enqueueAudio(audio, withSampleRate: nil) + try provider.drainRightContextWithSilence() + + var emittedChunks: [[Float]] = [] + while let input = try provider.emitNextChunk() { + emittedChunks.append(allValues(in: input.melFeatures)) + } + + let expectedChunks = try buildExpectedChunks(audio: audio, metadata: metadata) + + XCTAssertEqual(emittedChunks.count, expectedChunks.count) + XCTAssertFalse(emittedChunks.isEmpty) + + for (actual, expected) in zip(emittedChunks, expectedChunks) { + XCTAssertEqual(actual.count, expected.count) + for (index, pair) in zip(actual.indices, zip(actual, expected)) { + XCTAssertEqual( + pair.0, + pair.1, + accuracy: 1e-6, + "Mismatch at chunk sample \(index)" + ) + } + } + } + + private func buildExpectedChunks(audio: [Float], metadata: LSEENDMetadata) throws -> [[Float]] { + let audioBuffer = buildProviderAudioBuffer(audio: audio, metadata: metadata) + let spectrogram = AudioMelSpectrogram( + sampleRate: metadata.sampleRate, + nMels: metadata.nMels, + nFFT: metadata.nFFT, + hopLength: metadata.hopLength, + winLength: metadata.winLength, + preemph: 0, + padTo: 0, + logFloor: 1e-10, + logFloorMode: .clamped, + windowPeriodic: true + ) + + var (melFeatures, _, _) = spectrogram.computeFlatTransposed( + audio: audioBuffer, + lastAudioSample: 0, + paddingMode: .prePadded, + expectedFrameCount: nil + ) + + let scale: Float = 1.0 / log(10.0) + melFeatures = melFeatures.map { $0 * scale } + applyCumulativeMeanNormalization(to: &melFeatures, nMels: metadata.nMels) + + var melQueue = StreamingChunkQueue( + chunkLength: metadata.subsampling * metadata.chunkSize, + leftContextLength: metadata.contextSize, + rightContextLength: metadata.contextSize + 1 - metadata.subsampling, + stride: metadata.nMels + ) + melQueue.append(melFeatures) + + var chunks: [[Float]] = [] + while let chunk = melQueue.popNextChunk() { + chunks.append(Array(chunk)) + } + + return chunks + } + + private func buildProviderAudioBuffer(audio: [Float], metadata: LSEENDMetadata) -> [Float] { + let contextSamples = metadata.nFFT / 2 + let chunkSamples = metadata.hopLength * metadata.subsampling * metadata.chunkSize + let rightSamples = metadata.nFFT / 2 - metadata.hopLength + let flushSampleCount = + (metadata.contextSize + metadata.convDelay * metadata.subsampling) * metadata.hopLength + + contextSamples + + var buffer = [Float](repeating: 0, count: contextSamples) + buffer.append(contentsOf: audio) + buffer.append(contentsOf: repeatElement(0, count: flushSampleCount)) + + let contextFloats = contextSamples + rightSamples + let unread = buffer.count + let overContext = max(0, unread - contextFloats) + let shortfall = (chunkSamples - overContext % chunkSamples) % chunkSamples + if shortfall > 0 { + buffer.append(contentsOf: repeatElement(0, count: shortfall)) + } + + return buffer + } + + private func applyCumulativeMeanNormalization(to melFeatures: inout [Float], nMels: Int) { + var cmnMean = [Float](repeating: 0, count: nMels) + var cmnCount = 0 + + melFeatures.withUnsafeMutableBufferPointer { buffer in + guard let base = buffer.baseAddress else { return } + + for frame in stride(from: 0, to: buffer.count, by: nMels) { + cmnCount += 1 + var alpha = 1.0 / Float(cmnCount) + + vDSP_vintb(cmnMean, 1, base + frame, 1, &alpha, &cmnMean, 1, vDSP_Length(nMels)) + vDSP_vsub(cmnMean, 1, base + frame, 1, base + frame, 1, vDSP_Length(nMels)) + } + } + } + + private func minimumSamplesForFirstChunk(metadata: LSEENDMetadata) -> Int { + let chunkSamples = metadata.hopLength * metadata.subsampling * metadata.chunkSize + let rightSamples = metadata.nFFT / 2 - metadata.hopLength + return chunkSamples + rightSamples + } + + private func makeAudio(count: Int) -> [Float] { + (0.. [Float] { + var values: [Float] = [] + array.withUnsafeBufferPointer(ofType: Float.self) { buffer in + values = Array(buffer) + } + return values + } + + private func makeMetadata() -> LSEENDMetadata { + LSEENDMetadata( + chunkSize: 4, + frameDurationSeconds: 0.1, + maxSpeakers: 4, + sampleRate: 16_000, + maxNspks: 5, + hopLength: 4, + winLength: 16, + nMels: 6, + contextSize: 7, + subsampling: 8, + convDelay: 1, + nUnits: 32, + nHeads: 4, + encNLayers: 2, + decNLayers: 2, + convKernelSize: 3 + ) + } +} diff --git a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift deleted file mode 100644 index bae07e8a9..000000000 --- a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift +++ /dev/null @@ -1,457 +0,0 @@ -import CoreML -import Foundation -import XCTest - -@testable import FluidAudio - -@MainActor -final class LSEENDIntegrationTests: XCTestCase { - private struct ErrorStats { - let maxAbs: Double - let meanAbs: Double - } - - private static var cachedEngines: [LSEENDVariant: LSEENDInferenceHelper] = [:] - - func testVariantRegistryResolvesAllExportedArtifacts() async throws { - let expectedColumns: [LSEENDVariant: Int] = [ - .ami: 4, - .callhome: 7, - .dihard2: 10, - .dihard3: 10, - ] - - for variant in LSEENDVariant.allCases { - let descriptor = try await LSEENDModelDescriptor.loadFromHuggingFace(variant: variant) - XCTAssertTrue(FileManager.default.fileExists(atPath: descriptor.modelURL.path)) - XCTAssertTrue(FileManager.default.fileExists(atPath: descriptor.metadataURL.path)) - - let engine = try await makeEngine(variant: variant) - XCTAssertEqual(engine.metadata.realOutputDim, expectedColumns[variant]) - XCTAssertEqual(engine.metadata.fullOutputDim, (expectedColumns[variant] ?? 0) + 2) - XCTAssertGreaterThan(engine.streamingLatencySeconds, 0) - XCTAssertGreaterThan(engine.modelFrameHz, 0) - } - } - - func testOfflineInferenceProducesConsistentShapesAcrossVariants() async throws { - for variant in LSEENDVariant.allCases { - let engine = try await makeEngine(variant: variant) - let samples = try DiarizationTestFixtures.fixtureAudio( - sampleRate: engine.targetSampleRate, limitSeconds: 2.0) - let result = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate) - - try assertResultInvariants( - result, engine: engine, - expectedDurationSeconds: duration(of: samples, sampleRate: engine.targetSampleRate)) - assertMatrixClose(result.probabilities, result.logits.applyingSigmoid(), maxAbs: 1e-7, meanAbs: 1e-8) - assertMatrixClose( - result.fullProbabilities, result.fullLogits.applyingSigmoid(), maxAbs: 1e-7, meanAbs: 1e-8) - } - } - - func testAudioFileInferenceMatchesInferenceOnResampledFixtureSamples() async throws { - let engine = try await makeEngine(variant: .dihard3) - let fileResult = try engine.infer(audioFileURL: try DiarizationTestFixtures.fixtureAudioFileURL()) - let resampled = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate) - let sampleResult = try engine.infer(samples: resampled, sampleRate: engine.targetSampleRate) - - assertMatrixClose(fileResult.logits, sampleResult.logits, maxAbs: 1e-6, meanAbs: 1e-7) - assertMatrixClose(fileResult.probabilities, sampleResult.probabilities, maxAbs: 1e-6, meanAbs: 1e-7) - assertMatrixClose(fileResult.fullLogits, sampleResult.fullLogits, maxAbs: 1e-6, meanAbs: 1e-7) - assertMatrixClose(fileResult.fullProbabilities, sampleResult.fullProbabilities, maxAbs: 1e-6, meanAbs: 1e-7) - } - - func testStreamingSessionMatchesOfflineInferenceOnRealFixtureAudio() async throws { - let engine = try await makeEngine(variant: .dihard3) - let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0) - let offline = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate) - let session = try engine.createSession(inputSampleRate: engine.targetSampleRate) - - var totalEmitted = 0 - let chunkSizes = [617, 911, 1283, 743] - var sawUpdate = false - var start = 0 - var chunkIndex = 0 - while start < samples.count { - let chunkSize = chunkSizes[chunkIndex % chunkSizes.count] - let stop = min(samples.count, start + chunkSize) - if let update = try session.pushAudio(Array(samples[start.. LSEENDInferenceHelper { - if let cached = Self.cachedEngines[variant] { - return cached - } - let descriptor = try await LSEENDModelDescriptor.loadFromHuggingFace(variant: variant) - let engine = try LSEENDInferenceHelper(descriptor: descriptor, computeUnits: .cpuOnly) - Self.cachedEngines[variant] = engine - return engine - } - - private func duration(of samples: [Float], sampleRate: Int) -> Double { - Double(samples.count) / Double(sampleRate) - } - - private func assertResultInvariants( - _ result: LSEENDInferenceResult, - engine: LSEENDInferenceHelper, - expectedDurationSeconds: Double, - file: StaticString = #filePath, - line: UInt = #line - ) throws { - XCTAssertGreaterThan(result.logits.rows, 0, file: file, line: line) - XCTAssertEqual(result.logits.rows, result.probabilities.rows, file: file, line: line) - XCTAssertEqual(result.logits.rows, result.fullLogits.rows, file: file, line: line) - XCTAssertEqual(result.logits.columns, engine.metadata.realOutputDim, file: file, line: line) - XCTAssertEqual(result.probabilities.columns, engine.metadata.realOutputDim, file: file, line: line) - XCTAssertEqual(result.fullLogits.columns, engine.metadata.fullOutputDim, file: file, line: line) - XCTAssertEqual(result.fullProbabilities.columns, engine.metadata.fullOutputDim, file: file, line: line) - XCTAssertEqual(result.frameHz, engine.modelFrameHz, accuracy: 1e-9, file: file, line: line) - XCTAssertEqual(result.durationSeconds, expectedDurationSeconds, accuracy: 1e-6, file: file, line: line) - } - - private func assertMatrixClose( - _ actual: LSEENDMatrix, - _ expected: LSEENDMatrix, - maxAbs: Double, - meanAbs: Double, - file: StaticString = #filePath, - line: UInt = #line - ) { - XCTAssertEqual(actual.rows, expected.rows, file: file, line: line) - XCTAssertEqual(actual.columns, expected.columns, file: file, line: line) - XCTAssertEqual(actual.values.count, expected.values.count, file: file, line: line) - let stats = compare(actual.values, expected.values) - XCTAssertLessThanOrEqual(stats.maxAbs, maxAbs, file: file, line: line) - XCTAssertLessThanOrEqual(stats.meanAbs, meanAbs, file: file, line: line) - } - - private func assertArrayClose( - _ actual: [Float], - _ expected: [Float], - maxAbs: Double, - meanAbs: Double, - file: StaticString = #filePath, - line: UInt = #line - ) { - XCTAssertEqual(actual.count, expected.count, file: file, line: line) - let stats = compare(actual, expected) - XCTAssertLessThanOrEqual(stats.maxAbs, maxAbs, file: file, line: line) - XCTAssertLessThanOrEqual(stats.meanAbs, meanAbs, file: file, line: line) - } - - private func compare(_ actual: [Float], _ expected: [Float]) -> ErrorStats { - guard actual.count == expected.count else { - return ErrorStats(maxAbs: .infinity, meanAbs: .infinity) - } - var maxAbs = 0.0 - var sumAbs = 0.0 - for (lhs, rhs) in zip(actual, expected) { - let diff = abs(Double(lhs - rhs)) - maxAbs = max(maxAbs, diff) - sumAbs += diff - } - return ErrorStats( - maxAbs: maxAbs, - meanAbs: actual.isEmpty ? 0 : sumAbs / Double(actual.count) - ) - } - -} diff --git a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDMatrixTests.swift b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDMatrixTests.swift deleted file mode 100644 index 1667413e8..000000000 --- a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDMatrixTests.swift +++ /dev/null @@ -1,326 +0,0 @@ -import XCTest - -@testable import FluidAudio - -final class LSEENDMatrixTests: XCTestCase { - - // MARK: - Init (validated) - - func testInitWithMatchingDimensions() throws { - let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) - XCTAssertEqual(m.rows, 2) - XCTAssertEqual(m.columns, 3) - XCTAssertEqual(m.values, [1, 2, 3, 4, 5, 6]) - } - - func testInitThrowsOnCountMismatch() { - XCTAssertThrowsError(try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3])) { error in - guard case LSEENDError.invalidMatrixShape = error else { - return XCTFail("Expected invalidMatrixShape, got \(error)") - } - } - } - - func testInitThrowsOnNegativeRows() { - XCTAssertThrowsError(try LSEENDMatrix(rows: -1, columns: 3, values: [])) { error in - guard case LSEENDError.invalidMatrixShape = error else { - return XCTFail("Expected invalidMatrixShape, got \(error)") - } - } - } - - func testInitThrowsOnNegativeColumns() { - XCTAssertThrowsError(try LSEENDMatrix(rows: 2, columns: -1, values: [])) { error in - guard case LSEENDError.invalidMatrixShape = error else { - return XCTFail("Expected invalidMatrixShape, got \(error)") - } - } - } - - func testInitWithZeroDimensions() throws { - let m = try LSEENDMatrix(rows: 0, columns: 5, values: []) - XCTAssertEqual(m.rows, 0) - XCTAssertEqual(m.columns, 5) - XCTAssertTrue(m.isEmpty) - } - - // MARK: - Factory Methods - - func testZeros() { - let m = LSEENDMatrix.zeros(rows: 3, columns: 2) - XCTAssertEqual(m.rows, 3) - XCTAssertEqual(m.columns, 2) - XCTAssertEqual(m.values, [Float](repeating: 0, count: 6)) - } - - func testEmpty() { - let m = LSEENDMatrix.empty(columns: 4) - XCTAssertEqual(m.rows, 0) - XCTAssertEqual(m.columns, 4) - XCTAssertTrue(m.isEmpty) - } - - // MARK: - isEmpty - - func testIsEmptyZeroRows() { - XCTAssertTrue(LSEENDMatrix.empty(columns: 3).isEmpty) - } - - func testIsEmptyZeroColumns() { - let m = LSEENDMatrix(validatingRows: 3, columns: 0, values: []) - XCTAssertTrue(m.isEmpty) - } - - func testIsEmptyFalseForPopulatedMatrix() throws { - let m = try LSEENDMatrix(rows: 1, columns: 1, values: [42]) - XCTAssertFalse(m.isEmpty) - } - - // MARK: - Subscript - - func testSubscriptGet() throws { - let m = try LSEENDMatrix(rows: 2, columns: 3, values: [10, 20, 30, 40, 50, 60]) - XCTAssertEqual(m[0, 0], 10) - XCTAssertEqual(m[0, 2], 30) - XCTAssertEqual(m[1, 0], 40) - XCTAssertEqual(m[1, 2], 60) - } - - func testSubscriptSet() throws { - var m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - m[1, 0] = 99 - XCTAssertEqual(m[1, 0], 99) - XCTAssertEqual(m.values, [1, 2, 99, 4]) - } - - // MARK: - row() - - func testRow() throws { - let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) - XCTAssertEqual(Array(m.row(0)), [1, 2]) - XCTAssertEqual(Array(m.row(1)), [3, 4]) - XCTAssertEqual(Array(m.row(2)), [5, 6]) - } - - // MARK: - prefixingColumns - - func testPrefixingColumns() throws { - let m = try LSEENDMatrix(rows: 2, columns: 4, values: [1, 2, 3, 4, 5, 6, 7, 8]) - let prefix = m.prefixingColumns(2) - XCTAssertEqual(prefix.rows, 2) - XCTAssertEqual(prefix.columns, 2) - XCTAssertEqual(prefix.values, [1, 2, 5, 6]) - } - - func testPrefixingColumnsEqualToWidth() throws { - let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) - let same = m.prefixingColumns(3) - XCTAssertEqual(same, m) - } - - func testPrefixingColumnsGreaterThanWidth() throws { - let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) - let same = m.prefixingColumns(10) - XCTAssertEqual(same, m) - } - - func testPrefixingColumnsZero() throws { - let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) - let empty = m.prefixingColumns(0) - XCTAssertTrue(empty.isEmpty) - } - - func testPrefixingColumnsOnEmptyMatrix() { - let m = LSEENDMatrix.empty(columns: 4) - let result = m.prefixingColumns(2) - XCTAssertTrue(result.isEmpty) - XCTAssertEqual(result.columns, 2) - } - - // MARK: - rowMajorRows - - func testRowMajorRows() throws { - let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) - let rows = m.rowMajorRows() - XCTAssertEqual(rows, [[1, 2, 3], [4, 5, 6]]) - } - - func testRowMajorRowsEmpty() { - let m = LSEENDMatrix.empty(columns: 3) - XCTAssertEqual(m.rowMajorRows(), []) - } - - // MARK: - appendingRows - - func testAppendingRows() throws { - let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - let b = try LSEENDMatrix(rows: 1, columns: 2, values: [5, 6]) - let result = a.appendingRows(b) - XCTAssertEqual(result.rows, 3) - XCTAssertEqual(result.columns, 2) - XCTAssertEqual(result.values, [1, 2, 3, 4, 5, 6]) - } - - func testAppendingRowsToEmpty() throws { - let a = LSEENDMatrix.empty(columns: 2) - let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - XCTAssertEqual(a.appendingRows(b), b) - } - - func testAppendingEmptyRows() throws { - let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - let b = LSEENDMatrix.empty(columns: 2) - XCTAssertEqual(a.appendingRows(b), a) - } - - // MARK: - droppingFirstRows - - func testDroppingFirstRows() throws { - let m = try LSEENDMatrix(rows: 4, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8]) - let dropped = m.droppingFirstRows(2) - XCTAssertEqual(dropped.rows, 2) - XCTAssertEqual(dropped.columns, 2) - XCTAssertEqual(dropped.values, [5, 6, 7, 8]) - } - - func testDroppingAllRows() throws { - let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) - let dropped = m.droppingFirstRows(3) - XCTAssertEqual(dropped.rows, 0) - XCTAssertTrue(dropped.isEmpty) - } - - func testDroppingMoreThanTotalRows() throws { - let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - let dropped = m.droppingFirstRows(100) - XCTAssertEqual(dropped.rows, 0) - XCTAssertTrue(dropped.isEmpty) - } - - func testDroppingZeroRows() throws { - let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - let same = m.droppingFirstRows(0) - XCTAssertEqual(same, m) - } - - func testDroppingNegativeCount() throws { - let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - let same = m.droppingFirstRows(-5) - XCTAssertEqual(same, m) - } - - // MARK: - slicingRows - - func testSlicingRows() throws { - let m = try LSEENDMatrix(rows: 5, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - let slice = m.slicingRows(start: 1, end: 4) - XCTAssertEqual(slice.rows, 3) - XCTAssertEqual(slice.values, [3, 4, 5, 6, 7, 8]) - } - - func testSlicingRowsFullRange() throws { - let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) - let slice = m.slicingRows(start: 0, end: 3) - XCTAssertEqual(slice, m) - } - - func testSlicingRowsEmptyRange() throws { - let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) - let slice = m.slicingRows(start: 2, end: 2) - XCTAssertTrue(slice.isEmpty) - XCTAssertEqual(slice.columns, 2) - } - - func testSlicingRowsClampsOutOfBounds() throws { - let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) - let slice = m.slicingRows(start: -5, end: 100) - XCTAssertEqual(slice, m) - } - - func testSlicingRowsInvertedRange() throws { - let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) - let slice = m.slicingRows(start: 3, end: 1) - XCTAssertTrue(slice.isEmpty) - } - - // MARK: - applyingSigmoid - - func testSigmoidZero() throws { - let m = try LSEENDMatrix(rows: 1, columns: 1, values: [0]) - let s = m.applyingSigmoid() - XCTAssertEqual(s[0, 0], 0.5, accuracy: 1e-6) - } - - func testSigmoidLargePositive() throws { - let m = try LSEENDMatrix(rows: 1, columns: 1, values: [20]) - let s = m.applyingSigmoid() - XCTAssertEqual(s[0, 0], 1.0, accuracy: 1e-5) - } - - func testSigmoidLargeNegative() throws { - let m = try LSEENDMatrix(rows: 1, columns: 1, values: [-20]) - let s = m.applyingSigmoid() - XCTAssertEqual(s[0, 0], 0.0, accuracy: 1e-5) - } - - func testSigmoidPreservesShape() throws { - let m = try LSEENDMatrix(rows: 3, columns: 4, values: [Float](repeating: 0, count: 12)) - let s = m.applyingSigmoid() - XCTAssertEqual(s.rows, 3) - XCTAssertEqual(s.columns, 4) - XCTAssertEqual(s.values.count, 12) - } - - func testSigmoidDoesNotMutateOriginal() throws { - let m = try LSEENDMatrix(rows: 1, columns: 2, values: [0, 0]) - _ = m.applyingSigmoid() - XCTAssertEqual(m.values, [0, 0]) - } - - func testSigmoidOnEmpty() { - let m = LSEENDMatrix.empty(columns: 3) - let s = m.applyingSigmoid() - XCTAssertTrue(s.isEmpty) - } - - // MARK: - Equatable - - func testEqualMatrices() throws { - let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - XCTAssertEqual(a, b) - } - - func testUnequalValues() throws { - let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) - let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 5]) - XCTAssertNotEqual(a, b) - } - - // MARK: - Roundtrip: append then drop - - func testAppendThenDropRecoversOriginal() throws { - let original = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) - let extra = try LSEENDMatrix(rows: 2, columns: 2, values: [7, 8, 9, 10]) - let combined = original.appendingRows(extra) - let recovered = combined.slicingRows(start: 0, end: 3) - XCTAssertEqual(recovered, original) - } - - func testSliceThenAppendRecombines() throws { - let m = try LSEENDMatrix(rows: 4, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8]) - let head = m.slicingRows(start: 0, end: 2) - let tail = m.slicingRows(start: 2, end: 4) - let recombined = head.appendingRows(tail) - XCTAssertEqual(recombined, m) - } - - func testDropThenPrefixColumnsCommutes() throws { - let m = try LSEENDMatrix(rows: 4, columns: 4, values: (0..<16).map { Float($0) }) - - let dropFirst = m.droppingFirstRows(2).prefixingColumns(2) - let prefixFirst = m.prefixingColumns(2).droppingFirstRows(2) - - XCTAssertEqual(dropFirst, prefixFirst) - } -} diff --git a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDQueueTests.swift b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDQueueTests.swift new file mode 100644 index 000000000..bc7b052e5 --- /dev/null +++ b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDQueueTests.swift @@ -0,0 +1,49 @@ +import XCTest + +@testable import FluidAudio + +final class LSEENDQueueTests: XCTestCase { + + func testStreamingChunkQueueRequiresExactMinimumElementsForFirstChunk() { + var queue = StreamingChunkQueue( + chunkLength: 8, + leftContextLength: 3, + rightContextLength: 2, + stride: 1 + ) + + XCTAssertFalse(queue.hasChunk) + XCTAssertEqual(queue.readyChunks, 0) + + queue.append(repeatElement(1, count: 9)) + XCTAssertFalse(queue.hasChunk) + XCTAssertEqual(queue.readyChunks, 0) + + queue.append([1]) + XCTAssertTrue(queue.hasChunk) + XCTAssertEqual(queue.readyChunks, 1) + + let firstChunk = queue.popNextChunk() + XCTAssertEqual(firstChunk.map(Array.init), [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + XCTAssertEqual(queue.readyChunks, 0) + } + + func testPopAllChunksConsumesOnlyWholeChunksAndPreservesTrailingContext() { + var queue = StreamingChunkQueue( + chunkLength: 4, + leftContextLength: 2, + rightContextLength: 1, + stride: 1 + ) + + queue.append(Array(1...10).map(Float.init)) + + let combined = queue.popAllChunks() + XCTAssertEqual(combined.map(Array.init), [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + XCTAssertEqual(queue.readyChunks, 0) + + queue.append([11, 12, 13]) + let nextChunk = queue.popNextChunk() + XCTAssertEqual(nextChunk.map(Array.init), [7, 8, 9, 10, 11, 12, 13]) + } +} diff --git a/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift b/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift index eb4cc3ca2..13016b48d 100644 --- a/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift +++ b/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift @@ -8,22 +8,32 @@ import XCTest /// - `SortformerDiarizer.enrollSpeaker(withAudio:named:)` /// - `LSEENDDiarizer.enrollSpeaker(withSamples:named:)` final class SpeakerEnrollmentTests: XCTestCase { - nonisolated(unsafe) private static var cachedLseendEngine: LSEENDInferenceHelper? + nonisolated(unsafe) private static var cachedLseendModel: LSEENDModel? private func loadSortformerModelsForTest(config: SortformerConfig) async throws -> SortformerModels { // These tests validate Sortformer behavior after initialization, not accelerator selection. try await SortformerModels.loadFromHuggingFace(config: config, computeUnits: .cpuOnly) } - private func loadLseendEngineForTest(variant: LSEENDVariant = .dihard3) async throws -> LSEENDInferenceHelper { - if let cached = Self.cachedLseendEngine { + private func loadLseendModelForTest( + variant: LSEENDVariant = .ami, + stepSize: LSEENDStepSize = .step500ms + ) async throws -> LSEENDModel { + if let cached = Self.cachedLseendModel { return cached } - let descriptor = try await LSEENDModelDescriptor.loadFromHuggingFace(variant: variant) - let engine = try LSEENDInferenceHelper(descriptor: descriptor, computeUnits: .cpuOnly) - Self.cachedLseendEngine = engine - return engine + do { + let model = try await LSEENDModel.loadFromHuggingFace( + variant: variant, + stepSize: stepSize, + computeUnits: .cpuOnly + ) + Self.cachedLseendModel = model + return model + } catch { + throw XCTSkip("Unable to load LS-EEND test model: \(error)") + } } // MARK: - extractSpeakerEmbedding: Error Cases @@ -321,30 +331,38 @@ final class SpeakerEnrollmentTests: XCTestCase { // MARK: - LS-EEND enrollSpeaker: Error Cases func testLseendEnrollSpeakerThrowsWhenNotInitialized() { - let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) + let diarizer = DummyUnavailableLSEENDDiarizer() let audio = [Float](repeating: 0.1, count: 16000) - XCTAssertThrowsError(try diarizer.enrollSpeaker(withSamples: audio)) { error in - guard case LSEENDError.modelPredictionFailed(let message) = error else { - XCTFail("Expected modelPredictionFailed but got \(error)") + XCTAssertThrowsError( + try diarizer.enrollSpeaker( + withAudio: audio, + sourceSampleRate: nil, + named: nil, + overwritingAssignedSpeakerName: true + ) + ) { error in + guard case LSEENDError.notInitialized = error else { + XCTFail("Expected notInitialized but got \(error)") return } - XCTAssertTrue(message.contains("not initialized")) } } // MARK: - LS-EEND enrollSpeaker: Integration (requires model download) func testLseendEnrollSpeakerResetsTimelineAndWarmsSession() async throws { - XCTExpectFailure("Download might fail in CI environment", strict: false) - - let engine = try await loadLseendEngineForTest() - let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) - diarizer.initialize(engine: engine) + let model = try await loadLseendModelForTest() + let diarizer = try LSEENDDiarizer(model: model) let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( - sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0) + sampleRate: diarizer.targetSampleRate ?? 16_000, startSeconds: 0.0, durationSeconds: 3.0) - let speaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice") + let speaker = try diarizer.enrollSpeaker( + withAudio: enrollmentAudio, + sourceSampleRate: nil, + named: "Alice", + overwritingAssignedSpeakerName: true + ) if let speaker { XCTAssertEqual(speaker.name, "Alice") @@ -352,30 +370,32 @@ final class SpeakerEnrollmentTests: XCTestCase { XCTAssertEqual(diarizer.numFramesProcessed, 0) XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0) XCTAssertEqual(namedSpeakerIndices(in: diarizer.timeline), [speaker?.index].compactMap { $0 }) - XCTAssertTrue(diarizer.hasActiveSession) + XCTAssertTrue(diarizer.isAvailable) } func testLseendEnrollSpeakerFollowedByStreamingProcessingStartsAtFrameZero() async throws { - XCTExpectFailure("Download might fail in CI environment", strict: false) - - let engine = try await loadLseendEngineForTest() - let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) - diarizer.initialize(engine: engine) + let model = try await loadLseendModelForTest() + let diarizer = try LSEENDDiarizer(model: model) let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( - sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0) + sampleRate: diarizer.targetSampleRate ?? 16_000, startSeconds: 0.0, durationSeconds: 3.0) let liveAudio = try DiarizationTestFixtures.fixtureAudio( - sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0) + sampleRate: diarizer.targetSampleRate ?? 16_000, startSeconds: 3.0, durationSeconds: 3.0) - let speaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice") + let speaker = try diarizer.enrollSpeaker( + withAudio: enrollmentAudio, + sourceSampleRate: nil, + named: "Alice", + overwritingAssignedSpeakerName: true + ) var firstUpdate: DiarizerTimelineUpdate? for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [977, 1231, 1607]) { - if let update = try diarizer.process(samples: chunk) { + if let update = try diarizer.process(samples: chunk, sourceSampleRate: nil) { firstUpdate = update break } } - let finalChunk = try diarizer.finalizeSession() + let finalChunk = try diarizer.finalize() XCTAssertTrue(firstUpdate != nil || finalChunk != nil) if let firstUpdate { @@ -388,40 +408,50 @@ final class SpeakerEnrollmentTests: XCTestCase { } func testLseendMultipleEnrollmentsRetainNamedSpeakersAndSession() async throws { - XCTExpectFailure("Download might fail in CI environment", strict: false) - - let engine = try await loadLseendEngineForTest() - let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) - diarizer.initialize(engine: engine) + let model = try await loadLseendModelForTest() + let diarizer = try LSEENDDiarizer(model: model) let speakerAAudio = try DiarizationTestFixtures.fixtureAudio( - sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0) + sampleRate: diarizer.targetSampleRate ?? 16_000, startSeconds: 0.0, durationSeconds: 3.0) let speakerBAudio = try DiarizationTestFixtures.fixtureAudio( - sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0) + sampleRate: diarizer.targetSampleRate ?? 16_000, startSeconds: 3.0, durationSeconds: 3.0) - let speakerA = try diarizer.enrollSpeaker(withSamples: speakerAAudio, named: "Alice") - let speakerB = try diarizer.enrollSpeaker(withSamples: speakerBAudio, named: "Bob") + let speakerA = try diarizer.enrollSpeaker( + withAudio: speakerAAudio, + sourceSampleRate: nil, + named: "Alice", + overwritingAssignedSpeakerName: true + ) + let speakerB = try diarizer.enrollSpeaker( + withAudio: speakerBAudio, + sourceSampleRate: nil, + named: "Bob", + overwritingAssignedSpeakerName: true + ) XCTAssertEqual(diarizer.numFramesProcessed, 0) XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0) - XCTAssertTrue(diarizer.hasActiveSession) + XCTAssertTrue(diarizer.isAvailable) let expectedNames = Set([speakerA?.name, speakerB?.name].compactMap { $0 }) XCTAssertEqual(Set(namedSpeakerNames(in: diarizer.timeline)), expectedNames) } func testLseendEnrollmentCanRefuseToOverwriteNamedSpeaker() async throws { - XCTExpectFailure("Download might fail in CI environment", strict: false) - - let engine = try await loadLseendEngineForTest() - let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) - diarizer.initialize(engine: engine) + let model = try await loadLseendModelForTest() + let diarizer = try LSEENDDiarizer(model: model) let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( - sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0) + sampleRate: diarizer.targetSampleRate ?? 16_000, startSeconds: 0.0, durationSeconds: 3.0) - let firstSpeaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice") + let firstSpeaker = try diarizer.enrollSpeaker( + withAudio: enrollmentAudio, + sourceSampleRate: nil, + named: "Alice", + overwritingAssignedSpeakerName: true + ) try XCTSkipIf( firstSpeaker == nil, "Fixture did not produce a confident LS-EEND speaker segment on this host.") let secondSpeaker = try diarizer.enrollSpeaker( - withSamples: enrollmentAudio, + withAudio: enrollmentAudio, + sourceSampleRate: nil, named: "Bob", overwritingAssignedSpeakerName: false ) @@ -444,3 +474,47 @@ final class SpeakerEnrollmentTests: XCTestCase { .sorted() } } + +private final class DummyUnavailableLSEENDDiarizer: Diarizer { + var isAvailable: Bool = false + var numFramesProcessed: Int = 0 + var targetSampleRate: Int? = nil + var modelFrameHz: Double? = nil + var numSpeakers: Int? = nil + var timeline = DiarizerTimeline(config: .default(numSpeakers: 1, frameDurationSeconds: 0.1)) + + func addAudio(_ samples: C, sourceSampleRate: Double?) throws where C: Collection, C.Element == Float {} + func process() throws -> DiarizerTimelineUpdate? { nil } + func process(samples: C, sourceSampleRate: Double?) throws -> DiarizerTimelineUpdate? + where + C: Collection, + C.Element == Float + { nil } + func processComplete( + _ samples: C, + sourceSampleRate: Double?, + keepingEnrolledSpeakers keepSpeakers: Bool?, + finalizeOnCompletion: Bool, + progressCallback: ((Int, Int, Int) -> Void)? + ) throws -> DiarizerTimeline where C: Collection, C.Element == Float { + throw LSEENDError.notInitialized + } + func processComplete( + audioFileURL: URL, + keepingEnrolledSpeakers keepSpeakers: Bool?, + finalizeOnCompletion: Bool, + progressCallback: ((Int, Int, Int) -> Void)? + ) throws -> DiarizerTimeline { + throw LSEENDError.notInitialized + } + func reset() {} + func cleanup() {} + func enrollSpeaker( + withAudio samples: C, + sourceSampleRate: Double?, + named name: String?, + overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool + ) throws -> DiarizerSpeaker? where C: Collection, C.Element == Float { + throw LSEENDError.notInitialized + } +}