Skip to content

[BUG] matmul produces incorrect results for batched / N-D inputs #87

@SwayamInSync

Description

@SwayamInSync

Summary

np.matmul on QuadPrecision arrays returns incorrect results whenever the inputs have leading (broadcast/batch) dimensions larger than 1. Only the first batch is computed; every subsequent output cell comes back as zero.

Reproducer

Click to expand!
import numpy as np
from numpy_quaddtype import QuadPrecision, QuadPrecDType

dt = QuadPrecDType(backend='sleef')
def Q(v): return QuadPrecision(str(float(v)), backend='sleef')
def Qa(arr):
    a = np.asarray(arr, dtype=np.float64)
    return np.array([Q(v) for v in a.ravel()], dtype=dt).reshape(a.shape)

# 2 stacked matmuls of (3,4) @ (4,5)
A = Qa(np.arange(2*3*4).reshape(2, 3, 4))
B = Qa(np.arange(2*4*5).reshape(2, 4, 5) + 100)

got = np.matmul(A, B)
ref = np.matmul(np.arange(2*3*4, dtype=np.float64).reshape(2,3,4),
                np.arange(2*4*5, dtype=np.float64).reshape(2,4,5) + 100)

print("Batch 0 matches:", np.allclose([float(v) for v in got[0].ravel()], ref[0].ravel()))
# True
print("Batch 1 all zeros:", all(float(v) == 0.0 for v in got[1].ravel()))
# True   ← bug: should match ref[1]

Root cause

In src/csrc/umath/matmul.cpp all three strided-loop functions capture N = dimensions[0] (the outer broadcast
loop length) but never iterate over it.
They also never advance the A, B, C pointers by their batch strides. As a result, batch 0 is computed correctly.

  • The single memset zeroes the whole output once at the start, so batches 1..N-1 just stay zero.
  • My misunderstanding also left an misleading comment // Batch size, this remains always 1 for matmul afaik

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions