Skip to content

feat: add euclidean one2many#188

Open
richyreachy wants to merge 4 commits intomainfrom
feat/euclidean_one2many
Open

feat: add euclidean one2many#188
richyreachy wants to merge 4 commits intomainfrom
feat/euclidean_one2many

Conversation

@richyreachy
Copy link
Collaborator

add euclidean one2many

@richyreachy richyreachy requested a review from iaojnh March 1, 2026 14:47
@greptile-apps
Copy link

greptile-apps bot commented Mar 1, 2026

Greptile Summary

This PR adds optimized one-to-many Euclidean distance batch computation for fp32, fp16, and int8 data types, mirroring the existing inner product distance batch implementation.

Key changes:

  • Added new files implementing SIMD-optimized Euclidean distance batch operations (AVX2, AVX512F, AVX512FP16)
  • Integrated Euclidean distance batch functions into the distance computation routing logic
  • Refactored inner product functions to add _inner_product_ prefix to avoid naming conflicts
  • Extracted common math utility functions to distance_batch_math.h
  • Fixed incorrect Hamming distance handling in SquaredEuclideanMetric

Critical issues found:

  • Multiple logical errors in remainder element handling that compute incorrect mathematical operations
  • Array indexing bugs that would cause out-of-bounds reads or process wrong elements
  • Pointer arithmetic errors that would read from incorrect memory locations

These bugs would cause incorrect distance computations and potential memory corruption.

Confidence Score: 0/5

  • This PR has critical bugs that will cause incorrect results and should not be merged
  • Found 5 critical logical errors in the core distance computation code that would produce mathematically incorrect results (computing multiplication instead of squared differences, wrong array indices, incorrect pointer arithmetic)
  • Pay close attention to euclidean_distance_batch_impl_int8.h (most critical), euclidean_distance_batch_impl.h, and euclidean_distance_batch_impl_fp16.h - all contain critical bugs in their SIMD implementations

Important Files Changed

Filename Overview
src/ailego/math_batch/euclidean_distance_batch_impl.h Added AVX512F and AVX2 implementations for fp32 squared euclidean distance with critical bugs in remainder handling
src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h Added fp16 euclidean distance implementations with a critical bug in AVX512F mid-size remainder case
src/ailego/math_batch/euclidean_distance_batch_impl_int8.h Added int8 euclidean distance implementation with multiple critical bugs in pointer arithmetic and remainder handling
src/ailego/math_batch/euclidean_distance_batch.h Added template structs for SquaredEuclideanDistanceBatch and EuclideanDistanceBatch with architecture-specific dispatching
src/core/metric/euclidean_metric.cc Added batch_distance implementation for EuclideanMetric and removed incorrect Hamming distance handling from SquaredEuclideanMetric

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[BaseDistance::ComputeBatch] --> B{Distance Type?}
    B -->|InnerProduct| C[InnerProductDistanceBatch]
    B -->|SquaredEuclidean| D[SquaredEuclideanDistanceBatch]
    B -->|Euclidean| E[EuclideanDistanceBatch]
    
    D --> D1{Data Type?}
    D1 -->|float| D2[SquaredEuclideanDistanceBatchImpl float]
    D1 -->|int8_t| D3[SquaredEuclideanDistanceBatchImpl int8_t]
    D1 -->|Float16| D4[SquaredEuclideanDistanceBatchImpl Float16]
    
    D2 --> D2A{CPU Features?}
    D2A -->|AVX512F| D2B[compute_one_to_many_squared_euclidean_avx512f_fp32]
    D2A -->|AVX2| D2C[compute_one_to_many_squared_euclidean_avx2_fp32]
    D2A -->|Fallback| D2D[compute_one_to_many_squared_euclidean_fallback]
    
    D3 --> D3A{CPU Features?}
    D3A -->|AVX2| D3B[compute_one_to_many_squared_euclidean_avx2_int8]
    D3A -->|Fallback| D3C[compute_one_to_many_squared_euclidean_fallback]
    
    D4 --> D4A{CPU Features?}
    D4A -->|AVX512FP16| D4B[compute_one_to_many_squared_euclidean_avx512fp16_fp16]
    D4A -->|AVX512F| D4C[compute_one_to_many_squared_euclidean_avx512f_fp16]
    D4A -->|AVX2| D4D[compute_one_to_many_squared_euclidean_avx2_fp16]
    D4A -->|Fallback| D4E[compute_one_to_many_squared_euclidean_fallback]
    
    E --> E1[Call SquaredEuclideanDistanceBatch]
    E1 --> E2[Apply sqrt to results]
Loading

Last reviewed commit: 9c07ff9

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +81 to +84
accs[i] = _mm512_mask3_fmadd_ps(
_mm512_mask_loadu_ps(zmm_undefined, mask, query + dim),
_mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i],
mask);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing query * ptrs[i] + accs[i] instead of (query - ptrs[i])^2 + accs[i]

Suggested change
accs[i] = _mm512_mask3_fmadd_ps(
_mm512_mask_loadu_ps(zmm_undefined, mask, query + dim),
_mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i],
mask);
__m512 q_vals = _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim);
__m512 d_vals = _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim);
__m512 diff = _mm512_sub_ps(q_vals, d_vals);
accs[i] = _mm512_mask_fmadd_ps(diff, mask, diff, accs[i]);

Comment on lines +128 to +148
switch (dimensionality - dim) {
case 7:
SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]);
/* FALLTHRU */
case 6:
SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]);
/* FALLTHRU */
case 5:
SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]);
/* FALLTHRU */
case 4:
SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]);
/* FALLTHRU */
case 3:
SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]);
/* FALLTHRU */
case 2:
SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]);
/* FALLTHRU */
case 1:
SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array indices should be offset by dim to process remaining elements correctly

Suggested change
switch (dimensionality - dim) {
case 7:
SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]);
/* FALLTHRU */
case 6:
SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]);
/* FALLTHRU */
case 5:
SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]);
/* FALLTHRU */
case 4:
SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]);
/* FALLTHRU */
case 3:
SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]);
/* FALLTHRU */
case 2:
SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]);
/* FALLTHRU */
case 1:
SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]);
switch (dimensionality - dim) {
case 7:
SSD_FP32_GENERAL(query[dim + 6], ptrs[i][dim + 6], results[i]);
/* FALLTHRU */
case 6:
SSD_FP32_GENERAL(query[dim + 5], ptrs[i][dim + 5], results[i]);
/* FALLTHRU */
case 5:
SSD_FP32_GENERAL(query[dim + 4], ptrs[i][dim + 4], results[i]);
/* FALLTHRU */
case 4:
SSD_FP32_GENERAL(query[dim + 3], ptrs[i][dim + 3], results[i]);
/* FALLTHRU */
case 3:
SSD_FP32_GENERAL(query[dim + 2], ptrs[i][dim + 2], results[i]);
/* FALLTHRU */
case 2:
SSD_FP32_GENERAL(query[dim + 1], ptrs[i][dim + 1], results[i]);
/* FALLTHRU */
case 1:
SSD_FP32_GENERAL(query[dim + 0], ptrs[i][dim + 0], results[i]);
}

Comment on lines +136 to +140
for (size_t i = 0; i < dp_batch; ++i) {
data_regs[i] = _mm512_cvtph_ps(
_mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim)));
accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing q * data_regs[i] + accs[i] instead of (q - data_regs[i])^2 + accs[i]

Suggested change
for (size_t i = 0; i < dp_batch; ++i) {
data_regs[i] = _mm512_cvtph_ps(
_mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim)));
accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]);
}
for (size_t i = 0; i < dp_batch; ++i) {
data_regs[i] = _mm512_cvtph_ps(
_mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim)));
__m512 diff = _mm512_sub_ps(q, data_regs[i]);
accs[i] = _mm512_fmadd_ps(diff, diff, accs[i]);
}


