diff --git a/native/Cargo.lock b/native/Cargo.lock index 0eed66b2ba..c922c5f218 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2111,6 +2111,7 @@ version = "0.16.0" dependencies = [ "arrow", "base64", + "bigdecimal", "chrono", "chrono-tz", "criterion", diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 061f287dcf..778e8aa1da 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -41,6 +41,7 @@ twox-hash = "2.1.2" rand = { workspace = true } hex = "0.4.3" base64 = "0.22.1" +bigdecimal = "0.4" [dev-dependencies] arrow = {workspace = true} diff --git a/native/spark-expr/src/math_funcs/round.rs b/native/spark-expr/src/math_funcs/round.rs index 9605f93f17..6ea4fec35c 100644 --- a/native/spark-expr/src/math_funcs/round.rs +++ b/native/spark-expr/src/math_funcs/round.rs @@ -17,14 +17,12 @@ use crate::arithmetic_overflow_error; use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar}; -use arrow::array::{Array, ArrowNativeTypeOp}; +use arrow::array::{Array, ArrowNativeTypeOp, Float32Array, Float64Array}; use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion::common::config::ConfigOptions; +use bigdecimal::{BigDecimal, RoundingMode}; use datafusion::common::{exec_err, internal_err, DataFusionError, ScalarValue}; -use datafusion::functions::math::round::RoundFunc; -use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion::physical_plan::ColumnarValue; use std::{cmp::min, sync::Arc}; @@ -110,8 +108,6 @@ pub fn spark_round( let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else { return internal_err!("Invalid point argument for Round(): {:#?}", point); }; - // DataFusion's RoundFunc expects Int32 for decimal_places - let point_i32 = ColumnarValue::Scalar(ScalarValue::Int32(Some(*point as i32))); match value { ColumnarValue::Array(array) => match array.data_type() { DataType::Int64 if *point < 0 => { @@ -131,17 +127,19 @@ pub fn spark_round( let (precision, scale) = get_precision_scale(data_type); make_decimal_array(array, precision, scale, &f) } - DataType::Float32 | DataType::Float64 => { - let round_udf = RoundFunc::new(); - let return_field = Arc::new(Field::new("round", array.data_type().clone(), true)); - let args_for_round = ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::clone(array)), point_i32.clone()], - number_rows: array.len(), - return_field, - arg_fields: vec![], - config_options: Arc::new(ConfigOptions::default()), - }; - round_udf.invoke_with_args(args_for_round) + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + let result: Float64Array = arrow::compute::kernels::arity::unary(array, |v| { + spark_round_via_bigdecimal_f64(v, *point) + }); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + let result: Float32Array = arrow::compute::kernels::arity::unary(array, |v| { + spark_round_via_bigdecimal_f32(v, *point) + }); + Ok(ColumnarValue::Array(Arc::new(result))) } dt => exec_err!("Not supported datatype for ROUND: {dt}"), }, @@ -163,19 +161,14 @@ pub fn spark_round( let (precision, scale) = get_precision_scale(data_type); make_decimal_scalar(a, precision, scale, &f) } - ScalarValue::Float32(_) | ScalarValue::Float64(_) => { - let round_udf = RoundFunc::new(); - let data_type = a.data_type(); - let return_field = Arc::new(Field::new("round", data_type, true)); - let args_for_round = ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(a.clone()), point_i32.clone()], - number_rows: 1, - return_field, - arg_fields: vec![], - config_options: Arc::new(ConfigOptions::default()), - }; - round_udf.invoke_with_args(args_for_round) - } + ScalarValue::Float64(Some(v)) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + spark_round_via_bigdecimal_f64(*v, *point), + )))), + ScalarValue::Float64(None) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))), + ScalarValue::Float32(Some(v)) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some( + spark_round_via_bigdecimal_f32(*v, *point), + )))), + ScalarValue::Float32(None) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(None))), dt => exec_err!("Not supported datatype for ROUND: {dt}"), }, } @@ -201,6 +194,335 @@ fn decimal_round_f(scale: &i8, point: &i64) -> Box i128> { } } +/// Replicate JDK 17's `Double.toString` (Gay/dtoa algorithm) for Spark-compatible rounding. +/// +/// The Gay algorithm extracts decimal digits one at a time, stopping when the remainder +/// is small enough that the output uniquely identifies the double. We implement the core +/// stopping criterion using BigDecimal arithmetic. +fn double_to_bigdecimal_like_java(v: f64) -> BigDecimal { + let abs_v = v.abs(); + let bits = abs_v.to_bits(); + + if bits == 0 { + return BigDecimal::from(0); + } + + // Extract significand and exponent (matching JDK's convention) + // JDK: binExp = unbiased exponent, fractBits has hidden bit at position 52 + let t = bits & 0x000F_FFFF_FFFF_FFFF; + let bq = ((bits >> 52) & 0x7FF) as i32; + let (fract_bits, bin_exp) = if bq == 0 { + // Subnormal: normalize + let lz = t.leading_zeros() as i32 - 11; + ( + (t << lz) & 0x000F_FFFF_FFFF_FFFF | (1u64 << 52), + -1023 + 1 - lz, + ) + } else { + // Normal + (t | (1u64 << 52), bq - 1023) + }; + + let n_fract_bits = 53 - fract_bits.trailing_zeros() as i32; + let n_sig_bits = if bq != 0 { + 53 + } else { + 64 - (t.leading_zeros() as i32) + }; + let n_tiny = (n_fract_bits - bin_exp - 1).max(0); + + // JDK fast path: for small exponents where the value fits in a long integer + if (-52..=62).contains(&bin_exp) && n_tiny == 0 { + // Value is an exact integer (no fractional bits) + let long_val = if bin_exp >= 52 { + fract_bits << (bin_exp - 52) + } else { + fract_bits >> (52 - bin_exp) + }; + // Determine insignificant trailing digits + let insignificant = if bin_exp > n_sig_bits { + let p2 = (bin_exp - n_sig_bits - 1) as usize; + if p2 > 1 && p2 < 64 { + [ + 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6, 7, 7, + 7, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 12, 13, 13, 13, 14, + 14, 14, 15, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, + ][p2] + } else { + 0 + } + } else { + 0 + }; + // Convert integer to BigDecimal, zeroing out insignificant trailing digits + let mut long_str = long_val.to_string(); + if insignificant > 0 && insignificant < long_str.len() { + let sig_len = long_str.len() - insignificant; + let bytes = unsafe { long_str.as_bytes_mut() }; + for b in &mut bytes[sig_len..] { + *b = b'0'; + } + } + let bd: BigDecimal = long_str.parse().unwrap(); + return if v < 0.0 { -bd } else { bd }; + } + + let dec_exp_est = estimate_dec_exp(fract_bits, bin_exp); + + let b5 = (-dec_exp_est).max(0); + let mut b2 = b5 + n_tiny + bin_exp; + let s5 = dec_exp_est.max(0); + let mut s2 = s5 + n_tiny; + let m5 = b5; + let mut m2 = b2 - n_sig_bits; + + // Remove trailing zeros from fract_bits and adjust B2 + let tail_zeros = fract_bits.trailing_zeros() as i32; + let fract_reduced = fract_bits >> tail_zeros; + b2 -= n_fract_bits - 1; + + // Remove common factor of 2 + let common2 = b2.min(s2).min(m2); + b2 -= common2; + s2 -= common2; + m2 -= common2; + + // For exact powers of 2, halve M + if n_fract_bits == 1 { + m2 -= 1; + } + + // If M2 < 0, scale everything up + if m2 < 0 { + b2 -= m2; + s2 -= m2; + m2 = 0; + } + + use num::bigint::BigUint; + use num::ToPrimitive; + + let b5u = b5.max(0) as u32; + let s5u = s5.max(0) as u32; + let m5u = m5.max(0) as u32; + let b2u = b2.max(0) as u32; + let s2u = s2.max(0) as u32; + let m2u = m2.max(0) as u32; + + // Determine whether to use FDBigInteger-style comparison (>=) or int/long-style (>) + // for the 'high' check. The JDK uses >= for the FDBigInteger path (large values) + // and > for the int/long path (small values). + let n_fract_bits_b2 = n_fract_bits + b2; + let n5_b5 = if (b5u as usize) < 25 { + [ + 0, 3, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28, 31, 33, 35, 38, 40, 42, 45, 47, 49, 52, + 54, 56, + ][b5u as usize] + } else { + b5 * 3 + }; + let bbits = n_fract_bits_b2 + n5_b5; + let n5_s5p1 = if ((s5u + 1) as usize) < 25 { + [ + 0, 3, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28, 31, 33, 35, 38, 40, 42, 45, 47, 49, 52, + 54, 56, + ][(s5u + 1) as usize] + } else { + (s5 + 1) * 3 + }; + let ten_sbits = s2 + 1 + n5_s5p1; + let use_bigint_path = bbits >= 64 || ten_sbits >= 64; + + let pow5 = |n: u32| -> BigUint { BigUint::from(5u32).pow(n) }; + + // Normalize S to improve division accuracy (matching JDK's shiftBias) + let s_base = pow5(s5u) << s2u; + let shift_bias = if use_bigint_path { + // getNormalizationBias: shift S so its highest bit fills the MSB of a u32 word + let s_bits = s_base.bits() as u32; + let word_bits = s_bits.div_ceil(32) * 32; + word_bits - s_bits + } else { + 0u32 + }; + + let mut b_val = (BigUint::from(fract_reduced) * pow5(b5u)) << (b2u + shift_bias); + let s_val = &s_base << shift_bias; + let mut m_val = pow5(m5u) << (m2u + shift_bias); + let tens = &s_val * BigUint::from(10u32); + + let mut digits = Vec::with_capacity(20); + let mut dec_exp = dec_exp_est; + + // First digit + let q = (&b_val / &s_val).to_u32().unwrap_or(0); + b_val = (&b_val % &s_val) * BigUint::from(10u32); + m_val = &m_val * BigUint::from(10u32); + + let mut low = b_val < m_val; + let mut high = if use_bigint_path { + &b_val + &m_val >= tens + } else { + &b_val + &m_val > tens + }; + + #[cfg(test)] + if abs_v > 1e18 { + eprintln!( + "DTOA: q={} low={} high={} dec_exp={} s_bits={} tens_bits={} m_bits={}", + q, + low, + high, + dec_exp, + s_val.bits(), + tens.bits(), + m_val.bits() + ); + } + + if q == 0 && !high { + dec_exp -= 1; + } else { + digits.push(q as u8); + } + + // HACK: for E-form, require more digits + if !(-3..8).contains(&dec_exp) { + low = false; + high = false; + } + + // Extract remaining digits + let mut iter_count = 0; + while !low && !high { + let q = (&b_val / &s_val).to_u32().unwrap_or(0); + b_val = (&b_val % &s_val) * BigUint::from(10u32); + m_val = &m_val * BigUint::from(10u32); + + if m_val > BigUint::from(0u32) { + low = b_val < m_val; + high = if use_bigint_path { + &b_val + &m_val >= tens + } else { + &b_val + &m_val > tens + }; + } else { + low = true; + high = true; + } + digits.push(q as u8); + iter_count += 1; + #[cfg(test)] + if abs_v > 1e18 { + eprintln!( + " iter {}: q={} ndigits={} low={} high={} b_bits={} m_bits={}", + iter_count, + q, + digits.len(), + low, + high, + b_val.bits(), + m_val.bits() + ); + } + if iter_count > 20 { + break; + } // safety + } + + // Final rounding + if high { + if low { + let b2 = &b_val << 1u32; + let cmp = b2.cmp(&tens); + if cmp == std::cmp::Ordering::Equal { + // Tie: round to even + if let Some(&last) = digits.last() { + if last & 1 != 0 { + round_up_digits(&mut digits, &mut dec_exp); + } + } + } else if cmp == std::cmp::Ordering::Greater { + round_up_digits(&mut digits, &mut dec_exp); + } + } else { + round_up_digits(&mut digits, &mut dec_exp); + } + } + + // Convert digits + decExp to BigDecimal + // The value is 0.d1d2d3...dn * 10^(dec_exp+1) + let mut sig = 0i64; + for &d in &digits { + sig = sig * 10 + d as i64; + } + let n_digits = digits.len() as i32; + let scale = n_digits - (dec_exp + 1); + let bd = BigDecimal::new(num::BigInt::from(sig), scale as i64); + if v < 0.0 { + -bd + } else { + bd + } +} + +fn round_up_digits(digits: &mut Vec, dec_exp: &mut i32) { + if let Some(last) = digits.last_mut() { + if *last < 9 { + *last += 1; + return; + } + } + // Carry propagation + let mut i = digits.len(); + while i > 0 { + i -= 1; + if digits[i] < 9 { + digits[i] += 1; + digits.truncate(i + 1); + return; + } + digits[i] = 0; + } + // All 9s: e.g., 999 -> 1000 + digits.clear(); + digits.push(1); + *dec_exp += 1; +} + +fn estimate_dec_exp(fract_bits: u64, bin_exp: i32) -> i32 { + let d2_bits = 0x3FF0_0000_0000_0000u64 | (fract_bits & 0x000F_FFFF_FFFF_FFFF); + let d2 = f64::from_bits(d2_bits); + // These constants are from JDK's estimateDecExp and must match exactly + #[allow(clippy::approx_constant)] + let d = (d2 - 1.5) * 0.289529654 + 0.176091259 + (bin_exp as f64) * 0.301029995663981; + d.floor() as i32 +} + +/// Spark-compatible round for f64. +fn spark_round_via_bigdecimal_f64(v: f64, scale: i64) -> f64 { + if !v.is_finite() { + return v; + } + let bd = double_to_bigdecimal_like_java(v); + bd.with_scale_round(scale, RoundingMode::HalfUp) + .to_string() + .parse::() + .unwrap() +} + +/// Spark-compatible round for f32. +fn spark_round_via_bigdecimal_f32(v: f32, scale: i64) -> f32 { + if !v.is_finite() { + return v; + } + let bd = double_to_bigdecimal_like_java(f64::from(v)); + bd.with_scale_round(scale, RoundingMode::HalfUp) + .to_string() + .parse::() + .unwrap() +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -280,4 +602,104 @@ mod test { assert_eq!(result, 125.23); Ok(()) } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_f64_spark_bigdecimal_edge_case() { + use super::spark_round_via_bigdecimal_f64; + // -5.81855622136895E8: Java toString = "-5.81855622136895E8" (15 sig digits). + // At 15 sig digits, the closest representation is "-5.81855622136895e8" + // which matches Java. The 6th fractional digit is '5' -> rounds up at scale=5. + let v = -5.81855622136895E8_f64; + let result = spark_round_via_bigdecimal_f64(v, 5); + assert_eq!(result, -5.8185562213690E8_f64); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_f64_spark_bigdecimal_tostring_roundtrip() { + use super::spark_round_via_bigdecimal_f64; + // 6.1317116247283497E18: JDK 17 fast path gives integer 6131711624728349600 + // (last 2 digits insignificant, zeroed). Digit at 10^5 is '4' -> rounds DOWN. + let v = 6.131_711_624_728_35E18_f64; + let result = spark_round_via_bigdecimal_f64(v, -5); + let expected: f64 = "6.1317116247282995E18".parse().unwrap(); + assert_eq!(result, expected); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_f64_large_integer_string() { + use super::spark_round_via_bigdecimal_f64; + // cast("-8316362075006449156" as double): exact = -8316362075006449664. + // Rust's Display (shortest repr) gives "-8316362075006450000" = 15 sig digits. + // Java 17's Schubfach gives "-8.3163620750064497E18" = 17 sig digits (closer to exact). + // Both are valid, but the different digit count causes different rounding at scale=-5. + let v: f64 = "-8316362075006449156".parse().unwrap(); + let result = spark_round_via_bigdecimal_f64(v, -5); + // Rust shortest-repr: digit at 10^5 is '5' -> rounds UP (Spark rounds DOWN) + let expected: f64 = "-8.3163620750064005E18".parse().unwrap(); + assert_eq!(result, expected); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_f64_half_up() { + use super::spark_round_via_bigdecimal_f64; + assert_eq!(spark_round_via_bigdecimal_f64(2.5, 0), 3.0); + assert_eq!(spark_round_via_bigdecimal_f64(3.5, 0), 4.0); + assert_eq!(spark_round_via_bigdecimal_f64(-2.5, 0), -3.0); + assert_eq!(spark_round_via_bigdecimal_f64(-3.5, 0), -4.0); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_f64_special_values() { + use super::spark_round_via_bigdecimal_f64; + assert!(spark_round_via_bigdecimal_f64(f64::NAN, 2).is_nan()); + assert_eq!( + spark_round_via_bigdecimal_f64(f64::INFINITY, 2), + f64::INFINITY + ); + assert_eq!( + spark_round_via_bigdecimal_f64(f64::NEG_INFINITY, 2), + f64::NEG_INFINITY + ); + assert_eq!(spark_round_via_bigdecimal_f64(0.0, 2), 0.0); + assert_eq!(spark_round_via_bigdecimal_f64(-0.0, 2), 0.0); + assert_eq!(spark_round_via_bigdecimal_f64(f64::MIN_POSITIVE, 2), 0.0); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_f64_negative_scale() { + use super::spark_round_via_bigdecimal_f64; + assert_eq!(spark_round_via_bigdecimal_f64(123.456, -1), 120.0); + assert_eq!(spark_round_via_bigdecimal_f64(155.0, -2), 200.0); + assert_eq!(spark_round_via_bigdecimal_f64(-155.0, -2), -200.0); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_f32_spark_compatible() { + use super::spark_round_via_bigdecimal_f32; + assert_eq!(spark_round_via_bigdecimal_f32(2.5_f32, 0), 3.0_f32); + assert_eq!(spark_round_via_bigdecimal_f32(-2.5_f32, 0), -3.0_f32); + assert_eq!(spark_round_via_bigdecimal_f32(0.125_f32, 2), 0.13_f32); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_f64_null_scalar() -> Result<()> { + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ]; + let ColumnarValue::Scalar(ScalarValue::Float64(None)) = + spark_round(&args, &DataType::Float64, false)? + else { + unreachable!() + }; + Ok(()) + } } diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala index 85574fbab7..c98b3d96ca 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -303,23 +303,6 @@ object CometRound extends CometExpressionSerde[Round] { exprToProtoInternal(Literal(null), inputs, binding) case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark - case _: FloatType | DoubleType => - // We cannot properly match with the Spark behavior for floating-point numbers. - // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a - // double to string internally in order to create its own internal representation. - // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated - // rounding algorithm. E.g. -5.81855622136895E8 is actually - // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of - // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a - // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be - // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that - // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can - // be rounded up to 6.13171162472835E18 that still represents the same double number. - // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. - // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead - // of 6.1317116247283999E18. - withInfo(r, "Comet does not support Spark's BigDecimal rounding") - None case _ => // `scale` must be Int64 type in DataFusion val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding) diff --git a/spark/src/test/resources/sql-tests/expressions/math/round.sql b/spark/src/test/resources/sql-tests/expressions/math/round.sql index 7a821b8027..dfd00188d4 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/round.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/round.sql @@ -21,18 +21,86 @@ CREATE TABLE test_round(d double, i int) USING parquet statement INSERT INTO test_round VALUES (2.5, 0), (3.5, 0), (-2.5, 0), (123.456, 2), (123.456, -1), (NULL, 0), (cast('NaN' as double), 0), (cast('Infinity' as double), 0), (0.0, 0) -query expect_fallback(BigDecimal rounding) +query SELECT round(d, 0) FROM test_round WHERE i = 0 -query expect_fallback(BigDecimal rounding) +query SELECT round(d, 2) FROM test_round WHERE i = 2 -query expect_fallback(BigDecimal rounding) +query SELECT round(d, -1) FROM test_round WHERE i = -1 -query expect_fallback(BigDecimal rounding) +query SELECT round(d) FROM test_round -- literal + literal -query expect_fallback(BigDecimal rounding) +query SELECT round(123.456, 2), round(2.5, 0), round(3.5, 0), round(-2.5, 0), round(NULL, 0) + +-- HALF_UP semantics: .5 always rounds away from zero +statement +CREATE TABLE test_round_half_up(d double) USING parquet + +statement +INSERT INTO test_round_half_up VALUES (0.5), (1.5), (2.5), (-0.5), (-1.5), (-2.5) + +query +SELECT d, round(d, 0) FROM test_round_half_up + +-- various scales on a single value +query +SELECT round(123.456, 0), round(123.456, 1), round(123.456, 2), round(123.456, 3), round(123.456, 5) + +query +SELECT round(123.456, -1), round(123.456, -2), round(123.456, -3) + +-- special values +query +SELECT round(cast('NaN' as double), 2), round(cast('Infinity' as double), 2), round(cast('-Infinity' as double), 2) + +query +SELECT round(0.0, 5), round(-0.0, 5) + +-- very small values +query +SELECT round(1.0E-10, 15), round(1.0E-10, 10), round(1.0E-10, 5) + +-- negative scale on doubles +query +SELECT round(9999.9, -1), round(9999.9, -2), round(9999.9, -3), round(9999.9, -4) + +query +SELECT round(-9999.9, -1), round(-9999.9, -2), round(-9999.9, -3), round(-9999.9, -4) + +-- float type +statement +CREATE TABLE test_round_float(f float) USING parquet + +statement +INSERT INTO test_round_float VALUES (cast(2.5 as float)), (cast(3.5 as float)), (cast(-2.5 as float)), (cast(0.125 as float)), (cast(0.785 as float)), (cast(123.456 as float)), (cast('NaN' as float)), (cast('Infinity' as float)), (NULL) + +query +SELECT round(f, 0) FROM test_round_float + +query +SELECT round(f, 2) FROM test_round_float + +query +SELECT round(f, -1) FROM test_round_float + +-- BigDecimal rounding edge case from Spark +statement +CREATE TABLE test_round_edge(d double) USING parquet + +statement +INSERT INTO test_round_edge VALUES (-5.81855622136895E8), (6.1317116247283497E18), (6.13171162472835E18) + +query +SELECT round(d, 4), round(d, 5), round(d, 6) FROM test_round_edge + +query +SELECT round('-8316362075006449156', -5) + +-- round with column from table (not literals) +query +SELECT d, round(d, 0), round(d, 2), round(d, -1) FROM test_round diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q78.native_datafusion/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q78.native_datafusion/extended.txt index d29dbc13e5..d344860519 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q78.native_datafusion/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q78.native_datafusion/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -76,4 +76,4 @@ TakeOrderedAndProject +- CometFilter +- CometNativeScan parquet spark_catalog.default.date_dim -Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q78.native_iceberg_compat/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q78.native_iceberg_compat/extended.txt index 1b1e6d0cde..02c33f69e2 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q78.native_iceberg_compat/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark3_5/q78.native_iceberg_compat/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -77,4 +77,4 @@ TakeOrderedAndProject +- CometFilter +- CometScan [native_iceberg_compat] parquet spark_catalog.default.date_dim -Comet accelerated 70 out of 76 eligible operators (92%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 72 out of 76 eligible operators (94%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark4_0/q78.native_datafusion/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark4_0/q78.native_datafusion/extended.txt index d29dbc13e5..d344860519 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark4_0/q78.native_datafusion/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark4_0/q78.native_datafusion/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -76,4 +76,4 @@ TakeOrderedAndProject +- CometFilter +- CometNativeScan parquet spark_catalog.default.date_dim -Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark4_0/q78.native_iceberg_compat/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark4_0/q78.native_iceberg_compat/extended.txt index 1b1e6d0cde..02c33f69e2 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark4_0/q78.native_iceberg_compat/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4-spark4_0/q78.native_iceberg_compat/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -77,4 +77,4 @@ TakeOrderedAndProject +- CometFilter +- CometScan [native_iceberg_compat] parquet spark_catalog.default.date_dim -Comet accelerated 70 out of 76 eligible operators (92%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 72 out of 76 eligible operators (94%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78.native_datafusion/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78.native_datafusion/extended.txt index d29dbc13e5..d344860519 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78.native_datafusion/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78.native_datafusion/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -76,4 +76,4 @@ TakeOrderedAndProject +- CometFilter +- CometNativeScan parquet spark_catalog.default.date_dim -Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78.native_iceberg_compat/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78.native_iceberg_compat/extended.txt index 1b1e6d0cde..02c33f69e2 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78.native_iceberg_compat/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78.native_iceberg_compat/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -77,4 +77,4 @@ TakeOrderedAndProject +- CometFilter +- CometScan [native_iceberg_compat] parquet spark_catalog.default.date_dim -Comet accelerated 70 out of 76 eligible operators (92%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 72 out of 76 eligible operators (94%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q78.native_datafusion/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q78.native_datafusion/extended.txt index d29dbc13e5..d344860519 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q78.native_datafusion/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q78.native_datafusion/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -76,4 +76,4 @@ TakeOrderedAndProject +- CometFilter +- CometNativeScan parquet spark_catalog.default.date_dim -Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q78.native_iceberg_compat/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q78.native_iceberg_compat/extended.txt index 1b1e6d0cde..02c33f69e2 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q78.native_iceberg_compat/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark3_5/q78.native_iceberg_compat/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -77,4 +77,4 @@ TakeOrderedAndProject +- CometFilter +- CometScan [native_iceberg_compat] parquet spark_catalog.default.date_dim -Comet accelerated 70 out of 76 eligible operators (92%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 72 out of 76 eligible operators (94%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark4_0/q78.native_datafusion/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark4_0/q78.native_datafusion/extended.txt index d29dbc13e5..d344860519 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark4_0/q78.native_datafusion/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark4_0/q78.native_datafusion/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -76,4 +76,4 @@ TakeOrderedAndProject +- CometFilter +- CometNativeScan parquet spark_catalog.default.date_dim -Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark4_0/q78.native_iceberg_compat/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark4_0/q78.native_iceberg_compat/extended.txt index 1b1e6d0cde..02c33f69e2 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark4_0/q78.native_iceberg_compat/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7-spark4_0/q78.native_iceberg_compat/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -77,4 +77,4 @@ TakeOrderedAndProject +- CometFilter +- CometScan [native_iceberg_compat] parquet spark_catalog.default.date_dim -Comet accelerated 70 out of 76 eligible operators (92%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 72 out of 76 eligible operators (94%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.native_datafusion/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.native_datafusion/extended.txt index d29dbc13e5..d344860519 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.native_datafusion/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.native_datafusion/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -76,4 +76,4 @@ TakeOrderedAndProject +- CometFilter +- CometNativeScan parquet spark_catalog.default.date_dim -Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.native_iceberg_compat/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.native_iceberg_compat/extended.txt index 1b1e6d0cde..02c33f69e2 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.native_iceberg_compat/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78.native_iceberg_compat/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -77,4 +77,4 @@ TakeOrderedAndProject +- CometFilter +- CometScan [native_iceberg_compat] parquet spark_catalog.default.date_dim -Comet accelerated 70 out of 76 eligible operators (92%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 72 out of 76 eligible operators (94%). Final plan contains 2 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 3f3d6773a0..34a92a4902 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2849,10 +2849,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { Byte.MinValue, Byte.MaxValue, Short.MinValue, - Short.MaxValue)).foreach { value => + Short.MaxValue, + Float.MinPositiveValue, + Float.MaxValue, + Float.NaN, + Float.MinValue, + Float.NegativeInfinity, + Float.PositiveInfinity, + Double.MinPositiveValue, + Double.MaxValue, + Double.NaN, + Double.MinValue, + Double.NegativeInfinity, + Double.PositiveInfinity, + -5.81855622136895e8, + 6.1317116247283497e18)).foreach { value => val data = Seq(value) withParquetTable(data, "tbl") { - Seq(-1000, -100, -10, -1, 0, 1, 10, 100, 1000).foreach { scale => + Seq(-1000, -100, -10, -5, -1, 0, 1, 5, 10, 100, 1000).foreach { scale => Seq(true, false).foreach { ansi => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) { val res = spark.sql(s"SELECT round(_1, $scale) from tbl") @@ -2874,6 +2888,54 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("round") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes( + path, + dictionaryEnabled = dictionaryEnabled, + -128, + 128, + randomSize = 100) + withParquetTable(path.toString, "tbl") { + for (s <- Seq(-5, -1, 0, 1, 5, -1000, 1000, -323, -308, 308, -15, 15, -16, 16, null)) { + // array tests + // TODO: enable test for unsigned ints (_9, _10, _11, _12) + for (c <- Seq(2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17)) { + checkSparkAnswerAndOperator(s"select _${c}, round(_${c}, ${s}) FROM tbl") + } + // scalar tests + // Exclude the constant folding optimizer in order to actually execute the native round + // operations for scalar (literal) values. + withSQLConf( + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + for (n <- Seq("0.0", "-0.0", "0.5", "-0.5", "1.2", "-1.2")) { + checkSparkAnswerAndOperator(s"select round(cast(${n} as tinyint), ${s}) FROM tbl") + checkSparkAnswerAndOperator(s"select round(cast(${n} as float), ${s}) FROM tbl") + checkSparkAnswerAndOperator( + s"select round(cast(${n} as decimal(38, 18)), ${s}) FROM tbl") + checkSparkAnswerAndOperator( + s"select round(cast(${n} as decimal(20, 0)), ${s}) FROM tbl") + } + checkSparkAnswerAndOperator(s"select round(double('infinity'), ${s}) FROM tbl") + checkSparkAnswerAndOperator(s"select round(double('-infinity'), ${s}) FROM tbl") + checkSparkAnswerAndOperator(s"select round(double('NaN'), ${s}) FROM tbl") + checkSparkAnswerAndOperator( + s"select round(double('0.000000000000000000000000000000000001'), ${s}) FROM tbl") + } + } + } + } + } + } + + test("round double from large integer string") { + withParquetTable(Seq(Tuple1("-8316362075006449156")), "tbl") { + checkSparkAnswerAndOperator("SELECT round(cast(_1 as double), -5) FROM tbl") + } + } + test("test integral divide overflow for decimal") { // All inserted values produce a quotient > Decimal(38,0).max (~1e38), so they overflow // the intermediate decimal result type. In legacy/try mode both Spark and Comet return