Skip to content

Performance: EncDecMultiTaskModel (Canary) initialization triggers double .nemo extraction and recursive heavy reload (~52s cold start) #15240

@paulirish

Description

@paulirish

Profiling the cold start of the Canary ASR model (canary-1b-v2) reveals significant redundant I/O and initialization overhead. Total execution time is 55.5s, with 52.4s spent on Model.from_pretrained and only 3.1s on inference.

Repro

View for repro script
# To run this script:
# 1. Have 'uv' installed
# 2. Install dependencies manually (I had to, to bypass project-level resolution issues):
#    uv pip install torch viztracer "nemo-toolkit[asr]"
# 4. Run the script without project context:
#    uv run --no-project trace_canary.py

import os
import time
import torch
from nemo.collections.asr.models import EncDecMultiTaskModel
from viztracer import VizTracer

def main():
    ENABLE_TRACER = True

    if ENABLE_TRACER:
        tracer = VizTracer(output_file="canary_trace.json", log_torch=True, min_duration=100)
        tracer.start()

    start_time = time.perf_counter()

    # Check if GPU is available
    map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {map_location}")

    # Load model
    # We use canary-1b-v2 as it is the latest and most comprehensive model
    try:
        print("Loading canary-1b-v2 model...")
        canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b-v2', map_location=map_location)
        print("Model loaded.")
    except Exception as e:
        print(f"Failed to load canary-1b-v2. Error: {e}")
        return

    audio_path = "tutorials/tts/audio_samples/phonemes_as_input.wav"

    if not os.path.exists(audio_path):
        print(f"Audio file {audio_path} not found.")
        return

    # Transcribe
    print("Starting transcription...")
    # The return value of transcribe is a list of objects (Hypothesis or similar) that have a .text attribute
    transcript = canary_model.transcribe(
        audio=[audio_path],
        batch_size=1,
        source_lang='en',
        target_lang='en',
    )

    end_time = time.perf_counter()
    print(f"Total time: {end_time - start_time:.2f} seconds")

    # Access the text of the first transcript
    if transcript:
        print(f"Transcript: {transcript[0].text}")
    else:
        print("No transcript generated.")

    if ENABLE_TRACER:
        tracer.stop()
        tracer.save()
        print("Trace saved to canary_trace.json")

if __name__ == "__main__":
    main()

Observations:

Interactive flame chart: https://trace.cafe/t/g5uMmDWHwI

Image
  1. Redundant Archive Extraction (~22s overhead)
    The .nemo tarball is extracted to a temporary directory twice during a single load operation:
  • Root Load: SaveRestoreConnector._unpack_nemo_file (10.5s).
  • Nested Load: Inside EncDecMultiTaskModel.__restore_timestamps_asr_model, a second restore_from call triggers _unpack_nemo_file again (11.5s).
Image
  1. Heavy Recursive Initialization (~34s overhead)
    __restore_timestamps_asr_model consumes 61% of total execution time.
  • It invokes a full ModelPT.restore_from cycle for the internal timestamp model.
Image
  • Beyond the IO redundancy, this triggers a heavy initialization chain (EncDecCTCModelBPE.__init__), incurring significant overhead from hydra / OmegaConf object creation (~15s) inside the nested scope. Looks like multiple very expensive ListConfig are invalidated because of interleaved invalidated repeated maybe_update_config_version.
Image

Thoughts for investigation:

(I'm a web performance guy, so I'm entirely out of my depth here. Apologies…)

  • Artifact Sharing: Can the nested timestamp model utilize the artifacts already unpacked by the parent Canary model to eliminate the second 11.5s extraction cost?
  • Lazy/Light Loading: The overhead for initializing the timestamp helper (EncDecCTCModelBPE) seems disproportionately high compared to the parent model. Is a lighter initialization path possible for this sub-module?
  • FYI: I did the same experiment with parakeet (parakeet-tdt-0.6b-v3). Trace here: https://trace.cafe/t/iLn7mU0Hx9. It also has ~20sec of omegaconf thrashing. But no duplicate tar extractions.. that problem looks like its specific to EncDecMultiTaskModel.

Environment details

  • OS Version: macOS 15.7.3 (Build 24G419)
  • PyTorch Version: 2.9.1
  • Python Version: 3.11.11
  • CPU/GPU: Apple M1 Max
  • Memory: 64 GB RAM
  • MPS Support: Available and Built (though AFAIK unused here)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions