diff --git a/Cargo.lock b/Cargo.lock index fd817b46a..3cef3a3f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -658,6 +658,7 @@ dependencies = [ "diskann-benchmark-runner", "diskann-disk", "diskann-label-filter", + "diskann-pipnn", "diskann-providers", "diskann-quantization", "diskann-tools", @@ -740,6 +741,7 @@ dependencies = [ "criterion", "diskann", "diskann-linalg", + "diskann-pipnn", "diskann-platform", "diskann-providers", "diskann-quantization", @@ -756,6 +758,7 @@ dependencies = [ "rayon", "rstest", "serde", + "serde_json", "tempfile", "thiserror 2.0.17", "tokio", @@ -816,6 +819,28 @@ dependencies = [ "thiserror 2.0.17", ] +[[package]] +name = "diskann-pipnn" +version = "0.49.1" +dependencies = [ + "bytemuck", + "criterion", + "diskann", + "diskann-linalg", + "diskann-quantization", + "diskann-utils", + "diskann-vector", + "half", + "num-traits", + "parking_lot", + "rand 0.9.2", + "rand_distr", + "rayon", + "serde", + "thiserror 2.0.17", + "tracing", +] + [[package]] name = "diskann-platform" version = "0.49.1" diff --git a/Cargo.toml b/Cargo.toml index 91cb564af..e54f4bcf1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ members = [ "diskann-benchmark", "diskann-tools", "vectorset", + # PiPNN + "diskann-pipnn", ] default-members = [ @@ -60,6 +62,8 @@ diskann = { path = "diskann", version = "0.49.1" } diskann-providers = { path = "diskann-providers", default-features = false, version = "0.49.1" } diskann-disk = { path = "diskann-disk", version = "0.49.1" } diskann-label-filter = { path = "diskann-label-filter", version = "0.49.1" } +# PiPNN +diskann-pipnn = { path = "diskann-pipnn", version = "0.49.1" } # Infra diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.49.1" } diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.49.1" } diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index bebaf4b8e..0af385c45 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -30,7 +30,8 @@ diskann-vector.workspace = true diskann-wide.workspace = true diskann-label-filter.workspace = true diskann-tools = { workspace = true } -diskann-disk = { workspace = true, optional = true } +diskann-disk = { workspace = true, optional = true, features = ["pipnn"] } +diskann-pipnn = { workspace = true } cfg-if.workspace = true diskann-benchmark-runner = { workspace = true } opentelemetry = { workspace = true, optional = true } diff --git a/diskann-benchmark/src/backend/disk_index/build.rs b/diskann-benchmark/src/backend/disk_index/build.rs index b6ebf3b83..8cff002c4 100644 --- a/diskann-benchmark/src/backend/disk_index/build.rs +++ b/diskann-benchmark/src/backend/disk_index/build.rs @@ -91,10 +91,11 @@ where let metadata = load_metadata_from_file(storage_provider, &data_path)?; - let build_parameters = DiskIndexBuildParameters::new( + let build_parameters = DiskIndexBuildParameters::new_with_algorithm( MemoryBudget::try_from_gb(params.build_ram_limit_gb)?, params.quantization_type, NumPQChunks::new_with(params.num_pq_chunks.get(), metadata.ndims())?, + params.build_algorithm.clone(), ); let index_configuration = IndexConfiguration::new( diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index bf843d72f..a757130cf 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -10,6 +10,8 @@ use diskann_benchmark_runner::{ files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, }; #[cfg(feature = "disk-index")] +use diskann_disk::BuildAlgorithm; +#[cfg(feature = "disk-index")] use diskann_disk::QuantizationType; use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file}; use serde::{Deserialize, Serialize}; @@ -68,6 +70,10 @@ pub(crate) struct DiskIndexBuild { pub(crate) num_pq_chunks: NonZeroUsize, #[cfg(feature = "disk-index")] pub(crate) quantization_type: QuantizationType, + /// Build algorithm: "Vamana" (default) or "PiPNN" with config params. + #[cfg(feature = "disk-index")] + #[serde(default)] + pub(crate) build_algorithm: BuildAlgorithm, pub(crate) save_path: String, } @@ -257,6 +263,8 @@ impl Example for DiskIndexOperation { num_pq_chunks: NonZeroUsize::new(16).unwrap(), #[cfg(feature = "disk-index")] quantization_type: QuantizationType::PQ { num_chunks: 16 }, + #[cfg(feature = "disk-index")] + build_algorithm: BuildAlgorithm::default(), save_path: "sample_index_l50_r32".to_string(), }; @@ -351,6 +359,8 @@ impl DiskIndexBuild { } } } + #[cfg(feature = "disk-index")] + write_field!(f, "Build Algorithm", self.build_algorithm)?; write_field!(f, "Save Path", self.save_path)?; Ok(()) } diff --git a/diskann-disk/Cargo.toml b/diskann-disk/Cargo.toml index c68d65769..e81ab0a7b 100644 --- a/diskann-disk/Cargo.toml +++ b/diskann-disk/Cargo.toml @@ -45,6 +45,7 @@ vfs = { workspace = true } # Optional dependencies opentelemetry = { workspace = true, optional = true } +diskann-pipnn = { workspace = true, optional = true } [target.'cfg(target_os = "linux")'.dependencies] io-uring = "0.6.4" @@ -54,6 +55,7 @@ libc = "0.2.148" rstest.workspace = true tempfile.workspace = true vfs.workspace = true +serde_json.workspace = true diskann-providers = { workspace = true, default-features = false, features = [ "testing", "virtual_storage", @@ -66,6 +68,7 @@ diskann = { workspace = true } [features] default = [] perf_test = ["dep:opentelemetry"] +pipnn = ["dep:diskann-pipnn"] virtual_storage = ["diskann-providers/virtual_storage"] experimental_diversity_search = [ "diskann/experimental_diversity_search", @@ -82,4 +85,4 @@ harness = false # Some 'cfg's in the source tree will be flagged by `cargo clippy -j 2 --workspace --no-deps --all-targets -- -D warnings` [lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(coverage)'] } +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(coverage)', 'cfg(feature, values("pipnn"))'] } diff --git a/diskann-disk/src/build/builder/build.rs b/diskann-disk/src/build/builder/build.rs index 8eabad038..2dbd048b5 100644 --- a/diskann-disk/src/build/builder/build.rs +++ b/diskann-disk/src/build/builder/build.rs @@ -53,6 +53,7 @@ use crate::{ }, continuation::{process_while_resource_is_available_async, ChunkingConfig}, }, + configuration::build_algorithm::BuildAlgorithm, }, storage::{ quant::{GeneratorContext, PQGeneration, PQGenerationContext, QuantDataGenerator}, @@ -235,15 +236,35 @@ where self.index_configuration.num_threads ); + let t_pq = std::time::Instant::now(); self.generate_compressed_data(&pool).await?; logger.log_checkpoint(DiskIndexBuildCheckpoint::PqConstruction); + let pq_secs = t_pq.elapsed().as_secs_f64(); + let t_index = std::time::Instant::now(); self.build_inmem_index(&pool).await?; logger.log_checkpoint(DiskIndexBuildCheckpoint::InmemIndexBuild); + let index_secs = t_index.elapsed().as_secs_f64(); + + // Return freed memory (f32 data, graph, PiPNN internals) to the OS + // before disk layout starts. Without this, ~1.7 GB of freed-but-retained + // memory inflates peak RSS during the disk layout phase. + #[cfg(target_os = "linux")] + unsafe { + extern "C" { fn malloc_trim(pad: usize) -> i32; } + malloc_trim(0); + } // Use physical file to pass the memory index to the disk writer + let t_layout = std::time::Instant::now(); self.create_disk_layout()?; logger.log_checkpoint(DiskIndexBuildCheckpoint::DiskLayout); + let layout_secs = t_layout.elapsed().as_secs_f64(); + + println!("Disk Index Build Phases"); + println!(" PQ compression: {:.3}s", pq_secs); + println!(" Graph build: {:.3}s", index_secs); + println!(" Disk layout: {:.3}s", layout_secs); Ok(()) } @@ -313,6 +334,22 @@ where } async fn build_inmem_index(&mut self, pool: &RayonThreadPool) -> ANNResult<()> { + // Check for PiPNN algorithm + #[cfg(feature = "pipnn")] + if let BuildAlgorithm::PiPNN { .. } = self.disk_build_param.build_algorithm() { + return self.build_pipnn_index().await; + } + + #[cfg(not(feature = "pipnn"))] + if !matches!( + self.disk_build_param.build_algorithm(), + BuildAlgorithm::Vamana + ) { + return Err(ANNError::log_index_error( + "PiPNN build algorithm requires the 'pipnn' feature to be enabled", + )); + } + match determine_build_strategy::( &self.index_configuration, self.disk_build_param.build_memory_limit().in_bytes() as f64, @@ -326,6 +363,91 @@ where } } + #[cfg(feature = "pipnn")] + async fn build_pipnn_index(&mut self) -> ANNResult<()> { + use diskann_pipnn::builder; + + let config = self.disk_build_param.build_algorithm() + .to_pipnn_config( + self.index_configuration.config.pruned_degree().get(), + self.index_configuration.dist_metric, + self.index_configuration.config.alpha(), + ) + .ok_or_else(|| ANNError::log_index_error( + "build_pipnn_index called but build algorithm is not PiPNN" + ))?; + + config.validate().map_err(|e| { + ANNError::log_index_error(format!("PiPNN config error: {}", e)) + })?; + + info!( + "Building PiPNN index: max_degree={}", + config.max_degree + ); + + let data_path = self.index_writer.get_dataset_file(); + + // Build the PiPNN graph, using pre-trained SQ if available. + let graph = match &self.build_quantizer { + BuildQuantizer::Scalar1Bit(with_bits) => { + // SQ path needs f32 data for quantize_1bit. + let (npoints, ndims, data) = load_data_as_f32::( + &data_path, + self.storage_provider, + )?; + // Use the DiskANN-trained ScalarQuantizer for 1-bit quantization. + // This ensures identical quantization between Vamana and PiPNN builds. + let sq = with_bits.quantizer(); + let scale = sq.scale(); + let inverse_scale = if scale == 0.0 { 1.0 } else { 1.0 / scale }; + let sq_params = builder::SQParams { + shift: sq.shift().to_vec(), + inverse_scale, + }; + info!("Using pre-trained SQ quantizer for PiPNN 1-bit build"); + builder::build_with_sq(&data, npoints, ndims, &config, &sq_params) + .map_err(|e| ANNError::log_index_error(format!("PiPNN build failed: {}", e)))? + } + _ => { + // Full precision or PQ build quantization — load data in native type + // and use build_typed to avoid upfront f32 conversion (saves ~793 MB + // peak RSS for f16 data). + let (npoints, ndims, data) = load_data_typed::( + &data_path, + self.storage_provider, + )?; + builder::build_typed(&data, npoints, ndims, &config) + .map_err(|e| ANNError::log_index_error(format!("PiPNN build failed: {}", e)))? + } + }; + + let save_path = self.index_writer.get_mem_index_file(); + graph.save_graph(std::path::Path::new(&save_path)) + .map_err(|e| ANNError::log_index_error(format!("PiPNN graph save failed: {}", e)))?; + + info!( + "PiPNN build complete: avg_degree={:.1}, max_degree={}, isolated={}, total={:.3}s", + graph.avg_degree(), + graph.max_degree(), + graph.num_isolated(), + graph.build_stats.total_secs + ); + // Print timing breakdown to stdout (tracing goes to OpenTelemetry spans, + // not stdout, so use print! for user-visible output like Vamana does). + print!("{}", graph.build_stats); + + // Mark checkpoint stages as complete so the checkpoint system is consistent. + self.checkpoint_record_manager.execute_stage( + WorkStage::InMemIndexBuild, + WorkStage::WriteDiskLayout, + || Ok(()), + || Ok(()), + )?; + + Ok(()) + } + async fn build_merged_vamana_index(&mut self, pool: &RayonThreadPool) -> ANNResult<()> { let mut logger = PerfLogger::new_disk_index_build_logger(); let mut workflow = MergedVamanaIndexWorkflow::new(self, pool); @@ -480,6 +602,51 @@ where } } +#[cfg(feature = "pipnn")] +fn load_data_as_f32( + data_path: &str, + storage_provider: &SP, +) -> ANNResult<(usize, usize, Vec)> +where + T: VectorRepr, + SP: StorageReadProvider, +{ + let matrix = read_bin::(&mut storage_provider.open_reader(data_path)?)?; + let npoints = matrix.nrows(); + let ndims = matrix.ncols(); + + // Convert to f32 + let mut f32_data = vec![0.0f32; npoints * ndims]; + for i in 0..npoints { + let src = matrix.row(i); + let dst = &mut f32_data[i * ndims..(i + 1) * ndims]; + T::as_f32_into(src, dst) + .map_err(|e| ANNError::log_index_error(format!("Data conversion error: {}", e)))?; + } + + Ok((npoints, ndims, f32_data)) +} + +/// Load data in its native type T without converting to f32. +/// This avoids doubling memory for f16 data by keeping it as f16 in memory +/// and converting to f32 on-the-fly at each access point inside PiPNN. +#[cfg(feature = "pipnn")] +fn load_data_typed( + data_path: &str, + storage_provider: &SP, +) -> ANNResult<(usize, usize, Vec)> +where + T: VectorRepr, + SP: StorageReadProvider, +{ + let matrix = read_bin::(&mut storage_provider.open_reader(data_path)?)?; + let npoints = matrix.nrows(); + let ndims = matrix.ncols(); + let data: Vec = matrix.into_inner().into_vec(); + + Ok((npoints, ndims, data)) +} + #[allow(clippy::too_many_arguments)] async fn build_inmem_index( config: IndexConfiguration, diff --git a/diskann-disk/src/build/configuration/build_algorithm.rs b/diskann-disk/src/build/configuration/build_algorithm.rs new file mode 100644 index 000000000..158849b0d --- /dev/null +++ b/diskann-disk/src/build/configuration/build_algorithm.rs @@ -0,0 +1,283 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Build algorithm selection for graph index construction. + +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// Selects the graph construction algorithm for index building. +/// +/// - `Vamana`: The default incremental insert + prune algorithm. +/// - `PiPNN`: Partition-based batch builder (arXiv:2602.21247). +/// Significantly faster build times at comparable graph quality. +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(tag = "algorithm")] +pub enum BuildAlgorithm { + /// Default Vamana graph construction. + #[default] + Vamana, + + /// PiPNN: Pick-in-Partitions Nearest Neighbors. + PiPNN { + /// Maximum leaf partition size. + #[serde(default = "default_c_max")] + c_max: usize, + /// Minimum cluster size before merging. + #[serde(default = "default_c_min")] + c_min: usize, + /// Sampling fraction for RBC leaders. + #[serde(default = "default_p_samp")] + p_samp: f64, + /// Fanout at each partitioning level. + #[serde(default = "default_fanout")] + fanout: Vec, + /// k-NN within each leaf. + #[serde(default = "default_leaf_k")] + leaf_k: usize, + /// Number of independent partitioning passes. + #[serde(default = "default_replicas")] + replicas: usize, + /// Maximum reservoir size per node in HashPrune. + #[serde(default = "default_l_max")] + l_max: usize, + /// Number of LSH hyperplanes for HashPrune. + #[serde(default = "default_num_hash_planes")] + num_hash_planes: usize, + /// Whether to apply a final RobustPrune pass. + #[serde(default)] + final_prune: bool, + }, +} + +impl fmt::Display for BuildAlgorithm { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BuildAlgorithm::Vamana => write!(f, "Vamana"), + BuildAlgorithm::PiPNN { + c_max, + leaf_k, + replicas, + .. + } => { + write!( + f, + "PiPNN(c_max={}, leaf_k={}, replicas={})", + c_max, leaf_k, replicas + ) + } + } + } +} + +impl BuildAlgorithm { + /// Convert PiPNN build parameters to a PiPNNConfig. + /// `max_degree`, `metric`, and `alpha` come from the DiskANN index configuration. + #[cfg(feature = "pipnn")] + pub fn to_pipnn_config( + &self, + max_degree: usize, + metric: diskann_vector::distance::Metric, + alpha: f32, + ) -> Option { + match self { + BuildAlgorithm::PiPNN { + c_max, c_min, p_samp, fanout, leaf_k, replicas, + l_max, num_hash_planes, final_prune, + } => Some(diskann_pipnn::PiPNNConfig { + c_max: *c_max, + c_min: *c_min, + p_samp: *p_samp, + fanout: fanout.clone(), + k: *leaf_k, + max_degree, + replicas: *replicas, + l_max: *l_max, + num_hash_planes: *num_hash_planes, + metric, + final_prune: *final_prune, + alpha, + }), + _ => None, + } + } +} + +fn default_c_max() -> usize { + 1024 +} +fn default_c_min() -> usize { + 256 +} +fn default_p_samp() -> f64 { + 0.005 +} +fn default_fanout() -> Vec { + vec![10, 3] +} +fn default_leaf_k() -> usize { + 3 +} +fn default_replicas() -> usize { + 1 +} +fn default_l_max() -> usize { + 128 +} +fn default_num_hash_planes() -> usize { + 12 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_algorithm_default_is_vamana() { + let algo = BuildAlgorithm::default(); + assert_eq!(algo, BuildAlgorithm::Vamana, "default BuildAlgorithm should be Vamana"); + } + + #[test] + fn test_build_algorithm_display_vamana() { + let algo = BuildAlgorithm::Vamana; + let display = format!("{}", algo); + assert_eq!(display, "Vamana", "Vamana display should be 'Vamana'"); + } + + #[test] + fn test_build_algorithm_display_pipnn() { + let algo = BuildAlgorithm::PiPNN { + c_max: 2048, + c_min: 512, + p_samp: 0.1, + fanout: vec![5, 3], + leaf_k: 4, + replicas: 2, + l_max: 256, + num_hash_planes: 12, + final_prune: false, + }; + let display = format!("{}", algo); + assert_eq!( + display, + "PiPNN(c_max=2048, leaf_k=4, replicas=2)", + "PiPNN display should include c_max, leaf_k, and replicas" + ); + } + + #[test] + fn test_build_algorithm_serde_roundtrip_vamana() { + let algo = BuildAlgorithm::Vamana; + let json = serde_json::to_string(&algo).expect("serialize Vamana should succeed"); + let deserialized: BuildAlgorithm = + serde_json::from_str(&json).expect("deserialize Vamana should succeed"); + assert_eq!(algo, deserialized, "Vamana should roundtrip through serde_json"); + } + + #[test] + fn test_build_algorithm_serde_roundtrip_pipnn() { + let algo = BuildAlgorithm::PiPNN { + c_max: 2048, + c_min: 512, + p_samp: 0.1, + fanout: vec![5, 3], + leaf_k: 4, + replicas: 2, + l_max: 256, + num_hash_planes: 8, + final_prune: true, + }; + let json = serde_json::to_string(&algo).expect("serialize PiPNN should succeed"); + let deserialized: BuildAlgorithm = + serde_json::from_str(&json).expect("deserialize PiPNN should succeed"); + assert_eq!(algo, deserialized, "PiPNN with all fields should roundtrip through serde_json"); + } + + #[test] + fn test_build_algorithm_serde_pipnn_defaults() { + // Deserialize PiPNN with only the algorithm tag -- all fields should use defaults. + let json = r#"{"algorithm":"PiPNN"}"#; + let deserialized: BuildAlgorithm = + serde_json::from_str(json).expect("PiPNN with defaults should deserialize"); + + let expected = BuildAlgorithm::PiPNN { + c_max: default_c_max(), + c_min: default_c_min(), + p_samp: default_p_samp(), + fanout: default_fanout(), + leaf_k: default_leaf_k(), + replicas: default_replicas(), + l_max: default_l_max(), + num_hash_planes: default_num_hash_planes(), + final_prune: false, + }; + assert_eq!( + deserialized, expected, + "deserializing PiPNN with missing fields should use default values" + ); + } + + #[test] + fn test_build_algorithm_partial_eq() { + let v1 = BuildAlgorithm::Vamana; + let v2 = BuildAlgorithm::Vamana; + assert_eq!(v1, v2, "two Vamana instances should be equal"); + + let p1 = BuildAlgorithm::PiPNN { + c_max: 1024, + c_min: 256, + p_samp: 0.05, + fanout: vec![10, 3], + leaf_k: 3, + replicas: 1, + l_max: 128, + num_hash_planes: 12, + final_prune: false, + }; + let p2 = p1.clone(); + assert_eq!(p1, p2, "cloned PiPNN should equal original"); + + assert_ne!(v1, p1, "Vamana and PiPNN should not be equal"); + + let p3 = BuildAlgorithm::PiPNN { + c_max: 2048, // different + c_min: 256, + p_samp: 0.05, + fanout: vec![10, 3], + leaf_k: 3, + replicas: 1, + l_max: 128, + num_hash_planes: 12, + final_prune: false, + }; + assert_ne!(p1, p3, "PiPNN with different c_max should not be equal"); + } + + #[test] + #[cfg(feature = "pipnn")] + fn test_to_pipnn_config_vamana_returns_none() { + let algo = BuildAlgorithm::Vamana; + assert!(algo.to_pipnn_config(64, diskann_vector::distance::Metric::L2, 1.2).is_none()); + } + + #[test] + #[cfg(feature = "pipnn")] + fn test_to_pipnn_config_pipnn_returns_some() { + let algo = BuildAlgorithm::PiPNN { + c_max: 512, c_min: 128, p_samp: 0.01, fanout: vec![8], + leaf_k: 5, replicas: 1, l_max: 128, num_hash_planes: 12, + final_prune: true, + }; + let config = algo.to_pipnn_config(64, diskann_vector::distance::Metric::L2, 1.2); + assert!(config.is_some()); + let config = config.unwrap(); + assert_eq!(config.c_max, 512); + assert_eq!(config.k, 5); // leaf_k maps to k + assert_eq!(config.max_degree, 64); + assert_eq!(config.alpha, 1.2); + } +} diff --git a/diskann-disk/src/build/configuration/disk_index_build_parameter.rs b/diskann-disk/src/build/configuration/disk_index_build_parameter.rs index 07dfe3c51..c068bf794 100644 --- a/diskann-disk/src/build/configuration/disk_index_build_parameter.rs +++ b/diskann-disk/src/build/configuration/disk_index_build_parameter.rs @@ -10,6 +10,7 @@ use std::num::NonZeroUsize; use diskann::ANNError; use thiserror::Error; +use super::build_algorithm::BuildAlgorithm; use super::QuantizationType; /// GB to bytes ratio. @@ -103,7 +104,7 @@ impl NumPQChunks { } /// Parameters specific for disk index construction. -#[derive(Clone, Copy, PartialEq, Debug)] +#[derive(Clone, PartialEq, Debug)] pub struct DiskIndexBuildParameters { /// Limit on the memory allowed for building the index. build_memory_limit: MemoryBudget, @@ -113,10 +114,14 @@ pub struct DiskIndexBuildParameters { /// QuantizationType used to instantiate quantized DataProvider for DiskANN Index during build. build_quantization: QuantizationType, + + /// Which graph construction algorithm to use (Vamana or PiPNN). + build_algorithm: BuildAlgorithm, } impl DiskIndexBuildParameters { /// Create new build parameters from already validated components. + /// Uses the default Vamana build algorithm. pub fn new( build_memory_limit: MemoryBudget, build_quantization: QuantizationType, @@ -126,6 +131,22 @@ impl DiskIndexBuildParameters { build_memory_limit, search_pq_chunks, build_quantization, + build_algorithm: BuildAlgorithm::default(), + } + } + + /// Create new build parameters with a specific build algorithm. + pub fn new_with_algorithm( + build_memory_limit: MemoryBudget, + build_quantization: QuantizationType, + search_pq_chunks: NumPQChunks, + build_algorithm: BuildAlgorithm, + ) -> Self { + Self { + build_memory_limit, + search_pq_chunks, + build_quantization, + build_algorithm, } } @@ -143,6 +164,11 @@ impl DiskIndexBuildParameters { pub fn search_pq_chunks(&self) -> NumPQChunks { self.search_pq_chunks } + + /// Get the build algorithm to use for graph construction. + pub fn build_algorithm(&self) -> &BuildAlgorithm { + &self.build_algorithm + } } #[cfg(test)] diff --git a/diskann-disk/src/build/configuration/mod.rs b/diskann-disk/src/build/configuration/mod.rs index 25453abd0..7a0e6816c 100644 --- a/diskann-disk/src/build/configuration/mod.rs +++ b/diskann-disk/src/build/configuration/mod.rs @@ -2,6 +2,9 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ +pub mod build_algorithm; +pub use build_algorithm::BuildAlgorithm; + pub mod disk_index_build_parameter; pub use disk_index_build_parameter::{DiskIndexBuildParameters, MemoryBudget, NumPQChunks}; diff --git a/diskann-disk/src/build/mod.rs b/diskann-disk/src/build/mod.rs index c04e1a9c8..2e5df9ae9 100644 --- a/diskann-disk/src/build/mod.rs +++ b/diskann-disk/src/build/mod.rs @@ -14,5 +14,6 @@ pub mod configuration; // Re-export key types for convenience pub use configuration::{ - disk_index_build_parameter, filter_parameter, DiskIndexBuildParameters, QuantizationType, + disk_index_build_parameter, filter_parameter, BuildAlgorithm, DiskIndexBuildParameters, + QuantizationType, }; diff --git a/diskann-disk/src/lib.rs b/diskann-disk/src/lib.rs index 0da6938a3..b936ec00d 100644 --- a/diskann-disk/src/lib.rs +++ b/diskann-disk/src/lib.rs @@ -10,7 +10,8 @@ pub mod build; pub use build::{ - disk_index_build_parameter, filter_parameter, DiskIndexBuildParameters, QuantizationType, + disk_index_build_parameter, filter_parameter, BuildAlgorithm, DiskIndexBuildParameters, + QuantizationType, }; pub mod data_model; diff --git a/diskann-pipnn/Cargo.toml b/diskann-pipnn/Cargo.toml new file mode 100644 index 000000000..2197ce88a --- /dev/null +++ b/diskann-pipnn/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "diskann-pipnn" +version.workspace = true +description = "PiPNN (Pick-in-Partitions Nearest Neighbors) index builder for DiskANN" +authors.workspace = true +documentation.workspace = true +license.workspace = true +edition = "2021" + +[dependencies] +diskann = { workspace = true } +diskann-vector = { workspace = true } +diskann-utils = { workspace = true, default-features = false, features = ["rayon"] } +rayon = { workspace = true } +rand = { workspace = true } +rand_distr = { workspace = true } +bytemuck = { workspace = true, features = ["must_cast"] } +num-traits = { workspace = true } +diskann-linalg = { workspace = true } +half = { workspace = true } +diskann-quantization = { workspace = true } +serde = { workspace = true, features = ["derive"] } +thiserror = { workspace = true } +tracing = { workspace = true } +parking_lot = "0.12" + +[dev-dependencies] +criterion = { workspace = true } +rand = { workspace = true } + +[lints] +workspace = true + +[[bench]] +name = "pipnn_bench" +harness = false diff --git a/diskann-pipnn/README.md b/diskann-pipnn/README.md new file mode 100644 index 000000000..712e4477d --- /dev/null +++ b/diskann-pipnn/README.md @@ -0,0 +1,145 @@ +# PiPNN: Pick-in-Partitions Nearest Neighbors for DiskANN + +A fast graph index builder for [DiskANN](https://github.com/microsoft/DiskANN) based on the [PiPNN algorithm](https://arxiv.org/abs/2602.21247) (Rubel et al., 2026). + +PiPNN replaces Vamana's incremental beam-search insertion with a partition-then-build approach: + +1. **Partition** the dataset into overlapping clusters via Randomized Ball Carving (RBC) +2. **Build** local k-NN graphs within each cluster using GEMM-based all-pairs distance +3. **Merge** edges from overlapping clusters using HashPrune (LSH-based online pruning) +4. **Prune** (optional) with RobustPrune for diversity + +The output is a standard DiskANN graph file that can be loaded and searched by the existing DiskANN infrastructure. + +## Results + +### SIFT-1M (128d, L2, R=64, 1M vectors) + +| Builder | Build Time | Speedup | Recall@10 (L=100) | +|---------|-----------|---------|-------------------| +| DiskANN Vamana | 81.7s | 1.0x | 0.997 | +| **PiPNN** | **8.0s** | **10.2x** | **0.985** | + +### Enron (384d, fp16, cosine_normalized, R=59, 1.09M vectors) + +| Builder | Build Time | Speedup | Recall@1000 (L=2000) | +|---------|-----------|---------|---------------------| +| DiskANN Vamana | 78.1s | 1.0x | 0.950 | +| **PiPNN** | **15.2s** | **5.1x** | **0.949** | + +Speedup scales with dataset size and is highest on lower-dimensional data where GEMM throughput dominates. Hardware: AMD EPYC 7763, 16 cores. + +### Prerequisites + +```bash +sudo apt install libopenblas-dev # Required for GEMM acceleration +``` + +## Build + +```bash +cargo build --release -p diskann-pipnn +``` + +For best performance on your CPU: + +```bash +RUSTFLAGS="-C target-cpu=native" cargo build --release -p diskann-pipnn +``` + +## Usage + +### Build a PiPNN index and save as DiskANN graph + +```bash +./target/release/pipnn-bench \ + --data \ + --max-degree 64 \ + --c-max 2048 --c-min 1024 \ + --leaf-k 4 --fanout "8" \ + --replicas 1 --final-prune \ + --save-path +``` + +The output graph is written in DiskANN's canonical format at ``. Copy or symlink your data file to `.data` for the DiskANN benchmark loader. + +### Key parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--max-degree` | 64 | Maximum graph degree (R) | +| `--c-max` | 1024 | Maximum leaf partition size | +| `--c-min` | c_max/4 | Minimum cluster size before merging | +| `--leaf-k` | 3 | k-NN within each leaf | +| `--fanout` | "10,3" | Overlap factor per partition level (comma-separated) | +| `--replicas` | 1 | Independent partitioning passes | +| `--l-max` | 128 | HashPrune reservoir size per node | +| `--p-samp` | 0.05 | Leader sampling fraction | +| `--final-prune` | false | Apply RobustPrune after HashPrune | +| `--fp16` | false | Read input as fp16 (auto-converts to f32) | +| `--cosine` | false | Use cosine distance (for normalized vectors) | +| `--save-path` | none | Save graph in DiskANN format | + +### Recommended configurations + +**Low-dimensional (d <= 128):** +```bash +--c-max 2048 --c-min 1024 --leaf-k 4 --fanout "8" --p-samp 0.01 --final-prune +``` + +**High-dimensional (d >= 256):** +```bash +--c-max 2048 --c-min 1024 --leaf-k 5 --fanout "8" --p-samp 0.01 --final-prune +``` + +### Search with DiskANN benchmark + +After building, search the graph using the standard DiskANN benchmark: + +```bash +# Symlink your data file +ln -s .data + +# Create a search config (JSON) +cat > search.json << 'EOF' +{ + "search_directories": ["."], + "jobs": [{ + "type": "async-index-build", + "content": { + "source": { + "index-source": "Load", + "data_type": "float32", + "distance": "squared_l2", + "load_path": "", + "max_degree": 64 + }, + "search_phase": { + "search-type": "topk", + "queries": "", + "groundtruth": "", + "reps": 1, + "num_threads": [1], + "runs": [{"search_n": 10, "search_l": [100, 200], "recall_k": 10}] + } + } + }] +} +EOF + +cargo run --release -p diskann-benchmark -- run --input-file search.json --output-file results.json +``` + +## Architecture + +``` +diskann-pipnn/ + src/ + lib.rs - Config and module structure + partition.rs - Randomized Ball Carving with fused GEMM + assignment + leaf_build.rs - GEMM-based all-pairs distance + bi-directed k-NN + hash_prune.rs - LSH-based online pruning with per-point reservoirs + builder.rs - Main PiPNN orchestrator + bin/ + pipnn_bench.rs - CLI benchmark and index writer +``` diff --git a/diskann-pipnn/benches/pipnn_bench.rs b/diskann-pipnn/benches/pipnn_bench.rs new file mode 100644 index 000000000..3d63733f0 --- /dev/null +++ b/diskann-pipnn/benches/pipnn_bench.rs @@ -0,0 +1,305 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Criterion benchmarks for PiPNN hot-path components. +//! +//! Run with: cargo bench -p diskann-pipnn + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use rand::{Rng, SeedableRng}; + +use diskann_pipnn::gemm; +use diskann_pipnn::hash_prune::HashPrune; +use diskann_pipnn::leaf_build; +use diskann_pipnn::partition::{self, PartitionConfig}; +use diskann_pipnn::quantize; + +/// Generate random f32 data for benchmarking. +fn random_data(npoints: usize, ndims: usize, seed: u64) -> Vec { + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + (0..npoints * ndims) + .map(|_| rng.random_range(-1.0f32..1.0f32)) + .collect() +} + +// ================== +// GEMM benchmarks +// ================== + +fn bench_sgemm_aat(c: &mut Criterion) { + let mut group = c.benchmark_group("gemm/sgemm_aat"); + + for &(m, k) in &[(256, 128), (512, 128), (1024, 128), (512, 384)] { + let a = random_data(m, k, 42); + let mut result = vec![0.0f32; m * m]; + + group.throughput(Throughput::Elements((m * m) as u64)); + group.bench_with_input( + BenchmarkId::new("m_x_k", format!("{}x{}", m, k)), + &(m, k), + |b, &(m, k)| { + b.iter(|| { + gemm::sgemm_aat(&a, m, k, &mut result); + }); + }, + ); + } + group.finish(); +} + +fn bench_sgemm_abt(c: &mut Criterion) { + let mut group = c.benchmark_group("gemm/sgemm_abt"); + + for &(m, n, k) in &[(1000, 100, 128), (1000, 100, 384), (10000, 100, 128)] { + let a = random_data(m, k, 42); + let b = random_data(n, k, 99); + let mut result = vec![0.0f32; m * n]; + + group.throughput(Throughput::Elements((m * n) as u64)); + group.bench_with_input( + BenchmarkId::new("m_n_k", format!("{}x{}x{}", m, n, k)), + &(m, n, k), + |b_iter, &(m, k, _)| { + b_iter.iter(|| { + gemm::sgemm_abt(&a, m, k, &b, n, &mut result); + }); + }, + ); + } + group.finish(); +} + +// ======================== +// Quantization benchmarks +// ======================== + +/// Train SQ parameters and quantize. Benchmark helper. +fn train_and_quantize(data: &[f32], npoints: usize, ndims: usize) -> quantize::QuantizedData { + use diskann_quantization::scalar::train::ScalarQuantizationParameters; + use diskann_utils::views::MatrixView; + + let data_matrix = MatrixView::try_from(data, npoints, ndims) + .expect("data length must equal npoints * ndims"); + let quantizer = ScalarQuantizationParameters::default().train(data_matrix); + let shift = quantizer.shift().to_vec(); + let scale = quantizer.scale(); + let inverse_scale = if scale == 0.0 { 1.0 } else { 1.0 / scale }; + quantize::quantize_1bit(data, npoints, ndims, &shift, inverse_scale) +} + +fn bench_hamming_distance_matrix(c: &mut Criterion) { + let mut group = c.benchmark_group("quantize/hamming_matrix"); + + for &n in &[64, 256, 512, 1024] { + let ndims = 128; + let data = random_data(n, ndims, 42); + let qd = train_and_quantize(&data, n, ndims); + let indices: Vec = (0..n).collect(); + + group.throughput(Throughput::Elements((n * n) as u64)); + group.bench_with_input( + BenchmarkId::new("n_points", n), + &n, + |b, _| { + b.iter(|| { + qd.compute_distance_matrix(&indices); + }); + }, + ); + } + group.finish(); +} + +// ======================== +// Leaf build benchmarks +// ======================== + +fn bench_build_leaf(c: &mut Criterion) { + let mut group = c.benchmark_group("leaf_build/build_leaf"); + gemm::set_blas_threads(1); + + for &(n, ndims, k) in &[(128, 128, 3), (512, 128, 4), (1024, 128, 4), (512, 384, 5)] { + let data = random_data(n, ndims, 42); + let indices: Vec = (0..n).collect(); + + group.throughput(Throughput::Elements(n as u64)); + group.bench_with_input( + BenchmarkId::new("n_d_k", format!("{}x{}x{}", n, ndims, k)), + &(), + |b, _| { + b.iter(|| { + leaf_build::build_leaf(&data, ndims, &indices, k, false); + }); + }, + ); + } + group.finish(); +} + +fn bench_build_leaf_quantized(c: &mut Criterion) { + let mut group = c.benchmark_group("leaf_build/build_leaf_quantized"); + + for &(n, ndims, k) in &[(128, 128, 3), (512, 128, 4), (1024, 128, 4)] { + let data = random_data(n, ndims, 42); + let qd = train_and_quantize(&data, n, ndims); + let indices: Vec = (0..n).collect(); + + group.throughput(Throughput::Elements(n as u64)); + group.bench_with_input( + BenchmarkId::new("n_d_k", format!("{}x{}x{}", n, ndims, k)), + &(), + |b, _| { + b.iter(|| { + leaf_build::build_leaf_quantized(&qd, &indices, k); + }); + }, + ); + } + group.finish(); +} + +// ======================== +// HashPrune benchmarks +// ======================== + +fn bench_hash_prune_add_edges(c: &mut Criterion) { + let mut group = c.benchmark_group("hash_prune/add_edges_batched"); + + for &npoints in &[10_000, 100_000] { + let ndims = 128; + let data = random_data(npoints, ndims, 42); + let hp = HashPrune::new(&data, npoints, ndims, 12, 128, 64, 42); + + // Simulate edges from a single leaf + let leaf_size = 512; + let k = 4; + let leaf_data = random_data(leaf_size, ndims, 99); + let leaf_indices: Vec = (0..leaf_size).collect(); + let edges = leaf_build::build_leaf(&leaf_data, ndims, &leaf_indices, k, false); + + group.throughput(Throughput::Elements(edges.len() as u64)); + group.bench_with_input( + BenchmarkId::new("npoints", npoints), + &(), + |b, _| { + b.iter(|| { + hp.add_edges_batched(&edges); + }); + }, + ); + } + group.finish(); +} + +// ========================== +// Partition benchmarks +// ========================== + +fn bench_partition(c: &mut Criterion) { + let mut group = c.benchmark_group("partition/parallel_partition"); + gemm::set_blas_threads(1); + group.sample_size(10); + + for &(npoints, ndims) in &[(10_000, 128), (50_000, 128), (10_000, 384)] { + let data = random_data(npoints, ndims, 42); + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 1024, + c_min: 256, + p_samp: 0.05, + fanout: vec![8], + }; + + group.throughput(Throughput::Elements(npoints as u64)); + group.bench_with_input( + BenchmarkId::new("n_d", format!("{}x{}", npoints, ndims)), + &(), + |b, _| { + b.iter(|| { + partition::parallel_partition(&data, ndims, &indices, &config, 42); + }); + }, + ); + } + group.finish(); +} + +// ========================== +// End-to-end build benchmark +// ========================== + +fn bench_full_build(c: &mut Criterion) { + let mut group = c.benchmark_group("build/full"); + gemm::set_blas_threads(1); + group.sample_size(10); + + for &(npoints, ndims) in &[(1_000, 128), (10_000, 128), (10_000, 384)] { + let data = random_data(npoints, ndims, 42); + let config = diskann_pipnn::PiPNNConfig { + c_max: 512, + c_min: 128, + k: 3, + max_degree: 32, + replicas: 1, + l_max: 64, + p_samp: 0.05, + fanout: vec![8], + ..Default::default() + }; + + group.throughput(Throughput::Elements(npoints as u64)); + group.bench_with_input( + BenchmarkId::new("n_d", format!("{}x{}", npoints, ndims)), + &(), + |b, _| { + b.iter(|| { + diskann_pipnn::builder::build(&data, npoints, ndims, &config).unwrap(); + }); + }, + ); + } + group.finish(); +} + +criterion_group!( + gemm_benches, + bench_sgemm_aat, + bench_sgemm_abt, +); + +criterion_group!( + quantize_benches, + bench_hamming_distance_matrix, +); + +criterion_group!( + leaf_benches, + bench_build_leaf, + bench_build_leaf_quantized, +); + +criterion_group!( + hash_prune_benches, + bench_hash_prune_add_edges, +); + +criterion_group!( + partition_benches, + bench_partition, +); + +criterion_group!( + build_benches, + bench_full_build, +); + +criterion_main!( + gemm_benches, + quantize_benches, + leaf_benches, + hash_prune_benches, + partition_benches, + build_benches, +); diff --git a/diskann-pipnn/examples/sift_1m_pipnn.json b/diskann-pipnn/examples/sift_1m_pipnn.json new file mode 100644 index 000000000..1f3fbad68 --- /dev/null +++ b/diskann-pipnn/examples/sift_1m_pipnn.json @@ -0,0 +1,46 @@ +{ + "search_directories": ["datasets"], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "sift1M_base.fbin", + "distance": "squared_l2", + "dim": 128, + "max_degree": 64, + "l_build": 64, + "num_threads": 8, + "build_ram_limit_gb": 16.0, + "num_pq_chunks": 32, + "quantization_type": "FP", + "build_algorithm": { + "algorithm": "PiPNN", + "c_max": 512, + "c_min": 128, + "leaf_k": 5, + "fanout": [8], + "p_samp": 0.001, + "replicas": 1, + "l_max": 128, + "num_hash_planes": 12, + "final_prune": true + }, + "save_path": "target/tmp/sift_1m_pipnn_index" + }, + "search_phase": { + "queries": "sift1M_query.fbin", + "groundtruth": "sift1M_groundtruth.bin", + "num_threads": 4, + "beam_width": 4, + "search_list": [100, 200], + "recall_at": 10, + "is_flat_search": false, + "distance": "squared_l2" + } + } + } + ] +} diff --git a/diskann-pipnn/src/builder.rs b/diskann-pipnn/src/builder.rs new file mode 100644 index 000000000..9e4f4b8d3 --- /dev/null +++ b/diskann-pipnn/src/builder.rs @@ -0,0 +1,1573 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Main PiPNN builder: orchestrates partitioning, leaf building, and edge merging. +//! +//! Algorithm (from arXiv:2602.21247): +//! 1. G <- empty graph +//! 2. B <- Partition(X) via RBC +//! 3. For each leaf b_i in B (in parallel): +//! edges <- Pick(b_i) // GEMM + bi-directed k-NN +//! G.Prune_And_Add_Edges(edges) // stream to HashPrune +//! 4. Optional: final RobustPrune on each node +//! 5. return G + +use std::time::Instant; + +use diskann::utils::VectorRepr; +use rayon::prelude::*; + +use crate::hash_prune::HashPrune; +use crate::leaf_build; +use crate::partition::{self, PartitionConfig}; +use crate::{PiPNNConfig, PiPNNError, PiPNNResult}; + +/// Ask glibc to return freed pages to the OS. +/// Without this, RSS stays inflated after large temporary allocations +/// (e.g. partition GEMM buffers) even though the memory is freed. +#[cfg(target_os = "linux")] +fn trim_heap() { + unsafe { + extern "C" { fn malloc_trim(pad: usize) -> i32; } + malloc_trim(0); + } +} + +#[cfg(not(target_os = "linux"))] +fn trim_heap() {} + + +use diskann_vector::distance::{Distance, DistanceProvider, Metric}; + +/// Create a DiskANN distance functor for the given metric. +/// +/// Uses the exact same SIMD-accelerated distance implementations as DiskANN: +/// - `L2` → `SquaredL2` (squared euclidean) +/// - `Cosine` → `Cosine` (normalizes + 1 - dot) +/// - `CosineNormalized` → `CosineNormalized` (1 - dot, assumes pre-normalized) +/// - `InnerProduct` → `InnerProduct` (-dot) +fn make_dist_fn(metric: Metric) -> Distance { + >::distance_comparer(metric, None) +} + +/// Timing breakdown for the PiPNN build phases. +#[derive(Debug, Clone, Default)] +pub struct PiPNNBuildStats { + pub total_secs: f64, + pub sketch_secs: f64, + pub partition_secs: f64, + pub leaf_build_secs: f64, + pub extract_secs: f64, + pub final_prune_secs: f64, + pub num_leaves: usize, + pub total_edges: usize, +} + +impl std::fmt::Display for PiPNNBuildStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "PiPNN Build Timing")?; + writeln!(f, " LSH sketches: {:.3}s", self.sketch_secs)?; + writeln!(f, " Partition: {:.3}s ({} leaves)", self.partition_secs, self.num_leaves)?; + writeln!(f, " Leaf build: {:.3}s ({} edges)", self.leaf_build_secs, self.total_edges)?; + writeln!(f, " Graph extract: {:.3}s", self.extract_secs)?; + writeln!(f, " Final prune: {:.3}s", self.final_prune_secs)?; + writeln!(f, " Total: {:.3}s", self.total_secs) + } +} + +/// The result of building a PiPNN index. +#[derive(Debug)] +pub struct PiPNNGraph { + /// Adjacency lists: graph[i] contains the neighbor indices for point i. + pub adjacency: Vec>, + /// Number of points. + pub npoints: usize, + /// Number of dimensions. + pub ndims: usize, + /// Cached medoid (entry point for search). + pub medoid: usize, + /// Distance metric used to build this graph. + pub metric: Metric, + /// Build timing breakdown. + pub build_stats: PiPNNBuildStats, +} + +impl PiPNNGraph { + /// Get neighbors of a point. + pub fn neighbors(&self, idx: usize) -> &[u32] { + &self.adjacency[idx] + } + + /// Get the average out-degree. + pub fn avg_degree(&self) -> f64 { + let total: usize = self.adjacency.iter().map(|adj| adj.len()).sum(); + total as f64 / self.npoints as f64 + } + + /// Get the max out-degree. + pub fn max_degree(&self) -> usize { + self.adjacency.iter().map(|adj| adj.len()).max().unwrap_or(0) + } + + /// Count the number of points with zero out-degree. + pub fn num_isolated(&self) -> usize { + self.adjacency.iter().filter(|adj| adj.is_empty()).count() + } + + /// Save the graph in DiskANN's canonical graph format. + /// + /// Format: + /// Header (24 bytes): + /// - u64 LE: total file size (header + data) + /// - u32 LE: max degree (observed) + /// - u32 LE: start point ID (medoid) + /// - u64 LE: number of additional/frozen points + /// Per node: + /// - u32 LE: number of neighbors + /// - N x u32 LE: neighbor IDs + pub fn save_graph(&self, path: &std::path::Path) -> PiPNNResult<()> { + use std::io::{Write, Seek, SeekFrom, BufWriter}; + use std::fs::File; + + let mut f = BufWriter::new(File::create(path)?); + + let mut index_size: u64 = 24; + let mut observed_max_degree: u32 = 0; + let start_point = self.medoid as u32; + + // Write placeholder header + f.write_all(&index_size.to_le_bytes())?; + f.write_all(&observed_max_degree.to_le_bytes())?; + f.write_all(&start_point.to_le_bytes())?; + // Must be 1 to indicate the medoid is a frozen/start point. + // The disk layout writer uses this to record the frozen point location. + let num_additional: u64 = 1; + f.write_all(&num_additional.to_le_bytes())?; + + // Write per-node adjacency lists + for adj in &self.adjacency { + let num_neighbors = adj.len() as u32; + f.write_all(&num_neighbors.to_le_bytes())?; + for &neighbor in adj { + f.write_all(&neighbor.to_le_bytes())?; + } + observed_max_degree = observed_max_degree.max(num_neighbors); + index_size += (4 + adj.len() * 4) as u64; + } + + // Seek back and write correct header + f.seek(SeekFrom::Start(0))?; + f.write_all(&index_size.to_le_bytes())?; + f.write_all(&observed_max_degree.to_le_bytes())?; + f.flush()?; + + tracing::info!( + path = %path.display(), + npoints = self.npoints, + max_degree = observed_max_degree, + start_point = start_point, + "Saved PiPNN graph in DiskANN format" + ); + + Ok(()) + } + +} + +/// Search is only available for testing. +/// Production search goes through DiskANN's disk-based search pipeline. +#[cfg(test)] +impl PiPNNGraph { + /// Perform greedy graph search starting from the cached medoid. + /// + /// This method is for testing and benchmarking only. Production search + /// should use DiskANN's disk-based search pipeline which operates on the + /// saved graph format. + /// + /// Returns the indices and distances of the `k` approximate nearest neighbors. + pub fn search( + &self, + data: &[f32], + query: &[f32], + k: usize, + search_list_size: usize, + ) -> Vec<(usize, f32)> { + let ndims = self.ndims; + let npoints = self.npoints; + + if npoints == 0 { + return Vec::new(); + } + + let dist_fn = make_dist_fn(self.metric); + + let start = self.medoid; + + // Greedy beam search. + let l = search_list_size.max(k); + let mut visited = vec![false; npoints]; + let mut candidates: Vec<(usize, f32)> = Vec::with_capacity(l + 1); + + let start_dist = dist_fn.call( + &data[start * ndims..(start + 1) * ndims], + query, + ); + candidates.push((start, start_dist)); + visited[start] = true; + + let mut pointer = 0; + + while pointer < candidates.len() { + let (current, _) = candidates[pointer]; + pointer += 1; + + for &neighbor in &self.adjacency[current] { + let neighbor = neighbor as usize; + if neighbor >= npoints || visited[neighbor] { + continue; + } + visited[neighbor] = true; + + let dist = dist_fn.call( + &data[neighbor * ndims..(neighbor + 1) * ndims], + query, + ); + + if candidates.len() < l || dist < candidates.last().map(|c| c.1).unwrap_or(f32::MAX) { + let pos = candidates + .binary_search_by(|c| { + c.1.partial_cmp(&dist).unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap_or_else(|e| e); + candidates.insert(pos, (neighbor, dist)); + if candidates.len() > l { + candidates.truncate(l); + } + + if pos < pointer { + pointer = pos; + } + } + } + } + + candidates.truncate(k); + candidates + } +} + +/// Find the medoid: the point closest to the centroid. +/// +/// Uses squared L2 distance to find the nearest point to the centroid, +/// matching DiskANN's `find_medoid_with_sampling` behavior. The centroid +/// is a geometric center, so L2 is the natural metric regardless of the +/// build distance metric. +fn find_medoid(data: &[T], npoints: usize, ndims: usize) -> usize { + let dist_fn = make_dist_fn(Metric::L2); + + // Compute centroid. + let mut centroid = vec![0.0f32; ndims]; + let mut point_buf = vec![0.0f32; ndims]; + for i in 0..npoints { + T::as_f32_into(&data[i * ndims..(i + 1) * ndims], &mut point_buf).expect("f32 conversion"); + for d in 0..ndims { + centroid[d] += point_buf[d]; + } + } + let inv_n = 1.0 / npoints as f32; + for d in 0..ndims { + centroid[d] *= inv_n; + } + + let mut best_idx = 0; + let mut best_dist = f32::MAX; + for i in 0..npoints { + T::as_f32_into(&data[i * ndims..(i + 1) * ndims], &mut point_buf).expect("f32 conversion"); + let dist = dist_fn.call(&point_buf, ¢roid); + if dist < best_dist { + best_dist = dist; + best_idx = i; + } + } + + best_idx +} + +/// Build a PiPNN index from typed vector data. +/// +/// Keeps data in its native type T and converts to f32 on-the-fly at each access point. +/// For f16 data this saves ~793 MB peak RSS compared to upfront conversion. +/// `data` is a flat slice of `T` in row-major order: npoints x ndims. +pub fn build_typed( + data: &[T], + npoints: usize, + ndims: usize, + config: &PiPNNConfig, +) -> PiPNNResult { + config.validate()?; + + let expected_len = npoints * ndims; + if data.len() != expected_len { + return Err(PiPNNError::DataLengthMismatch { + expected: expected_len, + actual: data.len(), + npoints, + ndims, + }); + } + + if npoints == 0 || ndims == 0 { + return Err(PiPNNError::Config( + "npoints and ndims must be > 0".into(), + )); + } + + tracing::info!( + npoints = npoints, + ndims = ndims, + k = config.k, + max_degree = config.max_degree, + c_max = config.c_max, + replicas = config.replicas, + "PiPNN build started (typed)" + ); + + build_internal(data, npoints, ndims, config, None) +} + +/// Build a PiPNN index. +/// +/// `data` is row-major: npoints x ndims. +pub fn build(data: &[f32], npoints: usize, ndims: usize, config: &PiPNNConfig) -> PiPNNResult { + config.validate()?; + + if npoints == 0 || ndims == 0 { + return Err(PiPNNError::Config( + "npoints and ndims must be > 0".into(), + )); + } + + if data.len() != npoints * ndims { + return Err(PiPNNError::DataLengthMismatch { + expected: npoints * ndims, + actual: data.len(), + npoints, + ndims, + }); + } + + tracing::info!( + npoints = npoints, + ndims = ndims, + k = config.k, + max_degree = config.max_degree, + c_max = config.c_max, + replicas = config.replicas, + "PiPNN build started" + ); + + // The build() path always builds at full precision with f32 data. + // For quantized builds, use build_with_sq() which accepts pre-trained SQ params. + build_internal::(data, npoints, ndims, config, None) +} + +/// Pre-trained scalar quantizer parameters for 1-bit quantization. +/// +/// These can be extracted from DiskANN's trained `ScalarQuantizer` to ensure +/// identical quantization between Vamana and PiPNN builds. +pub struct SQParams { + /// Per-dimension shift (length = ndims). + pub shift: Vec, + /// Global inverse scale (1.0 / scale). + pub inverse_scale: f32, +} + +/// Build a PiPNN index using a pre-trained scalar quantizer for 1-bit mode. +/// +/// When DiskANN's build pipeline has already trained a `ScalarQuantizer`, +/// this function reuses those parameters instead of training from scratch. +/// This ensures identical quantization between Vamana and PiPNN builds. +/// +/// `data` is row-major f32: npoints x ndims. +pub fn build_with_sq( + data: &[f32], + npoints: usize, + ndims: usize, + config: &PiPNNConfig, + sq_params: &SQParams, +) -> PiPNNResult { + config.validate()?; + + if data.len() != npoints * ndims { + return Err(PiPNNError::DataLengthMismatch { + expected: npoints * ndims, + actual: data.len(), + npoints, + ndims, + }); + } + if npoints == 0 || ndims == 0 { + return Err(PiPNNError::Config("npoints and ndims must be > 0".into())); + } + if sq_params.shift.len() != ndims { + return Err(PiPNNError::DimensionMismatch { + expected: ndims, + actual: sq_params.shift.len(), + }); + } + + tracing::info!( + npoints = npoints, + ndims = ndims, + k = config.k, + max_degree = config.max_degree, + c_max = config.c_max, + replicas = config.replicas, + "PiPNN build started (with pre-trained SQ)" + ); + + // Quantize using pre-trained parameters. + tracing::info!("Quantizing to 1-bit with pre-trained SQ params"); + let t = Instant::now(); + let qdata = crate::quantize::quantize_1bit( + data, + npoints, + ndims, + &sq_params.shift, + sq_params.inverse_scale, + ); + tracing::info!( + elapsed_secs = t.elapsed().as_secs_f64(), + bytes_per_vec = qdata.bytes_per_vec, + "Quantization complete (pre-trained SQ)" + ); + + // Build using the internal build loop with pre-quantized data. + build_internal::(data, npoints, ndims, config, Some(qdata)) +} + +/// Internal build logic shared between `build()`, `build_typed()`, and `build_with_sq()`. +fn build_internal( + data: &[T], + npoints: usize, + ndims: usize, + config: &PiPNNConfig, + qdata: Option, +) -> PiPNNResult { + let t_total = Instant::now(); + + // Compute medoid once upfront. + let medoid = find_medoid(data, npoints, ndims); + + // Initialize HashPrune for edge merging. + let t0 = Instant::now(); + let hash_prune = HashPrune::new( + data, + npoints, + ndims, + config.num_hash_planes, + config.l_max, + config.max_degree, + 42, + ); + let sketch_secs = t0.elapsed().as_secs_f64(); + tracing::info!(elapsed_secs = sketch_secs, "HashPrune init complete"); + + // Run multiple replicas of partitioning + leaf building. + let mut partition_secs = 0.0f64; + let mut leaf_build_secs = 0.0f64; + let mut total_leaves = 0usize; + let mut total_edges_count = 0usize; + + for replica in 0..config.replicas { + let seed = 1000 + replica as u64 * 7919; + + let t1 = Instant::now(); + let partition_config = PartitionConfig { + c_max: config.c_max, + c_min: config.c_min, + p_samp: config.p_samp, + fanout: config.fanout.clone(), + metric: config.metric, + }; + + let indices: Vec = (0..npoints).collect(); + let leaves = if let Some(ref q) = qdata { + partition::parallel_partition_quantized(q, &indices, &partition_config, seed) + } else { + partition::parallel_partition(data, ndims, &indices, &partition_config, seed) + }; + partition_secs += t1.elapsed().as_secs_f64(); + + let total_pts: usize = leaves.iter().map(|l| l.indices.len()).sum(); + let leaf_sizes: Vec = leaves.iter().map(|l| l.indices.len()).collect(); + total_leaves += leaves.len(); + let small_leaves = leaf_sizes.iter().filter(|&&s| s < 64).count(); + let med_leaves = leaf_sizes.iter().filter(|&&s| s >= 64 && s < 512).count(); + let big_leaves = leaf_sizes.iter().filter(|&&s| s >= 512).count(); + tracing::info!( + replica = replica, + partition_secs = t1.elapsed().as_secs_f64(), + num_leaves = leaves.len(), + avg_leaf_size = total_pts as f64 / leaves.len().max(1) as f64, + max_leaf_size = leaf_sizes.iter().max().unwrap_or(&0), + total_pts = total_pts, + "Partition complete" + ); + // Return freed partition GEMM buffers to the OS so they don't inflate + // peak RSS during the subsequent leaf build + reservoir filling phase. + trim_heap(); + tracing::debug!( + small_leaves = small_leaves, + med_leaves = med_leaves, + big_leaves = big_leaves, + overlap = total_pts as f64 / npoints as f64, + "Leaf size distribution" + ); + + // Build leaves in parallel, streaming edges to HashPrune per-leaf. + let t2 = Instant::now(); + + use std::sync::atomic::{AtomicUsize, Ordering}; + let total_edges = AtomicUsize::new(0); + + leaves.par_iter().for_each(|leaf| { + let edges = if let Some(ref q) = qdata { + leaf_build::build_leaf_quantized(q, &leaf.indices, config.k) + } else { + leaf_build::build_leaf(data, ndims, &leaf.indices, config.k, config.metric) + }; + total_edges.fetch_add(edges.len(), Ordering::Relaxed); + hash_prune.add_edges_batched(&edges); + }); + + let replica_edges = total_edges.load(Ordering::Relaxed); + total_edges_count += replica_edges; + leaf_build_secs += t2.elapsed().as_secs_f64(); + + tracing::info!( + replica = replica, + elapsed_secs = t2.elapsed().as_secs_f64(), + total_edges = replica_edges, + "Leaf build and merge complete" + ); + } + + // Release thread-local leaf buffers so their arena pages can be reclaimed. + (0..rayon::current_num_threads()).into_par_iter().for_each(|_| { + leaf_build::release_thread_buffers(); + }); + + // Extract final graph from HashPrune. + let t3 = Instant::now(); + let adjacency = hash_prune.extract_graph(); + let extract_secs = t3.elapsed().as_secs_f64(); + tracing::info!(elapsed_secs = extract_secs, "Graph extraction complete"); + trim_heap(); + + // Optional final prune pass. + let t4 = Instant::now(); + let adjacency = if config.final_prune { + tracing::info!("Applying final prune"); + final_prune(data, ndims, &adjacency, config.max_degree, config.metric, config.alpha) + } else { + adjacency + }; + let final_prune_secs = t4.elapsed().as_secs_f64(); + + let total_secs = t_total.elapsed().as_secs_f64(); + + let build_stats = PiPNNBuildStats { + total_secs, + sketch_secs, + partition_secs, + leaf_build_secs, + extract_secs, + final_prune_secs, + num_leaves: total_leaves, + total_edges: total_edges_count, + }; + + let graph = PiPNNGraph { + adjacency, + npoints, + ndims, + medoid, + metric: config.metric, + build_stats, + }; + + // Return all freed memory (reservoirs, sketches, partition buffers, leaf buffers) + // to the OS before handing off to the disk layout phase. + trim_heap(); + + tracing::info!( + avg_degree = graph.avg_degree(), + max_degree = graph.max_degree(), + isolated = graph.num_isolated(), + "PiPNN build complete" + ); + + Ok(graph) +} + +/// RobustPrune-like final pass: diversity-aware pruning via alpha-pruning. +/// Uses the same occlusion factor (alpha) as DiskANN's RobustPrune. +fn final_prune( + data: &[T], + ndims: usize, + adjacency: &[Vec], + max_degree: usize, + metric: Metric, + alpha: f32, +) -> Vec> { + let dist_fn = make_dist_fn(metric); + + adjacency + .par_iter() + .enumerate() + .map(|(i, neighbors)| { + if neighbors.len() <= max_degree { + return neighbors.clone(); + } + + let mut point_i = vec![0.0f32; ndims]; + T::as_f32_into(&data[i * ndims..(i + 1) * ndims], &mut point_i).expect("f32 conversion"); + + // Compute distances from i to all its current neighbors. + let mut point_buf = vec![0.0f32; ndims]; + let mut candidates: Vec<(u32, f32)> = neighbors + .iter() + .map(|&j| { + T::as_f32_into(&data[j as usize * ndims..(j as usize + 1) * ndims], &mut point_buf).expect("f32 conversion"); + let dist = dist_fn.call(&point_i, &point_buf); + (j, dist) + }) + .collect(); + + candidates.sort_unstable_by(|a, b| { + a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) + }); + + // Greedy diversity-aware selection. + let mut selected: Vec = Vec::with_capacity(max_degree); + + let mut point_sel = vec![0.0f32; ndims]; + let mut point_cand = vec![0.0f32; ndims]; + for &(cand_id, cand_dist) in &candidates { + if selected.len() >= max_degree { + break; + } + + T::as_f32_into(&data[cand_id as usize * ndims..(cand_id as usize + 1) * ndims], &mut point_cand).expect("f32 conversion"); + let is_pruned = selected.iter().any(|&sel_id| { + T::as_f32_into(&data[sel_id as usize * ndims..(sel_id as usize + 1) * ndims], &mut point_sel).expect("f32 conversion"); + let dist_sel_cand = dist_fn.call(&point_sel, &point_cand); + dist_sel_cand * alpha < cand_dist + }); + + if !is_pruned { + selected.push(cand_id); + } + } + + // Fill remaining from sorted list. + if selected.len() < max_degree { + let selected_set: std::collections::HashSet = selected.iter().copied().collect(); + for &(cand_id, _) in &candidates { + if selected.len() >= max_degree { + break; + } + if !selected_set.contains(&cand_id) { + selected.push(cand_id); + } + } + } + + selected + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn generate_random_data(npoints: usize, ndims: usize, seed: u64) -> Vec { + use rand::{Rng, SeedableRng}; + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + (0..npoints * ndims) + .map(|_| rng.random_range(-1.0f32..1.0f32)) + .collect() + } + + #[test] + fn test_build_small() { + let npoints = 100; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + + let graph = build(&data, npoints, ndims, &config).unwrap(); + + assert_eq!(graph.npoints, npoints); + assert!(graph.avg_degree() > 0.0); + assert!(graph.num_isolated() < npoints); + } + + #[test] + fn test_build_data_length_mismatch() { + let data = vec![0.0f32; 10]; + let config = PiPNNConfig::default(); + + let result = build(&data, 5, 3, &config); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, PiPNNError::DataLengthMismatch { .. })); + } + + #[test] + fn test_search_basic() { + let npoints = 200; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + + let config = PiPNNConfig { + c_max: 64, + c_min: 16, + k: 4, + max_degree: 32, + replicas: 2, + l_max: 64, + ..Default::default() + }; + + let graph = build(&data, npoints, ndims, &config).unwrap(); + + let query = &data[0..ndims]; + let results = graph.search(&data, query, 10, 50); + + assert!(!results.is_empty()); + assert_eq!(results[0].0, 0); + assert!(results[0].1 < 1e-6); + } + + #[test] + fn test_recall() { + use crate::leaf_build::brute_force_knn; + + let npoints = 500; + let ndims = 16; + let data = generate_random_data(npoints, ndims, 42); + + let config = PiPNNConfig { + c_max: 128, + c_min: 32, + k: 4, + max_degree: 32, + replicas: 2, + l_max: 64, + ..Default::default() + }; + + let graph = build(&data, npoints, ndims, &config).unwrap(); + + let k = 10; + let search_l = 100; + let num_queries = 20; + + use rand::{Rng, SeedableRng}; + let mut rng = rand::rngs::StdRng::seed_from_u64(999); + let mut total_recall = 0.0; + + for _ in 0..num_queries { + let query: Vec = (0..ndims).map(|_| rng.random_range(-1.0f32..1.0f32)).collect(); + + let approx = graph.search(&data, &query, k, search_l); + let exact = brute_force_knn(&data, ndims, npoints, &query, k); + + let exact_set: std::collections::HashSet = + exact.iter().map(|&(id, _)| id).collect(); + let recall = approx + .iter() + .filter(|&&(id, _)| exact_set.contains(&id)) + .count() as f64 + / k as f64; + + total_recall += recall; + } + + let avg_recall = total_recall / num_queries as f64; + eprintln!("Average recall@{}: {:.4}", k, avg_recall); + + assert!( + avg_recall > 0.2, + "recall too low: {:.4}", + avg_recall + ); + } + + #[test] + fn test_config_validate() { + let config = PiPNNConfig::default(); + assert!(config.validate().is_ok()); + + let bad = PiPNNConfig { c_max: 0, ..Default::default() }; + assert!(bad.validate().is_err()); + + let bad = PiPNNConfig { c_min: 0, ..Default::default() }; + assert!(bad.validate().is_err()); + + let bad = PiPNNConfig { c_min: 2048, c_max: 1024, ..Default::default() }; + assert!(bad.validate().is_err()); + + let bad = PiPNNConfig { p_samp: 0.0, ..Default::default() }; + assert!(bad.validate().is_err()); + + let bad = PiPNNConfig { p_samp: 1.5, ..Default::default() }; + assert!(bad.validate().is_err()); + + let bad = PiPNNConfig { fanout: vec![], ..Default::default() }; + assert!(bad.validate().is_err()); + + let bad = PiPNNConfig { fanout: vec![0], ..Default::default() }; + assert!(bad.validate().is_err()); + + let bad = PiPNNConfig { num_hash_planes: 0, ..Default::default() }; + assert!(bad.validate().is_err()); + + let bad = PiPNNConfig { num_hash_planes: 17, ..Default::default() }; + assert!(bad.validate().is_err()); + + } + + #[test] + fn test_config_validate_failures() { + // max_degree = 0 + let bad = PiPNNConfig { max_degree: 0, ..Default::default() }; + assert!(bad.validate().is_err()); + + // k = 0 + let bad = PiPNNConfig { k: 0, ..Default::default() }; + assert!(bad.validate().is_err()); + + // replicas = 0 + let bad = PiPNNConfig { replicas: 0, ..Default::default() }; + assert!(bad.validate().is_err()); + + // l_max = 0 + let bad = PiPNNConfig { l_max: 0, ..Default::default() }; + assert!(bad.validate().is_err()); + + // p_samp exactly 1.0 is valid + let ok = PiPNNConfig { p_samp: 1.0, ..Default::default() }; + assert!(ok.validate().is_ok()); + + // num_hash_planes = 1 (boundary) is valid + let ok = PiPNNConfig { num_hash_planes: 1, ..Default::default() }; + assert!(ok.validate().is_ok()); + + // num_hash_planes = 16 (boundary) is valid + let ok = PiPNNConfig { num_hash_planes: 16, ..Default::default() }; + assert!(ok.validate().is_ok()); + } + + #[test] + fn test_build_cosine() { + let npoints = 100; + let ndims = 8; + // Generate random data and normalize each vector for cosine. + let mut data = generate_random_data(npoints, ndims, 42); + for i in 0..npoints { + let row = &mut data[i * ndims..(i + 1) * ndims]; + let norm: f32 = row.iter().map(|v| v * v).sum::().sqrt(); + if norm > 0.0 { + for v in row.iter_mut() { *v /= norm; } + } + } + + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + metric: diskann_vector::distance::Metric::Cosine, + ..Default::default() + }; + + let graph = build(&data, npoints, ndims, &config).unwrap(); + assert!(matches!(graph.metric, Metric::Cosine)); + assert_eq!(graph.npoints, npoints); + assert!(graph.avg_degree() > 0.0); + } + + /// Train SQ parameters from data. Test-only helper. + fn train_sq_params(data: &[f32], npoints: usize, ndims: usize) -> SQParams { + use diskann_quantization::scalar::train::ScalarQuantizationParameters; + use diskann_utils::views::MatrixView; + + let data_matrix = MatrixView::try_from(data, npoints, ndims) + .expect("data length must equal npoints * ndims"); + let quantizer = ScalarQuantizationParameters::default().train(data_matrix); + let shift = quantizer.shift().to_vec(); + let scale = quantizer.scale(); + let inverse_scale = if scale == 0.0 { 1.0 } else { 1.0 / scale }; + SQParams { shift, inverse_scale } + } + + #[test] + fn test_build_with_sq() { + let npoints = 100; + let ndims = 64; // must be multiple of 64 for u64 alignment in quantize + let data = generate_random_data(npoints, ndims, 42); + + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + + let sq_params = train_sq_params(&data, npoints, ndims); + + let graph = super::build_with_sq(&data, npoints, ndims, &config, &sq_params).unwrap(); + assert_eq!(graph.npoints, npoints); + assert!(graph.avg_degree() > 0.0); + } + + #[test] + fn test_build_typed_f32() { + let npoints = 60; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + + let graph_direct = build(&data, npoints, ndims, &config).unwrap(); + let graph_typed = build_typed::(&data, npoints, ndims, &config).unwrap(); + + // Both should produce the same npoints and medoid. + assert_eq!(graph_direct.npoints, graph_typed.npoints); + assert_eq!(graph_direct.medoid, graph_typed.medoid); + } + + #[test] + fn test_save_graph_format() { + let npoints = 50; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + + let graph = build(&data, npoints, ndims, &config).unwrap(); + + let dir = std::env::temp_dir().join("pipnn_test_save_graph"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("test_graph.bin"); + graph.save_graph(&path).unwrap(); + + // Read back and verify the header. + let bytes = std::fs::read(&path).unwrap(); + assert!(bytes.len() >= 24, "file too small: {} bytes", bytes.len()); + + // First 8 bytes: u64 LE file size. + let file_size = u64::from_le_bytes(bytes[0..8].try_into().unwrap()); + assert_eq!(file_size as usize, bytes.len(), "header file_size mismatch"); + + // Bytes 8..12: u32 LE max degree. + let max_deg = u32::from_le_bytes(bytes[8..12].try_into().unwrap()); + assert_eq!(max_deg as usize, graph.max_degree()); + + // Bytes 12..16: u32 LE start point (medoid). + let start_pt = u32::from_le_bytes(bytes[12..16].try_into().unwrap()); + assert_eq!(start_pt as usize, graph.medoid); + + // Clean up. + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn test_medoid_is_valid() { + let npoints = 100; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + + let graph = build(&data, npoints, ndims, &config).unwrap(); + assert!( + graph.medoid < npoints, + "medoid {} is out of range [0, {})", + graph.medoid, + npoints + ); + } + + #[test] + fn test_graph_connectivity() { + // With sufficient replicas and params, no nodes should be isolated. + let npoints = 200; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + + let config = PiPNNConfig { + c_max: 64, + c_min: 16, + k: 4, + max_degree: 32, + replicas: 2, + l_max: 64, + ..Default::default() + }; + + let graph = build(&data, npoints, ndims, &config).unwrap(); + + // With these settings no node should be completely isolated. + assert_eq!( + graph.num_isolated(), 0, + "found {} isolated nodes with replicas=2", + graph.num_isolated() + ); + } + + #[test] + fn test_build_zero_npoints() { + let data: Vec = vec![]; + let config = PiPNNConfig::default(); + let result = build(&data, 0, 8, &config); + assert!(result.is_err(), "npoints=0 should error"); + } + + #[test] + fn test_build_zero_ndims() { + let data: Vec = vec![]; + let config = PiPNNConfig::default(); + let result = build(&data, 10, 0, &config); + assert!(result.is_err(), "ndims=0 should error"); + } + + #[test] + fn test_build_single_point() { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let config = PiPNNConfig { + c_max: 32, + c_min: 1, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + let graph = build(&data, 1, 4, &config).unwrap(); + assert_eq!(graph.npoints, 1, "should have 1 point"); + assert_eq!(graph.adjacency[0].len(), 0, "single point should have 0 edges"); + } + + #[test] + fn test_build_two_points() { + let data = vec![0.0f32, 0.0, 1.0, 0.0]; + let config = PiPNNConfig { + c_max: 32, + c_min: 1, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + let graph = build(&data, 2, 2, &config).unwrap(); + assert_eq!(graph.npoints, 2, "should have 2 points"); + // With 2 points, they should connect to each other. + let total_edges: usize = graph.adjacency.iter().map(|a| a.len()).sum(); + assert!(total_edges > 0, "two points should have at least one edge between them"); + } + + #[test] + fn test_build_duplicate_points() { + // All identical points; build should still succeed. + let npoints = 20; + let ndims = 4; + let data = vec![1.0f32; npoints * ndims]; + let config = PiPNNConfig { + c_max: 32, + c_min: 4, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + let graph = build(&data, npoints, ndims, &config).unwrap(); + assert_eq!(graph.npoints, npoints, "should build successfully with duplicate points"); + } + + #[test] + fn test_build_very_small_k() { + let npoints = 50; + let ndims = 4; + let data = generate_random_data(npoints, ndims, 42); + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 1, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + let graph = build(&data, npoints, ndims, &config).unwrap(); + assert_eq!(graph.npoints, npoints, "k=1 should produce valid graph"); + assert!(graph.avg_degree() > 0.0, "k=1 should still produce some edges"); + } + + #[test] + fn test_build_k_larger_than_leaf() { + // k > c_max should still work (clamped inside extract_knn). + let npoints = 50; + let ndims = 4; + let data = generate_random_data(npoints, ndims, 42); + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 100, // larger than c_max + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + let graph = build(&data, npoints, ndims, &config).unwrap(); + assert_eq!(graph.npoints, npoints, "k > c_max should still produce valid graph"); + } + + #[test] + fn test_search_empty_graph() { + let graph = PiPNNGraph { + adjacency: vec![], + npoints: 0, + ndims: 4, + medoid: 0, + metric: Metric::L2, + build_stats: Default::default(), + }; + let query = vec![1.0f32, 2.0, 3.0, 4.0]; + let results = graph.search(&[], &query, 5, 10); + assert!(results.is_empty(), "search on empty graph should return empty results"); + } + + #[test] + fn test_search_k_larger_than_npoints() { + let npoints = 10; + let ndims = 4; + let data = generate_random_data(npoints, ndims, 42); + let config = PiPNNConfig { + c_max: 32, + c_min: 4, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + let graph = build(&data, npoints, ndims, &config).unwrap(); + let query = &data[0..ndims]; + // Request more neighbors than points exist. + let results = graph.search(&data, query, 100, 200); + assert!( + results.len() <= npoints, + "should not return more results than npoints, got {}", + results.len() + ); + } + + #[test] + fn test_search_with_self_query() { + let npoints = 100; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + let config = PiPNNConfig { + c_max: 64, + c_min: 16, + k: 4, + max_degree: 32, + replicas: 2, + l_max: 64, + ..Default::default() + }; + let graph = build(&data, npoints, ndims, &config).unwrap(); + // Query with the medoid point itself. + let medoid = graph.medoid; + let query = &data[medoid * ndims..(medoid + 1) * ndims]; + let results = graph.search(&data, query, 5, 50); + assert!(!results.is_empty(), "search should return at least one result"); + assert_eq!( + results[0].0, medoid, + "searching with a data point should find itself first" + ); + assert!( + results[0].1 < 1e-6, + "self-distance should be near zero, got {}", + results[0].1 + ); + } + + #[test] + fn test_search_different_l_values() { + use crate::leaf_build::brute_force_knn; + + let npoints = 300; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + let config = PiPNNConfig { + c_max: 64, + c_min: 16, + k: 4, + max_degree: 32, + replicas: 2, + l_max: 64, + ..Default::default() + }; + let graph = build(&data, npoints, ndims, &config).unwrap(); + + let k = 10; + let query = &data[0..ndims]; + let exact = brute_force_knn(&data, ndims, npoints, query, k); + let exact_set: std::collections::HashSet = + exact.iter().map(|&(id, _)| id).collect(); + + // Compare recall for small L vs large L. + let results_small_l = graph.search(&data, query, k, k); + let recall_small: f64 = results_small_l + .iter() + .filter(|&&(id, _)| exact_set.contains(&id)) + .count() as f64 + / k as f64; + + let results_large_l = graph.search(&data, query, k, 200); + let recall_large: f64 = results_large_l + .iter() + .filter(|&&(id, _)| exact_set.contains(&id)) + .count() as f64 + / k as f64; + + assert!( + recall_large >= recall_small, + "larger L ({:.4}) should give recall >= smaller L ({:.4})", + recall_large, + recall_small + ); + } + + #[test] + fn test_build_with_sq_wrong_shift_dims() { + let npoints = 50; + let ndims = 64; + let data = generate_random_data(npoints, ndims, 42); + let config = PiPNNConfig { + c_max: 32, + c_min: 8, + k: 3, + max_degree: 16, + replicas: 1, + l_max: 32, + ..Default::default() + }; + // Shift length != ndims. + let sq_params = SQParams { + shift: vec![0.0f32; ndims + 5], // wrong length + inverse_scale: 1.0, + }; + let result = build_with_sq(&data, npoints, ndims, &config, &sq_params); + assert!( + result.is_err(), + "shift length != ndims should produce an error" + ); + assert!( + matches!(result.unwrap_err(), PiPNNError::DimensionMismatch { .. }), + "should be a DimensionMismatch error" + ); + } + + #[test] + fn test_build_with_sq_produces_connected_graph() { + let npoints = 100; + let ndims = 64; + let data = generate_random_data(npoints, ndims, 42); + let config = PiPNNConfig { + c_max: 64, + c_min: 16, + k: 4, + max_degree: 32, + replicas: 2, + l_max: 64, + ..Default::default() + }; + let sq_params = train_sq_params(&data, npoints, ndims); + let graph = build_with_sq(&data, npoints, ndims, &config, &sq_params).unwrap(); + assert_eq!( + graph.num_isolated(), 0, + "build_with_sq should produce a connected graph with sufficient replicas, found {} isolated nodes", + graph.num_isolated() + ); + } + + #[test] + fn test_build_typed_data_length_mismatch() { + let data = vec![1.0f32; 30]; // 30 elements + let config = PiPNNConfig::default(); + // npoints=5, ndims=8 expects 40 elements but data has 30. + let result = build_typed::(&data, 5, 8, &config); + assert!( + result.is_err(), + "data length mismatch should produce an error" + ); + } + + #[test] + fn test_save_graph_single_node() { + let graph = PiPNNGraph { + adjacency: vec![vec![]], + npoints: 1, + ndims: 4, + medoid: 0, + metric: Metric::L2, + build_stats: Default::default(), + }; + let dir = std::env::temp_dir().join("pipnn_test_save_single"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("single.bin"); + graph.save_graph(&path).unwrap(); + + let bytes = std::fs::read(&path).unwrap(); + assert!(bytes.len() >= 24, "file too small for single node graph"); + let file_size = u64::from_le_bytes(bytes[0..8].try_into().unwrap()); + assert_eq!(file_size as usize, bytes.len(), "header file_size mismatch for single node"); + + // Max degree should be 0 for single node with no edges. + let max_deg = u32::from_le_bytes(bytes[8..12].try_into().unwrap()); + assert_eq!(max_deg, 0, "single node with no edges should have max_degree=0"); + + // Read back neighbor count for the single node. + let num_neighbors = u32::from_le_bytes(bytes[24..28].try_into().unwrap()); + assert_eq!(num_neighbors, 0, "single node should have 0 neighbors"); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn test_save_graph_large() { + let npoints = 1000; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + let config = PiPNNConfig { + c_max: 128, + c_min: 32, + k: 4, + max_degree: 32, + replicas: 1, + l_max: 64, + ..Default::default() + }; + let graph = build(&data, npoints, ndims, &config).unwrap(); + + let dir = std::env::temp_dir().join("pipnn_test_save_large"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("large.bin"); + graph.save_graph(&path).unwrap(); + + // Read back and verify we can parse all adjacency lists. + let bytes = std::fs::read(&path).unwrap(); + let file_size = u64::from_le_bytes(bytes[0..8].try_into().unwrap()); + assert_eq!(file_size as usize, bytes.len(), "header file_size mismatch for large graph"); + + let mut offset = 24usize; + let mut total_parsed_nodes = 0usize; + while offset < bytes.len() { + let num_neighbors = u32::from_le_bytes( + bytes[offset..offset + 4].try_into().unwrap() + ) as usize; + offset += 4; + for _ in 0..num_neighbors { + let neighbor = u32::from_le_bytes( + bytes[offset..offset + 4].try_into().unwrap() + ) as usize; + assert!( + neighbor < npoints, + "neighbor index {} out of range for node {}", + neighbor, total_parsed_nodes + ); + offset += 4; + } + total_parsed_nodes += 1; + } + assert_eq!( + total_parsed_nodes, npoints, + "expected to parse {} nodes but got {}", + npoints, total_parsed_nodes + ); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn test_config_c_min_greater_than_c_max() { + let config = PiPNNConfig { + c_min: 2048, + c_max: 1024, + ..Default::default() + }; + assert!( + config.validate().is_err(), + "c_min > c_max should fail validation" + ); + } + + #[test] + fn test_config_empty_fanout() { + let config = PiPNNConfig { + fanout: vec![], + ..Default::default() + }; + assert!( + config.validate().is_err(), + "empty fanout should fail validation" + ); + } + + #[test] + fn test_config_zero_fanout_element() { + let config = PiPNNConfig { + fanout: vec![5, 0, 2], + ..Default::default() + }; + assert!( + config.validate().is_err(), + "fanout containing 0 should fail validation" + ); + } + + #[test] + fn test_config_p_samp_zero() { + let config = PiPNNConfig { + p_samp: 0.0, + ..Default::default() + }; + assert!( + config.validate().is_err(), + "p_samp=0.0 should fail validation" + ); + } + + #[test] + fn test_config_p_samp_negative() { + let config = PiPNNConfig { + p_samp: -0.5, + ..Default::default() + }; + assert!( + config.validate().is_err(), + "p_samp < 0 should fail validation" + ); + } + + #[test] + fn test_config_hash_planes_zero() { + let config = PiPNNConfig { + num_hash_planes: 0, + ..Default::default() + }; + assert!( + config.validate().is_err(), + "num_hash_planes=0 should fail validation" + ); + } + + #[test] + fn test_config_hash_planes_17() { + let config = PiPNNConfig { + num_hash_planes: 17, + ..Default::default() + }; + assert!( + config.validate().is_err(), + "num_hash_planes=17 (> 16) should fail validation" + ); + } + + #[test] + fn test_final_prune_reduces_degree() { + let npoints = 200; + let ndims = 8; + let data = generate_random_data(npoints, ndims, 42); + + // Build without final prune, then build with, and compare max degree. + let config_no_prune = PiPNNConfig { + c_max: 64, + c_min: 16, + k: 6, + max_degree: 16, + replicas: 2, + l_max: 64, + final_prune: false, + ..Default::default() + }; + let config_with_prune = PiPNNConfig { + final_prune: true, + ..config_no_prune.clone() + }; + + let graph_no = build(&data, npoints, ndims, &config_no_prune).unwrap(); + let graph_yes = build(&data, npoints, ndims, &config_with_prune).unwrap(); + + // Final prune should not increase max degree beyond max_degree. + assert!( + graph_yes.max_degree() <= config_with_prune.max_degree, + "final_prune max_degree {} > config max_degree {}", + graph_yes.max_degree(), + config_with_prune.max_degree + ); + + // Both should be valid graphs. + assert!(graph_no.avg_degree() > 0.0); + assert!(graph_yes.avg_degree() > 0.0); + } +} diff --git a/diskann-pipnn/src/gemm.rs b/diskann-pipnn/src/gemm.rs new file mode 100644 index 000000000..9c805a7f2 --- /dev/null +++ b/diskann-pipnn/src/gemm.rs @@ -0,0 +1,146 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! GEMM abstraction using diskann-linalg (faer backend), consistent with DiskANN. + +use diskann_linalg::Transpose; + +/// Compute C = A * B^T where A is m x k and B is n x k (both row-major). +/// Result C is m x n (row-major). +/// +/// Uses diskann-linalg's sgemm backed by faer, the same GEMM DiskANN uses internally. +#[inline] +pub fn sgemm_abt( + a: &[f32], m: usize, k: usize, + b: &[f32], n: usize, + c: &mut [f32], +) { + debug_assert_eq!(a.len(), m * k); + debug_assert_eq!(b.len(), n * k); + debug_assert_eq!(c.len(), m * n); + + diskann_linalg::sgemm( + Transpose::None, + Transpose::Ordinary, + m, n, k, + 1.0, + a, b, + None, + c, + ); +} + +/// Compute C = A * A^T where A is m x k (row-major). +/// Result C is m x m (row-major). +#[inline] +pub fn sgemm_aat(a: &[f32], m: usize, k: usize, c: &mut [f32]) { + sgemm_abt(a, m, k, a, m, c); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sgemm_abt_identity() { + let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let identity = vec![ + 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, + ]; + let mut c = vec![0.0f32; 6]; + sgemm_abt(&a, 2, 3, &identity, 3, &mut c); + for i in 0..6 { + assert!((c[i] - a[i]).abs() < 1e-6, "A*I^T != A at {}: got {}, expected {}", i, c[i], a[i]); + } + } + + #[test] + fn test_sgemm_abt_known() { + let a = vec![1.0, 2.0, 3.0, 4.0]; + let b = vec![5.0, 6.0, 7.0, 8.0]; + let mut c = vec![0.0f32; 4]; + sgemm_abt(&a, 2, 2, &b, 2, &mut c); + let expected = vec![17.0, 23.0, 39.0, 53.0]; + for i in 0..4 { + assert!((c[i] - expected[i]).abs() < 1e-5, "mismatch at {}: got {}, expected {}", i, c[i], expected[i]); + } + } + + #[test] + fn test_sgemm_aat_symmetric() { + let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let mut c = vec![0.0f32; 9]; + sgemm_aat(&a, 3, 3, &mut c); + for i in 0..3 { + for j in (i + 1)..3 { + assert!((c[i * 3 + j] - c[j * 3 + i]).abs() < 1e-5, "not symmetric at ({},{})", i, j); + } + } + assert!((c[0] - 14.0).abs() < 1e-5, "got {}", c[0]); + } + + #[test] + fn test_sgemm_abt_rectangular() { + let a = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; + let b = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0]; + let mut c = vec![0.0f32; 8]; + sgemm_abt(&a, 2, 3, &b, 4, &mut c); + let expected = vec![1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]; + for i in 0..8 { + assert!((c[i] - expected[i]).abs() < 1e-6, "rectangular mismatch at {}", i); + } + } + + #[test] + fn test_sgemm_abt_large() { + let m = 64; + let k = 128; + let n = 64; + let a = vec![1.0f32; m * k]; + let b = vec![1.0f32; n * k]; + let mut c = vec![0.0f32; m * n]; + sgemm_abt(&a, m, k, &b, n, &mut c); + for i in 0..(m * n) { + assert!((c[i] - k as f32).abs() < 1e-3, "large mismatch at {}", i); + } + } + + #[test] + fn test_sgemm_abt_zeros() { + let m = 4; + let k = 8; + let n = 3; + let a = vec![0.0f32; m * k]; + let b = vec![0.0f32; n * k]; + let mut c = vec![99.0f32; m * n]; + sgemm_abt(&a, m, k, &b, n, &mut c); + for i in 0..(m * n) { + assert!(c[i].abs() < 1e-6, "zeros mismatch at {}", i); + } + } + + #[test] + fn test_sgemm_abt_negative() { + let a = vec![-1.0, -2.0, -3.0, -4.0]; + let b = vec![5.0, 6.0, 7.0, 8.0]; + let mut c = vec![0.0f32; 4]; + sgemm_abt(&a, 2, 2, &b, 2, &mut c); + let expected = vec![-17.0, -23.0, -39.0, -53.0]; + for i in 0..4 { + assert!((c[i] - expected[i]).abs() < 1e-5, "negative mismatch at {}", i); + } + } + + #[test] + fn test_sgemm_abt_single_element() { + let a = vec![3.0f32]; + let b = vec![5.0f32]; + let mut c = vec![0.0f32; 1]; + sgemm_abt(&a, 1, 1, &b, 1, &mut c); + assert!((c[0] - 15.0).abs() < 1e-6); + } +} diff --git a/diskann-pipnn/src/hash_prune.rs b/diskann-pipnn/src/hash_prune.rs new file mode 100644 index 000000000..babf1876f --- /dev/null +++ b/diskann-pipnn/src/hash_prune.rs @@ -0,0 +1,598 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! HashPrune: LSH-based online pruning for merging edges from overlapping partitions. +//! +//! Uses random hyperplanes to hash candidate neighbors relative to each point. +//! Maintains a reservoir of l_max entries per point, keyed by hash bucket. +//! This is history-independent (order of insertion does not matter). + +use std::cell::RefCell; + +use parking_lot::Mutex; + +use diskann::utils::VectorRepr; +use rand::SeedableRng; +use rand_distr::{Distribution, StandardNormal}; +use rayon::prelude::*; + +/// Precomputed LSH sketches for a set of vectors. +/// +/// For each vector v, Sketch(v) = [v . H_i for i=0..m] where H_i are random hyperplanes. +/// Sketches are computed via parallel dot products. +pub struct LshSketches { + /// Number of hyperplanes (m). + num_planes: usize, + /// Precomputed sketches: npoints x m, stored row-major. + /// sketch[i * m + j] = dot(point_i, hyperplane_j) + sketches: Vec, + /// Number of points. + npoints: usize, +} + +impl LshSketches { + /// Create new LSH sketches for the given data using parallel dot products. + /// + /// `data` is row-major: npoints x ndims. + pub fn new(data: &[T], npoints: usize, ndims: usize, num_planes: usize, seed: u64) -> Self { + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + + // Generate random hyperplanes from standard normal distribution. + // Stored as num_planes x ndims (row-major). + let hyperplanes: Vec = (0..num_planes * ndims) + .map(|_| StandardNormal.sample(&mut rng)) + .collect(); + + // Compute sketches in parallel using direct dot products. + // For tall-thin output (npoints x 12), this is faster than GEMM. + let mut sketches = vec![0.0f32; npoints * num_planes]; + + sketches + .par_chunks_mut(num_planes) + .enumerate() + .for_each(|(i, sketch_row)| { + // Thread-local buffer for T -> f32 conversion. + thread_local! { + static SKETCH_BUF: RefCell> = RefCell::new(Vec::new()); + } + SKETCH_BUF.with(|cell| { + let mut buf = cell.borrow_mut(); + buf.resize(ndims, 0.0); + T::as_f32_into(&data[i * ndims..(i + 1) * ndims], &mut buf).expect("f32 conversion"); + for j in 0..num_planes { + let plane = &hyperplanes[j * ndims..(j + 1) * ndims]; + let mut dot = 0.0f32; + for d in 0..ndims { + unsafe { + dot += *buf.get_unchecked(d) * *plane.get_unchecked(d); + } + } + sketch_row[j] = dot; + } + }); + }); + + Self { + num_planes, + sketches, + npoints, + } + } + + /// Compute the hash of candidate c relative to point p. + /// + /// h_p(c) = concat of sign bits of (Sketch(c) - Sketch(p)) + /// Returns a u16 hash (supports up to 16 hyperplanes, matching paper's 8-byte entry). + #[inline(always)] + pub fn relative_hash(&self, p: usize, c: usize) -> u16 { + debug_assert!(p < self.npoints); + debug_assert!(c < self.npoints); + debug_assert!(self.num_planes <= 16); + + let m = self.num_planes; + let p_sketch = &self.sketches[p * m..(p + 1) * m]; + let c_sketch = &self.sketches[c * m..(c + 1) * m]; + + let mut hash: u16 = 0; + for j in 0..m { + let diff = c_sketch[j] - p_sketch[j]; + if diff >= 0.0 { + hash |= 1u16 << j; + } + } + hash + } +} + + + +/// Convert f32 distance to bf16 (truncate lower 16 mantissa bits). +/// For non-negative values, bf16 bit ordering matches f32 ordering, +/// so u16 comparison gives correct distance ordering. +#[inline(always)] +fn f32_to_bf16(v: f32) -> u16 { + (v.to_bits() >> 16) as u16 +} + +/// Convert bf16 back to f32 (zero-fill lower mantissa bits). +#[inline(always)] +fn bf16_to_f32(v: u16) -> f32 { + f32::from_bits((v as u32) << 16) +} + +/// A single entry in the HashPrune reservoir. +/// Packed to exactly 8 bytes: 4 (neighbor) + 2 (hash) + 2 (distance as bf16). +#[derive(Debug, Clone, Copy)] +#[repr(C)] +struct ReservoirEntry { + /// The candidate neighbor index. + neighbor: u32, + /// Hash bucket (16-bit). + hash: u16, + /// Distance stored as bf16 (raw u16 bits). Non-negative bf16 values + /// are monotonically ordered as u16, enabling integer comparison. + distance: u16, +} + +/// HashPrune reservoir for a single point. +/// +/// Uses a flat sorted Vec for O(log l) hash lookups instead of HashMap. +/// Caches the farthest entry for O(1) eviction checks. +/// Insertion is O(l) due to element shifting, but cache-friendly at typical l_max ~128. +pub struct HashPruneReservoir { + /// Entries sorted by hash for binary search. + entries: Vec, + /// Maximum reservoir size. + l_max: usize, + /// Cached farthest distance (bf16) and its index in entries. + farthest_dist: u16, + farthest_idx: usize, +} + +impl HashPruneReservoir { + pub fn new(l_max: usize) -> Self { + Self { + entries: Vec::with_capacity(l_max), + l_max, + farthest_dist: 0, + farthest_idx: 0, + } + } + + /// Create a reservoir without pre-allocating capacity. + pub fn new_lazy(l_max: usize) -> Self { + Self { + entries: Vec::new(), + l_max, + farthest_dist: 0, + farthest_idx: 0, + } + } + + /// Create a reservoir with a specific initial capacity hint. + /// Avoids Vec doubling when the expected fill is known. + pub fn new_with_capacity(l_max: usize, initial_capacity: usize) -> Self { + Self { + entries: Vec::with_capacity(initial_capacity), + l_max, + farthest_dist: 0, + farthest_idx: 0, + } + } + + /// Find entry with matching hash using binary search. + #[inline] + fn find_hash(&self, hash: u16) -> Option { + self.entries + .binary_search_by_key(&hash, |e| e.hash) + .ok() + } + + /// Update the cached farthest entry. + #[inline] + fn update_farthest(&mut self) { + if self.entries.is_empty() { + self.farthest_dist = 0; + self.farthest_idx = 0; + return; + } + let mut max_dist: u16 = 0; + let mut max_idx = 0; + for (idx, entry) in self.entries.iter().enumerate() { + if entry.distance > max_dist { + max_dist = entry.distance; + max_idx = idx; + } + } + self.farthest_dist = max_dist; + self.farthest_idx = max_idx; + } + + /// Try to insert a candidate neighbor with the given hash and distance. + /// Distance is converted to bf16 at the boundary for compact storage. + #[inline] + pub fn insert(&mut self, hash: u16, neighbor: u32, distance: f32) -> bool { + let dist_bf16 = f32_to_bf16(distance); + + // If the hash bucket already exists, keep the closer point. + if let Some(idx) = self.find_hash(hash) { + if dist_bf16 < self.entries[idx].distance { + let was_farthest = idx == self.farthest_idx; + self.entries[idx].neighbor = neighbor; + self.entries[idx].distance = dist_bf16; + if was_farthest { + self.update_farthest(); + } + return true; + } + return false; + } + + // If reservoir is not full, insert in sorted position. + if self.entries.len() < self.l_max { + let pos = self.entries + .binary_search_by_key(&hash, |e| e.hash) + .unwrap_or_else(|e| e); + self.entries.insert(pos, ReservoirEntry { neighbor, distance: dist_bf16, hash }); + if dist_bf16 > self.farthest_dist { + self.farthest_dist = dist_bf16; + // Position may have shifted + self.update_farthest(); + } else if self.entries.len() == 1 { + self.farthest_dist = dist_bf16; + self.farthest_idx = 0; + } + return true; + } + + // Reservoir is full: evict farthest if new is closer. + if dist_bf16 < self.farthest_dist { + self.entries.remove(self.farthest_idx); + let pos = self.entries + .binary_search_by_key(&hash, |e| e.hash) + .unwrap_or_else(|e| e); + self.entries.insert(pos, ReservoirEntry { neighbor, distance: dist_bf16, hash }); + self.update_farthest(); + return true; + } + + false + } + + /// Get all neighbors in the reservoir, sorted by distance. + pub fn get_neighbors_sorted(&self) -> Vec<(u32, f32)> { + let mut neighbors: Vec<(u32, u16)> = self + .entries + .iter() + .map(|e| (e.neighbor, e.distance)) + .collect(); + // u16 comparison is correct for non-negative bf16 values. + neighbors.sort_unstable_by_key(|&(_, d)| d); + neighbors.into_iter().map(|(id, d)| (id, bf16_to_f32(d))).collect() + } + + /// Get the number of entries in the reservoir. + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Check if the reservoir is empty. + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } +} + +/// The global HashPrune state managing reservoirs for all points. +/// Uses per-point Mutex for thread-safe parallel edge insertion. +pub struct HashPrune { + /// One reservoir per point, each behind a Mutex for parallel access. + reservoirs: Vec>, + /// LSH sketches. + sketches: LshSketches, + /// Maximum degree for the final graph. + max_degree: usize, +} + +impl HashPrune { + /// Create a new HashPrune instance. + /// + /// `data` is row-major: npoints x ndims. + pub fn new( + data: &[T], + npoints: usize, + ndims: usize, + num_planes: usize, + l_max: usize, + max_degree: usize, + seed: u64, + ) -> Self { + let t0 = std::time::Instant::now(); + let sketches = LshSketches::new(data, npoints, ndims, num_planes, seed); + tracing::debug!(elapsed_secs = t0.elapsed().as_secs_f64(), "sketch computation"); + let t1 = std::time::Instant::now(); + // Use lazy allocation: reservoirs grow on demand as edges are inserted. + // Pre-allocating 64×8B×1M = 512 MB upfront is worse because it spikes + // before any leaf data is freed. Lazy growth + malloc_trim between + // phases keeps peak RSS lower despite realloc fragmentation. + let reservoirs = (0..npoints) + .map(|_| Mutex::new(HashPruneReservoir::new_lazy(l_max))) + .collect(); + tracing::debug!(elapsed_secs = t1.elapsed().as_secs_f64(), "reservoir allocation"); + + Self { + reservoirs, + sketches, + max_degree, + } + } + + /// Add an edge from point `p` to candidate `c` with the given distance. + /// Thread-safe: acquires lock on p's reservoir only. + #[inline] + pub fn add_edge(&self, p: usize, c: usize, distance: f32) { + let hash = self.sketches.relative_hash(p, c); + self.reservoirs[p].lock().insert(hash, c as u32, distance); + } + + /// Add a batch of edges in parallel. Each edge is (point_idx, neighbor_idx, distance). + pub fn add_edges_parallel(&self, edges: &[(usize, usize, f32)]) { + edges.par_iter().for_each(|&(p, c, dist)| { + self.add_edge(p, c, dist); + }); + } + + /// Add edges from a leaf build result, batching by source point. + /// Sorts edges by source to acquire each lock once per unique source. + pub fn add_edges_batched(&self, edges: &[crate::leaf_build::Edge]) { + if edges.is_empty() { + return; + } + + let mut sorted: Vec<&crate::leaf_build::Edge> = edges.iter().collect(); + sorted.sort_unstable_by_key(|e| e.src); + + let mut i = 0; + while i < sorted.len() { + let src = sorted[i].src; + let mut reservoir = self.reservoirs[src].lock(); + while i < sorted.len() && sorted[i].src == src { + let edge = sorted[i]; + let hash = self.sketches.relative_hash(src, edge.dst); + reservoir.insert(hash, edge.dst as u32, edge.distance); + i += 1; + } + } + } + + /// Extract the final graph as adjacency lists, consuming the HashPrune. + /// + /// Consumes self so that reservoirs and sketches are freed as extraction proceeds, + /// rather than staying alive until the caller drops HashPrune. + /// Each reservoir is dropped immediately after its neighbors are extracted. + pub fn extract_graph(self) -> Vec> { + let max_degree = self.max_degree; + // Drop sketches first (~50 MB for 1M points × 12 planes). + drop(self.sketches); + self.reservoirs + .into_par_iter() + .map(|mutex| { + let res = mutex.into_inner(); + let mut neighbors = res.get_neighbors_sorted(); + neighbors.truncate(max_degree); + neighbors.into_iter().map(|(id, _)| id).collect() + }) + .collect() + } + + /// Get the number of points. + pub fn num_points(&self) -> usize { + self.reservoirs.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_reservoir_basic() { + let mut reservoir = HashPruneReservoir::new(3); + assert!(reservoir.is_empty()); + + // Insert three entries with different hashes. + assert!(reservoir.insert(0, 1, 1.0)); + assert!(reservoir.insert(1, 2, 2.0)); + assert!(reservoir.insert(2, 3, 3.0)); + assert_eq!(reservoir.len(), 3); + + // Reservoir is full. New closer entry should evict the farthest. + assert!(reservoir.insert(3, 4, 0.5)); + assert_eq!(reservoir.len(), 3); + + let neighbors = reservoir.get_neighbors_sorted(); + // Should not contain the farthest entry (neighbor 3, distance 3.0). + assert!(!neighbors.iter().any(|(id, _)| *id == 3)); + // Should contain the new closer entry. + assert!(neighbors.iter().any(|(id, _)| *id == 4)); + } + + #[test] + fn test_reservoir_same_hash_keeps_closer() { + let mut reservoir = HashPruneReservoir::new(10); + + assert!(reservoir.insert(0, 1, 2.0)); + assert_eq!(reservoir.len(), 1); + + // Same hash, closer distance: should update. + assert!(reservoir.insert(0, 2, 1.0)); + assert_eq!(reservoir.len(), 1); + + let neighbors = reservoir.get_neighbors_sorted(); + assert_eq!(neighbors[0].0, 2); + assert_eq!(neighbors[0].1, 1.0); + + // Same hash, farther distance: should not update. + assert!(!reservoir.insert(0, 3, 5.0)); + assert_eq!(reservoir.len(), 1); + } + + #[test] + fn test_lsh_sketches() { + // Simple test with 4 points in 2D. + let data = vec![ + 1.0, 0.0, // point 0 + 0.0, 1.0, // point 1 + -1.0, 0.0, // point 2 + 0.0, -1.0, // point 3 + ]; + let sketches = LshSketches::new(&data, 4, 2, 4, 42); + + // Relative hash of a point with itself: all diffs are 0, 0.0 >= 0.0 is true. + let h00 = sketches.relative_hash(0, 0); + assert_eq!(h00, (1u16 << 4) - 1); + + // Different points should generally have different hashes. + let h01 = sketches.relative_hash(0, 1); + let h02 = sketches.relative_hash(0, 2); + let _ = (h01, h02); + } + + #[test] + fn test_hash_prune_end_to_end() { + // 4 points in 2D. + let data = vec![ + 0.0, 0.0, // point 0 + 1.0, 0.0, // point 1 + 0.0, 1.0, // point 2 + 1.0, 1.0, // point 3 + ]; + + let hp = HashPrune::new(&data, 4, 2, 4, 10, 3, 42); + + // Add some edges. + hp.add_edge(0, 1, 1.0); + hp.add_edge(0, 2, 1.0); + hp.add_edge(0, 3, 1.414); + hp.add_edge(1, 0, 1.0); + hp.add_edge(1, 3, 1.0); + hp.add_edge(2, 0, 1.0); + hp.add_edge(2, 3, 1.0); + hp.add_edge(3, 1, 1.0); + hp.add_edge(3, 2, 1.0); + + let graph = hp.extract_graph(); + assert_eq!(graph.len(), 4); + + for (i, neighbors) in graph.iter().enumerate() { + assert!( + !neighbors.is_empty(), + "point {} has no neighbors", + i + ); + } + } + + #[test] + fn test_reservoir_lazy_allocation() { + let mut res = HashPruneReservoir::new_lazy(5); + assert!(res.is_empty()); + assert!(res.insert(0, 1, 1.0)); + assert_eq!(res.len(), 1); + } + + #[test] + fn test_reservoir_insert_then_evict_cycle() { + let mut res = HashPruneReservoir::new(3); + res.insert(0, 10, 3.0); + res.insert(1, 11, 2.0); + res.insert(2, 12, 1.0); + assert_eq!(res.len(), 3); + assert!(res.insert(3, 13, 0.5)); + assert_eq!(res.len(), 3); + let neighbors = res.get_neighbors_sorted(); + assert!(neighbors.iter().all(|&(_, d)| d <= 2.0)); + } + + #[test] + fn test_reservoir_all_same_hash() { + let mut res = HashPruneReservoir::new(5); + res.insert(0, 1, 3.0); + res.insert(0, 2, 2.0); + res.insert(0, 3, 1.0); + assert_eq!(res.len(), 1); + let neighbors = res.get_neighbors_sorted(); + assert_eq!(neighbors[0].0, 3); + assert_eq!(neighbors[0].1, 1.0); + } + + #[test] + fn test_reservoir_all_same_distance() { + let mut res = HashPruneReservoir::new(5); + res.insert(0, 1, 1.0); + res.insert(1, 2, 1.0); + res.insert(2, 3, 1.0); + assert_eq!(res.len(), 3); + } + + #[test] + fn test_hash_prune_parallel_safety() { + use rayon::prelude::*; + let data = vec![0.0f32; 100 * 4]; + let hp = HashPrune::new(&data, 100, 4, 4, 10, 5, 42); + (0..50).into_par_iter().for_each(|i| { + hp.add_edge(i, (i + 1) % 100, 1.0); + hp.add_edge((i + 1) % 100, i, 1.0); + }); + let graph = hp.extract_graph(); + assert_eq!(graph.len(), 100); + } + + #[test] + fn test_hash_prune_high_degree_limit() { + let data = vec![0.0f32; 10 * 2]; + let hp = HashPrune::new(&data, 10, 2, 4, 10, 1, 42); + for i in 0..10 { + for j in 0..10 { + if i != j { hp.add_edge(i, j, (i as f32 - j as f32).abs()); } + } + } + let graph = hp.extract_graph(); + for neighbors in &graph { + assert!(neighbors.len() <= 1, "max_degree=1 should limit to 1 neighbor"); + } + } + + #[test] + fn test_hash_prune_extract_sorted() { + let data = vec![0.0f32; 4 * 2]; + let hp = HashPrune::new(&data, 4, 2, 4, 10, 3, 42); + hp.add_edge(0, 1, 3.0); + hp.add_edge(0, 2, 1.0); + hp.add_edge(0, 3, 2.0); + let graph = hp.extract_graph(); + assert!(!graph[0].is_empty()); + } + + #[test] + fn test_lsh_sketches_different_seeds() { + let data = vec![1.0f32, 0.0, 0.0, 1.0]; + let s1 = LshSketches::new(&data, 2, 2, 4, 42); + let s2 = LshSketches::new(&data, 2, 2, 4, 99); + let h1 = s1.relative_hash(0, 1); + let h2 = s2.relative_hash(0, 1); + // Different seeds should generally produce different hashes (not guaranteed but very likely) + let _ = (h1, h2); // Just verify they compile and don't panic + } + + #[test] + fn test_relative_hash_symmetry_broken() { + let data = vec![1.0f32, 0.0, 0.0, 1.0, -1.0, 0.0]; + let sketches = LshSketches::new(&data, 3, 2, 4, 42); + let h01 = sketches.relative_hash(0, 1); + let h10 = sketches.relative_hash(1, 0); + // h_p(c) != h_c(p) in general because relative_hash is asymmetric + let _ = (h01, h10); + } +} diff --git a/diskann-pipnn/src/leaf_build.rs b/diskann-pipnn/src/leaf_build.rs new file mode 100644 index 000000000..c75cf2f66 --- /dev/null +++ b/diskann-pipnn/src/leaf_build.rs @@ -0,0 +1,633 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Leaf building: GEMM-based all-pairs distance computation and bi-directed k-NN extraction. +//! +//! For each leaf partition (bounded by C_max, typically 1024-2048): +//! 1. Compute all-pairs distance matrix via GEMM +//! For L2: ||a-b||^2 = ||a||^2 + ||b||^2 - 2*(a.b) +//! The dot product matrix A * A^T is computed as a GEMM operation. +//! 2. Extract k nearest neighbors per point using partial sort +//! 3. Create bi-directed edges (both forward and reverse k-NN) + +use std::cell::RefCell; + +use diskann::utils::VectorRepr; +use diskann_vector::PureDistanceFunction; +use diskann_vector::distance::SquaredL2; + +/// Thread-local reusable buffers for leaf building. +/// Avoids repeated allocation/deallocation of large matrices. +pub struct LeafBuffers { + pub local_data: Vec, + pub norms_sq: Vec, + pub dot_matrix: Vec, + pub dist_matrix: Vec, + pub seen: Vec, +} + +impl LeafBuffers { + pub fn new() -> Self { + Self { + local_data: Vec::new(), + norms_sq: Vec::new(), + dot_matrix: Vec::new(), + dist_matrix: Vec::new(), + seen: Vec::new(), + } + } + + /// Ensure all buffers are large enough for a leaf of size n x ndims. + fn ensure_capacity(&mut self, n: usize, ndims: usize) { + let nd = n * ndims; + let nn = n * n; + if self.local_data.len() < nd { self.local_data.resize(nd, 0.0); } + if self.norms_sq.len() < n { self.norms_sq.resize(n, 0.0); } + if self.dot_matrix.len() < nn { self.dot_matrix.resize(nn, 0.0); } + if self.dist_matrix.len() < nn { self.dist_matrix.resize(nn, 0.0); } + if self.seen.len() < nn { self.seen.resize(nn, false); } + } +} + +thread_local! { + static LEAF_BUFFERS: RefCell = RefCell::new(LeafBuffers::new()); + static QUANT_SEEN: RefCell> = RefCell::new(Vec::new()); +} + +/// Release thread-local leaf build buffers on the calling thread. +/// +/// After leaf building is complete, these buffers pin pages in glibc's +/// per-thread arenas, preventing `malloc_trim` from returning freed +/// reservoir memory to the OS. Calling this from each rayon thread +/// allows the arena heaps to be reclaimed. +pub fn release_thread_buffers() { + LEAF_BUFFERS.with(|cell| { + let mut bufs = cell.borrow_mut(); + bufs.local_data = Vec::new(); + bufs.norms_sq = Vec::new(); + bufs.dot_matrix = Vec::new(); + bufs.dist_matrix = Vec::new(); + bufs.seen = Vec::new(); + }); + QUANT_SEEN.with(|cell| { + *cell.borrow_mut() = Vec::new(); + }); +} + +/// An edge produced by leaf building: (source, destination, distance). +#[derive(Debug, Clone, Copy)] +pub struct Edge { + pub src: usize, + pub dst: usize, + pub distance: f32, +} + + +/// Extract k nearest neighbors for each point from the distance matrix. +/// +/// Uses index-sort: partitions a u32 index array by indirect distance comparison. +/// Sorting 4-byte indices instead of 8-byte (index, distance) pairs reduces memory +/// movement during quickselect, yielding ~1.5x speedup over the pair-based approach. +fn extract_knn(dist_matrix: &[f32], n: usize, k: usize) -> Vec<(usize, usize, f32)> { + let actual_k = k.min(n - 1); + let mut edges = Vec::with_capacity(n * actual_k); + + // Reuse index buffer across all rows (4 bytes per element vs 8 for pairs). + let mut indices: Vec = (0..n as u32).collect(); + + for i in 0..n { + let row = &dist_matrix[i * n..(i + 1) * n]; + + // Reset indices for this row. + for j in 0..n { + unsafe { *indices.get_unchecked_mut(j) = j as u32; } + } + + if actual_k < n { + indices.select_nth_unstable_by(actual_k, |&a, &b| { + let da = unsafe { *row.get_unchecked(a as usize) }; + let db = unsafe { *row.get_unchecked(b as usize) }; + da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal) + }); + } + + for idx in 0..actual_k { + let j = unsafe { *indices.get_unchecked(idx) } as usize; + edges.push((i, j, row[j])); + } + } + + edges +} + +/// Build a leaf partition: compute all-pairs distances and extract bi-directed k-NN edges. +/// +/// Returns edges as (global_src, global_dst, distance). +pub fn build_leaf( + data: &[T], + ndims: usize, + indices: &[usize], + k: usize, + metric: diskann_vector::distance::Metric, +) -> Vec { + let n = indices.len(); + if n <= 1 { + return Vec::new(); + } + + LEAF_BUFFERS.with(|cell| { + let mut bufs = cell.borrow_mut(); + build_leaf_with_buffers(data, ndims, indices, k, metric, &mut bufs) + }) +} + +fn build_leaf_with_buffers( + data: &[T], + ndims: usize, + indices: &[usize], + k: usize, + metric: diskann_vector::distance::Metric, + bufs: &mut LeafBuffers, +) -> Vec { + let n = indices.len(); + bufs.ensure_capacity(n, ndims); + + // Extract local data into reused buffer, converting T -> f32 on the fly. + let local_data = &mut bufs.local_data[..n * ndims]; + for (i, &idx) in indices.iter().enumerate() { + let src = &data[idx * ndims..(idx + 1) * ndims]; + let dst = &mut local_data[i * ndims..(i + 1) * ndims]; + T::as_f32_into(src, dst).expect("f32 conversion"); + } + + // Compute norms into reused buffer. + let norms_sq = &mut bufs.norms_sq[..n]; + for i in 0..n { + let row = &local_data[i * ndims..(i + 1) * ndims]; + let mut norm = 0.0f32; + for &v in row.iter() { norm += v * v; } + norms_sq[i] = norm; + } + + // GEMM: dots = local_data * local_data^T + // Computes all n² dot products at once via BLAS — much faster than n² individual + // distance calls. The dot-to-distance conversion below is O(n²) scalar ops. + let dot_matrix = &mut bufs.dot_matrix[..n * n]; + crate::gemm::sgemm_aat(local_data, n, ndims, dot_matrix); + + // Convert to distance matrix using the target metric. + use diskann_vector::distance::Metric; + let dist_matrix = match metric { + Metric::CosineNormalized => { + // Pre-normalized: dist = 1 - dot(a, b) + for i in 0..n { + let row = &mut dot_matrix[i * n..(i + 1) * n]; + for j in 0..n { row[j] = (1.0 - row[j]).max(0.0); } + row[i] = f32::MAX; + } + &mut bufs.dot_matrix[..n * n] + } + Metric::Cosine => { + // Unnormalized: dist = 1 - dot(a,b)/(||a||*||b||) + let dist = &mut bufs.dist_matrix[..n * n]; + for i in 0..n { + let ni_sqrt = norms_sq[i].sqrt(); + for j in 0..n { + let denom = ni_sqrt * norms_sq[j].sqrt(); + let cos_sim = if denom > 0.0 { dot_matrix[i * n + j] / denom } else { 0.0 }; + dist[i * n + j] = (1.0 - cos_sim).max(0.0); + } + dist[i * n + i] = f32::MAX; + } + dist + } + Metric::L2 => { + let dist = &mut bufs.dist_matrix[..n * n]; + for i in 0..n { + let ni = norms_sq[i]; + for j in 0..n { + dist[i * n + j] = (ni + norms_sq[j] - 2.0 * dot_matrix[i * n + j]).max(0.0); + } + dist[i * n + i] = f32::MAX; + } + dist + } + Metric::InnerProduct => { + for i in 0..n { + let row = &mut dot_matrix[i * n..(i + 1) * n]; + for j in 0..n { row[j] = -row[j]; } + row[i] = f32::MAX; + } + &mut bufs.dot_matrix[..n * n] + } + }; + + // Extract k-NN edges. + let local_edges = extract_knn(dist_matrix, n, k); + + // Create bi-directed edges using reused seen buffer. + let seen = &mut bufs.seen[..n * n]; + seen.fill(false); + + let mut global_edges = Vec::with_capacity(local_edges.len() * 2); + + for &(src, dst, dist) in &local_edges { + if !seen[src * n + dst] { + seen[src * n + dst] = true; + global_edges.push(Edge { src: indices[src], dst: indices[dst], distance: dist }); + } + if !seen[dst * n + src] { + seen[dst * n + src] = true; + global_edges.push(Edge { src: indices[dst], dst: indices[src], distance: dist_matrix[dst * n + src] }); + } + } + + global_edges +} + +/// Build a leaf using 1-bit quantized vectors with Hamming distance. +pub fn build_leaf_quantized( + qdata: &crate::quantize::QuantizedData, + indices: &[usize], + k: usize, +) -> Vec { + let n = indices.len(); + if n <= 1 { + return Vec::new(); + } + + let dist_matrix = qdata.compute_distance_matrix(indices); + let local_edges = extract_knn(&dist_matrix, n, k); + + QUANT_SEEN.with(|cell| { + let mut seen = cell.borrow_mut(); + seen.resize(n * n, false); + seen.fill(false); + + let mut global_edges = Vec::with_capacity(local_edges.len() * 2); + + for &(src, dst, dist) in &local_edges { + if !seen[src * n + dst] { + seen[src * n + dst] = true; + global_edges.push(Edge { src: indices[src], dst: indices[dst], distance: dist }); + } + if !seen[dst * n + src] { + seen[dst * n + src] = true; + global_edges.push(Edge { src: indices[dst], dst: indices[src], distance: dist_matrix[dst * n + src] }); + } + } + + global_edges + }) +} + +/// Brute-force search the dataset using L2 distance. +/// +/// Returns the `k` nearest neighbor indices and distances for the query. +pub fn brute_force_knn( + data: &[f32], + ndims: usize, + npoints: usize, + query: &[f32], + k: usize, +) -> Vec<(usize, f32)> { + let mut dists: Vec<(usize, f32)> = (0..npoints) + .map(|i| { + let point = &data[i * ndims..(i + 1) * ndims]; + let dist = SquaredL2::evaluate(point, query); + (i, dist) + }) + .collect(); + + let actual_k = k.min(npoints); + if actual_k < dists.len() { + dists.select_nth_unstable_by(actual_k, |a, b| { + a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) + }); + dists.truncate(actual_k); + } + dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + dists +} + +#[cfg(test)] +mod tests { + use super::*; + use diskann_vector::distance::{DistanceProvider, Metric}; + + #[test] + fn test_gemm_aat() { + // 2x3 matrix: + // [1 2 3] + // [4 5 6] + // A * A^T should be: + // [14 32] + // [32 77] + let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let mut result = vec![0.0; 4]; + crate::gemm::sgemm_aat(&a, 2, 3, &mut result); + + assert!((result[0] - 14.0).abs() < 1e-6); + assert!((result[1] - 32.0).abs() < 1e-6); + assert!((result[2] - 32.0).abs() < 1e-6); + assert!((result[3] - 77.0).abs() < 1e-6); + } + + #[test] + fn test_distance_l2() { + let dist_fn = >::distance_comparer(Metric::L2, Some(2)); + let p0 = [0.0f32, 0.0]; + let p1 = [1.0f32, 0.0]; + let p2 = [0.0f32, 1.0]; + // dist(0,1) = 1 + assert!((dist_fn.call(&p0, &p1) - 1.0).abs() < 1e-6); + // dist(0,2) = 1 + assert!((dist_fn.call(&p0, &p2) - 1.0).abs() < 1e-6); + // dist(1,2) = 2 + assert!((dist_fn.call(&p1, &p2) - 2.0).abs() < 1e-6); + } + + #[test] + fn test_build_leaf() { + let data = vec![ + 0.0, 0.0, // point 0 + 1.0, 0.0, // point 1 + 0.0, 1.0, // point 2 + 1.0, 1.0, // point 3 + ]; + let indices = vec![0, 1, 2, 3]; + + let edges = build_leaf(&data, 2, &indices, 2, Metric::L2); + + assert!(!edges.is_empty()); + + for edge in &edges { + assert!(edge.src < 4); + assert!(edge.dst < 4); + assert!(edge.src != edge.dst); + assert!(edge.distance >= 0.0); + } + } + + #[test] + fn test_extract_knn() { + let dist = vec![ + f32::MAX, 1.0, 4.0, + 1.0, f32::MAX, 1.0, + 4.0, 1.0, f32::MAX, + ]; + let edges = extract_knn(&dist, 3, 1); + + assert_eq!(edges.len(), 3); + + let p0_edges: Vec<_> = edges.iter().filter(|e| e.0 == 0).collect(); + assert_eq!(p0_edges.len(), 1); + assert_eq!(p0_edges[0].1, 1); + + let p2_edges: Vec<_> = edges.iter().filter(|e| e.0 == 2).collect(); + assert_eq!(p2_edges.len(), 1); + assert_eq!(p2_edges[0].1, 1); + } + + #[test] + fn test_brute_force_knn() { + let data = vec![ + 0.0, 0.0, // point 0 + 1.0, 0.0, // point 1 + 0.0, 1.0, // point 2 + 1.0, 1.0, // point 3 + ]; + let query = vec![0.1, 0.1]; + let results = brute_force_knn(&data, 2, 4, &query, 2); + + assert_eq!(results.len(), 2); + assert_eq!(results[0].0, 0); + } + + #[test] + fn test_build_leaf_cosine() { + // Verify that cosine distance path works correctly with normalized vectors. + let mut data = vec![ + 1.0, 0.0, // point 0: along x + 0.0, 1.0, // point 1: along y + 0.707, 0.707, // point 2: 45 degrees + -1.0, 0.0, // point 3: negative x + ]; + // Normalize all vectors. + for i in 0..4 { + let row = &mut data[i * 2..(i + 1) * 2]; + let norm: f32 = row.iter().map(|v| v * v).sum::().sqrt(); + if norm > 0.0 { + for v in row.iter_mut() { *v /= norm; } + } + } + + let indices = vec![0, 1, 2, 3]; + let edges = build_leaf(&data, 2, &indices, 2, Metric::CosineNormalized); + + assert!(!edges.is_empty(), "cosine leaf should produce edges"); + + for edge in &edges { + assert!(edge.src < 4); + assert!(edge.dst < 4); + assert_ne!(edge.src, edge.dst); + // Cosine distance for normalized vectors is in [0, 2]. + assert!(edge.distance >= 0.0, "negative cosine distance"); + } + } + + #[test] + fn test_build_leaf_quantized() { + // Build a leaf using quantized data and verify basic correctness. + let ndims = 64; + let npoints = 10; + use rand::{Rng, SeedableRng}; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let data: Vec = (0..npoints * ndims) + .map(|_| rng.random_range(-1.0..1.0)) + .collect(); + + let (shift, inverse_scale) = { + use diskann_quantization::scalar::train::ScalarQuantizationParameters; + use diskann_utils::views::MatrixView; + let dm = MatrixView::try_from(data.as_slice(), npoints, ndims).unwrap(); + let q = ScalarQuantizationParameters::default().train(dm); + let s = q.scale(); + (q.shift().to_vec(), if s == 0.0 { 1.0 } else { 1.0 / s }) + }; + let qdata = crate::quantize::quantize_1bit(&data, npoints, ndims, &shift, inverse_scale); + let indices: Vec = (0..npoints).collect(); + let edges = build_leaf_quantized(&qdata, &indices, 3); + + assert!(!edges.is_empty(), "quantized leaf should produce edges"); + + for edge in &edges { + assert!(edge.src < npoints, "src {} out of range", edge.src); + assert!(edge.dst < npoints, "dst {} out of range", edge.dst); + assert_ne!(edge.src, edge.dst); + assert!(edge.distance >= 0.0); + } + } + + #[test] + fn test_build_leaf_single_point() { + // A leaf with 1 point should produce no edges. + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let indices = vec![0]; + let edges = build_leaf(&data, 4, &indices, 3, Metric::L2); + assert!( + edges.is_empty(), + "single point leaf should produce 0 edges, got {}", + edges.len() + ); + } + + #[test] + fn test_build_leaf_two_points() { + // A leaf with 2 points should produce bidirectional edges. + let data = vec![0.0f32, 0.0, 1.0, 0.0]; + let indices = vec![0, 1]; + let edges = build_leaf(&data, 2, &indices, 3, Metric::L2); + assert!(!edges.is_empty(), "two point leaf should produce edges"); + + // Should have both directions: 0->1 and 1->0. + let has_0_to_1 = edges.iter().any(|e| e.src == 0 && e.dst == 1); + let has_1_to_0 = edges.iter().any(|e| e.src == 1 && e.dst == 0); + assert!(has_0_to_1, "should have edge 0 -> 1"); + assert!(has_1_to_0, "should have edge 1 -> 0"); + } + + #[test] + fn test_build_leaf_k_equals_n() { + // k >= n, every point should connect to every other. + let data = vec![ + 0.0, 0.0, + 1.0, 0.0, + 0.0, 1.0, + 1.0, 1.0, + ]; + let indices = vec![0, 1, 2, 3]; + let n = indices.len(); + // k = n means each point gets n-1 nearest neighbors = all others. + let edges = build_leaf(&data, 2, &indices, n, Metric::L2); + + // Collect directed edges. + let edge_set: std::collections::HashSet<(usize, usize)> = + edges.iter().map(|e| (e.src, e.dst)).collect(); + + // Every pair (i, j) with i != j should be present. + for i in 0..n { + for j in 0..n { + if i != j { + assert!( + edge_set.contains(&(i, j)), + "k >= n: edge ({} -> {}) should exist", + i, j + ); + } + } + } + } + + #[test] + fn test_build_leaf_with_buffers_reuse() { + // Call build_leaf_with_buffers twice and verify buffers are reused. + let data = vec![ + 0.0, 0.0, + 1.0, 0.0, + 0.0, 1.0, + 1.0, 1.0, + ]; + let indices = vec![0, 1, 2, 3]; + let mut bufs = LeafBuffers::new(); + + let edges1 = build_leaf_with_buffers(&data, 2, &indices, 2, Metric::L2, &mut bufs); + assert!(!edges1.is_empty(), "first call should produce edges"); + + // Verify buffers are allocated. + assert!(!bufs.local_data.is_empty(), "buffers should be allocated after first call"); + + // Second call with same data should still work. + let edges2 = build_leaf_with_buffers(&data, 2, &indices, 2, Metric::L2, &mut bufs); + assert_eq!( + edges1.len(), edges2.len(), + "same input should produce same number of edges with reused buffers" + ); + } + + #[test] + fn test_extract_knn_k_larger_than_n() { + // k > n-1 should be clamped. + let dist = vec![ + f32::MAX, 1.0, + 1.0, f32::MAX, + ]; + let edges = extract_knn(&dist, 2, 100); // k=100 but only 2 points + assert_eq!( + edges.len(), 2, + "k > n-1 should be clamped, each point gets 1 neighbor, total 2 edges" + ); + } + + #[test] + fn test_brute_force_knn_single_point() { + let data = vec![5.0f32, 10.0]; + let query = vec![5.0, 10.0]; + let results = brute_force_knn(&data, 2, 1, &query, 5); + assert_eq!(results.len(), 1, "brute force on 1 point should return 1 result"); + assert_eq!(results[0].0, 0, "should return the only point (index 0)"); + assert!( + results[0].1 < 1e-6, + "distance to identical query should be near zero" + ); + } + + #[test] + fn test_brute_force_knn_identity() { + // query = data point, first result should be self with distance 0. + let data = vec![ + 0.0, 0.0, + 1.0, 0.0, + 0.0, 1.0, + 1.0, 1.0, + ]; + let query = vec![1.0, 0.0]; // same as point 1 + let results = brute_force_knn(&data, 2, 4, &query, 3); + assert_eq!(results[0].0, 1, "query identical to point 1 should find it first"); + assert!( + results[0].1 < 1e-6, + "self-distance should be 0, got {}", + results[0].1 + ); + } + + #[test] + fn test_edge_symmetry() { + // Verify that build_leaf produces bi-directed edges: + // if (a -> b) exists, then (b -> a) should also exist. + let data = vec![ + 0.0, 0.0, + 1.0, 0.0, + 0.0, 1.0, + 1.0, 1.0, + 0.5, 0.5, + ]; + let indices = vec![0, 1, 2, 3, 4]; + let edges = build_leaf(&data, 2, &indices, 2, Metric::L2); + + // Collect directed edges as a set. + let edge_set: std::collections::HashSet<(usize, usize)> = + edges.iter().map(|e| (e.src, e.dst)).collect(); + + // For every edge (a, b), (b, a) should also exist. + for edge in &edges { + assert!( + edge_set.contains(&(edge.dst, edge.src)), + "edge ({} -> {}) exists but reverse ({} -> {}) does not", + edge.src, edge.dst, edge.dst, edge.src + ); + } + } +} diff --git a/diskann-pipnn/src/lib.rs b/diskann-pipnn/src/lib.rs new file mode 100644 index 000000000..b87879ec5 --- /dev/null +++ b/diskann-pipnn/src/lib.rs @@ -0,0 +1,185 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! PiPNN (Pick-in-Partitions Nearest Neighbors) index builder. +//! +//! Implements the PiPNN algorithm from arXiv:2602.21247, which builds graph-based +//! ANN indexes significantly faster than Vamana/HNSW by: +//! 1. Partitioning the dataset into overlapping clusters via Randomized Ball Carving +//! 2. Building local graphs within each leaf cluster using GEMM-based all-pairs distance +//! 3. Merging edges from overlapping partitions using HashPrune (LSH-based online pruning) + +pub mod builder; +pub mod gemm; +pub mod hash_prune; +pub mod leaf_build; +pub mod partition; +pub mod quantize; + +use diskann_vector::distance::Metric; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +/// Errors that can occur during PiPNN index construction. +#[derive(Debug, Error)] +pub enum PiPNNError { + #[error("configuration error: {0}")] + Config(String), + + #[error("data dimension mismatch: expected {expected}, got {actual}")] + DimensionMismatch { expected: usize, actual: usize }, + + #[error("data length mismatch: expected {expected} elements ({npoints} x {ndims}), got {actual}")] + DataLengthMismatch { + expected: usize, + actual: usize, + npoints: usize, + ndims: usize, + }, + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), +} + +/// Result type for PiPNN operations. +pub type PiPNNResult = Result; + +/// Custom serde module for `Metric`, which does not derive Serialize/Deserialize. +/// Serializes as a string representation (e.g. "l2", "cosine"). +mod metric_serde { + use diskann_vector::distance::Metric; + use serde::{self, Deserialize, Deserializer, Serializer}; + + pub fn serialize(metric: &Metric, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(metric.as_str()) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + s.parse::().map_err(serde::de::Error::custom) + } +} + +/// Configuration for the PiPNN index builder. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PiPNNConfig { + /// Number of LSH hyperplanes for HashPrune. + pub num_hash_planes: usize, + /// Maximum leaf partition size. + pub c_max: usize, + /// Minimum cluster size before merging. + pub c_min: usize, + /// Sampling fraction for RBC leaders. + pub p_samp: f64, + /// Fanout at each partitioning level (overlap factor). + pub fanout: Vec, + /// k for k-NN in leaf building. + pub k: usize, + /// Maximum graph degree (R). + pub max_degree: usize, + /// Number of independent partitioning passes (replicas). + pub replicas: usize, + /// Maximum reservoir size per node in HashPrune. + pub l_max: usize, + /// Distance metric. + #[serde(with = "metric_serde")] + pub metric: Metric, + /// Whether to apply a final RobustPrune pass. + pub final_prune: bool, + /// Alpha (occlusion factor) for final RobustPrune. Same as DiskANN's `alpha` parameter. + /// Higher values yield sparser graphs. Default: 1.2 (matches DiskANN default). + pub alpha: f32, +} + +impl PiPNNConfig { + /// Validate the configuration, returning an error if any parameter is invalid. + pub fn validate(&self) -> PiPNNResult<()> { + if self.c_max == 0 { + return Err(PiPNNError::Config("c_max must be > 0".into())); + } + if self.c_min == 0 { + return Err(PiPNNError::Config("c_min must be > 0".into())); + } + if self.c_min > self.c_max { + return Err(PiPNNError::Config(format!( + "c_min ({}) must be <= c_max ({})", + self.c_min, self.c_max + ))); + } + if self.max_degree == 0 { + return Err(PiPNNError::Config("max_degree must be > 0".into())); + } + if self.k == 0 { + return Err(PiPNNError::Config("k must be > 0".into())); + } + if self.replicas == 0 { + return Err(PiPNNError::Config("replicas must be > 0".into())); + } + if self.l_max == 0 { + return Err(PiPNNError::Config("l_max must be > 0".into())); + } + if self.p_samp <= 0.0 || self.p_samp > 1.0 { + return Err(PiPNNError::Config(format!( + "p_samp ({}) must be in (0.0, 1.0]", + self.p_samp + ))); + } + if !self.p_samp.is_finite() { + return Err(PiPNNError::Config("p_samp must be finite".into())); + } + if self.fanout.is_empty() { + return Err(PiPNNError::Config("fanout must not be empty".into())); + } + if self.fanout.iter().any(|&f| f == 0) { + return Err(PiPNNError::Config("all fanout values must be > 0".into())); + } + if self.num_hash_planes == 0 || self.num_hash_planes > 16 { + return Err(PiPNNError::Config(format!( + "num_hash_planes ({}) must be in [1, 16]", + self.num_hash_planes + ))); + } + if self.alpha < 1.0 { + return Err(PiPNNError::Config(format!( + "alpha ({}) must be >= 1.0", + self.alpha + ))); + } + if !self.alpha.is_finite() { + return Err(PiPNNError::Config("alpha must be finite".into())); + } + if self.metric == Metric::InnerProduct { + return Err(PiPNNError::Config( + "InnerProduct metric is not supported by PiPNN; use L2, Cosine, or CosineNormalized".into(), + )); + } + Ok(()) + } +} + +impl Default for PiPNNConfig { + fn default() -> Self { + Self { + num_hash_planes: 12, + c_max: 1024, + c_min: 256, + p_samp: 0.005, + fanout: vec![10, 3], + k: 3, + max_degree: 64, + replicas: 1, + l_max: 128, + metric: Metric::L2, + final_prune: false, + alpha: 1.2, + } + } +} diff --git a/diskann-pipnn/src/partition.rs b/diskann-pipnn/src/partition.rs new file mode 100644 index 000000000..47f858746 --- /dev/null +++ b/diskann-pipnn/src/partition.rs @@ -0,0 +1,1036 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Randomized Ball Carving (RBC) partitioning. +//! +//! Recursively partitions the dataset into overlapping clusters: +//! - Sample a fraction of points as leaders +//! - Assign each point to its `fanout` nearest leaders (creating overlap) +//! - Merge undersized clusters +//! - Recurse on oversized clusters + +use diskann::utils::VectorRepr; +use rand::prelude::IndexedRandom; +use rand::{Rng, SeedableRng}; +use rayon::prelude::*; + +/// Maximum recursion depth to prevent stack overflow. +const MAX_DEPTH: usize = 30; + +/// A leaf partition containing indices into the original dataset. +#[derive(Debug, Clone)] +pub struct Leaf { + pub indices: Vec, +} + +/// Configuration for RBC partitioning. +#[derive(Debug, Clone)] +pub struct PartitionConfig { + pub c_max: usize, + pub c_min: usize, + pub p_samp: f64, + pub fanout: Vec, + /// Distance metric for partition assignment. + pub metric: diskann_vector::distance::Metric, +} + +/// Compute squared L2 distance between two f32 slices using manual loop +/// (auto-vectorized by the compiler). +#[allow(dead_code)] // Alternative implementation kept for benchmarking/debugging. +#[inline] +fn l2_distance_inline(a: &[f32], b: &[f32]) -> f32 { + debug_assert_eq!(a.len(), b.len()); + let mut sum = 0.0f32; + for i in 0..a.len() { + let d = unsafe { *a.get_unchecked(i) - *b.get_unchecked(i) }; + sum += d * d; + } + sum +} + +/// Quantized version of partition_assign using Hamming distance on 1-bit data. +/// Pre-extracts leader u64 data for cache locality. +fn partition_assign_quantized( + qdata: &crate::quantize::QuantizedData, + points: &[usize], + leaders: &[usize], + fanout: usize, +) -> Vec> { + let np = points.len(); + let nl = leaders.len(); + let num_assign = fanout.min(nl); + let u64s = qdata.u64s_per_vec(); + + // Pre-extract leader data into contiguous cache-friendly array. + let mut leader_data: Vec = vec![0u64; nl * u64s]; + for (i, &idx) in leaders.iter().enumerate() { + leader_data[i * u64s..(i + 1) * u64s].copy_from_slice(qdata.get_u64(idx)); + } + + let mut assignments = vec![0u32; np * num_assign]; + + const STRIPE: usize = 32_768; + assignments + .par_chunks_mut(STRIPE * num_assign) + .enumerate() + .for_each(|(stripe_idx, assign_chunk)| { + let start = stripe_idx * STRIPE; + let end = (start + STRIPE).min(np); + let sn = end - start; + + // Pre-extract point data for this stripe. + let mut point_data: Vec = vec![0u64; sn * u64s]; + for i in 0..sn { + let src = qdata.get_u64(points[start + i]); + point_data[i * u64s..(i + 1) * u64s].copy_from_slice(src); + } + + let mut buf: Vec<(u32, f32)> = Vec::with_capacity(nl); + let ld_ptr = leader_data.as_ptr(); + let pd_ptr = point_data.as_ptr(); + + for i in 0..sn { + let pt_base = unsafe { pd_ptr.add(i * u64s) }; + + // Compute Hamming distance to all leaders + build buf in one pass. + buf.clear(); + for j in 0..nl { + let ld_base = unsafe { ld_ptr.add(j * u64s) }; + let mut h = 0u32; + for k in 0..u64s { + unsafe { h += (*pt_base.add(k) ^ *ld_base.add(k)).count_ones(); } + } + buf.push((j as u32, h as f32)); + } + if num_assign < buf.len() { + buf.select_nth_unstable_by(num_assign, |a, b| { + a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) + }); + } + + let out = &mut assign_chunk[i * num_assign..(i + 1) * num_assign]; + for k in 0..num_assign { + out[k] = buf[k].0; + } + } + }); + + let mut clusters: Vec> = vec![Vec::new(); nl]; + for i in 0..np { + let row = &assignments[i * num_assign..(i + 1) * num_assign]; + for &li in row { + clusters[li as usize].push(i); + } + } + clusters +} + +/// Fused GEMM + assignment: compute distances to leaders in stripes and immediately +/// extract top-k assignments without materializing the full N x L distance matrix. +/// Peak memory: stripe * L * 4 bytes (~64MB) instead of N * L * 4 bytes. +fn partition_assign( + data: &[T], + ndims: usize, + points: &[usize], + leaders: &[usize], + fanout: usize, + metric: diskann_vector::distance::Metric, +) -> Vec> { + partition_assign_impl(data, ndims, points, leaders, fanout, metric) +} + +/// Core implementation: fused GEMM + distance + top-k assignment in parallel stripes. +fn partition_assign_impl( + data: &[T], + ndims: usize, + points: &[usize], + leaders: &[usize], + fanout: usize, + metric: diskann_vector::distance::Metric, +) -> Vec> { + let np = points.len(); + let nl = leaders.len(); + let num_assign = fanout.min(nl); + + use diskann_vector::distance::Metric; + + // Extract leader data (shared, stays in cache), converting T -> f32. + let mut l_data = vec![0.0f32; nl * ndims]; + for (i, &idx) in leaders.iter().enumerate() { + let src = &data[idx * ndims..(idx + 1) * ndims]; + let dst = &mut l_data[i * ndims..(i + 1) * ndims]; + T::as_f32_into(src, dst).expect("f32 conversion"); + } + // Precompute leader norms. + // L2 needs squared norms; Cosine needs sqrt norms; CosineNormalized/IP need none. + let l_norms: Vec = match metric { + Metric::L2 => { + let mut norms = vec![0.0f32; nl]; + for i in 0..nl { + let row = &l_data[i * ndims..(i + 1) * ndims]; + let mut norm = 0.0f32; + for &v in row { norm += v * v; } + norms[i] = norm; + } + norms + } + Metric::Cosine => { + let mut norms = vec![0.0f32; nl]; + for i in 0..nl { + let row = &l_data[i * ndims..(i + 1) * ndims]; + let mut norm = 0.0f32; + for &v in row { norm += v * v; } + norms[i] = norm.sqrt(); + } + norms + } + Metric::CosineNormalized | Metric::InnerProduct => Vec::new(), + }; + + // Flat assignments: assignments[i * num_assign .. (i+1) * num_assign] + let mut assignments = vec![0u32; np * num_assign]; + + // Fused parallel stripes: GEMM + distance + top-k in one pass. + // Adaptive stripe size: limit per-stripe GEMM output to ~16 MB. + // Smaller stripes reduce concurrent memory from ~1.4 GB (8 threads × 90 MB) + // to ~350 MB (8 threads × 22 MB), cutting partition peak RSS by ~1 GB. + // Partition is <5% of total build time, so the throughput cost is negligible. + let stripe: usize = ((16 * 1024 * 1024) / (nl.max(1) * std::mem::size_of::())) + .clamp(256, 16_384); + assignments + .par_chunks_mut(stripe * num_assign) + .enumerate() + .for_each(|(stripe_idx, assign_chunk)| { + let start = stripe_idx * stripe; + let end = (start + stripe).min(np); + let sn = end - start; + let stripe_points = &points[start..end]; + + let mut p_data = vec![0.0f32; sn * ndims]; + for (i, &idx) in stripe_points.iter().enumerate() { + let src = &data[idx * ndims..(idx + 1) * ndims]; + let dst = &mut p_data[i * ndims..(i + 1) * ndims]; + T::as_f32_into(src, dst).expect("f32 conversion"); + } + + let mut dots = vec![0.0f32; sn * nl]; + crate::gemm::sgemm_abt(&p_data, sn, ndims, &l_data, nl, &mut dots); + + let mut buf: Vec<(u32, f32)> = Vec::with_capacity(nl); + for i in 0..sn { + let dot_row = &dots[i * nl..(i + 1) * nl]; + + buf.clear(); + match metric { + Metric::CosineNormalized => { + // Pre-normalized: dist = 1 - dot(a, b) + for j in 0..nl { + buf.push((j as u32, (1.0 - dot_row[j]).max(0.0))); + } + } + Metric::Cosine => { + // Unnormalized: dist = 1 - dot(a,b)/(||a||*||b||) + let mut pi = 0.0f32; + let row = &p_data[i * ndims..(i + 1) * ndims]; + for &v in row { pi += v * v; } + let pi_sqrt = pi.sqrt(); + for j in 0..nl { + let denom = pi_sqrt * l_norms[j]; + let cos_sim = if denom > 0.0 { dot_row[j] / denom } else { 0.0 }; + buf.push((j as u32, (1.0 - cos_sim).max(0.0))); + } + } + Metric::L2 => { + let mut pi = 0.0f32; + let row = &p_data[i * ndims..(i + 1) * ndims]; + for &v in row { pi += v * v; } + for j in 0..nl { + let d = (pi + l_norms[j] - 2.0 * dot_row[j]).max(0.0); + buf.push((j as u32, d)); + } + } + Metric::InnerProduct => { + for j in 0..nl { + buf.push((j as u32, -dot_row[j])); + } + } + } + + if num_assign < buf.len() { + buf.select_nth_unstable_by(num_assign, |a, b| { + a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) + }); + } + + let out = &mut assign_chunk[i * num_assign..(i + 1) * num_assign]; + for k in 0..num_assign { + out[k] = buf[k].0; + } + } + }); + + // Aggregate into per-leader clusters. + let mut clusters: Vec> = vec![Vec::new(); nl]; + for i in 0..np { + let row = &assignments[i * num_assign..(i + 1) * num_assign]; + for &li in row { + clusters[li as usize].push(i); + } + } + clusters +} + +/// Force-split a set of indices into chunks of at most c_max, used as fallback. +fn force_split(indices: &[usize], c_max: usize) -> Vec { + indices + .chunks(c_max) + .map(|chunk| Leaf { + indices: chunk.to_vec(), + }) + .collect() +} + +/// Merge undersized clusters into the nearest large cluster by centroid distance. +/// +/// Paper (arXiv:2602.21247): "Merge undersized clusters into the nearest +/// (by centroid) appropriately-sized cluster." +fn merge_small_into_nearest( + data: &[T], + ndims: usize, + mut clusters: Vec>, + c_min: usize, +) -> Vec> { + let mut large: Vec> = Vec::new(); + let mut smalls: Vec> = Vec::new(); + + for c in clusters.drain(..) { + if c.len() < c_min && !c.is_empty() { + smalls.push(c); + } else if !c.is_empty() { + large.push(c); + } + } + + if smalls.is_empty() || large.is_empty() { + if large.is_empty() { return smalls; } + return large; + } + + // Compute centroids for large clusters, converting T -> f32 per point. + let centroids: Vec> = large.iter() + .map(|c| { + let mut centroid = vec![0.0f32; ndims]; + let inv = 1.0 / c.len() as f32; + let mut point_buf = vec![0.0f32; ndims]; + for &idx in c { + T::as_f32_into(&data[idx * ndims..(idx + 1) * ndims], &mut point_buf).expect("f32 conversion"); + for d in 0..ndims { centroid[d] += point_buf[d]; } + } + for d in 0..ndims { centroid[d] *= inv; } + centroid + }) + .collect(); + + // For each small cluster, find nearest large cluster by L2 distance + // from the small cluster's representative point to each large centroid. + for small in smalls { + let mut rep_buf = vec![0.0f32; ndims]; + T::as_f32_into(&data[small[0] * ndims..(small[0] + 1) * ndims], &mut rep_buf).expect("f32 conversion"); + let nearest = centroids.iter().enumerate() + .map(|(i, c)| { + let mut dist = 0.0f32; + for d in 0..ndims { + let diff = unsafe { *rep_buf.get_unchecked(d) - *c.get_unchecked(d) }; + dist += diff * diff; + } + (i, dist) + }) + .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i) + .unwrap_or(0); + large[nearest].extend(small); + } + + large +} + +/// Partition the dataset using Randomized Ball Carving. +/// +/// `data` is row-major: npoints_global x ndims. +/// `indices` are the global indices of the points to partition. +pub fn partition( + data: &[T], + ndims: usize, + indices: &[usize], + config: &PartitionConfig, + level: usize, + rng: &mut impl Rng, +) -> Vec { + let n = indices.len(); + + if n <= config.c_max { + return vec![Leaf { + indices: indices.to_vec(), + }]; + } + + // For clusters at deep recursion levels or only marginally over c_max, + // force-split is cheaper than doing another full GEMM + assignment. + if level >= MAX_DEPTH || (level >= 2 && n <= config.c_max * 3) { + return force_split(indices, config.c_max); + } + + let fanout = if level < config.fanout.len() { + config.fanout[level] + } else { + 1 + }; + + // Sample leaders. + let num_leaders = ((n as f64 * config.p_samp).ceil() as usize) + .max(2) + .min(n); + + let leaders: Vec = indices.choose_multiple(rng, num_leaders).copied().collect(); + + // Fused GEMM + assignment (avoids materializing full distance matrix). + let clusters_local = partition_assign(data, ndims, indices, &leaders, fanout, config.metric); + + // Map local indices back to global. + let clusters: Vec> = clusters_local + .into_iter() + .map(|local_cluster| { + local_cluster.into_iter().map(|li| indices[li]).collect() + }) + .collect(); + + // Merge undersized clusters into nearest large cluster by centroid proximity. + let merged_clusters = merge_small_into_nearest(data, ndims, clusters, config.c_min); + + if merged_clusters.len() == 1 && merged_clusters[0].len() > config.c_max { + return force_split(&merged_clusters[0], config.c_max); + } + + let mut leaves = Vec::new(); + for cluster in merged_clusters { + if cluster.len() <= config.c_max { + leaves.push(Leaf { indices: cluster }); + } else { + let sub_seed: u64 = rng.random(); + let mut sub_rng = rand::rngs::StdRng::seed_from_u64(sub_seed); + let sub_leaves = partition(data, ndims, &cluster, config, level + 1, &mut sub_rng); + leaves.extend(sub_leaves); + } + } + + leaves +} + +/// Partition using parallelism at the top level. +/// Prints timing breakdown for the top-level operations. +pub fn parallel_partition( + data: &[T], + ndims: usize, + indices: &[usize], + config: &PartitionConfig, + seed: u64, +) -> Vec { + let n = indices.len(); + + if n <= config.c_max { + return vec![Leaf { + indices: indices.to_vec(), + }]; + } + + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let fanout = if !config.fanout.is_empty() { + config.fanout[0] + } else { + 3 + }; + + // Sample leaders. + let num_leaders = ((n as f64 * config.p_samp).ceil() as usize) + .max(2) + .min(n); + + let leaders: Vec = indices.choose_multiple(&mut rng, num_leaders).copied().collect(); + + // Fused GEMM + assignment. + let t0 = std::time::Instant::now(); + let clusters_local = partition_assign(data, ndims, indices, &leaders, fanout, config.metric); + let assign_time = t0.elapsed(); + + let t1 = std::time::Instant::now(); + let clusters: Vec> = clusters_local + .into_iter() + .map(|local_cluster| { + local_cluster.into_iter().map(|li| indices[li]).collect() + }) + .collect(); + let map_time = t1.elapsed(); + + tracing::debug!( + assign_secs = assign_time.as_secs_f64(), + map_secs = map_time.as_secs_f64(), + num_leaders = num_leaders, + fanout = fanout, + "top-level partition assign" + ); + + // Merge undersized clusters into nearest large cluster by centroid proximity. + let merged_clusters = merge_small_into_nearest(data, ndims, clusters, config.c_min); + + let need_recurse = merged_clusters.iter().filter(|c| c.len() > config.c_max).count(); + let total_in_recurse: usize = merged_clusters.iter().filter(|c| c.len() > config.c_max).map(|c| c.len()).sum(); + tracing::debug!( + num_clusters = merged_clusters.len(), + need_recurse = need_recurse, + total_in_recurse = total_in_recurse, + "partition merge" + ); + + // Generate sub-seeds for parallel recursion. + let sub_seeds: Vec = (0..merged_clusters.len()) + .map(|_| rng.random()) + .collect(); + + // Recurse in parallel. Each cluster is either a leaf or needs further splitting. + let t2 = std::time::Instant::now(); + let results: Vec> = merged_clusters + .par_iter() + .zip(sub_seeds.par_iter()) + .map(|(cluster, sub_seed)| { + if cluster.len() <= config.c_max { + vec![Leaf { + indices: cluster.clone(), + }] + } else { + let mut sub_rng = rand::rngs::StdRng::seed_from_u64(*sub_seed); + partition(data, ndims, cluster, config, 1, &mut sub_rng) + } + }) + .collect(); + + tracing::debug!(recursion_secs = t2.elapsed().as_secs_f64(), "partition recursion complete"); + results.into_iter().flatten().collect() +} + +/// Quantized version of parallel_partition using Hamming distance on 1-bit data. +pub fn parallel_partition_quantized( + qdata: &crate::quantize::QuantizedData, + indices: &[usize], + config: &PartitionConfig, + seed: u64, +) -> Vec { + let n = indices.len(); + if n <= config.c_max { + return vec![Leaf { indices: indices.to_vec() }]; + } + + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let fanout = if !config.fanout.is_empty() { config.fanout[0] } else { 3 }; + + let num_leaders = ((n as f64 * config.p_samp).ceil() as usize) + .max(2).min(n); + + let leaders: Vec = indices.choose_multiple(&mut rng, num_leaders).copied().collect(); + + let t0 = std::time::Instant::now(); + let clusters_local = partition_assign_quantized(qdata, indices, &leaders, fanout); + let assign_time = t0.elapsed(); + + let t1 = std::time::Instant::now(); + let mut clusters: Vec> = clusters_local + .into_iter() + .map(|local_cluster| local_cluster.into_iter().map(|li| indices[li]).collect()) + .collect(); + let map_time = t1.elapsed(); + + tracing::debug!( + assign_secs = assign_time.as_secs_f64(), + map_secs = map_time.as_secs_f64(), + num_leaders = num_leaders, + fanout = fanout, + "top-level partition assign (quantized)" + ); + + // Merge undersized clusters. + let mut merged_clusters: Vec> = Vec::new(); + let mut small_clusters: Vec> = Vec::new(); + for cluster in clusters.drain(..) { + if cluster.len() < config.c_min && !cluster.is_empty() { + small_clusters.push(cluster); + } else if !cluster.is_empty() { + merged_clusters.push(cluster); + } + } + if !small_clusters.is_empty() && !merged_clusters.is_empty() { + for small in small_clusters { + let min_idx = merged_clusters.iter().enumerate() + .min_by_key(|(_, c)| c.len()).map(|(i, _)| i).unwrap_or(0); + merged_clusters[min_idx].extend(small); + } + } else if merged_clusters.is_empty() { + merged_clusters = small_clusters; + } + + let need_recurse = merged_clusters.iter().filter(|c| c.len() > config.c_max).count(); + tracing::debug!( + num_clusters = merged_clusters.len(), + need_recurse = need_recurse, + "partition merge (quantized)" + ); + + let sub_seeds: Vec = (0..merged_clusters.len()).map(|_| rng.random()).collect(); + + let t2 = std::time::Instant::now(); + let results: Vec> = merged_clusters + .par_iter() + .zip(sub_seeds.par_iter()) + .map(|(cluster, sub_seed)| { + if cluster.len() <= config.c_max { + vec![Leaf { indices: cluster.clone() }] + } else if cluster.len() <= config.c_max * 3 { + force_split(cluster, config.c_max) + } else { + // Recursive quantized partition. + let mut sub_rng = rand::rngs::StdRng::seed_from_u64(*sub_seed); + partition_quantized_recursive(qdata, cluster, config, 1, &mut sub_rng) + } + }) + .collect(); + + tracing::debug!(recursion_secs = t2.elapsed().as_secs_f64(), "partition recursion complete (quantized)"); + results.into_iter().flatten().collect() +} + +fn partition_quantized_recursive( + qdata: &crate::quantize::QuantizedData, + indices: &[usize], + config: &PartitionConfig, + level: usize, + rng: &mut impl Rng, +) -> Vec { + let n = indices.len(); + if n <= config.c_max { return vec![Leaf { indices: indices.to_vec() }]; } + if level >= MAX_DEPTH || (level >= 2 && n <= config.c_max * 3) { + return force_split(indices, config.c_max); + } + + let fanout = if level < config.fanout.len() { config.fanout[level] } else { 1 }; + let num_leaders = ((n as f64 * config.p_samp).ceil() as usize).max(2).min(n); + + let leaders: Vec = indices.choose_multiple(rng, num_leaders).copied().collect(); + + let clusters_local = partition_assign_quantized(qdata, indices, &leaders, fanout); + let mut clusters: Vec> = clusters_local + .into_iter() + .map(|lc| lc.into_iter().map(|li| indices[li]).collect()) + .collect(); + + // Merge small clusters. + let mut merged: Vec> = Vec::new(); + let mut smalls: Vec> = Vec::new(); + for c in clusters.drain(..) { + if c.len() < config.c_min && !c.is_empty() { smalls.push(c); } + else if !c.is_empty() { merged.push(c); } + } + if !smalls.is_empty() && !merged.is_empty() { + for s in smalls { + let mi = merged.iter().enumerate().min_by_key(|(_, c)| c.len()).map(|(i,_)| i).unwrap_or(0); + merged[mi].extend(s); + } + } else if merged.is_empty() { merged = smalls; } + + if merged.len() == 1 && merged[0].len() > config.c_max { + return force_split(&merged[0], config.c_max); + } + + let mut leaves = Vec::new(); + for cluster in merged { + if cluster.len() <= config.c_max { + leaves.push(Leaf { indices: cluster }); + } else { + let sub_seed: u64 = rng.random(); + let mut sub_rng = rand::rngs::StdRng::seed_from_u64(sub_seed); + leaves.extend(partition_quantized_recursive(qdata, &cluster, config, level + 1, &mut sub_rng)); + } + } + leaves +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::SeedableRng; + + #[test] + fn test_partition_small_dataset() { + let data: Vec = (0..20).map(|i| i as f32).collect(); + let indices: Vec = (0..10).collect(); + let config = PartitionConfig { + c_max: 10, + c_min: 3, + p_samp: 0.5, + fanout: vec![3], + metric: diskann_vector::distance::Metric::L2, + }; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let leaves = partition(&data, 2, &indices, &config, 0, &mut rng); + + assert_eq!(leaves.len(), 1); + assert_eq!(leaves[0].indices.len(), 10); + } + + #[test] + fn test_partition_needs_splitting() { + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let data: Vec = (0..200) + .map(|_| rand::Rng::random_range(&mut rng, -10.0..10.0)) + .collect(); + let indices: Vec = (0..100).collect(); + let config = PartitionConfig { + c_max: 20, + c_min: 5, + p_samp: 0.1, + fanout: vec![3, 2], + metric: diskann_vector::distance::Metric::L2, + }; + + let mut rng2 = rand::rngs::StdRng::seed_from_u64(123); + let leaves = partition(&data, 2, &indices, &config, 0, &mut rng2); + + assert!(leaves.len() > 1, "expected multiple leaves, got {}", leaves.len()); + + for leaf in &leaves { + assert!( + leaf.indices.len() <= config.c_max, + "leaf too large: {}", + leaf.indices.len() + ); + } + + let total: usize = leaves.iter().map(|l| l.indices.len()).sum(); + assert!( + total >= indices.len(), + "total assignments {} < original count {}", + total, + indices.len() + ); + } + + #[test] + fn test_parallel_partition() { + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let data: Vec = (0..2000) + .map(|_| rand::Rng::random_range(&mut rng, -10.0..10.0)) + .collect(); + let indices: Vec = (0..1000).collect(); + let config = PartitionConfig { + c_max: 50, + c_min: 10, + p_samp: 0.05, + fanout: vec![5, 3], + metric: diskann_vector::distance::Metric::L2, + }; + + let leaves = parallel_partition(&data, 2, &indices, &config, 42); + + assert!(leaves.len() > 1); + for leaf in &leaves { + assert!( + leaf.indices.len() <= config.c_max, + "leaf too large: {}", + leaf.indices.len() + ); + } + } + + #[test] + fn test_partition_overlap() { + // With fanout > 1, each point is assigned to multiple leaders, + // so the total assignments across all leaves should exceed the + // original point count (overlap). + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let npoints = 500; + let ndims = 4; + let data: Vec = (0..npoints * ndims) + .map(|_| rand::Rng::random_range(&mut rng, -5.0..5.0)) + .collect(); + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 100, + c_min: 20, + p_samp: 0.05, + fanout: vec![3, 2], // fanout > 1 creates overlap + metric: diskann_vector::distance::Metric::L2, + }; + + let leaves = parallel_partition(&data, ndims, &indices, &config, 42); + + let total_in_leaves: usize = leaves.iter().map(|l| l.indices.len()).sum(); + assert!( + total_in_leaves >= npoints, + "total in leaves ({}) should be >= npoints ({})", + total_in_leaves, + npoints + ); + } + + #[test] + fn test_partition_respects_c_max() { + // All leaves must have at most c_max elements. + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let npoints = 300; + let ndims = 4; + let data: Vec = (0..npoints * ndims) + .map(|_| rand::Rng::random_range(&mut rng, -5.0..5.0)) + .collect(); + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 40, + c_min: 10, + p_samp: 0.1, + fanout: vec![5, 2], + metric: diskann_vector::distance::Metric::L2, + }; + + let leaves = parallel_partition(&data, ndims, &indices, &config, 99); + for (i, leaf) in leaves.iter().enumerate() { + assert!( + leaf.indices.len() <= config.c_max, + "leaf {} has size {} > c_max {}", + i, + leaf.indices.len(), + config.c_max + ); + } + } + + #[test] + fn test_partition_single_point() { + let data = vec![1.0f32, 2.0]; + let indices = vec![0usize]; + let config = PartitionConfig { + c_max: 10, + c_min: 1, + p_samp: 0.5, + fanout: vec![3], + metric: diskann_vector::distance::Metric::L2, + }; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let leaves = partition(&data, 2, &indices, &config, 0, &mut rng); + assert_eq!(leaves.len(), 1, "single point should produce 1 leaf"); + assert_eq!(leaves[0].indices.len(), 1, "leaf should contain exactly 1 point"); + assert_eq!(leaves[0].indices[0], 0, "leaf should contain index 0"); + } + + #[test] + fn test_partition_two_points() { + let data = vec![0.0f32, 0.0, 10.0, 10.0]; + let indices = vec![0, 1]; + let config = PartitionConfig { + c_max: 5, + c_min: 1, + p_samp: 0.5, + fanout: vec![3], + metric: diskann_vector::distance::Metric::L2, + }; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let leaves = partition(&data, 2, &indices, &config, 0, &mut rng); + assert_eq!(leaves.len(), 1, "two points with c_max=5 should produce 1 leaf"); + assert_eq!(leaves[0].indices.len(), 2, "leaf should contain both points"); + } + + #[test] + fn test_partition_all_identical() { + // All identical vectors should still partition without crashing. + let npoints = 100; + let ndims = 4; + let data = vec![42.0f32; npoints * ndims]; + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 20, + c_min: 5, + p_samp: 0.1, + fanout: vec![3], + metric: diskann_vector::distance::Metric::L2, + }; + let leaves = parallel_partition(&data, ndims, &indices, &config, 42); + assert!(!leaves.is_empty(), "should produce at least one leaf"); + let total: usize = leaves.iter().map(|l| l.indices.len()).sum(); + assert!( + total >= npoints, + "total in leaves ({}) should be >= npoints ({})", + total, npoints + ); + for (i, leaf) in leaves.iter().enumerate() { + assert!( + leaf.indices.len() <= config.c_max, + "leaf {} has {} elements > c_max={}", + i, leaf.indices.len(), config.c_max + ); + } + } + + #[test] + fn test_partition_high_fanout() { + // fanout > npoints should still work (clamped to num_leaders). + let npoints = 20; + let ndims = 4; + let mut rng_data = rand::rngs::StdRng::seed_from_u64(42); + let data: Vec = (0..npoints * ndims) + .map(|_| rand::Rng::random_range(&mut rng_data, -10.0..10.0)) + .collect(); + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 5, + c_min: 2, + p_samp: 0.5, + fanout: vec![100], // much larger than npoints + metric: diskann_vector::distance::Metric::L2, + }; + let leaves = parallel_partition(&data, ndims, &indices, &config, 42); + assert!(!leaves.is_empty(), "high fanout should still produce leaves"); + for (i, leaf) in leaves.iter().enumerate() { + assert!( + leaf.indices.len() <= config.c_max, + "leaf {} has {} elements > c_max={}", + i, leaf.indices.len(), config.c_max + ); + } + } + + #[test] + fn test_partition_multi_level_fanout() { + // Multi-level fanout vec![4,2] should work and produce valid leaves. + let npoints = 200; + let ndims = 4; + let mut rng_data = rand::rngs::StdRng::seed_from_u64(42); + let data: Vec = (0..npoints * ndims) + .map(|_| rand::Rng::random_range(&mut rng_data, -10.0..10.0)) + .collect(); + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 30, + c_min: 8, + p_samp: 0.1, + fanout: vec![4, 2], + metric: diskann_vector::distance::Metric::L2, + }; + let leaves = parallel_partition(&data, ndims, &indices, &config, 42); + assert!(leaves.len() > 1, "multi-level fanout should produce multiple leaves"); + for (i, leaf) in leaves.iter().enumerate() { + assert!( + leaf.indices.len() <= config.c_max, + "leaf {} has {} elements > c_max={}", + i, leaf.indices.len(), config.c_max + ); + } + } + + #[test] + fn test_partition_c_min_equals_c_max() { + // c_min == c_max is a valid (if unusual) configuration. + let npoints = 100; + let ndims = 4; + let mut rng_data = rand::rngs::StdRng::seed_from_u64(42); + let data: Vec = (0..npoints * ndims) + .map(|_| rand::Rng::random_range(&mut rng_data, -10.0..10.0)) + .collect(); + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 30, + c_min: 30, + p_samp: 0.1, + fanout: vec![3], + metric: diskann_vector::distance::Metric::L2, + }; + let leaves = parallel_partition(&data, ndims, &indices, &config, 42); + assert!(!leaves.is_empty(), "c_min == c_max should produce leaves"); + for (i, leaf) in leaves.iter().enumerate() { + assert!( + leaf.indices.len() <= config.c_max, + "leaf {} has {} elements > c_max={}", + i, leaf.indices.len(), config.c_max + ); + } + } + + #[test] + fn test_partition_large_p_samp() { + // p_samp=1.0 means sample all points as leaders. + let npoints = 50; + let ndims = 4; + let mut rng_data = rand::rngs::StdRng::seed_from_u64(42); + let data: Vec = (0..npoints * ndims) + .map(|_| rand::Rng::random_range(&mut rng_data, -10.0..10.0)) + .collect(); + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 10, + c_min: 3, + p_samp: 1.0, + fanout: vec![3], + metric: diskann_vector::distance::Metric::L2, + }; + let leaves = parallel_partition(&data, ndims, &indices, &config, 42); + assert!(!leaves.is_empty(), "p_samp=1.0 should produce leaves"); + for (i, leaf) in leaves.iter().enumerate() { + assert!( + leaf.indices.len() <= config.c_max, + "leaf {} has {} elements > c_max={}", + i, leaf.indices.len(), config.c_max + ); + } + } + + #[test] + fn test_partition_quantized() { + // Quantized partition should produce valid leaves with same constraints. + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let npoints = 300; + let ndims = 64; // must be multiple of 64 for u64 alignment + let data: Vec = (0..npoints * ndims) + .map(|_| rand::Rng::random_range(&mut rng, -5.0..5.0)) + .collect(); + let indices: Vec = (0..npoints).collect(); + let config = PartitionConfig { + c_max: 80, + c_min: 20, + p_samp: 0.05, + fanout: vec![3, 2], + metric: diskann_vector::distance::Metric::L2, + }; + + let (shift, inverse_scale) = { + use diskann_quantization::scalar::train::ScalarQuantizationParameters; + use diskann_utils::views::MatrixView; + let dm = MatrixView::try_from(data.as_slice(), npoints, ndims).unwrap(); + let q = ScalarQuantizationParameters::default().train(dm); + let s = q.scale(); + (q.shift().to_vec(), if s == 0.0 { 1.0 } else { 1.0 / s }) + }; + let qdata = crate::quantize::quantize_1bit(&data, npoints, ndims, &shift, inverse_scale); + let leaves = parallel_partition_quantized(&qdata, &indices, &config, 42); + + assert!(!leaves.is_empty(), "no leaves produced"); + for (i, leaf) in leaves.iter().enumerate() { + assert!( + leaf.indices.len() <= config.c_max, + "quantized leaf {} has size {} > c_max {}", + i, + leaf.indices.len(), + config.c_max + ); + // All indices should be valid. + for &idx in &leaf.indices { + assert!(idx < npoints, "index {} out of range", idx); + } + } + } +} diff --git a/diskann-pipnn/src/quantize.rs b/diskann-pipnn/src/quantize.rs new file mode 100644 index 000000000..32c157abf --- /dev/null +++ b/diskann-pipnn/src/quantize.rs @@ -0,0 +1,530 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! 1-bit scalar quantization for PiPNN. +//! +//! Reuses diskann-quantization's ScalarQuantizer for training (shift/scale), +//! then packs vectors into compact bit arrays for fast Hamming distance. + +use rayon::prelude::*; + +/// Result of 1-bit quantization. +pub struct QuantizedData { + /// Packed bit vectors: each vector is `bytes_per_vec` bytes. + /// Layout: npoints * bytes_per_vec, row-major. + pub bits: Vec, + /// Number of bytes per vector (ceil(ndims / 8)). + pub bytes_per_vec: usize, + /// Original dimensionality. + pub ndims: usize, + /// Number of points. + pub npoints: usize, +} + +/// Quantize data to 1-bit using pre-trained shift and inverse_scale parameters. +/// +/// Uses parameters from a `ScalarQuantizer` trained by DiskANN's build pipeline, +/// ensuring identical quantization regardless of build algorithm (Vamana vs PiPNN). +/// +/// Each dimension is packed to 1 bit: +/// bit = 1 if (value - shift[d]) * inverse_scale >= 0.5, else 0 +/// +/// # Arguments +/// * `shift` - Per-dimension shift from ScalarQuantizer (length = ndims) +/// * `inverse_scale` - 1.0 / scale from ScalarQuantizer +pub fn quantize_1bit( + data: &[f32], + npoints: usize, + ndims: usize, + shift: &[f32], + inverse_scale: f32, +) -> QuantizedData { + // Round up to a multiple of 8 bytes (64 bits) so that get_u64() is always aligned. + let bytes_per_vec = ((ndims + 63) / 64) * 8; + let total_bytes = npoints * bytes_per_vec; + // Allocate as Vec for guaranteed u64 alignment, then reinterpret as Vec. + let u64s_total = total_bytes / 8; + let mut bits_u64 = vec![0u64; u64s_total]; + // SAFETY: Vec has stricter alignment than Vec. We reinterpret the allocation + // in-place, preserving the original alignment. Length and capacity are scaled by 8. + let mut bits = unsafe { + let ptr = bits_u64.as_mut_ptr() as *mut u8; + let len = bits_u64.len() * 8; + let cap = bits_u64.capacity() * 8; + std::mem::forget(bits_u64); + Vec::from_raw_parts(ptr, len, cap) + }; + + // Parallel quantization. + bits.par_chunks_mut(bytes_per_vec) + .enumerate() + .for_each(|(i, out)| { + let vec = &data[i * ndims..(i + 1) * ndims]; + for d in 0..ndims { + let code = ((vec[d] - shift[d]) * inverse_scale).clamp(0.0, 1.0).round() as u8; + if code > 0 { + out[d / 8] |= 1 << (d % 8); + } + } + }); + + QuantizedData { + bits, + bytes_per_vec, + ndims, + npoints, + } +} + +impl QuantizedData { + /// Number of points. + pub fn npoints(&self) -> usize { + self.npoints + } + + /// Get the packed bit vector for point i. + #[inline(always)] + pub fn get(&self, i: usize) -> &[u8] { + debug_assert!(i < self.npoints, "QuantizedData::get index {} out of range (npoints={})", i, self.npoints); + let start = i * self.bytes_per_vec; + unsafe { self.bits.get_unchecked(start..start + self.bytes_per_vec) } + } + + /// Get the packed bit vector as u64 slice for point i (fast path). + #[inline(always)] + pub fn get_u64(&self, i: usize) -> &[u64] { + debug_assert!(i < self.npoints, "QuantizedData::get index {} out of range (npoints={})", i, self.npoints); + let start = i * self.bytes_per_vec; + let u64s = self.bytes_per_vec / 8; + // SAFETY: bits buffer was allocated as Vec, guaranteeing u64 alignment. bytes_per_vec is always a multiple of 8. + unsafe { + let ptr = self.bits.as_ptr().add(start) as *const u64; + std::slice::from_raw_parts(ptr, u64s) + } + } + + /// Number of u64s per vector. + #[inline] + pub fn u64s_per_vec(&self) -> usize { + self.bytes_per_vec / 8 + } + + /// Compute Hamming distance between two quantized vectors (u64 fast path). + #[inline(always)] + pub fn hamming_u64(a: &[u64], b: &[u64]) -> u32 { + let mut dist = 0u32; + for i in 0..a.len() { + unsafe { + dist += (*a.get_unchecked(i) ^ *b.get_unchecked(i)).count_ones(); + } + } + dist + } + + /// Compute Hamming distance between two byte slices. + #[inline] + pub fn hamming(a: &[u8], b: &[u8]) -> u32 { + let chunks = a.len() / 8; + let a64 = a.as_ptr() as *const u64; + let b64 = b.as_ptr() as *const u64; + let mut dist = 0u32; + for i in 0..chunks { + unsafe { + let va = std::ptr::read_unaligned(a64.add(i)); + let vb = std::ptr::read_unaligned(b64.add(i)); + dist += (va ^ vb).count_ones(); + } + } + for i in (chunks * 8)..a.len() { + dist += (a[i] ^ b[i]).count_ones(); + } + dist + } + + /// Compute all-pairs Hamming distance matrix for a set of points. + /// Returns flat n x n matrix (row-major) with f32::MAX on diagonal. + /// Inlines the Hamming computation and uses unchecked indexing for speed. + pub fn compute_distance_matrix(&self, indices: &[usize]) -> Vec { + let n = indices.len(); + let u64s = self.u64s_per_vec(); + + // Extract contiguous u64 data for cache locality. + let mut local: Vec = vec![0u64; n * u64s]; + for (i, &idx) in indices.iter().enumerate() { + let src = self.get_u64(idx); + local[i * u64s..(i + 1) * u64s].copy_from_slice(src); + } + + let mut dist = vec![f32::MAX; n * n]; + let local_ptr = local.as_ptr(); + let dist_ptr = dist.as_mut_ptr(); + + // Flat loop with inlined Hamming — avoids function call + slice bounds overhead. + for i in 0..n { + let a_base = unsafe { local_ptr.add(i * u64s) }; + for j in (i + 1)..n { + let b_base = unsafe { local_ptr.add(j * u64s) }; + + // Inline Hamming: XOR + popcount over u64s. + let mut h = 0u32; + for k in 0..u64s { + unsafe { + h += (*a_base.add(k) ^ *b_base.add(k)).count_ones(); + } + } + + let d = h as f32; + unsafe { + *dist_ptr.add(i * n + j) = d; + *dist_ptr.add(j * n + i) = d; + } + } + } + dist + } + + /// Compute Hamming distances from one point to many leaders. + /// Returns distances as f32 slice. + pub fn distances_to_leaders( + &self, + point_idx: usize, + leader_indices: &[usize], + out: &mut [f32], + ) { + let pt = self.get_u64(point_idx); + for (j, &leader_idx) in leader_indices.iter().enumerate() { + let ld = self.get_u64(leader_idx); + out[j] = Self::hamming_u64(pt, ld) as f32; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: create data with known bit patterns for predictable quantization. + /// All values are either -1.0 or 1.0 so 1-bit quantization is unambiguous. + fn make_binary_data(npoints: usize, ndims: usize, seed: u64) -> Vec { + use rand::{Rng, SeedableRng}; + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + (0..npoints * ndims) + .map(|_| if rng.random_bool(0.5) { 1.0f32 } else { -1.0f32 }) + .collect() + } + + /// Train SQ parameters and quantize to 1-bit. Test-only convenience wrapper + /// that uses DiskANN's ScalarQuantizer to compute shift/scale, then calls + /// `quantize_1bit()`. + fn train_and_quantize(data: &[f32], npoints: usize, ndims: usize) -> QuantizedData { + let (shift, inverse_scale) = train_sq_params(data, npoints, ndims); + quantize_1bit(data, npoints, ndims, &shift, inverse_scale) + } + + /// Train SQ parameters (shift, inverse_scale) from data. Test-only helper. + fn train_sq_params(data: &[f32], npoints: usize, ndims: usize) -> (Vec, f32) { + use diskann_quantization::scalar::train::ScalarQuantizationParameters; + use diskann_utils::views::MatrixView; + + let data_matrix = MatrixView::try_from(data, npoints, ndims) + .expect("data length must equal npoints * ndims"); + let quantizer = ScalarQuantizationParameters::default().train(data_matrix); + let shift = quantizer.shift().to_vec(); + let scale = quantizer.scale(); + let inverse_scale = if scale == 0.0 { 1.0 } else { 1.0 / scale }; + (shift, inverse_scale) + } + + #[test] + fn test_quantize_1bit_basic() { + // 4 points, 16 dims -- check packing correctness. + // bytes_per_vec is rounded up to a multiple of 8 for u64 alignment. + let ndims = 16; + let npoints = 4; + // All dimensions positive -> all bits should be 1. + let data: Vec = vec![1.0; npoints * ndims]; + let qd = train_and_quantize(&data, npoints, ndims); + + assert_eq!(qd.npoints, npoints); + assert_eq!(qd.ndims, ndims); + assert_eq!(qd.bytes_per_vec, 8); // ((16 + 63) / 64) * 8 = 8 + + // With all-identical positive values, after training shift/scale the + // quantization is deterministic. All bits for every point should be + // the same since all values are identical. + for i in 0..npoints { + let v = qd.get(i); + // All points should be identical. + assert_eq!(v, qd.get(0), "point {} differs from point 0", i); + } + } + + #[test] + fn test_quantize_1bit_roundtrip() { + // Verify that get() and get_u64() return consistent data for the same point. + let ndims = 64; // 8 bytes per vec -> exactly 1 u64 + let npoints = 10; + let data = make_binary_data(npoints, ndims, 42); + let qd = train_and_quantize(&data, npoints, ndims); + + assert_eq!(qd.bytes_per_vec, 8); + assert_eq!(qd.u64s_per_vec(), 1); + + for i in 0..npoints { + let bytes = qd.get(i); + let u64s = qd.get_u64(i); + + // Convert the byte slice to a u64 (little-endian) and compare. + let from_bytes = u64::from_le_bytes(bytes.try_into().unwrap()); + assert_eq!( + from_bytes, u64s[0], + "get() and get_u64() disagree for point {}", + i + ); + } + } + + #[test] + fn test_hamming_u64_identity() { + // Hamming distance of a vector with itself is always 0. + let a: Vec = vec![0xDEAD_BEEF_CAFE_BABE, 0x0123_4567_89AB_CDEF]; + assert_eq!(QuantizedData::hamming_u64(&a, &a), 0); + + let zeros: Vec = vec![0, 0, 0, 0]; + assert_eq!(QuantizedData::hamming_u64(&zeros, &zeros), 0); + + let ones: Vec = vec![u64::MAX, u64::MAX]; + assert_eq!(QuantizedData::hamming_u64(&ones, &ones), 0); + } + + #[test] + fn test_hamming_u64_all_different() { + // XOR of all-zeros and all-ones gives all-ones, popcount = 64 per word. + let a: Vec = vec![0u64; 3]; + let b: Vec = vec![u64::MAX; 3]; + assert_eq!(QuantizedData::hamming_u64(&a, &b), 64 * 3); + } + + #[test] + fn test_hamming_byte_matches_u64() { + // The byte-based and u64-based Hamming distance should agree. + let ndims = 128; // 16 bytes = 2 u64s + let npoints = 5; + let data = make_binary_data(npoints, ndims, 77); + let qd = train_and_quantize(&data, npoints, ndims); + + for i in 0..npoints { + for j in 0..npoints { + let d_byte = QuantizedData::hamming(qd.get(i), qd.get(j)); + let d_u64 = QuantizedData::hamming_u64(qd.get_u64(i), qd.get_u64(j)); + assert_eq!( + d_byte, d_u64, + "hamming mismatch for ({}, {}): byte={} u64={}", + i, j, d_byte, d_u64 + ); + } + } + } + + #[test] + fn test_compute_distance_matrix() { + // Verify symmetry and diagonal (f32::MAX) of the distance matrix. + let ndims = 64; + let npoints = 8; + let data = make_binary_data(npoints, ndims, 55); + let qd = train_and_quantize(&data, npoints, ndims); + + let indices: Vec = (0..npoints).collect(); + let dist = qd.compute_distance_matrix(&indices); + + let n = npoints; + // Diagonal must be f32::MAX. + for i in 0..n { + assert_eq!( + dist[i * n + i], + f32::MAX, + "diagonal at ({},{}) is not f32::MAX", + i, + i + ); + } + // Symmetry: dist[i][j] == dist[j][i] + for i in 0..n { + for j in (i + 1)..n { + assert_eq!( + dist[i * n + j], dist[j * n + i], + "asymmetry at ({}, {})", + i, j + ); + } + } + // Non-negative off-diagonal. + for i in 0..n { + for j in 0..n { + if i != j { + assert!( + dist[i * n + j] >= 0.0, + "negative distance at ({}, {}): {}", + i, + j, + dist[i * n + j] + ); + } + } + } + } + + #[test] + fn test_quantize_1bit_single_point() { + let ndims = 8; + let data = vec![1.0f32; ndims]; + let qd = train_and_quantize(&data, 1, ndims); + assert_eq!(qd.npoints, 1, "should have 1 point"); + assert_eq!(qd.ndims, ndims, "should preserve ndims"); + assert_eq!(qd.bytes_per_vec, 8, "bytes_per_vec should be 8 (rounded up to multiple of 8)"); + } + + #[test] + fn test_quantize_1bit_single_dim() { + // ndims=1, bytes_per_vec should round up to 8 (multiple of 8 for u64 alignment). + let npoints = 5; + let data: Vec = vec![1.0, -1.0, 0.5, -0.5, 0.0]; + let qd = train_and_quantize(&data, npoints, 1); + assert_eq!(qd.ndims, 1, "should preserve ndims=1"); + assert_eq!(qd.bytes_per_vec, 8, "bytes_per_vec for ndims=1 should be 8"); + assert_eq!(qd.npoints, npoints, "should have correct npoints"); + } + + #[test] + fn test_quantize_1bit_large_ndims() { + // ndims=1024, bytes_per_vec = ceil(1024/64) * 8 = 16 * 8 = 128. + let ndims = 1024; + let npoints = 3; + let data = make_binary_data(npoints, ndims, 42); + let qd = train_and_quantize(&data, npoints, ndims); + assert_eq!(qd.bytes_per_vec, 128, "bytes_per_vec for ndims=1024 should be 128"); + assert_eq!(qd.u64s_per_vec(), 16, "u64s_per_vec for ndims=1024 should be 16"); + } + + #[test] + fn test_quantize_zero_variance() { + // All identical data -- should not crash due to zero-variance guard. + let npoints = 10; + let ndims = 8; + let data = vec![42.0f32; npoints * ndims]; + let qd = train_and_quantize(&data, npoints, ndims); + assert_eq!(qd.npoints, npoints, "should succeed with zero-variance data"); + // All points should have identical bit patterns. + for i in 1..npoints { + assert_eq!( + qd.get(i), qd.get(0), + "zero-variance data should produce identical quantized vectors" + ); + } + } + + #[test] + fn test_quantize_negative_data() { + // All negative values should produce valid quantized data. + let npoints = 5; + let ndims = 16; + let data = vec![-5.0f32; npoints * ndims]; + let qd = train_and_quantize(&data, npoints, ndims); + assert_eq!(qd.npoints, npoints, "should succeed with all-negative data"); + } + + #[test] + fn test_hamming_single_bit_diff() { + // XOR with exactly 1 bit different should give distance 1. + let a: Vec = vec![0b0000_0000]; + let b: Vec = vec![0b0000_0001]; + assert_eq!( + QuantizedData::hamming_u64(&a, &b), 1, + "single bit difference should yield Hamming distance 1" + ); + } + + #[test] + fn test_compute_distance_matrix_single_point() { + let ndims = 64; + let data = make_binary_data(1, ndims, 42); + let qd = train_and_quantize(&data, 1, ndims); + let indices = vec![0]; + let dist = qd.compute_distance_matrix(&indices); + assert_eq!(dist.len(), 1, "1x1 matrix should have 1 element"); + assert_eq!(dist[0], f32::MAX, "diagonal for single point should be f32::MAX"); + } + + #[test] + fn test_compute_distance_matrix_two_identical() { + // Two identical points should have distance 0. + let ndims = 64; + let npoints = 2; + let data = vec![1.0f32; npoints * ndims]; // identical + let qd = train_and_quantize(&data, npoints, ndims); + let indices = vec![0, 1]; + let dist = qd.compute_distance_matrix(&indices); + assert_eq!( + dist[0 * 2 + 1], 0.0, + "two identical quantized vectors should have Hamming distance 0" + ); + assert_eq!( + dist[1 * 2 + 0], 0.0, + "symmetric: two identical quantized vectors should have Hamming distance 0" + ); + } + + #[test] + fn test_distances_to_leaders_empty() { + let ndims = 64; + let data = make_binary_data(3, ndims, 42); + let qd = train_and_quantize(&data, 3, ndims); + let leader_indices: Vec = vec![]; + let mut out: Vec = vec![]; + // Should not crash with empty leader list. + qd.distances_to_leaders(0, &leader_indices, &mut out); + assert!(out.is_empty(), "empty leader list should produce empty output"); + } + + #[test] + fn test_bytes_per_vec_alignment() { + // Verify bytes_per_vec is always a multiple of 8 for various ndims. + for ndims in [1, 7, 8, 9, 63, 64, 65, 127, 128, 129, 255, 256, 512, 1024] { + let data = vec![0.0f32; ndims]; + let qd = train_and_quantize(&data, 1, ndims); + assert_eq!( + qd.bytes_per_vec % 8, 0, + "bytes_per_vec ({}) should be a multiple of 8 for ndims={}", + qd.bytes_per_vec, ndims + ); + } + } + + #[test] + fn test_distances_to_leaders() { + // Verify distances_to_leaders matches manual pairwise computation. + let ndims = 64; + let npoints = 6; + let data = make_binary_data(npoints, ndims, 33); + let qd = train_and_quantize(&data, npoints, ndims); + + let point_idx = 0; + let leader_indices: Vec = vec![1, 3, 5]; + let mut out = vec![0.0f32; leader_indices.len()]; + qd.distances_to_leaders(point_idx, &leader_indices, &mut out); + + // Compare with direct hamming_u64 computation. + let pt = qd.get_u64(point_idx); + for (j, &leader_idx) in leader_indices.iter().enumerate() { + let ld = qd.get_u64(leader_idx); + let expected = QuantizedData::hamming_u64(pt, ld) as f32; + assert_eq!( + out[j], expected, + "distance to leader {} mismatch: got {}, expected {}", + leader_idx, out[j], expected + ); + } + } +} diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 9d3fd9c32..6682edc46 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -66,6 +66,11 @@ impl WithBits { pub fn new(quantizer: ScalarQuantizer) -> Self { Self { quantizer } } + + /// Access the underlying scalar quantizer. + pub fn quantizer(&self) -> &ScalarQuantizer { + &self.quantizer + } } ////////////// diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index ae987dca9..369dc41b2 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -14,7 +14,7 @@ byteorder.workspace = true clap = { workspace = true, features = ["derive"] } diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` diskann-vector = { workspace = true } -diskann-disk = { workspace = true } +diskann-disk = { workspace = true, features = ["pipnn"] } diskann-utils = { workspace = true } bytemuck.workspace = true num_cpus.workspace = true diff --git a/diskann-tools/src/utils/build_disk_index.rs b/diskann-tools/src/utils/build_disk_index.rs index 916c24c8a..34162571b 100644 --- a/diskann-tools/src/utils/build_disk_index.rs +++ b/diskann-tools/src/utils/build_disk_index.rs @@ -79,6 +79,7 @@ pub struct BuildDiskIndexParameters<'a> { pub build_quantization_type: QuantizationType, pub chunking_parameters: Option, pub dim_values: DimensionValues, + pub build_algorithm: diskann_disk::BuildAlgorithm, } /// The main function to build a disk index @@ -91,13 +92,14 @@ where StorageProviderType: StorageReadProvider + StorageWriteProvider + 'static, ::Reader: std::marker::Send, { - let build_parameters = DiskIndexBuildParameters::new( + let build_parameters = DiskIndexBuildParameters::new_with_algorithm( MemoryBudget::try_from_gb(parameters.index_build_ram_limit_gb)?, parameters.build_quantization_type, NumPQChunks::new_with( parameters.num_of_pq_chunks, parameters.dim_values.full_dim(), )?, + parameters.build_algorithm.clone(), ); let config = config::Builder::new_with( @@ -208,6 +210,7 @@ mod tests { build_quantization_type: QuantizationType::FP, chunking_parameters: None, dim_values: DimensionValues::new(128, 128), + build_algorithm: diskann_disk::BuildAlgorithm::default(), }; let result = build_disk_index::>( @@ -233,6 +236,7 @@ mod tests { build_quantization_type: QuantizationType::FP, chunking_parameters: None, dim_values: DimensionValues::new(128, 128), + build_algorithm: diskann_disk::BuildAlgorithm::default(), }; let result = build_disk_index::>(