-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Closed
Labels
Description
Profiling the cold start of the Canary ASR model (canary-1b-v2) reveals significant redundant I/O and initialization overhead. Total execution time is 55.5s, with 52.4s spent on Model.from_pretrained and only 3.1s on inference.
Repro
View for repro script
# To run this script:
# 1. Have 'uv' installed
# 2. Install dependencies manually (I had to, to bypass project-level resolution issues):
# uv pip install torch viztracer "nemo-toolkit[asr]"
# 4. Run the script without project context:
# uv run --no-project trace_canary.py
import os
import time
import torch
from nemo.collections.asr.models import EncDecMultiTaskModel
from viztracer import VizTracer
def main():
ENABLE_TRACER = True
if ENABLE_TRACER:
tracer = VizTracer(output_file="canary_trace.json", log_torch=True, min_duration=100)
tracer.start()
start_time = time.perf_counter()
# Check if GPU is available
map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {map_location}")
# Load model
# We use canary-1b-v2 as it is the latest and most comprehensive model
try:
print("Loading canary-1b-v2 model...")
canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b-v2', map_location=map_location)
print("Model loaded.")
except Exception as e:
print(f"Failed to load canary-1b-v2. Error: {e}")
return
audio_path = "tutorials/tts/audio_samples/phonemes_as_input.wav"
if not os.path.exists(audio_path):
print(f"Audio file {audio_path} not found.")
return
# Transcribe
print("Starting transcription...")
# The return value of transcribe is a list of objects (Hypothesis or similar) that have a .text attribute
transcript = canary_model.transcribe(
audio=[audio_path],
batch_size=1,
source_lang='en',
target_lang='en',
)
end_time = time.perf_counter()
print(f"Total time: {end_time - start_time:.2f} seconds")
# Access the text of the first transcript
if transcript:
print(f"Transcript: {transcript[0].text}")
else:
print("No transcript generated.")
if ENABLE_TRACER:
tracer.stop()
tracer.save()
print("Trace saved to canary_trace.json")
if __name__ == "__main__":
main()Observations:
Interactive flame chart: https://trace.cafe/t/g5uMmDWHwI
- Redundant Archive Extraction (~22s overhead)
The.nemotarball is extracted to a temporary directory twice during a single load operation:
- Root Load:
SaveRestoreConnector._unpack_nemo_file(10.5s). - Nested Load: Inside
EncDecMultiTaskModel.__restore_timestamps_asr_model, a secondrestore_fromcall triggers_unpack_nemo_fileagain (11.5s).
- Heavy Recursive Initialization (~34s overhead)
__restore_timestamps_asr_modelconsumes 61% of total execution time.
- It invokes a full
ModelPT.restore_fromcycle for the internal timestamp model.
- Beyond the IO redundancy, this triggers a heavy initialization chain (
EncDecCTCModelBPE.__init__), incurring significant overhead fromhydra/OmegaConfobject creation (~15s) inside the nested scope. Looks like multiple very expensive ListConfig are invalidated because of interleaved invalidated repeatedmaybe_update_config_version.
Thoughts for investigation:
(I'm a web performance guy, so I'm entirely out of my depth here. Apologies…)
- Artifact Sharing: Can the nested timestamp model utilize the artifacts already unpacked by the parent
Canarymodel to eliminate the second 11.5s extraction cost? - Lazy/Light Loading: The overhead for initializing the timestamp helper (
EncDecCTCModelBPE) seems disproportionately high compared to the parent model. Is a lighter initialization path possible for this sub-module? - FYI: I did the same experiment with parakeet (
parakeet-tdt-0.6b-v3). Trace here: https://trace.cafe/t/iLn7mU0Hx9. It also has ~20sec of omegaconf thrashing. But no duplicate tar extractions.. that problem looks like its specific toEncDecMultiTaskModel.
Environment details
- OS Version: macOS 15.7.3 (Build 24G419)
- PyTorch Version: 2.9.1
- Python Version: 3.11.11
- CPU/GPU: Apple M1 Max
- Memory: 64 GB RAM
- MPS Support: Available and Built (though AFAIK unused here)