Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 33 additions & 66 deletions src/bin/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use clap::{Args as ClapArgs, Parser};

use serde::de::DeserializeOwned;

use eggshell::distance::{EGraph, Expr, Label, TreeNode, UnitCost, find_min_struct, find_min_zs};
use eggshell::distance::{EGraph, Expr, Label, TreeNode, UnitCost, find_min_sampling_zs};

#[derive(Parser)]
#[command(about = "Find the closest tree in an e-graph to a reference tree")]
Expand All @@ -22,17 +22,20 @@ Examples:
# With revisits and quiet mode
extract graph.json -e '(foo bar)' -r 2 -q
")]
#[expect(clippy::struct_excessive_bools)]
struct Args {
/// Path to the serialized e-graph JSON file
egraph: String,

#[command(flatten)]
reference: RefSource,

/// Maximum number of times a node may be revisited (for cycles)
#[arg(short = 'r', long, default_value_t = 0)]
max_revisits: usize,
/// Target weight
#[arg(short, long, default_value_t = 0)]
target_weight: usize,

/// Number of samples
#[arg(short = 'n', long, default_value_t = 10000)]
samples: usize,

/// Include the types in the comparison
#[arg(short, long)]
Expand All @@ -41,14 +44,13 @@ struct Args {
/// Use raw string labels instead of Rise-typed labels (for regression testing)
#[arg(long)]
raw_strings: bool,
// /// Use structural distance instead of Zhang-Shasha tree edit distance
// #[arg(short, long)]
// structural: bool,

/// Use structural distance instead of Zhang-Shasha tree edit distance
#[arg(short, long)]
structural: bool,

/// Ignore the labels when using the structural option
#[arg(short, long, requires_all = ["structural"])]
ignore_labels: bool,
// /// Ignore the labels when using the structural option
// #[arg(short, long, requires_all = ["structural"])]
// ignore_labels: bool,
}

#[derive(ClapArgs)]
Expand Down Expand Up @@ -94,6 +96,16 @@ where

println!(" Root e-class: {:?}", graph.root());

let ref_tree = parse_ref(args, parse_tree);

run_extraction(&graph, &ref_tree, args);
}

fn parse_ref<L, F>(args: &Args, parse_tree: F) -> TreeNode<L>
where
L: Label + std::fmt::Display + DeserializeOwned,
F: Fn(&str) -> TreeNode<L>,
{
let ref_tree: TreeNode<L> = if let Some(expr) = &args.reference.expr {
println!("Parsing reference tree from command line...");
parse_tree(expr)
Expand All @@ -116,50 +128,25 @@ where
})
.unwrap_or_else(|| panic!("No tree with name {name} found"))
};

run_extraction(&graph, &ref_tree, args);
ref_tree
}

fn run_extraction<L: Label + std::fmt::Display>(
graph: &EGraph<L>,
ref_tree: &TreeNode<L>,
args: &Args,
) {
#[expect(clippy::cast_precision_loss)]
fn run_extraction<L: Label>(graph: &EGraph<L>, ref_tree: &TreeNode<L>, args: &Args) {
let ref_node_count = ref_tree.size();
println!(" Reference tree has {ref_node_count} nodes");

// Count trees in the e-graph
println!(
"\nCounting trees in e-graph (max_revisits={})...",
args.max_revisits
);
let count_start = Instant::now();
let tree_count = graph.count_trees(args.max_revisits);
let count_time = count_start.elapsed();
println!(" Found {tree_count} trees in {count_time:.2?}");

if tree_count == 0 {
println!("No trees found in e-graph!");
return;
}
let ref_stripped_count = ref_tree.strip_types().size();
println!(" Reference tree has {ref_node_count} nodes ({ref_stripped_count} without types)");

if args.structural {
run_structural(graph, ref_tree, args);
} else {
run_zs(graph, ref_tree, args);
}
}

#[expect(clippy::cast_precision_loss)]
fn run_zs<L: Label>(graph: &EGraph<L>, ref_tree: &TreeNode<L>, args: &Args) {
let start = Instant::now();
println!("\n--- Zhang-Shasha extraction (with lower-bound pruning) ---");
if let (Some(result), stats) = find_min_zs(
if let (Some(result), stats) = find_min_sampling_zs(
graph,
ref_tree,
&UnitCost,
args.max_revisits,
args.with_types,
args.samples,
args.target_weight,
100,
) {
println!(" Best distance: {}", result.1);
println!(" Time: {:.2?}", start.elapsed());
Expand All @@ -186,23 +173,3 @@ fn run_zs<L: Label>(graph: &EGraph<L>, ref_tree: &TreeNode<L>, args: &Args) {
println!(" No result found!");
}
}

fn run_structural<L: Label>(graph: &EGraph<L>, ref_tree: &TreeNode<L>, args: &Args) {
let start = Instant::now();
println!("\n--- Structural distance extraction ---");
if let Some((tree, distance)) = find_min_struct(
graph,
ref_tree,
&UnitCost,
args.max_revisits,
args.with_types,
args.ignore_labels,
) {
println!(" Best distance: {distance}");
println!(" Time: {:.2?}", start.elapsed());
println!("\n Best tree:");
println!("{tree}");
} else {
println!(" No result found!");
}
}
Loading
Loading