if (dimensionality >= dim + 16) {
for (size_t i = 0; i < dp_batch; ++i) {
__m128i q = _mm_loadu_si128((const __m128i *)query + dim);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pointer arithmetic on __m128i* reads from query + dim * 16 instead of query + dim

Suggested change
__m128i q = _mm_loadu_si128((const __m128i *)query + dim);
__m128i q = _mm_loadu_si128((const __m128i *)(query + dim));

Comment on lines +91 to +137
for (size_t i = 0; i < dp_batch; ++i) {
switch (dimensionality - dim) {
case 15:
SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]);
/* FALLTHRU */
case 14:
SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]);
/* FALLTHRU */
case 13:
SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]);
/* FALLTHRU */
case 12:
SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]);
/* FALLTHRU */
case 11:
SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]);
/* FALLTHRU */
case 10:
SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]);
/* FALLTHRU */
case 9:
SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]);
/* FALLTHRU */
case 8:
SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]);
/* FALLTHRU */
case 7:
SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]);
/* FALLTHRU */
case 6:
SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]);
/* FALLTHRU */
case 5:
SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]);
/* FALLTHRU */
case 4:
SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]);
/* FALLTHRU */
case 3:
SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]);
/* FALLTHRU */
case 2:
SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]);
/* FALLTHRU */
case 1:
SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing pointers instead of scalar values to SSD_INT8_GENERAL macro, and using wrong array indices (ptrs[14] instead of ptrs[i])

Suggested change
for (size_t i = 0; i < dp_batch; ++i) {
switch (dimensionality - dim) {
case 15:
SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]);
/* FALLTHRU */
case 14:
SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]);
/* FALLTHRU */
case 13:
SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]);
/* FALLTHRU */
case 12:
SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]);
/* FALLTHRU */
case 11:
SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]);
/* FALLTHRU */
case 10:
SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]);
/* FALLTHRU */
case 9:
SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]);
/* FALLTHRU */
case 8:
SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]);
/* FALLTHRU */
case 7:
SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]);
/* FALLTHRU */
case 6:
SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]);
/* FALLTHRU */
case 5:
SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]);
/* FALLTHRU */
case 4:
SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]);
/* FALLTHRU */
case 3:
SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]);
/* FALLTHRU */
case 2:
SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]);
/* FALLTHRU */
case 1:
SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]);
}
for (size_t i = 0; i < dp_batch; ++i) {
switch (dimensionality - dim) {
case 15:
SSD_INT8_GENERAL(query[dim + 14], ptrs[i][dim + 14], results[i]);
/* FALLTHRU */
case 14:
SSD_INT8_GENERAL(query[dim + 13], ptrs[i][dim + 13], results[i]);
/* FALLTHRU */
case 13:
SSD_INT8_GENERAL(query[dim + 12], ptrs[i][dim + 12], results[i]);
/* FALLTHRU */
case 12:
SSD_INT8_GENERAL(query[dim + 11], ptrs[i][dim + 11], results[i]);
/* FALLTHRU */
case 11:
SSD_INT8_GENERAL(query[dim + 10], ptrs[i][dim + 10], results[i]);
/* FALLTHRU */
case 10:
SSD_INT8_GENERAL(query[dim + 9], ptrs[i][dim + 9], results[i]);
/* FALLTHRU */
case 9:
SSD_INT8_GENERAL(query[dim + 8], ptrs[i][dim + 8], results[i]);
/* FALLTHRU */
case 8:
SSD_INT8_GENERAL(query[dim + 7], ptrs[i][dim + 7], results[i]);
/* FALLTHRU */
case 7:
SSD_INT8_GENERAL(query[dim + 6], ptrs[i][dim + 6], results[i]);
/* FALLTHRU */
case 6:
SSD_INT8_GENERAL(query[dim + 5], ptrs[i][dim + 5], results[i]);
/* FALLTHRU */
case 5:
SSD_INT8_GENERAL(query[dim + 4], ptrs[i][dim + 4], results[i]);
/* FALLTHRU */
case 4:
SSD_INT8_GENERAL(query[dim + 3], ptrs[i][dim + 3], results[i]);
/* FALLTHRU */
case 3:
SSD_INT8_GENERAL(query[dim + 2], ptrs[i][dim + 2], results[i]);
/* FALLTHRU */
case 2:
SSD_INT8_GENERAL(query[dim + 1], ptrs[i][dim + 1], results[i]);
/* FALLTHRU */
case 1:
SSD_INT8_GENERAL(query[dim + 0], ptrs[i][dim + 0], results[i]);
}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant