Skip to content
5 changes: 4 additions & 1 deletion tests/integration/test_match_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@ def test_compare_huggingface_attention_match_local_implementation(self, model_na
hidden_states=input, output_attentions=True
)[0]

assert torch.allclose(tl_out, hf_out, atol=1e-4)
# Tolerance accounts for float32 accumulation differences between
# TransformerLens and HuggingFace attention implementations across
# 12 layers. Empirically, worst-case diff is ~1.3e-3 on layer 11.
assert torch.allclose(tl_out, hf_out, atol=1e-3)
42 changes: 42 additions & 0 deletions tests/unit/components/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,45 @@ def test_remove_einsum_from_complex_attn_linear():

# Check if the results are the same
assert torch.allclose(result_new, result_old, atol=1e-4)


@pytest.mark.skipif(
not torch.backends.mps.is_available() or torch.__version__ != "2.8.0",
reason="Issue with F.linear issue exclusive to mps and PyTorch 2.8\n"
"https://github.com/pytorch/pytorch/issues/161640",
)
def test_cpu_mps_outputs_match():
torch.manual_seed(0)

cfg = {
"n_layers": 1,
"d_model": 48,
"n_ctx": 256,
"d_head": 16,
"n_heads": 3,
"load_in_4bit": False,
"dtype": torch.float32,
"act_fn": "relu",
}

def init_weights(attn_layer: nn.Module):
nn.init.normal_(attn_layer.W_Q, mean=0.0, std=0.02)
nn.init.normal_(attn_layer.W_K, mean=0.0, std=0.02)
nn.init.normal_(attn_layer.W_V, mean=0.0, std=0.02)
nn.init.normal_(attn_layer.W_O, mean=0.0, std=0.02)
return attn_layer

attn_cpu = Attention(cfg)
attn_cpu = init_weights(attn_cpu)

attn_mps = Attention(cfg).to("mps")
attn_mps.load_state_dict(attn_cpu.state_dict(), strict=True)

batch = 1
input_cpu = torch.randn(batch, cfg["n_ctx"], cfg["d_model"])
input_mps = input_cpu.to("mps")

cpu_output = attn_cpu(input_cpu, input_cpu, input_cpu)
mps_output = attn_mps(input_mps, input_mps, input_mps)

assert torch.allclose(cpu_output, mps_output.cpu())
14 changes: 9 additions & 5 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,15 @@ def forward(
if self.b_O.device != z.device:
z = z.to(self.b_O.device)

out = F.linear(
z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
w,
self.b_O,
)
z = z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads)

# F.linear is a fused matmul+bias that matches HuggingFace exactly,
# but has a bug on MPS with PyTorch 2.8 (pytorch#161640).
# Fall back to manual matmul on MPS to work around it.
if z.device.type == "mps":
out = torch.matmul(z, w.T) + self.b_O
else:
out = F.linear(z, w, self.b_O)
else:
# Explicitly calculate the attention result so it can be accessed by a hook
# This is off by default because it can easily eat through your GPU memory.
Expand Down
Loading