diff --git a/src/bin/extract.rs b/src/bin/extract.rs index 4b4a107..0157f81 100644 --- a/src/bin/extract.rs +++ b/src/bin/extract.rs @@ -7,7 +7,7 @@ use clap::{Args as ClapArgs, Parser}; use serde::de::DeserializeOwned; -use eggshell::distance::{EGraph, Expr, Label, TreeNode, UnitCost, find_min_struct, find_min_zs}; +use eggshell::distance::{EGraph, Expr, Label, TreeNode, UnitCost, find_min_sampling_zs}; #[derive(Parser)] #[command(about = "Find the closest tree in an e-graph to a reference tree")] @@ -22,7 +22,6 @@ Examples: # With revisits and quiet mode extract graph.json -e '(foo bar)' -r 2 -q ")] -#[expect(clippy::struct_excessive_bools)] struct Args { /// Path to the serialized e-graph JSON file egraph: String, @@ -30,9 +29,13 @@ struct Args { #[command(flatten)] reference: RefSource, - /// Maximum number of times a node may be revisited (for cycles) - #[arg(short = 'r', long, default_value_t = 0)] - max_revisits: usize, + /// Target weight + #[arg(short, long, default_value_t = 0)] + target_weight: usize, + + /// Number of samples + #[arg(short = 'n', long, default_value_t = 10000)] + samples: usize, /// Include the types in the comparison #[arg(short, long)] @@ -41,14 +44,13 @@ struct Args { /// Use raw string labels instead of Rise-typed labels (for regression testing) #[arg(long)] raw_strings: bool, + // /// Use structural distance instead of Zhang-Shasha tree edit distance + // #[arg(short, long)] + // structural: bool, - /// Use structural distance instead of Zhang-Shasha tree edit distance - #[arg(short, long)] - structural: bool, - - /// Ignore the labels when using the structural option - #[arg(short, long, requires_all = ["structural"])] - ignore_labels: bool, + // /// Ignore the labels when using the structural option + // #[arg(short, long, requires_all = ["structural"])] + // ignore_labels: bool, } #[derive(ClapArgs)] @@ -94,6 +96,16 @@ where println!(" Root e-class: {:?}", graph.root()); + let ref_tree = parse_ref(args, parse_tree); + + run_extraction(&graph, &ref_tree, args); +} + +fn parse_ref(args: &Args, parse_tree: F) -> TreeNode +where + L: Label + std::fmt::Display + DeserializeOwned, + F: Fn(&str) -> TreeNode, +{ let ref_tree: TreeNode = if let Some(expr) = &args.reference.expr { println!("Parsing reference tree from command line..."); parse_tree(expr) @@ -116,50 +128,25 @@ where }) .unwrap_or_else(|| panic!("No tree with name {name} found")) }; - - run_extraction(&graph, &ref_tree, args); + ref_tree } -fn run_extraction( - graph: &EGraph, - ref_tree: &TreeNode, - args: &Args, -) { +#[expect(clippy::cast_precision_loss)] +fn run_extraction(graph: &EGraph, ref_tree: &TreeNode, args: &Args) { let ref_node_count = ref_tree.size(); - println!(" Reference tree has {ref_node_count} nodes"); - - // Count trees in the e-graph - println!( - "\nCounting trees in e-graph (max_revisits={})...", - args.max_revisits - ); - let count_start = Instant::now(); - let tree_count = graph.count_trees(args.max_revisits); - let count_time = count_start.elapsed(); - println!(" Found {tree_count} trees in {count_time:.2?}"); - - if tree_count == 0 { - println!("No trees found in e-graph!"); - return; - } + let ref_stripped_count = ref_tree.strip_types().size(); + println!(" Reference tree has {ref_node_count} nodes ({ref_stripped_count} without types)"); - if args.structural { - run_structural(graph, ref_tree, args); - } else { - run_zs(graph, ref_tree, args); - } -} - -#[expect(clippy::cast_precision_loss)] -fn run_zs(graph: &EGraph, ref_tree: &TreeNode, args: &Args) { let start = Instant::now(); println!("\n--- Zhang-Shasha extraction (with lower-bound pruning) ---"); - if let (Some(result), stats) = find_min_zs( + if let (Some(result), stats) = find_min_sampling_zs( graph, ref_tree, &UnitCost, - args.max_revisits, args.with_types, + args.samples, + args.target_weight, + 100, ) { println!(" Best distance: {}", result.1); println!(" Time: {:.2?}", start.elapsed()); @@ -186,23 +173,3 @@ fn run_zs(graph: &EGraph, ref_tree: &TreeNode, args: &Args) { println!(" No result found!"); } } - -fn run_structural(graph: &EGraph, ref_tree: &TreeNode, args: &Args) { - let start = Instant::now(); - println!("\n--- Structural distance extraction ---"); - if let Some((tree, distance)) = find_min_struct( - graph, - ref_tree, - &UnitCost, - args.max_revisits, - args.with_types, - args.ignore_labels, - ) { - println!(" Best distance: {distance}"); - println!(" Time: {:.2?}", start.elapsed()); - println!("\n Best tree:"); - println!("{tree}"); - } else { - println!(" No result found!"); - } -} diff --git a/src/distance/count.rs b/src/distance/count.rs new file mode 100644 index 0000000..46625a9 --- /dev/null +++ b/src/distance/count.rs @@ -0,0 +1,607 @@ +//! Term counting analysis for e-graphs. +//! +//! Counts the number of terms up to a given size that can be extracted from each e-class. + +use std::sync::{Arc, Mutex, RwLock}; +use std::thread; + +use dashmap::DashMap; +use hashbrown::{HashMap, HashSet}; +use log::debug; +use num_traits::{NumAssignRef, NumRef}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +use super::graph::EGraph; +use super::ids::{DataChildId, DataId, EClassId, ExprChildId, FunId, NatId, TypeChildId}; +use super::nodes::Label; +use crate::utils::UniqueQueue; + +/// Counter trait for counting terms. +pub trait Counter: Clone + Send + Sync + NumRef + NumAssignRef + Default + std::fmt::Debug {} + +impl Counter for T {} + +/// Configuration for the term counting analysis. +#[derive(Debug, Clone)] +pub struct TermCount { + limit: usize, + with_types: bool, +} + +impl TermCount { + /// Create a new term counting configuration. + /// + /// # Arguments + /// * `limit` - Maximum term size to count + /// * `with_types` - If true, include type annotations in size calculations + #[must_use] + pub fn new(limit: usize, with_types: bool) -> Self { + Self { limit, with_types } + } + + /// Run the term counting analysis on an e-graph. + /// + /// Returns a map from e-class ID to a map of (size -> count). + /// + /// # Panics + /// Panics if threads fail to join (should not happen in practice). + #[must_use] + pub fn analyze( + &self, + egraph: &EGraph, + ) -> HashMap> { + // Build parent map and type size cache + let parents = build_parent_map(egraph); + let type_cache = Arc::new(RwLock::new(TypeSizeCache::default())); + + // Find leaf classes (classes with at least one leaf node) + let leaves: UniqueQueue = egraph + .class_ids() + .filter(|&id| { + egraph + .class(id) + .nodes() + .iter() + .any(|n| n.children().is_empty()) + }) + .collect(); + + let analysis_pending = Arc::new(Mutex::new(leaves)); + let data: Arc>> = Arc::new(DashMap::new()); + + // Run parallel analysis + let result_data = thread::scope(|scope| { + for i in 0..thread::available_parallelism().map_or(1, |p| p.get()) { + let thread_data = data.clone(); + let thread_pending = analysis_pending.clone(); + let thread_type_cache = type_cache.clone(); + let thread_parents = &parents; + + scope.spawn(move || { + debug!("Thread #{i} started!"); + self.resolve_pending_analysis( + egraph, + &thread_data, + &thread_pending, + &thread_type_cache, + thread_parents, + ); + debug!("Thread #{i} finished!"); + }); + } + data + }); + + Arc::into_inner(result_data) + .unwrap() + .into_par_iter() + .collect() + } + + /// Process pending e-classes from the work queue. + fn resolve_pending_analysis( + &self, + egraph: &EGraph, + data: &Arc>>, + analysis_pending: &Arc>>, + type_cache: &Arc>, + parents: &HashMap>, + ) { + // Mirrors the structure of CommutativeSemigroupAnalysis::one_shot_analysis + while let Some(id) = { analysis_pending.lock().unwrap().pop() } { + let canonical_id = egraph.canonicalize(id); + let eclass = egraph.class(canonical_id); + + // Get the type overhead for this e-class + let type_overhead = if self.with_types { + let ty_size = TypeSizeCache::get_type_size(type_cache, egraph, eclass.ty()); + 1 + ty_size + } else { + 0 + }; + + // Check if we can calculate the analysis for any enode + let available_data = eclass.nodes().iter().filter_map(|node| { + // If all the eclass children have data, we can calculate it + let all_ready = node.children().iter().all(|child_id| match child_id { + ExprChildId::Nat(_) | ExprChildId::Data(_) => true, + ExprChildId::EClass(eclass_id) => { + data.contains_key(&egraph.canonicalize(*eclass_id)) + } + }); + all_ready.then(|| { + self.make_node_data(egraph, node.children(), data, type_cache, type_overhead) + }) + }); + + // If we have some info, we add that info to our storage. + // Otherwise put back onto the queue. + if let Some(computed_data) = available_data.reduce(|mut a, b| { + Self::merge(&mut a, b); + a + }) { + // If we have gained new information, put the parents onto the queue. + // Only once we have reached a fixpoint we can stop updating the parents. + if !(data.get(&canonical_id).is_some_and(|v| *v == computed_data)) { + if let Some(parent_set) = parents.get(&canonical_id) { + analysis_pending + .lock() + .unwrap() + .extend(parent_set.iter().copied()); + } + data.insert(canonical_id, computed_data); + } + } else { + assert!(!eclass.nodes().is_empty()); + analysis_pending.lock().unwrap().insert(canonical_id); + } + } + } + + /// Merge two term count data maps. + fn merge(a: &mut HashMap, b: HashMap) { + for (size, count) in b { + a.entry(size).and_modify(|c| *c += &count).or_insert(count); + } + } + + /// Compute term counts for a single e-node. + fn make_node_data( + &self, + egraph: &EGraph, + children: &[ExprChildId], + data: &Arc>>, + type_cache: &Arc>, + type_overhead: usize, + ) -> HashMap { + // Base size: 1 for the node itself + type overhead + let base_size = 1 + type_overhead; + + if children.is_empty() { + // Leaf node + if base_size <= self.limit { + return HashMap::from([(base_size, C::one())]); + } + return HashMap::new(); + } + + // For nodes with children, combine their counts + let mut tmp = Vec::new(); + + children.iter().fold( + HashMap::from([(base_size, C::one())]), + |mut acc, child_id| { + let child_data = Self::get_child_data::(egraph, *child_id, data, type_cache); + + tmp.extend(acc.drain()); + + for (acc_size, acc_count) in &tmp { + for (child_size, child_count) in &child_data { + let combined_size = acc_size + child_size; + if combined_size > self.limit { + continue; + } + let combined_count = acc_count.to_owned() * child_count; + acc.entry(combined_size) + .and_modify(|c| *c += &combined_count) + .or_insert(combined_count); + } + } + + tmp.clear(); + acc + }, + ) + } + + /// Get the count data for a child, handling Nat/Data/EClass variants. + fn get_child_data( + egraph: &EGraph, + child_id: ExprChildId, + data: &Arc>>, + type_cache: &Arc>, + ) -> HashMap { + match child_id { + ExprChildId::Nat(nat_id) => { + // Nat nodes have a fixed size (no choices) + let size = TypeSizeCache::get_nat_size(type_cache, egraph, nat_id); + let mut result = HashMap::new(); + result.insert(size, C::one()); + result + } + ExprChildId::Data(data_id) => { + // Data type nodes have a fixed size (no choices) + let size = TypeSizeCache::get_data_size(type_cache, egraph, data_id); + let mut result = HashMap::new(); + result.insert(size, C::one()); + result + } + ExprChildId::EClass(eclass_id) => { + // E-class children use the precomputed data + let canonical_id = egraph.canonicalize(eclass_id); + data.get(&canonical_id) + .map(|r| r.clone()) + .unwrap_or_default() + } + } + } +} + +/// Build a map from child e-class to parent e-classes. +fn build_parent_map(egraph: &EGraph) -> HashMap> { + let mut parents: HashMap> = HashMap::new(); + + for class_id in egraph.class_ids() { + let canonical_id = egraph.canonicalize(class_id); + for node in egraph.class(canonical_id).nodes() { + for child_id in node.children() { + if let ExprChildId::EClass(child_eclass_id) = child_id { + let canonical_child = egraph.canonicalize(*child_eclass_id); + parents + .entry(canonical_child) + .or_default() + .insert(canonical_id); + } + } + } + } + + parents +} + +/// Cache for type node sizes to avoid repeated computation. +#[derive(Debug, Default)] +struct TypeSizeCache { + nats: HashMap, + data: HashMap, + funs: HashMap, +} + +impl TypeSizeCache { + /// Get the size of a type (`TypeChildId`), dispatching to the appropriate cache. + fn get_type_size(cache: &RwLock, egraph: &EGraph, id: TypeChildId) -> usize { + match id { + TypeChildId::Nat(nat_id) => Self::get_nat_size(cache, egraph, nat_id), + TypeChildId::Type(fun_id) => Self::get_fun_size(cache, egraph, fun_id), + TypeChildId::Data(data_id) => Self::get_data_size(cache, egraph, data_id), + } + } + + /// Get the size of a nat node, using cache with read-preferring access. + fn get_nat_size(cache: &RwLock, egraph: &EGraph, id: NatId) -> usize { + // Try read lock first (fast path for cache hits) + if let Some(&size) = cache.read().unwrap().nats.get(&id) { + return size; + } + + // Cache miss: compute and insert with write lock + let size = Self::compute_nat_size(cache, egraph, id); + cache.write().unwrap().nats.insert(id, size); + size + } + + /// Get the size of a data type node, using cache with read-preferring access. + fn get_data_size(cache: &RwLock, egraph: &EGraph, id: DataId) -> usize { + // Try read lock first (fast path for cache hits) + if let Some(&size) = cache.read().unwrap().data.get(&id) { + return size; + } + + // Cache miss: compute and insert with write lock + let size = Self::compute_data_size(cache, egraph, id); + cache.write().unwrap().data.insert(id, size); + size + } + + /// Get the size of a fun type node, using cache with read-preferring access. + fn get_fun_size(cache: &RwLock, egraph: &EGraph, id: FunId) -> usize { + // Try read lock first (fast path for cache hits) + if let Some(&size) = cache.read().unwrap().funs.get(&id) { + return size; + } + + // Cache miss: compute and insert with write lock + let size = Self::compute_fun_size(cache, egraph, id); + cache.write().unwrap().funs.insert(id, size); + size + } + + fn compute_nat_size(cache: &RwLock, egraph: &EGraph, id: NatId) -> usize { + let node = egraph.nat(id); + let children_size: usize = node + .children() + .iter() + .map(|&child_id| Self::get_nat_size(cache, egraph, child_id)) + .sum(); + 1 + children_size + } + + fn compute_data_size(cache: &RwLock, egraph: &EGraph, id: DataId) -> usize { + let node = egraph.data_ty(id); + let children_size: usize = node + .children() + .iter() + .map(|&child_id| match child_id { + DataChildId::Nat(nat_id) => Self::get_nat_size(cache, egraph, nat_id), + DataChildId::DataType(data_id) => Self::get_data_size(cache, egraph, data_id), + }) + .sum(); + 1 + children_size + } + + fn compute_fun_size(cache: &RwLock, egraph: &EGraph, id: FunId) -> usize { + let node = egraph.fun_ty(id); + let children_size: usize = node + .children() + .iter() + .map(|&child_id| Self::get_type_size(cache, egraph, child_id)) + .sum(); + 1 + children_size + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::graph::EClass; + use crate::distance::nodes::{ENode, NatNode}; + use num::BigUint; + + fn eid(i: usize) -> ExprChildId { + ExprChildId::EClass(EClassId::new(i)) + } + + fn dummy_ty() -> TypeChildId { + TypeChildId::Nat(NatId::new(0)) + } + + fn dummy_nat_nodes() -> HashMap> { + let mut nats = HashMap::new(); + nats.insert(NatId::new(0), NatNode::leaf("0".to_owned())); + nats + } + + fn cfv(classes: Vec>) -> HashMap> { + classes + .into_iter() + .enumerate() + .map(|(i, c)| (EClassId::new(i), c)) + .collect() + } + + #[test] + fn single_leaf_no_types() { + let graph = EGraph::new( + cfv(vec![EClass::new( + vec![ENode::leaf("a".to_owned())], + dummy_ty(), + )]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let counter = TermCount::new(10, false); + let data: HashMap> = counter.analyze(&graph); + + let root_data = &data[&EClassId::new(0)]; + assert_eq!(root_data.len(), 1); + assert_eq!(root_data[&1], BigUint::from(1u32)); + } + + #[test] + fn single_leaf_with_types() { + let graph = EGraph::new( + cfv(vec![EClass::new( + vec![ENode::leaf("a".to_owned())], + dummy_ty(), + )]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let counter = TermCount::new(10, true); + let data: HashMap> = counter.analyze(&graph); + + let root_data = &data[&EClassId::new(0)]; + // Size = 1 (node) + 1 (typeOf) + 1 (type "0") = 3 + assert_eq!(root_data.len(), 1); + assert_eq!(root_data[&3], BigUint::from(1u32)); + } + + #[test] + fn two_choices_no_types() { + let graph = EGraph::new( + cfv(vec![EClass::new( + vec![ENode::leaf("a".to_owned()), ENode::leaf("b".to_owned())], + dummy_ty(), + )]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let counter = TermCount::new(10, false); + let data: HashMap> = counter.analyze(&graph); + + let root_data = &data[&EClassId::new(0)]; + // Two terms of size 1 + assert_eq!(root_data[&1], BigUint::from(2u32)); + } + + #[test] + fn parent_child_no_types() { + // Class 0: has node "f" pointing to class 1 + // Class 1: has leaf "a" + let graph = EGraph::new( + cfv(vec![ + EClass::new(vec![ENode::new("f".to_owned(), vec![eid(1)])], dummy_ty()), + EClass::new(vec![ENode::leaf("a".to_owned())], dummy_ty()), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let counter = TermCount::new(10, false); + let data: HashMap> = counter.analyze(&graph); + + // Class 1: one term of size 1 + assert_eq!(data[&EClassId::new(1)][&1], BigUint::from(1u32)); + + // Class 0: one term of size 2 (f + a) + assert_eq!(data[&EClassId::new(0)][&2], BigUint::from(1u32)); + } + + #[test] + fn parent_with_multiple_child_choices() { + // Class 0: has node "f" pointing to class 1 + // Class 1: has two leaves "a" and "b" + let graph = EGraph::new( + cfv(vec![ + EClass::new(vec![ENode::new("f".to_owned(), vec![eid(1)])], dummy_ty()), + EClass::new( + vec![ENode::leaf("a".to_owned()), ENode::leaf("b".to_owned())], + dummy_ty(), + ), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let counter = TermCount::new(10, false); + let data: HashMap> = counter.analyze(&graph); + + // Class 1: two terms of size 1 + assert_eq!(data[&EClassId::new(1)][&1], BigUint::from(2u32)); + + // Class 0: two terms of size 2 (f(a), f(b)) + assert_eq!(data[&EClassId::new(0)][&2], BigUint::from(2u32)); + } + + #[test] + fn two_children() { + // Class 0: has node "f" pointing to classes 1 and 2 + // Class 1: leaf "a" + // Class 2: leaf "b" + let graph = EGraph::new( + cfv(vec![ + EClass::new( + vec![ENode::new("f".to_owned(), vec![eid(1), eid(2)])], + dummy_ty(), + ), + EClass::new(vec![ENode::leaf("a".to_owned())], dummy_ty()), + EClass::new(vec![ENode::leaf("b".to_owned())], dummy_ty()), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let counter = TermCount::new(10, false); + let data: HashMap> = counter.analyze(&graph); + + // Class 0: one term of size 3 (f + a + b) + assert_eq!(data[&EClassId::new(0)][&3], BigUint::from(1u32)); + } + + #[test] + fn combinatorial_explosion() { + // Class 0: has node "f" pointing to classes 1 and 2 + // Class 1: two leaves "a1", "a2" + // Class 2: three leaves "b1", "b2", "b3" + let graph = EGraph::new( + cfv(vec![ + EClass::new( + vec![ENode::new("f".to_owned(), vec![eid(1), eid(2)])], + dummy_ty(), + ), + EClass::new( + vec![ENode::leaf("a1".to_owned()), ENode::leaf("a2".to_owned())], + dummy_ty(), + ), + EClass::new( + vec![ + ENode::leaf("b1".to_owned()), + ENode::leaf("b2".to_owned()), + ENode::leaf("b3".to_owned()), + ], + dummy_ty(), + ), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let counter = TermCount::new(10, false); + let data: HashMap> = counter.analyze(&graph); + + // Class 0: 2 * 3 = 6 terms of size 3 + assert_eq!(data[&EClassId::new(0)][&3], BigUint::from(6u32)); + } + + #[test] + fn size_limit_filters() { + // Class 0: has node "f" pointing to class 1 + // Class 1: leaf "a" + let graph = EGraph::new( + cfv(vec![ + EClass::new(vec![ENode::new("f".to_owned(), vec![eid(1)])], dummy_ty()), + EClass::new(vec![ENode::leaf("a".to_owned())], dummy_ty()), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + // Limit = 1, so f(a) with size 2 should be filtered out + let counter = TermCount::new(1, false); + let data: HashMap> = counter.analyze(&graph); + + // Class 1 should have data (size 1) + assert!(data.contains_key(&EClassId::new(1))); + assert_eq!(data[&EClassId::new(1)][&1], BigUint::from(1u32)); + + // Class 0 should be empty (size 2 exceeds limit) + assert!(data.get(&EClassId::new(0)).is_none_or(|d| d.is_empty())); + } +} diff --git a/src/distance/extract.rs b/src/distance/extract.rs new file mode 100644 index 0000000..6fb0c2f --- /dev/null +++ b/src/distance/extract.rs @@ -0,0 +1,288 @@ +//! Tree extraction for E-Graphs. +//! +//! This module provides iterators for enumerating trees from an e-graph. + +use hashbrown::HashMap; + +use super::graph::EGraph; +use super::ids::{EClassId, ExprChildId}; +use super::nodes::Label; + +/// Iterator that yields choice vectors without materializing trees. +/// Each choice vector can later be used with `tree_from_choices` to get the actual tree. +#[derive(Debug)] +pub struct ChoiceIter<'a, L: Label> { + choices: Vec, + path: PathTracker, + egraph: &'a EGraph, +} + +impl<'a, L: Label> ChoiceIter<'a, L> { + #[must_use] + pub fn new(egraph: &'a EGraph, max_revisits: usize) -> Self { + Self { + choices: Vec::new(), + path: PathTracker::new(max_revisits), + egraph, + } + } + + /// Find the next valid choice vector, modifying `choices` in place. + /// + /// If `choices` is empty or shorter than needed, finds the first valid tree. + /// If `choices` already represents a tree and `advance` is true, finds the + /// lexicographically next one. + /// + /// Returns `Some(last_idx)` on success, `None` if no more trees exist. + fn next_choices(&mut self, id: EClassId, choice_idx: usize, advance: bool) -> Option { + if !self.path.can_visit(id) { + return None; + } + + let class = self.egraph.class(id); + + // Determine starting node and whether to advance children + let (start_node, advance_children) = if let Some(&c) = self.choices.get(choice_idx) { + (c, advance) + } else { + self.choices.push(0); + (0, false) + }; + + self.path.enter(id); + + let result = + class + .nodes() + .iter() + .enumerate() + .skip(start_node) + .find_map(|(node_idx, node)| { + self.choices[choice_idx] = node_idx; + let should_advance = advance_children && node_idx == start_node; + + self.next_choices_children(node.children(), choice_idx, should_advance) + .or_else(|| { + self.choices.truncate(choice_idx + 1); + None + }) + }); + + self.path.leave(id); + result + } + + /// Process children, optionally advancing to find the next combination. + fn next_choices_children( + &mut self, + children: &[ExprChildId], + parent_idx: usize, + advance: bool, + ) -> Option { + let eclass_children: Vec<_> = children + .iter() + .filter_map(|c| match c { + ExprChildId::EClass(id) => Some(*id), + _ => None, + }) + .collect(); + + match (eclass_children.is_empty(), advance) { + (true, true) => None, // No children to advance + (true, false) => Some(parent_idx), // Leaf node, nothing to do + (false, false) => eclass_children + .iter() + .try_fold(parent_idx, |curr_idx, &child_id| { + self.next_choices(child_id, curr_idx + 1, false) + }), + (false, true) => self.advance_children(&eclass_children, parent_idx), + } + } + + /// Advance to the next combination by trying to advance rightmost child first. + fn advance_children(&mut self, children: &[EClassId], parent_idx: usize) -> Option { + // Try advancing each child from right to left + (0..children.len()).rev().find_map(|advance_idx| { + // Rebuild prefix (children before advance_idx) + let prefix_idx = children[..advance_idx] + .iter() + .try_fold(parent_idx, |curr_idx, &child_id| { + self.next_choices(child_id, curr_idx + 1, false) + })?; + + // Try to advance child at advance_idx + let advanced_idx = self.next_choices(children[advance_idx], prefix_idx + 1, true)?; + + // Rebuild suffix (children after advance_idx) + children[advance_idx + 1..] + .iter() + .try_fold(advanced_idx, |curr_idx, &child_id| { + self.next_choices(child_id, curr_idx + 1, false) + }) + }) + } +} + +impl Iterator for ChoiceIter<'_, L> { + type Item = Vec; + + fn next(&mut self) -> Option { + // On first call, choices is empty, so advance=false finds the first tree. + // On subsequent calls, choices contains the previous result, so advance=true + // finds the next tree. + let advance = !self.choices.is_empty(); + let root = self.egraph.root(); + + self.next_choices(root, 0, advance)?; + + Some(self.choices.clone()) + } +} + +/// Count the number of trees in an e-graph with the given revisit limit. +#[must_use] +pub fn count_trees(egraph: &EGraph, max_revisits: usize) -> usize { + let mut path = PathTracker::new(max_revisits); + count_trees_rec(egraph, egraph.root(), &mut path) +} + +fn count_trees_rec(egraph: &EGraph, id: EClassId, path: &mut PathTracker) -> usize { + // Cycle detection + if !path.can_visit(id) { + return 0; + } + + path.enter(id); + let count = egraph + .class(id) + .nodes() + .iter() + .map(|node| { + node.children() + .iter() + .map(|child_id| { + if let ExprChildId::EClass(inner_id) = child_id { + count_trees_rec(egraph, *inner_id, path) + } else { + 1 + } + }) + .product::() // product for children (and-choices) + }) + .sum::(); // sum for nodes (or-choices) + path.leave(id); + count +} + +/// Path tracker for cycle detection in the `EGraph`. +/// Tracks how many times each class has been visited on the current path +/// and allows configurable revisit limits. +#[derive(Debug, Clone)] +struct PathTracker { + /// Visit counts for classes on the current path + visits: HashMap, + /// Maximum number of times any node may be revisited (0 = no revisits allowed) + max_revisits: usize, +} + +impl PathTracker { + fn new(max_revisits: usize) -> Self { + PathTracker { + visits: HashMap::new(), + max_revisits, + } + } + + /// Check if visiting this OR node would exceed the revisit limit. + /// Returns true if the visit is allowed. + fn can_visit(&self, id: EClassId) -> bool { + let count = self.visits.get(&id).copied().unwrap_or(0); + count <= self.max_revisits + } + + /// Mark an OR node as visited on the current path. + fn enter(&mut self, id: EClassId) { + *self.visits.entry(id).or_insert(0) += 1; + } + + /// Unmark an OR node when leaving the current path. + fn leave(&mut self, id: EClassId) { + if let Some(count) = self.visits.get_mut(&id) { + *count = count.saturating_sub(1); + if *count == 0 { + self.visits.remove(&id); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::graph::EClass; + use crate::distance::ids::{ExprChildId, NatId, TypeChildId}; + use crate::distance::nodes::{ENode, NatNode}; + + fn eid(i: usize) -> ExprChildId { + ExprChildId::EClass(EClassId::new(i)) + } + + fn dummy_ty() -> TypeChildId { + TypeChildId::Nat(NatId::new(0)) + } + + fn dummy_nat_nodes() -> HashMap> { + let mut nats = HashMap::new(); + nats.insert(NatId::new(0), NatNode::leaf("0".to_owned())); + nats + } + + fn cfv(classes: Vec>) -> HashMap> { + classes + .into_iter() + .enumerate() + .map(|(i, c)| (EClassId::new(i), c)) + .collect() + } + + #[test] + fn choice_iter_enumerates_all_trees_diamond_cycle() { + let graph = EGraph::new( + cfv(vec![ + EClass::new( + vec![ENode::new("a".to_owned(), vec![eid(1), eid(2)])], + dummy_ty(), + ), + EClass::new(vec![ENode::new("b".to_owned(), vec![eid(3)])], dummy_ty()), + EClass::new(vec![ENode::new("c".to_owned(), vec![eid(3)])], dummy_ty()), + EClass::new( + vec![ + ENode::new("rec".to_owned(), vec![eid(3)]), + ENode::leaf("d".to_owned()), + ], + dummy_ty(), + ), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + assert_eq!(graph.choice_iter(1).count(), 4); + assert_eq!(graph.count_trees(1), graph.choice_iter(1).count()); + + let trees = graph + .choice_iter(1) + .map(|c| graph.tree_from_choices(graph.root(), &c, false).to_string()) + .collect::>(); + assert!(trees.contains(&"(a (b d) (c d))".to_owned())); + assert!(trees.contains(&"(a (b d) (c (rec d)))".to_owned())); + assert!(trees.contains(&"(a (b (rec d)) (c d))".to_owned())); + assert!(trees.contains(&"(a (b (rec d)) (c (rec d)))".to_owned())); + + assert_eq!(graph.choice_iter(0).count(), 1); + assert_eq!(graph.count_trees(0), graph.choice_iter(0).count()); + } +} diff --git a/src/distance/graph.rs b/src/distance/graph.rs index 4edb663..b77ccbd 100644 --- a/src/distance/graph.rs +++ b/src/distance/graph.rs @@ -1,31 +1,22 @@ -//! `EGraph` Extension for Zhang-Shasha Tree Edit Distance +//! E-Graph data structure with tree extraction support. //! -//! Finds the solution tree in a bounded `EGraph` with minimum edit distance -//! to a target tree. Assumes bounded maximum number of nodes in an `EClass` (N) and bounded depth (d). -//! -//! With strict alternation (`EClass` -> `ENode` -> `EClass` ->...), -//! complexity is O(N^(d/2) * |T|^2) for single-path graphs +//! This module provides the core `EGraph` and `EClass` types for representing +//! equivalence graphs with type annotations. use std::fs::File; use std::io::BufReader; use std::path::Path; -use std::sync::atomic::{AtomicUsize, Ordering}; use hashbrown::HashMap; -use indicatif::ParallelProgressIterator; -use rayon::iter::{ParallelBridge, ParallelIterator}; use serde::{Deserialize, Serialize}; -use crate::distance::structural::structural_diff; - +use super::extract::ChoiceIter; use super::ids::{ DataId, EClassId, ExprChildId, FunId, NatId, NumericId, TypeChildId, eclass_id_vec, numeric_key_map, }; use super::nodes::{DataTyNode, ENode, FunTyNode, Label, NatNode}; -use super::str::EulerString; use super::tree::TreeNode; -use super::zs::{EditCosts, PreprocessedTree, tree_distance_with_ref}; /// `EClass`: choose exactly one child (`ENode`) /// Children are `ENode` instances directly @@ -159,131 +150,9 @@ impl EGraph { &self.nat_nodes[&id] } - /// Find the next valid choice vector, modifying `choices` in place. - /// - /// If `choices` is empty or shorter than needed, finds the first valid tree. - /// If `choices` already represents a tree and `advance` is true, finds the - /// lexicographically next one. - /// - /// Returns `Some(last_idx)` on success, `None` if no more trees exist. - fn next_choices( - &self, - id: EClassId, - choice_idx: usize, - choices: &mut Vec, - path: &mut PathTracker, - advance: bool, - ) -> Option { - if !path.can_visit(id) { - return None; - } - - let class = self.class(id); - - // Determine starting node and whether to advance children - let (start_node, advance_children) = if let Some(&c) = choices.get(choice_idx) { - (c, advance) - } else { - choices.push(0); - (0, false) - }; - - path.enter(id); - - let result = - class - .nodes() - .iter() - .enumerate() - .skip(start_node) - .find_map(|(node_idx, node)| { - choices[choice_idx] = node_idx; - let should_advance = advance_children && node_idx == start_node; - - self.next_choices_children( - node.children(), - choice_idx, - choices, - path, - should_advance, - ) - .or_else(|| { - choices.truncate(choice_idx + 1); - None - }) - }); - - path.leave(id); - result - } - - /// Process children, optionally advancing to find the next combination. - fn next_choices_children( - &self, - children: &[ExprChildId], - parent_idx: usize, - choices: &mut Vec, - path: &mut PathTracker, - advance: bool, - ) -> Option { - let eclass_children = children - .iter() - .filter_map(|c| match c { - ExprChildId::EClass(id) => Some(*id), - _ => None, - }) - .collect::>(); - - match (eclass_children.is_empty(), advance) { - (true, true) => None, // No children to advance - (true, false) => Some(parent_idx), // Leaf node, nothing to do - (false, false) => self.first_children(&eclass_children, parent_idx, choices, path), - (false, true) => self.advance_children(&eclass_children, parent_idx, choices, path), - } - } - - /// Find the first valid combination for all children using fold. - fn first_children( - &self, - children: &[EClassId], - parent_idx: usize, - choices: &mut Vec, - path: &mut PathTracker, - ) -> Option { - children.iter().try_fold(parent_idx, |curr_idx, &child_id| { - self.next_choices(child_id, curr_idx + 1, choices, path, false) - }) - } - - /// Advance to the next combination by trying to advance rightmost child first. - fn advance_children( - &self, - children: &[EClassId], - parent_idx: usize, - choices: &mut Vec, - path: &mut PathTracker, - ) -> Option { - // Try advancing each child from right to left - (0..children.len()).rev().find_map(|advance_idx| { - // Rebuild prefix (children before advance_idx) - let prefix_idx = - children[..advance_idx] - .iter() - .try_fold(parent_idx, |curr_idx, &child_id| { - self.next_choices(child_id, curr_idx + 1, choices, path, false) - })?; - - // Try to advance child at advance_idx - let advanced_idx = - self.next_choices(children[advance_idx], prefix_idx + 1, choices, path, true)?; - - // Rebuild suffix (children after advance_idx) - children[advance_idx + 1..] - .iter() - .try_fold(advanced_idx, |curr_idx, &child_id| { - self.next_choices(child_id, curr_idx + 1, choices, path, false) - }) - }) + /// Returns an iterator over all e-class IDs in the graph. + pub fn class_ids(&self) -> impl Iterator + '_ { + self.classes.keys().map(|k| self.canonicalize(*k)) } #[must_use] @@ -333,340 +202,24 @@ impl EGraph { (eclass_tree, curr_idx) } - #[must_use] - pub fn enumerate_trees(&self, max_revisits: usize, with_type: bool) -> Vec> { - TreeIter::new(self, max_revisits, with_type).collect() - } - #[must_use] pub fn count_trees(&self, max_revisits: usize) -> usize { - let mut p = PathTracker::new(max_revisits); - self.count_rec(self.root(), &mut p) - } - - fn count_rec(&self, id: EClassId, path: &mut PathTracker) -> usize { - // Cycle detection - if !path.can_visit(id) { - return 0; - } - - path.enter(id); - let c = self - .class(id) - .nodes() - .iter() - .map(|node| { - node.children() - .iter() - .map(|child_id| { - if let ExprChildId::EClass(inner_id) = child_id { - self.count_rec(*inner_id, path) - } else { - 1 - } - }) - .product::() // product for children (and-choices) - }) - .sum::(); // sum for nodes (or-choices) - path.leave(id); - c + super::extract::count_trees(self, max_revisits) } #[must_use] pub fn choice_iter(&self, max_revisits: usize) -> ChoiceIter<'_, L> { ChoiceIter::new(self, max_revisits) } - - #[must_use] - pub fn tree_iter(&self, max_revisits: usize, with_type: bool) -> TreeIter<'_, L> { - TreeIter::new(self, max_revisits, with_type) - } -} - -/// If `quiet` is true, hides the progress bar. -/// If `strip_types`, ignores the types -#[must_use] -pub fn find_min_zs>( - graph: &EGraph, - reference: &TreeNode, - costs: &C, - max_revisits: usize, - with_types: bool, -) -> (Option<(TreeNode, usize)>, Stats) { - let ref_tree = if with_types { - reference - } else { - &reference.strip_types() - }; - - let ref_size = ref_tree.size(); - let ref_euler = EulerString::new(ref_tree); - let ref_pp = PreprocessedTree::new(ref_tree); - let running_best = AtomicUsize::new(usize::MAX); - - let (result, stats) = graph - .choice_iter(max_revisits) - .par_bridge() - .progress_count(graph.count_trees(max_revisits) as u64) - .map(|choices| { - { - let stripped_candidated = - graph.tree_from_choices(graph.root(), &choices, with_types); - let best = running_best.load(Ordering::Relaxed); - - // Fast pruning: size difference is a lower bound on edit distance - // (need at least |n1 - n2| insertions or deletions) - if stripped_candidated.size().abs_diff(ref_size) > best { - return (None, Stats::size_pruned()); - } - - // Euler string heuristic: EDS(s(T1), s(T2)) ≤ 2 · EDT(T1, T2) - // Therefore EDT ≥ EDS / 2, giving us a tighter lower bound - if ref_euler.lower_bound(&stripped_candidated, costs) > best { - return (None, Stats::euler_pruned()); - } - - let distance = tree_distance_with_ref(&stripped_candidated, &ref_pp, costs); - running_best.fetch_min(distance, Ordering::Relaxed); - - let tree = graph.tree_from_choices(graph.root(), &choices, true); - (Some((tree, distance)), Stats::compared()) - } - }) - .reduce( - || (None, Stats::default()), - |a, b| { - let best = [a.0, b.0].into_iter().flatten().min_by_key(|v| v.1); - (best, a.1 + b.1) - }, - ); - - (result, stats) -} - -/// If `quiet` is true, hides the progress bar. -/// If `strip_types`, ignores the types -#[must_use] -pub fn find_min_struct>( - graph: &EGraph, - reference: &TreeNode, - costs: &C, - max_revisits: usize, - with_types: bool, - ignore_labels: bool, -) -> Option<(TreeNode, usize)> { - let ref_tree = if with_types { - reference - } else { - &reference.strip_types() - }; - graph - .choice_iter(max_revisits) - .par_bridge() - .progress_count(graph.count_trees(max_revisits) as u64) - .map(|choices| { - let stripped_candidated = graph.tree_from_choices(graph.root(), &choices, with_types); - let distance = structural_diff(ref_tree, &stripped_candidated, costs, ignore_labels); - let tree = graph.tree_from_choices(graph.root(), &choices, true); - (tree, distance) - }) - .min_by_key(|(_, d)| *d) -} - -#[derive(Debug)] -pub struct TreeIter<'a, L: Label> { - choices: Vec, - path: PathTracker, - egraph: &'a EGraph, - with_type: bool, -} - -impl<'a, L: Label> TreeIter<'a, L> { - pub fn new(egraph: &'a EGraph, max_revisits: usize, with_type: bool) -> Self { - Self { - choices: Vec::new(), - path: PathTracker::new(max_revisits), - egraph, - with_type, - } - } -} - -impl Iterator for TreeIter<'_, L> { - type Item = TreeNode; - - fn next(&mut self) -> Option { - // Use choice_iter logic to get next valid choices, then materialize the tree - let advance = !self.choices.is_empty(); - - self.egraph.next_choices( - self.egraph.root(), - 0, - &mut self.choices, - &mut self.path, - advance, - )?; - - Some( - self.egraph - .tree_from_choices(self.egraph.root(), &self.choices, self.with_type), - ) - } -} - -/// Iterator that yields choice vectors without materializing trees. -/// Each choice vector can later be used with `tree_from_choices` to get the actual tree. -#[derive(Debug)] -pub struct ChoiceIter<'a, L: Label> { - choices: Vec, - path: PathTracker, - egraph: &'a EGraph, -} - -impl<'a, L: Label> ChoiceIter<'a, L> { - pub fn new(egraph: &'a EGraph, max_revisits: usize) -> Self { - Self { - choices: Vec::new(), - path: PathTracker::new(max_revisits), - egraph, - } - } -} - -impl Iterator for ChoiceIter<'_, L> { - type Item = Vec; - - fn next(&mut self) -> Option { - // On first call, choices is empty, so advance=false finds the first tree. - // On subsequent calls, choices contains the previous result, so advance=true - // finds the next tree. - let advance = !self.choices.is_empty(); - - self.egraph.next_choices( - self.egraph.root(), - 0, - &mut self.choices, - &mut self.path, - advance, - )?; - - Some(self.choices.clone()) - } -} - -/// Path tracker for cycle detection in the `EGraph`. -/// Tracks how many times each class has been visited on the current path -/// and allows configurable revisit limits. -#[derive(Debug, Clone)] -struct PathTracker { - /// Visit counts for classes on the current path - visits: HashMap, - /// Maximum number of times any node may be revisited (0 = no revisits allowed) - max_revisits: usize, -} - -impl PathTracker { - fn new(max_revisits: usize) -> Self { - PathTracker { - visits: HashMap::new(), - max_revisits, - } - } - - /// Check if visiting this OR node would exceed the revisit limit. - /// Returns true if the visit is allowed. - fn can_visit(&self, id: EClassId) -> bool { - let count = self.visits.get(&id).copied().unwrap_or(0); - count <= self.max_revisits - } - - /// Mark an OR node as visited on the current path. - fn enter(&mut self, id: EClassId) { - *self.visits.entry(id).or_insert(0) += 1; - } - - /// Unmark an OR node when leaving the current path. - fn leave(&mut self, id: EClassId) { - if let Some(count) = self.visits.get_mut(&id) { - *count = count.saturating_sub(1); - if *count == 0 { - self.visits.remove(&id); - } - } - } -} - -/// Statistics from filtered extraction -#[derive(Debug, Clone, Default)] -pub struct Stats { - /// Total number of trees enumerated - pub trees_enumerated: usize, - /// Trees pruned by simple metric - pub size_pruned: usize, - /// Number of trees pruned by euler string filter - pub euler_pruned: usize, - /// Number of trees for which full distance was computed - pub full_comparisons: usize, -} - -impl Stats { - fn size_pruned() -> Self { - Self { - trees_enumerated: 1, - size_pruned: 1, - euler_pruned: 0, - full_comparisons: 0, - } - } - - fn euler_pruned() -> Self { - Self { - trees_enumerated: 1, - size_pruned: 0, - euler_pruned: 1, - full_comparisons: 0, - } - } - - fn compared() -> Self { - Self { - trees_enumerated: 1, - size_pruned: 0, - euler_pruned: 0, - full_comparisons: 1, - } - } -} - -impl std::ops::Add for Stats { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - Self { - trees_enumerated: self.trees_enumerated + rhs.trees_enumerated, - size_pruned: self.size_pruned + rhs.size_pruned, - euler_pruned: self.euler_pruned + rhs.euler_pruned, - full_comparisons: self.full_comparisons + rhs.full_comparisons, - } - } } #[cfg(test)] mod tests { use crate::distance::ids::NumericId; - use crate::distance::zs::UnitCost; use super::*; - fn leaf(label: impl Into) -> TreeNode { - TreeNode::leaf(label.into()) - } - - fn node(label: L, children: Vec>) -> TreeNode { - TreeNode::new(label, children) - } - fn eid(i: usize) -> ExprChildId { ExprChildId::EClass(EClassId::new(i)) } @@ -708,9 +261,12 @@ mod tests { } #[test] - fn enumerate_single_leaf() { + fn choice_iter_single_leaf() { let graph = single_node_graph("a"); - let trees = graph.enumerate_trees(0, true); + let trees: Vec<_> = graph + .choice_iter(0) + .map(|c| graph.tree_from_choices(graph.root(), &c, true)) + .collect(); assert_eq!(trees.len(), 1); // with_type=true wraps in typeOf(expr, type) @@ -721,7 +277,7 @@ mod tests { } #[test] - fn enumerate_with_or_choice() { + fn choice_iter_with_or_choice() { // Graph with one class containing two node choices let graph = EGraph::new( cfv(vec![EClass::new( @@ -735,21 +291,24 @@ mod tests { HashMap::new(), ); - let trees = graph.enumerate_trees(0, true); + let trees: Vec<_> = graph + .choice_iter(0) + .map(|c| graph.tree_from_choices(graph.root(), &c, true)) + .collect(); assert_eq!(trees.len(), 2); // with_type=true wraps in typeOf(expr, type) // Extract expr labels from inside the typeOf wrapper - let labels = trees + let labels: Vec<_> = trees .iter() .map(|t| t.children()[0].label().as_str()) - .collect::>(); + .collect(); assert!(labels.contains(&"a")); assert!(labels.contains(&"b")); } #[test] - fn enumerate_with_and_children() { + fn choice_iter_with_and_children() { // Graph: root class -> node with two child classes (each has one leaf) // Class 0: root, has node "a" pointing to classes 1 and 2 // Class 1: leaf "b" @@ -770,7 +329,10 @@ mod tests { HashMap::new(), ); - let trees = graph.enumerate_trees(0, true); + let trees: Vec<_> = graph + .choice_iter(0) + .map(|c| graph.tree_from_choices(graph.root(), &c, true)) + .collect(); assert_eq!(trees.len(), 1); // with_type=true wraps in typeOf(expr, type) // typeOf(a(typeOf(b, type), typeOf(c, type)), type) @@ -784,7 +346,7 @@ mod tests { } #[test] - fn enumerate_with_cycle_no_revisits() { + fn choice_iter_with_cycle_no_revisits() { // Graph with a cycle: class 0 -> node -> class 0 let graph = EGraph::new( cfv(vec![EClass::new( @@ -802,7 +364,10 @@ mod tests { ); // With 0 revisits, we can only take the leaf option - let trees = graph.enumerate_trees(0, true); + let trees: Vec<_> = graph + .choice_iter(0) + .map(|c| graph.tree_from_choices(graph.root(), &c, true)) + .collect(); assert_eq!(trees.len(), 1); // with_type=true wraps in typeOf(expr, type) assert_eq!(trees[0].label(), "typeOf"); @@ -810,7 +375,7 @@ mod tests { } #[test] - fn enumerate_with_cycle_one_revisit() { + fn choice_iter_with_cycle_one_revisit() { // Graph with a cycle: class 0 -> node -> class 0 let graph = EGraph::new( cfv(vec![EClass::new( @@ -828,7 +393,10 @@ mod tests { ); // With 1 revisit, we can go one level deep - let trees = graph.enumerate_trees(1, true); + let trees: Vec<_> = graph + .choice_iter(1) + .map(|c| graph.tree_from_choices(graph.root(), &c, true)) + .collect(); // Should have: "leaf", "rec(leaf)", "rec(rec(leaf))" // Actually: at depth 0 we have 2 choices, at depth 1 we have 2 choices... @@ -843,136 +411,6 @@ mod tests { assert!(has_recursive); } - #[test] - fn min_distance_exact_match() { - // Graph contains the exact reference tree - let graph = EGraph::new( - cfv(vec![ - EClass::new( - vec![ENode::new("a".to_owned(), vec![eid(1), eid(2)])], - dummy_ty(), - ), - EClass::new(vec![ENode::leaf("b".to_owned())], dummy_ty()), - EClass::new(vec![ENode::leaf("c".to_owned())], dummy_ty()), - ]), - EClassId::new(0), - Vec::new(), - HashMap::new(), - dummy_nat_nodes(), - HashMap::new(), - ); - - // Reference: with_type=true wraps in typeOf(expr, type) - // typeOf(a(typeOf(b, type), typeOf(c, type)), type) - let reference = node( - "typeOf".to_owned(), - vec![ - node( - "a".to_owned(), - vec![ - node( - "typeOf".to_owned(), - vec![leaf("b".to_owned()), leaf("0".to_owned())], - ), - node( - "typeOf".to_owned(), - vec![leaf("c".to_owned()), leaf("0".to_owned())], - ), - ], - ), - leaf("0".to_owned()), // a's type - ], - ); - let result = find_min_zs(&graph, &reference, &UnitCost, 0, true) - .0 - .unwrap(); - - assert_eq!(result.1, 0); - } - - #[test] - fn min_distance_chooses_best() { - // Graph with OR choice: "a" or "x" - // Reference is "a", so should choose "a" with distance 0 - let graph = EGraph::new( - cfv(vec![EClass::new( - vec![ENode::leaf("a".to_owned()), ENode::leaf("x".to_owned())], - dummy_ty(), - )]), - EClassId::new(0), - Vec::new(), - HashMap::new(), - dummy_nat_nodes(), - HashMap::new(), - ); - - // Reference: with_type=true wraps in typeOf(expr, type) - let reference = node( - "typeOf".to_owned(), - vec![leaf("a".to_owned()), leaf("0".to_owned())], - ); - let result = find_min_zs(&graph, &reference, &UnitCost, 0, true) - .0 - .unwrap(); - - assert_eq!(result.1, 0); - // Result is wrapped in typeOf - assert_eq!(result.0.label(), "typeOf"); - assert_eq!(result.0.children()[0].label(), "a"); - } - - #[test] - fn min_distance_with_structure_choice() { - // Graph offers two structures: - // Option 1: a(b) - // Option 2: a(b, c) - // Reference: a(b) - // Should choose option 1 with distance 0 - let graph = EGraph::new( - cfv(vec![ - EClass::new( - vec![ - ENode::new("a".to_owned(), vec![eid(1)]), // a(b) - ENode::new("a".to_owned(), vec![eid(1), eid(2)]), // a(b, c) - ], - dummy_ty(), - ), - EClass::new(vec![ENode::leaf("b".to_owned())], dummy_ty()), - EClass::new(vec![ENode::leaf("c".to_owned())], dummy_ty()), - ]), - EClassId::new(0), - Vec::new(), - HashMap::new(), - dummy_nat_nodes(), - HashMap::new(), - ); - - // Reference: with_type=true wraps in typeOf(expr, type) - // typeOf(a(typeOf(b, type)), type) - let reference = node( - "typeOf".to_owned(), - vec![ - node( - "a".to_owned(), - vec![node( - "typeOf".to_owned(), - vec![leaf("b".to_owned()), leaf("0".to_owned())], - )], - ), - leaf("0".to_owned()), // a's type - ], - ); - let result = find_min_zs(&graph, &reference, &UnitCost, 0, true) - .0 - .unwrap(); - - assert_eq!(result.1, 0); - // Result is typeOf(a(...), type), so outer node has 2 children - assert_eq!(result.0.children().len(), 2); - // Inner 'a' has 1 child (b wrapped in typeOf) - assert_eq!(result.0.children()[0].children().len(), 1); - } - #[test] fn tree_from_choices_single_leaf() { let graph = single_node_graph("a"); @@ -1309,19 +747,8 @@ mod tests { } #[test] - fn tree_from_choices_matches_enumeration() { - // Helper to check if two trees are structurally equal - fn trees_equal(a: &TreeNode, b: &TreeNode) -> bool { - if a.label() != b.label() || a.children().len() != b.children().len() { - return false; - } - a.children() - .iter() - .zip(b.children().iter()) - .all(|(x, y)| trees_equal(x, y)) - } - - // Verify that tree_from_choices produces the same trees as enumeration + fn tree_from_choices_matches_choice_iter() { + // Verify that tree_from_choices produces correct trees for each choice vector let graph = EGraph::new( cfv(vec![ EClass::new( @@ -1343,23 +770,23 @@ mod tests { HashMap::new(), ); - let enumerated = graph.enumerate_trees(0, true); + let choices: Vec<_> = graph.choice_iter(0).collect(); // Should produce: x(a), x(b), y - assert_eq!(enumerated.len(), 3); + assert_eq!(choices.len(), 3); + assert_eq!(choices[0], vec![0, 0]); + assert_eq!(choices[1], vec![0, 1]); + assert_eq!(choices[2], vec![1]); - // Reconstruct using tree_from_choices - let choices1 = vec![0, 0]; - let tree1 = graph.tree_from_choices(EClassId::new(0), &choices1, true); - assert!(trees_equal(&tree1, &enumerated[0])); + // Verify tree_from_choices produces expected trees + let tree1 = graph.tree_from_choices(graph.root(), &choices[0], true); + assert_eq!(tree1.children()[0].label(), "x"); - let choices2 = vec![0, 1]; - let tree2 = graph.tree_from_choices(EClassId::new(0), &choices2, true); - assert!(trees_equal(&tree2, &enumerated[1])); + let tree2 = graph.tree_from_choices(graph.root(), &choices[1], true); + assert_eq!(tree2.children()[0].label(), "x"); - let choices3 = vec![1]; - let tree3 = graph.tree_from_choices(EClassId::new(0), &choices3, true); - assert!(trees_equal(&tree3, &enumerated[2])); + let tree3 = graph.tree_from_choices(graph.root(), &choices[2], true); + assert_eq!(tree3.children()[0].label(), "y"); } #[test] @@ -1381,179 +808,4 @@ mod tests { // Verify nat nodes exist assert!(!graph.nat_nodes.is_empty()); } - - #[test] - fn min_distance_extract_fast_exact_match() { - // Graph contains the exact reference tree - let graph = EGraph::new( - cfv(vec![ - EClass::new( - vec![ENode::new("a".to_owned(), vec![eid(1), eid(2)])], - dummy_ty(), - ), - EClass::new(vec![ENode::leaf("b".to_owned())], dummy_ty()), - EClass::new(vec![ENode::leaf("c".to_owned())], dummy_ty()), - ]), - EClassId::new(0), - Vec::new(), - HashMap::new(), - dummy_nat_nodes(), - HashMap::new(), - ); - - // Reference: with_type=true wraps in typeOf(expr, type) - // typeOf(a(typeOf(b, type), typeOf(c, type)), type) - let reference = node( - "typeOf".to_owned(), - vec![ - node( - "a".to_owned(), - vec![ - node( - "typeOf".to_owned(), - vec![leaf("b".to_owned()), leaf("0".to_owned())], - ), - node( - "typeOf".to_owned(), - vec![leaf("c".to_owned()), leaf("0".to_owned())], - ), - ], - ), - leaf("0".to_owned()), - ], - ); - - let result = find_min_zs(&graph, &reference, &UnitCost, 0, true) - .0 - .unwrap(); - assert_eq!(result.1, 0); - } - - #[test] - fn min_distance_extract_fast_chooses_best() { - // Graph with OR choice: "a" or "x" - let graph = EGraph::new( - cfv(vec![EClass::new( - vec![ENode::leaf("a".to_owned()), ENode::leaf("x".to_owned())], - dummy_ty(), - )]), - EClassId::new(0), - Vec::new(), - HashMap::new(), - dummy_nat_nodes(), - HashMap::new(), - ); - - // Reference: with_type=true wraps in typeOf(expr, type) - let reference = node( - "typeOf".to_owned(), - vec![leaf("a".to_owned()), leaf("0".to_owned())], - ); - - let result = find_min_zs(&graph, &reference, &UnitCost, 0, true) - .0 - .unwrap(); - assert_eq!(result.1, 0); - // Result is wrapped in typeOf - assert_eq!(result.0.label(), "typeOf"); - assert_eq!(result.0.children()[0].label(), "a"); - } - - #[test] - fn min_distance_extract_filtered_prunes_bad_trees() { - // Create a graph where one option is clearly worse - // Option 1: typeOf(a, type) - matches reference exactly - // Option 2: typeOf(x, type) - has different label, lower bound >= 1 - let graph = EGraph::new( - cfv(vec![EClass::new( - vec![ENode::leaf("a".to_owned()), ENode::leaf("x".to_owned())], - dummy_ty(), - )]), - EClassId::new(0), - Vec::new(), - HashMap::new(), - dummy_nat_nodes(), - HashMap::new(), - ); - - // Reference: typeOf(a, type) - tree_from_choices wraps in typeOf - let reference = node( - "typeOf".to_owned(), - vec![leaf("a".to_owned()), leaf("0".to_owned())], - ); - - let (result, stats) = find_min_zs(&graph, &reference, &UnitCost, 0, true); - - assert_eq!(result.unwrap().1, 0); - assert_eq!(stats.trees_enumerated, 2); - // With parallel execution, pruning is non-deterministic since both trees - // may be processed before best_distance is updated. We just verify the - // invariant that pruned + full_comparisons == trees_enumerated. - assert_eq!( - stats.size_pruned + stats.euler_pruned + stats.full_comparisons, - stats.trees_enumerated - ); - } - - #[test] - fn choice_iter_enumerates_all_trees_diamond_cycle() { - // Diamond with cycle at shared node - this pattern previously caused - // the iterator to miss some valid trees due to incomplete backtracking - // Class 0 -> a(1, 2), Class 1 -> b(3), Class 2 -> c(3), Class 3 -> [rec(3), d] - let graph = EGraph::new( - cfv(vec![ - EClass::new( - vec![ENode::new("a".to_owned(), vec![eid(1), eid(2)])], - dummy_ty(), - ), - EClass::new(vec![ENode::new("b".to_owned(), vec![eid(3)])], dummy_ty()), - EClass::new(vec![ENode::new("c".to_owned(), vec![eid(3)])], dummy_ty()), - EClass::new( - vec![ - ENode::new("rec".to_owned(), vec![eid(3)]), - ENode::leaf("d".to_owned()), - ], - dummy_ty(), - ), - ]), - EClassId::new(0), - Vec::new(), - HashMap::new(), - dummy_nat_nodes(), - HashMap::new(), - ); - - // With revisits=1, valid trees are (siblings are independent): - // 1. a(b(d), c(d)) - no revisits used - // 2. a(b(d), c(rec(d))) - c's branch revisits class 3 once - // 3. a(b(rec(d)), c(d)) - b's branch revisits class 3 once - // 4. a(b(rec(d)), c(rec(d))) - both branches revisit class 3 once each - assert_eq!(graph.choice_iter(1).count(), 4); - assert_eq!(graph.count_trees(1), graph.choice_iter(1).count()); - - // Verify specific trees are found - let trees = graph - .choice_iter(1) - .map(|c| graph.tree_from_choices(graph.root(), &c, false).to_string()) - .collect::>(); - assert!(trees.contains(&"(a (b d) (c d))".to_owned())); - assert!(trees.contains(&"(a (b d) (c (rec d)))".to_owned())); - assert!(trees.contains(&"(a (b (rec d)) (c d))".to_owned())); - assert!(trees.contains(&"(a (b (rec d)) (c (rec d)))".to_owned())); - - // With revisits=0, only one tree is valid (no revisits allowed) - assert_eq!(graph.choice_iter(0).count(), 1); - assert_eq!(graph.count_trees(0), graph.choice_iter(0).count()); - - // Verify TreeIter produces the same trees as ChoiceIter - let tree_iter_trees = graph - .tree_iter(1, false) - .map(|t| t.to_string()) - .collect::>(); - assert_eq!(tree_iter_trees.len(), 4); - assert!(tree_iter_trees.contains(&"(a (b d) (c d))".to_owned())); - assert!(tree_iter_trees.contains(&"(a (b d) (c (rec d)))".to_owned())); - assert!(tree_iter_trees.contains(&"(a (b (rec d)) (c d))".to_owned())); - assert!(tree_iter_trees.contains(&"(a (b (rec d)) (c (rec d)))".to_owned())); - } } diff --git a/src/distance/min.rs b/src/distance/min.rs new file mode 100644 index 0000000..dbfe641 --- /dev/null +++ b/src/distance/min.rs @@ -0,0 +1,544 @@ +//! Minimum distance search for E-Graphs. +//! +//! This module provides functions for finding the tree with minimum edit distance to a reference. + +use std::sync::atomic::{AtomicUsize, Ordering}; + +use indicatif::ParallelProgressIterator; +use rand::SeedableRng; +use rand::rngs::StdRng; +use rayon::iter::{ParallelBridge, ParallelIterator}; + +use crate::distance::sampling::find_lambda_for_target_size; +use crate::distance::{FixpointSampler, FixpointSamplerConfig, Sampler}; + +use super::graph::EGraph; +use super::nodes::Label; +use super::str::EulerString; +use super::structural::structural_diff; +use super::tree::TreeNode; +use super::zs::{EditCosts, PreprocessedTree, tree_distance_with_ref}; + +/// Statistics from filtered extraction +#[derive(Debug, Clone, Default)] +pub struct Stats { + /// Total number of trees enumerated + pub trees_enumerated: usize, + /// Trees pruned by simple metric + pub size_pruned: usize, + /// Number of trees pruned by euler string filter + pub euler_pruned: usize, + /// Number of trees for which full distance was computed + pub full_comparisons: usize, +} + +impl Stats { + pub(crate) fn size_pruned() -> Self { + Self { + trees_enumerated: 1, + size_pruned: 1, + euler_pruned: 0, + full_comparisons: 0, + } + } + + pub(crate) fn euler_pruned() -> Self { + Self { + trees_enumerated: 1, + size_pruned: 0, + euler_pruned: 1, + full_comparisons: 0, + } + } + + pub(crate) fn compared() -> Self { + Self { + trees_enumerated: 1, + size_pruned: 0, + euler_pruned: 0, + full_comparisons: 1, + } + } +} + +impl std::ops::Add for Stats { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self { + trees_enumerated: self.trees_enumerated + rhs.trees_enumerated, + size_pruned: self.size_pruned + rhs.size_pruned, + euler_pruned: self.euler_pruned + rhs.euler_pruned, + full_comparisons: self.full_comparisons + rhs.full_comparisons, + } + } +} + +/// Find the tree in the e-graph with minimum Zhang-Shasha edit distance to the reference. +/// +/// Uses parallel enumeration with pruning heuristics: +/// - Size difference lower bound +/// - Euler string distance lower bound +/// +/// # Arguments +/// * `graph` - The e-graph to search +/// * `reference` - The target tree to match +/// * `costs` - Edit cost function +/// * `max_revisits` - Maximum allowed revisits for cycle handling +/// * `with_types` - Whether to include type annotations in comparison +/// +/// # Returns +/// A tuple of (`best_result`, statistics) where `best_result` is `Some((tree, distance))` +/// if a tree was found. +#[must_use] +pub fn find_min_exhaustive_zs>( + graph: &EGraph, + reference: &TreeNode, + costs: &C, + max_revisits: usize, + with_types: bool, +) -> (Option<(TreeNode, usize)>, Stats) { + let ref_tree = if with_types { + reference + } else { + &reference.strip_types() + }; + + let ref_size = ref_tree.size(); + let ref_euler = EulerString::new(ref_tree); + let ref_pp = PreprocessedTree::new(ref_tree); + let running_best = AtomicUsize::new(usize::MAX); + + let (result, stats) = graph + .choice_iter(max_revisits) + .par_bridge() + .progress_count(graph.count_trees(max_revisits) as u64) + .map(|choices| { + { + let stripped_candidated = + graph.tree_from_choices(graph.root(), &choices, with_types); + let best = running_best.load(Ordering::Relaxed); + + // Fast pruning: size difference is a lower bound on edit distance + // (need at least |n1 - n2| insertions or deletions) + if stripped_candidated.size().abs_diff(ref_size) > best { + return (None, Stats::size_pruned()); + } + + // Euler string heuristic: EDS(s(T1), s(T2)) ≤ 2 · EDT(T1, T2) + // Therefore EDT ≥ EDS / 2, giving us a tighter lower bound + if ref_euler.lower_bound(&stripped_candidated, costs) > best { + return (None, Stats::euler_pruned()); + } + + let distance = tree_distance_with_ref(&stripped_candidated, &ref_pp, costs); + running_best.fetch_min(distance, Ordering::Relaxed); + + let tree = graph.tree_from_choices(graph.root(), &choices, true); + (Some((tree, distance)), Stats::compared()) + } + }) + .reduce( + || (None, Stats::default()), + |a, b| { + let best = [a.0, b.0].into_iter().flatten().min_by_key(|v| v.1); + (best, a.1 + b.1) + }, + ); + + (result, stats) +} + +/// Find the tree in the e-graph with minimum Zhang-Shasha edit distance to the reference. +/// +/// Uses fixpoint-based sampling instead of exhaustive enumeration. This is more efficient +/// for large e-graphs where exhaustive search is infeasible. Samples are drawn according to +/// a Boltzmann distribution that favors trees of a target weight. +/// +/// Uses the same parallel pruning heuristics as exhaustive search: +/// - Size difference lower bound +/// - Euler string distance lower bound +/// +/// # Arguments +/// * `graph` - The e-graph to search +/// * `reference` - The target tree to match +/// * `costs` - Edit cost function +/// * `with_types` - Whether to include type annotations in comparison +/// * `n_samples` - Number of trees to sample from the e-graph +/// * `target_weight` - Target tree size for the Boltzmann distribution (controls sampling bias) +/// * `seed` - Random seed for reproducible sampling +/// +/// # Returns +/// A tuple of (`best_result`, statistics) where `best_result` is `Some((tree, distance))` +/// if a tree was found. +/// +/// # Panics +/// +/// Panics if no sampler can be built +/// +/// # Note +/// The critical lambda parameter is automatically computed to target trees of the specified +#[must_use] +pub fn find_min_sampling_zs>( + graph: &EGraph, + reference: &TreeNode, + costs: &C, + with_types: bool, + n_samples: usize, + target_weight: usize, + seed: u64, +) -> (Option<(TreeNode, usize)>, Stats) { + let ref_tree = if with_types { + reference + } else { + &reference.strip_types() + }; + + let ref_size = ref_tree.size(); + let ref_euler = EulerString::new(ref_tree); + let ref_pp = PreprocessedTree::new(ref_tree); + let running_best = AtomicUsize::new(usize::MAX); + + let mut rng = StdRng::seed_from_u64(seed); + let config = FixpointSamplerConfig::builder().build(); + let (lambda, expected_size) = + find_lambda_for_target_size(graph, target_weight, &config, with_types, &mut rng).unwrap(); + eprintln!("LAMBDA IS {lambda}"); + eprintln!("EXPECTED SIZE IS {expected_size}"); + let (result, stats) = FixpointSampler::new(graph, lambda, &config, rng) + .unwrap() + .into_sample_iter(with_types) + .take(n_samples) + .par_bridge() + .progress_count(n_samples as u64) + .map(|candidate| { + { + let candidate_tree = if with_types { + &candidate + } else { + &candidate.strip_types() + }; + let best = running_best.load(Ordering::Relaxed); + + // Fast pruning: size difference is a lower bound on edit distance + // (need at least |n1 - n2| insertions or deletions) + if candidate_tree.size().abs_diff(ref_size) > best { + return (None, Stats::size_pruned()); + } + + // Euler string heuristic: EDS(s(T1), s(T2)) ≤ 2 · EDT(T1, T2) + // Therefore EDT ≥ EDS / 2, giving us a tighter lower bound + if ref_euler.lower_bound(candidate_tree, costs) > best { + return (None, Stats::euler_pruned()); + } + + let distance = tree_distance_with_ref(candidate_tree, &ref_pp, costs); + running_best.fetch_min(distance, Ordering::Relaxed); + + (Some((candidate, distance)), Stats::compared()) + } + }) + .reduce( + || (None, Stats::default()), + |a, b| { + let best = [a.0, b.0].into_iter().flatten().min_by_key(|v| v.1); + (best, a.1 + b.1) + }, + ); + + (result, stats) +} + +/// Find the tree in the e-graph with minimum structural difference to the reference. +/// +/// # Arguments +/// * `graph` - The e-graph to search +/// * `reference` - The target tree to match +/// * `costs` - Edit cost function +/// * `max_revisits` - Maximum allowed revisits for cycle handling +/// * `with_types` - Whether to include type annotations in comparison +/// * `ignore_labels` - Whether to ignore label differences (structure only) +/// +/// # Returns +/// `Some((tree, distance))` if a tree was found. +#[must_use] +pub fn find_min_struct>( + graph: &EGraph, + reference: &TreeNode, + costs: &C, + max_revisits: usize, + with_types: bool, + ignore_labels: bool, +) -> Option<(TreeNode, usize)> { + let ref_tree = if with_types { + reference + } else { + &reference.strip_types() + }; + graph + .choice_iter(max_revisits) + .par_bridge() + .progress_count(graph.count_trees(max_revisits) as u64) + .map(|choices| { + let stripped_candidated = graph.tree_from_choices(graph.root(), &choices, with_types); + let distance = structural_diff(ref_tree, &stripped_candidated, costs, ignore_labels); + let tree = graph.tree_from_choices(graph.root(), &choices, true); + (tree, distance) + }) + .min_by_key(|(_, d)| *d) +} + +#[cfg(test)] +mod tests { + use hashbrown::HashMap; + + use super::*; + use crate::distance::graph::EClass; + use crate::distance::ids::{EClassId, ExprChildId, NatId, TypeChildId}; + use crate::distance::nodes::{ENode, NatNode}; + use crate::distance::zs::UnitCost; + + fn leaf(label: impl Into) -> TreeNode { + TreeNode::leaf(label.into()) + } + + fn node(label: L, children: Vec>) -> TreeNode { + TreeNode::new(label, children) + } + + fn eid(i: usize) -> ExprChildId { + ExprChildId::EClass(EClassId::new(i)) + } + + fn dummy_ty() -> TypeChildId { + TypeChildId::Nat(NatId::new(0)) + } + + fn dummy_nat_nodes() -> HashMap> { + let mut nats = HashMap::new(); + nats.insert(NatId::new(0), NatNode::leaf("0".to_owned())); + nats + } + + fn cfv(classes: Vec>) -> HashMap> { + classes + .into_iter() + .enumerate() + .map(|(i, c)| (EClassId::new(i), c)) + .collect() + } + + #[test] + fn min_distance_exact_match() { + let graph = EGraph::new( + cfv(vec![ + EClass::new( + vec![ENode::new("a".to_owned(), vec![eid(1), eid(2)])], + dummy_ty(), + ), + EClass::new(vec![ENode::leaf("b".to_owned())], dummy_ty()), + EClass::new(vec![ENode::leaf("c".to_owned())], dummy_ty()), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let reference = node( + "typeOf".to_owned(), + vec![ + node( + "a".to_owned(), + vec![ + node( + "typeOf".to_owned(), + vec![leaf("b".to_owned()), leaf("0".to_owned())], + ), + node( + "typeOf".to_owned(), + vec![leaf("c".to_owned()), leaf("0".to_owned())], + ), + ], + ), + leaf("0".to_owned()), + ], + ); + let result = find_min_exhaustive_zs(&graph, &reference, &UnitCost, 0, true) + .0 + .unwrap(); + + assert_eq!(result.1, 0); + } + + #[test] + fn min_distance_chooses_best() { + let graph = EGraph::new( + cfv(vec![EClass::new( + vec![ENode::leaf("a".to_owned()), ENode::leaf("x".to_owned())], + dummy_ty(), + )]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let reference = node( + "typeOf".to_owned(), + vec![leaf("a".to_owned()), leaf("0".to_owned())], + ); + let result = find_min_exhaustive_zs(&graph, &reference, &UnitCost, 0, true) + .0 + .unwrap(); + + assert_eq!(result.1, 0); + assert_eq!(result.0.label(), "typeOf"); + assert_eq!(result.0.children()[0].label(), "a"); + } + + #[test] + fn min_distance_with_structure_choice() { + let graph = EGraph::new( + cfv(vec![ + EClass::new( + vec![ + ENode::new("a".to_owned(), vec![eid(1)]), + ENode::new("a".to_owned(), vec![eid(1), eid(2)]), + ], + dummy_ty(), + ), + EClass::new(vec![ENode::leaf("b".to_owned())], dummy_ty()), + EClass::new(vec![ENode::leaf("c".to_owned())], dummy_ty()), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let reference = node( + "typeOf".to_owned(), + vec![ + node( + "a".to_owned(), + vec![node( + "typeOf".to_owned(), + vec![leaf("b".to_owned()), leaf("0".to_owned())], + )], + ), + leaf("0".to_owned()), + ], + ); + let result = find_min_exhaustive_zs(&graph, &reference, &UnitCost, 0, true) + .0 + .unwrap(); + + assert_eq!(result.1, 0); + assert_eq!(result.0.children().len(), 2); + assert_eq!(result.0.children()[0].children().len(), 1); + } + + #[test] + fn min_distance_extract_fast_exact_match() { + let graph = EGraph::new( + cfv(vec![ + EClass::new( + vec![ENode::new("a".to_owned(), vec![eid(1), eid(2)])], + dummy_ty(), + ), + EClass::new(vec![ENode::leaf("b".to_owned())], dummy_ty()), + EClass::new(vec![ENode::leaf("c".to_owned())], dummy_ty()), + ]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let reference = node( + "typeOf".to_owned(), + vec![ + node( + "a".to_owned(), + vec![ + node( + "typeOf".to_owned(), + vec![leaf("b".to_owned()), leaf("0".to_owned())], + ), + node( + "typeOf".to_owned(), + vec![leaf("c".to_owned()), leaf("0".to_owned())], + ), + ], + ), + leaf("0".to_owned()), + ], + ); + + let result = find_min_exhaustive_zs(&graph, &reference, &UnitCost, 0, true) + .0 + .unwrap(); + assert_eq!(result.1, 0); + } + + #[test] + fn min_distance_extract_fast_chooses_best() { + let graph = EGraph::new( + cfv(vec![EClass::new( + vec![ENode::leaf("a".to_owned()), ENode::leaf("x".to_owned())], + dummy_ty(), + )]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let reference = node( + "typeOf".to_owned(), + vec![leaf("a".to_owned()), leaf("0".to_owned())], + ); + + let result = find_min_exhaustive_zs(&graph, &reference, &UnitCost, 0, true) + .0 + .unwrap(); + assert_eq!(result.1, 0); + assert_eq!(result.0.label(), "typeOf"); + assert_eq!(result.0.children()[0].label(), "a"); + } + + #[test] + fn min_distance_extract_filtered_prunes_bad_trees() { + let graph = EGraph::new( + cfv(vec![EClass::new( + vec![ENode::leaf("a".to_owned()), ENode::leaf("x".to_owned())], + dummy_ty(), + )]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ); + + let reference = node( + "typeOf".to_owned(), + vec![leaf("a".to_owned()), leaf("0".to_owned())], + ); + + let (result, stats) = find_min_exhaustive_zs(&graph, &reference, &UnitCost, 0, true); + + assert_eq!(result.unwrap().1, 0); + assert_eq!(stats.trees_enumerated, 2); + assert_eq!( + stats.size_pruned + stats.euler_pruned + stats.full_comparisons, + stats.trees_enumerated + ); + } +} diff --git a/src/distance/mod.rs b/src/distance/mod.rs index 654e600..91d5c54 100644 --- a/src/distance/mod.rs +++ b/src/distance/mod.rs @@ -1,7 +1,11 @@ +mod count; +mod extract; mod graph; mod ids; +mod min; mod nodes; pub mod rise; +mod sampling; mod str; mod structural; mod tree; @@ -10,8 +14,15 @@ mod zs; // Re-export rise types at this level for convenience pub use rise::{Expr, Nat, RiseLabel, Type}; -pub use graph::{EClass, EGraph, Stats, find_min_struct, find_min_zs}; +pub use count::TermCount; +pub use extract::ChoiceIter; +pub use graph::{EClass, EGraph}; +pub use min::{Stats, find_min_exhaustive_zs, find_min_sampling_zs, find_min_struct}; pub use nodes::Label; +pub use sampling::{ + DiverseSampler, DiverseSamplerConfig, FixpointSampler, FixpointSamplerConfig, Sampler, + SamplingIter, structural_hash, +}; pub use str::tree_distance_euler_bound; pub use tree::TreeNode; pub use zs::{EditCosts, UnitCost, tree_distance, tree_distance_unit}; diff --git a/src/distance/sampling.rs b/src/distance/sampling.rs new file mode 100644 index 0000000..3327946 --- /dev/null +++ b/src/distance/sampling.rs @@ -0,0 +1,823 @@ +//! Boltzmann Sampling for E-Graphs +//! +//! Provides methods for sampling terms from an e-graph with control over: +//! - Term size distribution (via Boltzmann parameter λ) +//! - Diversity (via structural hashing and deduplication) +//! +//! # Boltzmann Sampling +//! +//! Each term is weighted by `λ^size` where `λ ∈ (0, 1]`. Smaller λ values +//! bias toward smaller terms. The "critical" λ gives a target expected size. +//! +//! Iterative fixed-point computation handles cycles +//! correctly by converging to stable weights. + +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +use hashbrown::{HashMap, HashSet}; +use ordered_float::OrderedFloat; +use rand::prelude::*; + +use super::graph::EGraph; +use super::ids::{DataChildId, DataId, EClassId, ExprChildId, NatId}; +use super::nodes::{ENode, Label}; +use super::tree::TreeNode; + +/// Trait for samplers that can draw terms from an e-graph. +pub trait Sampler: Sized { + type Label: Label; + + /// Sample a single term, returning `None` if sampling fails. + fn sample(&mut self, with_types: bool) -> Option>; + + /// Sample multiple terms. + fn sample_many(&mut self, count: usize, with_types: bool) -> Vec> { + (0..count).filter_map(|_| self.sample(with_types)).collect() + } + + /// Convert this sampler into an iterator that yields samples. + /// + /// The iterator will call `sample()` on each `next()` and terminate + /// when the sampler returns `None`. Use `.take(n)` to limit samples. + /// + /// # Example + /// ``` + /// let sampler = FixpointSampler::new(&graph, &config, rng).unwrap(); + /// for tree in sampler.into_sample_iter(true).take(100) { + /// println!("{:?}", tree); + /// } + /// ``` + fn into_sample_iter(self, with_types: bool) -> SamplingIter { + SamplingIter::new(self, with_types) + } +} + +/// Sampler using fixed-point iteration for weight computation. +/// +/// This sampler handles cyclic e-graphs correctly by computing weights +/// via iterative convergence rather than recursion. Use this when your +/// e-graph may contain cycles. +/// +/// Internally uses log-space arithmetic for numerical stability when +/// λ is small or trees are deep. +pub struct FixpointSampler<'a, L: Label, R: Rng> { + graph: &'a EGraph, + log_lambda: OrderedFloat, + /// Log-weights for each e-class: log(W[id]) + log_weights: HashMap>, + rng: R, + max_depth: usize, +} + +/// Configuration for the fixed-point sampler. +#[derive(Debug, Clone, bon::Builder)] +#[builder(derive(Clone, Debug))] +pub struct FixpointSamplerConfig { + /// Convergence threshold for weight computation. + #[builder(default = 1e-3)] + pub epsilon: f64, + /// Maximum iterations for weight convergence. + #[builder(default = 1000)] + pub max_iterations: usize, + /// Maximum depth during sampling (prevents infinite loops on cycles). + #[builder(default = 1000)] + pub max_depth: usize, +} + +impl FixpointSamplerConfig { + /// Config for graphs with cycles - uses smaller lambda for faster convergence. + #[must_use] + pub fn for_cyclic() -> Self { + Self::builder().max_depth(100).build() + } +} + +/// Find the lambda value that produces trees of approximately the target size. +/// +/// Uses binary search over lambda, estimating average tree size at each point +/// by sampling. Larger lambda values produce larger trees (monotonic relationship). +/// +/// # Arguments +/// * `graph` - The e-graph to sample from +/// * `target_size` - The desired average tree size (number of nodes) +/// * `config` - Configuration for the search +/// * `rng` - Random number generator (will be used to create seedable sub-rngs) +/// +/// # Returns +/// * `Ok((lambda, actual_avg_size))` - The found lambda and the actual average size achieved +/// * `Err(FindLambdaError)` - If the search fails (e.g., target unreachable) +pub fn find_lambda_for_target_size( + graph: &EGraph, + target_size: usize, + fixpoint_config: &FixpointSamplerConfig, + with_types: bool, + rng: &mut R, +) -> Result<(f64, f64), FindLambdaError> { + let mut lo = 0.001; + + // For cyclic graphs, we need to find the maximum lambda that converges. + // Binary search to find the critical lambda (highest that still converges). + let mut hi = find_critical_lambda(graph, fixpoint_config, rng)?; + eprintln!("Critical lambda found: {hi}"); + + // First, check bounds to ensure target is achievable + // Use a relative epsilon for convergence that scales with expected weight magnitude + let size_at_min = estimate_avg_size( + graph, + fixpoint_config, + lo, + with_types, + R::seed_from_u64(rng.next_u64()), + )?; + // eprintln!("size_min is doable"); + let size_at_max = estimate_avg_size( + graph, + fixpoint_config, + hi, + with_types, + R::seed_from_u64(rng.next_u64()), + )?; + // eprintln!("size_max is doable"); + eprintln!("Achievable size range: {size_at_min:.1} - {size_at_max:.1} (target: {target_size})"); + + #[expect(clippy::cast_precision_loss)] + let target = target_size as f64; + + if target < size_at_min { + return Err(FindLambdaError::TargetTooSmall { + target: target_size, + min_achievable: size_at_min, + }); + } + if target > size_at_max { + return Err(FindLambdaError::TargetTooLarge { + target: target_size, + max_achievable: size_at_max, + }); + } + + // Binary search + let mut best_lambda = f64::midpoint(lo, hi); + let mut best_size = 0.0; + + while hi - lo > fixpoint_config.epsilon { + // eprintln!("Current best lambda: {best_lambda}"); + // eprintln!("current best size: {best_size}"); + + let mid = f64::midpoint(lo, hi); + let avg_size = estimate_avg_size( + graph, + fixpoint_config, + mid, + with_types, + R::seed_from_u64(rng.next_u64()), + )?; + + best_lambda = mid; + best_size = avg_size; + + if avg_size < target { + lo = mid; + } else { + hi = mid; + } + } + + Ok((best_lambda, best_size)) +} + +/// Find the critical lambda - the highest value where the sampler still converges. +/// +/// For cyclic e-graphs, there's a threshold lambda above which the fixed-point +/// iteration diverges. This function uses binary search to find that threshold. +fn find_critical_lambda( + graph: &EGraph, + config: &FixpointSamplerConfig, + rng: &mut R, +) -> Result { + let mut lo = 0.001; // Known to converge (very small lambda) + let mut hi = 1.0; // Upper bound to search + + // First verify that lo converges + if FixpointSampler::new(graph, lo, config, R::seed_from_u64(rng.next_u64())).is_none() { + return Err(FindLambdaError::SamplerDidNotConverge { lambda: lo }); + } + + // Binary search for the critical lambda + while hi - lo > 0.001 { + let mid = f64::midpoint(lo, hi); + if FixpointSampler::new(graph, mid, config, R::seed_from_u64(rng.next_u64())).is_some() { + lo = mid; // mid converges, try higher + } else { + hi = mid; // mid diverges, try lower + } + } + + // Return slightly below the critical point for safety margin + Ok(lo * 0.99) +} + +/// Estimate the average tree size for a given lambda by sampling. +fn estimate_avg_size( + graph: &EGraph, + config: &FixpointSamplerConfig, + lambda: f64, + with_types: bool, + rng: R, +) -> Result { + let samples = FixpointSampler::new(graph, lambda, config, rng) + .ok_or(FindLambdaError::SamplerDidNotConverge { lambda })? + .sample_many(1000, with_types); + + if samples.is_empty() { + return Err(FindLambdaError::SamplerDidNotConverge { lambda }); + } + + #[expect(clippy::cast_precision_loss)] + let avg = samples.iter().map(|t| t.size()).sum::() as f64 / samples.len() as f64; + + Ok(avg) +} + +/// Errors that can occur when finding lambda for a target size. +#[derive(Debug, Clone, thiserror::Error)] +pub enum FindLambdaError { + /// The target size is smaller than achievable with the minimum lambda. + #[error("target size {target} is too small; minimum achievable is {min_achievable:.1}")] + TargetTooSmall { target: usize, min_achievable: f64 }, + /// The target size is larger than achievable with the maximum lambda. + #[error("target size {target} is too large; maximum achievable is {max_achievable:.1}")] + TargetTooLarge { target: usize, max_achievable: f64 }, + /// The sampler failed to converge for a given lambda. + #[error("sampler did not converge for lambda={lambda}")] + SamplerDidNotConverge { lambda: f64 }, +} + +impl<'a, L: Label, R: Rng> FixpointSampler<'a, L, R> { + /// Create a new fixed-point sampler. + /// + /// Computes weights via fixed-point iteration until convergence. + /// Uses log-space arithmetic internally for numerical stability. + /// + /// # Returns + /// `None` if weight computation does not converge, otherwise the sampler. + /// + /// # Panics + /// Panics if `lambda` is not in the range (0, 1]. + pub fn new( + graph: &'a EGraph, + lambda: f64, + config: &FixpointSamplerConfig, + rng: R, + ) -> Option { + assert!(lambda > 0.0 && lambda <= 1.0, "λ must be in (0, 1]"); + + let log_lambda = OrderedFloat(lambda.ln()); + + // Collect canonical class IDs and initialize log-weights to log(λ) + let mut log_weights = graph + .class_ids() + .map(|id| (graph.canonicalize(id), log_lambda)) + .collect::>(); + let class_ids = log_weights.keys().copied().collect::>(); + + let mut prev_max_delta = f64::INFINITY; + let mut divergence_count = 0; + + for _ in 0..config.max_iterations { + let mut max_delta: f64 = 0.0; + + for &id in &class_ids { + // new_log_weight = log(sum over nodes of (λ × product of child weights)) + // = log_sum_exp(log(λ) + sum of log(child weights)) + let node_log_weights = graph.class(id).nodes().iter().map(|node| { + node.children() + .iter() + .map(|child| match child { + ExprChildId::EClass(eid) => log_weights[&graph.canonicalize(*eid)], + ExprChildId::Nat(nat_id) => { + Self::nat_log_weight(graph, *nat_id, log_lambda) + } + ExprChildId::Data(dt_id) => { + Self::dt_log_weight(graph, *dt_id, log_lambda) + } + }) + .sum::>() + + log_lambda + }); + + let max = node_log_weights + .clone() + .max() + .expect("log_sum_exp requires non-empty input"); + let new_log_weight = + max + node_log_weights.map(|v| (v - max).exp()).sum::().ln(); + + // Convergence check: |Δ log w| < ε means relative change |w_new/w_old - 1| ≈ ε + let delta = (new_log_weight - log_weights[&id]).abs(); + max_delta = max_delta.max(delta); + log_weights.insert(id, new_log_weight); + } + // eprintln!("max_delta at {max_delta}"); + if max_delta < config.epsilon { + return Some(Self { + graph, + log_lambda, + log_weights, + rng, + max_depth: config.max_depth, + }); + } + + // Detect divergence: if max_delta is consistently increasing, we're diverging + if max_delta > prev_max_delta * 1.5 { + divergence_count += 1; + if divergence_count >= 3 { + // Weights are diverging - lambda is too large for this cyclic graph + eprintln!("Divergence detected: lambda={lambda} is too large"); + return None; + } + } else { + divergence_count = 0; + } + prev_max_delta = max_delta; + } + None + } + + fn nat_log_weight( + graph: &EGraph, + id: NatId, + log_lambda: OrderedFloat, + ) -> OrderedFloat { + graph + .nat(id) + .children() + .iter() + .map(|child_id| Self::nat_log_weight(graph, *child_id, log_lambda)) + .sum::>() + + log_lambda + } + + fn dt_log_weight( + graph: &EGraph, + id: DataId, + log_lambda: OrderedFloat, + ) -> OrderedFloat { + graph + .data_ty(id) + .children() + .iter() + .map(|child_id| match child_id { + DataChildId::Nat(nat_id) => Self::nat_log_weight(graph, *nat_id, log_lambda), + DataChildId::DataType(data_id) => Self::dt_log_weight(graph, *data_id, log_lambda), + }) + .sum::>() + + log_lambda + } + + /// Compute log-weight for a node (sum of children log-weights + log(λ)). + fn node_log_weight(&self, node: &ENode) -> OrderedFloat { + node.children() + .iter() + .map(|child| match child { + ExprChildId::EClass(eid) => { + let c = self.graph.canonicalize(*eid); + self.log_weights[&c] + } + ExprChildId::Nat(nat_id) => { + Self::nat_log_weight(self.graph, *nat_id, self.log_lambda) + } + ExprChildId::Data(dt_id) => { + Self::dt_log_weight(self.graph, *dt_id, self.log_lambda) + } + }) + .sum::>() + + self.log_lambda + } + + /// Sample a term rooted at the given e-class. + fn sample_from(&mut self, id: EClassId, depth: usize, with_types: bool) -> Option> { + if depth >= self.max_depth { + return None; + } + + let nodes = self.graph.class(id).nodes(); + + // Compute log-weights for each node + let log_weights = nodes + .iter() + .map(|node| (node, self.node_log_weight(node))) + .collect::>(); + + // Convert to probabilities using softmax (numerically stable) + let max_log = log_weights.values().max().expect("e-class has no nodes"); + + let chosen_node = nodes + // Compute exp(log_w - max) for numerical stability + .choose_weighted(&mut self.rng, |node| (log_weights[node] - max_log).exp()) + .expect("weights are always positive"); + + let expr_tree = TreeNode::new( + chosen_node.label().clone(), + chosen_node + .children() + .iter() + .map(|c| self.sample_child(c, depth + 1, with_types)) + .collect::>()?, + ); + if with_types { + let type_tree = TreeNode::from_eclass(self.graph, id); + Some(TreeNode::new(L::type_of(), vec![expr_tree, type_tree])) + } else { + Some(expr_tree) + } + } + + /// Sample a child, dispatching on child type. + fn sample_child( + &mut self, + child: &ExprChildId, + depth: usize, + with_types: bool, + ) -> Option> { + match child { + ExprChildId::EClass(eid) => self.sample_from(*eid, depth, with_types), + ExprChildId::Nat(nid) => Some(TreeNode::from_nat(self.graph, *nid)), + ExprChildId::Data(did) => Some(TreeNode::from_data(self.graph, *did)), + } + } +} + +impl Sampler for FixpointSampler<'_, L, R> { + type Label = L; + + fn sample(&mut self, with_types: bool) -> Option> { + self.sample_from(self.graph.root(), 0, with_types) + } +} + +/// Configuration for diverse sampling. +#[derive(Debug, Clone, bon::Builder)] +pub struct DiverseSamplerConfig { + /// Maximum attempts to find a novel sample before giving up. + #[builder(default = 100)] + pub max_attempts_per_sample: usize, + /// Minimum novelty ratio (0, 1]. Sample accepted if this fraction of features are new. + #[builder(default = 0.3)] + pub min_novelty_ratio: f64, +} + +/// Sampler that produces diverse terms using structural deduplication. +/// Generic over any underlying sampler. +pub struct DiverseSampler { + sampler: S, + config: DiverseSamplerConfig, + seen_hashes: HashSet, + seen_features: HashSet<(S::Label, usize, S::Label)>, +} + +impl DiverseSampler { + /// Create a new diverse sampler wrapping an existing sampler. + pub fn new(sampler: S, config: DiverseSamplerConfig) -> Self { + Self { + sampler, + config, + seen_hashes: HashSet::new(), + seen_features: HashSet::new(), + } + } + + /// Check if a term is novel enough to accept. + #[expect(clippy::type_complexity)] + fn is_novel( + &self, + term: &TreeNode, + ) -> (bool, u64, HashSet<(S::Label, usize, S::Label)>) { + let hash = structural_hash(term); + let features = extract_features(term); + + // Novel if we've never seen this exact structure + if !self.seen_hashes.contains(&hash) { + return (true, hash, features); + } + + // Otherwise check feature novelty ratio + let novel_count = features + .iter() + .filter(|f| !self.seen_features.contains(*f)) + .count(); + + #[expect(clippy::cast_precision_loss)] + let novelty_ratio = if features.is_empty() { + 0.0 + } else { + novel_count as f64 / features.len() as f64 + }; + + ( + novelty_ratio >= self.config.min_novelty_ratio, + hash, + features, + ) + } + + /// Accept a term, updating seen hashes and features. + fn accept(&mut self, hash: u64, features: HashSet<(S::Label, usize, S::Label)>) { + self.seen_hashes.insert(hash); + self.seen_features.extend(features); + } + + /// Reset the diversity tracking state. + pub fn reset(&mut self) { + self.seen_hashes.clear(); + self.seen_features.clear(); + } + + pub fn seen_hashes(&self) -> &HashSet { + &self.seen_hashes + } + + pub fn seen_features(&self) -> &HashSet<(S::Label, usize, S::Label)> { + &self.seen_features + } +} + +impl Sampler for DiverseSampler { + type Label = S::Label; + + fn sample(&mut self, with_types: bool) -> Option> { + let this = &mut *self; + for _ in 0..this.config.max_attempts_per_sample { + let term = this.sampler.sample(with_types)?; + let (is_novel, hash, features) = this.is_novel(&term); + if is_novel { + this.accept(hash, features); + return Some(term); + } + } + None + } +} + +/// Compute a structural hash of a tree for diversity checking. +/// Trees with the same structure and labels will have the same hash. +#[must_use] +pub fn structural_hash(tree: &TreeNode) -> u64 { + let mut hasher = DefaultHasher::new(); + hash_tree_rec(tree, &mut hasher); + hasher.finish() +} + +fn hash_tree_rec(tree: &TreeNode, hasher: &mut H) { + tree.label().hash(hasher); + tree.children().len().hash(hasher); + for child in tree.children() { + hash_tree_rec(child, hasher); + } +} + +/// Extract structural features from a tree for diversity measurement. +/// Returns bigrams of `(parent_label, child_index, child_label)`. +#[must_use] +pub fn extract_features(tree: &TreeNode) -> HashSet<(L, usize, L)> { + let mut features = HashSet::new(); + collect_features(tree, &mut features); + features +} + +fn collect_features(tree: &TreeNode, features: &mut HashSet<(L, usize, L)>) { + let parent = tree.label().clone(); + for (i, child) in tree.children().iter().enumerate() { + features.insert((parent.clone(), i, child.label().clone())); + collect_features(child, features); + } +} + +/// Iterator adapter that yields samples from any `Sampler`. +/// +/// This iterator wraps a sampler and calls `sample()` on each `next()`. +/// The iterator terminates when the sampler returns `None`, or you can +/// use methods like `.take(n)` to limit the number of samples. +/// +/// # Example +/// ```ignore +/// let sampler = FixpointSampler::new(&graph, &config, rng).unwrap(); +/// let iter = SamplingIter::new(sampler); +/// for tree in iter.take(100) { +/// println!("{:?}", tree); +/// } +/// ``` +pub struct SamplingIter { + sampler: S, + with_types: bool, +} + +impl SamplingIter { + /// Create a new sampling iterator from a sampler. + pub fn new(sampler: S, with_types: bool) -> Self { + Self { + sampler, + with_types, + } + } + + /// Consume the iterator and return the underlying sampler. + pub fn into_inner(self) -> S { + self.sampler + } + + /// Get a reference to the underlying sampler. + pub fn sampler(&self) -> &S { + &self.sampler + } + + /// Get a mutable reference to the underlying sampler. + pub fn sampler_mut(&mut self) -> &mut S { + &mut self.sampler + } + + pub fn with_types(&self) -> bool { + self.with_types + } +} + +impl Iterator for SamplingIter { + type Item = TreeNode; + + fn next(&mut self) -> Option { + self.sampler.sample(self.with_types) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::graph::EClass; + use crate::distance::ids::NatId; + use crate::distance::ids::TypeChildId; + use crate::distance::nodes::{ENode, NatNode}; + use rand::SeedableRng; + use rand::rngs::StdRng; + + fn dummy_ty() -> TypeChildId { + TypeChildId::Nat(NatId::new(0)) + } + + fn dummy_nat_nodes() -> HashMap> { + let mut map = HashMap::new(); + map.insert(NatId::new(0), NatNode::leaf("0".to_owned())); + map + } + + fn eid(i: usize) -> ExprChildId { + ExprChildId::EClass(EClassId::new(i)) + } + + fn cfv(classes: Vec>) -> HashMap> { + classes + .into_iter() + .enumerate() + .map(|(i, c)| (EClassId::new(i), c)) + .collect() + } + + fn cyclic_graph() -> EGraph { + // Class 0: "f" with child Class 0 (cycle!), or leaf "x" + // This represents: x, f(x), f(f(x)), f(f(f(x))), ... + EGraph::new( + cfv(vec![EClass::new( + vec![ + ENode::new("f".to_owned(), vec![eid(0)]), // cycle back to self + ENode::leaf("x".to_owned()), + ], + dummy_ty(), + )]), + EClassId::new(0), + Vec::new(), + HashMap::new(), + dummy_nat_nodes(), + HashMap::new(), + ) + } + + #[test] + fn fixpoint_handles_cycles() { + let graph = cyclic_graph(); + let config = FixpointSamplerConfig::for_cyclic(); + let rng = StdRng::seed_from_u64(42); + + let mut sampler = FixpointSampler::new(&graph, 0.5, &config, rng) + .expect("Should converge with λ < 1 on cyclic graph"); + + // Should be able to sample without infinite loop + let terms = sampler.sample_many(100, false); + assert!(!terms.is_empty(), "Should produce some terms"); + + // With small lambda, most terms should be small + let small_terms = terms.iter().filter(|t| t.label() == "x").count(); + assert!( + small_terms > 30, + "With λ=0.5, should prefer leaf 'x', got {small_terms}/100" + ); + } + + #[test] + fn fixpoint_weights_converge() { + let graph = cyclic_graph(); + + // With λ = 0.5, the weight equation for the cyclic class is: + // W = λ × W + λ (from "f" with child + "x" leaf) + // W = 0.5W + 0.5 + // 0.5W = 0.5 + // W = 1.0 + // + // This means P(x) = 0.5/1.0 = 50%, P(f(...)) = 50% + let config = FixpointSamplerConfig::builder().max_depth(100).build(); + let rng = StdRng::seed_from_u64(42); + let mut sampler = FixpointSampler::new(&graph, 0.5, &config, rng).unwrap(); + + // Verify the probability distribution matches theory + let leaf_count = sampler + .sample_many(1000, false) + .iter() + .filter(|t| t.label() == "x") + .count(); + + // Should be close to 50% (allowing for variance) + assert!( + (400..600).contains(&leaf_count), + "Expected ~50% leaves, got {leaf_count}/1000" + ); + } + + #[test] + fn sampling_iter_yields_samples() { + let graph = cyclic_graph(); + let config = FixpointSamplerConfig::for_cyclic(); + let rng = StdRng::seed_from_u64(42); + + // Use the iterator interface + let samples = FixpointSampler::new(&graph, 0.5, &config, rng) + .expect("Should converge with λ < 1 on cyclic graph") + .into_sample_iter(false) + .take(50) + .collect::>(); + + assert_eq!(samples.len(), 50, "Should yield exactly 50 samples"); + + // All samples should have valid root labels + for sample in &samples { + assert!( + sample.label() == "f" || sample.label() == "x", + "Unexpected root label: {}", + sample.label() + ); + } + } + + #[test] + fn sampling_iter_can_access_sampler() { + let graph = cyclic_graph(); + let config = FixpointSamplerConfig::for_cyclic(); + let rng = StdRng::seed_from_u64(42); + + let sampler = FixpointSampler::new(&graph, 0.5, &config, rng).unwrap(); + let mut iter = SamplingIter::new(sampler, false); + + // Take some samples + let _ = iter.next(); + let _ = iter.next(); + + // We can still access the sampler through the iterator + let _sampler_ref = iter.sampler(); + + // And recover the sampler when done + let _sampler = iter.into_inner(); + } + + #[test] + fn find_lambda_for_target_size_works() { + let graph = cyclic_graph(); + + let mut rng = StdRng::seed_from_u64(42); + let config = FixpointSamplerConfig::builder().build(); + + // Target size of 2 (e.g., f(x) has 2 nodes) + let result = find_lambda_for_target_size(&graph, 2, &config, false, &mut rng); + assert!( + result.is_ok(), + "Should find a lambda for target size 2: {:?}", + result.err() + ); + + let (lambda, avg_size) = result.unwrap(); + assert!(lambda > 0.0 && lambda < 1.0, "Lambda should be in (0, 1)"); + // Allow some variance in the achieved size + assert!( + (1.5..4.0).contains(&avg_size), + "Average size {avg_size} should be roughly near target 2" + ); + } +}