Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 94 additions & 69 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2317,8 +2317,8 @@ fn cast_string_to_decimal256_impl(
}

/// Parse a string to decimal following Spark's behavior
fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
let string_bytes = s.as_bytes();
fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
let string_bytes = input_str.as_bytes();
let mut start = 0;
let mut end = string_bytes.len();

Expand All @@ -2330,7 +2330,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
end -= 1;
}

let trimmed = &s[start..end];
let trimmed = &input_str[start..end];

if trimmed.is_empty() {
return Ok(None);
Expand All @@ -2347,73 +2347,98 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
return Ok(None);
}

// validate and parse mantissa and exponent
match parse_decimal_str(trimmed) {
Ok((mantissa, exponent)) => {
// Convert to target scale
let target_scale = scale as i32;
let scale_adjustment = target_scale - exponent;
// validate and parse mantissa and exponent or bubble up the error
let (mantissa, exponent) = parse_decimal_str(
trimmed,
input_str,
"STRING",
&format!("DECIMAL({},{})", precision, scale),
)?;

let scaled_value = if scale_adjustment >= 0 {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
if scale_adjustment > 38 {
return Ok(None);
}
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
} else {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
let abs_scale_adjustment = (-scale_adjustment) as u32;
if abs_scale_adjustment > 38 {
return Ok(Some(0));
}
// Early return mantissa 0, Spark checks if it fits digits and throw error in ansi
if mantissa == 0 {
if exponent < -37 {
return Err(SparkError::NumericOutOfRange {
value: input_str.to_string(),
});
}
return Ok(Some(0));
}

let divisor = 10_i128.pow(abs_scale_adjustment);
let quotient_opt = mantissa.checked_div(divisor);
// Check if divisor is 0
if quotient_opt.is_none() {
return Ok(None);
}
let quotient = quotient_opt.unwrap();
let remainder = mantissa % divisor;

// Round half up: if abs(remainder) >= divisor/2, round away from zero
let half_divisor = divisor / 2;
let rounded = if remainder.abs() >= half_divisor {
if mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
};
Some(rounded)
};
// scale adjustment
let target_scale = scale as i32;
let scale_adjustment = target_scale - exponent;

match scaled_value {
Some(value) => {
// Check if it fits target precision
if is_validate_decimal_precision(value, precision) {
Ok(Some(value))
} else {
Ok(None)
}
}
None => {
// Overflow while scaling
Ok(None)
}
let scaled_value = if scale_adjustment >= 0 {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
if scale_adjustment > 38 {
return Ok(None);
}
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
} else {
// Need to divide (decrease scale)
let abs_scale_adjustment = (-scale_adjustment) as u32;
if abs_scale_adjustment > 38 {
return Ok(Some(0));
}

let divisor = 10_i128.pow(abs_scale_adjustment);
let quotient_opt = mantissa.checked_div(divisor);
// Check if divisor is 0
if quotient_opt.is_none() {
return Ok(None);
}
let quotient = quotient_opt.unwrap();
let remainder = mantissa % divisor;

// Round half up: if abs(remainder) >= divisor/2, round away from zero
let half_divisor = divisor / 2;
let rounded = if remainder.abs() >= half_divisor {
if mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
};
Some(rounded)
};

match scaled_value {
Some(value) => {
if is_validate_decimal_precision(value, precision) {
Ok(Some(value))
} else {
// Value ok but exceeds precision mentioned . THrow error
Err(SparkError::NumericValueOutOfRange {
value: trimmed.to_string(),
precision,
scale,
})
}
}
None => {
// Overflow when scaling raise exception
Err(SparkError::NumericValueOutOfRange {
value: trimmed.to_string(),
precision,
scale,
})
}
Err(_) => Ok(None),
}
}

/// Parse a decimal string into mantissa and scale
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc
fn parse_decimal_str(
s: &str,
original_str: &str,
from_type: &str,
to_type: &str,
) -> SparkResult<(i128, i32)> {
if s.is_empty() {
return Err("Empty string".to_string());
return Err(invalid_value(original_str, from_type, to_type));
}

let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) {
Expand All @@ -2422,7 +2447,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
// Parse exponent
let exp: i32 = exponent_part
.parse()
.map_err(|e| format!("Invalid exponent: {}", e))?;
.map_err(|_| invalid_value(original_str, from_type, to_type))?;

(mantissa_part, exp)
} else {
Expand All @@ -2437,29 +2462,29 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
};

if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
return Err("Invalid sign format".to_string());
return Err(invalid_value(original_str, from_type, to_type));
}

let (integral_part, fractional_part) = match mantissa_str.find('.') {
Some(dot_pos) => {
if mantissa_str[dot_pos + 1..].contains('.') {
return Err("Multiple decimal points".to_string());
return Err(invalid_value(original_str, from_type, to_type));
}
(&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
}
None => (mantissa_str, ""),
};

if integral_part.is_empty() && fractional_part.is_empty() {
return Err("No digits found".to_string());
return Err(invalid_value(original_str, from_type, to_type));
}

if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) {
return Err("Invalid integral part".to_string());
return Err(invalid_value(original_str, from_type, to_type));
}

if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) {
return Err("Invalid fractional part".to_string());
return Err(invalid_value(original_str, from_type, to_type));
}

// Parse integral part
Expand All @@ -2469,7 +2494,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
} else {
integral_part
.parse()
.map_err(|_| "Invalid integral part".to_string())?
.map_err(|_| invalid_value(original_str, from_type, to_type))?
};

// Parse fractional part
Expand All @@ -2479,14 +2504,14 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
} else {
fractional_part
.parse()
.map_err(|_| "Invalid fractional part".to_string())?
.map_err(|_| invalid_value(original_str, from_type, to_type))?
};

// Combine: value = integral * 10^fractional_scale + fractional
let mantissa = integral_value
.checked_mul(10_i128.pow(fractional_scale as u32))
.and_then(|v| v.checked_add(fractional_value))
.ok_or("Overflow in mantissa calculation")?;
.ok_or_else(|| invalid_value(original_str, from_type, to_type))?;

let final_mantissa = if negative { -mantissa } else { mantissa };
// final scale = fractional_scale - exponent
Expand Down
3 changes: 3 additions & 0 deletions native/spark-expr/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ pub enum SparkError {
scale: i8,
},

#[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")]
NumericOutOfRange { value: String },

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
Expand Down
35 changes: 24 additions & 11 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}

import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible
Expand Down Expand Up @@ -709,8 +708,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("cast StringType to DecimalType(10,2) (does not support fullwidth unicode digits)") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled))
Expand All @@ -719,18 +716,38 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("cast StringType to DecimalType(2,2)") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
}
}

test("cast StringType to DecimalType check if right exception message is thrown") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
val values = Seq("d11307\n").toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
}
}

test("cast StringType to DecimalType(2,2) check if right exception is being thrown") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
val values = gen.generateInts(10000).map(" " + _).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
}
}

test("cast StringType to DecimalType(38,10) high precision - check 0 mantissa") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
val values = Seq("0e31", "000e3375", "0e40", "0E+695", "0e5887677").toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
}
}

test("cast StringType to DecimalType(38,10) high precision") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
Expand All @@ -739,8 +756,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("cast StringType to DecimalType(10,2) basic values") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
val values = Seq(
"123.45",
"-67.89",
Expand All @@ -766,8 +781,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("cast StringType to Decimal type scientific notation") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
val values = Seq(
"1.23E-5",
"1.23e10",
Expand Down