diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 2d69728bda919..d6a6693d862cc 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -247,6 +247,10 @@ harness = false name = "replace" required-features = ["string_expressions"] +[[bench]] +harness = false +name = "overlay" + [[bench]] harness = false name = "random" diff --git a/datafusion/functions/benches/overlay.rs b/datafusion/functions/benches/overlay.rs new file mode 100644 index 0000000000000..4554cc435e738 --- /dev/null +++ b/datafusion/functions/benches/overlay.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod helper; + +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use helper::gen_string_array; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + const N_ROWS: usize = 8192; + const STR_LEN: usize = 128; + + let overlay = datafusion_functions::core::overlay(); + let config_options = Arc::new(ConfigOptions::default()); + + let mut args = gen_string_array(N_ROWS, STR_LEN, 0.1, 0.5, false); + args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "DataFusion".to_string(), + )))); + args.push(ColumnarValue::Scalar(ScalarValue::Int64(Some(32)))); + args.push(ColumnarValue::Scalar(ScalarValue::Int64(Some(8)))); + + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Utf8, true)); + + c.bench_function("overlay_StringArray_utf8_scalar_args", |b| { + b.iter(|| { + black_box( + overlay + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/core/overlay.rs b/datafusion/functions/src/core/overlay.rs index 58f22caf52e32..6f48c403f338d 100644 --- a/datafusion/functions/src/core/overlay.rs +++ b/datafusion/functions/src/core/overlay.rs @@ -112,84 +112,138 @@ impl ScalarUDFImpl for OverlayFunc { } } +/// Converts a 0-based character index into a byte index suitable for UTF-8 +/// slicing. +fn byte_index_for_char(string: &str, char_idx: usize, is_ascii: bool) -> usize { + if is_ascii { + char_idx.min(string.len()) + } else { + string + .char_indices() + .nth(char_idx) + .map_or(string.len(), |(byte_idx, _)| byte_idx) + } +} + +/// Builds the OVERLAY result for a single (non-null) row. +/// +/// `start_pos` is a 1-based character position; `replace_len` is the number +/// of characters of `string` to replace with `characters`. +fn overlay_one( + string: &str, + characters: &str, + start_pos: i64, + replace_len: i64, +) -> String { + debug_assert!(start_pos >= 1); + + let is_ascii = string.is_ascii(); + let string_char_len = if is_ascii { + string.len() as i64 + } else { + string.chars().count() as i64 + }; + + // Convert SQL's 1-based character position into 0-based character indexes. + // `start_char_idx` is the first replaced character; `end_char_idx` is the + // first character after the replaced span. + // + // No upper-bound check on `start_char_idx`: when it exceeds `string_char_len` + // we want the whole string as the prefix (PostgreSQL-compatible "insert past + // end" semantics). + let start_char_idx = start_pos - 1; + let end_char_idx = start_char_idx.saturating_add(replace_len); + + let prefix_char_idx = usize::try_from(start_char_idx).unwrap_or(usize::MAX); + let prefix_end_byte = byte_index_for_char(string, prefix_char_idx, is_ascii); + + let mut res = String::with_capacity(string.len() + characters.len()); + res.push_str(&string[..prefix_end_byte]); + res.push_str(characters); + + if end_char_idx < string_char_len { + let suffix_char_idx = usize::try_from(end_char_idx.max(0)).unwrap_or(usize::MAX); + let suffix_start_byte = byte_index_for_char(string, suffix_char_idx, is_ascii); + res.push_str(&string[suffix_start_byte..]); + } + res +} + macro_rules! process_overlay { - // For the three-argument case - ($string_array:expr, $characters_array:expr, $pos_num:expr) => {{ + // Three argument case + ($string_array:expr, $characters_array:expr, $pos_array:expr) => {{ $string_array - .iter() - .zip($characters_array.iter()) - .zip($pos_num.iter()) - .map(|((string, characters), start_pos)| { - match (string, characters, start_pos) { - (Some(string), Some(characters), Some(start_pos)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = characters_len as i64; - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); + .iter() + .zip($characters_array.iter()) + .zip($pos_array.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + if start_pos < 1 { + return exec_err!("negative substring length not allowed"); + } + let replace_len = characters.chars().count() as i64; + Ok(Some(overlay_one( + string, + characters, + start_pos, + replace_len, + ))) } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) + _ => Ok(None), } - _ => Ok(None), - } - }) - .collect::>>() + }) + .collect::>>() }}; - // For the four-argument case - ($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{ + // Four argument case + ($string_array:expr, $characters_array:expr, $pos_array:expr, $len_array:expr) => {{ $string_array - .iter() - .zip($characters_array.iter()) - .zip($pos_num.iter()) - .zip($len_num.iter()) - .map(|(((string, characters), start_pos), len)| { - match (string, characters, start_pos, len) { - (Some(string), Some(characters), Some(start_pos), Some(len)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = len.min(string_len as i64); - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); + .iter() + .zip($characters_array.iter()) + .zip($pos_array.iter()) + .zip($len_array.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + if start_pos < 1 { + return exec_err!("negative substring length not allowed"); + } + let string_char_len = string.chars().count() as i64; + let replace_len = len.min(string_char_len); + Ok(Some(overlay_one( + string, + characters, + start_pos, + replace_len, + ))) } - Ok(Some(res)) + _ => Ok(None), } - _ => Ok(None), - } - }) - .collect::>>() + }) + .collect::>>() }}; } -/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) -/// Replaces a substring of string1 with string2 starting at the integer bit -/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas -/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead +/// `OVERLAY(string PLACING substring FROM start [FOR count])` +/// +/// Replaces a region of `string` with `substring`, starting at the 1-based +/// character position `start`. If `count` is supplied, that many characters +/// of `string` are replaced; otherwise `count` defaults to the character +/// length of `substring`. +/// +/// ```text +/// overlay('Txxxxas' placing 'hom' from 2 for 4) → 'Thomas' +/// overlay('Txxxxas' placing 'hom' from 2) → 'Thomxas' +/// ``` fn overlay(args: &[ArrayRef]) -> Result { - let use_string_view = args[0].data_type() == &DataType::Utf8View; - if use_string_view { + if !matches!(args.len(), 3 | 4) { + return exec_err!( + "overlay was called with {} arguments. It requires 3 or 4.", + args.len() + ); + } + if args[0].data_type() == &DataType::Utf8View { string_view_overlay::(args) } else { string_overlay::(args) @@ -197,55 +251,31 @@ fn overlay(args: &[ArrayRef]) -> Result { } fn string_overlay(args: &[ArrayRef]) -> Result { - match args.len() { - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_array = as_int64_array(&args[2])?; - let result = process_overlay!(string_array, characters_array, pos_num)?; - Ok(Arc::new(result) as ArrayRef) - } - 4 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - let len_num = as_int64_array(&args[3])?; - - let result = - process_overlay!(string_array, characters_array, pos_num, len_num)?; - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") - } - } + let result = if args.len() == 4 { + let len_array = as_int64_array(&args[3])?; + process_overlay!(string_array, characters_array, pos_array, len_array)? + } else { + process_overlay!(string_array, characters_array, pos_array)? + }; + Ok(Arc::new(result) as ArrayRef) } fn string_view_overlay(args: &[ArrayRef]) -> Result { - match args.len() { - 3 => { - let string_array = as_string_view_array(&args[0])?; - let characters_array = as_string_view_array(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - - let result = process_overlay!(string_array, characters_array, pos_num)?; - Ok(Arc::new(result) as ArrayRef) - } - 4 => { - let string_array = as_string_view_array(&args[0])?; - let characters_array = as_string_view_array(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - let len_num = as_int64_array(&args[3])?; + let string_array = as_string_view_array(&args[0])?; + let characters_array = as_string_view_array(&args[1])?; + let pos_array = as_int64_array(&args[2])?; - let result = - process_overlay!(string_array, characters_array, pos_num, len_num)?; - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") - } - } + let result = if args.len() == 4 { + let len_array = as_int64_array(&args[3])?; + process_overlay!(string_array, characters_array, pos_array, len_array)? + } else { + process_overlay!(string_array, characters_array, pos_array)? + }; + Ok(Arc::new(result) as ArrayRef) } #[cfg(test)] @@ -265,7 +295,9 @@ mod tests { let res = overlay::(&[string, replace_string, start, end]).unwrap(); let result = as_generic_string_array::(&res).unwrap(); - let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + // First row: start=4 is past the end of "123" (len 3). PostgreSQL + // takes the whole string as prefix and appends the replacement. + let expected = StringArray::from(vec!["123abc", "qwertyasdfg", "ijkz", "Thomas"]); assert_eq!(&expected, result); Ok(()) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index ee11dc973bbd7..49e2c27f90806 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -707,10 +707,11 @@ CREATE TABLE over_test( ('Txxxxas', NULL, 2, 4) ; +# If the start position is past the end, OVERLAY appends the replacement. query T SELECT overlay(str placing characters from pos for len) from over_test ---- -abc +123abc qwertyasdfg ijkz Thomas @@ -722,7 +723,7 @@ NULL query T SELECT overlay(str placing characters from pos) from over_test ---- -abc +123abc qwertyasdfg ijk Thomxas @@ -735,7 +736,7 @@ NULL query T SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos for len) from over_test ---- -abc +123abc qwertyasdfg ijkz Thomas @@ -747,7 +748,7 @@ NULL query T SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos) from over_test ---- -abc +123abc qwertyasdfg ijk Thomxas @@ -756,6 +757,84 @@ NULL Thomxas NULL +# overlay uses character positions for non-ASCII input. +statement ok +CREATE TABLE over_unicode_test( + str TEXT, + characters TEXT, + pos INT, + len INT +) as VALUES + ('aébc', 'X', 3, 1), + ('aébc', 'ZZ', 2, 2), + ('αβγδ', 'XY', 3, 2) +; + +query T +SELECT overlay(str placing characters from pos for len) FROM over_unicode_test +---- +aéXc +aZZc +αβXY + +# Same non-ASCII cases with Utf8View inputs. +query T +SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos for len) FROM over_unicode_test +---- +aéXc +aZZc +αβXY + +# Same non-ASCII cases with LargeUtf8 inputs. +query T +SELECT overlay(arrow_cast(str, 'LargeUtf8') placing arrow_cast(characters, 'LargeUtf8') from pos for len) FROM over_unicode_test +---- +aéXc +aZZc +αβXY + +statement ok +DROP TABLE over_unicode_test + +# overlay edge cases that match PostgreSQL behavior. + +# Start past the end appends the replacement. +query T +SELECT overlay('abc' placing 'X' from 5 for 1) +---- +abcX + +# Start positions must be positive. +statement error negative substring length not allowed +SELECT overlay('abc' placing 'X' from 0 for 1) + +statement error negative substring length not allowed +SELECT overlay('abc' placing 'X' from -1 for 1) + +# Negative count keeps the suffix from before the start position. +query T +SELECT overlay('abc' placing 'XY' from 2 for -1) +---- +aXYabc + +# Count 0 inserts without deleting input characters. +query T +SELECT overlay('abc' placing 'XY' from 2 for 0) +---- +aXYbc + +# Count larger than the remaining input replaces through the end. +query T +SELECT overlay('abc' placing 'XY' from 2 for 100) +---- +aXY + +# Empty replacement deletes the selected input characters. +query T +SELECT overlay('abc' placing '' from 2 for 1) +---- +ac + # Verify that multiple calls to volatile functions like `random()` are not combined / optimized away query B SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0) diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index fb0901ebf8e37..9e5b8f91e7d8e 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -860,8 +860,8 @@ ufoor_score ufoość core u🔥der_score u🔥 iść core NULL NULL pfooent pfooTadeusz ma iść w kąt p🔥rcent p🔥n Tadeusz ma iść w kąt NULL NULL foo foo 🔥 🔥 NULL NULL foo foo 🔥 🔥 NULL NULL -foo foo 🔥 🔥 NULL NULL -foo foo 🔥 🔥 NULL NULL +%foo foo %🔥 🔥 NULL NULL +_foo foo _🔥 🔥 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL