diff --git a/diskann-providers/Cargo.toml b/diskann-providers/Cargo.toml index 65a7b0e0b..c917b0a57 100644 --- a/diskann-providers/Cargo.toml +++ b/diskann-providers/Cargo.toml @@ -47,6 +47,7 @@ vfs = { workspace = true, optional = true } [dev-dependencies] approx.workspace = true criterion.workspace = true +diskann = { workspace = true, features = ["testing"] } diskann-utils = { workspace = true, features = ["testing"] } iai-callgrind.workspace = true itertools.workspace = true diff --git a/diskann-providers/src/model/graph/provider/async_/caching/example.rs b/diskann-providers/src/model/graph/provider/async_/caching/example.rs index ab3edf0aa..c533cf4f9 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/example.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/example.rs @@ -3,18 +3,17 @@ * Licensed under the MIT license. */ +//! An example cache for demonstrating how to set up a [`CachingProvider`] with an inner +//! data provider. The inner provider used in this module is the test provider from +//! [`diskann::graph::test::provider`]. + use diskann::{ - graph::AdjacencyList, - provider::{self as core_provider, DefaultContext}, + graph::{AdjacencyList, test::provider as test_provider}, + provider::{self as core_provider}, }; use diskann_utils::future::AsyncFriendly; use diskann_vector::distance::Metric; -use crate::model::graph::provider::async_::{ - common::FullPrecision, - debug_provider::{self, DebugProvider}, -}; - use super::{ bf_cache::{self, Cache}, error::CacheAccessError, @@ -179,14 +178,14 @@ where // Provider Bridge // ///////////////////// -impl<'a> cache_provider::AsCacheAccessorFor<'a, debug_provider::FullAccessor<'a>> for ExampleCache { +impl<'a> cache_provider::AsCacheAccessorFor<'a, test_provider::Accessor<'a>> for ExampleCache { type Accessor = CacheAccessor<'a, bf_cache::VecCacher>; type Error = diskann::error::Infallible; fn as_cache_accessor_for( &'a self, - inner: debug_provider::FullAccessor<'a>, + inner: test_provider::Accessor<'a>, ) -> Result< - cache_provider::CachingAccessor, Self::Accessor>, + cache_provider::CachingAccessor, Self::Accessor>, Self::Error, > { let provider = inner.provider(); @@ -202,12 +201,12 @@ impl<'a> cache_provider::AsCacheAccessorFor<'a, debug_provider::FullAccessor<'a> } impl<'a> cache_provider::CachedFillSet>> - for debug_provider::FullAccessor<'a> + for test_provider::Accessor<'a> { } impl<'a> cache_provider::CachedAsElement<&'a [f32], CacheAccessor<'a, bf_cache::VecCacher>> - for debug_provider::FullAccessor<'a> + for test_provider::Accessor<'a> { type Error = CachingError; async fn cached_as_element<'b>( @@ -231,7 +230,7 @@ mod tests { use std::sync::Arc; use diskann::{ - graph::{DiskANNIndex, glue::SearchStrategy}, + graph::{DiskANNIndex, glue::SearchStrategy, search::Knn, search_output_buffer}, provider::{ Accessor, DataProvider, Delete, NeighborAccessor, NeighborAccessorMut, SetElement, }, @@ -243,66 +242,63 @@ mod tests { use rstest::rstest; use crate::{ - index::diskann_async::{self, tests as async_tests}, + index::diskann_async::tests as async_tests, model::graph::provider::async_::caching::provider::{AsCacheAccessorFor, CachingProvider}, utils as crate_utils, }; - const CTX: &DefaultContext = &DefaultContext; - fn test_provider( uncacheable: Option>, - ) -> CachingProvider { + ) -> CachingProvider { let dim = 2; + let max_degree = 10; + let start_id = u32::MAX; - let config = debug_provider::DebugConfig { - start_id: u32::MAX, - start_point: vec![0.0; dim], - max_degree: 10, - metric: Metric::L2, - }; - - let table = diskann_async::train_pq( - Matrix::new(0.0, 1, dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, + let config = test_provider::Config::new( + Metric::L2, + max_degree, + test_provider::StartPoint::new(start_id, vec![0.0; dim]), ) .unwrap(); CachingProvider::new( - DebugProvider::new(config, Arc::new(table)).unwrap(), + test_provider::Provider::new(config), ExampleCache::new(PowerOfTwo::new(1024 * 16).unwrap(), uncacheable), ) } + fn ctx() -> test_provider::Context { + test_provider::Context::new() + } + #[tokio::test] async fn basic_operations_happy_path() { let provider = test_provider(None); - let ctx = &DefaultContext; + let ctx = ctx(); // Translations do not yet exist. - assert!(provider.to_external_id(ctx, 0).is_err()); - assert!(provider.to_internal_id(ctx, &0).is_err()); + assert!(provider.to_external_id(&ctx, 0).is_err()); + assert!(provider.to_internal_id(&ctx, &0).is_err()); - assert_eq!(provider.inner().data_writes.get(), 0); - provider.set_element(CTX, &0, &[1.0, 2.0]).await.unwrap(); - assert_eq!(provider.inner().data_writes.get(), 1 /* increased */); + assert_eq!(provider.inner().metrics().set_vector, 0); + provider.set_element(&ctx, &0, &[1.0, 2.0]).await.unwrap(); + assert_eq!( + provider.inner().metrics().set_vector, + 1 /* increased */ + ); - assert_eq!(provider.to_external_id(ctx, 0).unwrap(), 0); - assert_eq!(provider.to_internal_id(ctx, &0).unwrap(), 0); + assert_eq!(provider.to_external_id(&ctx, 0).unwrap(), 0); + assert_eq!(provider.to_internal_id(&ctx, &0).unwrap(), 0); // Retrieval of a valid element. let mut accessor = provider .cache() - .as_cache_accessor_for(debug_provider::FullAccessor::new(provider.inner())) + .as_cache_accessor_for(test_provider::Accessor::new(provider.inner())) .unwrap(); - // Hit served from the underlying provider. - assert_eq!(provider.inner().full_reads.get(), 0); + // Hit served from the underlying provider (cache miss). let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 1); assert_eq!( accessor.cache().stats.get_local_misses(), 1, /* increased */ @@ -312,7 +308,6 @@ mod tests { // This time, the hit is served from the underlying cache. let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 1); assert_eq!(accessor.cache().stats.get_local_misses(), 1); assert_eq!( accessor.cache().stats.get_local_hits(), @@ -320,18 +315,18 @@ mod tests { ); // Adjacency List from Underlying - assert_eq!(provider.inner().neighbor_writes.get(), 0); + assert_eq!(provider.inner().metrics().set_neighbors, 0); accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap(); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors, 1, /* increased */ ); let mut list = AdjacencyList::new(); - assert_eq!(provider.inner().neighbor_reads.get(), 0); + assert_eq!(provider.inner().metrics().get_neighbors, 0); accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 1, /* increased */ ); assert_eq!( @@ -345,7 +340,7 @@ mod tests { list.clear(); accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[1, 2, 3]); - assert_eq!(provider.inner().neighbor_reads.get(), 1); + assert_eq!(provider.inner().metrics().get_neighbors, 1); assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1); assert_eq!( accessor.cache().graph.stats().get_local_hits(), @@ -358,7 +353,6 @@ mod tests { let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2 /* increased */,); assert_eq!( accessor.cache().stats.get_local_misses(), 2, /* increased */ @@ -368,7 +362,6 @@ mod tests { // Once more from the cache. let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2); assert_eq!(accessor.cache().stats.get_local_misses(), 2); assert_eq!( accessor.cache().stats.get_local_hits(), @@ -379,7 +372,7 @@ mod tests { accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[1, 2, 3]); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 2, /* increased */ ); assert_eq!( @@ -393,11 +386,11 @@ mod tests { accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[2, 3, 4]); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors, 2, /* increased */ ); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 3, /* increased */ ); assert_eq!( @@ -411,11 +404,12 @@ mod tests { assert_eq!(&*list, &[2, 3, 4, 1]); assert_eq!( - provider.inner().neighbor_writes.get(), - 3, /* increased */ + provider.inner().metrics().set_neighbors, + 2, + "append_vector doesn't go through set_neighbors counter" ); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 4, /* increased */ ); assert_eq!( @@ -426,33 +420,32 @@ mod tests { // Deletion. assert_eq!( - provider.status_by_internal_id(CTX, 0).await.unwrap(), + provider.status_by_internal_id(&ctx, 0).await.unwrap(), core_provider::ElementStatus::Valid ); assert_eq!( - provider.status_by_external_id(CTX, &0).await.unwrap(), + provider.status_by_external_id(&ctx, &0).await.unwrap(), core_provider::ElementStatus::Valid ); - assert!(provider.status_by_internal_id(CTX, 1).await.is_err()); - assert!(provider.status_by_external_id(CTX, &1).await.is_err()); + assert!(provider.status_by_internal_id(&ctx, 1).await.is_err()); + assert!(provider.status_by_external_id(&ctx, &1).await.is_err()); - provider.delete(CTX, &0).await.unwrap(); + provider.delete(&ctx, &0).await.unwrap(); assert_eq!( - provider.status_by_internal_id(CTX, 0).await.unwrap(), + provider.status_by_internal_id(&ctx, 0).await.unwrap(), core_provider::ElementStatus::Deleted ); assert_eq!( - provider.status_by_external_id(CTX, &0).await.unwrap(), + provider.status_by_external_id(&ctx, &0).await.unwrap(), core_provider::ElementStatus::Deleted ); - assert!(provider.status_by_internal_id(CTX, 1).await.is_err()); - assert!(provider.status_by_external_id(CTX, &1).await.is_err()); + assert!(provider.status_by_internal_id(&ctx, 1).await.is_err()); + assert!(provider.status_by_external_id(&ctx, &1).await.is_err()); - // Access the deleted element is still valid. + // Accessing the deleted element is still valid. let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2); assert_eq!(accessor.cache().stats.get_local_misses(), 2); assert_eq!( accessor.cache().stats.get_local_hits(), @@ -461,20 +454,18 @@ mod tests { accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[2, 3, 4, 1]); - assert_eq!(provider.inner().neighbor_writes.get(), 3); - assert_eq!(provider.inner().neighbor_reads.get(), 4); + assert_eq!(provider.inner().metrics().get_neighbors, 4); assert_eq!(accessor.cache().graph.stats().get_local_misses(), 4); assert_eq!( accessor.cache().graph.stats().get_local_hits(), 2, /* increased */ ); - provider.release(CTX, 0).await.unwrap(); - assert!(provider.status_by_internal_id(CTX, 0).await.is_err()); - assert!(provider.status_by_external_id(CTX, &0).await.is_err()); + provider.release(&ctx, 0).await.unwrap(); + assert!(provider.status_by_internal_id(&ctx, 0).await.is_err()); + assert!(provider.status_by_external_id(&ctx, &0).await.is_err()); assert!(accessor.get_element(0).await.is_err()); - assert_eq!(provider.inner().full_reads.get(), 2); assert_eq!( accessor.cache().stats.get_local_misses(), 3 /* increased */ @@ -482,8 +473,6 @@ mod tests { assert_eq!(accessor.cache().stats.get_local_hits(), 3); assert!(accessor.get_neighbors(0, &mut list).await.is_err()); - assert_eq!(provider.inner().neighbor_writes.get(), 3); - assert_eq!(provider.inner().neighbor_reads.get(), 4); assert_eq!( accessor.cache().graph.stats().get_local_misses(), 5 /* increased */ @@ -514,31 +503,32 @@ mod tests { // the provider and a call to `set_neighbors` is not made. let uncacheable = u32::MAX; let provider = test_provider(Some(vec![uncacheable])); + let ctx = ctx(); let mut accessor = provider .cache() - .as_cache_accessor_for(debug_provider::FullAccessor::new(provider.inner())) + .as_cache_accessor_for(test_provider::Accessor::new(provider.inner())) .unwrap(); - provider.set_element(CTX, &0, &[1.0, 2.0]).await.unwrap(); + provider.set_element(&ctx, &0, &[1.0, 2.0]).await.unwrap(); //---------------// // Cacheable IDs // //---------------// // Adjacency List from Underlying - assert_eq!(provider.inner().neighbor_writes.get(), 0); + assert_eq!(provider.inner().metrics().set_neighbors, 0); accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap(); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors, 1, /* increased */ ); let mut list = AdjacencyList::new(); - assert_eq!(provider.inner().neighbor_reads.get(), 0); + assert_eq!(provider.inner().metrics().get_neighbors, 0); accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 1, /* increased */ ); assert_eq!( @@ -552,7 +542,7 @@ mod tests { list.clear(); accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[1, 2, 3]); - assert_eq!(provider.inner().neighbor_reads.get(), 1); + assert_eq!(provider.inner().metrics().get_neighbors, 1); assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1); assert_eq!( accessor.cache().graph.stats().get_local_hits(), @@ -563,21 +553,21 @@ mod tests { // Uncacheable IDs // //-----------------// - assert_eq!(provider.inner().neighbor_writes.get(), 1); + assert_eq!(provider.inner().metrics().set_neighbors, 1); accessor.set_neighbors(uncacheable, &[4, 5]).await.unwrap(); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors, 2, /* increased */ ); // The retrieval is served by the inner provider. - assert_eq!(provider.inner().neighbor_reads.get(), 1); + assert_eq!(provider.inner().metrics().get_neighbors, 1); accessor .get_neighbors(uncacheable, &mut list) .await .unwrap(); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 2, /* increased */ ); assert_eq!( @@ -588,13 +578,13 @@ mod tests { assert_eq!(&*list, &[4, 5]); // Again, retrieval is served by the inner provider. - assert_eq!(provider.inner().neighbor_reads.get(), 2); + assert_eq!(provider.inner().metrics().get_neighbors, 2); accessor .get_neighbors(uncacheable, &mut list) .await .unwrap(); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 3, /* increased */ ); assert_eq!( @@ -609,106 +599,14 @@ mod tests { // Standard Tests // //----------------// - #[rstest] - #[case(1, 100)] - #[case(3, 7)] - #[case(4, 5)] - #[tokio::test] - async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) { - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - let start_id = u32::MAX; - let start_point = vec![grid_size as f32; dim]; - let metric = Metric::L2; - let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); - - let index_config = diskann::graph::config::Builder::new( - max_degree, - diskann::graph::config::MaxDegree::default_slack(), - l, - metric.into(), - ) - .build() - .unwrap(); - - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), - metric, - }; - - let mut vectors = ::generate_grid(dim, grid_size); - let table = diskann_async::train_pq( - async_tests::squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, - ) - .unwrap(); - - let provider = CachingProvider::new( - DebugProvider::new(test_config, Arc::new(table)).unwrap(), - ExampleCache::new(cache_size, None), - ); - let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); - - let adjacency_lists = match dim { - 1 => crate_utils::generate_1d_grid_adj_list(grid_size as u32), - 3 => crate_utils::genererate_3d_grid_adj_list(grid_size as u32), - 4 => crate_utils::generate_4d_grid_adj_list(grid_size as u32), - _ => panic!("Unsupported number of dimensions"), - }; - assert_eq!(adjacency_lists.len(), num_points); - assert_eq!(vectors.len(), num_points); - - let strategy = cache_provider::Cached::new(FullPrecision); - async_tests::populate_data(index.provider(), CTX, &vectors).await; - { - // Note: Without the fully qualified syntax - this fails to compile. - let mut accessor = as SearchStrategy< - cache_provider::CachingProvider, - [f32], - >>::search_accessor(&strategy, index.provider(), CTX) - .unwrap(); - async_tests::populate_graph(&mut accessor, &adjacency_lists).await; - - accessor - .set_neighbors(start_id, &[num_points as u32 - 1]) - .await - .unwrap(); - } - - let corpus: diskann_utils::views::Matrix = async_tests::squish(vectors.iter(), dim); - let mut paged_tests = Vec::new(); - - // Test with the zero query. - let query = vec![0.0; dim]; - let gt = crate::test_utils::groundtruth(corpus.as_view(), &query, |a, b| { - SquaredL2::evaluate(a, b) - }); - paged_tests.push(async_tests::PagedSearch::new(query, gt)); - - // Test with the start point to ensure it is filtered out. - let gt = crate::test_utils::groundtruth(corpus.as_view(), &start_point, |a, b| { - SquaredL2::evaluate(a, b) - }); - paged_tests.push(async_tests::PagedSearch::new(start_point.clone(), gt)); - - // Unfortunately - this is needed for the `check_grid_search` test. - vectors.push(start_point.clone()); - async_tests::check_grid_search(&index, &vectors, &paged_tests, strategy, strategy).await; - } - - fn check_stats(caching: &CachingProvider) { - let provider = caching.inner(); + fn check_stats(caching: &CachingProvider) { + let metrics = caching.inner().metrics(); let cache = caching.cache(); - println!("neighbor reads: {}", provider.neighbor_reads.get()); - println!("neighbor writes: {}", provider.neighbor_writes.get()); - println!("vector reads: {}", provider.full_reads.get()); - println!("vector writes: {}", provider.data_writes.get()); + println!("neighbor reads: {}", metrics.get_neighbors); + println!("neighbor writes: {}", metrics.set_neighbors); + println!("vector reads: {}", metrics.get_vector); + println!("vector writes: {}", metrics.set_vector); println!("neighbor hits: {}", cache.neighbor_stats.get_hits()); println!("neighbor misses: {}", cache.neighbor_stats.get_misses()); @@ -716,13 +614,10 @@ mod tests { println!("vector misses: {}", cache.vector_stats.get_misses()); // Neighbors - assert_eq!( - provider.neighbor_reads.get(), - cache.neighbor_stats.get_misses() - ); + assert_eq!(metrics.get_neighbors, cache.neighbor_stats.get_misses()); // Vectors - assert_eq!(provider.full_reads.get(), cache.vector_stats.get_misses()); + assert_eq!(metrics.get_vector, cache.vector_stats.get_misses()); } #[rstest] @@ -747,16 +642,7 @@ mod tests { let max_degree = 2 * dim; let num_points = (grid_size).pow(dim as u32); - let mut vectors = ::generate_grid(dim, grid_size); - let table = Arc::new( - diskann_async::train_pq( - async_tests::squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, - ) - .unwrap(), - ); + let vectors = ::generate_grid(dim, grid_size); let index_config = diskann::graph::config::Builder::new_with( max_degree, @@ -770,45 +656,37 @@ mod tests { .build() .unwrap(); - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), + let config = test_provider::Config::new( metric, - }; + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), + ) + .unwrap(); assert_eq!(vectors.len(), num_points); - // This is a little subtle, but we need `vectors` to contain the start point as - // its last element, but we **don't** want to include it in the index build. - // - // This basically means that we need to be careful with out index initialization. - vectors.push(vec![grid_size as f32; dim]); - // Initialize an index for a new round of building. let init_index = || { let provider = CachingProvider::new( - DebugProvider::new(test_config.clone(), table.clone()).unwrap(), + test_provider::Provider::new(config.clone()), ExampleCache::new(cache_size, None), ); Arc::new(DiskANNIndex::new(index_config.clone(), provider, None)) }; - let strategy = cache_provider::Cached::new(FullPrecision); + let strategy = cache_provider::Cached::new(test_provider::Strategy::new()); + let ctx = ctx(); // Build with full-precision single insert { let index = init_index(); for (i, v) in vectors.iter().take(num_points).enumerate() { index - .insert(strategy, CTX, &(i as u32), v.as_slice()) + .insert(strategy, &ctx, &(i as u32), v.as_slice()) .await .unwrap(); } check_stats(index.provider()); - - async_tests::check_grid_search(&index, &vectors, &[], strategy, strategy).await; - check_stats(index.provider()); } // Build with full-precision multi-insert @@ -821,23 +699,217 @@ mod tests { .map(|(id, v)| async_tools::VectorIdBoxSlice::new(id as u32, v.as_slice().into())) .collect(); - index.multi_insert(strategy, CTX, batch).await.unwrap(); + index.multi_insert(strategy, &ctx, batch).await.unwrap(); - async_tests::check_grid_search(&index, &vectors, &[], strategy, strategy).await; check_stats(index.provider()); } } + #[rstest] + #[case(1, 100)] + #[case(3, 7)] + #[case(4, 5)] + #[tokio::test] + async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) { + let l = 10; + let max_degree = 2 * dim; + let num_points = (grid_size).pow(dim as u32); + let start_id = u32::MAX; + let start_point = vec![grid_size as f32; dim]; + let metric = Metric::L2; + let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); + + let index_config = diskann::graph::config::Builder::new( + max_degree, + diskann::graph::config::MaxDegree::default_slack(), + l, + metric.into(), + ) + .build() + .unwrap(); + + let config = test_provider::Config::new( + metric, + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), + ) + .unwrap(); + + let mut vectors = ::generate_grid(dim, grid_size); + + let provider = CachingProvider::new( + test_provider::Provider::new(config), + ExampleCache::new(cache_size, None), + ); + let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); + + let adjacency_lists = match dim { + 1 => crate_utils::generate_1d_grid_adj_list(grid_size as u32), + 3 => crate_utils::genererate_3d_grid_adj_list(grid_size as u32), + 4 => crate_utils::generate_4d_grid_adj_list(grid_size as u32), + _ => panic!("Unsupported number of dimensions"), + }; + assert_eq!(adjacency_lists.len(), num_points); + assert_eq!(vectors.len(), num_points); + + let strategy = cache_provider::Cached::new(test_provider::Strategy::new()); + let ctx = ctx(); + async_tests::populate_data(index.provider(), &ctx, &vectors).await; + { + // Note: Without the fully qualified syntax - this fails to compile. + let mut accessor = + as SearchStrategy< + cache_provider::CachingProvider, + [f32], + >>::search_accessor(&strategy, index.provider(), &ctx) + .unwrap(); + async_tests::populate_graph(&mut accessor, &adjacency_lists).await; + + accessor + .set_neighbors(start_id, &[num_points as u32 - 1]) + .await + .unwrap(); + } + + // Test with the zero query — the farthest point from the entry point. + let search_k = dim + 1; + let search_l = 10; + { + let query = vec![0.0f32; dim]; + let mut ids = vec![0u32; search_k]; + let mut distances = vec![0.0f32; search_k]; + let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let knn = Knn::new_default(search_k, search_l).unwrap(); + index + .search(knn, &strategy, &ctx, query.as_slice(), &mut output) + .await + .unwrap(); + + // The nearest neighbor should be id 0 with distance 0. + assert_eq!(ids[0], 0, "expected the nearest neighbor to be 0"); + assert_eq!(distances[0], 0.0, "expected the nearest distance to be 0"); + for &d in &distances[1..search_k] { + assert_eq!( + d, 1.0, + "expected corner query close neighbor to have distance 1.0" + ); + } + } + + // Test with the start point to ensure it is filtered out. + { + let query = start_point.clone(); + let mut ids = vec![0u32; search_k]; + let mut distances = vec![0.0f32; search_k]; + let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let knn = Knn::new_default(search_k, search_l).unwrap(); + index + .search(knn, &strategy, &ctx, query.as_slice(), &mut output) + .await + .unwrap(); + + assert_ne!( + ids[0] as usize, num_points, + "start point should not be returned" + ); + } + + // Paged Search + let corpus: Matrix = async_tests::squish(vectors.iter(), dim); + + // Test paged search with the zero query. + { + let query = vec![0.0f32; dim]; + let mut gt = crate::test_utils::groundtruth(corpus.as_view(), &query, |a, b| { + SquaredL2::evaluate(a, b) + }); + let max_candidates = gt.len(); + let mut state = index + .start_paged_search(strategy, &ctx, query.as_slice(), search_l) + .await + .unwrap(); + + let mut buffer = vec![diskann::neighbor::Neighbor::::default(); search_k]; + let mut seen = 0; + while !gt.is_empty() { + let count = index + .next_search_results::, [f32]>( + &ctx, + &mut state, + search_k, + &mut buffer, + ) + .await + .unwrap(); + for b in buffer.iter().take(count) { + let m = gt.iter().position(|g| g.id == b.id); + if let Some(j) = m { + gt.remove(j); + } + seen += 1; + if seen == max_candidates { + break; + } + } + if seen == max_candidates { + break; + } + } + } + + // Test paged search with the start point. + { + let query = start_point.clone(); + let mut gt = crate::test_utils::groundtruth(corpus.as_view(), &query, |a, b| { + SquaredL2::evaluate(a, b) + }); + let max_candidates = gt.len(); + let mut state = index + .start_paged_search(strategy, &ctx, query.as_slice(), search_l) + .await + .unwrap(); + + let mut buffer = vec![diskann::neighbor::Neighbor::::default(); search_k]; + let mut seen = 0; + while !gt.is_empty() { + let count = index + .next_search_results::, [f32]>( + &ctx, + &mut state, + search_k, + &mut buffer, + ) + .await + .unwrap(); + for b in buffer.iter().take(count) { + let m = gt.iter().position(|g| g.id == b.id); + if let Some(j) = m { + gt.remove(j); + } + seen += 1; + if seen == max_candidates { + break; + } + } + if seen == max_candidates { + break; + } + } + } + + // Unfortunately - this is needed for the `check_grid_search` test. + vectors.push(start_point.clone()); + } + #[tokio::test] async fn test_inplace_delete_2d() { // create small index instance let metric = Metric::L2; let num_points = 4; - let strategy = cache_provider::Cached::new(FullPrecision); + let strategy = cache_provider::Cached::new(test_provider::DeletionAwareStrategy::new()); let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); let start_id = num_points as u32; let start_point = vec![0.5, 0.5]; - let dim = start_point.len(); let index_config = diskann::graph::config::Builder::new( 4, // target_degree @@ -848,32 +920,24 @@ mod tests { .build() .unwrap(); - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), + let config = test_provider::Config::new( metric, - }; - - // The contents of the table don't matter for this test because we use full - // precision only. - let table = diskann_async::train_pq( - Matrix::new(0.5, 1, dim).as_view(), - dim, - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), ) .unwrap(); let index = DiskANNIndex::new( index_config, CachingProvider::new( - DebugProvider::new(test_config, Arc::new(table)).unwrap(), + test_provider::Provider::new(config), ExampleCache::new(cache_size, None), ), None, ); + let ctx = ctx(); + // vectors are the four corners of a square, with the start point in the middle // the middle point forms an edge to each corner, while corners form an edge // to their opposite vertex vertically as well as the middle @@ -892,19 +956,20 @@ mod tests { ]; // Note: Without the fully qualified syntax - this fails to compile. - let mut accessor = as SearchStrategy< - cache_provider::CachingProvider, - [f32], - >>::search_accessor(&strategy, index.provider(), CTX) - .unwrap(); + let mut accessor = + as SearchStrategy< + cache_provider::CachingProvider, + [f32], + >>::search_accessor(&strategy, index.provider(), &ctx) + .unwrap(); - async_tests::populate_data(index.provider(), CTX, &vectors).await; + async_tests::populate_data(index.provider(), &ctx, &vectors).await; async_tests::populate_graph(&mut accessor, &adjacency_lists).await; index .inplace_delete( strategy, - CTX, + &ctx, &3, // id to delete 3, // num_to_replace diskann::graph::InplaceDeleteMethod::VisitedAndTopK { @@ -919,7 +984,7 @@ mod tests { assert!( index .data_provider - .status_by_internal_id(CTX, 3) + .status_by_internal_id(&ctx, 3) .await .unwrap() .is_deleted() diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs deleted file mode 100644 index 0c0ba780e..000000000 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ /dev/null @@ -1,1453 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{ - collections::{HashMap, hash_map}, - sync::{ - Arc, RwLock, RwLockReadGuard, RwLockWriteGuard, - atomic::{AtomicUsize, Ordering}, - }, -}; - -use diskann::{ - ANNError, ANNErrorKind, ANNResult, - graph::{ - AdjacencyList, - glue::{ - AsElement, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, - InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, - }, - }, - provider::{ - self, Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultAccessor, - DefaultContext, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, - NeighborAccessorMut, - }, - tracked_warn, - utils::VectorRepr, -}; -use diskann_quantization::CompressInto; -use diskann_vector::distance::Metric; -use thiserror::Error; - -use crate::{ - model::{ - FixedChunkPQTable, - distance::{DistanceComputer, QueryComputer}, - graph::provider::async_::{ - common::{FullPrecision, Internal, Panics, Quantized}, - distances::{self, pq::Hybrid}, - postprocess, - }, - pq, - }, - utils::BridgeErr, -}; - -#[derive(Debug, Clone)] -pub struct DebugConfig { - pub start_id: u32, - pub start_point: Vec, - pub max_degree: usize, - pub metric: Metric, -} - -/// A version of `DebugConfig` that has the compressed representation of `start_point` in -/// addition to the full-precision representation. -#[derive(Debug, Clone)] -struct InternalConfig { - start_id: u32, - start_point: Datum, - max_degree: usize, - metric: Metric, -} - -/// A combined full-precision and PQ quantized vector. -#[derive(Debug, Default, Clone)] -pub struct Datum { - full: Vec, - quant: Vec, -} - -impl Datum { - /// Create a new `Datum`. - fn new(full: Vec, quant: Vec) -> Self { - Self { full, quant } - } - - /// Return a reference to the full-precision vector. - fn full(&self) -> &[f32] { - &self.full - } - - /// Return a reference to the quantized vector. - fn quant(&self) -> &[u8] { - &self.quant - } -} - -/// A container for `Datum`s within the `Debug` provider. -/// -/// This tracks whether items are valid or have been marked as deleted inline. -#[derive(Debug, Clone)] -pub enum Vector { - Valid(Datum), - Deleted(Datum), -} - -impl Vector { - /// Change `self` to be `Self::Deleted`, leaving the internal `Datum` unchanged. - fn mark_deleted(&mut self) { - *self = match self.take() { - Self::Valid(v) => Self::Deleted(v), - Self::Deleted(v) => Self::Deleted(v), - } - } - - /// Take the internal `Datum` and construct a new instance of `Self`. - /// - /// Leave the caller with an empty `Datum`. - fn take(&mut self) -> Self { - match self { - Self::Valid(v) => Self::Valid(std::mem::take(v)), - Self::Deleted(v) => Self::Deleted(std::mem::take(v)), - } - } - - /// Return `true` if `self` has been marked as deleted. Otherwise, return `false`. - fn is_deleted(&self) -> bool { - matches!(self, Self::Deleted(_)) - } -} - -impl std::ops::Deref for Vector { - type Target = Datum; - fn deref(&self) -> &Datum { - match self { - Self::Valid(v) => v, - Self::Deleted(v) => v, - } - } -} - -/// A simple increment-only thread-safe counter. -#[derive(Debug)] -pub struct Counter(AtomicUsize); - -impl Counter { - /// Construct a new counter with a count of 0. - fn new() -> Self { - Self(AtomicUsize::new(0)) - } - - /// Increment the counter by 1. - pub(crate) fn increment(&self) { - self.0.fetch_add(1, Ordering::Relaxed); - } - - /// Return the current value of the counter. - pub(crate) fn get(&self) -> usize { - self.0.load(Ordering::Relaxed) - } -} - -pub struct DebugProvider { - config: InternalConfig, - - pub pq_table: Arc, - pub data: RwLock>, - pub neighbors: RwLock>>, - - // Counters - pub full_reads: Counter, - pub quant_reads: Counter, - pub neighbor_reads: Counter, - pub data_writes: Counter, - pub neighbor_writes: Counter, - - // Track whether the `insert_search_accessor` is invoked. - pub insert_search_accessor_calls: Counter, -} - -impl DebugProvider { - pub fn new(config: DebugConfig, pq_table: Arc) -> ANNResult { - // Compress the start point. - let mut pq = vec![0u8; pq_table.get_num_chunks()]; - pq_table - .compress_into(config.start_point.as_slice(), pq.as_mut_slice()) - .bridge_err()?; - - let config = InternalConfig { - start_id: config.start_id, - start_point: Datum::new(config.start_point, pq), - max_degree: config.max_degree, - metric: config.metric, - }; - - let mut data = HashMap::new(); - data.insert(config.start_id, Vector::Valid(config.start_point.clone())); - - let mut neighbors = HashMap::new(); - neighbors.insert(config.start_id, Vec::new()); - - Ok(Self { - config, - pq_table: pq_table.clone(), - data: RwLock::new(data), - neighbors: RwLock::new(neighbors), - full_reads: Counter::new(), - quant_reads: Counter::new(), - neighbor_reads: Counter::new(), - data_writes: Counter::new(), - neighbor_writes: Counter::new(), - insert_search_accessor_calls: Counter::new(), - }) - } - - /// Return the dimension of the full-precision data. - pub fn dim(&self) -> usize { - self.config.start_point.full().len() - } - - /// Return the maximum degree that can be held by this graph. - pub fn max_degree(&self) -> usize { - self.config.max_degree - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn data(&self) -> RwLockReadGuard<'_, HashMap> { - self.data.read().expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn data_mut(&self) -> RwLockWriteGuard<'_, HashMap> { - self.data.write().expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn neighbors(&self) -> RwLockReadGuard<'_, HashMap>> { - self.neighbors - .read() - .expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn neighbors_mut(&self) -> RwLockWriteGuard<'_, HashMap>> { - self.neighbors - .write() - .expect("cannot recover from lock poison") - } - - fn is_deleted(&self, id: u32) -> Result { - match self.data().get(&id) { - Some(element) => Ok(element.is_deleted()), - None => Err(InvalidId::Internal(id)), - } - } -} - -/// Light-weight error type for reporting access to an invalid ID. -#[derive(Debug, Clone, Copy, Error)] -pub enum InvalidId { - #[error("internal id {0} not initialized")] - Internal(u32), - #[error("external id {0} not initialized")] - External(u32), - #[error("is start point {0}")] - IsStartPoint(u32), -} - -diskann::always_escalate!(InvalidId); - -impl From for ANNError { - #[track_caller] - fn from(err: InvalidId) -> ANNError { - ANNError::opaque(err) - } -} - -////////////////// -// DataProvider // -////////////////// - -impl DataProvider for DebugProvider { - type Context = DefaultContext; - type InternalId = u32; - type ExternalId = u32; - type Error = InvalidId; - - fn to_internal_id( - &self, - _context: &DefaultContext, - gid: &Self::ExternalId, - ) -> Result { - // Check that the ID actually exists - let valid = self.data().contains_key(gid); - if valid { - Ok(*gid) - } else { - Err(InvalidId::External(*gid)) - } - } - - fn to_external_id( - &self, - _context: &DefaultContext, - id: Self::InternalId, - ) -> Result { - // Check that the ID actually exists - let valid = self.data().contains_key(&id); - if valid { - Ok(id) - } else { - Err(InvalidId::External(id)) - } - } -} - -impl Delete for DebugProvider { - async fn delete( - &self, - _context: &Self::Context, - gid: &Self::ExternalId, - ) -> Result<(), Self::Error> { - if *gid == self.config.start_id { - return Err(InvalidId::IsStartPoint(*gid)); - } - - let mut guard = self.data_mut(); - match guard.entry(*gid) { - hash_map::Entry::Occupied(mut occupied) => { - occupied.get_mut().mark_deleted(); - Ok(()) - } - hash_map::Entry::Vacant(_) => Err(InvalidId::External(*gid)), - } - } - - async fn release( - &self, - _context: &Self::Context, - id: Self::InternalId, - ) -> Result<(), Self::Error> { - if id == self.config.start_id { - return Err(InvalidId::IsStartPoint(id)); - } - - // NOTE: acquire the locks in the order `data` then `neighbors`. - let mut data = self.data_mut(); - let mut neighbors = self.neighbors_mut(); - - let v = data.remove(&id); - let u = neighbors.remove(&id); - - if v.is_none() || u.is_none() { - Err(InvalidId::Internal(id)) - } else { - Ok(()) - } - } - - async fn status_by_internal_id( - &self, - _context: &Self::Context, - id: Self::InternalId, - ) -> Result { - if self.is_deleted(id)? { - Ok(provider::ElementStatus::Deleted) - } else { - Ok(provider::ElementStatus::Valid) - } - } - - fn status_by_external_id( - &self, - context: &Self::Context, - gid: &Self::ExternalId, - ) -> impl Future> + Send { - self.status_by_internal_id(context, *gid) - } -} - -impl provider::SetElement<[f32]> for DebugProvider { - type SetError = ANNError; - - type Guard = provider::NoopGuard; - - fn set_element( - &self, - _context: &Self::Context, - id: &Self::ExternalId, - element: &[f32], - ) -> impl Future> + Send { - #[derive(Debug, Clone, Copy, Error)] - #[error("vector id {0} is already assigned")] - pub struct AlreadyAssigned(u32); - - diskann::always_escalate!(AlreadyAssigned); - - impl From for ANNError { - #[track_caller] - fn from(err: AlreadyAssigned) -> Self { - Self::new(ANNErrorKind::IndexError, err) - } - } - - // NOTE: acquire the locks in the order `vectors` then `neighbors`. - let result = match self.data_mut().entry(*id) { - hash_map::Entry::Occupied(_) => Err(AlreadyAssigned(*id).into()), - hash_map::Entry::Vacant(data) => match self.neighbors_mut().entry(*id) { - hash_map::Entry::Occupied(_) => Err(AlreadyAssigned(*id).into()), - hash_map::Entry::Vacant(neighbors) => { - self.data_writes.increment(); - - let mut pq = vec![0u8; self.pq_table.get_num_chunks()]; - match self - .pq_table - .compress_into(element, pq.as_mut_slice()) - .bridge_err() - { - Ok(()) => { - data.insert(Vector::Valid(Datum::new(element.into(), pq))); - neighbors.insert(Vec::new()); - Ok(provider::NoopGuard::new(*id)) - } - Err(err) => Err(ANNError::from(err)), - } - } - }, - }; - - std::future::ready(result) - } -} - -impl postprocess::DeletionCheck for DebugProvider { - fn deletion_check(&self, id: u32) -> bool { - match self.is_deleted(id) { - Ok(is_deleted) => is_deleted, - Err(err) => { - tracked_warn!("Deletion post-process failed with error {err} - continuing"); - true - } - } - } -} - -/////////////// -// Accessors // -/////////////// - -#[derive(Debug, Clone, Copy, Error)] -#[error("Attempt to access an invalid id: {0}")] -pub struct AccessedInvalidId(u32); - -diskann::always_escalate!(AccessedInvalidId); - -impl From for ANNError { - #[track_caller] - fn from(err: AccessedInvalidId) -> Self { - Self::opaque(err) - } -} - -impl DefaultAccessor for DebugProvider { - type Accessor<'a> = DebugNeighborAccessor<'a>; - - fn default_accessor(&self) -> Self::Accessor<'_> { - DebugNeighborAccessor::new(self) - } -} - -#[derive(Clone, Copy)] -pub struct DebugNeighborAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> DebugNeighborAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for DebugNeighborAccessor<'_> { - type Id = u32; -} - -impl NeighborAccessor for DebugNeighborAccessor<'_> { - fn get_neighbors( - self, - id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - let result = match self.provider.neighbors().get(&id) { - Some(v) => { - self.provider.neighbor_reads.increment(); - neighbors.overwrite_trusted(v); - Ok(self) - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } -} - -impl NeighborAccessorMut for DebugNeighborAccessor<'_> { - fn set_neighbors( - self, - id: Self::Id, - neighbors: &[Self::Id], - ) -> impl Future> + Send { - assert!(neighbors.len() <= self.provider.config.max_degree); - let result = match self.provider.neighbors_mut().get_mut(&id) { - Some(v) => { - self.provider.neighbor_writes.increment(); - v.clear(); - v.extend_from_slice(neighbors); - Ok(self) - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } - - fn append_vector( - self, - id: Self::Id, - neighbors: &[Self::Id], - ) -> impl Future> + Send { - let result = match self.provider.neighbors_mut().get_mut(&id) { - Some(v) => { - assert!( - v.len().checked_add(neighbors.len()).unwrap() - <= self.provider.config.max_degree, - "current = {:?}, new = {:?}, id = {}", - v, - neighbors, - id - ); - - let check = neighbors.iter().try_for_each(|n| { - if v.contains(n) { - Err(ANNError::message( - ANNErrorKind::Opaque, - format!("id {} is duplicated", n), - )) - } else { - Ok(()) - } - }); - - match check { - Ok(()) => { - self.provider.neighbor_writes.increment(); - v.extend_from_slice(neighbors); - Ok(self) - } - Err(err) => Err(err), - } - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } -} - -//---------------// -// Full Accessor // -//---------------// - -pub struct FullAccessor<'a> { - provider: &'a DebugProvider, - buffer: Box<[f32]>, -} - -impl<'a> FullAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - let buffer = (0..provider.dim()).map(|_| 0.0).collect(); - Self { provider, buffer } - } - - pub fn provider(&self) -> &DebugProvider { - self.provider - } -} - -impl HasId for FullAccessor<'_> { - type Id = u32; -} - -impl Accessor for FullAccessor<'_> { - type Extended = Box<[f32]>; - type Element<'a> - = &'a [f32] - where - Self: 'a; - type ElementRef<'a> = &'a [f32]; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.full_reads.increment(); - self.buffer.copy_from_slice(v.full()); - Ok(&*self.buffer) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl diskann::provider::CacheableAccessor for FullAccessor<'_> { - type Map = diskann_utils::lifetime::Slice; - - fn as_cached<'a, 'b>(element: &'a &'b [f32]) -> &'a &'b [f32] - where - Self: 'a + 'b, - { - element - } - - fn from_cached<'a>(element: &'a [f32]) -> &'a [f32] - where - Self: 'a, - { - element - } -} - -impl SearchExt for FullAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for FullAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildDistanceComputer for FullAccessor<'_> { - type DistanceComputerError = Panics; - type DistanceComputer = ::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(f32::distance( - self.provider.config.metric, - Some(self.provider.dim()), - )) - } -} - -impl BuildQueryComputer<[f32]> for FullAccessor<'_> { - type QueryComputerError = Panics; - type QueryComputer = ::QueryDistance; - - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(f32::query_distance(from, self.provider.config.metric)) - } -} - -impl ExpandBeam<[f32]> for FullAccessor<'_> {} -impl FillSet for FullAccessor<'_> {} - -impl postprocess::AsDeletionCheck for FullAccessor<'_> { - type Checker = DebugProvider; - fn as_deletion_check(&self) -> &Self::Checker { - self.provider - } -} - -//----------------// -// Quant Accessor // -//----------------// - -pub struct QuantAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> QuantAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for QuantAccessor<'_> { - type Id = u32; -} - -impl Accessor for QuantAccessor<'_> { - type Extended = Vec; - type Element<'a> - = Vec - where - Self: 'a; - type ElementRef<'a> = &'a [u8]; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.quant_reads.increment(); - Ok(v.quant().to_owned()) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl SearchExt for QuantAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for QuantAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildQueryComputer<[f32]> for QuantAccessor<'_> { - type QueryComputerError = Panics; - type QueryComputer = pq::distance::QueryComputer>; - - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(QueryComputer::new( - self.provider.pq_table.clone(), - self.provider.config.metric, - from, - None, - ) - .unwrap()) - } -} - -impl ExpandBeam<[f32]> for QuantAccessor<'_> {} - -impl postprocess::AsDeletionCheck for QuantAccessor<'_> { - type Checker = DebugProvider; - fn as_deletion_check(&self) -> &Self::Checker { - self.provider - } -} - -//-----------------// -// Hybrid Accessor // -//-----------------// - -pub struct HybridAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> HybridAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for HybridAccessor<'_> { - type Id = u32; -} - -impl Accessor for HybridAccessor<'_> { - type Extended = Hybrid, Vec>; - type Element<'a> - = Hybrid, Vec> - where - Self: 'a; - type ElementRef<'a> = Hybrid<&'a [f32], &'a [u8]>; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.full_reads.increment(); - Ok(Hybrid::Full(v.full().to_owned())) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl SearchExt for HybridAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for HybridAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildDistanceComputer for HybridAccessor<'_> { - type DistanceComputerError = Panics; - type DistanceComputer = distances::pq::HybridComputer; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(distances::pq::HybridComputer::new( - DistanceComputer::new(self.provider.pq_table.clone(), self.provider.config.metric) - .unwrap(), - f32::distance(self.provider.config.metric, Some(self.provider.dim())), - )) - } -} - -impl FillSet for HybridAccessor<'_> { - async fn fill_set( - &mut self, - set: &mut HashMap, - itr: Itr, - ) -> Result<(), Self::GetError> - where - Itr: Iterator + Send + Sync, - { - let threshold = 1; // one full vec per fill - let data = self.provider.data(); - itr.enumerate().for_each(|(i, id)| { - let e = set.entry(id); - if i < threshold { - e.and_modify(|v| { - if !v.is_full() { - let element = data.get(&id).unwrap(); - *v = Hybrid::Full(element.full().to_owned()); - } - }) - .or_insert_with(|| { - let element = data.get(&id).unwrap(); - Hybrid::Full(element.full().to_owned()) - }); - } else { - e.or_insert_with(|| { - let element = data.get(&id).unwrap(); - Hybrid::Quant(element.quant().to_owned()) - }); - } - }); - Ok(()) - } -} - -//////////////// -// Strategies // -//////////////// - -impl SearchStrategy for Internal { - type QueryComputer = ::QueryDistance; - type PostProcessor = postprocess::RemoveDeletedIdsAndCopy; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = FullAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -impl SearchStrategy for FullPrecision { - type QueryComputer = ::QueryDistance; - type PostProcessor = Pipeline; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = FullAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -impl SearchStrategy for Internal { - type QueryComputer = pq::distance::QueryComputer>; - type PostProcessor = postprocess::RemoveDeletedIdsAndCopy; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = QuantAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -impl SearchStrategy for Quantized { - type QueryComputer = pq::distance::QueryComputer>; - type PostProcessor = Pipeline; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = QuantAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -impl PruneStrategy for FullPrecision { - type DistanceComputer = ::Distance; - type PruneAccessor<'a> = FullAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -impl<'a> AsElement<&'a [f32]> for FullAccessor<'a> { - type Error = Panics; - fn as_element( - &mut self, - vector: &'a [f32], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(vector)) - } -} - -impl PruneStrategy for Quantized { - type DistanceComputer = distances::pq::HybridComputer; - type PruneAccessor<'a> = HybridAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::PruneAccessorError> { - Ok(HybridAccessor::new(provider)) - } -} - -impl<'a> AsElement<&'a [f32]> for HybridAccessor<'a> { - type Error = Panics; - fn as_element( - &mut self, - vector: &'a [f32], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(Hybrid::Full(vector.to_vec()))) - } -} - -impl InsertStrategy for FullPrecision { - type PruneStrategy = Self; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn insert_search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - provider.insert_search_accessor_calls.increment(); - self.search_accessor(provider, context) - } -} - -impl InsertStrategy for Quantized { - type PruneStrategy = Self; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn insert_search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - provider.insert_search_accessor_calls.increment(); - self.search_accessor(provider, context) - } -} - -impl InplaceDeleteStrategy for FullPrecision { - type DeleteElement<'a> = [f32]; - type DeleteElementGuard = Vec; - type DeleteElementError = Panics; - type PruneStrategy = Self; - type SearchStrategy = Internal; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(*self) - } - - fn get_delete_element<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - id: ::InternalId, - ) -> impl Future> + Send - { - futures_util::future::ok(provider.data().get(&id).unwrap().full().to_vec()) - } -} - -impl InplaceDeleteStrategy for Quantized { - type DeleteElement<'a> = [f32]; - type DeleteElementGuard = Vec; - type DeleteElementError = Panics; - type PruneStrategy = Self; - type SearchStrategy = Internal; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(*self) - } - - fn get_delete_element<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - id: ::InternalId, - ) -> impl Future> + Send - { - futures_util::future::ok(provider.data().get(&id).unwrap().full().to_vec()) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use diskann::{ - graph::{self, DiskANNIndex}, - provider::{Guard, SetElement}, - utils::async_tools::VectorIdBoxSlice, - }; - use diskann_vector::{PureDistanceFunction, distance::SquaredL2}; - use rstest::rstest; - - use super::*; - use crate::{ - index::diskann_async::{ - tests::{ - GenerateGrid, PagedSearch, check_grid_search, populate_data, populate_graph, squish, - }, - train_pq, - }, - test_utils::groundtruth, - utils, - }; - - #[tokio::test] - async fn basic_operations() { - let dim = 2; - let ctx = &DefaultContext; - - let debug_config = DebugConfig { - start_id: u32::MAX, - start_point: vec![0.0; dim], - max_degree: 10, - metric: Metric::L2, - }; - - let vectors = [vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; - let pq_table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - let provider = DebugProvider::new(debug_config, Arc::new(pq_table)).unwrap(); - - provider - .set_element(ctx, &0, &[1.0, 1.0]) - .await - .unwrap() - .complete() - .await; - - // internal id = external id - assert_eq!(provider.to_internal_id(ctx, &0).unwrap(), 0); - assert_eq!(provider.to_external_id(ctx, 0).unwrap(), 0); - - let mut accessor = FullAccessor::new(&provider); - - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 1); - - let mut neighbors = AdjacencyList::new(); - - let accessor = provider.default_accessor(); - let res = accessor.get_neighbors(0, &mut neighbors).await; - assert!(res.is_ok()); - assert_eq!(provider.neighbor_reads.get(), 1); - - let accessor = provider.default_accessor(); - let res = accessor.set_neighbors(0, &[1, 2, 3]).await; - assert!(res.is_ok()); - assert_eq!(provider.neighbor_writes.get(), 1); - - // delete and release vector 0 - let res = provider.delete(&DefaultContext, &0).await; - assert!(res.is_ok()); - assert_eq!( - ElementStatus::Deleted, - provider - .status_by_external_id(&DefaultContext, &0) - .await - .unwrap() - ); - - let mut accessor = FullAccessor::new(&provider); - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 2); - - let mut accessor = HybridAccessor::new(&provider); - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 3); - - // Releasing should make the element unreachable. - let res = provider.release(&DefaultContext, 0).await; - assert!(res.is_ok()); - assert!( - provider - .status_by_external_id(&DefaultContext, &0) - .await - .is_err() - ); - } - - pub fn new_quant_index( - index_config: graph::Config, - debug_config: DebugConfig, - pq_table: FixedChunkPQTable, - ) -> Arc> { - let data = DebugProvider::new(debug_config, Arc::new(pq_table)).unwrap(); - Arc::new(DiskANNIndex::new(index_config, data, None)) - } - - #[rstest] - #[case(1, 100)] - #[case(3, 7)] - #[case(4, 5)] - #[tokio::test] - async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) { - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - let start_id = u32::MAX; - - let index_config = graph::config::Builder::new( - max_degree, - graph::config::MaxDegree::default_slack(), - l, - (Metric::L2).into(), - ) - .build() - .unwrap(); - - let debug_config = DebugConfig { - start_id, - start_point: vec![grid_size as f32; dim], - max_degree, - metric: Metric::L2, - }; - - let adjacency_lists = match dim { - 1 => utils::generate_1d_grid_adj_list(grid_size as u32), - 3 => utils::genererate_3d_grid_adj_list(grid_size as u32), - 4 => utils::generate_4d_grid_adj_list(grid_size as u32), - _ => panic!("Unsupported number of dimensions"), - }; - let mut vectors = f32::generate_grid(dim, grid_size); - - assert_eq!(adjacency_lists.len(), num_points); - assert_eq!(vectors.len(), num_points); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - - let index = new_quant_index(index_config, debug_config, table); - { - let mut neighbor_accessor = index.provider().default_accessor(); - populate_data(index.provider(), &DefaultContext, &vectors).await; - populate_graph(&mut neighbor_accessor, &adjacency_lists).await; - - // Set the adjacency list for the start point. - neighbor_accessor - .set_neighbors(start_id, &[num_points as u32 - 1]) - .await - .unwrap(); - } - - // The corpus of actual vectors consists of all but the last point, which we use - // as the start point. - // - // So, when we compute the corpus used during groundtruth generation, we take all - // but this last point. - let corpus: diskann_utils::views::Matrix = - squish(vectors.iter().take(num_points), dim); - - let mut paged_tests = Vec::new(); - - // Test with the zero query. - let query = vec![0.0; dim]; - let gt = groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b)); - paged_tests.push(PagedSearch::new(query, gt)); - - // Test with the start point to ensure it is filtered out. - let query = vectors.last().unwrap(); - let gt = groundtruth(corpus.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); - paged_tests.push(PagedSearch::new(query.clone(), gt)); - - // Unfortunately - this is needed for the `check_grid_search` test. - vectors.push(index.provider().config.start_point.full().to_owned()); - check_grid_search(&index, &vectors, &paged_tests, FullPrecision, Quantized).await; - } - - #[rstest] - #[tokio::test] - async fn grid_search_with_build( - #[values((1, 100), (3, 7), (4, 5))] dim_and_size: (usize, usize), - ) { - let dim = dim_and_size.0; - let grid_size = dim_and_size.1; - let start_id = u32::MAX; - - let l = 10; - - // NOTE: Be careful changing `max_degree`. It needs to be high enough that the - // graph is navigable, but low enough that the batch parallel handling inside - // `multi_insert` is needed for the multi-insert graph to be navigable. - // - // With the current configured values, removing the other elements in the batch - // from the visited set during `multi_insert` results in a graph failure. - let max_degree = 2 * dim; - - let num_points = (grid_size).pow(dim as u32); - - let index_config = graph::config::Builder::new_with( - max_degree, - graph::config::MaxDegree::default_slack(), - l, - (Metric::L2).into(), - |b| { - b.max_minibatch_par(10); - }, - ) - .build() - .unwrap(); - - let debug_config = DebugConfig { - start_id, - start_point: vec![grid_size as f32; dim], - max_degree: index_config.max_degree().into(), - metric: Metric::L2, - }; - - let mut vectors = f32::generate_grid(dim, grid_size); - assert_eq!(vectors.len(), num_points); - - // This is a little subtle, but we need `vectors` to contain the start point as - // its last element, but we **don't** want to include it in the index build. - // - // This basically means that we need to be careful with index initialization. - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - - // Initialize an index for a new round of building. - let init_index = - || new_quant_index(index_config.clone(), debug_config.clone(), table.clone()); - - // Build with full-precision single insert - { - let index = init_index(); - let ctx = DefaultContext; - for (i, v) in vectors.iter().take(num_points).enumerate() { - index - .insert(FullPrecision, &ctx, &(i as u32), v.as_slice()) - .await - .unwrap(); - } - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with quantized single insert - { - let index = init_index(); - let ctx = DefaultContext; - for (i, v) in vectors.iter().take(num_points).enumerate() { - index - .insert(Quantized, &ctx, &(i as u32), v.as_slice()) - .await - .unwrap(); - } - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with full-precision multi-insert - { - let index = init_index(); - let ctx = DefaultContext; - let batch: Box<[_]> = vectors - .iter() - .take(num_points) - .enumerate() - .map(|(id, v)| VectorIdBoxSlice::new(id as u32, v.as_slice().into())) - .collect(); - - index - .multi_insert(FullPrecision, &ctx, batch) - .await - .unwrap(); - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "multi-insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with quantized multi-insert - { - let index = init_index(); - let ctx = DefaultContext; - let batch: Box<[_]> = vectors - .iter() - .take(num_points) - .enumerate() - .map(|(id, v)| VectorIdBoxSlice::new(id as u32, v.as_slice().into())) - .collect(); - - index.multi_insert(Quantized, &ctx, batch).await.unwrap(); - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "multi-insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - } -} diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 3d89359e2..cf719e730 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -39,7 +39,3 @@ pub mod bf_tree; // Caching proxy provider to accelerate slow providers. #[cfg(feature = "bf_tree")] pub mod caching; - -// Debug provider for testing. -#[cfg(test)] -pub mod debug_provider; diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 6a1857fac..6f2a00683 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -18,9 +18,10 @@ use thiserror::Error; use crate::{ ANNError, ANNResult, error::{Infallible, message}, - graph::{AdjacencyList, glue, test::synthetic}, + graph::{AdjacencyList, SearchOutputBuffer, glue, test::synthetic}, internal::counter::{Counter, LocalCounter}, - provider, + neighbor::Neighbor, + provider::{self, BuildQueryComputer}, utils::VectorRepr, }; @@ -893,6 +894,11 @@ impl<'a> Accessor<'a> { get_vector: provider.get_vector.local(), } } + + /// Return a reference to the underlying provider. + pub fn provider(&self) -> &Provider { + self.provider + } } impl provider::HasId for Accessor<'_> { @@ -966,6 +972,24 @@ impl glue::SearchExt for Accessor<'_> { impl glue::ExpandBeam<[f32]> for Accessor<'_> {} impl glue::FillSet for Accessor<'_> {} +impl provider::CacheableAccessor for Accessor<'_> { + type Map = diskann_utils::lifetime::Slice; + + fn as_cached<'a, 'b>(element: &'a &'b [f32]) -> &'a &'b [f32] + where + Self: 'a + 'b, + { + element + } + + fn from_cached<'a>(element: &'a [f32]) -> &'a [f32] + where + Self: 'a, + { + element + } +} + #[derive(Debug, Default, Clone, Copy)] pub struct Strategy { _phantom: (), @@ -1066,6 +1090,143 @@ impl glue::InplaceDeleteStrategy for Strategy { } } +//--------------------------------------// +// Deletion-Aware Strategy and PostProc // +//--------------------------------------// + +/// A [`glue::SearchPostProcess`] that filters out deleted IDs before copying results. +/// +/// This is analogous to the `RemoveDeletedIdsAndCopy` post-processor in `diskann-providers`, +/// tailored for the test provider. +#[derive(Debug, Default, Clone, Copy)] +pub struct FilterDeletedCopyIds; + +impl<'a, T> glue::SearchPostProcess, T> for FilterDeletedCopyIds +where + T: ?Sized, + Accessor<'a>: BuildQueryComputer, +{ + type Error = std::convert::Infallible; + + fn post_process( + &self, + accessor: &mut Accessor<'a>, + _query: &T, + _computer: & as BuildQueryComputer>::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl std::future::Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let provider = accessor.provider; + let count = output.extend(candidates.filter_map(|n| { + let is_deleted = provider.is_deleted(n.id).unwrap_or(true); + if is_deleted { + None + } else { + Some((n.id, n.distance)) + } + })); + std::future::ready(Ok(count)) + } +} + +/// A strategy variant that filters deleted IDs during search post-processing. +/// +/// This is needed for `inplace_delete` which relies on the post-processor to exclude +/// deleted items from search results. The base [`Strategy`] uses [`glue::CopyIds`] which +/// does not filter deletions. +#[derive(Debug, Default, Clone, Copy)] +pub struct DeletionAwareStrategy { + _phantom: (), +} + +impl DeletionAwareStrategy { + pub fn new() -> Self { + Self { _phantom: () } + } +} + +impl glue::SearchStrategy for DeletionAwareStrategy { + type QueryComputer = ::QueryDistance; + type PostProcessor = FilterDeletedCopyIds; + type SearchAccessorError = Infallible; + type SearchAccessor<'a> = Accessor<'a>; + + fn search_accessor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Infallible> { + Ok(Accessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + FilterDeletedCopyIds + } +} + +impl glue::PruneStrategy for DeletionAwareStrategy { + type DistanceComputer = ::Distance; + type PruneAccessor<'a> = Accessor<'a>; + type PruneAccessorError = Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Self::PruneAccessorError> { + Ok(Accessor::new(provider)) + } +} + +impl glue::InsertStrategy for DeletionAwareStrategy { + type PruneStrategy = Self; + + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Self::SearchAccessorError> { + Ok(Accessor::new(provider)) + } +} + +impl glue::InplaceDeleteStrategy for DeletionAwareStrategy { + type DeleteElement<'a> = [f32]; + type DeleteElementGuard = Box<[f32]>; + type DeleteElementError = AccessedInvalidId; + type PruneStrategy = Self; + type SearchStrategy = Self; + + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } + + fn search_strategy(&self) -> Self::SearchStrategy { + *self + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a Provider, + _context: &'a ::Context, + id: ::InternalId, + ) -> Result { + provider + .terms + .get(&id) + .map(|v| (*v.data).into()) + .ok_or(AccessedInvalidId(id)) + } +} + /////////// // Tests // ///////////