Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions arrow-array/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,70 @@ impl StructArray {
fields,
}
}

/// Returns the children of this [`StructArray`] with the struct's validity
/// bitmap AND'd into each child's validity bitmap.
///
/// This ensures that positions where the struct itself is null are also
/// null in each returned child array. Fields that were non-nullable are
/// marked nullable in the returned [`Fields`] when the struct has nulls.
///
/// If the struct has no nulls, children and fields are returned as-is.
///
/// This mirrors the semantics of C++ Arrow's `StructArray::Flatten`.
///
/// # Example
///
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::{Array, ArrayRef, Int32Array, StructArray};
/// # use arrow_buffer::{BooleanBuffer, NullBuffer};
/// # use arrow_schema::{DataType, Field, Fields};
/// let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
/// let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true]));
/// let sa = StructArray::new(
/// Fields::from(vec![Field::new("a", DataType::Int32, false)]),
/// vec![child],
/// Some(struct_nulls),
/// );
/// let (fields, columns) = sa.flatten();
/// assert!(fields[0].is_nullable());
/// assert!(columns[0].is_null(1));
/// ```
pub fn flatten(&self) -> (Fields, Vec<ArrayRef>) {
let schema_fields = self.fields();

let struct_nulls = match &self.nulls {
Some(n) => n,
None => return (schema_fields.clone(), self.fields.clone()),
};

let new_fields: Fields = schema_fields
.iter()
.map(|f| {
if f.is_nullable() {
Arc::clone(f)
} else {
Arc::new(f.as_ref().clone().with_nullable(true))
}
})
.collect::<Vec<_>>()
.into();

let new_columns = self
.fields
.iter()
.map(|child| {
let merged = NullBuffer::union(Some(struct_nulls), child.nulls());
// SAFETY: We only make the null buffer more restrictive (adding nulls).
// All data buffers and child data remain unchanged.
let data = child.to_data().into_builder().nulls(merged);
make_array(unsafe { data.build_unchecked() })
})
.collect();

(new_fields, new_columns)
}
}

impl From<ArrayData> for StructArray {
Expand Down Expand Up @@ -958,4 +1022,140 @@ mod tests {

StructArray::try_new(fields, arrays, nulls).expect("should not error");
}

#[test]
fn test_flatten_no_nulls() {
let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
let sa = StructArray::from(vec![(
Arc::new(Field::new("a", DataType::Int32, false)),
child,
)]);

let (fields, columns) = sa.flatten();

assert_eq!(columns.len(), 1);
assert!(!fields[0].is_nullable());
assert_eq!(columns[0].null_count(), 0);
assert_eq!(columns[0].len(), 3);
}

#[test]
fn test_flatten_struct_nulls_child_no_nulls() {
let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true]));
let sa = StructArray::new(
Fields::from(vec![Field::new("a", DataType::Int32, false)]),
vec![child],
Some(struct_nulls),
);

let (fields, columns) = sa.flatten();

assert!(fields[0].is_nullable());
assert!(columns[0].is_valid(0));
assert!(columns[0].is_null(1));
assert!(columns[0].is_valid(2));
assert_eq!(columns[0].null_count(), 1);
}

#[test]
fn test_flatten_both_have_nulls() {
// struct validity: [valid, null, valid, valid]
// child validity: [valid, valid, null, valid]
// expected: [valid, null, null, valid]
let child = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])) as ArrayRef;
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, true]));
let sa = StructArray::new(
Fields::from(vec![Field::new("a", DataType::Int32, true)]),
vec![child],
Some(struct_nulls),
);

let (fields, columns) = sa.flatten();

assert!(fields[0].is_nullable());
assert!(columns[0].is_valid(0));
assert!(columns[0].is_null(1));
assert!(columns[0].is_null(2));
assert!(columns[0].is_valid(3));
assert_eq!(columns[0].null_count(), 2);
}

#[test]
fn test_flatten_sliced_struct() {
let child = Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as ArrayRef;
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, false]));
let sa = StructArray::new(
Fields::from(vec![Field::new("a", DataType::Int32, false)]),
vec![child],
Some(struct_nulls),
);
let sliced = sa.slice(1, 2);

let (fields, columns) = sliced.flatten();

assert!(fields[0].is_nullable());
assert_eq!(columns[0].len(), 2);
assert!(columns[0].is_null(0));
assert!(columns[0].is_valid(1));
}

#[test]
fn test_flatten_multiple_children() {
let int_child = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])) as ArrayRef;
let str_child = Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef;
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true]));
let sa = StructArray::new(
Fields::from(vec![
Field::new("ints", DataType::Int32, true),
Field::new("strs", DataType::Utf8, true),
]),
vec![int_child, str_child],
Some(struct_nulls),
);

