diff --git a/tests/unit/Hubert_audio_head_ablation_test.ipynb b/tests/unit/Hubert_audio_head_ablation_test.ipynb new file mode 100644 index 000000000..c9eff9f0c --- /dev/null +++ b/tests/unit/Hubert_audio_head_ablation_test.ipynb @@ -0,0 +1,1888 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5wXO1Udg5FeI", + "metadata": { + "id": "5wXO1Udg5FeI" + }, + "source": [ + "# Audio Head Ablation Test\n", + "\n", + "This notebook tests a simple audio forward pass, cache inspection, and a head ablation hook using the forked TransformerLens repo." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "KnoBcUvV5FeJ", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KnoBcUvV5FeJ", + "outputId": "2d65dcb7-d56d-4f4c-a92a-939b146ab5f7" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.8 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.7/1.8 MB\u001b[0m \u001b[31m22.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m24.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.0 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m27.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "ipython 7.34.0 requires jedi>=0.16, which is not installed.\u001b[0m\u001b[31m\n", + "\u001b[0mCloning into 'TransformerLens'...\n", + "remote: Enumerating objects: 5328, done.\u001b[K\n", + "remote: Counting objects: 100% (190/190), done.\u001b[K\n", + "remote: Compressing objects: 100% (123/123), done.\u001b[K\n", + "remote: Total 5328 (delta 126), reused 78 (delta 67), pack-reused 5138 (from 2)\u001b[K\n", + "Receiving objects: 100% (5328/5328), 25.07 MiB | 19.58 MiB/s, done.\n", + "Resolving deltas: 100% (3580/3580), done.\n", + "/content/TransformerLens\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n", + " Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Building editable for transformer-lens (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for transformers-stream-generator (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "shap 0.51.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n", + "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "cupy-cuda12x 14.0.1 requires numpy<2.6,>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "rasterio 1.5.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n", + "plum-dispatch 2.7.1 requires beartype>=0.16.2, but you have beartype 0.14.1 which is incompatible.\n", + "xarray-einstats 0.10.0 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "opencv-python 4.13.0.92 requires numpy>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", + "pytensor 2.38.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "opencv-python-headless 4.13.0.92 requires numpy>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", + "opencv-contrib-python 4.13.0.92 requires numpy>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", + "tobler 0.13.0 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip -q install -U pip setuptools wheel\n", + "!git clone https://github.com/david-wei-01001/TransformerLens.git\n", + "%cd TransformerLens\n", + "!pip -q install -e .\n", + "print(\"\\n⚠️ IMPORTANT: Restart runtime now, then run the next cell.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "WcdfVuX05FeJ", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "WcdfVuX05FeJ", + "outputId": "e28304f4-e562-4a23-abee-bd19ea4b4237" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Device: cuda\n" + ] + } + ], + "source": [ + "import math\n", + "from typing import Optional, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens import HookedAudioEncoder\n", + "\n", + "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "SAMPLE_RATE = 16000\n", + "DURATION_S = 1.0\n", + "torch.set_grad_enabled(False)\n", + "print('Device:', DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cBeHX9ze5FeK", + "metadata": { + "id": "cBeHX9ze5FeK" + }, + "outputs": [], + "source": [ + "def make_sine(\n", + " sr: int = SAMPLE_RATE,\n", + " duration: float = DURATION_S,\n", + " freq: float = 440.0,\n", + " amp: float = 0.1,\n", + ") -> np.ndarray:\n", + " t = np.linspace(0, duration, int(sr * duration), endpoint=False, dtype=np.float32)\n", + " return amp * np.sin(2 * math.pi * freq * t)\n", + "\n", + "\n", + "def get_output_repr(\n", + " model: HookedAudioEncoder,\n", + " frames: torch.Tensor,\n", + " frame_mask: Optional[torch.Tensor],\n", + " hooks=None,\n", + ") -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict]]:\n", + " \"\"\"Return a pooled representation plus optional logits/cache.\"\"\"\n", + " if hooks is None:\n", + " try:\n", + " out = model(frames, one_zero_attention_mask=frame_mask)\n", + " except TypeError:\n", + " out = model(frames)\n", + " else:\n", + " try:\n", + " out = model.run_with_hooks(frames, fwd_hooks=hooks, one_zero_attention_mask=frame_mask)\n", + " except TypeError:\n", + " out = model.run_with_hooks(frames, fwd_hooks=hooks)\n", + "\n", + " if isinstance(out, torch.Tensor):\n", + " return out.mean(dim=1) if out.ndim >= 2 else out.unsqueeze(0), out, None\n", + "\n", + " if isinstance(out, dict):\n", + " for key in ('logits', 'ctc_logits', 'predictions'):\n", + " if key in out and isinstance(out[key], torch.Tensor):\n", + " logits = out[key]\n", + " pooled = logits.mean(dim=1) if logits.ndim == 3 else logits\n", + " return pooled, logits, None\n", + "\n", + " raise RuntimeError(f'Unknown output type from model: {type(out)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "z__ODEfG5FeK", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 379, + "referenced_widgets": [ + "e57f174475d44c469d23d5b5b5d0c55e", + "8eaedcbace42414cbd3c14bf7ccece15", + "7965d943b8ab4df3ab784bc20683f5bc", + "b3ed809df81b4de7a2150abb7c51c460", + "ff70c5fe173d4f3ab02004211b6cb4e7", + "f055ddd334234fec8624390c3cd9eba9", + "329e30cc33a049ec8cac39749a4eea61", + "12a605fbd56f4ec994335bc45e1b50cb", + "7e2cd6f588f94b7087e1f03f2f53945f", + "141576adb9bd44d787c90a74628702d0", + "cb9fd856e5c145c7afaf2c7df13fc8c4", + "f65752526ea74c6a949e867a3e537c67", + "04207c4747704febb2471d993b17dba4", + "7332b3fcb4964482972779508a4fc8bf", + "24d779f9cd2c45cf9ae4e703f3b63ad7", + "d237b086aceb4798a4a338c2046dd95b", + "ad379c5dcf1741f3aaf45cb9ed64fa63", + "95cac2792a9d42598b010df67e1be087", + "bf14d8339de54391afa6f424f7286b74", + "524b7f3c8ac6466383cec1cfb3dc5eb5", + "3a179e9033a747b8ac6af03834216ed7", + "818c616cd75d4988aba0fad91d212811", + "627425e2f0b141ae8d4a669d594f39b1", + "1eb2f8c5068042018c2e872e6c1239ed", + "e621984ad12945ed8263589baefae8f4", + "f1f51620485345868724ad29c8cf48a1", + "cd412ef3db1b4f6b9e33e853de517af4", + "702801acb15a41c0b916a2e9ac37825a", + "177a3663f1334373ad3d97e62c290d76", + "9a95864dd1824d2f93d5b7d95b79962d", + "7b73808d2e654d089e188b4ea5423528", + "7324a9eb88084b56aca9ddbb1e784fbe", + "f2da718b147e45c592cb1cdaf1e72b9c", + "49bb43e9d60c4e70a1408366e58a6e00", + "936e0a15a86a4dd0a64f7a48ba0ee892", + "5153752e2b574cc09713440bdb5a2576", + "72ba70c1757347f2b6b9dc4b6442dcd7", + "34717283782e4c67937a83367250f00a", + "8d634d15848c47f08f5590664c50f0e3", + "3b6608c167894b1996fe8fd558e18087", + "185aa53f588144aab071d36b86008067", + "df94015640794724a51fc3f4e1854cdb", + "b2d5b15a02d74cf490371d523da34960", + "d8b7b6d18a0d46628f11038820457320" + ] + }, + "id": "z__ODEfG5FeK", + "outputId": "207f3a54-4591-4ed4-cba6-c5399d8780af" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Support for HuBERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n", + "If using HuBERT for interpretability research, keep in mind that HuBERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n", + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e57f174475d44c469d23d5b5b5d0c55e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0.00B [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`torch_dtype` is deprecated! Use `dtype` instead!\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f65752526ea74c6a949e867a3e537c67", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "pytorch_model.bin: 0%| | 0.00/378M [00:00=0.16, which is not installed.\u001b[0m\u001b[31m\n", + "\u001b[0mCloning into 'TransformerLens'...\n", + "remote: Enumerating objects: 5328, done.\u001b[K\n", + "remote: Counting objects: 100% (190/190), done.\u001b[K\n", + "remote: Compressing objects: 100% (123/123), done.\u001b[K\n", + "remote: Total 5328 (delta 126), reused 78 (delta 67), pack-reused 5138 (from 2)\u001b[K\n", + "Receiving objects: 100% (5328/5328), 25.07 MiB | 6.86 MiB/s, done.\n", + "Resolving deltas: 100% (3580/3580), done.\n", + "/content/TransformerLens\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n", + " Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Building editable for transformer-lens (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for transformers-stream-generator (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "shap 0.51.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n", + "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "cupy-cuda12x 14.0.1 requires numpy<2.6,>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "rasterio 1.5.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n", + "plum-dispatch 2.7.1 requires beartype>=0.16.2, but you have beartype 0.14.1 which is incompatible.\n", + "xarray-einstats 0.10.0 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "opencv-python 4.13.0.92 requires numpy>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", + "pytensor 2.38.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "opencv-python-headless 4.13.0.92 requires numpy>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", + "opencv-contrib-python 4.13.0.92 requires numpy>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", + "tobler 0.13.0 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "# Install dependencies and the forked repo\n", + "!pip -q install -U pip setuptools wheel\n", + "!git clone https://github.com/david-wei-01001/TransformerLens.git\n", + "%cd TransformerLens\n", + "!pip -q install -e .\n", + "print(\"\\n⚠️ IMPORTANT: Restart runtime now, then run the next cell.\")" + ], + "id": "BrXhdfsO79g-" + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Xr6qP-Rd79g-", + "outputId": "250160a9-f75b-479a-c32e-d64559a8bf8a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Device: cuda\n" + ] + } + ], + "source": [ + "import math\n", + "from typing import Any\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from transformer_lens import HookedAudioEncoder\n", + "\n", + "SAMPLE_RATE = 16000\n", + "DURATION_S = 1.0\n", + "BATCH_SIZE = 1\n", + "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "HF_CHECKPOINT = 'facebook/hubert-base-ls960'\n", + "\n", + "torch.set_grad_enabled(True)\n", + "print('Device:', DEVICE)" + ], + "id": "Xr6qP-Rd79g-" + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "YJxuD5Sw79g_" + }, + "outputs": [], + "source": [ + "def make_sine(\n", + " frequency: float = 440.0,\n", + " sr: int = SAMPLE_RATE,\n", + " duration: float = DURATION_S,\n", + " amplitude: float = 0.1,\n", + ") -> np.ndarray:\n", + " t = np.linspace(0, duration, int(sr * duration), endpoint=False, dtype=np.float32)\n", + " wav = amplitude * np.sin(2 * math.pi * frequency * t)\n", + " return wav\n", + "\n", + "\n", + "def extract_tensor(output: Any) -> torch.Tensor:\n", + " if isinstance(output, torch.Tensor):\n", + " return output\n", + " if isinstance(output, dict):\n", + " for key in ('predictions', 'logits', 'ctc_logits'):\n", + " if key in output and isinstance(output[key], torch.Tensor):\n", + " return output[key]\n", + " raise TypeError(f'Could not extract tensor from output type {type(output)}')" + ], + "id": "YJxuD5Sw79g_" + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "1F8xAakm79g_" + }, + "outputs": [], + "source": [ + "def run_basic_sanity_tests(model, waveform_np):\n", + " \"\"\"Run quick checks: forward pass, shape, finite, deterministic, grad flow.\"\"\"\n", + " model.to(DEVICE)\n", + "\n", + " x = torch.from_numpy(waveform_np).unsqueeze(0).to(DEVICE)\n", + "\n", + " # 1) Eval forward: no grad\n", + " model.eval()\n", + " with torch.no_grad():\n", + " out1 = model(x)\n", + " print('Forward (eval) output type:', type(out1))\n", + " out_tensor = extract_tensor(out1)\n", + "\n", + " print('Output shape:', tuple(out_tensor.shape))\n", + " print(\n", + " 'Output stats: min=%.6g max=%.6g mean=%.6g'\n", + " % (out_tensor.min().item(), out_tensor.max().item(), out_tensor.mean().item())\n", + " )\n", + " assert torch.isfinite(out_tensor).all(), 'Found NaNs or Infs in forward output!'\n", + "\n", + " # 2) Determinism in eval\n", + " with torch.no_grad():\n", + " out2 = model(x)\n", + " out2_tensor = extract_tensor(out2)\n", + " if not torch.allclose(out_tensor, out2_tensor, atol=1e-6):\n", + " print(\n", + " 'Warning: outputs differ between two eval runs (non-deterministic?), max diff:',\n", + " (out_tensor - out2_tensor).abs().max().item(),\n", + " )\n", + " else:\n", + " print('Determinism test passed (eval mode).')\n", + "\n", + " # 3) Gradient flow test in train mode\n", + " model.train()\n", + " for p in model.parameters():\n", + " if p.grad is not None:\n", + " p.grad.detach_()\n", + " p.grad.zero_()\n", + " out_train = model(x)\n", + " out_train_tensor = extract_tensor(out_train)\n", + " loss = out_train_tensor.mean()\n", + " loss.backward()\n", + "\n", + " grads_found = any(\n", + " (p.grad is not None and torch.isfinite(p.grad).all())\n", + " for p in model.parameters()\n", + " if p.requires_grad\n", + " )\n", + " assert grads_found, 'No finite gradients found on any parameter after backward()'\n", + " print('Gradient check passed: some parameters have finite gradients.')" + ], + "id": "1F8xAakm79g_" + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "VBckm3gn79hA" + }, + "outputs": [], + "source": [ + "def optional_compare_to_hf(your_model, waveform_np, sr: int = SAMPLE_RATE):\n", + " \"\"\"\n", + " OPTIONAL: compare your_model outputs to Hugging Face's HubertModel outputs.\n", + " Requires transformers and internet access.\n", + " \"\"\"\n", + " try:\n", + " from transformers import HubertModel, Wav2Vec2FeatureExtractor\n", + " except Exception as e:\n", + " print('Transformers or feature extractor not available:', e)\n", + " return\n", + "\n", + " print('Loading Hugging Face HubertModel for optional comparison...')\n", + " hf_feat = Wav2Vec2FeatureExtractor(sampling_rate=sr, do_normalize=True)\n", + " hf_model = HubertModel.from_pretrained(HF_CHECKPOINT).to(DEVICE).eval()\n", + "\n", + " input_values = hf_feat(waveform_np, sampling_rate=sr, return_tensors='pt').get('input_values')\n", + " input_values = input_values.to(DEVICE)\n", + "\n", + " with torch.no_grad():\n", + " hf_outputs = hf_model(input_values).last_hidden_state\n", + " hf_embedding = hf_outputs.mean(dim=1)\n", + "\n", + " your_model.eval()\n", + " with torch.no_grad():\n", + " your_out = your_model(torch.from_numpy(waveform_np).unsqueeze(0).to(DEVICE))\n", + " your_tensor = extract_tensor(your_out)\n", + "\n", + " if your_tensor.ndim == 3:\n", + " your_emb = your_tensor.mean(dim=1)\n", + " else:\n", + " your_emb = your_tensor\n", + "\n", + " if hf_embedding.shape[1] != your_emb.shape[1]:\n", + " print(\n", + " f'Dimension mismatch (HF {hf_embedding.shape[1]} vs your {your_emb.shape[1]}). '\n", + " 'Compare after projecting to a common dimension if needed.'\n", + " )\n", + " return\n", + "\n", + " cos = torch.nn.functional.cosine_similarity(hf_embedding, your_emb, dim=1)\n", + " print('Cosine similarity between HF pooled embedding and your model:', cos.cpu().numpy())" + ], + "id": "VBckm3gn79hA" + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 486, + "referenced_widgets": [ + "2d580bdb43e6427ba1d694fdf0b8551f", + "4742020e227243a088e9632bce7c79f5", + "ca1f06b6b6a541808fd7260f0207d7b4", + "d84c90c0041c4f7f91507409428b1b62", + "53520cb5a77147c8917c235cff5e0f71", + "b4b7b0e59f8d4c8096eda22c2a61a53c", + "fb2dd39ea1d14f9690c18bea155803df", + "70788fbe7d8b4e3ba9b787797d83a4f3", + "c2702a14bf354a58af26e1728bdf5bac", + "bdaa150960dd41d5b7f85d0b5de78f4c", + "1b4185f72f5d4806b435e9cb704b4917", + "484a01ea330642f9add68ce0ada0f9e7", + "fcc8c03313264b7283535684d30ad1b2", + "846e06d913424109a42822543f4b63a3", + "eace9841b56f4c448c44ef5cc3188a06", + "32f7f0749718436b829eab1fdbe6a331", + "10d5065800a74c439181dc76f02597cb", + "b54d10ce3d73462ca6f330c928c5feb7", + "dd584a60722c43bdad38854225a5ea4f", + "5b52c86a251d4733ab0647d65b808da8", + "240fb903e892414d8a783d2e5393149e", + "9903a8278bca42e5a01274519fd89698", + "e56d6f447fe946a8a6892022d992a5d7", + "4e2d94f0c7a44f5d8f69833de35688f7", + "18cf277137ce45d28f88e70445bdca24", + "d64a0f3f02c2486192058e98b7c2727a", + "c11d75c9b274428abccb6acf8236f1fe", + "ec281a2c615d40e2a60b21797e2887dd", + "5b9eeb862506442c8e4ff411c5d2ea5f", + "d419a6ae1d1246648d91f0ddda453433", + "9adadc0a37294ae58be10d7cc5fa0306", + "bf1ded4a734a42a68e159e7a2ea65a17", + "04b77427038a4e20bc0c4db3439e2e52", + "db7d151862104aec8fa492229335b0f8", + "991434e9d20d4fa6bd74e140a8436791", + "de7825b2bbc4450fbe71b6db3ca72116", + "adbf883f7ebd45eea2c7600a52fc96f6", + "48806f023af74a5289cc920c4da237ec", + "40dfe4a8237f47d7baa80dbcc576cc48", + "86d76c0566344a669b0c3f45b044878d", + "90559e1d7fb243e3ae48ea9807ea5b59", + "278477013af841c682618f0d5aa03e18", + "9502b18116cc458883724c789c0deb42", + "223c58d60dfe463b86fdb4d4faacbca7" + ] + }, + "id": "YRG_br4B79hA", + "outputId": "f07bfc9d-053e-4040-dee0-c7f1289aaeeb" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:root:Support for HuBERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n", + "If using HuBERT for interpretability research, keep in mind that HuBERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n", + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "config.json: 0.00B [00:00, ?B/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "2d580bdb43e6427ba1d694fdf0b8551f" + } + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "`torch_dtype` is deprecated! Use `dtype` instead!\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "pytorch_model.bin: 0%| | 0.00/378M [00:00\n", + "Output shape: (1, 49, 768)\n", + "Output stats: min=-2.88404 max=2.90499 mean=-0.00641759\n", + "Determinism test passed (eval mode).\n", + "Gradient check passed: some parameters have finite gradients.\n" + ] + } + ], + "source": [ + "# Create sample waveform\n", + "wav = make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S)\n", + "\n", + "# Instantiate your model\n", + "model = HookedAudioEncoder.from_pretrained('facebook/hubert-base-ls960').to(DEVICE)\n", + "\n", + "# Run tests\n", + "run_basic_sanity_tests(model, wav)" + ], + "id": "YRG_br4B79hA" + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LdpVdGYz79hB", + "outputId": "5402f971-781b-4634-e258-f164719880b5" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Loading Hugging Face HubertModel for optional comparison...\n", + "Cosine similarity between HF pooled embedding and your model: [0.99999994]\n" + ] + } + ], + "source": [ + "# Optional comparison to Hugging Face (requires transformers + internet)\n", + "optional_compare_to_hf(model, wav, sr=SAMPLE_RATE)" + ], + "id": "LdpVdGYz79hB" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "2d580bdb43e6427ba1d694fdf0b8551f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_4742020e227243a088e9632bce7c79f5", + "IPY_MODEL_ca1f06b6b6a541808fd7260f0207d7b4", + "IPY_MODEL_d84c90c0041c4f7f91507409428b1b62" + ], + "layout": "IPY_MODEL_53520cb5a77147c8917c235cff5e0f71" + } + }, + "4742020e227243a088e9632bce7c79f5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b4b7b0e59f8d4c8096eda22c2a61a53c", + "placeholder": "​", + "style": "IPY_MODEL_fb2dd39ea1d14f9690c18bea155803df", + "value": "config.json: " + } + }, + "ca1f06b6b6a541808fd7260f0207d7b4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_70788fbe7d8b4e3ba9b787797d83a4f3", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c2702a14bf354a58af26e1728bdf5bac", + "value": 1 + } + }, + "d84c90c0041c4f7f91507409428b1b62": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bdaa150960dd41d5b7f85d0b5de78f4c", + "placeholder": "​", + "style": "IPY_MODEL_1b4185f72f5d4806b435e9cb704b4917", + "value": " 1.39k/? [00:00<00:00, 114kB/s]" + } + }, + "53520cb5a77147c8917c235cff5e0f71": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b4b7b0e59f8d4c8096eda22c2a61a53c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fb2dd39ea1d14f9690c18bea155803df": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "70788fbe7d8b4e3ba9b787797d83a4f3": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "c2702a14bf354a58af26e1728bdf5bac": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "bdaa150960dd41d5b7f85d0b5de78f4c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1b4185f72f5d4806b435e9cb704b4917": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "484a01ea330642f9add68ce0ada0f9e7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_fcc8c03313264b7283535684d30ad1b2", + "IPY_MODEL_846e06d913424109a42822543f4b63a3", + "IPY_MODEL_eace9841b56f4c448c44ef5cc3188a06" + ], + "layout": "IPY_MODEL_32f7f0749718436b829eab1fdbe6a331" + } + }, + "fcc8c03313264b7283535684d30ad1b2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_10d5065800a74c439181dc76f02597cb", + "placeholder": "​", + "style": "IPY_MODEL_b54d10ce3d73462ca6f330c928c5feb7", + "value": "pytorch_model.bin: 100%" + } + }, + "846e06d913424109a42822543f4b63a3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dd584a60722c43bdad38854225a5ea4f", + "max": 377569754, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_5b52c86a251d4733ab0647d65b808da8", + "value": 377569754 + } + }, + "eace9841b56f4c448c44ef5cc3188a06": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_240fb903e892414d8a783d2e5393149e", + "placeholder": "​", + "style": "IPY_MODEL_9903a8278bca42e5a01274519fd89698", + "value": " 378M/378M [00:04<00:00, 138MB/s]" + } + }, + "32f7f0749718436b829eab1fdbe6a331": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "10d5065800a74c439181dc76f02597cb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b54d10ce3d73462ca6f330c928c5feb7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "dd584a60722c43bdad38854225a5ea4f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5b52c86a251d4733ab0647d65b808da8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "240fb903e892414d8a783d2e5393149e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9903a8278bca42e5a01274519fd89698": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e56d6f447fe946a8a6892022d992a5d7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_4e2d94f0c7a44f5d8f69833de35688f7", + "IPY_MODEL_18cf277137ce45d28f88e70445bdca24", + "IPY_MODEL_d64a0f3f02c2486192058e98b7c2727a" + ], + "layout": "IPY_MODEL_c11d75c9b274428abccb6acf8236f1fe" + } + }, + "4e2d94f0c7a44f5d8f69833de35688f7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ec281a2c615d40e2a60b21797e2887dd", + "placeholder": "​", + "style": "IPY_MODEL_5b9eeb862506442c8e4ff411c5d2ea5f", + "value": "model.safetensors: 100%" + } + }, + "18cf277137ce45d28f88e70445bdca24": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d419a6ae1d1246648d91f0ddda453433", + "max": 377510580, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_9adadc0a37294ae58be10d7cc5fa0306", + "value": 377510580 + } + }, + "d64a0f3f02c2486192058e98b7c2727a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bf1ded4a734a42a68e159e7a2ea65a17", + "placeholder": "​", + "style": "IPY_MODEL_04b77427038a4e20bc0c4db3439e2e52", + "value": " 378M/378M [00:02<00:00, 236MB/s]" + } + }, + "c11d75c9b274428abccb6acf8236f1fe": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ec281a2c615d40e2a60b21797e2887dd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5b9eeb862506442c8e4ff411c5d2ea5f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d419a6ae1d1246648d91f0ddda453433": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9adadc0a37294ae58be10d7cc5fa0306": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "bf1ded4a734a42a68e159e7a2ea65a17": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "04b77427038a4e20bc0c4db3439e2e52": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "db7d151862104aec8fa492229335b0f8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_991434e9d20d4fa6bd74e140a8436791", + "IPY_MODEL_de7825b2bbc4450fbe71b6db3ca72116", + "IPY_MODEL_adbf883f7ebd45eea2c7600a52fc96f6" + ], + "layout": "IPY_MODEL_48806f023af74a5289cc920c4da237ec" + } + }, + "991434e9d20d4fa6bd74e140a8436791": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_40dfe4a8237f47d7baa80dbcc576cc48", + "placeholder": "​", + "style": "IPY_MODEL_86d76c0566344a669b0c3f45b044878d", + "value": "preprocessor_config.json: 100%" + } + }, + "de7825b2bbc4450fbe71b6db3ca72116": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_90559e1d7fb243e3ae48ea9807ea5b59", + "max": 213, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_278477013af841c682618f0d5aa03e18", + "value": 213 + } + }, + "adbf883f7ebd45eea2c7600a52fc96f6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9502b18116cc458883724c789c0deb42", + "placeholder": "​", + "style": "IPY_MODEL_223c58d60dfe463b86fdb4d4faacbca7", + "value": " 213/213 [00:00<00:00, 15.2kB/s]" + } + }, + "48806f023af74a5289cc920c4da237ec": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "40dfe4a8237f47d7baa80dbcc576cc48": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "86d76c0566344a669b0c3f45b044878d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "90559e1d7fb243e3ae48ea9807ea5b59": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "278477013af841c682618f0d5aa03e18": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "9502b18116cc458883724c789c0deb42": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "223c58d60dfe463b86fdb4d4faacbca7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py new file mode 100644 index 000000000..dde2b31a5 --- /dev/null +++ b/transformer_lens/HookedAudioEncoder.py @@ -0,0 +1,526 @@ +"""Hooked Audio Encoder. + +Contains a HuBERT style model. This is separate from :class:`transformer_lens.HookedTransformer` +because it has a significantly different architecture to e.g. GPT style transformers. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload + +import numpy as np +import torch +import torch.nn as nn +from einops import repeat +from jaxtyping import Float, Int +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + HubertForCTC, + HubertModel, + Wav2Vec2Model, +) +from typing_extensions import Literal + +from transformer_lens import loading_from_pretrained as loading +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.components import MLP, Attention, BertBlock +from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.hook_points import HookedRootModule +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utilities import devices + +T = TypeVar("T", bound="HookedEncoder") + + +class HookedAudioEncoder(HookedRootModule): + """ + This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. + + Limitations: + - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. + + Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: + - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model + """ + + def __init__( + self, + cfg: Union[HookedTransformerConfig, Dict], + move_to_device: bool = True, + model_name: str = "facebook/hubert-base-ls960", + **kwargs: Any, + ): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig(**cfg) + elif isinstance(cfg, str): + raise ValueError( + "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedAudioEncoder.from_pretrained() instead." + ) + self.cfg = cfg + + assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" + + self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) + + if move_to_device: + if self.cfg.device is None: + raise ValueError("Cannot move to device when device is None") + self.to(self.cfg.device) + + self.setup() + + def _ensure_numpy(self, wave): + """ + Convert torch.Tensor / np.ndarray / list -> 1D np.float32 array on CPU. + """ + if isinstance(wave, torch.Tensor): + arr = wave.detach().cpu().numpy() + elif isinstance(wave, np.ndarray): + arr = wave + elif isinstance(wave, list): + arr = np.asarray(wave) + else: + raise TypeError("wave must be torch.Tensor, np.ndarray or list of floats") + + # force 1-D (if stereo or shape (N,1) etc) + if arr.ndim > 1: + # if shape (n_samples, n_channels) average channels -> mono + if arr.shape[1] <= arr.shape[0]: + arr = arr.mean(axis=1) + else: + arr = arr.reshape(-1) + + return arr.astype(np.float32, copy=False) + + def to_frames( + self, + raw_inputs: Union[torch.Tensor, List[torch.Tensor], List[np.ndarray]], + sampling_rate: int = 16000, + move_to_device: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert raw audio batch -> (projected frames, frame_attention_mask) + + Args: + raw_inputs: one of: + - a 1D torch.Tensor or numpy array (single waveform) + - a list of 1D torch.Tensors / numpy arrays (batch) + self.processor: HF AutoProcessor (creates input_values + sample-level attention_mask) + self.model: pretrained HubertModel (provides feature_extractor and feature_projection) + sampling_rate: sample rate of the audio (default 16k) + move_to_device: move outputs to model.device + + Returns: + frames: torch.Tensor of shape (batch, frames, hidden_size) <- after feature_projection + frame_attention_mask: torch.LongTensor of shape (batch, frames) with 1 for real frames, 0 for padding + """ + # AutoFeatureExtractor works better onnumpy array where it pads automatically. If passing in tensors, it does not pad properly, giving inhomogeneous arts error + if isinstance(raw_inputs, (torch.Tensor, np.ndarray)): + waves = [self._ensure_numpy(raw_inputs)] + elif isinstance(raw_inputs, list): + waves = [self._ensure_numpy(w) for w in raw_inputs] + else: + raise TypeError("Unsupported raw_inputs type") + + # Use HF processor to create input_values (padded) + sample-level attention_mask + # Processor will do padding so we can pass a variable-length batch + proc_out = self.processor( + waves, + sampling_rate=sampling_rate, + return_tensors="pt", + padding=True, + return_attention_mask=True, + ) + input_values = proc_out["input_values"] # (batch, samples), float + sample_attention_mask = proc_out.get( + "attention_mask" + ) # (batch, samples), 1 for valid, 0 for padding; may be None + + # move to device + device = self.cfg.device + if move_to_device: + input_values = input_values.to(device) + if sample_attention_mask is not None: + sample_attention_mask = sample_attention_mask.to(device) + + # 1) convolutional frontend -> (batch, conv_dim, conv_time) + if input_values.ndim > 2: + input_values = input_values.squeeze() + if input_values.ndim == 1: + input_values = input_values.unsqueeze(0) # (1, T) + with torch.no_grad(): + conv_feats = self.hubert_model.feature_extractor(input_values) # (B, C, T_conv) + + # 2) transpose to (batch, T_conv, C) + extract_features = conv_feats.transpose(1, 2) + + # 3) compute reduced frame-level attention mask (if sample mask provided) + frame_attention_mask = None + if sample_attention_mask is not None: + # model should provide helper _get_feature_vector_attention_mask + try: + frame_attention_mask = self.hubert_model._get_feature_vector_attention_mask( + extract_features.shape[1], sample_attention_mask + ) + except AttributeError: + # fallback: compute output lengths and create mask similarly to HF implementation + # compute output lengths (downsampled lengths) from sample attention mask (sums per example) + input_lengths = sample_attention_mask.sum(dim=-1) # (batch,) + # compute output lengths through conv layers using model._get_feat_extract_output_lengths if exists + if hasattr(model, "_get_feat_extract_output_lengths"): + output_lengths = self.hubert_model._get_feat_extract_output_lengths( + input_lengths + ).to(torch.long) + else: + # fallback to naive downsample ratio: output_frames = extract_features.shape[1] + output_lengths = torch.full( + (sample_attention_mask.shape[0],), + extract_features.shape[1], + device=device, + dtype=torch.long, + ) + + batch_size = sample_attention_mask.shape[0] + feat_len = extract_features.shape[1] + frame_attention_mask = torch.zeros( + (batch_size, feat_len), dtype=sample_attention_mask.dtype, device=device + ) + # mark the last valid index for each example and then cumsum trick to fill ones before it + idx = (torch.arange(batch_size, device=device), (output_lengths - 1).clamp(min=0)) + frame_attention_mask[idx] = 1 + frame_attention_mask = ( + frame_attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool().long() + ) + + # 4) feature projection -> (batch, frames, hidden_size) + with torch.no_grad(): + hidden_states = self.hubert_model.feature_projection( + extract_features + ) # typically returns (B, T, hidden) + # In HF's hubert, feature_projection is a module that returns a tensor (not tuple). If it returns tuple, adjust. + + # convert bool mask to long (1/0) if needed + if frame_attention_mask is not None: + frame_attention_mask = frame_attention_mask.to(dtype=torch.long) + + return hidden_states, frame_attention_mask + + def encoder_output( + self, + frames: torch.Tensor, # (batch, frames, d_model) <-- precomputed conv features + one_zero_attention_mask: Optional[torch.Tensor] = None, # (batch, frames) + ): + # Ensure device + if frames.device.type != self.cfg.device: + frames = frames.to(self.cfg.device) + if one_zero_attention_mask is not None: + one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) + + position_embeddings = self.hubert_model.encoder.pos_conv_embed(frames) + resid = frames + position_embeddings + resid = self.hubert_model.encoder.layer_norm(resid) + + large_negative_number = -torch.inf + mask = ( + repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") + if one_zero_attention_mask is not None + else None + ) + additive_attention_mask = ( + torch.where(mask == 1, large_negative_number, 0) if mask is not None else None + ) + for block in self.blocks: + resid = block(resid, additive_attention_mask) + + return resid + + def forward( + self, + inputs: Union[ + torch.Tensor, # waveform (1D) OR precomputed frames (3D) + List[Union[torch.Tensor, np.ndarray]], # list of waveforms + Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) + ], + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + sampling_rate: int = 16000, + move_to_device: bool = True, + ) -> Optional[torch.Tensor]: + """ + HuBERT-like forward (Transformer-Lens style). + + Args: + input: one of: + - 1D torch.Tensor or numpy array (single waveform) OR list of 1D waveforms -> will call self.to_frames(...) + - 3D torch.Tensor shaped (batch, frames, d_model) -> treated as precomputed frames (skip to_frames) + - tuple (frames, frame_mask) -> use directly + sampling_rate: sampling rate for to_frames when converting raw audio. + use_proj: Whether to use the final head of HubertCTC + move_to_device: move tensors to self.cfg.device (to match your other code). + + Returns: + Depending on return_type: + - "hidden": (batch, frames, d_model) final encoder hidden states + """ + # ---------- 1) Normalize input: get (frames, frame_mask) ---------- + frames = None + frame_mask = None # one_zero_attention_mask: 1 = valid, 0 = padding + # print(type(inputs)) + # If user passed (frames, mask) tuple + if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[0], torch.Tensor): + frames, frame_mask = inputs + + # If user passed a 3D tensor -> assume (B, T, D) frames (pre-projected) + elif isinstance(inputs, torch.Tensor) and inputs.ndim == 3: + frames = inputs + # frame_mask stays whatever was passed as separate argument (None here) + + # Else treat as raw waveform(s) -> call to_frames + else: + # allow single 1D tensor or numpy array or list of tensors/arrays + frames, frame_mask = self.to_frames(inputs) + # to_frames should already place tensors on device if move_to_device=True + if isinstance(frames, tuple): + frames = frames[0] + frame_mask = frame_mask if one_zero_attention_mask is None else one_zero_attention_mask + # ---------- 2) Ensure device & dtype consistency ---------- + device = self.cfg.device + if frames.device.type != device: + frames = frames.to(device) + if frame_mask is not None: + frame_mask = frame_mask.to(device) + + # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- + resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) + + return resid + + @overload + def run_with_cache( + self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: + ... + + @overload + def run_with_cache( + self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: + ... + + def run_with_cache( + self, + *model_args: Any, + return_cache_object: bool = True, + remove_batch_dim: bool = False, + **kwargs: Any, + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + Union[ActivationCache, Dict[str, torch.Tensor]], + ]: + """ + Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. + """ + out, cache_dict = super().run_with_cache( + *model_args, remove_batch_dim=remove_batch_dim, **kwargs + ) + if return_cache_object: + cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) + return out, cache + else: + return out, cache_dict + + def to( # type: ignore + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) + + def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: + if isinstance(device, int): + return self.to(f"cuda:{device}") + elif device is None: + return self.to("cuda") + else: + return self.to(device) + + def cpu(self: T) -> T: + return self.to("cpu") + + def mps(self: T) -> T: + return self.to(torch.device("mps")) + + @classmethod + def from_pretrained( + cls, + model_name: str, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, + hf_model: Optional[Any] = None, + device: Optional[str] = None, + move_to_device: bool = True, + dtype: torch.dtype = torch.float32, + **from_pretrained_kwargs: Any, + ) -> HookedEncoder: + """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" + logging.warning( + "Support for HuBERT in TransformerLens is currently experimental, until such a time when it has feature " + "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " + "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " + "implementation." + "\n" + "If using HuBERT for interpretability research, keep in mind that HuBERT has some significant architectural " + "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " + "that the last LayerNorm in a block cannot be folded." + ) + + assert not ( + from_pretrained_kwargs.get("load_in_8bit", False) + or from_pretrained_kwargs.get("load_in_4bit", False) + ), "Quantization not supported" + + if "torch_dtype" in from_pretrained_kwargs: + dtype = from_pretrained_kwargs["torch_dtype"] + + official_model_name = loading.get_official_model_name(model_name) + + cfg = loading.get_pretrained_model_config( + official_model_name, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, + fold_ln=False, + device=device, + n_devices=1, + dtype=dtype, + **from_pretrained_kwargs, + ) + + state_dict = loading.get_pretrained_state_dict( + official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + ) + + model = cls(cfg, move_to_device=False, model_name=official_model_name) + model.load_state_dict(state_dict, strict=False) + + model.processor = AutoFeatureExtractor.from_pretrained(official_model_name) + + if "wav2vec2" in model_name: + hubert_model = Wav2Vec2Model.from_pretrained(official_model_name) + else: + hubert_model = HubertModel.from_pretrained(official_model_name) + + if move_to_device: + if cfg.device is None: + raise ValueError("Cannot move to device when device is None") + hubert_model.to(cfg.device) + + hubert_model.eval() + model.hubert_model = hubert_model + + if move_to_device: + model.to(cfg.device) + + print(f"Loaded pretrained model {model_name} into HookedEncoder") + + return model + + @property + def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the key weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_K for block in self.blocks], dim=0) + + @property + def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the query weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) + + @property + def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the value weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_V for block in self.blocks], dim=0) + + @property + def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: + """Stacks the attn output weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_O for block in self.blocks], dim=0) + + @property + def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: + """Stacks the MLP input weights across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) + + @property + def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: + """Stacks the MLP output weights across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) + + @property + def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the key biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_K for block in self.blocks], dim=0) + + @property + def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the query biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) + + @property + def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the value biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_V for block in self.blocks], dim=0) + + @property + def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the attn output biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_O for block in self.blocks], dim=0) + + @property + def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: + """Stacks the MLP input biases across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) + + @property + def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the MLP output biases across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) + + @property + def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. + Useful for visualizing attention patterns.""" + return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) + + @property + def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" + return FactoredMatrix(self.W_V, self.W_O) + + def all_head_labels(self) -> List[str]: + """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" + return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 02e5f2561..479ab9fb2 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -13,6 +13,7 @@ from .HookedTransformer import HookedTransformer from .SVDInterpreter import SVDInterpreter from .HookedEncoder import HookedEncoder +from .HookedAudioEncoder import HookedAudioEncoder from .HookedEncoderDecoder import HookedEncoderDecoder from .BertNextSentencePrediction import BertNextSentencePrediction from . import head_detector diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8b3b64241..34832ed6b 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -18,7 +18,9 @@ AutoConfig, AutoModelForCausalLM, BertForPreTraining, + HubertModel, T5ForConditionalGeneration, + Wav2Vec2Model, ) import transformer_lens.utils as utils @@ -32,6 +34,7 @@ convert_gpt2_weights, convert_gpt_oss_weights, convert_gptj_weights, + convert_hubert_weights, convert_llama_weights, convert_mingpt_weights, convert_mistral_weights, @@ -64,6 +67,9 @@ "facebook/opt-13b", "facebook/opt-30b", "facebook/opt-66b", + "facebook/hubert-base-ls960", + "facebook/wav2vec2-base", + "facebook/wav2vec2-large", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", @@ -633,6 +639,9 @@ "google-bert/bert-base-uncased": ["bert-base-uncased"], "google-bert/bert-large-cased": ["bert-large-cased"], "google-bert/bert-large-uncased": ["bert-large-uncased"], + "facebook/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], + "facebook/wav2vec2-base": ["facebook/wav2vec2-base", "wav2vec2-base", "w2v2-base"], + "facebook/wav2vec2-large": ["facebook/wav2vec2-large", "wav2vec2-large", "w2v2-large"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], "roneneldan/TinyStories-3M": ["tiny-stories-3M"], "roneneldan/TinyStories-8M": ["tiny-stories-8M"], @@ -1230,6 +1239,51 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): } rotary_pct = hf_config.rotary_pct cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) + elif architecture == "HubertModel": + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + # HuBERT operates on audio frames, not tokens — n_ctx is flexible + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": getattr(hf_config, "hidden_act", "gelu"), + "attention_dir": "bidirectional", + "d_vocab": -1, # no text vocabulary + } + elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + # HuBERT operates on audio frames, not tokens — n_ctx is flexible + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": getattr(hf_config, "hidden_act", "gelu"), + "attention_dir": "bidirectional", + "d_vocab": -1, # no text vocabulary + } + elif architecture == "HubertForCTC": + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": getattr(hf_config, "hidden_act", "gelu"), + "attention_dir": "bidirectional", + # For CTC models: + "d_vocab": hf_config.vocab_size, # text vocab from tokenizer + } elif architecture == "BertForMaskedLM": # All supported Bert architectures have the same config, # so we can use the BertForMaskedLM config for all of them @@ -2402,6 +2456,20 @@ def get_pretrained_state_dict( huggingface_token = os.environ.get("HF_TOKEN", "") if official_model_name in NON_HF_HOSTED_MODEL_NAMES: raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") + elif "hubert" in official_model_name: + hf_model = HubertModel.from_pretrained( + official_model_name, + torch_dtype=dtype, + token=huggingface_token if len(huggingface_token) > 0 else None, + **kwargs, + ) + elif "wav2vec2" in official_model_name: + hf_model = Wav2Vec2Model.from_pretrained( + official_model_name, + torch_dtype=dtype, + token=huggingface_token if len(huggingface_token) > 0 else None, + **kwargs, + ) elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, @@ -2451,6 +2519,15 @@ def get_pretrained_state_dict( state_dict = convert_neox_weights(hf_model, cfg) elif cfg.original_architecture == "LlamaForCausalLM": state_dict = convert_llama_weights(hf_model, cfg) + elif cfg.original_architecture == "HubertModel": + state_dict = convert_hubert_weights(hf_model, cfg) + elif ( + cfg.original_architecture == "Wav2Vec2Model" + or cfg.original_architecture == "Wav2Vec2ForPreTraining" + ): + state_dict = convert_hubert_weights(hf_model, cfg) + elif cfg.original_architecture == "HubertForCTC": + state_dict = convert_hubert_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) elif cfg.original_architecture == "T5ForConditionalGeneration": diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index b8d940f62..aa0d0a553 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,5 +19,6 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .hubert import convert_hubert_weights from .apertus import convert_apertus_weights from .openai import convert_gpt_oss_weights diff --git a/transformer_lens/pretrained/weight_conversions/hubert.py b/transformer_lens/pretrained/weight_conversions/hubert.py new file mode 100644 index 000000000..d1b2d4cb4 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/hubert.py @@ -0,0 +1,146 @@ +import einops + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): + """ + Convert transformer encoder weights from a HuggingFace HuBERT model + into the state_dict expected by Transformer-Lens' HookedEncoder. + + Notes: + - This intentionally skips the convolutional frontend and feature_projection. + Those are used directly from the HF model (hf_model.feature_extractor, hf_model.feature_projection). + - Use model.load_state_dict(state_dict, strict=False) to load these. + """ + state_dict = {} + + # Try to find the encoder layer list (different HF variants use .layers or .layer) + encoder = getattr(hf_model, "encoder", None) + if encoder is None: + raise ValueError("hf_model has no .encoder attribute") + + encoder_layers = getattr(encoder, "layers", None) or getattr(encoder, "layer", None) + if encoder_layers is None: + # maybe hf_model itself is the encoder (unlikely), or a wrapped attribute + raise ValueError("Couldn't find encoder.layers or encoder.layer on hf_model.encoder") + + # Use cfg dims for reshaping + d_model = cfg.d_model + n_heads = cfg.n_heads + # d_head = d_model // n_heads # implicit if needed + + for l, layer in enumerate(encoder_layers): + # --- Attention module --- + # Some HF variants might call it `attention`, others `self_attn` etc. + att = getattr(layer, "attention", None) or getattr(layer, "self_attn", None) + if att is None: + raise AttributeError(f"Encoder layer {l} has no 'attention' or 'self_attn' attribute") + + # q/k/v/out proj names in HuBERT's HubertAttention: q_proj, k_proj, v_proj, out_proj + # fall back to common alternatives if present + q_w = getattr(att, "q_proj", None) + k_w = getattr(att, "k_proj", None) + v_w = getattr(att, "v_proj", None) + o_w = getattr(att, "out_proj", None) or getattr(att, "proj", None) + + if any(x is None for x in (q_w, k_w, v_w, o_w)): + # Try alternate nested attributes like att.q, att.k, att.v, att.o + q_w = q_w or getattr(att, "q", None) + k_w = k_w or getattr(att, "k", None) + v_w = v_w or getattr(att, "v", None) + o_w = o_w or getattr(att, "o", None) + + if any(x is None for x in (q_w, k_w, v_w, o_w)): + raise AttributeError(f"Could not find q/k/v/out projections in layer {l}. Found: {att}") + + # weights are Linear modules: weight shape (out, in) => same convention as Bert conversion + # reshape to Transformer-Lens expected shapes using einops + state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( + q_w.weight, "(i h) m -> i m h", i=n_heads + ) + if q_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( + q_w.bias, "(i h) -> i h", i=n_heads + ) + + state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( + k_w.weight, "(i h) m -> i m h", i=n_heads + ) + if k_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange( + k_w.bias, "(i h) -> i h", i=n_heads + ) + + state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( + v_w.weight, "(i h) m -> i m h", i=n_heads + ) + if v_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange( + v_w.bias, "(i h) -> i h", i=n_heads + ) + + state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( + o_w.weight, "m (i h) -> i h m", i=n_heads + ) + if o_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_O"] = o_w.bias + + # --- Layer norms inside the layer --- + # HuBERT layer has `layer.layer_norm` and `layer.final_layer_norm` + ln1 = getattr(layer, "layer_norm", None) + ln2 = getattr(layer, "final_layer_norm", None) + if ln1 is None or ln2 is None: + # try alternative names + ln1 = ln1 or getattr(layer, "attention_norm", None) + ln2 = ln2 or getattr(layer, "output_layer_norm", None) + + if ln1 is not None: + state_dict[f"blocks.{l}.ln1.w"] = ln1.weight + state_dict[f"blocks.{l}.ln1.b"] = ln1.bias + if ln2 is not None: + state_dict[f"blocks.{l}.ln2.w"] = ln2.weight + state_dict[f"blocks.{l}.ln2.b"] = ln2.bias + + # --- Feed-forward / MLP --- + # HuBERT uses `feed_forward` which contains intermediate_dense and output_dense + ff = ( + getattr(layer, "feed_forward", None) + or getattr(layer, "feedforward", None) + or getattr(layer, "ff", None) + ) + if ff is None: + raise AttributeError(f"Layer {l} has no feed_forward/ff attribute") + + # Many implementations name them intermediate_dense and output_dense + fc1 = ( + getattr(ff, "intermediate_dense", None) + or getattr(ff, "fc1", None) + or getattr(ff, "linear1", None) + ) + fc2 = ( + getattr(ff, "output_dense", None) + or getattr(ff, "fc2", None) + or getattr(ff, "linear2", None) + ) + + if fc1 is None or fc2 is None: + raise AttributeError(f"Could not find FFN dense layers in layer {l}: {ff}") + + # fc1.weight shape: (d_mlp, d_model) -> Transformer-Lens expects (d_model, d_mlp) + state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(fc1.weight, "mlp model -> model mlp") + if fc1.bias is not None: + state_dict[f"blocks.{l}.mlp.b_in"] = fc1.bias + + # fc2.weight shape: (d_model, d_mlp) -> Transformer-Lens expects (d_mlp, d_model) + state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(fc2.weight, "model mlp -> mlp model") + if fc2.bias is not None: + state_dict[f"blocks.{l}.mlp.b_out"] = fc2.bias + + # --- Optional: encoder-level layer_norm (HubertModel.encoder.layer_norm) --- + if hasattr(hf_model.encoder, "layer_norm"): + ln_final = hf_model.encoder.layer_norm + state_dict["ln_final.w"] = ln_final.weight + state_dict["ln_final.b"] = ln_final.bias + + return state_dict