diff --git a/arrow-schema/src/datatype_parse.rs b/arrow-schema/src/datatype_parse.rs index 721bbda11a09..1be24a574517 100644 --- a/arrow-schema/src/datatype_parse.rs +++ b/arrow-schema/src/datatype_parse.rs @@ -17,7 +17,11 @@ use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; -use crate::{ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields, UnionMode}; +use crate::{ + ArrowError, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, + DECIMAL256_MAX_PRECISION, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields, + UnionMode, +}; /// Parses a DataType from a string representation /// @@ -181,6 +185,12 @@ impl<'a> Parser<'a> { fn parse_fixed_size_list(&mut self) -> ArrowResult { self.expect_token(Token::LParen)?; let length = self.parse_i32("FixedSizeList")?; + if length < 0 { + return Err(make_error( + self.val, + &format!("FixedSizeList length must be non-negative, got {length}"), + )); + } match self.next_token()? { // `FixedSizeList(5 x non-null Int64, field: 'foo')` format Token::X => { @@ -335,6 +345,15 @@ impl<'a> Parser<'a> { fn parse_time32(&mut self) -> ArrowResult { self.expect_token(Token::LParen)?; let time_unit = self.parse_time_unit("Time32")?; + match time_unit { + TimeUnit::Second | TimeUnit::Millisecond => (), + TimeUnit::Microsecond | TimeUnit::Nanosecond => { + return Err(make_error( + self.val, + &format!("Time32 time unit must be 's' or 'ms', got '{time_unit}'"), + )); + } + }; self.expect_token(Token::RParen)?; Ok(DataType::Time32(time_unit)) } @@ -343,6 +362,15 @@ impl<'a> Parser<'a> { fn parse_time64(&mut self) -> ArrowResult { self.expect_token(Token::LParen)?; let time_unit = self.parse_time_unit("Time64")?; + match time_unit { + TimeUnit::Microsecond | TimeUnit::Nanosecond => (), + TimeUnit::Second | TimeUnit::Millisecond => { + return Err(make_error( + self.val, + &format!("Time64 time unit must be 'µs' or 'ns', got '{time_unit}'"), + )); + } + }; self.expect_token(Token::RParen)?; Ok(DataType::Time64(time_unit)) } @@ -385,6 +413,32 @@ impl<'a> Parser<'a> { Ok(DataType::FixedSizeBinary(length)) } + fn validate_decimal( + &self, + precision: u8, + scale: i8, + type_name: &str, + max_precision: u8, + ) -> ArrowResult<()> { + if precision == 0 || precision > max_precision { + return Err(make_error( + self.val, + &format!( + "{type_name} precision must be in range [1, {max_precision}], got '{precision}'" + ), + )); + } + if scale > 0 && scale as u8 > precision { + return Err(make_error( + self.val, + &format!( + "{type_name} scale '{scale}' cannot be greater than precision '{precision}'" + ), + )); + } + Ok(()) + } + /// Parses the next Decimal32 (called after `Decimal32` has been consumed) fn parse_decimal_32(&mut self) -> ArrowResult { self.expect_token(Token::LParen)?; @@ -392,6 +446,7 @@ impl<'a> Parser<'a> { self.expect_token(Token::Comma)?; let scale = self.parse_i8("Decimal32")?; self.expect_token(Token::RParen)?; + self.validate_decimal(precision, scale, "Decimal32", DECIMAL32_MAX_PRECISION)?; Ok(DataType::Decimal32(precision, scale)) } @@ -402,6 +457,7 @@ impl<'a> Parser<'a> { self.expect_token(Token::Comma)?; let scale = self.parse_i8("Decimal64")?; self.expect_token(Token::RParen)?; + self.validate_decimal(precision, scale, "Decimal64", DECIMAL64_MAX_PRECISION)?; Ok(DataType::Decimal64(precision, scale)) } @@ -412,6 +468,7 @@ impl<'a> Parser<'a> { self.expect_token(Token::Comma)?; let scale = self.parse_i8("Decimal128")?; self.expect_token(Token::RParen)?; + self.validate_decimal(precision, scale, "Decimal128", DECIMAL128_MAX_PRECISION)?; Ok(DataType::Decimal128(precision, scale)) } @@ -422,6 +479,7 @@ impl<'a> Parser<'a> { self.expect_token(Token::Comma)?; let scale = self.parse_i8("Decimal256")?; self.expect_token(Token::RParen)?; + self.validate_decimal(precision, scale, "Decimal256", DECIMAL256_MAX_PRECISION)?; Ok(DataType::Decimal256(precision, scale)) } @@ -986,10 +1044,6 @@ mod test { DataType::Date64, DataType::Time32(TimeUnit::Second), DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Microsecond), - DataType::Time32(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Second), - DataType::Time64(TimeUnit::Millisecond), DataType::Time64(TimeUnit::Microsecond), DataType::Time64(TimeUnit::Nanosecond), DataType::Duration(TimeUnit::Second), @@ -1007,10 +1061,10 @@ mod test { DataType::Utf8, DataType::Utf8View, DataType::LargeUtf8, - DataType::Decimal32(7, 8), - DataType::Decimal64(6, 9), - DataType::Decimal128(7, 12), - DataType::Decimal256(6, 13), + DataType::Decimal32(7, 6), + DataType::Decimal64(6, 5), + DataType::Decimal128(7, 6), + DataType::Decimal256(6, 5), // --------- // Nested types // --------- @@ -1300,10 +1354,6 @@ mod test { ("Date64", Date64), ("Time32(s)", Time32(Second)), ("Time32(ms)", Time32(Millisecond)), - ("Time32(µs)", Time32(Microsecond)), - ("Time32(ns)", Time32(Nanosecond)), - ("Time64(s)", Time64(Second)), - ("Time64(ms)", Time64(Millisecond)), ("Time64(µs)", Time64(Microsecond)), ("Time64(ns)", Time64(Nanosecond)), ("Duration(s)", Duration(Second)), @@ -1321,10 +1371,10 @@ mod test { ("Utf8", Utf8), ("Utf8View", Utf8View), ("LargeUtf8", LargeUtf8), - ("Decimal32(7, 8)", Decimal32(7, 8)), - ("Decimal64(6, 9)", Decimal64(6, 9)), - ("Decimal128(7, 12)", Decimal128(7, 12)), - ("Decimal256(6, 13)", Decimal256(6, 13)), + ("Decimal32(7, 6)", Decimal32(7, 6)), + ("Decimal64(6, 5)", Decimal64(6, 5)), + ("Decimal128(7, 6)", Decimal128(7, 6)), + ("Decimal256(6, 5)", Decimal256(6, 5)), ( "Dictionary(Int32, Utf8)", Dictionary(Box::new(Int32), Box::new(Utf8)), @@ -1446,6 +1496,10 @@ mod test { "FixedSizeBinary(-1), ", "FixedSizeBinary length must be non-negative, got -1", ), + ( + "FixedSizeList(-1, Int64), ", + "FixedSizeList length must be non-negative, got -1", + ), // can't have negative precision ( "Decimal32(-3, 5)", @@ -1485,6 +1539,74 @@ mod test { "Struct(\"f1\": )", "Error finding next type, got unexpected ')'", ), + // Invalid time combinations + ( + "Time32(µs)", + "Error Time32 time unit must be 's' or 'ms', got 'µs'", + ), + ( + "Time32(ns)", + "Error Time32 time unit must be 's' or 'ms', got 'ns'", + ), + ( + "Time64(s)", + "Error Time64 time unit must be 'µs' or 'ns', got 's'", + ), + ( + "Time64(ms)", + "Error Time64 time unit must be 'µs' or 'ns', got 'ms'", + ), + // Decimals can't have scale exceeding precision + ( + "Decimal32(5, 6)", + "Error Decimal32 scale '6' cannot be greater than precision '5'", + ), + ( + "Decimal64(5, 6)", + "Error Decimal64 scale '6' cannot be greater than precision '5'", + ), + ( + "Decimal128(5, 6)", + "Error Decimal128 scale '6' cannot be greater than precision '5'", + ), + ( + "Decimal256(5, 6)", + "Error Decimal256 scale '6' cannot be greater than precision '5'", + ), + // Decimals have a max supported precision + ( + "Decimal32(10, 0)", + "Error Decimal32 precision must be in range [1, 9], got '10'", + ), + ( + "Decimal64(19, 0)", + "Error Decimal64 precision must be in range [1, 18], got '19'", + ), + ( + "Decimal128(39, 0)", + "Error Decimal128 precision must be in range [1, 38], got '39'", + ), + ( + "Decimal256(77, 0)", + "Error Decimal256 precision must be in range [1, 76], got '77'", + ), + // Decimals precision can't be 0 + ( + "Decimal32(0, 0)", + "Error Decimal32 precision must be in range [1, 9], got '0'", + ), + ( + "Decimal64(0, 0)", + "Error Decimal64 precision must be in range [1, 18], got '0'", + ), + ( + "Decimal128(0, 0)", + "Error Decimal128 precision must be in range [1, 38], got '0'", + ), + ( + "Decimal256(0, 0)", + "Error Decimal256 precision must be in range [1, 76], got '0'", + ), ]; for (data_type_string, expected_message) in cases {