let (fields, columns) = sa.flatten();

assert_eq!(fields.len(), 2);
// int: [valid, null(struct), null(child)] => null_count=2
assert_eq!(columns[0].null_count(), 2);
assert!(columns[0].is_valid(0));
assert!(columns[0].is_null(1));
assert!(columns[0].is_null(2));
// str: [valid, null(struct+child), valid] => null_count=1
assert_eq!(columns[1].null_count(), 1);
assert!(columns[1].is_valid(0));
assert!(columns[1].is_null(1));
assert!(columns[1].is_valid(2));
}

#[test]
fn test_flatten_empty_struct() {
let sa = StructArray::new_empty_fields(5, Some(NullBuffer::new_null(5)));

let (fields, columns) = sa.flatten();

assert_eq!(fields.len(), 0);
assert_eq!(columns.len(), 0);
}

#[test]
fn test_flatten_field_nullability_update() {
let non_null_child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
let nullable_child = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef;
let struct_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, true, false]));
let sa = StructArray::new(
Fields::from(vec![
Field::new("non_null", DataType::Int32, false),
Field::new("nullable", DataType::Int32, true),
]),
vec![non_null_child, nullable_child],
Some(struct_nulls),
);

let (fields, _columns) = sa.flatten();

assert!(fields[0].is_nullable()); // was false, now true
assert!(fields[1].is_nullable()); // was true, stays true
}
}
59 changes: 40 additions & 19 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,37 +543,29 @@ impl RecordBatch {
0 => usize::MAX,
val => val,
};
let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
let mut stack: Vec<(usize, ArrayRef, String, FieldRef)> = self
.columns
.iter()
.zip(self.schema.fields())
.rev()
.map(|(c, f)| {
let name_vec: Vec<&str> = vec![f.name()];
(0, c, name_vec, f)
})
.map(|(c, f)| (0, c.clone(), f.name().clone(), Arc::clone(f)))
.collect();
let mut columns: Vec<ArrayRef> = Vec::new();
let mut fields: Vec<FieldRef> = Vec::new();

while let Some((depth, c, name, field_ref)) = stack.pop() {
match field_ref.data_type() {
DataType::Struct(ff) if depth < max_level => {
// Need to zip these in reverse to maintain original order
for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
let mut name = name.clone();
name.push(separator);
name.push(fff.name());
stack.push((depth + 1, cff, name, fff))
DataType::Struct(_) if depth < max_level => {
let (flat_fields, flat_cols) = c.as_struct().flatten();
for (cff, fff) in flat_cols.into_iter().zip(flat_fields.iter()).rev() {
let child_name = format!("{name}{separator}{}", fff.name());
stack.push((depth + 1, cff, child_name, Arc::clone(fff)))
}
}
_ => {
let updated_field = Field::new(
name.concat(),
field_ref.data_type().clone(),
field_ref.is_nullable(),
);
columns.push(c.clone());
let updated_field =
Field::new(name, field_ref.data_type().clone(), field_ref.is_nullable());
columns.push(c);
fields.push(Arc::new(updated_field));
}
}
Expand Down Expand Up @@ -973,7 +965,7 @@ mod tests {
use crate::{
BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray,
};
use arrow_buffer::{Buffer, ToByteSlice};
use arrow_buffer::{Buffer, NullBuffer, ToByteSlice};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::Fields;
use std::collections::HashMap;
Expand Down Expand Up @@ -1771,4 +1763,33 @@ mod tests {
"bar"
);
}

#[test]
fn test_normalize_nullable_struct() {
let child = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
let struct_nulls =
NullBuffer::new(arrow_buffer::BooleanBuffer::from(vec![true, false, true]));
let struct_array = Arc::new(StructArray::new(
Fields::from(vec![Field::new("x", DataType::Int32, false)]),
vec![child],
Some(struct_nulls),
)) as ArrayRef;

let schema = Schema::new(vec![Field::new(
"s",
DataType::Struct(Fields::from(vec![Field::new("x", DataType::Int32, false)])),
true,
)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![struct_array]).unwrap();

let normalized = batch.normalize(".", None).unwrap();

assert_eq!(normalized.num_columns(), 1);
assert_eq!(normalized.schema().field(0).name(), "s.x");
assert!(normalized.schema().field(0).is_nullable());
let col = normalized.column(0);
assert!(col.is_valid(0));
assert!(col.is_null(1));
assert!(col.is_valid(2));
}
}