-
Notifications
You must be signed in to change notification settings - Fork 526
HuBERT support rollout #1111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
david-wei-01001
wants to merge
89
commits into
TransformerLensOrg:dev
Choose a base branch
from
david-wei-01001:hubert
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
HuBERT support rollout #1111
Changes from all commits
Commits
Show all changes
89 commits
Select commit
Hold shift + click to select a range
f7e76b6
Create hubert_block.py
david-wei-01001 926c1c4
Delete transformer_lens/components/hubert_block.py
david-wei-01001 95d50bc
Create HookedAudioEncoder.py
david-wei-01001 9a29295
Update HookedAudioEncoder.py
david-wei-01001 48c6efe
Update HookedAudioEncoder.py
david-wei-01001 1b7559c
Update HookedAudioEncoder.py
david-wei-01001 94fa33e
Update HookedAudioEncoder.py
david-wei-01001 6e93a5b
Update loading_from_pretrained.py
david-wei-01001 4edde8d
Update HookedAudioEncoder.py
david-wei-01001 a5ef321
Update loading_from_pretrained.py
david-wei-01001 cd930f3
Update loading_from_pretrained.py
david-wei-01001 4621730
Update loading_from_pretrained.py
david-wei-01001 548e693
Create hubert.py
david-wei-01001 5dc88a1
Update HookedAudioEncoder.py
david-wei-01001 8282805
Update HookedAudioEncoder.py
david-wei-01001 7f0c373
Create hubert_test.py
david-wei-01001 e8bbf84
Update hubert_test.py
david-wei-01001 86ac1d9
Update HookedAudioEncoder.py
david-wei-01001 8f1b889
Create hubert_ctc_test.py
david-wei-01001 afc2a35
Update HookedAudioEncoder.py
david-wei-01001 f94fa40
Create hubert_hook_test.py
david-wei-01001 cff50b3
Update hubert_hook_test.py
david-wei-01001 764810a
done
david-wei-01001 7e844a3
done
david-wei-01001 9a6bc7a
done
david-wei-01001 1ddbf7f
done
david-wei-01001 c646ee5
done
david-wei-01001 7d5fe2a
Rename hubert_ctc_test.py to demos/HuBERT_test/hubert_ctc_test.py
david-wei-01001 21a0256
Rename hubert_hook_test.py to demos/HuBERT_test /hubert_hook_test.py
david-wei-01001 c9f7c68
Rename hubert_hook_test.py to hubert_hook_test.py
david-wei-01001 2f578ce
Rename hubert_test.py to demos/HuBERT_test/hubert_test.py
david-wei-01001 f76c2ee
Update HookedAudioEncoder.py
david-wei-01001 69345b1
Update HookedAudioEncoder.py
david-wei-01001 7be3d4e
Update hubert.py
david-wei-01001 7e177c4
Update hubert_ctc_test.py
david-wei-01001 6737ccd
Update hubert_hook_test.py
david-wei-01001 e062f38
Update hubert_hook_test.py
david-wei-01001 340260f
Update HookedAudioEncoder.py
david-wei-01001 3c44076
Update loading_from_pretrained.py
david-wei-01001 64aeb4c
Update HookedAudioEncoder.py
david-wei-01001 71a4f51
Update hubert.py
david-wei-01001 f0207ca
Update hubert_ctc_test.py
david-wei-01001 98f6eac
Update hubert_hook_test.py
david-wei-01001 da84180
Update hubert_hook_test.py
david-wei-01001 ede04f8
Update hubert_test.py
david-wei-01001 305509a
Update loading_from_pretrained.py
david-wei-01001 6461e2e
Update hubert.py
david-wei-01001 5344612
Update HookedAudioEncoder.py
david-wei-01001 32db5d2
Update HookedAudioEncoder.py
david-wei-01001 560ffb9
Update hubert_hook_test.py
david-wei-01001 dda10e5
Update HookedAudioEncoder.py
david-wei-01001 219defb
Update hubert_hook_test.py
david-wei-01001 2df2d27
Update HookedAudioEncoder.py
david-wei-01001 46c3344
Update HookedAudioEncoder.py
david-wei-01001 6272b9f
Update loading_from_pretrained.py
david-wei-01001 af6163d
Update loading_from_pretrained.py
david-wei-01001 0b5a860
Update HookedAudioEncoder.py
david-wei-01001 817c97f
Update HookedAudioEncoder.py
david-wei-01001 48920e1
Update hubert_hook_test.py
david-wei-01001 6dcffb2
Update hubert_hook_test.py
david-wei-01001 5f7af85
Update HookedAudioEncoder.py
david-wei-01001 fefcea2
Update HookedAudioEncoder.py
david-wei-01001 b1414e0
Update HookedAudioEncoder.py
david-wei-01001 94bd3d7
Update HookedAudioEncoder.py
david-wei-01001 fbae9c1
Update HookedAudioEncoder.py
david-wei-01001 f23d0d9
Update HookedAudioEncoder.py
david-wei-01001 d20ee07
Update HookedAudioEncoder.py
david-wei-01001 00c12cb
Update HookedAudioEncoder.py
david-wei-01001 14ab5bb
Update HookedAudioEncoder.py
david-wei-01001 41402ba
Update loading_from_pretrained.py
david-wei-01001 b5cb2e1
Update loading_from_pretrained.py
david-wei-01001 f8200bc
Update loading_from_pretrained.py
david-wei-01001 6926e2b
Update loading_from_pretrained.py
david-wei-01001 e8e958c
Update loading_from_pretrained.py
david-wei-01001 fa89321
Update requirements.txt
david-wei-01001 5a7c5c7
Update requirements.txt
david-wei-01001 cd8e922
Update loading_from_pretrained.py
david-wei-01001 77285ba
Update loading_from_pretrained.py
david-wei-01001 9fa6464
Update loading_from_pretrained.py
david-wei-01001 fc9327e
Update HookedAudioEncoder.py
david-wei-01001 c6a43a7
Update bert_pooler.py
david-wei-01001 9427068
Update bert_pooler.py
david-wei-01001 e72aee1
Merge remote-tracking branch 'origin/dev' into hubert
jlarson4 e0a649d
Apply isort and black formatting to HuBERT files
jlarson4 fc6ec7e
Merge branch 'dev' into hubert
jlarson4 e6e0a95
Delete requirements.txt
david-wei-01001 0cd6b3a
Update bert_pooler.py
david-wei-01001 e54c098
Update HookedAudioEncoder.py
david-wei-01001 22f972e
Update HookedAudioEncoder.py
david-wei-01001 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,257 @@ | ||
| # test_hubert_ctc_lmhead.py | ||
| """ | ||
| Test script to verify HookedAudioEncoder.forward(..., use_ctc=True) | ||
| loads/uses an lm_head and produces CTC logits. | ||
|
|
||
| Usage: | ||
| python test_hubert_ctc_lmhead.py | ||
| Change the import to point at your HookedAudioEncoder implementation. | ||
| """ | ||
|
|
||
| import math | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from transformer_lens import HookedAudioEncoder | ||
|
|
||
| # ----- CONFIG ----- | ||
| SAMPLE_RATE = 16000 | ||
| DURATION_S = 1.0 | ||
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||
| BATCH_SIZE = 1 | ||
| # If you want to attempt optional decoding with a HF tokenizer, | ||
| # set TOKENIZER_NAME to a valid tokenizer (e.g. "facebook/wav2vec2-base-960h") | ||
| # or set to None to skip tokenizer decoding. | ||
| TOKENIZER_NAME = "facebook/hubert-large-ls960-ft" | ||
| # ------------------ | ||
|
|
||
|
|
||
| def make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S, amplitude=0.1): | ||
| t = np.linspace(0, duration, int(sr * duration), endpoint=False, dtype=np.float32) | ||
| return amplitude * np.sin(2 * math.pi * frequency * t) | ||
|
|
||
|
|
||
| def has_lm_head(model): | ||
| return any( | ||
| name.endswith("lm_head") or name == "lm_head" for name, _ in model.named_children() | ||
| ) or hasattr(model, "lm_head") | ||
|
|
||
|
|
||
| def try_get_lm_head(model): | ||
| if hasattr(model, "lm_head"): | ||
| return model.lm_head | ||
| # try common nested names | ||
| for name, module in model.named_modules(): | ||
| if name.endswith("lm_head") or name == "lm_head": | ||
| return module | ||
| return None | ||
|
|
||
|
|
||
| def print_param_info(module, prefix=""): | ||
| if module is None: | ||
| print(prefix + "None") | ||
| return | ||
| params = list(module.parameters()) | ||
| print(prefix + f"module type: {type(module)}, #params: {sum(p.numel() for p in params)}") | ||
| # print weight shape if available | ||
| if hasattr(module, "weight"): | ||
| try: | ||
| print(prefix + f" weight.shape = {tuple(module.weight.shape)}") | ||
| except Exception: | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| model = HookedAudioEncoder.from_pretrained("facebook/hubert-large-ls960-ft") | ||
|
|
||
| model.to(DEVICE) | ||
|
|
||
| # sample waveform | ||
| wav = make_sine(frequency=440.0) | ||
| x = torch.from_numpy(wav).unsqueeze(0).to(DEVICE) # shape (1, T) | ||
|
|
||
| print("=== lm_head presence BEFORE forward() ===") | ||
| print("has_lm_head():", has_lm_head(model)) | ||
| print("try_get_lm_head():") | ||
| print_param_info(try_get_lm_head(model), prefix=" ") | ||
|
|
||
| # Forward pass with use_ctc=True (some model APIs accept it directly, some do not). | ||
| print( | ||
| "\nCalling forward(..., use_ctc=True) -- if that fails, will set attribute and call without arg" | ||
| ) | ||
| logits = None | ||
| forward_exc = None | ||
| try: | ||
| # try direct call with argument | ||
| out = model(x, use_ctc=True) | ||
| except TypeError as e: | ||
| # forward signature may not accept use_ctc param; try setting attribute on model and call | ||
| forward_exc = e | ||
| print( | ||
| "Direct forward(..., use_ctc=True) failed with TypeError - will try setting model.use_ctc = True and calling forward(x)." | ||
| ) | ||
| try: | ||
| if hasattr(model, "use_ctc"): | ||
| model.use_ctc = True | ||
| else: | ||
| # set attribute anyway | ||
| setattr(model, "use_ctc", True) | ||
| out = model(x) | ||
| except Exception as e2: | ||
| print("Forward still failed after setting model.use_ctc =", e2) | ||
| raise | ||
|
|
||
| # Normalize out to logits tensor if possible | ||
| def extract_logits(out): | ||
| if out is None: | ||
| return None | ||
| if isinstance(out, torch.Tensor): | ||
| return out # assume logits | ||
| # dict-like outputs: look for common keys | ||
| if isinstance(out, dict): | ||
| for key in ("logits", "ctc_logits", "predictions", "hidden_states"): | ||
| if key in out: | ||
| t = out[key] | ||
| # if hidden_states is (batch, seq, dim) that's also fine to inspect | ||
| if isinstance(t, torch.Tensor): | ||
| return t | ||
| # if no known keys found, try to pick first tensor value | ||
| for v in out.values(): | ||
| if isinstance(v, torch.Tensor): | ||
| return v | ||
| # fallback: try to convert | ||
| return None | ||
|
|
||
| logits = extract_logits(out) | ||
| print("\n=== Post-forward lm_head presence ===") | ||
| print("has_lm_head():", has_lm_head(model)) | ||
| lm = try_get_lm_head(model) | ||
| print("try_get_lm_head():") | ||
| print_param_info(lm, prefix=" ") | ||
|
|
||
| if logits is None: | ||
| print( | ||
| "\nCould not automatically extract logits from the model output. The model returned:", | ||
| type(out), | ||
| ) | ||
| # if out is tensor-like but not torch tensor, attempt conversion | ||
| if hasattr(out, "numpy"): | ||
| try: | ||
| logits = torch.from_numpy(out.numpy()).to(DEVICE) | ||
| except Exception: | ||
| pass | ||
|
|
||
| if logits is not None: | ||
| print("\n=== Logits / CTC output info ===") | ||
| print("logits type:", type(logits)) | ||
| print("logits shape:", tuple(logits.shape)) | ||
| # typical CTC logits shape: (batch, time, vocab_size) or (batch, seq_len, vocab) | ||
| try: | ||
| print( | ||
| "stats: min=%.6g max=%.6g mean=%.6g" | ||
| % (logits.min().item(), logits.max().item(), logits.mean().item()) | ||
| ) | ||
| except Exception: | ||
| pass | ||
| assert torch.isfinite(logits).all(), "Found NaNs/Infs in logits!" | ||
|
|
||
| # simple decode: argmax over last dim -> token ids | ||
| if logits.ndim >= 2: | ||
| token_dim = -1 | ||
| token_ids = logits.argmax(dim=token_dim) # shape: (batch, time) | ||
| token_ids_cpu = token_ids.detach().cpu().numpy() | ||
| print("Sample argmax token ids (first batch, up to first 40 frames):") | ||
| print(token_ids_cpu[0][:40].tolist()) | ||
|
|
||
| # Optional: try to decode token ids to text if a tokenizer is available | ||
| if TOKENIZER_NAME is not None: | ||
| try: | ||
| from transformers import AutoTokenizer | ||
|
|
||
| tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME) | ||
| # For many CTC tokenizers, you need to collapse repeats and remove blank token id (often id=0 or tok.pad_token_id) | ||
| # Here we do a naive collapse+remove assuming blank token is tokenizer.pad_token_id or tokenizer.pad_token_id==tok.pad_token_id | ||
| blank_id = getattr(tok, "pad_token_id", None) | ||
| seq = token_ids_cpu[0].tolist() | ||
| # collapse repeats and remove blanks | ||
| collapsed = [] | ||
| prev = None | ||
| for t in seq: | ||
| if t == prev: | ||
| prev = t | ||
| continue | ||
| prev = t | ||
| if blank_id is not None and t == blank_id: | ||
| continue | ||
| collapsed.append(t) | ||
| decoded = tok.decode(collapsed, skip_special_tokens=True) | ||
| print("Decoded (naive collapse) text:", decoded) | ||
| except Exception as e: | ||
| print("Optional decoding failed:", e) | ||
|
|
||
| else: | ||
| print("No logits found — cannot run CTC-specific checks.") | ||
|
|
||
| # Gradient test specifically for transformer encoder (since lm_head is frozen) | ||
| print("\nRunning gradient propagation test through transformer encoder...") | ||
|
|
||
| model.train() | ||
| for p in model.parameters(): | ||
| if p.grad is not None: | ||
| p.grad.detach_() | ||
| p.grad.zero_() | ||
|
|
||
| try: | ||
| out2 = model(x, use_ctc=True) | ||
| except TypeError: | ||
| if hasattr(model, "use_ctc"): | ||
| model.use_ctc = True | ||
| out2 = model(x) | ||
|
|
||
| logits2 = extract_logits(out2) | ||
| if logits2 is None: | ||
| print("Could not extract logits for gradient test; aborting gradient check.") | ||
| else: | ||
| loss = logits2.mean() | ||
| loss.backward() | ||
|
|
||
| # --- Check that lm_head is frozen --- | ||
| lm = try_get_lm_head(model) | ||
| if lm is not None: | ||
| lm_params = list(lm.parameters()) | ||
| grads = [p.grad for p in lm_params if p.grad is not None] | ||
| if len(grads) > 0: | ||
| print("Warning: lm_head has gradients, but it should be frozen (eval mode).") | ||
| else: | ||
| print("✅ lm_head correctly frozen (no gradients).") | ||
|
|
||
| # --- Check that transformer block parameters have gradients --- | ||
| has_transformer_grad = False | ||
| for name, p in model.named_parameters(): | ||
| if "transformer" in name or "encoder" in name or "block" in name: | ||
| print(name) | ||
| if p.grad is not None and torch.isfinite(p.grad).all(): | ||
| has_transformer_grad = True | ||
| break | ||
|
|
||
| if has_transformer_grad: | ||
| print("✅ Gradient test PASSED: transformer block parameters have finite gradients.") | ||
| else: | ||
| print("❌ Gradient test FAILED: no gradients found in transformer blocks.") | ||
|
|
||
| print("\n=== DONE ===") | ||
| print("Interpretation notes:") | ||
| print( | ||
| " - If lm_head appears AFTER calling forward(use_ctc=True) and logits shape looks like (B, T, V)," | ||
| ) | ||
| print( | ||
| " then your forward-path is constructing/attaching an lm_head and producing CTC logits." | ||
| ) | ||
| print( | ||
| " - If lm_head parameters have finite gradients after loss.backward(), the head is hooked into the graph." | ||
| ) | ||
| print( | ||
| " - If you want a numeric golden-check, instantiate a HF Hubert/Wav2Vec2 CTC model and compare pooled logits/ids (optional)." | ||
| ) | ||
| print(model.named_parameters()) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, we build all our demos as Jupiter Notebooks, which can be run in an IDE or in Google Colab. This is for ease of use. Please reformat these
demos/HuBERT_test/files into a notebook, and create a set of tests for your new functionality intests/