diff --git a/vortex-array/src/expr/analysis/immediate_access.rs b/vortex-array/src/expr/analysis/immediate_access.rs index 6dc26a7c950..f68cd3567dc 100644 --- a/vortex-array/src/expr/analysis/immediate_access.rs +++ b/vortex-array/src/expr/analysis/immediate_access.rs @@ -16,15 +16,44 @@ use crate::expr::exprs::select::Select; pub type FieldAccesses<'a> = Annotations<'a, FieldName>; -/// An [`AnnotationFn`] for annotating scope accesses. -pub fn annotate_scope_access(scope: &StructFields) -> impl AnnotationFn { +/// Returns the "free fields" for this expression node. +/// +/// A "free field" is a top-level field from the root scope that this expression references—not +/// nested fields within those top-level fields. For example, `root().a.b` has free field `{a}`, +/// not `{b}`, because `a` is the top-level field being accessed from root. +/// +/// The term "free" is borrowed from PL theory's "free variables"—variables that reference an +/// outer scope rather than being introduced locally. +/// +/// This is useful for column pruning, where we only need to read the top-level fields that an +/// expression actually touches. +/// +/// # Annotation Rules +/// +/// - **[`Select`]**: Returns the included field names if the child is [`Root`]. +/// - **[`GetItem`] on [`Root`]**: Returns `[field_name]` if the child is [`Root`]. +/// - **[`Root`]**: Returns all field names from `scope` (conservative over-approximation). +/// - **Everything else**: Returns empty (annotations aggregate from children automatically). +/// +/// # Example +/// +/// Given `scope = {a: {b: .., c: ..}, d: ..}` and `expr = root().a.b + root().d`: +/// - `root().a` has free fields `{a}`. +/// - `root().d` has free fields `{d}`. +/// - The full expression has free fields `{a, d}` (not `b`, only top-level fields are tracked). +pub fn make_free_field_annotator( + scope: &StructFields, +) -> impl AnnotationFn { move |expr: &Expression| { - assert!( - !expr.is::() { + if expr.child(0).is::() { + return selection + .normalize_to_included_fields(scope.names()) + .vortex_expect("Select fields must be valid for scope") + .into_iter() + .collect(); + } + } else if let Some(field_name) = expr.as_opt::() { if expr.child(0).is::() { return vec![field_name.clone()]; } @@ -47,7 +76,7 @@ pub fn immediate_scope_accesses<'a>( expr: &'a Expression, scope: &'a StructFields, ) -> FieldAccesses<'a> { - descendent_annotations(expr, annotate_scope_access(scope)) + descendent_annotations(expr, make_free_field_annotator(scope)) } /// This returns the immediate scope_access (as explained `immediate_scope_accesses`) for `expr`. diff --git a/vortex-array/src/expr/exprs/select.rs b/vortex-array/src/expr/exprs/select.rs index 5782806c530..f3ed0c41b49 100644 --- a/vortex-array/src/expr/exprs/select.rs +++ b/vortex-array/src/expr/exprs/select.rs @@ -19,18 +19,18 @@ use vortex_session::VortexSession; use crate::IntoArray; use crate::arrays::StructArray; +use crate::expr; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; use crate::expr::ExecutionResult; use crate::expr::ExprId; +use crate::expr::Pack; use crate::expr::SimplifyCtx; use crate::expr::VTable; use crate::expr::VTableExt; use crate::expr::expression::Expression; use crate::expr::field::DisplayFieldNames; -use crate::expr::get_item; -use crate::expr::pack; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum FieldSelection { @@ -47,7 +47,7 @@ impl VTable for Select { ExprId::new_ref("vortex.select") } - fn serialize(&self, instance: &Self::Options) -> VortexResult>> { + fn serialize(&self, instance: &FieldSelection) -> VortexResult>> { let opts = match instance { FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames { names: fields.iter().map(|f| f.to_string()).collect(), @@ -65,7 +65,7 @@ impl VTable for Select { &self, _metadata: &[u8], _session: &VortexSession, - ) -> VortexResult { + ) -> VortexResult { let prost_metadata = SelectOpts::decode(_metadata)?; let select_opts = prost_metadata @@ -84,11 +84,11 @@ impl VTable for Select { Ok(field_selection) } - fn arity(&self, _options: &Self::Options) -> Arity { + fn arity(&self, _options: &FieldSelection) -> Arity { Arity::Exact(1) } - fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { + fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::new_ref("child"), _ => unreachable!(), @@ -165,44 +165,66 @@ impl VTable for Select { fn simplify( &self, - options: &Self::Options, + selection: &FieldSelection, expr: &Expression, ctx: &dyn SimplifyCtx, ) -> VortexResult> { - let child = expr.child(0); - let child_dtype = ctx.return_dtype(child)?; - let child_nullability = child_dtype.nullability(); + let child_struct = expr.child(0); + let struct_dtype = ctx.return_dtype(child_struct)?; + let struct_nullability = struct_dtype.nullability(); - let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| { + let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| { vortex_err!( "Select child must return a struct dtype, however it was a {}", - child_dtype + struct_dtype ) })?; - let expr = pack( - options - .as_include_names(child_dtype.names()) - .map_err(|e| { - e.with_context(format!( - "Select fields {:?} must be a subset of child fields {:?}", - options, - child_dtype.names() - )) - })? - .iter() - .map(|name| (name.clone(), get_item(name.clone(), child.clone()))), - child_nullability, - ); + // "Mask" out the unwanted fields of the child struct `DType`. + let included_fields = selection.normalize_to_included_fields(struct_fields.names())?; + let all_included_fields_are_nullable = included_fields.iter().all(|name| { + struct_fields + .field(name) + .vortex_expect( + "`normalize_to_included_fields` checks that the included fields already exist \ + in `struct_fields`", + ) + .is_nullable() + }); + + // We cannot always convert a `select` into a `pack(get_item(f1), get_item(f2), ...)`. + // This is because `get_item` does a validity intersection of the struct validity with its + // fields, which is not the same as just "masking" out the unwanted fields (a selection). + // + // We can, however, make this simplification when the child of the `select` is already a + // `pack` and we know that `get_item` will do no validity intersections. + let child_is_pack = child_struct.is::(); + + // `get_item` only performs validity intersection when the struct is nullable but the field + // is not. This would change the semantics of a `select`, so we can only simplify when this + // won't happen. + let would_intersect_validity = + struct_nullability.is_nullable() && !all_included_fields_are_nullable; + + if child_is_pack && !would_intersect_validity { + let pack_expr = expr::pack( + included_fields + .into_iter() + .map(|name| (name.clone(), expr::get_item(name, child_struct.clone()))), + struct_nullability, + ); - Ok(Some(expr)) + return Ok(Some(pack_expr)); + } + + Ok(None) } - fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { + fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool { true } - fn is_fallible(&self, _instance: &Self::Options) -> bool { + fn is_fallible(&self, _instance: &FieldSelection) -> bool { // If this type-checks its infallible. false } @@ -260,21 +282,26 @@ impl FieldSelection { fields } - pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult { + pub fn normalize_to_included_fields( + &self, + available_fields: &FieldNames, + ) -> VortexResult { + // Check that all of the field names exist in the available fields. if self .field_names() .iter() - .any(|f| !field_names.iter().contains(f)) + .any(|f| !available_fields.iter().contains(f)) { vortex_bail!( - "Field {:?} in select not in field names {:?}", + "Select fields {:?} must be a subset of child fields {:?}", self, - field_names + available_fields ); } + match self { FieldSelection::Include(fields) => Ok(fields.clone()), - FieldSelection::Exclude(exc_fields) => Ok(field_names + FieldSelection::Exclude(exc_fields) => Ok(available_fields .iter() .filter(|f| !exc_fields.iter().contains(f)) .cloned() @@ -308,7 +335,6 @@ mod tests { use crate::IntoArray; use crate::ToCanonical; use crate::arrays::StructArray; - use crate::expr::exprs::pack::Pack; use crate::expr::exprs::root::root; use crate::expr::exprs::select::Select; use crate::expr::test_harness; @@ -393,11 +419,11 @@ mod tests { assert_eq!( &include .as_::() - .as_include_names(&field_names) + .normalize_to_included_fields(&field_names) .unwrap() ); } @@ -412,7 +438,6 @@ mod tests { let result = e.optimize_recursive(&dtype).unwrap(); - assert!(result.is::()); assert!(result.return_dtype(&dtype).unwrap().is_nullable()); } @@ -431,8 +456,6 @@ mod tests { let result = e.optimize_recursive(&dtype).unwrap(); - assert!(result.is::()); - // Should exclude "c" and include "a" and "b" let result_dtype = result.return_dtype(&dtype).unwrap(); assert!(result_dtype.is_nullable()); diff --git a/vortex-array/src/expr/transform/partition.rs b/vortex-array/src/expr/transform/partition.rs index 4994ae55f9d..44c40986da9 100644 --- a/vortex-array/src/expr/transform/partition.rs +++ b/vortex-array/src/expr/transform/partition.rs @@ -211,7 +211,7 @@ mod tests { use vortex_dtype::StructFields; use super::*; - use crate::expr::analysis::annotate_scope_access; + use crate::expr::analysis::make_free_field_annotator; use crate::expr::exprs::binary::and; use crate::expr::exprs::get_item::col; use crate::expr::exprs::get_item::get_item; @@ -219,7 +219,6 @@ mod tests { use crate::expr::exprs::merge::merge; use crate::expr::exprs::pack::pack; use crate::expr::exprs::root::root; - use crate::expr::exprs::select::select; use crate::expr::transform::replace::replace_root_fields; #[fixture] @@ -245,7 +244,8 @@ mod tests { let fields = dtype.as_struct_fields_opt().unwrap(); let expr = root(); - let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap(); + let partitioned = + partition(expr.clone(), &dtype, make_free_field_annotator(fields)).unwrap(); // An un-expanded root expression is annotated by all fields, but since it is a single node assert_eq!(partitioned.partitions.len(), 0); @@ -253,7 +253,7 @@ mod tests { // Instead, callers must expand the root expression themselves. let expr = replace_root_fields(expr, fields); - let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap(); + let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap(); assert_eq!(partitioned.partitions.len(), fields.names().len()); } @@ -264,7 +264,7 @@ mod tests { let expr = get_item("y", get_item("a", root())); - let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap(); + let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap(); assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root()))); } @@ -280,7 +280,7 @@ mod tests { ], NonNullable, ); - let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap(); + let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap(); let split_a = partitioned.find_partition(&"a".into()).unwrap(); assert_eq!( @@ -300,7 +300,7 @@ mod tests { let fields = dtype.as_struct_fields_opt().unwrap(); let expr = and(get_item("y", get_item("a", root())), lit(1)); - let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap(); + let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap(); // Whole expr is a single split assert_eq!(partitioned.partitions.len(), 1); @@ -311,60 +311,19 @@ mod tests { let fields = dtype.as_struct_fields_opt().unwrap(); let expr = and(get_item("y", get_item("a", root())), get_item("b", root())); - let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap(); + let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap(); // One for id.a and id.b assert_eq!(partitioned.partitions.len(), 2); } - // Test that typed_simplify removes select and partition precise - #[rstest] - fn test_expr_partition_many_occurrences_of_field(dtype: DType) { - let fields = dtype.as_struct_fields_opt().unwrap(); - - let expr = and( - get_item("y", get_item("a", root())), - select(["a", "b"], root()), - ); - let expr = expr.optimize_recursive(&dtype).unwrap(); - let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap(); - - // One for id.a and id.b - assert_eq!(partitioned.partitions.len(), 2); - - // This fetches [].$c which is unused, however a previous optimisation should replace select - // with get_item and pack removing this field. - assert_eq!( - &partitioned.root, - &and( - get_item("a_0", get_item("a", root())), - pack( - [ - ( - "a", - get_item( - StructFieldExpressionSplitter::::field_name( - &"a".into(), - 1 - ), - get_item("a", root()) - ) - ), - ("b", get_item("b_0", get_item("b", root()))) - ], - NonNullable - ) - ) - ) - } - #[rstest] fn test_expr_merge(dtype: DType) { let fields = dtype.as_struct_fields_opt().unwrap(); let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]); - let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap(); + let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap(); let expected = pack( [ ("x", get_item("x", get_item("a_0", col("a")))), diff --git a/vortex-layout/src/layouts/struct_/reader.rs b/vortex-layout/src/layouts/struct_/reader.rs index 66c06505ff5..acc9f07f1b8 100644 --- a/vortex-layout/src/layouts/struct_/reader.rs +++ b/vortex-layout/src/layouts/struct_/reader.rs @@ -17,8 +17,8 @@ use vortex_array::expr::ExactExpr; use vortex_array::expr::Expression; use vortex_array::expr::Merge; use vortex_array::expr::Pack; -use vortex_array::expr::annotate_scope_access; use vortex_array::expr::col; +use vortex_array::expr::make_free_field_annotator; use vortex_array::expr::root; use vortex_array::expr::transform::PartitionedExpr; use vortex_array::expr::transform::partition; @@ -166,7 +166,7 @@ impl StructReader { let mut partitioned = partition( expr.clone(), self.dtype(), - annotate_scope_access( + make_free_field_annotator( self.dtype() .as_struct_fields_opt() .vortex_expect("We know it's a struct DType"), @@ -405,6 +405,7 @@ mod tests { use vortex_dtype::FieldName; use vortex_dtype::Nullability; use vortex_dtype::PType; + use vortex_dtype::StructFields; use vortex_io::runtime::single::block_on; use vortex_mask::Mask; use vortex_scalar::Scalar; @@ -699,7 +700,7 @@ mod tests { #[from(nested_struct_layout)] (segments, layout): (Arc, LayoutRef), ) { // Project out the nested struct field. - // The projection should preserve the nulls of the `a` column when we select out the + // The projection should preserve the nulls of the `b` struct when we select out the // child column `c`. let reader = layout.new_reader("".into(), segments, &SESSION).unwrap(); let expr = select( @@ -712,9 +713,21 @@ mod tests { .unwrap(); let result = block_on(move |_| project).unwrap(); - assert!(result.dtype().is_struct()); - // Struct scalars holding the "c" field value scalars + // The result is a nullable struct (because root.a.b is nullable) with a non-nullable + // field "c" (because the original field was non-nullable). + assert_eq!( + result.dtype(), + &DType::Struct( + StructFields::from_iter([( + "c", + DType::Primitive(PType::I32, Nullability::NonNullable) + )]), + Nullability::Nullable, + ) + ); + + // Row 0: struct is valid, field "c" is 4. assert_eq!( result .scalar_at(0) @@ -722,17 +735,13 @@ mod tests { .as_struct() .field_by_idx(0) .unwrap(), - Scalar::primitive(4, Nullability::Nullable) - ); - assert!( - result - .scalar_at(1) - .unwrap() - .as_struct() - .field_by_idx(0) - .unwrap() - .is_null(), + Scalar::primitive(4, Nullability::NonNullable) ); + + // Row 1: struct is null (because root.a.b was null at this row). + assert!(result.scalar_at(1).unwrap().as_struct().is_null()); + + // Row 2: struct is valid, field "c" is 6. assert_eq!( result .scalar_at(2) @@ -740,7 +749,7 @@ mod tests { .as_struct() .field_by_idx(0) .unwrap(), - Scalar::primitive(6, Nullability::Nullable) + Scalar::primitive(6, Nullability::NonNullable) ); }