diff --git a/encodings/alp/public-api.lock b/encodings/alp/public-api.lock index 31111f5dc7d..79d7c8f61bb 100644 --- a/encodings/alp/public-api.lock +++ b/encodings/alp/public-api.lock @@ -262,7 +262,7 @@ pub fn vortex_alp::ALPVTable::between(array: &vortex_alp::ALPArray, lower: &dyn impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_alp::ALPVTable -pub fn vortex_alp::ALPVTable::compare(lhs: &vortex_alp::ALPArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_alp::ALPVTable::compare(lhs: &vortex_alp::ALPArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::expr::exprs::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_alp::ALPVTable diff --git a/encodings/alp/src/alp/compute/compare.rs b/encodings/alp/src/alp/compute/compare.rs index d96b4bd2375..83441c4da1c 100644 --- a/encodings/alp/src/alp/compute/compare.rs +++ b/encodings/alp/src/alp/compute/compare.rs @@ -8,10 +8,11 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::NativePType; use vortex_array::expr::CompareKernel; +use vortex_array::expr::CompareOperator; +use vortex_array::expr::Operator; use vortex_array::scalar::Scalar; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -28,7 +29,7 @@ impl CompareKernel for ALPVTable { fn compare( lhs: &ALPArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { if lhs.patches().is_some() { @@ -71,7 +72,7 @@ impl CompareKernel for ALPVTable { fn alp_scalar_compare>( alp: &ALPArray, value: F, - operator: Operator, + operator: CompareOperator, ) -> VortexResult> where F::ALPInt: Into, @@ -89,16 +90,19 @@ where match encoded { Some(encoded) => { let s = ConstantArray::new(encoded, alp.len()); - Ok(Some(compare(alp.encoded(), s.as_ref(), operator)?)) + Ok(Some( + alp.encoded() + .binary(s.into_array(), Operator::from(operator))?, + )) } None => match operator { // Since this value is not encodable it cannot be equal to any value in the encoded // array. - Operator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())), + CompareOperator::Eq => Ok(Some(ConstantArray::new(false, alp.len()).into_array())), // Since this value is not encodable it cannot be equal to any value in the encoded // array, hence != to all values in the encoded array. - Operator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())), - Operator::Gt | Operator::Gte => { + CompareOperator::NotEq => Ok(Some(ConstantArray::new(true, alp.len()).into_array())), + CompareOperator::Gt | CompareOperator::Gte => { // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan. // All values in the encoded array are definitely finite let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value); @@ -107,17 +111,19 @@ where ConstantArray::new(value.is_sign_negative(), alp.len()).into_array(), )) } else { - Ok(Some(compare( - alp.encoded(), - ConstantArray::new(F::encode_above(value, exponents), alp.len()).as_ref(), - // Since the encoded value is unencodable gte is equivalent to gt. - // Consider a value v, between two encodable values v_l (just less) and - // v_a (just above), then for all encodable values (u), v > u <=> v_g >= u - Operator::Gte, - )?)) + Ok(Some( + alp.encoded().binary( + ConstantArray::new(F::encode_above(value, exponents), alp.len()) + .into_array(), + // Since the encoded value is unencodable gte is equivalent to gt. + // Consider a value v, between two encodable values v_l (just less) and + // v_a (just above), then for all encodable values (u), v > u <=> v_g >= u + Operator::Gte, + )?, + )) } } - Operator::Lt | Operator::Lte => { + CompareOperator::Lt | CompareOperator::Lte => { // Per IEEE 754 totalOrder semantics the ordering is -Nan < -Inf < Inf < Nan. // All values in the encoded array are definitely finite let is_not_finite = NativePType::is_infinite(value) || NativePType::is_nan(value); @@ -126,13 +132,15 @@ where ConstantArray::new(value.is_sign_positive(), alp.len()).into_array(), )) } else { - Ok(Some(compare( - alp.encoded(), - ConstantArray::new(F::encode_below(value, exponents), alp.len()).as_ref(), - // Since the encoded values unencodable lt is equivalent to lte. - // See Gt | Gte for further explanation. - Operator::Lte, - )?)) + Ok(Some( + alp.encoded().binary( + ConstantArray::new(F::encode_below(value, exponents), alp.len()) + .into_array(), + // Since the encoded values unencodable lt is equivalent to lte. + // See Gt | Gte for further explanation. + Operator::Lte, + )?, + )) } } }, @@ -148,11 +156,12 @@ mod tests { use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; + use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; + use vortex_array::expr::CompareOperator; + use vortex_array::expr::Operator; use vortex_array::scalar::Scalar; use super::*; @@ -161,7 +170,7 @@ mod tests { fn test_alp_compare>( alp: &ALPArray, value: F, - operator: Operator, + operator: CompareOperator, ) -> Option where F::ALPInt: Into, @@ -180,13 +189,13 @@ mod tests { vec![1234; 1025] ); - let r = alp_scalar_compare(&encoded, 1.3_f32, Operator::Eq) + let r = alp_scalar_compare(&encoded, 1.3_f32, CompareOperator::Eq) .unwrap() .unwrap(); let expected = BoolArray::from_iter([false; 1025]); assert_arrays_eq!(r, expected); - let r = alp_scalar_compare(&encoded, 1.234f32, Operator::Eq) + let r = alp_scalar_compare(&encoded, 1.234f32, CompareOperator::Eq) .unwrap() .unwrap(); let expected = BoolArray::from_iter([true; 1025]); @@ -204,14 +213,14 @@ mod tests { ); #[allow(clippy::excessive_precision)] - let r_eq = alp_scalar_compare(&encoded, 1.234444_f32, Operator::Eq) + let r_eq = alp_scalar_compare(&encoded, 1.234444_f32, CompareOperator::Eq) .unwrap() .unwrap(); let expected = BoolArray::from_iter([false; 1025]); assert_arrays_eq!(r_eq, expected); #[allow(clippy::excessive_precision)] - let r_neq = alp_scalar_compare(&encoded, 1.234444f32, Operator::NotEq) + let r_neq = alp_scalar_compare(&encoded, 1.234444f32, CompareOperator::NotEq) .unwrap() .unwrap(); let expected = BoolArray::from_iter([true; 1025]); @@ -229,28 +238,28 @@ mod tests { ); // !(0.0605_f32 >= 0.06051_f32); - let r_gte = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Gte) + let r_gte = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Gte) .unwrap() .unwrap(); let expected = BoolArray::from_iter([false; 10]); assert_arrays_eq!(r_gte, expected); // (0.0605_f32 > 0.06051_f32); - let r_gt = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Gt) + let r_gt = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Gt) .unwrap() .unwrap(); let expected = BoolArray::from_iter([false; 10]); assert_arrays_eq!(r_gt, expected); // 0.0605_f32 <= 0.06051_f32; - let r_lte = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Lte) + let r_lte = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Lte) .unwrap() .unwrap(); let expected = BoolArray::from_iter([true; 10]); assert_arrays_eq!(r_lte, expected); //0.0605_f32 < 0.06051_f32; - let r_lt = alp_scalar_compare(&encoded, 0.06051_f32, Operator::Lt) + let r_lt = alp_scalar_compare(&encoded, 0.06051_f32, CompareOperator::Lt) .unwrap() .unwrap(); let expected = BoolArray::from_iter([true; 10]); @@ -267,31 +276,31 @@ mod tests { vec![0; 10] ); - let r_gte = test_alp_compare(&encoded, -0.00000001_f32, Operator::Gte).unwrap(); + let r_gte = test_alp_compare(&encoded, -0.00000001_f32, CompareOperator::Gte).unwrap(); let expected = BoolArray::from_iter([true; 10]); assert_arrays_eq!(r_gte, expected); - let r_gte = test_alp_compare(&encoded, -0.0_f32, Operator::Gte).unwrap(); + let r_gte = test_alp_compare(&encoded, -0.0_f32, CompareOperator::Gte).unwrap(); let expected = BoolArray::from_iter([true; 10]); assert_arrays_eq!(r_gte, expected); - let r_gt = test_alp_compare(&encoded, -0.0000000001f32, Operator::Gt).unwrap(); + let r_gt = test_alp_compare(&encoded, -0.0000000001f32, CompareOperator::Gt).unwrap(); let expected = BoolArray::from_iter([true; 10]); assert_arrays_eq!(r_gt, expected); - let r_gte = test_alp_compare(&encoded, -0.0_f32, Operator::Gt).unwrap(); + let r_gte = test_alp_compare(&encoded, -0.0_f32, CompareOperator::Gt).unwrap(); let expected = BoolArray::from_iter([true; 10]); assert_arrays_eq!(r_gte, expected); - let r_lte = test_alp_compare(&encoded, 0.06051_f32, Operator::Lte).unwrap(); + let r_lte = test_alp_compare(&encoded, 0.06051_f32, CompareOperator::Lte).unwrap(); let expected = BoolArray::from_iter([true; 10]); assert_arrays_eq!(r_lte, expected); - let r_lt = test_alp_compare(&encoded, 0.06051_f32, Operator::Lt).unwrap(); + let r_lt = test_alp_compare(&encoded, 0.06051_f32, CompareOperator::Lt).unwrap(); let expected = BoolArray::from_iter([true; 10]); assert_arrays_eq!(r_lt, expected); - let r_lt = test_alp_compare(&encoded, -0.00001_f32, Operator::Lt).unwrap(); + let r_lt = test_alp_compare(&encoded, -0.00001_f32, CompareOperator::Lt).unwrap(); let expected = BoolArray::from_iter([false; 10]); assert_arrays_eq!(r_lt, expected); } @@ -305,7 +314,7 @@ mod tests { // Not supported! assert!( - alp_scalar_compare(&encoded, 1_000_000.9_f32, Operator::Eq) + alp_scalar_compare(&encoded, 1_000_000.9_f32, CompareOperator::Eq) .unwrap() .is_none() ) @@ -321,7 +330,10 @@ mod tests { array.len(), ); - let r = compare(encoded.as_ref(), other.as_ref(), Operator::Eq).unwrap(); + let r = encoded + .into_array() + .binary(other.into_array(), Operator::Eq) + .unwrap(); // Comparing to null yields null results let expected = BoolArray::from_iter([None::; 10]); assert_arrays_eq!(r, expected); @@ -336,7 +348,7 @@ mod tests { let array = PrimitiveArray::from_iter([1.234f32; 10]); let encoded = alp_encode(&array, None).unwrap(); - let r = test_alp_compare(&encoded, value, Operator::Gt).unwrap(); + let r = test_alp_compare(&encoded, value, CompareOperator::Gt).unwrap(); let expected = BoolArray::from_iter([result; 10]); assert_arrays_eq!(r, expected); } @@ -350,7 +362,7 @@ mod tests { let array = PrimitiveArray::from_iter([1.234f32; 10]); let encoded = alp_encode(&array, None).unwrap(); - let r = test_alp_compare(&encoded, value, Operator::Lt).unwrap(); + let r = test_alp_compare(&encoded, value, CompareOperator::Lt).unwrap(); let expected = BoolArray::from_iter([result; 10]); assert_arrays_eq!(r, expected); } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 1447e132ce0..c7068dc4774 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -91,8 +91,6 @@ mod tests { use rstest::rstest; use vortex_array::assert_arrays_eq; use vortex_array::builtins::ArrayBuiltins; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; use vortex_array::compute::conformance::cast::test_cast_conformance; use vortex_array::compute::conformance::consistency::test_array_consistency; use vortex_array::compute::conformance::filter::test_filter_conformance; @@ -100,6 +98,7 @@ mod tests { use vortex_array::compute::conformance::take::test_take_conformance; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; + use vortex_array::expr::Operator; use super::*; @@ -119,7 +118,7 @@ mod tests { let lhs = ByteBoolArray::from(vec![true; 5]); let rhs = ByteBoolArray::from(vec![true; 5]); - let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = ByteBoolArray::from(vec![true; 5]); assert_arrays_eq!(arr, expected.to_array()); @@ -130,7 +129,7 @@ mod tests { let lhs = ByteBoolArray::from(vec![false; 5]); let rhs = ByteBoolArray::from(vec![true; 5]); - let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = ByteBoolArray::from(vec![false; 5]); assert_arrays_eq!(arr, expected.to_array()); @@ -141,7 +140,7 @@ mod tests { let lhs = ByteBoolArray::from(vec![true; 5]); let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]); - let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]); diff --git a/encodings/datetime-parts/public-api.lock b/encodings/datetime-parts/public-api.lock index 65e22e9996e..3a05f07f8b6 100644 --- a/encodings/datetime-parts/public-api.lock +++ b/encodings/datetime-parts/public-api.lock @@ -140,7 +140,7 @@ pub fn vortex_datetime_parts::DateTimePartsVTable::is_constant(&self, array: &vo impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_datetime_parts::DateTimePartsVTable -pub fn vortex_datetime_parts::DateTimePartsVTable::compare(lhs: &vortex_datetime_parts::DateTimePartsArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_datetime_parts::DateTimePartsVTable::compare(lhs: &vortex_datetime_parts::DateTimePartsArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::expr::exprs::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_datetime_parts::DateTimePartsVTable diff --git a/encodings/datetime-parts/src/compute/compare.rs b/encodings/datetime-parts/src/compute/compare.rs index 85277383cf5..e6bd3091e18 100644 --- a/encodings/datetime-parts/src/compute/compare.rs +++ b/encodings/datetime-parts/src/compute/compare.rs @@ -7,14 +7,12 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::datetime::Timestamp; use vortex_array::expr::CompareKernel; -use vortex_array::expr::and_kleene; -use vortex_array::expr::or_kleene; +use vortex_array::expr::CompareOperator; +use vortex_array::expr::Operator; use vortex_array::scalar::Scalar; use vortex_error::VortexResult; @@ -26,7 +24,7 @@ impl CompareKernel for DateTimePartsVTable { fn compare( lhs: &DateTimePartsArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { let Some(rhs_const) = rhs.as_constant() else { @@ -53,17 +51,17 @@ impl CompareKernel for DateTimePartsVTable { let ts_parts = timestamp::split(timestamp, options.unit)?; match operator { - Operator::Eq => compare_eq(lhs, &ts_parts, nullability), - Operator::NotEq => compare_ne(lhs, &ts_parts, nullability), + CompareOperator::Eq => compare_eq(lhs, &ts_parts, nullability), + CompareOperator::NotEq => compare_ne(lhs, &ts_parts, nullability), // lt and lte have identical behavior, as we optimize // for the case that all days on the lhs are smaller. // If that special case is not hit, we return `Ok(None)` to // signal that the comparison wasn't handled within dtp. - Operator::Lt => compare_lt(lhs, &ts_parts, nullability), - Operator::Lte => compare_lt(lhs, &ts_parts, nullability), + CompareOperator::Lt => compare_lt(lhs, &ts_parts, nullability), + CompareOperator::Lte => compare_lt(lhs, &ts_parts, nullability), // (Like for lt, lte) - Operator::Gt => compare_gt(lhs, &ts_parts, nullability), - Operator::Gte => compare_gt(lhs, &ts_parts, nullability), + CompareOperator::Gt => compare_gt(lhs, &ts_parts, nullability), + CompareOperator::Gte => compare_gt(lhs, &ts_parts, nullability), } } } @@ -73,31 +71,32 @@ fn compare_eq( ts_parts: ×tamp::TimestampParts, nullability: Nullability, ) -> VortexResult> { - let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::Eq, nullability)?; + let mut comparison = compare_dtp(lhs.days(), ts_parts.days, CompareOperator::Eq, nullability)?; if comparison.statistics().compute_max::() == Some(false) { // All values are different. return Ok(Some(comparison)); } - comparison = and_kleene( - &compare_dtp(lhs.seconds(), ts_parts.seconds, Operator::Eq, nullability)?, - &comparison, - )?; + comparison = compare_dtp( + lhs.seconds(), + ts_parts.seconds, + CompareOperator::Eq, + nullability, + )? + .binary(comparison, Operator::And)?; if comparison.statistics().compute_max::() == Some(false) { // All values are different. return Ok(Some(comparison)); } - comparison = and_kleene( - &compare_dtp( - lhs.subseconds(), - ts_parts.subseconds, - Operator::Eq, - nullability, - )?, - &comparison, - )?; + comparison = compare_dtp( + lhs.subseconds(), + ts_parts.subseconds, + CompareOperator::Eq, + nullability, + )? + .binary(comparison, Operator::And)?; Ok(Some(comparison)) } @@ -107,36 +106,37 @@ fn compare_ne( ts_parts: ×tamp::TimestampParts, nullability: Nullability, ) -> VortexResult> { - let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::NotEq, nullability)?; + let mut comparison = compare_dtp( + lhs.days(), + ts_parts.days, + CompareOperator::NotEq, + nullability, + )?; if comparison.statistics().compute_min::() == Some(true) { // All values are different. return Ok(Some(comparison)); } - comparison = or_kleene( - &compare_dtp( - lhs.seconds(), - ts_parts.seconds, - Operator::NotEq, - nullability, - )?, - &comparison, - )?; + comparison = compare_dtp( + lhs.seconds(), + ts_parts.seconds, + CompareOperator::NotEq, + nullability, + )? + .binary(comparison, Operator::Or)?; if comparison.statistics().compute_min::() == Some(true) { // All values are different. return Ok(Some(comparison)); } - comparison = or_kleene( - &compare_dtp( - lhs.subseconds(), - ts_parts.subseconds, - Operator::NotEq, - nullability, - )?, - &comparison, - )?; + comparison = compare_dtp( + lhs.subseconds(), + ts_parts.subseconds, + CompareOperator::NotEq, + nullability, + )? + .binary(comparison, Operator::Or)?; Ok(Some(comparison)) } @@ -146,7 +146,7 @@ fn compare_lt( ts_parts: ×tamp::TimestampParts, nullability: Nullability, ) -> VortexResult> { - let days_lt = compare_dtp(lhs.days(), ts_parts.days, Operator::Lt, nullability)?; + let days_lt = compare_dtp(lhs.days(), ts_parts.days, CompareOperator::Lt, nullability)?; if days_lt.statistics().compute_min::() == Some(true) { // All values on the lhs are smaller. return Ok(Some(days_lt)); @@ -160,7 +160,7 @@ fn compare_gt( ts_parts: ×tamp::TimestampParts, nullability: Nullability, ) -> VortexResult> { - let days_gt = compare_dtp(lhs.days(), ts_parts.days, Operator::Gt, nullability)?; + let days_gt = compare_dtp(lhs.days(), ts_parts.days, CompareOperator::Gt, nullability)?; if days_gt.statistics().compute_min::() == Some(true) { // All values on the lhs are larger. return Ok(Some(days_gt)); @@ -172,7 +172,7 @@ fn compare_gt( fn compare_dtp( lhs: &dyn Array, rhs: i64, - operator: Operator, + operator: CompareOperator, nullability: Nullability, ) -> VortexResult { // Since nullability is stripped from RHS and carried forward through nullability argument we want to incorporate it into lhs.dtype() that we cast rhs into @@ -180,12 +180,12 @@ fn compare_dtp( .into_array() .cast(lhs.dtype().with_nullability(nullability)) { - Ok(casted) => compare(lhs, &casted, operator), + Ok(casted) => lhs.to_array().binary(casted, Operator::from(operator)), // The narrowing cast failed. Therefore, we know lhs < rhs. _ => { let constant_value = match operator { - Operator::Eq | Operator::Gte | Operator::Gt => false, - Operator::NotEq | Operator::Lte | Operator::Lt => true, + CompareOperator::Eq | CompareOperator::Gte | CompareOperator::Gt => false, + CompareOperator::NotEq | CompareOperator::Lte | CompareOperator::Lt => true, }; Ok( ConstantArray::new(Scalar::bool(constant_value, nullability), lhs.len()) @@ -200,7 +200,6 @@ mod test { use rstest::rstest; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::TemporalArray; - use vortex_array::compute::Operator; use vortex_array::dtype::IntegerPType; use vortex_array::dtype::datetime::TimeUnit; use vortex_array::validity::Validity; @@ -228,11 +227,11 @@ mod test { fn compare_date_time_parts_eq(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) { let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 00:00:00 UTC let rhs = dtp_array_from_timestamp(86400i64, rhs_validity.clone()); // January 2, 1970, 00:00:00 UTC - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let comparison = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1); let rhs = dtp_array_from_timestamp(0i64, rhs_validity); // January 1, 1970, 00:00:00 UTC - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let comparison = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0); } @@ -244,11 +243,17 @@ mod test { fn compare_date_time_parts_ne(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) { let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 00:00:00 UTC let rhs = dtp_array_from_timestamp(86401i64, rhs_validity.clone()); // January 2, 1970, 00:00:01 UTC - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap(); + let comparison = lhs + .to_array() + .binary(rhs.to_array(), Operator::NotEq) + .unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1); let rhs = dtp_array_from_timestamp(86400i64, rhs_validity); // January 2, 1970, 00:00:00 UTC - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap(); + let comparison = lhs + .to_array() + .binary(rhs.to_array(), Operator::NotEq) + .unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0); } @@ -261,7 +266,7 @@ mod test { let lhs = dtp_array_from_timestamp(0i64, lhs_validity); // January 1, 1970, 01:00:00 UTC let rhs = dtp_array_from_timestamp(86400i64, rhs_validity); // January 2, 1970, 00:00:00 UTC - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap(); + let comparison = lhs.to_array().binary(rhs.to_array(), Operator::Lt).unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1); } @@ -274,7 +279,7 @@ mod test { let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 02:00:00 UTC let rhs = dtp_array_from_timestamp(0i64, rhs_validity); // January 1, 1970, 01:00:00 UTC - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap(); + let comparison = lhs.to_array().binary(rhs.to_array(), Operator::Gt).unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1); } @@ -304,19 +309,25 @@ mod test { // Timestamp with a value larger than i32::MAX. let rhs = dtp_array_from_timestamp(i64::MAX, rhs_validity); - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let comparison = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0); - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap(); + let comparison = lhs + .to_array() + .binary(rhs.to_array(), Operator::NotEq) + .unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1); - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap(); + let comparison = lhs.to_array().binary(rhs.to_array(), Operator::Lt).unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1); - let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lte).unwrap(); + let comparison = lhs + .to_array() + .binary(rhs.to_array(), Operator::Lte) + .unwrap(); assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1); - // `Operator::Gt` and `Operator::Gte` only cover the case of all lhs values + // `CompareOperator::Gt` and `CompareOperator::Gte` only cover the case of all lhs values // being larger. Therefore, these cases are not covered by unit tests. } } diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index ef0849c7f2b..30e7abcc713 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -102,7 +102,7 @@ impl ArrayParentReduceRule for DTPComparisonPushDownRule { if parent .scalar_fn() .as_opt::() - .is_none_or(|c| c.maybe_cmp_operator().is_none()) + .is_none_or(|c| !c.is_comparison()) && !parent.scalar_fn().is::() { return Ok(None); diff --git a/encodings/decimal-byte-parts/public-api.lock b/encodings/decimal-byte-parts/public-api.lock index c7091df535f..74c20a39cc4 100644 --- a/encodings/decimal-byte-parts/public-api.lock +++ b/encodings/decimal-byte-parts/public-api.lock @@ -70,7 +70,7 @@ pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::is_constant(&self, arr impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_decimal_byte_parts::DecimalBytePartsVTable -pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::compare(lhs: &Self::Array, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::compare(lhs: &Self::Array, rhs: &dyn vortex_array::array::Array, operator: vortex_array::expr::exprs::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_decimal_byte_parts::DecimalBytePartsVTable diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs index e075ab6255a..fa83085a9fe 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs @@ -7,13 +7,14 @@ use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::IntegerPType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::ToI256; use vortex_array::expr::CompareKernel; +use vortex_array::expr::CompareOperator; +use vortex_array::expr::Operator; use vortex_array::match_each_decimal_value; use vortex_array::match_each_integer_ptype; use vortex_array::scalar::DecimalValue; @@ -29,7 +30,7 @@ impl CompareKernel for DecimalBytePartsVTable { fn compare( lhs: &Self::Array, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { let Some(rhs_const) = rhs.as_constant() else { @@ -49,7 +50,9 @@ impl CompareKernel for DecimalBytePartsVTable { Ok(value) => { let encoded_scalar = Scalar::try_new(scalar_type, Some(value))?; let encoded_const = ConstantArray::new(encoded_scalar, rhs.len()); - compare(&lhs.msp, &encoded_const.to_array(), operator).map(Some) + lhs.msp + .binary(encoded_const.to_array(), Operator::from(operator)) + .map(Some) } Err(sign) => { @@ -82,12 +85,16 @@ enum Sign { Negative, } -fn unconvertible_value(sign: Sign, operator: Operator, nullability: Nullability) -> Scalar { +fn unconvertible_value(sign: Sign, operator: CompareOperator, nullability: Nullability) -> Scalar { match operator { - Operator::Eq => Scalar::bool(false, nullability), - Operator::NotEq => Scalar::bool(true, nullability), - Operator::Gt | Operator::Gte => Scalar::bool(matches!(sign, Negative), nullability), - Operator::Lt | Operator::Lte => Scalar::bool(matches!(sign, Positive), nullability), + CompareOperator::Eq => Scalar::bool(false, nullability), + CompareOperator::NotEq => Scalar::bool(true, nullability), + CompareOperator::Gt | CompareOperator::Gte => { + Scalar::bool(matches!(sign, Negative), nullability) + } + CompareOperator::Lt | CompareOperator::Lte => { + Scalar::bool(matches!(sign, Positive), nullability) + } } } @@ -139,11 +146,11 @@ mod tests { use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; + use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::DecimalDType; use vortex_array::dtype::Nullability; + use vortex_array::expr::Operator; use vortex_array::scalar::DecimalValue; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; @@ -167,7 +174,7 @@ mod tests { lhs.len(), ); - let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let res = lhs.binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = BoolArray::from_iter([Some(false), Some(false), Some(true)]).into_array(); assert_arrays_eq!(res, expected); @@ -195,7 +202,7 @@ mod tests { ) .into_array(); - let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lte)?; + let res = lhs.to_array().binary(rhs, Operator::Lte)?; let expected = BoolArray::from_iter([None, Some(true), Some(true), Some(true)]).into_array(); assert_arrays_eq!(res, expected); @@ -223,15 +230,15 @@ mod tests { lhs.len(), ); - let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let res = lhs.binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = BoolArray::from_iter([Some(false), Some(false), Some(false)]).into_array(); assert_arrays_eq!(res, expected); - let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap(); + let res = lhs.binary(rhs.to_array(), Operator::Gt).unwrap(); let expected = BoolArray::from_iter([Some(true), Some(true), Some(true)]).into_array(); assert_arrays_eq!(res, expected); - let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap(); + let res = lhs.binary(rhs.to_array(), Operator::Lt).unwrap(); let expected = BoolArray::from_iter([Some(false), Some(false), Some(false)]).into_array(); assert_arrays_eq!(res, expected); @@ -241,15 +248,15 @@ mod tests { lhs.len(), ); - let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let res = lhs.binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = BoolArray::from_iter([Some(false), Some(false), Some(false)]).into_array(); assert_arrays_eq!(res, expected); - let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap(); + let res = lhs.binary(rhs.to_array(), Operator::Gt).unwrap(); let expected = BoolArray::from_iter([Some(false), Some(false), Some(false)]).into_array(); assert_arrays_eq!(res, expected); - let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap(); + let res = lhs.binary(rhs.to_array(), Operator::Lt).unwrap(); let expected = BoolArray::from_iter([Some(true), Some(true), Some(true)]).into_array(); assert_arrays_eq!(res, expected); } diff --git a/encodings/fastlanes/benches/compute_between.rs b/encodings/fastlanes/benches/compute_between.rs index 8df05d5c2c3..a476e850074 100644 --- a/encodings/fastlanes/benches/compute_between.rs +++ b/encodings/fastlanes/benches/compute_between.rs @@ -72,12 +72,10 @@ mod primitive { use vortex_array::VortexSessionExecute; use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; use vortex_array::dtype::NativePType; use vortex_array::expr::BetweenOptions; + use vortex_array::expr::Operator; use vortex_array::expr::StrictComparison::NonStrict; - use vortex_array::expr::and_kleene; use vortex_error::VortexExpect; use crate::BENCH_ARGS; @@ -100,22 +98,23 @@ mod primitive { bencher .with_inputs(|| (&arr, LEGACY_SESSION.create_execution_ctx())) .bench_refs(|(arr, ctx)| { - and_kleene( - &compare( - arr.as_ref(), - ConstantArray::new(min, arr.len()).as_ref(), + let gte = arr + .to_array() + .binary( + ConstantArray::new(min, arr.len()).into_array(), Operator::Gte, ) - .vortex_expect(""), - &compare( - arr.as_ref(), - ConstantArray::new(max, arr.len()).as_ref(), + .vortex_expect(""); + let lt = arr + .to_array() + .binary( + ConstantArray::new(max, arr.len()).into_array(), Operator::Lt, ) - .vortex_expect(""), - ) - .vortex_expect("") - .execute::(ctx) + .vortex_expect(""); + gte.binary(lt, Operator::And) + .vortex_expect("") + .execute::(ctx) }) } @@ -163,12 +162,10 @@ mod bitpack { use vortex_array::VortexSessionExecute; use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; use vortex_array::dtype::NativePType; use vortex_array::expr::BetweenOptions; + use vortex_array::expr::Operator; use vortex_array::expr::StrictComparison::NonStrict; - use vortex_array::expr::and_kleene; use vortex_error::VortexExpect; use crate::BENCH_ARGS; @@ -191,23 +188,24 @@ mod bitpack { bencher .with_inputs(|| (&arr, LEGACY_SESSION.create_execution_ctx())) .bench_refs(|(arr, ctx)| { - and_kleene( - &compare( - arr.as_ref(), - ConstantArray::new(min, arr.len()).as_ref(), + let gte = arr + .to_array() + .binary( + ConstantArray::new(min, arr.len()).into_array(), Operator::Gte, ) - .vortex_expect(""), - &compare( - arr.as_ref(), - ConstantArray::new(max, arr.len()).as_ref(), + .vortex_expect(""); + let lt = arr + .to_array() + .binary( + ConstantArray::new(max, arr.len()).into_array(), Operator::Lt, ) - .vortex_expect(""), - ) - .unwrap() - .execute::(ctx) - .unwrap() + .vortex_expect(""); + gte.binary(lt, Operator::And) + .unwrap() + .execute::(ctx) + .unwrap() }) } @@ -255,12 +253,10 @@ mod alp { use vortex_array::VortexSessionExecute; use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; use vortex_array::dtype::NativePType; use vortex_array::expr::BetweenOptions; + use vortex_array::expr::Operator; use vortex_array::expr::StrictComparison::NonStrict; - use vortex_array::expr::and_kleene; use vortex_error::VortexExpect; use crate::BENCH_ARGS; @@ -283,23 +279,24 @@ mod alp { bencher .with_inputs(|| (&arr, LEGACY_SESSION.create_execution_ctx())) .bench_refs(|(arr, ctx)| { - and_kleene( - &compare( - arr.as_ref(), - ConstantArray::new(min, arr.len()).as_ref(), + let gte = arr + .to_array() + .binary( + ConstantArray::new(min, arr.len()).into_array(), Operator::Gte, ) - .vortex_expect(""), - &compare( - arr.as_ref(), - ConstantArray::new(max, arr.len()).as_ref(), + .vortex_expect(""); + let lt = arr + .to_array() + .binary( + ConstantArray::new(max, arr.len()).into_array(), Operator::Lt, ) - .vortex_expect(""), - ) - .unwrap() - .execute::(ctx) - .unwrap() + .vortex_expect(""); + gte.binary(lt, Operator::And) + .unwrap() + .execute::(ctx) + .unwrap() }) } diff --git a/encodings/fastlanes/public-api.lock b/encodings/fastlanes/public-api.lock index aa6e29cff14..ad87af1e7a0 100644 --- a/encodings/fastlanes/public-api.lock +++ b/encodings/fastlanes/public-api.lock @@ -472,7 +472,7 @@ pub fn vortex_fastlanes::FoRVTable::is_strict_sorted(&self, array: &vortex_fastl impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_fastlanes::FoRVTable -pub fn vortex_fastlanes::FoRVTable::compare(lhs: &vortex_fastlanes::FoRArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_fastlanes::FoRVTable::compare(lhs: &vortex_fastlanes::FoRArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::expr::exprs::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_fastlanes::FoRVTable diff --git a/encodings/fastlanes/src/for/compute/compare.rs b/encodings/fastlanes/src/for/compute/compare.rs index 37ae0845ba1..2e0159f5ee5 100644 --- a/encodings/fastlanes/src/for/compute/compare.rs +++ b/encodings/fastlanes/src/for/compute/compare.rs @@ -7,12 +7,14 @@ use num_traits::WrappingSub; use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::expr::CompareKernel; +use vortex_array::expr::CompareOperator; +use vortex_array::expr::Operator; use vortex_array::match_each_integer_ptype; use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; @@ -27,7 +29,7 @@ impl CompareKernel for FoRVTable { fn compare( lhs: &FoRArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { if let Some(constant) = rhs.as_constant() @@ -53,7 +55,7 @@ fn compare_constant( lhs: &FoRArray, mut rhs: T, nullability: Nullability, - operator: Operator, + operator: CompareOperator, ) -> VortexResult> where T: NativePType + WrappingSub + Shr, @@ -62,7 +64,7 @@ where { // For now, we only support equals and not equals. Comparisons are a little more fiddly to // get right regarding how to handle overflow and the wrapping subtraction. - if !matches!(operator, Operator::Eq | Operator::NotEq) { + if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) { return Ok(None); } @@ -78,12 +80,12 @@ where // unsigned integer type). let rhs = Scalar::primitive(rhs, nullability); - compare( - lhs.encoded(), - ConstantArray::new(rhs, lhs.len()).as_ref(), - operator, - ) - .map(Some) + lhs.encoded() + .binary( + ConstantArray::new(rhs, lhs.len()).into_array(), + Operator::from(operator), + ) + .map(Some) } #[cfg(test)] @@ -108,17 +110,27 @@ mod tests { ) .unwrap(); - let result = compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq) + let result = compare_constant(&lhs, 30i32, Nullability::NonNullable, CompareOperator::Eq) .unwrap() .unwrap(); assert_arrays_eq!(result, BoolArray::from_iter([false, true, false].map(Some))); - let result = compare_constant(&lhs, 12i32, Nullability::NonNullable, Operator::NotEq) - .unwrap() - .unwrap(); + let result = compare_constant( + &lhs, + 12i32, + Nullability::NonNullable, + CompareOperator::NotEq, + ) + .unwrap() + .unwrap(); assert_arrays_eq!(result, BoolArray::from_iter([true, true, false].map(Some))); - for op in [Operator::Lt, Operator::Lte, Operator::Gt, Operator::Gte] { + for op in [ + CompareOperator::Lt, + CompareOperator::Lte, + CompareOperator::Gt, + CompareOperator::Gte, + ] { assert!( compare_constant(&lhs, 30i32, Nullability::NonNullable, op) .unwrap() @@ -138,14 +150,14 @@ mod tests { .unwrap(); assert_eq!( - compare_constant(&lhs, 30i32, Nullability::Nullable, Operator::Eq) + compare_constant(&lhs, 30i32, Nullability::Nullable, CompareOperator::Eq) .unwrap() .unwrap() .dtype(), &DType::Bool(Nullability::Nullable) ); assert_eq!( - compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq) + compare_constant(&lhs, 30i32, Nullability::NonNullable, CompareOperator::Eq) .unwrap() .unwrap() .dtype(), @@ -163,7 +175,7 @@ mod tests { ) .unwrap(); - let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::Eq) + let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, CompareOperator::Eq) .unwrap() .unwrap(); assert_arrays_eq!( @@ -171,9 +183,14 @@ mod tests { BoolArray::from_iter([false, false, false].map(Some)) ); - let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::NotEq) - .unwrap() - .unwrap(); + let result = compare_constant( + &lhs, + -1i32, + Nullability::NonNullable, + CompareOperator::NotEq, + ) + .unwrap() + .unwrap(); assert_arrays_eq!(result, BoolArray::from_iter([true, true, true].map(Some))); } @@ -195,7 +212,7 @@ mod tests { &lhs, 435090932899640449i64, Nullability::Nullable, - Operator::Eq, + CompareOperator::Eq, ) .unwrap() .unwrap(); @@ -205,7 +222,7 @@ mod tests { &lhs, 435090932899640449i64, Nullability::Nullable, - Operator::NotEq, + CompareOperator::NotEq, ) .unwrap() .unwrap(); diff --git a/encodings/fsst/benches/fsst_compress.rs b/encodings/fsst/benches/fsst_compress.rs index da4f951a140..8bb2387bb23 100644 --- a/encodings/fsst/benches/fsst_compress.rs +++ b/encodings/fsst/benches/fsst_compress.rs @@ -18,11 +18,11 @@ use vortex_array::arrays::ConstantArray; use vortex_array::arrays::VarBinArray; use vortex_array::builders::ArrayBuilder; use vortex_array::builders::VarBinViewBuilder; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::warm_up_vtables; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; +use vortex_array::expr::Operator; use vortex_array::scalar::Scalar; use vortex_array::session::ArraySession; use vortex_fsst::fsst_compress; @@ -97,7 +97,9 @@ fn pushdown_compare(bencher: Bencher, (string_count, avg_len, unique_chars): (us ) }) .bench_refs(|(fsst_array, constant, ctx)| { - compare(fsst_array.as_ref(), constant.as_ref(), Operator::Eq) + fsst_array + .to_array() + .binary(constant.to_array(), Operator::Eq) .unwrap() .execute::(ctx) .unwrap(); @@ -123,14 +125,15 @@ fn canonicalize_compare( ) }) .bench_refs(|(fsst_array, constant, ctx)| { - compare( - fsst_array.to_canonical().unwrap().as_ref(), - constant.as_ref(), - Operator::Eq, - ) - .unwrap() - .execute::(ctx) - .unwrap(); + fsst_array + .to_canonical() + .unwrap() + .as_ref() + .to_array() + .binary(constant.to_array(), Operator::Eq) + .unwrap() + .execute::(ctx) + .unwrap(); }); } diff --git a/encodings/fsst/public-api.lock b/encodings/fsst/public-api.lock index 692582aab67..786b35c4baf 100644 --- a/encodings/fsst/public-api.lock +++ b/encodings/fsst/public-api.lock @@ -106,7 +106,7 @@ pub fn vortex_fsst::FSSTVTable::slice(array: &Self::Array, range: core::ops::ran impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_fsst::FSSTVTable -pub fn vortex_fsst::FSSTVTable::compare(lhs: &vortex_fsst::FSSTArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_fsst::FSSTVTable::compare(lhs: &vortex_fsst::FSSTArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::expr::exprs::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_fsst::FSSTVTable diff --git a/encodings/fsst/src/compute/compare.rs b/encodings/fsst/src/compute/compare.rs index 27b7e15cd88..53bc8ffd5ea 100644 --- a/encodings/fsst/src/compute/compare.rs +++ b/encodings/fsst/src/compute/compare.rs @@ -8,11 +8,12 @@ use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::BoolArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::compare_lengths_to_empty; use vortex_array::dtype::DType; use vortex_array::expr::CompareKernel; +use vortex_array::expr::CompareOperator; +use vortex_array::expr::Operator; use vortex_array::match_each_integer_ptype; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; @@ -29,7 +30,7 @@ impl CompareKernel for FSSTVTable { fn compare( lhs: &FSSTArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { match rhs.as_constant() { @@ -44,7 +45,7 @@ impl CompareKernel for FSSTVTable { fn compare_fsst_constant( left: &FSSTArray, right: &Scalar, - operator: Operator, + operator: CompareOperator, ) -> VortexResult> { let is_rhs_empty = match right.dtype() { DType::Binary(_) => right @@ -60,9 +61,9 @@ fn compare_fsst_constant( if is_rhs_empty { let buffer = match operator { // Every possible value is gte "" - Operator::Gte => BitBuffer::new_set(left.len()), + CompareOperator::Gte => BitBuffer::new_set(left.len()), // No value is lt "" - Operator::Lt => BitBuffer::new_unset(left.len()), + CompareOperator::Lt => BitBuffer::new_unset(left.len()), _ => { let uncompressed_lengths = left.uncompressed_lengths().to_primitive(); match_each_integer_ptype!(uncompressed_lengths.ptype(), |P| { @@ -85,7 +86,7 @@ fn compare_fsst_constant( } // The following section only supports Eq/NotEq - if !matches!(operator, Operator::Eq | Operator::NotEq) { + if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) { return Ok(None); } @@ -114,7 +115,10 @@ fn compare_fsst_constant( ); let rhs = ConstantArray::new(encoded_scalar, left.len()); - compare(left.codes().as_ref(), rhs.as_ref(), operator).map(Some) + left.codes() + .to_array() + .binary(rhs.into_array(), Operator::from(operator)) + .map(Some) } #[cfg(test)] @@ -125,10 +129,10 @@ mod tests { use vortex_array::arrays::ConstantArray; use vortex_array::arrays::VarBinArray; use vortex_array::assert_arrays_eq; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; + use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; + use vortex_array::expr::Operator; use vortex_array::scalar::Scalar; use crate::fsst_compress; @@ -153,7 +157,9 @@ mod tests { let rhs = ConstantArray::new("world", lhs.len()); // Ensure fastpath for Eq exists, and returns correct answer - let equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq) + let equals = lhs + .to_array() + .binary(rhs.to_array(), Operator::Eq) .unwrap() .to_bool(); @@ -165,7 +171,9 @@ mod tests { ); // Ensure fastpath for Eq exists, and returns correct answer - let not_equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq) + let not_equals = lhs + .to_array() + .binary(rhs.to_array(), Operator::NotEq) .unwrap() .to_bool(); @@ -178,13 +186,19 @@ mod tests { // Ensure null constants are handled correctly. let null_rhs = ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len()); - let equals_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::Eq).unwrap(); + let equals_null = lhs + .to_array() + .binary(null_rhs.to_array(), Operator::Eq) + .unwrap(); assert_arrays_eq!( &equals_null, &BoolArray::from_iter([None::, None, None, None, None]) ); - let noteq_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::NotEq).unwrap(); + let noteq_null = lhs + .to_array() + .binary(null_rhs.to_array(), Operator::NotEq) + .unwrap(); assert_arrays_eq!( ¬eq_null, &BoolArray::from_iter([None::, None, None, None, None]) diff --git a/encodings/runend/public-api.lock b/encodings/runend/public-api.lock index 58465e90ab5..7378b0fa69e 100644 --- a/encodings/runend/public-api.lock +++ b/encodings/runend/public-api.lock @@ -138,7 +138,7 @@ pub fn vortex_runend::RunEndVTable::min_max(&self, array: &vortex_runend::RunEnd impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_runend::RunEndVTable -pub fn vortex_runend::RunEndVTable::compare(lhs: &vortex_runend::RunEndArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_runend::RunEndVTable::compare(lhs: &vortex_runend::RunEndArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::expr::exprs::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_runend::RunEndVTable diff --git a/encodings/runend/src/compute/compare.rs b/encodings/runend/src/compute/compare.rs index 1d95bf11fa5..c17c1f723b7 100644 --- a/encodings/runend/src/compute/compare.rs +++ b/encodings/runend/src/compute/compare.rs @@ -7,9 +7,10 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::expr::CompareKernel; +use vortex_array::expr::CompareOperator; +use vortex_array::expr::Operator; use vortex_error::VortexResult; use crate::RunEndArray; @@ -20,15 +21,14 @@ impl CompareKernel for RunEndVTable { fn compare( lhs: &RunEndArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { // If the RHS is constant, then we just need to compare against our encoded values. if let Some(const_scalar) = rhs.as_constant() { - let values = compare( - lhs.values(), - ConstantArray::new(const_scalar, lhs.values().len()).as_ref(), - operator, + let values = lhs.values().binary( + ConstantArray::new(const_scalar, lhs.values().len()).into_array(), + Operator::from(operator), )?; let decoded = runend_decode_bools( lhs.ends().to_primitive(), @@ -51,8 +51,8 @@ mod test { use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; + use vortex_array::builtins::ArrayBuiltins; + use vortex_array::expr::Operator; use crate::RunEndArray; @@ -66,12 +66,10 @@ mod test { #[test] fn compare_run_end() { let arr = ree_array(); - let res = compare( - arr.as_ref(), - ConstantArray::new(5, 12).as_ref(), - Operator::Eq, - ) - .unwrap(); + let res = arr + .to_array() + .binary(ConstantArray::new(5, 12).into_array(), Operator::Eq) + .unwrap(); let expected = BoolArray::from_iter([ false, false, false, false, false, false, false, false, true, true, true, true, ]); diff --git a/encodings/sequence/public-api.lock b/encodings/sequence/public-api.lock index 870232fe528..0ddbb005f67 100644 --- a/encodings/sequence/public-api.lock +++ b/encodings/sequence/public-api.lock @@ -90,7 +90,7 @@ pub fn vortex_sequence::SequenceVTable::min_max(&self, array: &vortex_sequence:: impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_sequence::SequenceVTable -pub fn vortex_sequence::SequenceVTable::compare(lhs: &vortex_sequence::SequenceArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_sequence::SequenceVTable::compare(lhs: &vortex_sequence::SequenceArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::expr::exprs::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_sequence::SequenceVTable diff --git a/encodings/sequence/src/compute/compare.rs b/encodings/sequence/src/compute/compare.rs index 6da5952f8b5..0f31a21ad54 100644 --- a/encodings/sequence/src/compute/compare.rs +++ b/encodings/sequence/src/compute/compare.rs @@ -6,10 +6,10 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::arrays::BoolArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::Operator; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::expr::CompareKernel; +use vortex_array::expr::CompareOperator; use vortex_array::match_each_integer_ptype; use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; @@ -26,11 +26,11 @@ impl CompareKernel for SequenceVTable { fn compare( lhs: &SequenceArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { // TODO(joe): support other operators (NotEq, Lt, Lte, Gt, Gte) in encoded space. - if operator != Operator::Eq { + if operator != CompareOperator::Eq { return Ok(None); } @@ -138,10 +138,10 @@ mod tests { use vortex_array::arrays::BoolArray; use vortex_array::arrays::ConstantArray; use vortex_array::assert_arrays_eq; - use vortex_array::compute::Operator; - use vortex_array::compute::compare; + use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::Nullability::NonNullable; use vortex_array::dtype::Nullability::Nullable; + use vortex_array::expr::Operator; use crate::SequenceArray; @@ -149,7 +149,7 @@ mod tests { fn test_compare_match() { let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap(); let rhs = ConstantArray::new(4i64, lhs.len()); - let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let result = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = BoolArray::from_iter([false, false, true, false]); assert_arrays_eq!(result, expected); } @@ -158,7 +158,7 @@ mod tests { fn test_compare_match_scale() { let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap(); let rhs = ConstantArray::new(8i64, lhs.len()); - let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let result = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = BoolArray::from_iter([Some(false), Some(false), Some(true), Some(false)]); assert_arrays_eq!(result, expected); } @@ -167,7 +167,7 @@ mod tests { fn test_compare_no_match() { let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap(); let rhs = ConstantArray::new(1i64, lhs.len()); - let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); + let result = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap(); let expected = BoolArray::from_iter([false, false, false, false]); assert_arrays_eq!(result, expected); } diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index 74d389e81fc..7349ca1ce9c 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -20,12 +20,10 @@ use vortex_array::ToCanonical; use vortex_array::arrays::ConstantArray; use vortex_array::buffer::BufferHandle; use vortex_array::builtins::ArrayBuiltins; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; use vortex_array::compute::filter; -use vortex_array::compute::sub_scalar; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; +use vortex_array::expr::Operator; use vortex_array::patches::Patches; use vortex_array::patches::PatchesMetadata; use vortex_array::scalar::Scalar; @@ -272,7 +270,10 @@ impl SparseArray { pub fn resolved_patches(&self) -> VortexResult { let patches = self.patches(); let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?; - let indices = sub_scalar(patches.indices(), indices_offset)?; + let indices = patches.indices().to_array().binary( + ConstantArray::new(indices_offset, patches.indices().len()).into_array(), + Operator::Sub, + )?; Patches::new( patches.array_len(), @@ -354,7 +355,9 @@ impl SparseArray { let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array(); let non_top_mask = Mask::from_buffer( - compare(array, &fill_array, Operator::NotEq)? + array + .to_array() + .binary(fill_array.clone(), Operator::NotEq)? .fill_null(Scalar::bool(true, Nullability::NonNullable))? .to_bool() .to_bit_buffer(), diff --git a/fuzz/fuzz_targets/file_io.rs b/fuzz/fuzz_targets/file_io.rs index 53e82294d38..1eb7af002d0 100644 --- a/fuzz/fuzz_targets/file_io.rs +++ b/fuzz/fuzz_targets/file_io.rs @@ -12,10 +12,10 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::ChunkedArray; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::StructFields; +use vortex_array::expr::Operator; use vortex_array::expr::lit; use vortex_array::expr::root; use vortex_buffer::ByteBufferMut; @@ -107,7 +107,8 @@ fuzz_target!(|fuzz: FuzzFileAction| -> Corpus { output_array.dtype() ); - let bool_result = compare(&expected_array, &output_array, Operator::Eq) + let bool_result = expected_array + .binary(output_array.clone(), Operator::Eq) .vortex_expect("compare operation should succeed in fuzz test") .to_bool(); let true_count = bool_result.to_bit_buffer().true_count(); diff --git a/fuzz/src/array/compare.rs b/fuzz/src/array/compare.rs index c2734b95289..ab026c8a9e6 100644 --- a/fuzz/src/array/compare.rs +++ b/fuzz/src/array/compare.rs @@ -8,10 +8,10 @@ use vortex_array::ToCanonical; use vortex_array::accessor::ArrayAccessor; use vortex_array::arrays::BoolArray; use vortex_array::arrays::NativeValue; -use vortex_array::compute::Operator; -use vortex_array::compute::scalar_cmp; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; +use vortex_array::expr::CompareOperator; +use vortex_array::expr::scalar_cmp; use vortex_array::match_each_decimal_value_type; use vortex_array::match_each_native_ptype; use vortex_array::scalar::Scalar; @@ -20,7 +20,11 @@ use vortex_buffer::BitBuffer; use vortex_error::VortexExpect; use vortex_error::vortex_panic; -pub fn compare_canonical_array(array: &dyn Array, value: &Scalar, operator: Operator) -> ArrayRef { +pub fn compare_canonical_array( + array: &dyn Array, + value: &Scalar, + operator: CompareOperator, +) -> ArrayRef { if value.is_null() { return BoolArray::new(BitBuffer::new_unset(array.len()), Validity::AllInvalid) .into_array(); @@ -146,16 +150,16 @@ pub fn compare_canonical_array(array: &dyn Array, value: &Scalar, operator: Oper fn compare_to( values: impl Iterator>, cmp_value: T, - operator: Operator, + operator: CompareOperator, nullability: Nullability, ) -> ArrayRef { let eval_fn = |v| match operator { - Operator::Eq => v == cmp_value, - Operator::NotEq => v != cmp_value, - Operator::Gt => v > cmp_value, - Operator::Gte => v >= cmp_value, - Operator::Lt => v < cmp_value, - Operator::Lte => v <= cmp_value, + CompareOperator::Eq => v == cmp_value, + CompareOperator::NotEq => v != cmp_value, + CompareOperator::Gt => v > cmp_value, + CompareOperator::Gte => v >= cmp_value, + CompareOperator::Lt => v < cmp_value, + CompareOperator::Lte => v <= cmp_value, }; if !nullability.is_nullable() { diff --git a/fuzz/src/array/mod.rs b/fuzz/src/array/mod.rs index 008d961e008..17f0da5a7ab 100644 --- a/fuzz/src/array/mod.rs +++ b/fuzz/src/array/mod.rs @@ -43,9 +43,9 @@ use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::arbitrary::ArbitraryArray; use vortex_array::compute::MinMaxResult; -use vortex_array::compute::Operator; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; +use vortex_array::expr::CompareOperator; use vortex_array::scalar::Scalar; use vortex_array::scalar::arbitrary::random_scalar; use vortex_array::search_sorted::SearchResult; @@ -96,7 +96,7 @@ pub enum Action { Take(ArrayRef), SearchSorted(Scalar, SearchSortedSide), Filter(Mask), - Compare(Scalar, Operator), + Compare(Scalar, CompareOperator), Cast(DType), Sum, MinMax, @@ -554,9 +554,9 @@ pub fn compress_array(array: &dyn Array, _strategy: CompressorStrategy) -> Array pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzzResult { use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; - use vortex_array::compute::compare; use vortex_array::compute::min_max; use vortex_array::compute::sum; + use vortex_array::expr::Operator; let FuzzArrayAction { array, actions } = fuzz_action; let mut current_array = array.to_array(); @@ -600,12 +600,12 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzz assert_array_eq(&expected.array(), ¤t_array, i)?; } Action::Compare(v, op) => { - let compare_result = compare( - ¤t_array, - &ConstantArray::new(v.clone(), current_array.len()).into_array(), - op, - ) - .vortex_expect("compare operation should succeed in fuzz test"); + let compare_result = current_array + .binary( + ConstantArray::new(v.clone(), current_array.len()).into_array(), + Operator::from(op), + ) + .vortex_expect("compare operation should succeed in fuzz test"); if let Err(e) = assert_array_eq(&expected.array(), &compare_result, i) { vortex_panic!( "Failed to compare {}with {op} {v}\nError: {e}", diff --git a/vortex-array/benches/compare.rs b/vortex-array/benches/compare.rs index 7c20cff6f39..98725a9242b 100644 --- a/vortex-array/benches/compare.rs +++ b/vortex-array/benches/compare.rs @@ -12,8 +12,8 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::expr::Operator; use vortex_buffer::Buffer; use vortex_session::VortexSession; @@ -35,7 +35,10 @@ fn compare_bool(bencher: Bencher) { bencher .with_inputs(|| (&arr1, &arr2, session.create_execution_ctx())) .bench_refs(|input| { - compare(input.0, input.1, Operator::Gte) + input + .0 + .to_array() + .binary(input.1.to_array(), Operator::Gte) .unwrap() .execute::(&mut input.2) }); @@ -60,7 +63,10 @@ fn compare_int(bencher: Bencher) { bencher .with_inputs(|| (&arr1, &arr2, session.create_execution_ctx())) .bench_refs(|input| { - compare(input.0, input.1, Operator::Gte) + input + .0 + .to_array() + .binary(input.1.to_array(), Operator::Gte) .unwrap() .execute::(&mut input.2) }); diff --git a/vortex-array/benches/dict_compare.rs b/vortex-array/benches/dict_compare.rs index e5bdcede59c..b4d9e4508e5 100644 --- a/vortex-array/benches/dict_compare.rs +++ b/vortex-array/benches/dict_compare.rs @@ -15,9 +15,9 @@ use vortex_array::arrays::VarBinViewArray; use vortex_array::arrays::dict_test::gen_primitive_for_dict; use vortex_array::arrays::dict_test::gen_varbin_words; use vortex_array::builders::dict::dict_encode; -use vortex_array::compute::Operator; -use vortex_array::compute::compare; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::warm_up_vtables; +use vortex_array::expr::Operator; use vortex_array::expr::eq; use vortex_array::expr::lit; use vortex_array::expr::root; @@ -56,14 +56,11 @@ fn bench_compare_primitive(bencher: divan::Bencher, (len, uniqueness): (usize, u bencher .with_inputs(|| (&dict, session.create_execution_ctx())) .bench_refs(|(dict, ctx)| { - compare( - dict.as_ref(), - ConstantArray::new(value, len).as_ref(), - Operator::Eq, - ) - .unwrap() - .execute::(ctx) - .unwrap() + dict.to_array() + .binary(ConstantArray::new(value, len).to_array(), Operator::Eq) + .unwrap() + .execute::(ctx) + .unwrap() }) } @@ -78,14 +75,11 @@ fn bench_compare_varbin(bencher: divan::Bencher, (len, uniqueness): (usize, usiz bencher .with_inputs(|| (&dict, session.create_execution_ctx())) .bench_refs(|(dict, ctx)| { - compare( - dict.as_ref(), - ConstantArray::new(value, len).as_ref(), - Operator::Eq, - ) - .unwrap() - .execute::(ctx) - .unwrap() + dict.to_array() + .binary(ConstantArray::new(value, len).to_array(), Operator::Eq) + .unwrap() + .execute::(ctx) + .unwrap() }) } @@ -100,14 +94,11 @@ fn bench_compare_varbinview(bencher: divan::Bencher, (len, uniqueness): (usize, bencher .with_inputs(|| (&dict, session.create_execution_ctx())) .bench_refs(|(dict, ctx)| { - compare( - dict.as_ref(), - ConstantArray::new(value, len).as_ref(), - Operator::Eq, - ) - .unwrap() - .execute::(ctx) - .unwrap() + dict.to_array() + .binary(ConstantArray::new(value, len).to_array(), Operator::Eq) + .unwrap() + .execute::(ctx) + .unwrap() }) } diff --git a/vortex-array/benches/scalar_subtract.rs b/vortex-array/benches/scalar_subtract.rs index 744d076c902..7a55502072c 100644 --- a/vortex-array/benches/scalar_subtract.rs +++ b/vortex-array/benches/scalar_subtract.rs @@ -9,7 +9,13 @@ use rand::SeedableRng; use rand::distr::Uniform; use rand::rngs::StdRng; use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; +use vortex_array::RecursiveCanonical; +use vortex_array::VortexSessionExecute; use vortex_array::arrays::ChunkedArray; +use vortex_array::arrays::ConstantArray; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::expr::Operator; use vortex_buffer::Buffer; fn main() { @@ -34,7 +40,20 @@ fn scalar_subtract(bencher: Bencher) { let chunked = ChunkedArray::from_iter([data1, data2]).into_array(); - bencher.with_inputs(|| &chunked).bench_refs(|chunked| { - vortex_array::compute::sub_scalar(*chunked, to_subtract.into()).unwrap() - }); + bencher + .with_inputs(|| (&chunked, LEGACY_SESSION.create_execution_ctx())) + .bench_refs(|(chunked, ctx)| { + chunked + .to_array() + .binary( + ConstantArray::new( + vortex_array::scalar::Scalar::from(to_subtract), + chunked.len(), + ) + .into_array(), + Operator::Sub, + ) + .unwrap() + .execute::(ctx) + }); } diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index e9d42e6ed82..cf958bf17c4 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -164,7 +164,7 @@ pub fn vortex_array::arrays::DictVTable::cast(array: &vortex_array::arrays::Dict impl vortex_array::expr::CompareKernel for vortex_array::arrays::DictVTable -pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::expr::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::FillNullKernel for vortex_array::arrays::DictVTable @@ -1222,7 +1222,7 @@ pub fn vortex_array::arrays::DictVTable::cast(array: &vortex_array::arrays::Dict impl vortex_array::expr::CompareKernel for vortex_array::arrays::DictVTable -pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::expr::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::FillNullKernel for vortex_array::arrays::DictVTable @@ -1420,7 +1420,7 @@ pub fn vortex_array::arrays::ExtensionVTable::cast(array: &vortex_array::arrays: impl vortex_array::expr::CompareKernel for vortex_array::arrays::ExtensionVTable -pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::expr::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::MaskReduce for vortex_array::arrays::ExtensionVTable @@ -3632,7 +3632,7 @@ pub fn vortex_array::arrays::VarBinVTable::cast(array: &vortex_array::arrays::Va impl vortex_array::expr::CompareKernel for vortex_array::arrays::VarBinVTable -pub fn vortex_array::arrays::VarBinVTable::compare(lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_array::arrays::VarBinVTable::compare(lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::expr::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::MaskReduce for vortex_array::arrays::VarBinVTable @@ -5412,6 +5412,8 @@ pub trait vortex_array::builtins::ArrayBuiltins: core::marker::Sized pub fn vortex_array::builtins::ArrayBuiltins::between(self, lower: vortex_array::ArrayRef, upper: vortex_array::ArrayRef, options: vortex_array::expr::BetweenOptions) -> vortex_error::VortexResult +pub fn vortex_array::builtins::ArrayBuiltins::binary(&self, rhs: vortex_array::ArrayRef, op: vortex_array::expr::Operator) -> vortex_error::VortexResult + pub fn vortex_array::builtins::ArrayBuiltins::cast(&self, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::builtins::ArrayBuiltins::fill_null(&self, fill_value: impl core::convert::Into) -> vortex_error::VortexResult @@ -5432,6 +5434,8 @@ impl vortex_array::builtins::ArrayBuiltins for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::between(self, lower: vortex_array::ArrayRef, upper: vortex_array::ArrayRef, options: vortex_array::expr::BetweenOptions) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::binary(&self, rhs: vortex_array::ArrayRef, op: vortex_array::expr::Operator) -> vortex_error::VortexResult + pub fn vortex_array::ArrayRef::cast(&self, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::ArrayRef::fill_null(&self, fill_value: impl core::convert::Into) -> vortex_error::VortexResult @@ -5450,6 +5454,8 @@ pub fn vortex_array::ArrayRef::zip(&self, if_false: vortex_array::ArrayRef, mask pub trait vortex_array::builtins::ExprBuiltins: core::marker::Sized +pub fn vortex_array::builtins::ExprBuiltins::binary(&self, rhs: vortex_array::expr::Expression, op: vortex_array::expr::Operator) -> vortex_error::VortexResult + pub fn vortex_array::builtins::ExprBuiltins::cast(&self, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::builtins::ExprBuiltins::fill_null(&self, fill_value: vortex_array::expr::Expression) -> vortex_error::VortexResult @@ -5468,6 +5474,8 @@ pub fn vortex_array::builtins::ExprBuiltins::zip(&self, if_false: vortex_array:: impl vortex_array::builtins::ExprBuiltins for vortex_array::expr::Expression +pub fn vortex_array::expr::Expression::binary(&self, rhs: vortex_array::expr::Expression, op: vortex_array::expr::Operator) -> vortex_error::VortexResult + pub fn vortex_array::expr::Expression::cast(&self, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::expr::Expression::fill_null(&self, fill_value: vortex_array::expr::Expression) -> vortex_error::VortexResult @@ -5556,66 +5564,6 @@ impl<'a> core::convert::From<&'a vortex_mask::Mask> for vortex_array::compute::I pub fn vortex_array::compute::Input<'a>::from(value: &'a vortex_mask::Mask) -> Self -pub enum vortex_array::compute::Operator - -pub vortex_array::compute::Operator::Eq - -pub vortex_array::compute::Operator::Gt - -pub vortex_array::compute::Operator::Gte - -pub vortex_array::compute::Operator::Lt - -pub vortex_array::compute::Operator::Lte - -pub vortex_array::compute::Operator::NotEq - -impl vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::inverse(self) -> Self - -pub fn vortex_array::compute::Operator::swap(self) -> Self - -impl core::clone::Clone for vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::clone(&self) -> vortex_array::compute::Operator - -impl core::cmp::Eq for vortex_array::compute::Operator - -impl core::cmp::PartialEq for vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::eq(&self, other: &vortex_array::compute::Operator) -> bool - -impl core::cmp::PartialOrd for vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::partial_cmp(&self, other: &vortex_array::compute::Operator) -> core::option::Option - -impl core::convert::From for vortex_array::expr::Operator - -pub fn vortex_array::expr::Operator::from(cmp_operator: vortex_array::compute::Operator) -> Self - -impl core::convert::TryInto for vortex_array::expr::Operator - -pub type vortex_array::expr::Operator::Error = vortex_error::VortexError - -pub fn vortex_array::expr::Operator::try_into(self) -> vortex_error::VortexResult - -impl core::fmt::Debug for vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl core::fmt::Display for vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl core::hash::Hash for vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::hash<__H: core::hash::Hasher>(&self, state: &mut __H) - -impl core::marker::Copy for vortex_array::compute::Operator - -impl core::marker::StructuralPartialEq for vortex_array::compute::Operator - pub enum vortex_array::compute::Output pub vortex_array::compute::Output::Array(vortex_array::ArrayRef) @@ -6186,9 +6134,7 @@ pub fn vortex_array::compute::arrow_filter_fn(array: &dyn vortex_array::Array, m pub fn vortex_array::compute::cast(array: &dyn vortex_array::Array, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult -pub fn vortex_array::compute::compare(left: &dyn vortex_array::Array, right: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult - -pub fn vortex_array::compute::compare_lengths_to_empty(lengths: I, op: vortex_array::compute::Operator) -> vortex_buffer::bit::buf::BitBuffer where P: vortex_array::dtype::IntegerPType, I: core::iter::traits::iterator::Iterator +pub fn vortex_array::compute::compare_lengths_to_empty(lengths: I, op: vortex_array::expr::CompareOperator) -> vortex_buffer::bit::buf::BitBuffer where P: vortex_array::dtype::IntegerPType, I: core::iter::traits::iterator::Iterator pub fn vortex_array::compute::div(lhs: &dyn vortex_array::Array, rhs: &dyn vortex_array::Array) -> vortex_error::VortexResult @@ -6226,8 +6172,6 @@ pub fn vortex_array::compute::numeric(lhs: &dyn vortex_array::Array, rhs: &dyn v pub fn vortex_array::compute::or_kleene(lhs: &dyn vortex_array::Array, rhs: &dyn vortex_array::Array) -> vortex_error::VortexResult -pub fn vortex_array::compute::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::compute::Operator) -> vortex_array::scalar::Scalar - pub fn vortex_array::compute::sub(lhs: &dyn vortex_array::Array, rhs: &dyn vortex_array::Array) -> vortex_error::VortexResult pub fn vortex_array::compute::sub_scalar(lhs: &dyn vortex_array::Array, rhs: vortex_array::scalar::Scalar) -> vortex_error::VortexResult @@ -10138,6 +10082,66 @@ impl core::marker::Copy for vortex_array::expr::Arity impl core::marker::StructuralPartialEq for vortex_array::expr::Arity +pub enum vortex_array::expr::CompareOperator + +pub vortex_array::expr::CompareOperator::Eq + +pub vortex_array::expr::CompareOperator::Gt + +pub vortex_array::expr::CompareOperator::Gte + +pub vortex_array::expr::CompareOperator::Lt + +pub vortex_array::expr::CompareOperator::Lte + +pub vortex_array::expr::CompareOperator::NotEq + +impl vortex_array::expr::CompareOperator + +pub fn vortex_array::expr::CompareOperator::inverse(self) -> Self + +pub fn vortex_array::expr::CompareOperator::swap(self) -> Self + +impl core::clone::Clone for vortex_array::expr::CompareOperator + +pub fn vortex_array::expr::CompareOperator::clone(&self) -> vortex_array::expr::CompareOperator + +impl core::cmp::Eq for vortex_array::expr::CompareOperator + +impl core::cmp::PartialEq for vortex_array::expr::CompareOperator + +pub fn vortex_array::expr::CompareOperator::eq(&self, other: &vortex_array::expr::CompareOperator) -> bool + +impl core::cmp::PartialOrd for vortex_array::expr::CompareOperator + +pub fn vortex_array::expr::CompareOperator::partial_cmp(&self, other: &vortex_array::expr::CompareOperator) -> core::option::Option + +impl core::convert::From for vortex_array::expr::Operator + +pub fn vortex_array::expr::Operator::from(value: vortex_array::expr::CompareOperator) -> Self + +impl core::convert::TryFrom for vortex_array::expr::CompareOperator + +pub type vortex_array::expr::CompareOperator::Error = vortex_error::VortexError + +pub fn vortex_array::expr::CompareOperator::try_from(value: vortex_array::expr::Operator) -> core::result::Result + +impl core::fmt::Debug for vortex_array::expr::CompareOperator + +pub fn vortex_array::expr::CompareOperator::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::expr::CompareOperator + +pub fn vortex_array::expr::CompareOperator::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::expr::CompareOperator + +pub fn vortex_array::expr::CompareOperator::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::Copy for vortex_array::expr::CompareOperator + +impl core::marker::StructuralPartialEq for vortex_array::expr::CompareOperator + pub enum vortex_array::expr::DuplicateHandling pub vortex_array::expr::DuplicateHandling::Error @@ -10254,8 +10258,6 @@ pub fn vortex_array::expr::Operator::is_comparison(&self) -> bool pub fn vortex_array::expr::Operator::logical_inverse(self) -> core::option::Option -pub fn vortex_array::expr::Operator::maybe_cmp_operator(self) -> core::option::Option - pub fn vortex_array::expr::Operator::swap(self) -> core::option::Option impl core::clone::Clone for vortex_array::expr::Operator @@ -10272,9 +10274,9 @@ impl core::cmp::PartialOrd for vortex_array::expr::Operator pub fn vortex_array::expr::Operator::partial_cmp(&self, other: &vortex_array::expr::Operator) -> core::option::Option -impl core::convert::From for vortex_array::expr::Operator +impl core::convert::From for vortex_array::expr::Operator -pub fn vortex_array::expr::Operator::from(cmp_operator: vortex_array::compute::Operator) -> Self +pub fn vortex_array::expr::Operator::from(value: vortex_array::expr::CompareOperator) -> Self impl core::convert::From for i32 @@ -10284,6 +10286,10 @@ impl core::convert::From for vortex_proto::expr::b pub fn vortex_proto::expr::binary_opts::BinaryOp::from(value: vortex_array::expr::Operator) -> Self +impl core::convert::From for vortex_array::expr::Operator + +pub fn vortex_array::expr::Operator::from(op: vortex_array::scalar::NumericOperator) -> Self + impl core::convert::From for vortex_array::expr::Operator pub fn vortex_array::expr::Operator::from(value: vortex_proto::expr::binary_opts::BinaryOp) -> Self @@ -10294,11 +10300,11 @@ pub type vortex_array::expr::Operator::Error = vortex_error::VortexError pub fn vortex_array::expr::Operator::try_from(value: i32) -> core::result::Result -impl core::convert::TryInto for vortex_array::expr::Operator +impl core::convert::TryFrom for vortex_array::expr::CompareOperator -pub type vortex_array::expr::Operator::Error = vortex_error::VortexError +pub type vortex_array::expr::CompareOperator::Error = vortex_error::VortexError -pub fn vortex_array::expr::Operator::try_into(self) -> vortex_error::VortexResult +pub fn vortex_array::expr::CompareOperator::try_from(value: vortex_array::expr::Operator) -> core::result::Result impl core::fmt::Debug for vortex_array::expr::Operator @@ -10326,7 +10332,9 @@ impl vortex_array::expr::StrictComparison pub const fn vortex_array::expr::StrictComparison::is_strict(&self) -> bool -pub const fn vortex_array::expr::StrictComparison::to_operator(&self) -> vortex_array::compute::Operator +pub const fn vortex_array::expr::StrictComparison::to_compare_operator(&self) -> vortex_array::expr::CompareOperator + +pub const fn vortex_array::expr::StrictComparison::to_operator(&self) -> vortex_array::expr::Operator impl core::clone::Clone for vortex_array::expr::StrictComparison @@ -10782,6 +10790,8 @@ pub fn vortex_array::expr::Expression::drop(&mut self) impl vortex_array::builtins::ExprBuiltins for vortex_array::expr::Expression +pub fn vortex_array::expr::Expression::binary(&self, rhs: vortex_array::expr::Expression, op: vortex_array::expr::Operator) -> vortex_error::VortexResult + pub fn vortex_array::expr::Expression::cast(&self, dtype: vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_array::expr::Expression::fill_null(&self, fill_value: vortex_array::expr::Expression) -> vortex_error::VortexResult @@ -11646,19 +11656,19 @@ pub fn vortex_array::arrays::VarBinViewVTable::cast(array: &vortex_array::arrays pub trait vortex_array::expr::CompareKernel: vortex_array::vtable::VTable -pub fn vortex_array::expr::CompareKernel::compare(lhs: &Self::Array, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_array::expr::CompareKernel::compare(lhs: &Self::Array, rhs: &dyn vortex_array::Array, operator: vortex_array::expr::CompareOperator, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::CompareKernel for vortex_array::arrays::DictVTable -pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::expr::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::CompareKernel for vortex_array::arrays::ExtensionVTable -pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::expr::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::CompareKernel for vortex_array::arrays::VarBinVTable -pub fn vortex_array::arrays::VarBinVTable::compare(lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_array::arrays::VarBinVTable::compare(lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::expr::CompareOperator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> pub trait vortex_array::expr::DynExprVTable: 'static + core::marker::Send + core::marker::Sync + vortex_array::expr::vtable::private::Sealed @@ -12476,7 +12486,7 @@ pub fn vortex_array::expr::col(field: impl core::convert::Into(expr: &vortex_array::expr::Expression, annotate: A) -> vortex_array::expr::Annotations<'_, ::Annotation> -pub fn vortex_array::expr::dynamic(operator: vortex_array::compute::Operator, rhs_value: impl core::ops::function::Fn() -> core::option::Option + core::marker::Send + core::marker::Sync + 'static, rhs_dtype: vortex_array::dtype::DType, default: bool, lhs: vortex_array::expr::Expression) -> vortex_array::expr::Expression +pub fn vortex_array::expr::dynamic(operator: vortex_array::expr::CompareOperator, rhs_value: impl core::ops::function::Fn() -> core::option::Option + core::marker::Send + core::marker::Sync + 'static, rhs_dtype: vortex_array::dtype::DType, default: bool, lhs: vortex_array::expr::Expression) -> vortex_array::expr::Expression pub fn vortex_array::expr::eq(lhs: vortex_array::expr::Expression, rhs: vortex_array::expr::Expression) -> vortex_array::expr::Expression @@ -12540,6 +12550,8 @@ pub fn vortex_array::expr::pack(elements: impl core::iter::traits::collect::Into pub fn vortex_array::expr::root() -> vortex_array::expr::Expression +pub fn vortex_array::expr::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::expr::CompareOperator) -> vortex_array::scalar::Scalar + pub fn vortex_array::expr::select(field_names: impl core::convert::Into, child: vortex_array::expr::Expression) -> vortex_array::expr::Expression pub fn vortex_array::expr::select_exclude(fields: impl core::convert::Into, child: vortex_array::expr::Expression) -> vortex_array::expr::Expression @@ -13224,16 +13236,8 @@ pub vortex_array::scalar::NumericOperator::Div pub vortex_array::scalar::NumericOperator::Mul -pub vortex_array::scalar::NumericOperator::RDiv - -pub vortex_array::scalar::NumericOperator::RSub - pub vortex_array::scalar::NumericOperator::Sub -impl vortex_array::scalar::NumericOperator - -pub fn vortex_array::scalar::NumericOperator::swap(self) -> Self - impl core::clone::Clone for vortex_array::scalar::NumericOperator pub fn vortex_array::scalar::NumericOperator::clone(&self) -> vortex_array::scalar::NumericOperator @@ -13244,6 +13248,10 @@ impl core::cmp::PartialEq for vortex_array::scalar::NumericOperator pub fn vortex_array::scalar::NumericOperator::eq(&self, other: &vortex_array::scalar::NumericOperator) -> bool +impl core::convert::From for vortex_array::expr::Operator + +pub fn vortex_array::expr::Operator::from(op: vortex_array::scalar::NumericOperator) -> Self + impl core::fmt::Debug for vortex_array::scalar::NumericOperator pub fn vortex_array::scalar::NumericOperator::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/vortex-array/src/arrays/dict/compute/compare.rs b/vortex-array/src/arrays/dict/compute/compare.rs index 2fbba2c29e1..147bb1627e6 100644 --- a/vortex-array/src/arrays/dict/compute/compare.rs +++ b/vortex-array/src/arrays/dict/compute/compare.rs @@ -10,15 +10,16 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::ConstantArray; -use crate::compute::Operator; -use crate::compute::compare; +use crate::builtins::ArrayBuiltins; use crate::expr::CompareKernel; +use crate::expr::CompareOperator; +use crate::expr::Operator; impl CompareKernel for DictVTable { fn compare( lhs: &DictArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { // if we have more values than codes, it is faster to canonicalise first. @@ -28,10 +29,9 @@ impl CompareKernel for DictVTable { // If the RHS is constant, then we just need to compare against our encoded values. if let Some(rhs) = rhs.as_constant() { - let compare_result = compare( - lhs.values(), - ConstantArray::new(rhs, lhs.values().len()).as_ref(), - operator, + let compare_result = lhs.values().to_array().binary( + ConstantArray::new(rhs, lhs.values().len()).to_array(), + Operator::from(operator), )?; // SAFETY: values len preserved, codes all still point to valid values diff --git a/vortex-array/src/arrays/dict/compute/fill_null.rs b/vortex-array/src/arrays/dict/compute/fill_null.rs index 9bf3c9d3611..490b4ff1b13 100644 --- a/vortex-array/src/arrays/dict/compute/fill_null.rs +++ b/vortex-array/src/arrays/dict/compute/fill_null.rs @@ -12,9 +12,8 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::ConstantArray; use crate::builtins::ArrayBuiltins; -use crate::compute::Operator; -use crate::compute::compare; use crate::expr::FillNullKernel; +use crate::expr::Operator; use crate::match_each_integer_ptype; use crate::scalar::Scalar; use crate::scalar::ScalarValue; @@ -27,12 +26,14 @@ impl FillNullKernel for DictVTable { ) -> VortexResult> { // If the fill value already exists in the dictionary, we can simply rewrite the null codes // to point to the value. - let found_fill_values = compare( - array.values(), - ConstantArray::new(fill_value.clone(), array.values().len()).as_ref(), - Operator::Eq, - )? - .to_bool(); + let found_fill_values = array + .values() + .to_array() + .binary( + ConstantArray::new(fill_value.clone(), array.values().len()).to_array(), + Operator::Eq, + )? + .to_bool(); // We found the fill value already in the values at this given index. let Some(existing_fill_value_index) = diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index e87dee05f7c..138eafcd8fa 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -72,14 +72,14 @@ mod test { use crate::arrays::VarBinViewArray; use crate::assert_arrays_eq; use crate::builders::dict::dict_encode; - use crate::compute::Operator; - use crate::compute::compare; + use crate::builtins::ArrayBuiltins; use crate::compute::conformance::filter::test_filter_conformance; use crate::compute::conformance::mask::test_mask_conformance; use crate::compute::conformance::take::test_take_conformance; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType::I32; + use crate::expr::Operator; #[test] fn canonicalise_nullable_primitive() { let values: Vec> = (0..65) @@ -157,7 +157,9 @@ mod test { fn compare_sliced_dict() { use crate::arrays::BoolArray; let sliced = sliced_dict_array(); - let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap(); + let compared = sliced + .binary(ConstantArray::new(42, 3).to_array(), Operator::Eq) + .unwrap(); let expected = BoolArray::from_iter([Some(false), None, Some(true)]); assert_arrays_eq!(compared, expected.to_array()); diff --git a/vortex-array/src/arrays/extension/compute/compare.rs b/vortex-array/src/arrays/extension/compute/compare.rs index aac0226bae2..c9782452dd1 100644 --- a/vortex-array/src/arrays/extension/compute/compare.rs +++ b/vortex-array/src/arrays/extension/compute/compare.rs @@ -9,31 +9,38 @@ use crate::ExecutionCtx; use crate::arrays::ConstantArray; use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; -use crate::compute; -use crate::compute::Operator; +use crate::builtins::ArrayBuiltins; use crate::expr::CompareKernel; +use crate::expr::CompareOperator; +use crate::expr::Operator; impl CompareKernel for ExtensionVTable { fn compare( lhs: &ExtensionArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { // If the RHS is a constant, we can extract the storage scalar. if let Some(const_ext) = rhs.as_constant() { let storage_scalar = const_ext.as_extension().to_storage_scalar(); - return compute::compare( - lhs.storage(), - ConstantArray::new(storage_scalar, lhs.len()).as_ref(), - operator, - ) - .map(Some); + return lhs + .storage() + .to_array() + .binary( + ConstantArray::new(storage_scalar, lhs.len()).to_array(), + Operator::from(operator), + ) + .map(Some); } // If the RHS is an extension array matching ours, we can extract the storage. if let Some(rhs_ext) = rhs.as_opt::() { - return compute::compare(lhs.storage(), rhs_ext.storage(), operator).map(Some); + return lhs + .storage() + .to_array() + .binary(rhs_ext.storage().to_array(), Operator::from(operator)) + .map(Some); } // Otherwise, we need the RHS to handle this comparison. diff --git a/vortex-array/src/arrays/list/array.rs b/vortex-array/src/arrays/list/array.rs index 68684d2d83f..21d4b430bcf 100644 --- a/vortex-array/src/arrays/list/array.rs +++ b/vortex-array/src/arrays/list/array.rs @@ -13,12 +13,14 @@ use vortex_error::vortex_panic; use crate::Array; use crate::ArrayRef; use crate::IntoArray; +use crate::arrays::ConstantArray; use crate::arrays::ListVTable; use crate::arrays::PrimitiveVTable; +use crate::builtins::ArrayBuiltins; use crate::compute::min_max; -use crate::compute::sub_scalar; use crate::dtype::DType; use crate::dtype::NativePType; +use crate::expr::Operator; use crate::match_each_integer_ptype; use crate::match_each_native_ptype; use crate::stats::ArrayStats; @@ -321,7 +323,10 @@ impl ListArray { let offsets = self.offsets(); let first_offset = offsets.scalar_at(0)?; - let adjusted_offsets = sub_scalar(offsets, first_offset)?; + let adjusted_offsets = offsets.to_array().binary( + ConstantArray::new(first_offset, offsets.len()).into_array(), + Operator::Sub, + )?; Self::try_new(elements, adjusted_offsets, self.validity.clone()) } diff --git a/vortex-array/src/arrays/list/compute/is_constant.rs b/vortex-array/src/arrays/list/compute/is_constant.rs index 3de9c6fbf8c..a70a0681321 100644 --- a/vortex-array/src/arrays/list/compute/is_constant.rs +++ b/vortex-array/src/arrays/list/compute/is_constant.rs @@ -5,13 +5,13 @@ use vortex_error::VortexResult; use crate::arrays::ListArray; use crate::arrays::ListVTable; +use crate::builtins::ArrayBuiltins; use crate::compute::IsConstantKernel; use crate::compute::IsConstantKernelAdapter; use crate::compute::IsConstantOpts; use crate::compute::is_constant; -use crate::compute::numeric; +use crate::expr::Operator; use crate::register_kernel; -use crate::scalar::NumericOperator; const SMALL_ARRAY_THRESHOLD: usize = 64; @@ -46,7 +46,7 @@ impl IsConstantKernel for ListVTable { let end_offsets = array .offsets() .slice(SMALL_ARRAY_THRESHOLD + 1..array.len() + 1)?; - let list_lengths = numeric(&end_offsets, &start_offsets, NumericOperator::Sub)?; + let list_lengths = end_offsets.binary(start_offsets, Operator::Sub)?; if !is_constant(&list_lengths)?.unwrap_or_default() { return Ok(Some(false)); diff --git a/vortex-array/src/arrays/listview/rebuild.rs b/vortex-array/src/arrays/listview/rebuild.rs index 3d12be0c27f..073f29cf741 100644 --- a/vortex-array/src/arrays/listview/rebuild.rs +++ b/vortex-array/src/arrays/listview/rebuild.rs @@ -9,11 +9,14 @@ use vortex_error::VortexResult; use crate::Array; use crate::IntoArray; use crate::ToCanonical; +use crate::arrays::ConstantArray; use crate::arrays::ListViewArray; use crate::builders::builder_with_capacity; +use crate::builtins::ArrayBuiltins; use crate::compute; use crate::dtype::IntegerPType; use crate::dtype::Nullability; +use crate::expr::Operator; use crate::match_each_integer_ptype; use crate::scalar::Scalar; use crate::vtable::ValidityHelper; @@ -211,7 +214,10 @@ impl ListViewArray { last_offset + last_size } else { let min_max = compute::min_max( - &compute::add(self.offsets(), self.sizes()) + &self + .offsets() + .clone() + .binary(self.sizes().clone(), Operator::Add) .vortex_expect("`offsets + sizes` somehow overflowed"), ) .vortex_expect("Something went wrong while computing min and max") @@ -229,7 +235,12 @@ impl ListViewArray { .vortex_expect("unable to convert the min offset `start` into a `usize`"); let scalar = Scalar::primitive(offset, Nullability::NonNullable); - compute::sub_scalar(self.offsets(), scalar) + self.offsets() + .to_array() + .binary( + ConstantArray::new(scalar, self.offsets().len()).into_array(), + Operator::Sub, + ) .vortex_expect("was somehow unable to adjust offsets down by their minimum") }); diff --git a/vortex-array/src/arrays/varbin/compute/compare.rs b/vortex-array/src/arrays/varbin/compute/compare.rs index dfdbfc8805d..5682c0e7959 100644 --- a/vortex-array/src/arrays/varbin/compute/compare.rs +++ b/vortex-array/src/arrays/varbin/compute/compare.rs @@ -22,12 +22,13 @@ use crate::arrays::VarBinArray; use crate::arrays::VarBinVTable; use crate::arrow::Datum; use crate::arrow::from_arrow_array_with_len; -use crate::compute::Operator; -use crate::compute::compare; +use crate::builtins::ArrayBuiltins; use crate::compute::compare_lengths_to_empty; use crate::dtype::DType; use crate::dtype::IntegerPType; use crate::expr::CompareKernel; +use crate::expr::CompareOperator; +use crate::expr::Operator; use crate::match_each_integer_ptype; use crate::vtable::ValidityHelper; @@ -36,7 +37,7 @@ impl CompareKernel for VarBinVTable { fn compare( lhs: &VarBinArray, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { if let Some(rhs_const) = rhs.as_constant() { @@ -57,9 +58,12 @@ impl CompareKernel for VarBinVTable { if rhs_is_empty { let buffer = match operator { - Operator::Gte => BitBuffer::new_set(len), // Every possible value is >= "" - Operator::Lt => BitBuffer::new_unset(len), // No value is < "" - Operator::Eq | Operator::NotEq | Operator::Gt | Operator::Lte => { + CompareOperator::Gte => BitBuffer::new_set(len), // Every possible value is >= "" + CompareOperator::Lt => BitBuffer::new_unset(len), // No value is < "" + CompareOperator::Eq + | CompareOperator::NotEq + | CompareOperator::Gt + | CompareOperator::Lte => { let lhs_offsets = lhs.offsets().to_primitive(); match_each_integer_ptype!(lhs_offsets.ptype(), |P| { compare_offsets_to_empty::

(lhs_offsets, operator) @@ -100,12 +104,12 @@ impl CompareKernel for VarBinVTable { }; let array = match operator { - Operator::Eq => cmp::eq(&lhs, arrow_rhs), - Operator::NotEq => cmp::neq(&lhs, arrow_rhs), - Operator::Gt => cmp::gt(&lhs, arrow_rhs), - Operator::Gte => cmp::gt_eq(&lhs, arrow_rhs), - Operator::Lt => cmp::lt(&lhs, arrow_rhs), - Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs), + CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs), + CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs), + CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs), + CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs), + CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs), + CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs), } .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?; @@ -114,7 +118,11 @@ impl CompareKernel for VarBinVTable { // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves // to VarBinView and re-invoke. - return Ok(Some(compare(lhs.to_varbinview().as_ref(), rhs, operator)?)); + return Ok(Some( + lhs.to_varbinview() + .to_array() + .binary(rhs.to_array(), Operator::from(operator))?, + )); } else { Ok(None) } @@ -123,7 +131,7 @@ impl CompareKernel for VarBinVTable { fn compare_offsets_to_empty( offsets: PrimitiveArray, - operator: Operator, + operator: CompareOperator, ) -> BitBuffer { let lengths_iter = offsets .as_slice::

() @@ -142,10 +150,10 @@ mod test { use crate::arrays::ConstantArray; use crate::arrays::VarBinArray; use crate::arrays::VarBinViewArray; - use crate::compute::Operator; - use crate::compute::compare; + use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; + use crate::expr::Operator; use crate::scalar::Scalar; #[test] @@ -154,17 +162,18 @@ mod test { [Some(b"abc".to_vec()), None, Some(b"def".to_vec())], DType::Binary(Nullability::Nullable), ); - let result = compare( - array.as_ref(), - ConstantArray::new( - Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable), - 3, + let result = array + .to_array() + .binary( + ConstantArray::new( + Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable), + 3, + ) + .to_array(), + Operator::Eq, ) - .as_ref(), - Operator::Eq, - ) - .unwrap() - .to_bool(); + .unwrap() + .to_bool(); assert_eq!( &result.validity_mask().unwrap().to_bit_buffer(), @@ -186,7 +195,9 @@ mod test { [None, None, Some(b"def".to_vec())], DType::Binary(Nullability::Nullable), ); - let result = compare(array.as_ref(), vbv.as_ref(), Operator::Eq) + let result = array + .to_array() + .binary(vbv.to_array(), Operator::Eq) .unwrap() .to_bool(); @@ -206,10 +217,10 @@ mod tests { use crate::Array; use crate::arrays::ConstantArray; use crate::arrays::VarBinArray; - use crate::compute::Operator; - use crate::compute::compare; + use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; + use crate::expr::Operator; use crate::scalar::Scalar; #[test] @@ -219,7 +230,8 @@ mod tests { let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1); assert_eq!( - compare(arr.as_ref(), const_.as_ref(), Operator::Eq) + arr.to_array() + .binary(const_.to_array(), Operator::Eq) .unwrap() .dtype(), &DType::Bool(Nullability::Nullable) diff --git a/vortex-array/src/builtins.rs b/vortex-array/src/builtins.rs index c7a0d94cbca..8591aafa6be 100644 --- a/vortex-array/src/builtins.rs +++ b/vortex-array/src/builtins.rs @@ -22,6 +22,7 @@ use crate::dtype::DType; use crate::dtype::FieldName; use crate::expr::Between; use crate::expr::BetweenOptions; +use crate::expr::Binary; use crate::expr::Cast; use crate::expr::EmptyOptions; use crate::expr::Expression; @@ -31,6 +32,7 @@ use crate::expr::IsNull; use crate::expr::ListContains; use crate::expr::Mask; use crate::expr::Not; +use crate::expr::Operator; use crate::expr::VTableExt; use crate::expr::Zip; use crate::optimizer::ArrayOptimizer; @@ -63,6 +65,9 @@ pub trait ExprBuiltins: Sized { /// Conditional selection: `result[i] = if mask[i] then self[i] else if_false[i]`. fn zip(&self, if_false: Expression, mask: Expression) -> VortexResult; + + /// Apply a binary operator to this expression and another. + fn binary(&self, rhs: Expression, op: Operator) -> VortexResult; } impl ExprBuiltins for Expression { @@ -97,6 +102,10 @@ impl ExprBuiltins for Expression { fn zip(&self, if_false: Expression, mask: Expression) -> VortexResult { Zip.try_new_expr(EmptyOptions, [self.clone(), if_false, mask]) } + + fn binary(&self, rhs: Expression, op: Operator) -> VortexResult { + Binary.try_new_expr(op, [self.clone(), rhs]) + } } pub trait ArrayBuiltins: Sized { @@ -126,6 +135,9 @@ pub trait ArrayBuiltins: Sized { /// Check if a list contains a value. fn list_contains(&self, value: ArrayRef) -> VortexResult; + /// Apply a binary operator to this array and another. + fn binary(&self, rhs: ArrayRef, op: Operator) -> VortexResult; + /// Compare a values between lower VortexResult { + Binary + .try_new_array(self.len(), op, [self.clone(), rhs])? + .optimize() + } + fn between( self, lower: ArrayRef, diff --git a/vortex-array/src/compute/arbitrary.rs b/vortex-array/src/compute/arbitrary.rs index 9f235b6eceb..6910c813910 100644 --- a/vortex-array/src/compute/arbitrary.rs +++ b/vortex-array/src/compute/arbitrary.rs @@ -4,17 +4,17 @@ use arbitrary::Arbitrary; use arbitrary::Unstructured; -use crate::compute::Operator; +use crate::expr::CompareOperator; -impl<'a> Arbitrary<'a> for Operator { +impl<'a> Arbitrary<'a> for CompareOperator { fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { Ok(match u.int_in_range(0..=5)? { - 0 => Operator::Eq, - 1 => Operator::NotEq, - 2 => Operator::Gt, - 3 => Operator::Gte, - 4 => Operator::Lt, - 5 => Operator::Lte, + 0 => CompareOperator::Eq, + 1 => CompareOperator::NotEq, + 2 => CompareOperator::Gt, + 3 => CompareOperator::Gte, + 4 => CompareOperator::Lt, + 5 => CompareOperator::Lte, _ => unreachable!(), }) } diff --git a/vortex-array/src/compute/boolean.rs b/vortex-array/src/compute/boolean.rs index ad5045fbf46..a78aa8e3af3 100644 --- a/vortex-array/src/compute/boolean.rs +++ b/vortex-array/src/compute/boolean.rs @@ -5,15 +5,17 @@ use vortex_error::VortexResult; use crate::Array; use crate::ArrayRef; +use crate::builtins::ArrayBuiltins; +use crate::expr::Operator; /// Point-wise Kleene logical _and_ between two Boolean arrays. -#[deprecated(note = "use expr::and_kleene instead")] +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn and_kleene(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { - crate::expr::and_kleene(lhs, rhs) + lhs.to_array().binary(rhs.to_array(), Operator::And) } /// Point-wise Kleene logical _or_ between two Boolean arrays. -#[deprecated(note = "use expr::or_kleene instead")] +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn or_kleene(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { - crate::expr::or_kleene(lhs, rhs) + lhs.to_array().binary(rhs.to_array(), Operator::Or) } diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index d9ccc3717c9..bcecf577baf 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -1,10 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use core::fmt; use std::cmp::Ordering; -use std::fmt::Display; -use std::fmt::Formatter; use arrow_array::BooleanArray; use arrow_buffer::NullBuffer; @@ -13,99 +10,22 @@ use arrow_schema::SortOptions; use vortex_buffer::BitBuffer; use vortex_error::VortexResult; -use crate::Array; -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::ScalarFnArray; -use crate::dtype::DType; use crate::dtype::IntegerPType; -use crate::dtype::Nullability; -use crate::expr::Binary; -use crate::expr::ScalarFn; -use crate::expr::operators; -use crate::scalar::Scalar; - -/// Compares two arrays and returns a new boolean array with the result of the comparison. -/// -/// The returned array is lazy (a [`ScalarFnArray`]) and will be evaluated on demand. -pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult { - let expr_op: operators::Operator = operator.into(); - Ok(ScalarFnArray::try_new( - ScalarFn::new(Binary, expr_op), - vec![left.to_array(), right.to_array()], - left.len(), - )? - .into_array()) -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)] -pub enum Operator { - /// Equality (`=`) - Eq, - /// Inequality (`!=`) - NotEq, - /// Greater than (`>`) - Gt, - /// Greater than or equal (`>=`) - Gte, - /// Less than (`<`) - Lt, - /// Less than or equal (`<=`) - Lte, -} - -impl Display for Operator { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let display = match &self { - Operator::Eq => "=", - Operator::NotEq => "!=", - Operator::Gt => ">", - Operator::Gte => ">=", - Operator::Lt => "<", - Operator::Lte => "<=", - }; - Display::fmt(display, f) - } -} - -impl Operator { - pub fn inverse(self) -> Self { - match self { - Operator::Eq => Operator::NotEq, - Operator::NotEq => Operator::Eq, - Operator::Gt => Operator::Lte, - Operator::Gte => Operator::Lt, - Operator::Lt => Operator::Gte, - Operator::Lte => Operator::Gt, - } - } - - /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation - pub fn swap(self) -> Self { - match self { - Operator::Eq => Operator::Eq, - Operator::NotEq => Operator::NotEq, - Operator::Gt => Operator::Lt, - Operator::Gte => Operator::Lte, - Operator::Lt => Operator::Gt, - Operator::Lte => Operator::Gte, - } - } -} +use crate::expr::CompareOperator; /// Helper function to compare empty values with arrays that have external value length information /// like `VarBin`. -pub fn compare_lengths_to_empty(lengths: I, op: Operator) -> BitBuffer +pub fn compare_lengths_to_empty(lengths: I, op: CompareOperator) -> BitBuffer where P: IntegerPType, I: Iterator, { // All comparison can be expressed in terms of equality. "" is the absolute min of possible value. let cmp_fn = match op { - Operator::Eq | Operator::Lte => |v| v == P::zero(), - Operator::NotEq | Operator::Gt => |v| v != P::zero(), - Operator::Gte => |_| true, - Operator::Lt => |_| false, + CompareOperator::Eq | CompareOperator::Lte => |v| v == P::zero(), + CompareOperator::NotEq | CompareOperator::Gt => |v| v != P::zero(), + CompareOperator::Gte => |_| true, + CompareOperator::Lt => |_| false, }; lengths.map(cmp_fn).collect() @@ -121,17 +41,17 @@ where pub(crate) fn compare_nested_arrow_arrays( lhs: &dyn arrow_array::Array, rhs: &dyn arrow_array::Array, - operator: Operator, + operator: CompareOperator, ) -> VortexResult { let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?; let cmp_fn = match operator { - Operator::Eq => Ordering::is_eq, - Operator::NotEq => Ordering::is_ne, - Operator::Gt => Ordering::is_gt, - Operator::Gte => Ordering::is_ge, - Operator::Lt => Ordering::is_lt, - Operator::Lte => Ordering::is_le, + CompareOperator::Eq => Ordering::is_eq, + CompareOperator::NotEq => Ordering::is_ne, + CompareOperator::Gt => Ordering::is_gt, + CompareOperator::Gte => Ordering::is_ge, + CompareOperator::Lt => Ordering::is_lt, + CompareOperator::Lte => Ordering::is_le, }; let values = (0..lhs.len()) @@ -142,263 +62,23 @@ pub(crate) fn compare_nested_arrow_arrays( Ok(BooleanArray::new(values, nulls)) } -pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar { - if lhs.is_null() | rhs.is_null() { - Scalar::null(DType::Bool(Nullability::Nullable)) - } else { - let b = match operator { - Operator::Eq => lhs == rhs, - Operator::NotEq => lhs != rhs, - Operator::Gt => lhs > rhs, - Operator::Gte => lhs >= rhs, - Operator::Lt => lhs < rhs, - Operator::Lte => lhs <= rhs, - }; - - Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability()) - } -} - #[cfg(test)] mod tests { use rstest::rstest; - use vortex_buffer::buffer; use super::*; - use crate::ToCanonical; - use crate::arrays::BoolArray; - use crate::arrays::ConstantArray; - use crate::arrays::ListArray; - use crate::arrays::ListViewArray; - use crate::arrays::PrimitiveArray; - use crate::arrays::StructArray; - use crate::arrays::VarBinArray; - use crate::arrays::VarBinViewArray; - use crate::assert_arrays_eq; - use crate::dtype::FieldName; - use crate::dtype::FieldNames; - use crate::test_harness::to_int_indices; - use crate::validity::Validity; - - #[test] - fn test_bool_basic_comparisons() { - let arr = BoolArray::new( - BitBuffer::from_iter([true, true, false, true, false]), - Validity::from_iter([false, true, true, true, true]), - ); - - let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq) - .unwrap() - .to_bool(); - - assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]); - - let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq) - .unwrap() - .to_bool(); - let empty: [u64; 0] = []; - assert_eq!(to_int_indices(matches).unwrap(), empty); - - let other = BoolArray::new( - BitBuffer::from_iter([false, false, false, true, true]), - Validity::from_iter([false, true, true, true, true]), - ); - - let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte) - .unwrap() - .to_bool(); - assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]); - - let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt) - .unwrap() - .to_bool(); - assert_eq!(to_int_indices(matches).unwrap(), [4u64]); - - let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte) - .unwrap() - .to_bool(); - assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]); - - let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt) - .unwrap() - .to_bool(); - assert_eq!(to_int_indices(matches).unwrap(), [4u64]); - } - - #[test] - fn constant_compare() { - let left = ConstantArray::new(Scalar::from(2u32), 10); - let right = ConstantArray::new(Scalar::from(10u32), 10); - - let result = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap(); - assert_eq!(result.len(), 10); - let scalar = result.scalar_at(0).unwrap(); - assert_eq!(scalar.as_bool().value(), Some(false)); - } #[rstest] - #[case(Operator::Eq, vec![false, false, false, true])] - #[case(Operator::NotEq, vec![true, true, true, false])] - #[case(Operator::Gt, vec![true, true, true, false])] - #[case(Operator::Gte, vec![true, true, true, true])] - #[case(Operator::Lt, vec![false, false, false, false])] - #[case(Operator::Lte, vec![false, false, false, true])] - fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec) { + #[case(CompareOperator::Eq, vec![false, false, false, true])] + #[case(CompareOperator::NotEq, vec![true, true, true, false])] + #[case(CompareOperator::Gt, vec![true, true, true, false])] + #[case(CompareOperator::Gte, vec![true, true, true, true])] + #[case(CompareOperator::Lt, vec![false, false, false, false])] + #[case(CompareOperator::Lte, vec![false, false, false, true])] + fn test_cmp_to_empty(#[case] op: CompareOperator, #[case] expected: Vec) { let lengths: Vec = vec![1, 5, 7, 0]; let output = compare_lengths_to_empty(lengths.iter().copied(), op); assert_eq!(Vec::from_iter(output.iter()), expected); } - - #[rstest] - #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())] - #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())] - #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())] - #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())] - fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) { - let res = compare(&left, &right, Operator::Eq).unwrap(); - let expected = BoolArray::from_iter([true, true]); - assert_arrays_eq!(res, expected); - } - - #[ignore = "Arrow's ListView cannot be compared"] - #[test] - fn test_list_array_comparison() { - // Create two simple list arrays with integers - let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]); - let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]); - let list1 = ListArray::try_new( - values1.into_array(), - offsets1.into_array(), - Validity::NonNullable, - ) - .unwrap(); - - let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]); - let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]); - let list2 = ListArray::try_new( - values2.into_array(), - offsets2.into_array(), - Validity::NonNullable, - ) - .unwrap(); - - // Test equality - first two lists should be equal, third should be different - let result = compare(list1.as_ref(), list2.as_ref(), Operator::Eq).unwrap(); - let expected = BoolArray::from_iter([true, true, false]); - assert_arrays_eq!(result, expected); - - // Test inequality - let result = compare(list1.as_ref(), list2.as_ref(), Operator::NotEq).unwrap(); - let expected = BoolArray::from_iter([false, false, true]); - assert_arrays_eq!(result, expected); - - // Test less than - let result = compare(list1.as_ref(), list2.as_ref(), Operator::Lt).unwrap(); - let expected = BoolArray::from_iter([false, false, true]); - assert_arrays_eq!(result, expected); - } - - #[ignore = "Arrow's ListView cannot be compared"] - #[test] - fn test_list_array_constant_comparison() { - use std::sync::Arc; - - use crate::dtype::DType; - use crate::dtype::PType; - - // Create a list array - let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]); - let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]); - let list = ListArray::try_new( - values.into_array(), - offsets.into_array(), - Validity::NonNullable, - ) - .unwrap(); - - // Create a constant list scalar [3,4] that will be broadcasted - let list_scalar = Scalar::list( - Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)), - vec![3i32.into(), 4i32.into()], - Nullability::NonNullable, - ); - let constant = ConstantArray::new(list_scalar, 3); - - // Compare list with constant - all should be compared to [3,4] - let result = compare(list.as_ref(), constant.as_ref(), Operator::Eq).unwrap(); - let expected = BoolArray::from_iter([false, true, false]); - assert_arrays_eq!(result, expected); - } - - #[test] - fn test_struct_array_comparison() { - // Create two struct arrays with bool and int fields - let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]); - let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]); - - let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]); - let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]); - - let struct1 = StructArray::from_fields(&[ - ("bool_col", bool_field1.into_array()), - ("int_col", int_field1.into_array()), - ]) - .unwrap(); - - let struct2 = StructArray::from_fields(&[ - ("bool_col", bool_field2.into_array()), - ("int_col", int_field2.into_array()), - ]) - .unwrap(); - - // Test equality - let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Eq).unwrap(); - let expected = BoolArray::from_iter([true, true, false]); - assert_arrays_eq!(result, expected); - - // Test greater than - let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Gt).unwrap(); - let expected = BoolArray::from_iter([false, false, true]); - assert_arrays_eq!(result, expected); - } - - #[test] - fn test_empty_struct_compare() { - let empty1 = StructArray::try_new( - FieldNames::from(Vec::::new()), - Vec::new(), - 5, - Validity::NonNullable, - ) - .unwrap(); - - let empty2 = StructArray::try_new( - FieldNames::from(Vec::::new()), - Vec::new(), - 5, - Validity::NonNullable, - ) - .unwrap(); - - let result = compare(empty1.as_ref(), empty2.as_ref(), Operator::Eq).unwrap(); - let expected = BoolArray::from_iter([true, true, true, true, true]); - assert_arrays_eq!(result, expected); - } - - #[test] - fn test_empty_list() { - let list = ListViewArray::new( - BoolArray::from_iter(Vec::::new()).into_array(), - buffer![0i32, 0i32, 0i32].into_array(), - buffer![0i32, 0i32, 0i32].into_array(), - Validity::AllValid, - ); - - // Compare two lists together - let result = compare(list.as_ref(), list.as_ref(), Operator::Eq).unwrap(); - assert!(result.scalar_at(0).unwrap().is_valid()); - assert!(result.scalar_at(1).unwrap().is_valid()); - assert!(result.scalar_at(2).unwrap().is_valid()); - } } diff --git a/vortex-array/src/compute/conformance/binary_numeric.rs b/vortex-array/src/compute/conformance/binary_numeric.rs index 1b030e400cf..1be9ffa7af9 100644 --- a/vortex-array/src/compute/conformance/binary_numeric.rs +++ b/vortex-array/src/compute/conformance/binary_numeric.rs @@ -32,9 +32,12 @@ use vortex_error::vortex_panic; use crate::Array; use crate::ArrayRef; use crate::IntoArray; +use crate::LEGACY_SESSION; +use crate::RecursiveCanonical; use crate::ToCanonical; +use crate::VortexSessionExecute; use crate::arrays::ConstantArray; -use crate::compute::numeric::numeric; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::NativePType; use crate::dtype::PType; @@ -102,22 +105,23 @@ where .cast(array.dtype()) .vortex_expect("operation should succeed in conformance test"); - let operators: [NumericOperator; 6] = [ + let operators: [NumericOperator; 4] = [ NumericOperator::Add, NumericOperator::Sub, - NumericOperator::RSub, NumericOperator::Mul, NumericOperator::Div, - NumericOperator::RDiv, ]; for operator in operators { + let op = operator.into(); + let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array(); + // Test array operator scalar (e.g., array + 1) - let result = numeric( - &array, - &ConstantArray::new(scalar_one.clone(), array.len()).into_array(), - operator, - ); + let result = array + .binary(rhs_const.clone(), op) + .vortex_expect("apply shouldn't fail") + .execute::(&mut LEGACY_SESSION.create_execution_ctx()) + .map(|c| c.0.into_array()); // Skip this operator if the entire operation fails // This can happen for some edge cases in specific encodings @@ -125,6 +129,9 @@ where continue; }; + println!("result {}", result.display_tree()); + println!("result {}", result.display_values()); + let actual_values = to_vec_of_scalar(&result); // Check each element for overflow/underflow @@ -153,11 +160,10 @@ where } // Test scalar operator array (e.g., 1 + array) - let result = numeric( - &ConstantArray::new(scalar_one.clone(), array.len()).into_array(), - &array, - operator, - ); + let result = rhs_const.binary(array.clone(), op).and_then(|a| { + a.execute::(&mut LEGACY_SESSION.create_execution_ctx()) + .map(|c| c.0.into_array()) + }); // Skip this operator if the entire operation fails let Ok(result) = result else { @@ -341,29 +347,30 @@ where vec![ NumericOperator::Add, NumericOperator::Sub, - NumericOperator::RSub, NumericOperator::Mul, ] } else { vec![ NumericOperator::Add, NumericOperator::Sub, - NumericOperator::RSub, NumericOperator::Mul, NumericOperator::Div, - NumericOperator::RDiv, ] }; for operator in operators { + let op = operator.into(); + let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array(); + // Test array operator scalar - let result = numeric( - &array, - &ConstantArray::new(scalar.clone(), array.len()).into_array(), - operator, - ); + let result = array + .binary(rhs_const, op) + .vortex_expect("apply failed") + .execute::(&mut LEGACY_SESSION.create_execution_ctx()) + .map(|x| x.0.into_array()); // Skip if the entire operation fails + // TODO(joe): this is odd. if result.is_err() { continue; } diff --git a/vortex-array/src/compute/conformance/consistency.rs b/vortex-array/src/compute/conformance/consistency.rs index 3c8c9fe08b1..898e766dc5f 100644 --- a/vortex-array/src/compute/conformance/consistency.rs +++ b/vortex-array/src/compute/conformance/consistency.rs @@ -29,15 +29,12 @@ use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::builtins::ArrayBuiltins; -use crate::compute::Operator; -use crate::compute::compare; use crate::compute::invert; use crate::compute::mask; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; -use crate::expr::and_kleene; -use crate::expr::or_kleene; +use crate::expr::Operator; /// Tests that filter and take operations produce consistent results. /// @@ -712,8 +709,12 @@ fn test_comparison_inverse_consistency(array: &dyn Array) { // Test Eq vs NotEq let const_array = crate::arrays::ConstantArray::new(test_scalar, len); if let (Ok(eq_result), Ok(neq_result)) = ( - compare(array, const_array.as_ref(), Operator::Eq), - compare(array, const_array.as_ref(), Operator::NotEq), + array + .to_array() + .binary(const_array.to_array(), Operator::Eq), + array + .to_array() + .binary(const_array.to_array(), Operator::NotEq), ) { let inverted_eq = invert(&eq_result).vortex_expect("invert should succeed in conformance test"); @@ -741,8 +742,12 @@ fn test_comparison_inverse_consistency(array: &dyn Array) { // Test Gt vs Lte if let (Ok(gt_result), Ok(lte_result)) = ( - compare(array, const_array.as_ref(), Operator::Gt), - compare(array, const_array.as_ref(), Operator::Lte), + array + .to_array() + .binary(const_array.to_array(), Operator::Gt), + array + .to_array() + .binary(const_array.to_array(), Operator::Lte), ) { let inverted_gt = invert(>_result).vortex_expect("invert should succeed in conformance test"); @@ -764,8 +769,12 @@ fn test_comparison_inverse_consistency(array: &dyn Array) { // Test Lt vs Gte if let (Ok(lt_result), Ok(gte_result)) = ( - compare(array, const_array.as_ref(), Operator::Lt), - compare(array, const_array.as_ref(), Operator::Gte), + array + .to_array() + .binary(const_array.to_array(), Operator::Lt), + array + .to_array() + .binary(const_array.to_array(), Operator::Gte), ) { let inverted_lt = invert(<_result).vortex_expect("invert should succeed in conformance test"); @@ -827,8 +836,12 @@ fn test_comparison_symmetry_consistency(array: &dyn Array) { // Test Gt vs Lt symmetry if let (Ok(arr_gt_scalar), Ok(scalar_lt_arr)) = ( - compare(array, const_array.as_ref(), Operator::Gt), - compare(const_array.as_ref(), array, Operator::Lt), + array + .to_array() + .binary(const_array.to_array(), Operator::Gt), + const_array + .to_array() + .binary(array.to_array(), Operator::Lt), ) { assert_eq!( arr_gt_scalar.len(), @@ -853,8 +866,12 @@ fn test_comparison_symmetry_consistency(array: &dyn Array) { // Test Eq symmetry if let (Ok(arr_eq_scalar), Ok(scalar_eq_arr)) = ( - compare(array, const_array.as_ref(), Operator::Eq), - compare(const_array.as_ref(), array, Operator::Eq), + array + .to_array() + .binary(const_array.to_array(), Operator::Eq), + const_array + .to_array() + .binary(array.to_array(), Operator::Eq), ) { for i in 0..arr_eq_scalar.len() { let arr_eq = arr_eq_scalar @@ -901,13 +918,16 @@ fn test_boolean_demorgan_consistency(array: &dyn Array) { let mask = mask.as_ref(); // Test first De Morgan's law: NOT(A AND B) = (NOT A) OR (NOT B) - if let (Ok(a_and_b), Ok(not_a), Ok(not_b)) = - (and_kleene(array, mask), invert(array), invert(mask)) - { + if let (Ok(a_and_b), Ok(not_a), Ok(not_b)) = ( + array.to_array().binary(mask.to_array(), Operator::And), + invert(array), + invert(mask), + ) { let not_a_and_b = invert(&a_and_b).vortex_expect("invert should succeed in conformance test"); - let not_a_or_not_b = - or_kleene(¬_a, ¬_b).vortex_expect("or should succeed in conformance test"); + let not_a_or_not_b = not_a + .binary(not_b.clone(), Operator::Or) + .vortex_expect("or should succeed in conformance test"); assert_eq!( not_a_and_b.len(), @@ -931,12 +951,15 @@ fn test_boolean_demorgan_consistency(array: &dyn Array) { } // Test second De Morgan's law: NOT(A OR B) = (NOT A) AND (NOT B) - if let (Ok(a_or_b), Ok(not_a), Ok(not_b)) = - (or_kleene(array, mask), invert(array), invert(mask)) - { + if let (Ok(a_or_b), Ok(not_a), Ok(not_b)) = ( + array.to_array().binary(mask.to_array(), Operator::Or), + invert(array), + invert(mask), + ) { let not_a_or_b = invert(&a_or_b).vortex_expect("invert should succeed in conformance test"); - let not_a_and_not_b = - and_kleene(¬_a, ¬_b).vortex_expect("and should succeed in conformance test"); + let not_a_and_not_b = not_a + .binary(not_b.clone(), Operator::And) + .vortex_expect("and should succeed in conformance test"); for i in 0..not_a_or_b.len() { let left = not_a_or_b diff --git a/vortex-array/src/compute/numeric.rs b/vortex-array/src/compute/numeric.rs index ddecba4370f..7916de00fc3 100644 --- a/vortex-array/src/compute/numeric.rs +++ b/vortex-array/src/compute/numeric.rs @@ -7,11 +7,12 @@ use vortex_error::VortexResult; use crate::Array; use crate::ArrayRef; -use crate::IntoArray; use crate::arrays::ConstantArray; use crate::arrow::Datum; use crate::arrow::from_arrow_array_with_len; +use crate::builtins::ArrayBuiltins; use crate::compute::Options; +use crate::expr::Operator; use crate::scalar::NumericOperator; use crate::scalar::Scalar; @@ -20,62 +21,59 @@ use crate::scalar::Scalar; /// Errs at runtime if the sum would overflow or underflow. /// /// The result is null at any index that either input is null. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { - numeric(lhs, rhs, NumericOperator::Add) + lhs.to_array().binary(rhs.to_array(), Operator::Add) } /// Point-wise add a scalar value to this array on the right-hand-side. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult { - numeric( - lhs, - &ConstantArray::new(rhs, lhs.len()).into_array(), - NumericOperator::Add, - ) + lhs.to_array() + .binary(ConstantArray::new(rhs, lhs.len()).to_array(), Operator::Add) } /// Point-wise subtract two numeric arrays. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { - numeric(lhs, rhs, NumericOperator::Sub) + lhs.to_array().binary(rhs.to_array(), Operator::Sub) } /// Point-wise subtract a scalar value from this array on the right-hand-side. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult { - numeric( - lhs, - &ConstantArray::new(rhs, lhs.len()).into_array(), - NumericOperator::Sub, - ) + lhs.to_array() + .binary(ConstantArray::new(rhs, lhs.len()).to_array(), Operator::Sub) } /// Point-wise multiply two numeric arrays. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { - numeric(lhs, rhs, NumericOperator::Mul) + lhs.to_array().binary(rhs.to_array(), Operator::Mul) } /// Point-wise multiply a scalar value into this array on the right-hand-side. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult { - numeric( - lhs, - &ConstantArray::new(rhs, lhs.len()).into_array(), - NumericOperator::Mul, - ) + lhs.to_array() + .binary(ConstantArray::new(rhs, lhs.len()).to_array(), Operator::Mul) } /// Point-wise divide two numeric arrays. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { - numeric(lhs, rhs, NumericOperator::Div) + lhs.to_array().binary(rhs.to_array(), Operator::Div) } /// Point-wise divide a scalar value into this array on the right-hand-side. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult { - numeric( - lhs, - &ConstantArray::new(rhs, lhs.len()).into_array(), - NumericOperator::Mul, - ) + lhs.to_array() + .binary(ConstantArray::new(rhs, lhs.len()).to_array(), Operator::Div) } /// Point-wise numeric operation between two arrays of the same type and length. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult { arrow_numeric(lhs, rhs, op) } @@ -101,16 +99,15 @@ pub(crate) fn arrow_numeric( let array = match operator { NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?, NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?, - NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?, NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?, NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?, - NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?, }; from_arrow_array_with_len(array.as_ref(), len, nullable) } #[cfg(test)] +#[allow(deprecated)] mod test { use vortex_buffer::buffer; diff --git a/vortex-array/src/expr/exprs/between/mod.rs b/vortex-array/src/expr/exprs/between/mod.rs index 2bd8ee32b94..c31ff78ac5c 100644 --- a/vortex-array/src/expr/exprs/between/mod.rs +++ b/vortex-array/src/expr/exprs/between/mod.rs @@ -26,7 +26,6 @@ use crate::arrays::DecimalVTable; use crate::arrays::PrimitiveVTable; use crate::builtins::ArrayBuiltins; use crate::compute::Options; -use crate::compute::compare; use crate::dtype::DType; use crate::dtype::DType::Bool; use crate::expr::Arity; @@ -39,6 +38,7 @@ use crate::expr::VTableExt; use crate::expr::execute_boolean; use crate::expr::expression::Expression; use crate::expr::exprs::binary::Binary; +use crate::expr::exprs::operators::CompareOperator; use crate::expr::exprs::operators::Operator; use crate::scalar::Scalar; @@ -80,10 +80,17 @@ pub enum StrictComparison { } impl StrictComparison { - pub const fn to_operator(&self) -> crate::compute::Operator { + pub const fn to_compare_operator(&self) -> CompareOperator { match self { - StrictComparison::Strict => crate::compute::Operator::Lt, - StrictComparison::NonStrict => crate::compute::Operator::Lte, + StrictComparison::Strict => CompareOperator::Lt, + StrictComparison::NonStrict => CompareOperator::Lte, + } + } + + pub const fn to_operator(&self) -> Operator { + match self { + StrictComparison::Strict => Operator::Lt, + StrictComparison::NonStrict => Operator::Lte, } } @@ -161,11 +168,15 @@ fn between_canonical( // TODO(joe): return lazy compare once the executor supports this // Fall back to compare + boolean and - execute_boolean( - &compare(lower, arr, options.lower_strict.to_operator())?, - &compare(arr, upper, options.upper_strict.to_operator())?, - Operator::And, - ) + let lower_cmp = lower.to_array().binary( + arr.to_array(), + Operator::from(options.lower_strict.to_compare_operator()), + )?; + let upper_cmp = arr.to_array().binary( + upper.to_array(), + Operator::from(options.upper_strict.to_compare_operator()), + )?; + execute_boolean(&lower_cmp, &upper_cmp, Operator::And) } /// An optimized scalar expression to compute whether values fall between two bounds. @@ -317,11 +328,8 @@ impl VTable for Between { let lower = expr.child(1).clone(); let upper = expr.child(2).clone(); - let lhs = Binary.new_expr( - options.lower_strict.to_operator().into(), - [lower, arr.clone()], - ); - let rhs = Binary.new_expr(options.upper_strict.to_operator().into(), [arr, upper]); + let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]); + let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]); Binary .new_expr(Operator::And, [lhs, rhs]) diff --git a/vortex-array/src/expr/exprs/binary/boolean.rs b/vortex-array/src/expr/exprs/binary/boolean.rs index 5e215d2dec2..d3cf5604bbe 100644 --- a/vortex-array/src/expr/exprs/binary/boolean.rs +++ b/vortex-array/src/expr/exprs/binary/boolean.rs @@ -8,40 +8,25 @@ use vortex_error::vortex_err; use crate::Array; use crate::ArrayRef; -use crate::IntoArray; use crate::arrays::ConstantArray; use crate::arrays::ConstantVTable; -use crate::arrays::ScalarFnArray; use crate::arrow::FromArrowArray; use crate::arrow::IntoArrowArray; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; -use crate::expr::Binary; -use crate::expr::ScalarFn; use crate::expr::operators::Operator; use crate::scalar::Scalar; /// Point-wise Kleene logical _and_ between two Boolean arrays. -/// -/// Returns a lazy [`ScalarFnArray`] wrapping the [`Binary`] expression. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn and_kleene(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { - Ok(ScalarFnArray::try_new( - ScalarFn::new(Binary, Operator::And), - vec![lhs.to_array(), rhs.to_array()], - lhs.len(), - )? - .into_array()) + lhs.to_array().binary(rhs.to_array(), Operator::And) } /// Point-wise Kleene logical _or_ between two Boolean arrays. -/// -/// Returns a lazy [`ScalarFnArray`] wrapping the [`Binary`] expression. +#[deprecated(note = "Use `ArrayBuiltins::binary` instead")] pub fn or_kleene(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult { - Ok(ScalarFnArray::try_new( - ScalarFn::new(Binary, Operator::Or), - vec![lhs.to_array(), rhs.to_array()], - lhs.len(), - )? - .into_array()) + lhs.to_array().binary(rhs.to_array(), Operator::Or) } /// Execute a Kleene boolean operation between two arrays. @@ -115,19 +100,19 @@ fn constant_boolean( .map(|b| Scalar::bool(b, nullable.into())) .unwrap_or_else(|| Scalar::null(DType::Bool(nullable.into()))); - Ok(Some(ConstantArray::new(scalar, length).into_array())) + Ok(Some(ConstantArray::new(scalar, length).to_array())) } #[cfg(test)] mod tests { use rstest::rstest; - use super::and_kleene; - use super::or_kleene; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::BoolArray; + use crate::builtins::ArrayBuiltins; use crate::canonical::ToCanonical; + use crate::expr::operators::Operator; #[rstest] #[case( @@ -139,7 +124,7 @@ mod tests { BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)]).into_array(), )] fn test_or(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) { - let r = or_kleene(&lhs, &rhs).unwrap(); + let r = lhs.binary(rhs, Operator::Or).unwrap(); let r = r.to_bool().into_array(); let v0 = r.scalar_at(0).unwrap().as_bool().value(); @@ -163,7 +148,11 @@ mod tests { BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)]).into_array(), )] fn test_and(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) { - let r = and_kleene(&lhs, &rhs).unwrap().to_bool().into_array(); + let r = lhs + .binary(rhs, Operator::And) + .unwrap() + .to_bool() + .into_array(); let v0 = r.scalar_at(0).unwrap().as_bool().value(); let v1 = r.scalar_at(1).unwrap().as_bool().value(); diff --git a/vortex-array/src/expr/exprs/binary/compare.rs b/vortex-array/src/expr/exprs/binary/compare.rs index 67f0d822a80..1d17534fe52 100644 --- a/vortex-array/src/expr/exprs/binary/compare.rs +++ b/vortex-array/src/expr/exprs/binary/compare.rs @@ -18,10 +18,11 @@ use crate::arrays::ScalarFnVTable; use crate::arrow::Datum; use crate::arrow::IntoArrowArray; use crate::arrow::from_arrow_array_with_len; -use crate::compute::Operator; use crate::compute::compare_nested_arrow_arrays; -use crate::compute::scalar_cmp; +use crate::dtype::DType; +use crate::dtype::Nullability; use crate::expr::Binary; +use crate::expr::CompareOperator; use crate::kernel::ExecuteParentKernel; use crate::scalar::Scalar; use crate::vtable::VTable; @@ -35,7 +36,7 @@ pub trait CompareKernel: VTable { fn compare( lhs: &Self::Array, rhs: &dyn Array, - operator: Operator, + operator: CompareOperator, ctx: &mut ExecutionCtx, ) -> VortexResult>; } @@ -62,7 +63,7 @@ where ctx: &mut ExecutionCtx, ) -> VortexResult> { // Only handle comparison operators - let Some(cmp_op) = parent.options.maybe_cmp_operator() else { + let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else { return Ok(None); }; @@ -86,18 +87,14 @@ where // Empty array → empty bool result if len == 0 { return Ok(Some( - Canonical::empty(&crate::dtype::DType::Bool(nullable.into())).into_array(), + Canonical::empty(&DType::Bool(nullable.into())).into_array(), )); } // Null constant on either side → all-null bool result if other.as_constant().is_some_and(|s| s.is_null()) { return Ok(Some( - ConstantArray::new( - Scalar::null(crate::dtype::DType::Bool(nullable.into())), - len, - ) - .into_array(), + ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(), )); } @@ -112,22 +109,20 @@ where pub(crate) fn execute_compare( lhs: &dyn Array, rhs: &dyn Array, - op: Operator, + op: CompareOperator, ) -> VortexResult { let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); if lhs.is_empty() { - return Ok(Canonical::empty(&crate::dtype::DType::Bool(nullable.into())).into_array()); + return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array()); } let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false); let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false); if left_constant_null || right_constant_null { - return Ok(ConstantArray::new( - Scalar::null(crate::dtype::DType::Bool(nullable.into())), - lhs.len(), - ) - .into_array()); + return Ok( + ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(), + ); } // Constant-constant fast path @@ -146,7 +141,7 @@ pub(crate) fn execute_compare( fn arrow_compare_arrays( left: &dyn Array, right: &dyn Array, - operator: Operator, + operator: CompareOperator, ) -> VortexResult { assert_eq!(left.len(), right.len()); @@ -172,13 +167,293 @@ fn arrow_compare_arrays( let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?; match operator { - Operator::Eq => cmp::eq(&lhs, &rhs)?, - Operator::NotEq => cmp::neq(&lhs, &rhs)?, - Operator::Gt => cmp::gt(&lhs, &rhs)?, - Operator::Gte => cmp::gt_eq(&lhs, &rhs)?, - Operator::Lt => cmp::lt(&lhs, &rhs)?, - Operator::Lte => cmp::lt_eq(&lhs, &rhs)?, + CompareOperator::Eq => cmp::eq(&lhs, &rhs)?, + CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?, + CompareOperator::Gt => cmp::gt(&lhs, &rhs)?, + CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?, + CompareOperator::Lt => cmp::lt(&lhs, &rhs)?, + CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?, } }; from_arrow_array_with_len(&array, left.len(), nullable) } + +pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar { + if lhs.is_null() | rhs.is_null() { + Scalar::null(DType::Bool(Nullability::Nullable)) + } else { + let b = match operator { + CompareOperator::Eq => lhs == rhs, + CompareOperator::NotEq => lhs != rhs, + CompareOperator::Gt => lhs > rhs, + CompareOperator::Gte => lhs >= rhs, + CompareOperator::Lt => lhs < rhs, + CompareOperator::Lte => lhs <= rhs, + }; + + Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rstest::rstest; + use vortex_buffer::buffer; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::ToCanonical; + use crate::arrays::BoolArray; + use crate::arrays::ConstantArray; + use crate::arrays::ListArray; + use crate::arrays::ListViewArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::StructArray; + use crate::arrays::VarBinArray; + use crate::arrays::VarBinViewArray; + use crate::assert_arrays_eq; + use crate::builtins::ArrayBuiltins; + use crate::dtype::DType; + use crate::dtype::FieldName; + use crate::dtype::FieldNames; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::expr::Operator; + use crate::scalar::Scalar; + use crate::test_harness::to_int_indices; + use crate::validity::Validity; + + #[test] + fn test_bool_basic_comparisons() { + use vortex_buffer::BitBuffer; + + let arr = BoolArray::new( + BitBuffer::from_iter([true, true, false, true, false]), + Validity::from_iter([false, true, true, true, true]), + ); + + let matches = arr + .to_array() + .binary(arr.to_array(), Operator::Eq) + .unwrap() + .to_bool(); + assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]); + + let matches = arr + .to_array() + .binary(arr.to_array(), Operator::NotEq) + .unwrap() + .to_bool(); + let empty: [u64; 0] = []; + assert_eq!(to_int_indices(matches).unwrap(), empty); + + let other = BoolArray::new( + BitBuffer::from_iter([false, false, false, true, true]), + Validity::from_iter([false, true, true, true, true]), + ); + + let matches = arr + .to_array() + .binary(other.to_array(), Operator::Lte) + .unwrap() + .to_bool(); + assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]); + + let matches = arr + .to_array() + .binary(other.to_array(), Operator::Lt) + .unwrap() + .to_bool(); + assert_eq!(to_int_indices(matches).unwrap(), [4u64]); + + let matches = other + .to_array() + .binary(arr.to_array(), Operator::Gte) + .unwrap() + .to_bool(); + assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]); + + let matches = other + .to_array() + .binary(arr.to_array(), Operator::Gt) + .unwrap() + .to_bool(); + assert_eq!(to_int_indices(matches).unwrap(), [4u64]); + } + + #[test] + fn constant_compare() { + let left = ConstantArray::new(Scalar::from(2u32), 10); + let right = ConstantArray::new(Scalar::from(10u32), 10); + + let result = left + .to_array() + .binary(right.to_array(), Operator::Gt) + .unwrap(); + assert_eq!(result.len(), 10); + let scalar = result.scalar_at(0).unwrap(); + assert_eq!(scalar.as_bool().value(), Some(false)); + } + + #[rstest] + #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())] + #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())] + #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())] + #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())] + fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) { + let res = left.binary(right, Operator::Eq).unwrap(); + let expected = BoolArray::from_iter([true, true]); + assert_arrays_eq!(res, expected); + } + + #[ignore = "Arrow's ListView cannot be compared"] + #[test] + fn test_list_array_comparison() { + let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]); + let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]); + let list1 = ListArray::try_new( + values1.into_array(), + offsets1.into_array(), + Validity::NonNullable, + ) + .unwrap(); + + let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]); + let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]); + let list2 = ListArray::try_new( + values2.into_array(), + offsets2.into_array(), + Validity::NonNullable, + ) + .unwrap(); + + let result = list1 + .to_array() + .binary(list2.to_array(), Operator::Eq) + .unwrap(); + let expected = BoolArray::from_iter([true, true, false]); + assert_arrays_eq!(result, expected); + + let result = list1 + .to_array() + .binary(list2.to_array(), Operator::NotEq) + .unwrap(); + let expected = BoolArray::from_iter([false, false, true]); + assert_arrays_eq!(result, expected); + + let result = list1 + .to_array() + .binary(list2.to_array(), Operator::Lt) + .unwrap(); + let expected = BoolArray::from_iter([false, false, true]); + assert_arrays_eq!(result, expected); + } + + #[ignore = "Arrow's ListView cannot be compared"] + #[test] + fn test_list_array_constant_comparison() { + let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]); + let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]); + let list = ListArray::try_new( + values.into_array(), + offsets.into_array(), + Validity::NonNullable, + ) + .unwrap(); + + let list_scalar = Scalar::list( + Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)), + vec![3i32.into(), 4i32.into()], + Nullability::NonNullable, + ); + let constant = ConstantArray::new(list_scalar, 3); + + let result = list + .to_array() + .binary(constant.to_array(), Operator::Eq) + .unwrap(); + let expected = BoolArray::from_iter([false, true, false]); + assert_arrays_eq!(result, expected); + } + + #[test] + fn test_struct_array_comparison() { + let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]); + let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]); + + let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]); + let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]); + + let struct1 = StructArray::from_fields(&[ + ("bool_col", bool_field1.into_array()), + ("int_col", int_field1.into_array()), + ]) + .unwrap(); + + let struct2 = StructArray::from_fields(&[ + ("bool_col", bool_field2.into_array()), + ("int_col", int_field2.into_array()), + ]) + .unwrap(); + + let result = struct1 + .to_array() + .binary(struct2.to_array(), Operator::Eq) + .unwrap(); + let expected = BoolArray::from_iter([true, true, false]); + assert_arrays_eq!(result, expected); + + let result = struct1 + .to_array() + .binary(struct2.to_array(), Operator::Gt) + .unwrap(); + let expected = BoolArray::from_iter([false, false, true]); + assert_arrays_eq!(result, expected); + } + + #[test] + fn test_empty_struct_compare() { + let empty1 = StructArray::try_new( + FieldNames::from(Vec::::new()), + Vec::new(), + 5, + Validity::NonNullable, + ) + .unwrap(); + + let empty2 = StructArray::try_new( + FieldNames::from(Vec::::new()), + Vec::new(), + 5, + Validity::NonNullable, + ) + .unwrap(); + + let result = empty1 + .to_array() + .binary(empty2.to_array(), Operator::Eq) + .unwrap(); + let expected = BoolArray::from_iter([true, true, true, true, true]); + assert_arrays_eq!(result, expected); + } + + #[test] + fn test_empty_list() { + let list = ListViewArray::new( + BoolArray::from_iter(Vec::::new()).into_array(), + buffer![0i32, 0i32, 0i32].into_array(), + buffer![0i32, 0i32, 0i32].into_array(), + Validity::AllValid, + ); + + let result = list + .to_array() + .binary(list.to_array(), Operator::Eq) + .unwrap(); + assert!(result.scalar_at(0).unwrap().is_valid()); + assert!(result.scalar_at(1).unwrap().is_valid()); + assert!(result.scalar_at(2).unwrap().is_valid()); + } +} diff --git a/vortex-array/src/expr/exprs/binary/mod.rs b/vortex-array/src/expr/exprs/binary/mod.rs index 8d4d29d1c5d..e2c102d88c7 100644 --- a/vortex-array/src/expr/exprs/binary/mod.rs +++ b/vortex-array/src/expr/exprs/binary/mod.rs @@ -3,7 +3,9 @@ use std::fmt::Formatter; +#[expect(deprecated)] pub use boolean::and_kleene; +#[expect(deprecated)] pub use boolean::or_kleene; use prost::Message; use vortex_error::VortexExpect; @@ -13,7 +15,6 @@ use vortex_proto::expr as pb; use vortex_session::VortexSession; use crate::ArrayRef; -use crate::compute; use crate::dtype::DType; use crate::expr::Arity; use crate::expr::ChildName; @@ -24,6 +25,7 @@ use crate::expr::VTable; use crate::expr::VTableExt; use crate::expr::expression::Expression; use crate::expr::exprs::literal::lit; +use crate::expr::exprs::operators::CompareOperator; use crate::expr::exprs::operators::Operator; use crate::expr::stats::Stat; @@ -118,12 +120,12 @@ impl VTable for Binary { }; match op { - Operator::Eq => execute_compare(lhs, rhs, compute::Operator::Eq), - Operator::NotEq => execute_compare(lhs, rhs, compute::Operator::NotEq), - Operator::Lt => execute_compare(lhs, rhs, compute::Operator::Lt), - Operator::Lte => execute_compare(lhs, rhs, compute::Operator::Lte), - Operator::Gt => execute_compare(lhs, rhs, compute::Operator::Gt), - Operator::Gte => execute_compare(lhs, rhs, compute::Operator::Gte), + Operator::Eq => execute_compare(lhs, rhs, CompareOperator::Eq), + Operator::NotEq => execute_compare(lhs, rhs, CompareOperator::NotEq), + Operator::Lt => execute_compare(lhs, rhs, CompareOperator::Lt), + Operator::Lte => execute_compare(lhs, rhs, CompareOperator::Lte), + Operator::Gt => execute_compare(lhs, rhs, CompareOperator::Gt), + Operator::Gte => execute_compare(lhs, rhs, CompareOperator::Gte), Operator::And => execute_boolean(lhs, rhs, Operator::And), Operator::Or => execute_boolean(lhs, rhs, Operator::Or), Operator::Add => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Add), @@ -560,7 +562,7 @@ pub fn checked_add(lhs: Expression, rhs: Expression) -> Expression { mod tests { use super::*; use crate::assert_arrays_eq; - use crate::compute::compare; + use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::Expression; @@ -732,16 +734,17 @@ mod tests { .unwrap() .into_array(); - // Test using compare compute function directly - let result_equal = compare(&lhs_struct, &rhs_struct_equal, compute::Operator::Eq).unwrap(); + // Test using binary method directly + let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap(); assert_eq!( result_equal.scalar_at(0).vortex_expect("value"), Scalar::bool(true, Nullability::NonNullable), "Equal structs should be equal" ); - let result_different = - compare(&lhs_struct, &rhs_struct_different, compute::Operator::Eq).unwrap(); + let result_different = lhs_struct + .binary(rhs_struct_different, Operator::Eq) + .unwrap(); assert_eq!( result_different.scalar_at(0).vortex_expect("value"), Scalar::bool(false, Nullability::NonNullable), diff --git a/vortex-array/src/expr/exprs/binary/numeric.rs b/vortex-array/src/expr/exprs/binary/numeric.rs index aefd54a8de8..6452930c035 100644 --- a/vortex-array/src/expr/exprs/binary/numeric.rs +++ b/vortex-array/src/expr/exprs/binary/numeric.rs @@ -2,7 +2,6 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_error::vortex_err; use crate::Array; use crate::ArrayRef; @@ -39,14 +38,14 @@ fn constant_numeric( return Ok(None); }; - Ok(Some( - ConstantArray::new( - lhs.scalar() - .as_primitive() - .checked_binary_numeric(&rhs.scalar().as_primitive(), op) - .ok_or_else(|| vortex_err!("numeric overflow"))?, - lhs.len(), - ) - .into_array(), - )) + let Some(result) = lhs + .scalar() + .as_primitive() + .checked_binary_numeric(&rhs.scalar().as_primitive(), op) + else { + // Overflow detected — fall through to arrow_numeric which uses wrapping arithmetic. + return Ok(None); + }; + + Ok(Some(ConstantArray::new(result, lhs.len()).into_array())) } diff --git a/vortex-array/src/expr/exprs/dynamic.rs b/vortex-array/src/expr/exprs/dynamic.rs index 93d3eed0120..b6896630079 100644 --- a/vortex-array/src/expr/exprs/dynamic.rs +++ b/vortex-array/src/expr/exprs/dynamic.rs @@ -17,14 +17,15 @@ use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; -use crate::compute::Operator; use crate::dtype::DType; use crate::expr::Arity; use crate::expr::Binary; use crate::expr::ChildName; +use crate::expr::CompareOperator; use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; +use crate::expr::Operator; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -98,11 +99,13 @@ impl VTable for DynamicComparison { .map_err(|_| vortex_error::vortex_err!("Wrong arg count for DynamicComparison"))?; let rhs = ConstantArray::new(scalar, args.row_count).into_array(); - return Binary.bind(data.operator.into()).execute(ExecutionArgs { - inputs: vec![lhs, rhs], - row_count: args.row_count, - ctx: args.ctx, - }); + return Binary + .bind(Operator::from(data.operator)) + .execute(ExecutionArgs { + inputs: vec![lhs, rhs], + row_count: args.row_count, + ctx: args.ctx, + }); } let ret_dtype = DType::Bool(args.inputs[0].dtype().nullability() | data.rhs.dtype.nullability()); @@ -122,39 +125,39 @@ impl VTable for DynamicComparison { ) -> Option { let lhs = expr.child(0); match dynamic.operator { - Operator::Gt => Some(DynamicComparison.new_expr( + CompareOperator::Eq | CompareOperator::NotEq => None, + CompareOperator::Gt => Some(DynamicComparison.new_expr( DynamicComparisonExpr { - operator: Operator::Lte, + operator: CompareOperator::Lte, rhs: dynamic.rhs.clone(), default: !dynamic.default, }, vec![lhs.stat_max(catalog)?], )), - Operator::Gte => Some(DynamicComparison.new_expr( + CompareOperator::Gte => Some(DynamicComparison.new_expr( DynamicComparisonExpr { - operator: Operator::Lt, + operator: CompareOperator::Lt, rhs: dynamic.rhs.clone(), default: !dynamic.default, }, vec![lhs.stat_max(catalog)?], )), - Operator::Lt => Some(DynamicComparison.new_expr( + CompareOperator::Lt => Some(DynamicComparison.new_expr( DynamicComparisonExpr { - operator: Operator::Gte, + operator: CompareOperator::Gte, rhs: dynamic.rhs.clone(), default: !dynamic.default, }, vec![lhs.stat_min(catalog)?], )), - Operator::Lte => Some(DynamicComparison.new_expr( + CompareOperator::Lte => Some(DynamicComparison.new_expr( DynamicComparisonExpr { - operator: Operator::Gt, + operator: CompareOperator::Gt, rhs: dynamic.rhs.clone(), default: !dynamic.default, }, vec![lhs.stat_min(catalog)?], )), - _ => None, } } @@ -165,7 +168,7 @@ impl VTable for DynamicComparison { } pub fn dynamic( - operator: Operator, + operator: CompareOperator, rhs_value: impl Fn() -> Option + Send + Sync + 'static, rhs_dtype: DType, default: bool, @@ -186,7 +189,7 @@ pub fn dynamic( #[derive(Clone, Debug)] pub struct DynamicComparisonExpr { - operator: Operator, + operator: CompareOperator, rhs: Arc, // Default value for the dynamic comparison. default: bool, @@ -346,7 +349,7 @@ mod tests { #[test] fn return_dtype_bool() -> VortexResult<()> { let expr = dynamic( - Operator::Lt, + CompareOperator::Lt, || Some(5i32.into()), DType::Primitive(PType::I32, Nullability::NonNullable), true, @@ -364,7 +367,7 @@ mod tests { fn execute_with_value() -> VortexResult<()> { let input = buffer![1i32, 5, 10].into_array(); let expr = dynamic( - Operator::Lt, + CompareOperator::Lt, || Some(5i32.into()), DType::Primitive(PType::I32, Nullability::NonNullable), true, @@ -379,7 +382,7 @@ mod tests { fn execute_without_value_default_true() -> VortexResult<()> { let input = buffer![1i32, 5, 10].into_array(); let expr = dynamic( - Operator::Lt, + CompareOperator::Lt, || None, DType::Primitive(PType::I32, Nullability::NonNullable), true, @@ -394,7 +397,7 @@ mod tests { fn execute_without_value_default_false() -> VortexResult<()> { let input = buffer![1i32, 5, 10].into_array(); let expr = dynamic( - Operator::Lt, + CompareOperator::Lt, || None, DType::Primitive(PType::I32, Nullability::NonNullable), false, @@ -410,7 +413,7 @@ mod tests { let threshold = Arc::new(AtomicI32::new(5)); let threshold_clone = threshold.clone(); let expr = dynamic( - Operator::Lt, + CompareOperator::Lt, move || Some(threshold_clone.load(Ordering::SeqCst).into()), DType::Primitive(PType::I32, Nullability::NonNullable), true, diff --git a/vortex-array/src/expr/exprs/list_contains/mod.rs b/vortex-array/src/expr/exprs/list_contains/mod.rs index a0776a00b09..2d70d355037 100644 --- a/vortex-array/src/expr/exprs/list_contains/mod.rs +++ b/vortex-array/src/expr/exprs/list_contains/mod.rs @@ -25,19 +25,19 @@ use crate::arrays::ConstantArray; use crate::arrays::ConstantVTable; use crate::arrays::ListViewArray; use crate::arrays::PrimitiveArray; +use crate::arrays::ScalarFnArrayExt; use crate::builtins::ArrayBuiltins; -use crate::compute; -use crate::compute::Operator; use crate::dtype::DType; use crate::dtype::IntegerPType; use crate::dtype::Nullability; -use crate::expr; use crate::expr::Arity; +use crate::expr::Binary; use crate::expr::ChildName; use crate::expr::EmptyOptions; use crate::expr::ExecutionArgs; use crate::expr::ExprId; use crate::expr::Expression; +use crate::expr::Operator; use crate::expr::StatsCatalog; use crate::expr::VTable; use crate::expr::VTableExt; @@ -243,14 +243,18 @@ fn constant_list_scalar_contains( let false_scalar = Scalar::bool(false, nullability); for element in elements { - let res = compute::compare( - ConstantArray::new(element, len).as_ref(), - values, - Operator::Eq, - )? - .fill_null(false_scalar.clone())?; + let res = Binary + .try_new_array( + len, + Operator::Eq, + [ + ConstantArray::new(element, len).into_array(), + values.to_array(), + ], + )? + .fill_null(false_scalar.clone())?; if let Some(acc) = result { - result = Some(expr::or_kleene(&acc, &res)?) + result = Some(acc.binary(res, Operator::Or)?) } else { result = Some(res); } @@ -279,7 +283,8 @@ fn list_contains_scalar( } let rhs = ConstantArray::new(value.clone(), elems.len()); - let matching_elements = compute::compare(elems, rhs.as_ref(), Operator::Eq)?; + let matching_elements = + Binary.try_new_array(elems.len(), Operator::Eq, &[elems.clone(), rhs.to_array()])?; let matches = matching_elements.to_bool(); // Fast path: no elements match. diff --git a/vortex-array/src/expr/exprs/operators.rs b/vortex-array/src/expr/exprs/operators.rs index 9c23624aeaa..2b0502f08f3 100644 --- a/vortex-array/src/expr/exprs/operators.rs +++ b/vortex-array/src/expr/exprs/operators.rs @@ -6,12 +6,8 @@ use std::fmt::Display; use std::fmt::Formatter; use vortex_error::VortexError; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_proto::expr::binary_opts::BinaryOp; -use crate::compute; - /// Equalities, inequalities, and boolean operations over possibly null values. /// /// For most operations, if either side is null, the result is null. @@ -33,6 +29,7 @@ pub enum Operator { /// Expression is less or equal to another Lte, /// Boolean AND (∧). + // TODO(joe): rename to KleeneAnd And, /// Boolean OR (∨). // TODO(joe): rename to KleeneOr @@ -40,7 +37,6 @@ pub enum Operator { /// The sum of the arguments. /// /// Errs at runtime if the sum would overflow or underflow. - // TODO(joe): rename to KleeneAnd Add, /// The difference between the arguments. /// @@ -170,18 +166,6 @@ impl Operator { } } - pub fn maybe_cmp_operator(self) -> Option { - match self { - Operator::Eq => Some(compute::Operator::Eq), - Operator::NotEq => Some(compute::Operator::NotEq), - Operator::Lt => Some(compute::Operator::Lt), - Operator::Lte => Some(compute::Operator::Lte), - Operator::Gt => Some(compute::Operator::Gt), - Operator::Gte => Some(compute::Operator::Gte), - _ => None, - } - } - pub fn is_arithmetic(&self) -> bool { matches!(self, Self::Add | Self::Sub | Self::Mul | Self::Div) } @@ -194,31 +178,92 @@ impl Operator { } } -impl From for Operator { - fn from(cmp_operator: compute::Operator) -> Self { - match cmp_operator { - compute::Operator::Eq => Operator::Eq, - compute::Operator::NotEq => Operator::NotEq, - compute::Operator::Gt => Operator::Gt, - compute::Operator::Gte => Operator::Gte, - compute::Operator::Lt => Operator::Lt, - compute::Operator::Lte => Operator::Lte, +/// The six comparison operators, providing compile-time guarantees that only +/// comparison variants are used where comparisons are expected. +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum CompareOperator { + /// Expressions are equal. + Eq, + /// Expressions are not equal. + NotEq, + /// Expression is greater than another. + Gt, + /// Expression is greater or equal to another. + Gte, + /// Expression is less than another. + Lt, + /// Expression is less or equal to another. + Lte, +} + +impl CompareOperator { + /// Return the logical inverse of this comparison operator. + pub fn inverse(self) -> Self { + match self { + CompareOperator::Eq => CompareOperator::NotEq, + CompareOperator::NotEq => CompareOperator::Eq, + CompareOperator::Gt => CompareOperator::Lte, + CompareOperator::Gte => CompareOperator::Lt, + CompareOperator::Lt => CompareOperator::Gte, + CompareOperator::Lte => CompareOperator::Gt, + } + } + + /// Swap the sides of the operator so that swapping lhs and rhs preserves the result. + pub fn swap(self) -> Self { + match self { + CompareOperator::Eq => CompareOperator::Eq, + CompareOperator::NotEq => CompareOperator::NotEq, + CompareOperator::Gt => CompareOperator::Lt, + CompareOperator::Gte => CompareOperator::Lte, + CompareOperator::Lt => CompareOperator::Gt, + CompareOperator::Lte => CompareOperator::Gte, + } + } +} + +impl Display for CompareOperator { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let display = match self { + CompareOperator::Eq => "=", + CompareOperator::NotEq => "!=", + CompareOperator::Gt => ">", + CompareOperator::Gte => ">=", + CompareOperator::Lt => "<", + CompareOperator::Lte => "<=", + }; + Display::fmt(display, f) + } +} + +impl From for Operator { + fn from(value: CompareOperator) -> Self { + match value { + CompareOperator::Eq => Operator::Eq, + CompareOperator::NotEq => Operator::NotEq, + CompareOperator::Gt => Operator::Gt, + CompareOperator::Gte => Operator::Gte, + CompareOperator::Lt => Operator::Lt, + CompareOperator::Lte => Operator::Lte, } } } -impl TryInto for Operator { +impl TryFrom for CompareOperator { type Error = VortexError; - fn try_into(self) -> VortexResult { - Ok(match self { - Operator::Eq => compute::Operator::Eq, - Operator::NotEq => compute::Operator::NotEq, - Operator::Gt => compute::Operator::Gt, - Operator::Gte => compute::Operator::Gte, - Operator::Lt => compute::Operator::Lt, - Operator::Lte => compute::Operator::Lte, - _ => vortex_bail!("Not a compute operator: {}", self), - }) + fn try_from(value: Operator) -> Result { + match value { + Operator::Eq => Ok(CompareOperator::Eq), + Operator::NotEq => Ok(CompareOperator::NotEq), + Operator::Gt => Ok(CompareOperator::Gt), + Operator::Gte => Ok(CompareOperator::Gte), + Operator::Lt => Ok(CompareOperator::Lt), + Operator::Lte => Ok(CompareOperator::Lte), + other => Err(vortex_error::vortex_err!( + InvalidArgument: "{other} is not a comparison operator" + )), + } } } diff --git a/vortex-array/src/scalar/typed_view/decimal/scalar.rs b/vortex-array/src/scalar/typed_view/decimal/scalar.rs index 2dbc1458794..2edf696e509 100644 --- a/vortex-array/src/scalar/typed_view/decimal/scalar.rs +++ b/vortex-array/src/scalar/typed_view/decimal/scalar.rs @@ -227,10 +227,8 @@ impl<'a> DecimalScalar<'a> { let operation_result = match op { NumericOperator::Add => lhs.checked_add(&rhs), NumericOperator::Sub => lhs.checked_sub(&rhs), - NumericOperator::RSub => rhs.checked_sub(&lhs), NumericOperator::Mul => lhs.checked_mul(&rhs), NumericOperator::Div => lhs.checked_div(&rhs), - NumericOperator::RDiv => rhs.checked_div(&lhs), }?; // Check if the result fits within the precision constraints diff --git a/vortex-array/src/scalar/typed_view/decimal/tests.rs b/vortex-array/src/scalar/typed_view/decimal/tests.rs index f28d845f29d..a0486b254f3 100644 --- a/vortex-array/src/scalar/typed_view/decimal/tests.rs +++ b/vortex-array/src/scalar/typed_view/decimal/tests.rs @@ -934,43 +934,6 @@ fn test_decimal_scalar_precision_overflow() { assert_eq!(result, None); } -#[test] -fn test_decimal_scalar_rsub_and_rdiv() { - use crate::scalar::NumericOperator; - - let decimal1 = Scalar::decimal( - DecimalValue::I64(100), - DecimalDType::new(10, 2), - Nullability::NonNullable, - ); - let scalar1 = decimal1.as_decimal(); - - let decimal2 = Scalar::decimal( - DecimalValue::I64(300), - DecimalDType::new(10, 2), - Nullability::NonNullable, - ); - let scalar2 = decimal2.as_decimal(); - - // RSub: 300 - 100 = 200 - let result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::RSub) - .unwrap(); - assert_eq!( - result.decimal_value(), - Some(DecimalValue::I256(i256::from_i128(200))) - ); - - // RDiv: 300 / 100 = 3 - let result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::RDiv) - .unwrap(); - assert_eq!( - result.decimal_value(), - Some(DecimalValue::I256(i256::from_i128(3))) - ); -} - #[test] fn test_decimal_value_from_scalar() { let value = DecimalValue::I32(12345); diff --git a/vortex-array/src/scalar/typed_view/primitive/numeric_operator.rs b/vortex-array/src/scalar/typed_view/primitive/numeric_operator.rs index 03df9da8a59..cfaa7de63b4 100644 --- a/vortex-array/src/scalar/typed_view/primitive/numeric_operator.rs +++ b/vortex-array/src/scalar/typed_view/primitive/numeric_operator.rs @@ -14,18 +14,10 @@ pub enum NumericOperator { Add, /// Binary element-wise subtraction of two arrays or of two scalars. Sub, - /// Same as [NumericOperator::Sub] but with the parameters flipped: `right - left`. - RSub, /// Binary element-wise multiplication of two arrays or of two scalars. Mul, /// Binary element-wise division of two arrays or of two scalars. Div, - /// Same as [NumericOperator::Div] but with the parameters flipped: `right / left`. - RDiv, - // Missing from arrow-rs: - // Min, - // Max, - // Pow, } impl fmt::Display for NumericOperator { @@ -34,16 +26,13 @@ impl fmt::Display for NumericOperator { } } -impl NumericOperator { - /// Returns the operator with swapped operands (e.g., Sub becomes RSub). - pub fn swap(self) -> Self { - match self { - NumericOperator::Add => NumericOperator::Add, - NumericOperator::Sub => NumericOperator::RSub, - NumericOperator::RSub => NumericOperator::Sub, - NumericOperator::Mul => NumericOperator::Mul, - NumericOperator::Div => NumericOperator::RDiv, - NumericOperator::RDiv => NumericOperator::Div, +impl From for crate::expr::Operator { + fn from(op: NumericOperator) -> Self { + match op { + NumericOperator::Add => crate::expr::Operator::Add, + NumericOperator::Sub => crate::expr::Operator::Sub, + NumericOperator::Mul => crate::expr::Operator::Mul, + NumericOperator::Div => crate::expr::Operator::Div, } } } diff --git a/vortex-array/src/scalar/typed_view/primitive/scalar.rs b/vortex-array/src/scalar/typed_view/primitive/scalar.rs index 7bbf4c43307..e3981de8dd3 100644 --- a/vortex-array/src/scalar/typed_view/primitive/scalar.rs +++ b/vortex-array/src/scalar/typed_view/primitive/scalar.rs @@ -337,10 +337,8 @@ impl<'a> PrimitiveScalar<'a> { (Some(lhs), Some(rhs)) => match op { NumericOperator::Add => Some(lhs + rhs), NumericOperator::Sub => Some(lhs - rhs), - NumericOperator::RSub => Some(rhs - lhs), NumericOperator::Mul => Some(lhs * rhs), NumericOperator::Div => Some(lhs / rhs), - NumericOperator::RDiv => Some(rhs / lhs), } }; Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) }) @@ -373,10 +371,8 @@ impl<'a> PrimitiveScalar<'a> { (Some(lhs), Some(rhs)) => match op { NumericOperator::Add => lhs.checked_add(&rhs).map(Some), NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some), - NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some), NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some), NumericOperator::Div => lhs.checked_div(&rhs).map(Some), - NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some), }, }; diff --git a/vortex-array/src/scalar/typed_view/primitive/tests.rs b/vortex-array/src/scalar/typed_view/primitive/tests.rs index 1dacae1f304..b0c0063a3df 100644 --- a/vortex-array/src/scalar/typed_view/primitive/tests.rs +++ b/vortex-array/src/scalar/typed_view/primitive/tests.rs @@ -195,16 +195,6 @@ fn test_as_conversion_null() { assert_eq!(scalar.as_::(), None); } -#[test] -fn test_numeric_operator_swap() { - assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add); - assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub); - assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub); - assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul); - assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv); - assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div); -} - #[test] fn test_checked_binary_numeric_add() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); @@ -274,21 +264,6 @@ fn test_checked_binary_numeric_div() { assert_eq!(result.typed_value::(), Some(5)); } -#[test] -fn test_checked_binary_numeric_rdiv() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let value1 = ScalarValue::Primitive(PValue::I32(4)); - let value2 = ScalarValue::Primitive(PValue::I32(20)); - let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); - let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); - - // RDiv means right / left, so 20 / 4 = 5 - let result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::RDiv) - .unwrap(); - assert_eq!(result.typed_value::(), Some(5)); -} - #[test] fn test_checked_binary_numeric_div_by_zero() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); diff --git a/vortex-duckdb/src/convert/table_filter.rs b/vortex-duckdb/src/convert/table_filter.rs index a3a42a4e3b6..f1c816a10ad 100644 --- a/vortex-duckdb/src/convert/table_filter.rs +++ b/vortex-duckdb/src/convert/table_filter.rs @@ -4,13 +4,13 @@ use std::sync::Arc; use itertools::Itertools; -use vortex::compute::Operator; use vortex::dtype::DType; use vortex::dtype::Nullability; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_bail; use vortex::expr::Binary; +use vortex::expr::CompareOperator; use vortex::expr::Expression; use vortex::expr::VTableExt; use vortex::expr::and_collect; @@ -86,13 +86,15 @@ pub fn try_from_table_filter( } TableFilterClass::Dynamic(dynamic) => { let op = match dynamic.operator { - DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_EQUAL => Operator::Eq, - DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_NOTEQUAL => Operator::NotEq, - DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_LESSTHAN => Operator::Lt, - DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_GREATERTHAN => Operator::Gt, - DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_LESSTHANOREQUALTO => Operator::Lte, + DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_EQUAL => CompareOperator::Eq, + DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_NOTEQUAL => CompareOperator::NotEq, + DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_LESSTHAN => CompareOperator::Lt, + DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_GREATERTHAN => CompareOperator::Gt, + DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_LESSTHANOREQUALTO => { + CompareOperator::Lte + } DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_COMPARE_GREATERTHANOREQUALTO => { - Operator::Gte + CompareOperator::Gte } _ => vortex_bail!( "unsupported dynamic filter operator: {:?}", diff --git a/vortex-python/src/arrays/mod.rs b/vortex-python/src/arrays/mod.rs index ffa0bdf7daf..aec8aec5627 100644 --- a/vortex-python/src/arrays/mod.rs +++ b/vortex-python/src/arrays/mod.rs @@ -26,12 +26,12 @@ use vortex::array::ArrayRef; use vortex::array::ToCanonical; use vortex::array::arrays::ChunkedVTable; use vortex::array::arrow::IntoArrowArray; +use vortex::array::builtins::ArrayBuiltins; use vortex::array::match_each_integer_ptype; -use vortex::compute::Operator; -use vortex::compute::compare; use vortex::dtype::DType; use vortex::dtype::Nullability; use vortex::dtype::PType; +use vortex::expr::Operator; use vortex::ipc::messages::EncoderMessage; use vortex::ipc::messages::MessageEncoder; @@ -454,42 +454,42 @@ impl PyArray { ///Rust docs are *not* copied into Python for __lt__: https://github.com/PyO3/pyo3/issues/4326 fn __lt__(slf: Bound, other: PyArrayRef) -> PyVortexResult { let slf = PyArrayRef::extract(slf.as_any().as_borrowed())?.into_inner(); - let inner = compare(&slf, &*other, Operator::Lt)?; + let inner = slf.binary(other.into_inner(), Operator::Lt)?; Ok(PyArrayRef::from(inner)) } ///Rust docs are *not* copied into Python for __le__: https://github.com/PyO3/pyo3/issues/4326 fn __le__(slf: Bound, other: PyArrayRef) -> PyVortexResult { let slf = PyArrayRef::extract(slf.as_any().as_borrowed())?.into_inner(); - let inner = compare(&*slf, &*other, Operator::Lte)?; + let inner = slf.binary(other.into_inner(), Operator::Lte)?; Ok(PyArrayRef::from(inner)) } ///Rust docs are *not* copied into Python for __eq__: https://github.com/PyO3/pyo3/issues/4326 fn __eq__(slf: Bound, other: PyArrayRef) -> PyVortexResult { let slf = PyArrayRef::extract(slf.as_any().as_borrowed())?.into_inner(); - let inner = compare(&*slf, &*other, Operator::Eq)?; + let inner = slf.binary(other.into_inner(), Operator::Eq)?; Ok(PyArrayRef::from(inner)) } ///Rust docs are *not* copied into Python for __ne__: https://github.com/PyO3/pyo3/issues/4326 fn __ne__(slf: Bound, other: PyArrayRef) -> PyVortexResult { let slf = PyArrayRef::extract(slf.as_any().as_borrowed())?.into_inner(); - let inner = compare(&*slf, &*other, Operator::NotEq)?; + let inner = slf.binary(other.into_inner(), Operator::NotEq)?; Ok(PyArrayRef::from(inner)) } ///Rust docs are *not* copied into Python for __ge__: https://github.com/PyO3/pyo3/issues/4326 fn __ge__(slf: Bound, other: PyArrayRef) -> PyVortexResult { let slf = PyArrayRef::extract(slf.as_any().as_borrowed())?.into_inner(); - let inner = compare(&*slf, &*other, Operator::Gte)?; + let inner = slf.binary(other.into_inner(), Operator::Gte)?; Ok(PyArrayRef::from(inner)) } ///Rust docs are *not* copied into Python for __gt__: https://github.com/PyO3/pyo3/issues/4326 fn __gt__(slf: Bound, other: PyArrayRef) -> PyVortexResult { let slf = PyArrayRef::extract(slf.as_any().as_borrowed())?.into_inner(); - let inner = compare(&*slf, &*other, Operator::Gt)?; + let inner = slf.binary(other.into_inner(), Operator::Gt)?; Ok(PyArrayRef::from(inner)) }