Conversation
Greptile SummaryThis PR adds optimized one-to-many Euclidean distance batch computation for fp32, fp16, and int8 data types, mirroring the existing inner product distance batch implementation. Key changes:
Critical issues found:
These bugs would cause incorrect distance computations and potential memory corruption. Confidence Score: 0/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[BaseDistance::ComputeBatch] --> B{Distance Type?}
B -->|InnerProduct| C[InnerProductDistanceBatch]
B -->|SquaredEuclidean| D[SquaredEuclideanDistanceBatch]
B -->|Euclidean| E[EuclideanDistanceBatch]
D --> D1{Data Type?}
D1 -->|float| D2[SquaredEuclideanDistanceBatchImpl float]
D1 -->|int8_t| D3[SquaredEuclideanDistanceBatchImpl int8_t]
D1 -->|Float16| D4[SquaredEuclideanDistanceBatchImpl Float16]
D2 --> D2A{CPU Features?}
D2A -->|AVX512F| D2B[compute_one_to_many_squared_euclidean_avx512f_fp32]
D2A -->|AVX2| D2C[compute_one_to_many_squared_euclidean_avx2_fp32]
D2A -->|Fallback| D2D[compute_one_to_many_squared_euclidean_fallback]
D3 --> D3A{CPU Features?}
D3A -->|AVX2| D3B[compute_one_to_many_squared_euclidean_avx2_int8]
D3A -->|Fallback| D3C[compute_one_to_many_squared_euclidean_fallback]
D4 --> D4A{CPU Features?}
D4A -->|AVX512FP16| D4B[compute_one_to_many_squared_euclidean_avx512fp16_fp16]
D4A -->|AVX512F| D4C[compute_one_to_many_squared_euclidean_avx512f_fp16]
D4A -->|AVX2| D4D[compute_one_to_many_squared_euclidean_avx2_fp16]
D4A -->|Fallback| D4E[compute_one_to_many_squared_euclidean_fallback]
E --> E1[Call SquaredEuclideanDistanceBatch]
E1 --> E2[Apply sqrt to results]
Last reviewed commit: 9c07ff9 |
| accs[i] = _mm512_mask3_fmadd_ps( | ||
| _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim), | ||
| _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i], | ||
| mask); |
There was a problem hiding this comment.
Computing query * ptrs[i] + accs[i] instead of (query - ptrs[i])^2 + accs[i]
| accs[i] = _mm512_mask3_fmadd_ps( | |
| _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim), | |
| _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i], | |
| mask); | |
| __m512 q_vals = _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim); | |
| __m512 d_vals = _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim); | |
| __m512 diff = _mm512_sub_ps(q_vals, d_vals); | |
| accs[i] = _mm512_mask_fmadd_ps(diff, mask, diff, accs[i]); |
| switch (dimensionality - dim) { | ||
| case 7: | ||
| SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]); | ||
| /* FALLTHRU */ | ||
| case 6: | ||
| SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]); | ||
| /* FALLTHRU */ | ||
| case 5: | ||
| SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]); | ||
| /* FALLTHRU */ | ||
| case 4: | ||
| SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]); | ||
| /* FALLTHRU */ | ||
| case 3: | ||
| SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]); | ||
| /* FALLTHRU */ | ||
| case 2: | ||
| SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]); | ||
| /* FALLTHRU */ | ||
| case 1: | ||
| SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]); |
There was a problem hiding this comment.
Array indices should be offset by dim to process remaining elements correctly
| switch (dimensionality - dim) { | |
| case 7: | |
| SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]); | |
| /* FALLTHRU */ | |
| case 6: | |
| SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]); | |
| /* FALLTHRU */ | |
| case 5: | |
| SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]); | |
| /* FALLTHRU */ | |
| case 4: | |
| SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]); | |
| /* FALLTHRU */ | |
| case 3: | |
| SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]); | |
| /* FALLTHRU */ | |
| case 2: | |
| SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]); | |
| /* FALLTHRU */ | |
| case 1: | |
| SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]); | |
| switch (dimensionality - dim) { | |
| case 7: | |
| SSD_FP32_GENERAL(query[dim + 6], ptrs[i][dim + 6], results[i]); | |
| /* FALLTHRU */ | |
| case 6: | |
| SSD_FP32_GENERAL(query[dim + 5], ptrs[i][dim + 5], results[i]); | |
| /* FALLTHRU */ | |
| case 5: | |
| SSD_FP32_GENERAL(query[dim + 4], ptrs[i][dim + 4], results[i]); | |
| /* FALLTHRU */ | |
| case 4: | |
| SSD_FP32_GENERAL(query[dim + 3], ptrs[i][dim + 3], results[i]); | |
| /* FALLTHRU */ | |
| case 3: | |
| SSD_FP32_GENERAL(query[dim + 2], ptrs[i][dim + 2], results[i]); | |
| /* FALLTHRU */ | |
| case 2: | |
| SSD_FP32_GENERAL(query[dim + 1], ptrs[i][dim + 1], results[i]); | |
| /* FALLTHRU */ | |
| case 1: | |
| SSD_FP32_GENERAL(query[dim + 0], ptrs[i][dim + 0], results[i]); | |
| } |
| for (size_t i = 0; i < dp_batch; ++i) { | ||
| data_regs[i] = _mm512_cvtph_ps( | ||
| _mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim))); | ||
| accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]); | ||
| } |
There was a problem hiding this comment.
Computing q * data_regs[i] + accs[i] instead of (q - data_regs[i])^2 + accs[i]
| for (size_t i = 0; i < dp_batch; ++i) { | |
| data_regs[i] = _mm512_cvtph_ps( | |
| _mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim))); | |
| accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]); | |
| } | |
| for (size_t i = 0; i < dp_batch; ++i) { | |
| data_regs[i] = _mm512_cvtph_ps( | |
| _mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim))); | |
| __m512 diff = _mm512_sub_ps(q, data_regs[i]); | |
| accs[i] = _mm512_fmadd_ps(diff, diff, accs[i]); | |
| } |
|
|
||
| if (dimensionality >= dim + 16) { | ||
| for (size_t i = 0; i < dp_batch; ++i) { | ||
| __m128i q = _mm_loadu_si128((const __m128i *)query + dim); |
There was a problem hiding this comment.
Pointer arithmetic on __m128i* reads from query + dim * 16 instead of query + dim
| __m128i q = _mm_loadu_si128((const __m128i *)query + dim); | |
| __m128i q = _mm_loadu_si128((const __m128i *)(query + dim)); |
| for (size_t i = 0; i < dp_batch; ++i) { | ||
| switch (dimensionality - dim) { | ||
| case 15: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 14: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]); | ||
| /* FALLTHRU */ | ||
| case 13: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 12: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 11: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]); | ||
| /* FALLTHRU */ | ||
| case 10: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 9: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 8: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 7: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 6: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 5: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 4: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 3: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 2: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 1: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]); | ||
| } |
There was a problem hiding this comment.
Passing pointers instead of scalar values to SSD_INT8_GENERAL macro, and using wrong array indices (ptrs[14] instead of ptrs[i])
| for (size_t i = 0; i < dp_batch; ++i) { | |
| switch (dimensionality - dim) { | |
| case 15: | |
| SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 14: | |
| SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]); | |
| /* FALLTHRU */ | |
| case 13: | |
| SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 12: | |
| SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 11: | |
| SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]); | |
| /* FALLTHRU */ | |
| case 10: | |
| SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 9: | |
| SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 8: | |
| SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 7: | |
| SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 6: | |
| SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 5: | |
| SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 4: | |
| SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 3: | |
| SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 2: | |
| SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 1: | |
| SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]); | |
| } | |
| for (size_t i = 0; i < dp_batch; ++i) { | |
| switch (dimensionality - dim) { | |
| case 15: | |
| SSD_INT8_GENERAL(query[dim + 14], ptrs[i][dim + 14], results[i]); | |
| /* FALLTHRU */ | |
| case 14: | |
| SSD_INT8_GENERAL(query[dim + 13], ptrs[i][dim + 13], results[i]); | |
| /* FALLTHRU */ | |
| case 13: | |
| SSD_INT8_GENERAL(query[dim + 12], ptrs[i][dim + 12], results[i]); | |
| /* FALLTHRU */ | |
| case 12: | |
| SSD_INT8_GENERAL(query[dim + 11], ptrs[i][dim + 11], results[i]); | |
| /* FALLTHRU */ | |
| case 11: | |
| SSD_INT8_GENERAL(query[dim + 10], ptrs[i][dim + 10], results[i]); | |
| /* FALLTHRU */ | |
| case 10: | |
| SSD_INT8_GENERAL(query[dim + 9], ptrs[i][dim + 9], results[i]); | |
| /* FALLTHRU */ | |
| case 9: | |
| SSD_INT8_GENERAL(query[dim + 8], ptrs[i][dim + 8], results[i]); | |
| /* FALLTHRU */ | |
| case 8: | |
| SSD_INT8_GENERAL(query[dim + 7], ptrs[i][dim + 7], results[i]); | |
| /* FALLTHRU */ | |
| case 7: | |
| SSD_INT8_GENERAL(query[dim + 6], ptrs[i][dim + 6], results[i]); | |
| /* FALLTHRU */ | |
| case 6: | |
| SSD_INT8_GENERAL(query[dim + 5], ptrs[i][dim + 5], results[i]); | |
| /* FALLTHRU */ | |
| case 5: | |
| SSD_INT8_GENERAL(query[dim + 4], ptrs[i][dim + 4], results[i]); | |
| /* FALLTHRU */ | |
| case 4: | |
| SSD_INT8_GENERAL(query[dim + 3], ptrs[i][dim + 3], results[i]); | |
| /* FALLTHRU */ | |
| case 3: | |
| SSD_INT8_GENERAL(query[dim + 2], ptrs[i][dim + 2], results[i]); | |
| /* FALLTHRU */ | |
| case 2: | |
| SSD_INT8_GENERAL(query[dim + 1], ptrs[i][dim + 1], results[i]); | |
| /* FALLTHRU */ | |
| case 1: | |
| SSD_INT8_GENERAL(query[dim + 0], ptrs[i][dim + 0], results[i]); | |
| } | |
| } |
add euclidean one2many