Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 77 additions & 48 deletions meeteval/der/md_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,49 @@ def restore(self, filename):
raise ValueError(f'Cannot find {filename} as value in {self.cache}')


def _parse_md_eval_22_output(output: str) -> (DiaErrorRate, dict[str, DiaErrorRate]):
"""
Parses all output blocks from the md_eval_22 output. Each block has the format:

*** Performance analysis for Speaker Diarization for f=utt_9 ***

SCORED SPEAKER TIME =10.160000 secs
MISSED SPEAKER TIME =0.960000 secs
FALARM SPEAKER TIME =0.000000 secs
SPEAKER ERROR TIME =0.180000 secs
OVERALL SPEAKER DIARIZATION ERROR = 11.22 percent of scored speaker time `(f=utt_9)

Each block is parsed into a `DiaErrorRate`. Returns a `DiaErrorRate` object
for the overall error rate (named "ALL" in md-eval-22) and a dict of
`DiaErrorRate` objects for the individual files (f=???).
"""
# Pattern for each performance block
block_pattern = re.compile(
r"\*\*\* Performance analysis for Speaker Diarization for (?P<file>[^ ]+) \*\*\*"
r".*?SCORED SPEAKER TIME\s*=\s*(?P<scored>[\d.]+)"
r".*?MISSED SPEAKER TIME\s*=\s*(?P<missed>[\d.]+)"
r".*?FALARM SPEAKER TIME\s*=\s*(?P<falarm>[\d.]+)"
r".*?SPEAKER ERROR TIME\s*=\s*(?P<serror>[\d.]+)"
r".*?OVERALL SPEAKER DIARIZATION ERROR\s*=\s*(?P<error_rate>[\d.]+)",
re.DOTALL
)

results = {}

for match in block_pattern.finditer(output):
file_name = match.group("file")
results[file_name] = DiaErrorRate(
error_rate=decimal.Decimal(match.group("error_rate")) / 100,
scored_speaker_time=decimal.Decimal(match.group("scored")),
missed_speaker_time=decimal.Decimal(match.group("missed")),
falarm_speaker_time=decimal.Decimal(match.group("falarm")),
speaker_error_time=decimal.Decimal(match.group("serror")),
)

summary = results.pop('ALL')
results = {k[2:]: v for k, v in results.items() if k.startswith('f=')}

return summary, results

def md_eval_22_multifile(
reference, hypothesis, collar=0, regions='all',
Expand Down Expand Up @@ -247,13 +290,13 @@ def md_eval_22_multifile(
urllib.request.urlretrieve(url, md_eval_22)
logging.info(f'Wrote {md_eval_22}')

warned = False

def get_details(r, h, key, tmpdir, uem):
nonlocal warned
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)

r_file = tmpdir / f'{key}.ref.rttm'
h_file = tmpdir / f'{key}.hyp.rttm'
r = meeteval.io.RTTM([line for key in keys for line in r[key]])
h = meeteval.io.RTTM([line for key in keys for line in h[key]])
r_file = tmpdir / f'ref.rttm'
h_file = tmpdir / f'hyp.rttm'
r.dump(r_file)
h.dump(h_file)

Expand All @@ -262,69 +305,55 @@ def get_details(r, h, key, tmpdir, uem):
'-c', f'{collar}',
'-r', f'{r_file}',
'-s', f'{h_file}',
'-a', 'f', # Per-file details
]

if regions == 'nooverlap':
cmd.append('-1')

if uem:
uem_file = tmpdir / f'{key}.uem'
uem_file = tmpdir / f'uem.uem'
uem = escaper.escape_uem(uem)
uem.dump(uem_file)
cmd.extend(['-u', f'{uem_file}'])
elif not warned:
else:
warned = True
logging.warning(f'No UEM file provided. See https://github.com/fgnt/meeteval/issues/97#issuecomment-2508140402 for details.')
cp = subprocess.run(cmd, stdout=subprocess.PIPE,
check=True, universal_newlines=True)

# SCORED SPEAKER TIME =4309.340250 secs
# MISSED SPEAKER TIME =4309.340250 secs
# FALARM SPEAKER TIME =0.000000 secs
# SPEAKER ERROR TIME =0.000000 secs
# OVERALL SPEAKER DIARIZATION ERROR = 100.00 percent of scored speaker time `(ALL)
md_eval, per_reco = _parse_md_eval_22_output(cp.stdout)

error_rate, = re.findall(r'OVERALL SPEAKER DIARIZATION ERROR = ([\d.]+) percent of scored speaker time',
cp.stdout)
length, = re.findall(r'SCORED SPEAKER TIME =([\d.]+) secs', cp.stdout)
deletions, = re.findall(r'MISSED SPEAKER TIME =([\d.]+) secs', cp.stdout)
insertions, = re.findall(r'FALARM SPEAKER TIME =([\d.]+) secs', cp.stdout)
substitutions, = re.findall(r'SPEAKER ERROR TIME =([\d.]+) secs', cp.stdout)
assert per_reco.keys() == keys, (per_reco.keys(), keys)

def convert(string):
return decimal.Decimal(string)

return DiaErrorRate(
scored_speaker_time=convert(length),
missed_speaker_time=convert(deletions),
falarm_speaker_time=convert(insertions),
speaker_error_time=convert(substitutions),
error_rate=convert(error_rate) / 100,
)

per_reco = {}
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
for key in keys:
per_reco[escaper.restore(key)] = get_details(r[key], h[key], key, tmpdir, uem)

md_eval = get_details(
meeteval.io.RTTM([line for key in keys for line in r[key]]),
meeteval.io.RTTM([line for key in keys for line in h[key]]),
'',
tmpdir,
uem,
)
summary = sum(per_reco.values())
error_rate = summary.error_rate.quantize(md_eval.error_rate)
if error_rate != md_eval.error_rate:

# Due to floating point precision, the output of md-eval-22.pl is not
# always reproduced exactly by average across the per-recording numbers.
# We'll raise an error if the difference is large and print a warning
# when it only differs slightly.
if abs(summary.error_rate - md_eval.error_rate) > 0.00007:
raise RuntimeError(
f'The error rate of md-eval-22.pl on all recordings '
f'({summary.error_rate})\n'
f'does not match the average error rate of md-eval-22.pl '
f'applied to each recording ({md_eval.error_rate}).'
f'({md_eval.error_rate})\n'
f'differs from the the averaged error rate across '
f'all sessions ({summary.error_rate}) by more than 0.00007 '
f'({abs(summary.error_rate - md_eval.error_rate)}.'
)

quantized_error_rate = summary.error_rate.quantize(
md_eval.error_rate, rounding='ROUND_HALF_UP'
)
if quantized_error_rate != md_eval.error_rate:
logging.warning(
f'The error rate of md-eval-22.pl on all recordings '
f'({md_eval.error_rate}) does not match the averaged error '
f'rate across all sessions ({quantized_error_rate}). This can '
f'happen due to floating point inaccuracies.'
)

per_reco = {escaper.restore(k): v for k, v in per_reco.items()}

return per_reco


Expand Down