Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 139 additions & 17 deletions arrow-schema/src/datatype_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc};

use crate::{ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields, UnionMode};
use crate::{
ArrowError, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION,
DECIMAL256_MAX_PRECISION, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields,
UnionMode,
};

/// Parses a DataType from a string representation
///
Expand Down Expand Up @@ -181,6 +185,12 @@ impl<'a> Parser<'a> {
fn parse_fixed_size_list(&mut self) -> ArrowResult<DataType> {
self.expect_token(Token::LParen)?;
let length = self.parse_i32("FixedSizeList")?;
if length < 0 {
return Err(make_error(
self.val,
&format!("FixedSizeList length must be non-negative, got {length}"),
));
}
match self.next_token()? {
// `FixedSizeList(5 x non-null Int64, field: 'foo')` format
Token::X => {
Expand Down Expand Up @@ -335,6 +345,15 @@ impl<'a> Parser<'a> {
fn parse_time32(&mut self) -> ArrowResult<DataType> {
self.expect_token(Token::LParen)?;
let time_unit = self.parse_time_unit("Time32")?;
match time_unit {
TimeUnit::Second | TimeUnit::Millisecond => (),
TimeUnit::Microsecond | TimeUnit::Nanosecond => {
return Err(make_error(
self.val,
&format!("Time32 time unit must be 's' or 'ms', got '{time_unit}'"),
));
}
};
self.expect_token(Token::RParen)?;
Ok(DataType::Time32(time_unit))
}
Expand All @@ -343,6 +362,15 @@ impl<'a> Parser<'a> {
fn parse_time64(&mut self) -> ArrowResult<DataType> {
self.expect_token(Token::LParen)?;
let time_unit = self.parse_time_unit("Time64")?;
match time_unit {
TimeUnit::Microsecond | TimeUnit::Nanosecond => (),
TimeUnit::Second | TimeUnit::Millisecond => {
return Err(make_error(
self.val,
&format!("Time64 time unit must be 'µs' or 'ns', got '{time_unit}'"),
));
}
};
self.expect_token(Token::RParen)?;
Ok(DataType::Time64(time_unit))
}
Expand Down Expand Up @@ -385,13 +413,40 @@ impl<'a> Parser<'a> {
Ok(DataType::FixedSizeBinary(length))
}

fn validate_decimal(
&self,
precision: u8,
scale: i8,
type_name: &str,
max_precision: u8,
) -> ArrowResult<()> {
if precision == 0 || precision > max_precision {
return Err(make_error(
self.val,
&format!(
"{type_name} precision must be in range [1, {max_precision}], got '{precision}'"
),
));
}
if scale > 0 && scale as u8 > precision {
return Err(make_error(
self.val,
&format!(
"{type_name} scale '{scale}' cannot be greater than precision '{precision}'"
),
));
}
Ok(())
}

/// Parses the next Decimal32 (called after `Decimal32` has been consumed)
fn parse_decimal_32(&mut self) -> ArrowResult<DataType> {
self.expect_token(Token::LParen)?;
let precision = self.parse_u8("Decimal32")?;
self.expect_token(Token::Comma)?;
let scale = self.parse_i8("Decimal32")?;
self.expect_token(Token::RParen)?;
self.validate_decimal(precision, scale, "Decimal32", DECIMAL32_MAX_PRECISION)?;
Ok(DataType::Decimal32(precision, scale))
}

Expand All @@ -402,6 +457,7 @@ impl<'a> Parser<'a> {
self.expect_token(Token::Comma)?;
let scale = self.parse_i8("Decimal64")?;
self.expect_token(Token::RParen)?;
self.validate_decimal(precision, scale, "Decimal64", DECIMAL64_MAX_PRECISION)?;
Ok(DataType::Decimal64(precision, scale))
}

Expand All @@ -412,6 +468,7 @@ impl<'a> Parser<'a> {
self.expect_token(Token::Comma)?;
let scale = self.parse_i8("Decimal128")?;
self.expect_token(Token::RParen)?;
self.validate_decimal(precision, scale, "Decimal128", DECIMAL128_MAX_PRECISION)?;
Ok(DataType::Decimal128(precision, scale))
}

Expand All @@ -422,6 +479,7 @@ impl<'a> Parser<'a> {
self.expect_token(Token::Comma)?;
let scale = self.parse_i8("Decimal256")?;
self.expect_token(Token::RParen)?;
self.validate_decimal(precision, scale, "Decimal256", DECIMAL256_MAX_PRECISION)?;
Ok(DataType::Decimal256(precision, scale))
}

Expand Down Expand Up @@ -986,10 +1044,6 @@ mod test {
DataType::Date64,
DataType::Time32(TimeUnit::Second),
DataType::Time32(TimeUnit::Millisecond),
DataType::Time32(TimeUnit::Microsecond),
DataType::Time32(TimeUnit::Nanosecond),
DataType::Time64(TimeUnit::Second),
DataType::Time64(TimeUnit::Millisecond),
DataType::Time64(TimeUnit::Microsecond),
DataType::Time64(TimeUnit::Nanosecond),
DataType::Duration(TimeUnit::Second),
Expand All @@ -1007,10 +1061,10 @@ mod test {
DataType::Utf8,
DataType::Utf8View,
DataType::LargeUtf8,
DataType::Decimal32(7, 8),
DataType::Decimal64(6, 9),
DataType::Decimal128(7, 12),
DataType::Decimal256(6, 13),
DataType::Decimal32(7, 6),
DataType::Decimal64(6, 5),
DataType::Decimal128(7, 6),
DataType::Decimal256(6, 5),
// ---------
// Nested types
// ---------
Expand Down Expand Up @@ -1300,10 +1354,6 @@ mod test {
("Date64", Date64),
("Time32(s)", Time32(Second)),
("Time32(ms)", Time32(Millisecond)),
("Time32(µs)", Time32(Microsecond)),
("Time32(ns)", Time32(Nanosecond)),
("Time64(s)", Time64(Second)),
("Time64(ms)", Time64(Millisecond)),
("Time64(µs)", Time64(Microsecond)),
("Time64(ns)", Time64(Nanosecond)),
("Duration(s)", Duration(Second)),
Expand All @@ -1321,10 +1371,10 @@ mod test {
("Utf8", Utf8),
("Utf8View", Utf8View),
("LargeUtf8", LargeUtf8),
("Decimal32(7, 8)", Decimal32(7, 8)),
("Decimal64(6, 9)", Decimal64(6, 9)),
("Decimal128(7, 12)", Decimal128(7, 12)),
("Decimal256(6, 13)", Decimal256(6, 13)),
("Decimal32(7, 6)", Decimal32(7, 6)),
("Decimal64(6, 5)", Decimal64(6, 5)),
("Decimal128(7, 6)", Decimal128(7, 6)),
("Decimal256(6, 5)", Decimal256(6, 5)),
(
"Dictionary(Int32, Utf8)",
Dictionary(Box::new(Int32), Box::new(Utf8)),
Expand Down Expand Up @@ -1446,6 +1496,10 @@ mod test {
"FixedSizeBinary(-1), ",
"FixedSizeBinary length must be non-negative, got -1",
),
(
"FixedSizeList(-1, Int64), ",
"FixedSizeList length must be non-negative, got -1",
),
// can't have negative precision
(
"Decimal32(-3, 5)",
Expand Down Expand Up @@ -1485,6 +1539,74 @@ mod test {
"Struct(\"f1\": )",
"Error finding next type, got unexpected ')'",
),
// Invalid time combinations
(
"Time32(µs)",
"Error Time32 time unit must be 's' or 'ms', got 'µs'",
),
(
"Time32(ns)",
"Error Time32 time unit must be 's' or 'ms', got 'ns'",
),
(
"Time64(s)",
"Error Time64 time unit must be 'µs' or 'ns', got 's'",
),
(
"Time64(ms)",
"Error Time64 time unit must be 'µs' or 'ns', got 'ms'",
),
// Decimals can't have scale exceeding precision
(
"Decimal32(5, 6)",
"Error Decimal32 scale '6' cannot be greater than precision '5'",
),
(
"Decimal64(5, 6)",
"Error Decimal64 scale '6' cannot be greater than precision '5'",
),
(
"Decimal128(5, 6)",
"Error Decimal128 scale '6' cannot be greater than precision '5'",
),
(
"Decimal256(5, 6)",
"Error Decimal256 scale '6' cannot be greater than precision '5'",
),
// Decimals have a max supported precision
(
"Decimal32(10, 0)",
"Error Decimal32 precision must be in range [1, 9], got '10'",
),
(
"Decimal64(19, 0)",
"Error Decimal64 precision must be in range [1, 18], got '19'",
),
(
"Decimal128(39, 0)",
"Error Decimal128 precision must be in range [1, 38], got '39'",
),
(
"Decimal256(77, 0)",
"Error Decimal256 precision must be in range [1, 76], got '77'",
),
// Decimals precision can't be 0
(
"Decimal32(0, 0)",
"Error Decimal32 precision must be in range [1, 9], got '0'",
),
(
"Decimal64(0, 0)",
"Error Decimal64 precision must be in range [1, 18], got '0'",
),
(
"Decimal128(0, 0)",
"Error Decimal128 precision must be in range [1, 38], got '0'",
),
(
"Decimal256(0, 0)",
"Error Decimal256 precision must be in range [1, 76], got '0'",
),
];

for (data_type_string, expected_message) in cases {
Expand Down
Loading