Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions diskann-tools/src/utils/gen_associated_data_from_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,29 @@ use std::io::Write;
use diskann_providers::storage::StorageWriteProvider;
use diskann_utils::io::Metadata;

use super::CMDResult;
use super::{CMDResult, CMDToolError};

pub fn gen_associated_data_from_range<S: StorageWriteProvider>(
storage_provider: &S,
associated_data_path: &str,
start: u32,
end: u32,
) -> CMDResult<()> {
if end < start {
return Err(CMDToolError {
details: format!(
"invalid range: end ({end}) must be greater than or equal to start ({start})"
),
});
}

let mut file = storage_provider.create_for_write(associated_data_path)?;

// Calculate the number of integers and the number of integers in associated data
let num_ints = end - start + 1;
// Use checked arithmetic to avoid overflow when end == u32::MAX and start == 0
let num_ints = (end - start).checked_add(1).ok_or_else(|| CMDToolError {
details: format!("range [{start}, {end}] is too large: count overflows u32"),
})?;
let int_length: u32 = 1;

// Write the number of integers and the length of each integer as little endian
Expand Down Expand Up @@ -105,4 +116,37 @@ mod tests {
assert_eq!(actual, expected);
}
}

#[test]
fn test_gen_associated_data_from_range_end_less_than_start() {
let storage_provider = VirtualStorageProvider::new_memory();
let path = "/test_gen_associated_data_invalid.bin";

let result = gen_associated_data_from_range(&storage_provider, path, 10, 5);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.details.contains("end") && err.details.contains("start"),
"error message should mention end and start: {err}"
);
assert!(
err.details.contains("5") && err.details.contains("10"),
"error message should include the specific values 5 and 10: {err}"
);
}

#[test]
fn test_gen_associated_data_from_range_max_overflow() {
let storage_provider = VirtualStorageProvider::new_memory();
let path = "/test_gen_associated_data_overflow.bin";

// end == u32::MAX and start == 0 would make count == u32::MAX + 1, which overflows
let result = gen_associated_data_from_range(&storage_provider, path, 0, u32::MAX);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.details.contains("overflow") || err.details.contains("too large"),
"error message should mention overflow or too large: {err}"
);
}
}