diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index da837ba16b7..ba187bdef27 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -343,6 +343,70 @@ impl StructArray { fields, } } + + /// Returns the children of this [`StructArray`] with the struct's validity + /// bitmap AND'd into each child's validity bitmap. + /// + /// This ensures that positions where the struct itself is null are also + /// null in each returned child array. Fields that were non-nullable are + /// marked nullable in the returned [`Fields`] when the struct has nulls. + /// + /// If the struct has no nulls, children and fields are returned as-is. + /// + /// This mirrors the semantics of C++ Arrow's `StructArray::Flatten`. + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::{Array, ArrayRef, Int32Array, StructArray}; + /// # use arrow_buffer::{BooleanBuffer, NullBuffer}; + /// # use arrow_schema::{DataType, Field, Fields}; + /// let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + /// let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true])); + /// let sa = StructArray::new( + /// Fields::from(vec![Field::new("a", DataType::Int32, false)]), + /// vec![child], + /// Some(struct_nulls), + /// ); + /// let (fields, columns) = sa.flatten(); + /// assert!(fields[0].is_nullable()); + /// assert!(columns[0].is_null(1)); + /// ``` + pub fn flatten(&self) -> (Fields, Vec) { + let schema_fields = self.fields(); + + let struct_nulls = match &self.nulls { + Some(n) => n, + None => return (schema_fields.clone(), self.fields.clone()), + }; + + let new_fields: Fields = schema_fields + .iter() + .map(|f| { + if f.is_nullable() { + Arc::clone(f) + } else { + Arc::new(f.as_ref().clone().with_nullable(true)) + } + }) + .collect::>() + .into(); + + let new_columns = self + .fields + .iter() + .map(|child| { + let merged = NullBuffer::union(Some(struct_nulls), child.nulls()); + // SAFETY: We only make the null buffer more restrictive (adding nulls). + // All data buffers and child data remain unchanged. + let data = child.to_data().into_builder().nulls(merged); + make_array(unsafe { data.build_unchecked() }) + }) + .collect(); + + (new_fields, new_columns) + } } impl From for StructArray { @@ -958,4 +1022,140 @@ mod tests { StructArray::try_new(fields, arrays, nulls).expect("should not error"); } + + #[test] + fn test_flatten_no_nulls() { + let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let sa = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Int32, false)), + child, + )]); + + let (fields, columns) = sa.flatten(); + + assert_eq!(columns.len(), 1); + assert!(!fields[0].is_nullable()); + assert_eq!(columns[0].null_count(), 0); + assert_eq!(columns[0].len(), 3); + } + + #[test] + fn test_flatten_struct_nulls_child_no_nulls() { + let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true])); + let sa = StructArray::new( + Fields::from(vec![Field::new("a", DataType::Int32, false)]), + vec![child], + Some(struct_nulls), + ); + + let (fields, columns) = sa.flatten(); + + assert!(fields[0].is_nullable()); + assert!(columns[0].is_valid(0)); + assert!(columns[0].is_null(1)); + assert!(columns[0].is_valid(2)); + assert_eq!(columns[0].null_count(), 1); + } + + #[test] + fn test_flatten_both_have_nulls() { + // struct validity: [valid, null, valid, valid] + // child validity: [valid, valid, null, valid] + // expected: [valid, null, null, valid] + let child = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])) as ArrayRef; + let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, true])); + let sa = StructArray::new( + Fields::from(vec![Field::new("a", DataType::Int32, true)]), + vec![child], + Some(struct_nulls), + ); + + let (fields, columns) = sa.flatten(); + + assert!(fields[0].is_nullable()); + assert!(columns[0].is_valid(0)); + assert!(columns[0].is_null(1)); + assert!(columns[0].is_null(2)); + assert!(columns[0].is_valid(3)); + assert_eq!(columns[0].null_count(), 2); + } + + #[test] + fn test_flatten_sliced_struct() { + let child = Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as ArrayRef; + let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, false])); + let sa = StructArray::new( + Fields::from(vec![Field::new("a", DataType::Int32, false)]), + vec![child], + Some(struct_nulls), + ); + let sliced = sa.slice(1, 2); + + let (fields, columns) = sliced.flatten(); + + assert!(fields[0].is_nullable()); + assert_eq!(columns[0].len(), 2); + assert!(columns[0].is_null(0)); + assert!(columns[0].is_valid(1)); + } + + #[test] + fn test_flatten_multiple_children() { + let int_child = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])) as ArrayRef; + let str_child = Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef; + let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true])); + let sa = StructArray::new( + Fields::from(vec![ + Field::new("ints", DataType::Int32, true), + Field::new("strs", DataType::Utf8, true), + ]), + vec![int_child, str_child], + Some(struct_nulls), + ); + + let (fields, columns) = sa.flatten(); + + assert_eq!(fields.len(), 2); + // int: [valid, null(struct), null(child)] => null_count=2 + assert_eq!(columns[0].null_count(), 2); + assert!(columns[0].is_valid(0)); + assert!(columns[0].is_null(1)); + assert!(columns[0].is_null(2)); + // str: [valid, null(struct+child), valid] => null_count=1 + assert_eq!(columns[1].null_count(), 1); + assert!(columns[1].is_valid(0)); + assert!(columns[1].is_null(1)); + assert!(columns[1].is_valid(2)); + } + + #[test] + fn test_flatten_empty_struct() { + let sa = StructArray::new_empty_fields(5, Some(NullBuffer::new_null(5))); + + let (fields, columns) = sa.flatten(); + + assert_eq!(fields.len(), 0); + assert_eq!(columns.len(), 0); + } + + #[test] + fn test_flatten_field_nullability_update() { + let non_null_child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let nullable_child = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef; + let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, true, false])); + let sa = StructArray::new( + Fields::from(vec![ + Field::new("non_null", DataType::Int32, false), + Field::new("nullable", DataType::Int32, true), + ]), + vec![non_null_child, nullable_child], + Some(struct_nulls), + ); + + let (fields, _columns) = sa.flatten(); + + assert!(fields[0].is_nullable()); // was false, now true + assert!(fields[1].is_nullable()); // was true, stays true + } } diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index f400ac4d0de..e05450a97f7 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -543,37 +543,29 @@ impl RecordBatch { 0 => usize::MAX, val => val, }; - let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self + let mut stack: Vec<(usize, ArrayRef, String, FieldRef)> = self .columns .iter() .zip(self.schema.fields()) .rev() - .map(|(c, f)| { - let name_vec: Vec<&str> = vec![f.name()]; - (0, c, name_vec, f) - }) + .map(|(c, f)| (0, c.clone(), f.name().clone(), Arc::clone(f))) .collect(); let mut columns: Vec = Vec::new(); let mut fields: Vec = Vec::new(); while let Some((depth, c, name, field_ref)) = stack.pop() { match field_ref.data_type() { - DataType::Struct(ff) if depth < max_level => { - // Need to zip these in reverse to maintain original order - for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() { - let mut name = name.clone(); - name.push(separator); - name.push(fff.name()); - stack.push((depth + 1, cff, name, fff)) + DataType::Struct(_) if depth < max_level => { + let (flat_fields, flat_cols) = c.as_struct().flatten(); + for (cff, fff) in flat_cols.into_iter().zip(flat_fields.iter()).rev() { + let child_name = format!("{name}{separator}{}", fff.name()); + stack.push((depth + 1, cff, child_name, Arc::clone(fff))) } } _ => { - let updated_field = Field::new( - name.concat(), - field_ref.data_type().clone(), - field_ref.is_nullable(), - ); - columns.push(c.clone()); + let updated_field = + Field::new(name, field_ref.data_type().clone(), field_ref.is_nullable()); + columns.push(c); fields.push(Arc::new(updated_field)); } } @@ -973,7 +965,7 @@ mod tests { use crate::{ BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray, }; - use arrow_buffer::{Buffer, ToByteSlice}; + use arrow_buffer::{Buffer, NullBuffer, ToByteSlice}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::Fields; use std::collections::HashMap; @@ -1771,4 +1763,33 @@ mod tests { "bar" ); } + + #[test] + fn test_normalize_nullable_struct() { + let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let struct_nulls = + NullBuffer::new(arrow_buffer::BooleanBuffer::from(vec![true, false, true])); + let struct_array = Arc::new(StructArray::new( + Fields::from(vec![Field::new("x", DataType::Int32, false)]), + vec![child], + Some(struct_nulls), + )) as ArrayRef; + + let schema = Schema::new(vec![Field::new( + "s", + DataType::Struct(Fields::from(vec![Field::new("x", DataType::Int32, false)])), + true, + )]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![struct_array]).unwrap(); + + let normalized = batch.normalize(".", None).unwrap(); + + assert_eq!(normalized.num_columns(), 1); + assert_eq!(normalized.schema().field(0).name(), "s.x"); + assert!(normalized.schema().field(0).is_nullable()); + let col = normalized.column(0); + assert!(col.is_valid(0)); + assert!(col.is_null(1)); + assert!(col.is_valid(2)); + } }