diff --git a/crates/fluss/Cargo.toml b/crates/fluss/Cargo.toml index 8942ffc7..c3bdd447 100644 --- a/crates/fluss/Cargo.toml +++ b/crates/fluss/Cargo.toml @@ -48,7 +48,7 @@ tokio = { workspace = true } parking_lot = "0.12" bytes = "1.10.1" dashmap = "6.1.0" -rust_decimal = "1" +bigdecimal = { version = "0.4", features = ["serde"] } ordered-float = { version = "5", features = ["serde"] } parse-display = "0.10" ref-cast = "1.0" diff --git a/crates/fluss/src/metadata/datatype.rs b/crates/fluss/src/metadata/datatype.rs index f1574665..e3652370 100644 --- a/crates/fluss/src/metadata/datatype.rs +++ b/crates/fluss/src/metadata/datatype.rs @@ -453,16 +453,40 @@ impl DecimalType { pub const DEFAULT_SCALE: u32 = 0; - pub fn new(precision: u32, scale: u32) -> Self { + pub fn new(precision: u32, scale: u32) -> Result { Self::with_nullable(true, precision, scale) } - pub fn with_nullable(nullable: bool, precision: u32, scale: u32) -> Self { - DecimalType { + /// Create a DecimalType with validation, returning an error if parameters are invalid. + pub fn with_nullable(nullable: bool, precision: u32, scale: u32) -> Result { + // Validate precision + if !(Self::MIN_PRECISION..=Self::MAX_PRECISION).contains(&precision) { + return Err(IllegalArgument { + message: format!( + "Decimal precision must be between {} and {} (both inclusive), got: {}", + Self::MIN_PRECISION, + Self::MAX_PRECISION, + precision + ), + }); + } + // Validate scale + // Note: MIN_SCALE is 0, and scale is u32, so scale >= MIN_SCALE is always true + if scale > precision { + return Err(IllegalArgument { + message: format!( + "Decimal scale must be between {} and the precision {} (both inclusive), got: {}", + Self::MIN_SCALE, + precision, + scale + ), + }); + } + Ok(DecimalType { nullable, precision, scale, - } + }) } pub fn precision(&self) -> u32 { @@ -475,6 +499,7 @@ impl DecimalType { pub fn as_non_nullable(&self) -> Self { Self::with_nullable(false, self.precision, self.scale) + .expect("Invalid decimal precision or scale") } } @@ -531,7 +556,7 @@ pub struct TimeType { impl TimeType { fn default() -> Self { - Self::new(Self::DEFAULT_PRECISION) + Self::new(Self::DEFAULT_PRECISION).expect("Invalid default time precision") } } @@ -542,15 +567,27 @@ impl TimeType { pub const DEFAULT_PRECISION: u32 = 0; - pub fn new(precision: u32) -> Self { + pub fn new(precision: u32) -> Result { Self::with_nullable(true, precision) } - pub fn with_nullable(nullable: bool, precision: u32) -> Self { - TimeType { + /// Create a TimeType with validation, returning an error if precision is invalid. + pub fn with_nullable(nullable: bool, precision: u32) -> Result { + // Validate precision + if !(Self::MIN_PRECISION..=Self::MAX_PRECISION).contains(&precision) { + return Err(IllegalArgument { + message: format!( + "Time precision must be between {} and {} (both inclusive), got: {}", + Self::MIN_PRECISION, + Self::MAX_PRECISION, + precision + ), + }); + } + Ok(TimeType { nullable, precision, - } + }) } pub fn precision(&self) -> u32 { @@ -558,7 +595,7 @@ impl TimeType { } pub fn as_non_nullable(&self) -> Self { - Self::with_nullable(false, self.precision) + Self::with_nullable(false, self.precision).expect("Invalid time precision") } } @@ -580,7 +617,7 @@ pub struct TimestampType { impl Default for TimestampType { fn default() -> Self { - Self::new(Self::DEFAULT_PRECISION) + Self::new(Self::DEFAULT_PRECISION).expect("Invalid default timestamp precision") } } @@ -591,15 +628,27 @@ impl TimestampType { pub const DEFAULT_PRECISION: u32 = 6; - pub fn new(precision: u32) -> Self { + pub fn new(precision: u32) -> Result { Self::with_nullable(true, precision) } - pub fn with_nullable(nullable: bool, precision: u32) -> Self { - TimestampType { + /// Create a TimestampType with validation, returning an error if precision is invalid. + pub fn with_nullable(nullable: bool, precision: u32) -> Result { + // Validate precision + if !(Self::MIN_PRECISION..=Self::MAX_PRECISION).contains(&precision) { + return Err(IllegalArgument { + message: format!( + "Timestamp precision must be between {} and {} (both inclusive), got: {}", + Self::MIN_PRECISION, + Self::MAX_PRECISION, + precision + ), + }); + } + Ok(TimestampType { nullable, precision, - } + }) } pub fn precision(&self) -> u32 { @@ -607,7 +656,7 @@ impl TimestampType { } pub fn as_non_nullable(&self) -> Self { - Self::with_nullable(false, self.precision) + Self::with_nullable(false, self.precision).expect("Invalid timestamp precision") } } @@ -630,6 +679,7 @@ pub struct TimestampLTzType { impl Default for TimestampLTzType { fn default() -> Self { Self::new(Self::DEFAULT_PRECISION) + .expect("Invalid default timestamp with local time zone precision") } } @@ -640,15 +690,27 @@ impl TimestampLTzType { pub const DEFAULT_PRECISION: u32 = 6; - pub fn new(precision: u32) -> Self { + pub fn new(precision: u32) -> Result { Self::with_nullable(true, precision) } - pub fn with_nullable(nullable: bool, precision: u32) -> Self { - TimestampLTzType { + /// Create a TimestampLTzType with validation, returning an error if precision is invalid. + pub fn with_nullable(nullable: bool, precision: u32) -> Result { + // Validate precision + if !(Self::MIN_PRECISION..=Self::MAX_PRECISION).contains(&precision) { + return Err(IllegalArgument { + message: format!( + "Timestamp with local time zone precision must be between {} and {} (both inclusive), got: {}", + Self::MIN_PRECISION, + Self::MAX_PRECISION, + precision + ), + }); + } + Ok(TimestampLTzType { nullable, precision, - } + }) } pub fn precision(&self) -> u32 { @@ -657,6 +719,7 @@ impl TimestampLTzType { pub fn as_non_nullable(&self) -> Self { Self::with_nullable(false, self.precision) + .expect("Invalid timestamp with local time zone precision") } } @@ -985,7 +1048,7 @@ impl DataTypes { /// digits to the right of the decimal point in a number (=scale). `p` must have a value /// between 1 and 38 (both inclusive). `s` must have a value between 0 and `p` (both inclusive). pub fn decimal(precision: u32, scale: u32) -> DataType { - DataType::Decimal(DecimalType::new(precision, scale)) + DataType::Decimal(DecimalType::new(precision, scale).expect("Invalid decimal parameters")) } pub fn date() -> DataType { @@ -1000,7 +1063,7 @@ impl DataTypes { /// Data type of a time WITHOUT time zone `TIME(p)` where `p` is the number of digits /// of fractional seconds (=precision). `p` must have a value between 0 and 9 (both inclusive). pub fn time_with_precision(precision: u32) -> DataType { - DataType::Time(TimeType::new(precision)) + DataType::Time(TimeType::new(precision).expect("Invalid time precision")) } /// Data type of a timestamp WITHOUT time zone `TIMESTAMP` with 6 digits of fractional @@ -1013,7 +1076,7 @@ impl DataTypes { /// of digits of fractional seconds (=precision). `p` must have a value between 0 and 9 /// (both inclusive). pub fn timestamp_with_precision(precision: u32) -> DataType { - DataType::Timestamp(TimestampType::new(precision)) + DataType::Timestamp(TimestampType::new(precision).expect("Invalid timestamp precision")) } /// Data type of a timestamp WITH time zone `TIMESTAMP WITH TIME ZONE` with 6 digits of @@ -1025,7 +1088,10 @@ impl DataTypes { /// Data type of a timestamp WITH time zone `TIMESTAMP WITH TIME ZONE(p)` where `p` is the number /// of digits of fractional seconds (=precision). `p` must have a value between 0 and 9 (both inclusive). pub fn timestamp_ltz_with_precision(precision: u32) -> DataType { - DataType::TimestampLTz(TimestampLTzType::new(precision)) + DataType::TimestampLTz( + TimestampLTzType::new(precision) + .expect("Invalid timestamp with local time zone precision"), + ) } /// Data type of an array of elements with same subtype. @@ -1100,82 +1166,56 @@ impl Display for DataField { } #[test] -fn test_boolean_display() { +fn test_primitive_types_display() { + // Test simple primitive types with nullable and non-nullable variants assert_eq!(BooleanType::new().to_string(), "BOOLEAN"); assert_eq!( BooleanType::with_nullable(false).to_string(), "BOOLEAN NOT NULL" ); -} -#[test] -fn test_tinyint_display() { assert_eq!(TinyIntType::new().to_string(), "TINYINT"); assert_eq!( TinyIntType::with_nullable(false).to_string(), "TINYINT NOT NULL" ); -} -#[test] -fn test_smallint_display() { assert_eq!(SmallIntType::new().to_string(), "SMALLINT"); assert_eq!( SmallIntType::with_nullable(false).to_string(), "SMALLINT NOT NULL" ); -} -#[test] -fn test_int_display() { assert_eq!(IntType::new().to_string(), "INT"); assert_eq!(IntType::with_nullable(false).to_string(), "INT NOT NULL"); -} -#[test] -fn test_bigint_display() { assert_eq!(BigIntType::new().to_string(), "BIGINT"); assert_eq!( BigIntType::with_nullable(false).to_string(), "BIGINT NOT NULL" ); -} -#[test] -fn test_float_display() { assert_eq!(FloatType::new().to_string(), "FLOAT"); assert_eq!( FloatType::with_nullable(false).to_string(), "FLOAT NOT NULL" ); -} -#[test] -fn test_double_display() { assert_eq!(DoubleType::new().to_string(), "DOUBLE"); assert_eq!( DoubleType::with_nullable(false).to_string(), "DOUBLE NOT NULL" ); -} -#[test] -fn test_string_display() { assert_eq!(StringType::new().to_string(), "STRING"); assert_eq!( StringType::with_nullable(false).to_string(), "STRING NOT NULL" ); -} -#[test] -fn test_date_display() { assert_eq!(DateType::new().to_string(), "DATE"); assert_eq!(DateType::with_nullable(false).to_string(), "DATE NOT NULL"); -} -#[test] -fn test_bytes_display() { assert_eq!(BytesType::new().to_string(), "BYTES"); assert_eq!( BytesType::with_nullable(false).to_string(), @@ -1184,59 +1224,58 @@ fn test_bytes_display() { } #[test] -fn test_char_display() { +fn test_parameterized_types_display() { + // Test types with parameters (length, precision, scale, etc.) assert_eq!(CharType::new(10).to_string(), "CHAR(10)"); assert_eq!( CharType::with_nullable(20, false).to_string(), "CHAR(20) NOT NULL" ); -} -#[test] -fn test_decimal_display() { - assert_eq!(DecimalType::new(10, 2).to_string(), "DECIMAL(10, 2)"); + assert_eq!(BinaryType::new(100).to_string(), "BINARY(100)"); + assert_eq!( + BinaryType::with_nullable(false, 256).to_string(), + "BINARY(256) NOT NULL" + ); + assert_eq!( - DecimalType::with_nullable(false, 38, 10).to_string(), + DecimalType::new(10, 2).unwrap().to_string(), + "DECIMAL(10, 2)" + ); + assert_eq!( + DecimalType::with_nullable(false, 38, 10) + .unwrap() + .to_string(), "DECIMAL(38, 10) NOT NULL" ); -} -#[test] -fn test_time_display() { - assert_eq!(TimeType::new(0).to_string(), "TIME(0)"); - assert_eq!(TimeType::new(3).to_string(), "TIME(3)"); + assert_eq!(TimeType::new(0).unwrap().to_string(), "TIME(0)"); + assert_eq!(TimeType::new(3).unwrap().to_string(), "TIME(3)"); assert_eq!( - TimeType::with_nullable(false, 9).to_string(), + TimeType::with_nullable(false, 9).unwrap().to_string(), "TIME(9) NOT NULL" ); -} -#[test] -fn test_timestamp_display() { - assert_eq!(TimestampType::new(6).to_string(), "TIMESTAMP(6)"); - assert_eq!(TimestampType::new(0).to_string(), "TIMESTAMP(0)"); + assert_eq!(TimestampType::new(6).unwrap().to_string(), "TIMESTAMP(6)"); + assert_eq!(TimestampType::new(0).unwrap().to_string(), "TIMESTAMP(0)"); assert_eq!( - TimestampType::with_nullable(false, 9).to_string(), + TimestampType::with_nullable(false, 9).unwrap().to_string(), "TIMESTAMP(9) NOT NULL" ); -} -#[test] -fn test_timestamp_ltz_display() { - assert_eq!(TimestampLTzType::new(6).to_string(), "TIMESTAMP_LTZ(6)"); - assert_eq!(TimestampLTzType::new(3).to_string(), "TIMESTAMP_LTZ(3)"); assert_eq!( - TimestampLTzType::with_nullable(false, 9).to_string(), - "TIMESTAMP_LTZ(9) NOT NULL" + TimestampLTzType::new(6).unwrap().to_string(), + "TIMESTAMP_LTZ(6)" ); -} - -#[test] -fn test_binary_display() { - assert_eq!(BinaryType::new(100).to_string(), "BINARY(100)"); assert_eq!( - BinaryType::with_nullable(false, 256).to_string(), - "BINARY(256) NOT NULL" + TimestampLTzType::new(3).unwrap().to_string(), + "TIMESTAMP_LTZ(3)" + ); + assert_eq!( + TimestampLTzType::with_nullable(false, 9) + .unwrap() + .to_string(), + "TIMESTAMP_LTZ(9) NOT NULL" ); } @@ -1352,3 +1391,68 @@ fn test_deeply_nested_types() { )); assert_eq!(nested.to_string(), "ARRAY>>"); } + +#[test] +fn test_decimal_invalid_precision() { + // DecimalType::with_nullable should return an error for invalid precision + let result = DecimalType::with_nullable(true, 50, 2); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Decimal precision must be between 1 and 38") + ); +} + +#[test] +fn test_decimal_invalid_scale() { + // DecimalType::with_nullable should return an error when scale > precision + let result = DecimalType::with_nullable(true, 10, 15); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Decimal scale must be between 0 and the precision 10") + ); +} + +#[test] +fn test_time_invalid_precision() { + // TimeType::with_nullable should return an error for invalid precision + let result = TimeType::with_nullable(true, 10); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Time precision must be between 0 and 9") + ); +} + +#[test] +fn test_timestamp_invalid_precision() { + // TimestampType::with_nullable should return an error for invalid precision + let result = TimestampType::with_nullable(true, 10); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Timestamp precision must be between 0 and 9") + ); +} + +#[test] +fn test_timestamp_ltz_invalid_precision() { + // TimestampLTzType::with_nullable should return an error for invalid precision + let result = TimestampLTzType::with_nullable(true, 10); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Timestamp with local time zone precision must be between 0 and 9") + ); +} diff --git a/crates/fluss/src/metadata/json_serde.rs b/crates/fluss/src/metadata/json_serde.rs index 7d94e194..faa5583b 100644 --- a/crates/fluss/src/metadata/json_serde.rs +++ b/crates/fluss/src/metadata/json_serde.rs @@ -202,7 +202,12 @@ impl JsonSerde for DataType { .get(Self::FIELD_NAME_SCALE) .and_then(|v| v.as_u64()) .unwrap_or(0) as u32; - DataTypes::decimal(precision, scale) + DataType::Decimal( + crate::metadata::datatype::DecimalType::with_nullable(true, precision, scale) + .map_err(|e| Error::JsonSerdeError { + message: format!("Invalid DECIMAL parameters: {}", e), + })?, + ) } "DATE" => DataTypes::date(), "TIME_WITHOUT_TIME_ZONE" => { @@ -210,21 +215,43 @@ impl JsonSerde for DataType { .get(Self::FIELD_NAME_PRECISION) .and_then(|v| v.as_u64()) .unwrap_or(0) as u32; - DataTypes::time_with_precision(precision) + DataType::Time( + crate::metadata::datatype::TimeType::with_nullable(true, precision).map_err( + |e| Error::JsonSerdeError { + message: format!("Invalid TIME_WITHOUT_TIME_ZONE precision: {}", e), + }, + )?, + ) } "TIMESTAMP_WITHOUT_TIME_ZONE" => { let precision = node .get(Self::FIELD_NAME_PRECISION) .and_then(|v| v.as_u64()) .unwrap_or(6) as u32; - DataTypes::timestamp_with_precision(precision) + DataType::Timestamp( + crate::metadata::datatype::TimestampType::with_nullable(true, precision) + .map_err(|e| Error::JsonSerdeError { + message: format!( + "Invalid TIMESTAMP_WITHOUT_TIME_ZONE precision: {}", + e + ), + })?, + ) } "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => { let precision = node .get(Self::FIELD_NAME_PRECISION) .and_then(|v| v.as_u64()) .unwrap_or(6) as u32; - DataTypes::timestamp_ltz_with_precision(precision) + DataType::TimestampLTz( + crate::metadata::datatype::TimestampLTzType::with_nullable(true, precision) + .map_err(|e| Error::JsonSerdeError { + message: format!( + "Invalid TIMESTAMP_WITH_LOCAL_TIME_ZONE precision: {}", + e + ), + })?, + ) } "BYTES" => DataTypes::bytes(), "BINARY" => { @@ -689,4 +716,81 @@ mod tests { assert_eq!(dt, deserialized); } } + + #[test] + fn test_invalid_datatype_validation() { + use serde_json::json; + + // Invalid DECIMAL precision (> 38) + let invalid_decimal = json!({ + "type": "DECIMAL", + "precision": 50, + "scale": 2 + }); + let result = DataType::deserialize_json(&invalid_decimal); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid DECIMAL parameters") + ); + + // Invalid TIME precision (> 9) + let invalid_time = json!({ + "type": "TIME_WITHOUT_TIME_ZONE", + "precision": 15 + }); + let result = DataType::deserialize_json(&invalid_time); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid TIME_WITHOUT_TIME_ZONE precision") + ); + + // Invalid TIMESTAMP precision (> 9) + let invalid_timestamp = json!({ + "type": "TIMESTAMP_WITHOUT_TIME_ZONE", + "precision": 20 + }); + let result = DataType::deserialize_json(&invalid_timestamp); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid TIMESTAMP_WITHOUT_TIME_ZONE precision") + ); + + // Invalid TIMESTAMP_LTZ precision (> 9) + let invalid_timestamp_ltz = json!({ + "type": "TIMESTAMP_WITH_LOCAL_TIME_ZONE", + "precision": 10 + }); + let result = DataType::deserialize_json(&invalid_timestamp_ltz); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid TIMESTAMP_WITH_LOCAL_TIME_ZONE precision") + ); + + // Invalid DECIMAL scale (> precision) + let invalid_decimal_scale = json!({ + "type": "DECIMAL", + "precision": 10, + "scale": 15 + }); + let result = DataType::deserialize_json(&invalid_decimal_scale); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid DECIMAL parameters") + ); + } } diff --git a/crates/fluss/src/record/arrow.rs b/crates/fluss/src/record/arrow.rs index b331ae9d..bd1eee5f 100644 --- a/crates/fluss/src/record/arrow.rs +++ b/crates/fluss/src/record/arrow.rs @@ -1056,8 +1056,7 @@ pub struct MyVec(pub StreamReader); mod tests { use super::*; use crate::error::Error; - use crate::metadata::DataField; - use crate::metadata::DataTypes; + use crate::metadata::{DataField, DataTypes}; #[test] fn test_to_array_type() { @@ -1161,24 +1160,6 @@ mod tests { ); } - #[test] - #[should_panic(expected = "Invalid precision value for TimeType: 10")] - fn test_time_invalid_precision() { - to_arrow_type(&DataTypes::time_with_precision(10)); - } - - #[test] - #[should_panic(expected = "Invalid precision value for TimestampType: 10")] - fn test_timestamp_invalid_precision() { - to_arrow_type(&DataTypes::timestamp_with_precision(10)); - } - - #[test] - #[should_panic(expected = "Invalid precision value for TimestampLTzType: 10")] - fn test_timestamp_ltz_invalid_precision() { - to_arrow_type(&DataTypes::timestamp_ltz_with_precision(10)); - } - #[test] fn test_parse_ipc_message() { let empty_body: &[u8] = &le_bytes(&[0xFFFFFFFF, 0x00000000]); diff --git a/crates/fluss/src/row/binary/binary_writer.rs b/crates/fluss/src/row/binary/binary_writer.rs index 9917c7b7..af2765c4 100644 --- a/crates/fluss/src/row/binary/binary_writer.rs +++ b/crates/fluss/src/row/binary/binary_writer.rs @@ -52,14 +52,20 @@ pub trait BinaryWriter { fn write_binary(&mut self, bytes: &[u8], length: usize); - // TODO Decimal type - // fn write_decimal(&mut self, pos: i32, value: f64); + fn write_decimal(&mut self, value: &crate::row::Decimal, precision: u32); - // TODO Timestamp type - // fn write_timestamp_ntz(&mut self, pos: i32, value: i64); + /// Writes a TIME value. + /// + /// Note: TIME is physically stored as an i32 (milliseconds since midnight). + /// This method exists for type safety and semantic clarity, even though it's + /// currently equivalent to `write_int()`. The precision parameter is accepted + /// for API consistency with TIMESTAMP types, though TIME encoding doesn't + /// currently vary by precision. + fn write_time(&mut self, value: i32, precision: u32); - // TODO Timestamp type - // fn write_timestamp_ltz(&mut self, pos: i32, value: i64); + fn write_timestamp_ntz(&mut self, value: &crate::row::datum::TimestampNtz, precision: u32); + + fn write_timestamp_ltz(&mut self, value: &crate::row::datum::TimestampLtz, precision: u32); // TODO InternalArray, ArraySerializer // fn write_array(&mut self, pos: i32, value: i64); @@ -125,7 +131,12 @@ pub enum InnerValueWriter { BigInt, Float, Double, - // TODO Decimal, Date, TimeWithoutTimeZone, TimestampWithoutTimeZone, TimestampWithLocalTimeZone, Array, Row + Decimal(u32, u32), // precision, scale + Date, + Time(u32), // precision (not used in wire format, but kept for consistency) + TimestampNtz(u32), // precision + TimestampLtz(u32), // precision + // TODO Array, Row } /// Accessor for writing the fields/elements of a binary writer during runtime, the @@ -147,6 +158,23 @@ impl InnerValueWriter { DataType::BigInt(_) => Ok(InnerValueWriter::BigInt), DataType::Float(_) => Ok(InnerValueWriter::Float), DataType::Double(_) => Ok(InnerValueWriter::Double), + DataType::Decimal(d) => { + // Validation is done at DecimalType construction time + Ok(InnerValueWriter::Decimal(d.precision(), d.scale())) + } + DataType::Date(_) => Ok(InnerValueWriter::Date), + DataType::Time(t) => { + // Validation is done at TimeType construction time + Ok(InnerValueWriter::Time(t.precision())) + } + DataType::Timestamp(t) => { + // Validation is done at TimestampType construction time + Ok(InnerValueWriter::TimestampNtz(t.precision())) + } + DataType::TimestampLTz(t) => { + // Validation is done at TimestampLTzType construction time + Ok(InnerValueWriter::TimestampLtz(t.precision())) + } _ => unimplemented!( "ValueWriter for DataType {:?} is currently not implemented", data_type @@ -194,6 +222,21 @@ impl InnerValueWriter { (InnerValueWriter::Double, Datum::Float64(v)) => { writer.write_double(v.into_inner()); } + (InnerValueWriter::Decimal(p, _s), Datum::Decimal(v)) => { + writer.write_decimal(v, *p); + } + (InnerValueWriter::Date, Datum::Date(d)) => { + writer.write_int(d.get_inner()); + } + (InnerValueWriter::Time(p), Datum::Time(t)) => { + writer.write_time(t.get_inner(), *p); + } + (InnerValueWriter::TimestampNtz(p), Datum::TimestampNtz(ts)) => { + writer.write_timestamp_ntz(ts, *p); + } + (InnerValueWriter::TimestampLtz(p), Datum::TimestampLtz(ts)) => { + writer.write_timestamp_ltz(ts, *p); + } _ => { return Err(IllegalArgument { message: format!("{self:?} used to write value {value:?}"), diff --git a/crates/fluss/src/row/column.rs b/crates/fluss/src/row/column.rs index 90437c11..615e0384 100644 --- a/crates/fluss/src/row/column.rs +++ b/crates/fluss/src/row/column.rs @@ -17,9 +17,10 @@ use crate::row::InternalRow; use arrow::array::{ - AsArray, BinaryArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int8Array, Int16Array, - Int32Array, Int64Array, RecordBatch, StringArray, + Array, AsArray, BinaryArray, Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, + Int8Array, Int16Array, Int32Array, Int64Array, RecordBatch, StringArray, }; +use arrow::datatypes::{DataType as ArrowDataType, TimeUnit}; use std::sync::Arc; #[derive(Clone)] @@ -54,6 +55,49 @@ impl ColumnarRow { pub fn get_record_batch(&self) -> &RecordBatch { &self.record_batch } + + /// Generic helper to read timestamp from Arrow, handling all TimeUnit conversions. + /// Like Java, the precision parameter is ignored - conversion is determined by Arrow TimeUnit. + fn read_timestamp_from_arrow( + &self, + pos: usize, + _precision: u32, + construct_compact: impl FnOnce(i64) -> T, + construct_with_nanos: impl FnOnce(i64, i32) -> crate::error::Result, + ) -> T { + let schema = self.record_batch.schema(); + let arrow_field = schema.field(pos); + let value = self.get_long(pos); + + match arrow_field.data_type() { + ArrowDataType::Timestamp(time_unit, _) => { + // Convert based on Arrow TimeUnit + let (millis, nanos) = match time_unit { + TimeUnit::Second => (value * 1000, 0), + TimeUnit::Millisecond => (value, 0), + TimeUnit::Microsecond => { + let millis = value / 1000; + let nanos = ((value % 1000) * 1000) as i32; + (millis, nanos) + } + TimeUnit::Nanosecond => { + let millis = value / 1_000_000; + let nanos = (value % 1_000_000) as i32; + (millis, nanos) + } + }; + + if nanos == 0 { + construct_compact(millis) + } else { + // nanos is guaranteed to be in valid range [0, 999_999] by arithmetic + construct_with_nanos(millis, nanos) + .expect("nanos in valid range by construction") + } + } + other => panic!("Expected Timestamp column at position {pos}, got {other:?}"), + } + } } impl InternalRow for ColumnarRow { @@ -126,6 +170,88 @@ impl InternalRow for ColumnarRow { .value(self.row_id) } + fn get_decimal(&self, pos: usize, precision: usize, scale: usize) -> crate::row::Decimal { + use arrow::datatypes::DataType; + + let column = self.record_batch.column(pos); + let array = column + .as_any() + .downcast_ref::() + .unwrap_or_else(|| { + panic!( + "Expected Decimal128Array at column {}, found: {:?}", + pos, + column.data_type() + ) + }); + + // Contract: caller must check is_null_at() before calling get_decimal. + // Calling on null value violates the contract and returns garbage data + debug_assert!( + !array.is_null(self.row_id), + "get_decimal called on null value at pos {} row {}", + pos, + self.row_id + ); + + // Read scale from Arrow schema field metadata + let schema = self.record_batch.schema(); + let field = schema.field(pos); + let arrow_scale = match field.data_type() { + DataType::Decimal128(_p, s) => *s as i64, + dt => panic!( + "Expected Decimal128 data type at column {}, found: {:?}", + pos, dt + ), + }; + + let i128_val = array.value(self.row_id); + + // Convert Arrow Decimal128 to Fluss Decimal (handles rescaling and validation) + crate::row::Decimal::from_arrow_decimal128( + i128_val, + arrow_scale, + precision as u32, + scale as u32, + ) + .unwrap_or_else(|e| { + panic!( + "Failed to create Decimal at column {} row {}: {}", + pos, self.row_id, e + ) + }) + } + + fn get_date(&self, pos: usize) -> crate::row::datum::Date { + crate::row::datum::Date::new(self.get_int(pos)) + } + + fn get_time(&self, pos: usize) -> crate::row::datum::Time { + crate::row::datum::Time::new(self.get_int(pos)) + } + + fn get_timestamp_ntz(&self, pos: usize, precision: u32) -> crate::row::datum::TimestampNtz { + // Like Java's ArrowTimestampNtzColumnVector, we ignore the precision parameter + // and determine the conversion from the Arrow column's TimeUnit. + self.read_timestamp_from_arrow( + pos, + precision, + crate::row::datum::TimestampNtz::new, + crate::row::datum::TimestampNtz::from_millis_nanos, + ) + } + + fn get_timestamp_ltz(&self, pos: usize, precision: u32) -> crate::row::datum::TimestampLtz { + // Like Java's ArrowTimestampLtzColumnVector, we ignore the precision parameter + // and determine the conversion from the Arrow column's TimeUnit. + self.read_timestamp_from_arrow( + pos, + precision, + crate::row::datum::TimestampLtz::new, + crate::row::datum::TimestampLtz::from_millis_nanos, + ) + } + fn get_char(&self, pos: usize, _length: usize) -> &str { let array = self .record_batch @@ -229,4 +355,72 @@ mod tests { row.set_row_id(0); assert_eq!(row.get_row_id(), 0); } + + #[test] + fn columnar_row_reads_decimal() { + use arrow::datatypes::DataType; + use bigdecimal::{BigDecimal, num_bigint::BigInt}; + + // Test with Decimal128 + let schema = Arc::new(Schema::new(vec![ + Field::new("dec1", DataType::Decimal128(10, 2), false), + Field::new("dec2", DataType::Decimal128(20, 5), false), + Field::new("dec3", DataType::Decimal128(38, 10), false), + ])); + + // Create decimal values: 123.45, 12345.67890, large decimal + let dec1_val = 12345i128; // 123.45 with scale 2 + let dec2_val = 1234567890i128; // 12345.67890 with scale 5 + let dec3_val = 999999999999999999i128; // Large value (18 nines) with scale 10 + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new( + Decimal128Array::from(vec![dec1_val]) + .with_precision_and_scale(10, 2) + .unwrap(), + ), + Arc::new( + Decimal128Array::from(vec![dec2_val]) + .with_precision_and_scale(20, 5) + .unwrap(), + ), + Arc::new( + Decimal128Array::from(vec![dec3_val]) + .with_precision_and_scale(38, 10) + .unwrap(), + ), + ], + ) + .expect("record batch"); + + let row = ColumnarRow::new(Arc::new(batch)); + assert_eq!(row.get_field_count(), 3); + + // Verify decimal values + assert_eq!( + row.get_decimal(0, 10, 2), + crate::row::Decimal::from_big_decimal(BigDecimal::new(BigInt::from(12345), 2), 10, 2) + .unwrap() + ); + assert_eq!( + row.get_decimal(1, 20, 5), + crate::row::Decimal::from_big_decimal( + BigDecimal::new(BigInt::from(1234567890), 5), + 20, + 5 + ) + .unwrap() + ); + assert_eq!( + row.get_decimal(2, 38, 10), + crate::row::Decimal::from_big_decimal( + BigDecimal::new(BigInt::from(999999999999999999i128), 10), + 38, + 10 + ) + .unwrap() + ); + } } diff --git a/crates/fluss/src/row/compacted/compacted_key_writer.rs b/crates/fluss/src/row/compacted/compacted_key_writer.rs index 1152b0c5..339e3661 100644 --- a/crates/fluss/src/row/compacted/compacted_key_writer.rs +++ b/crates/fluss/src/row/compacted/compacted_key_writer.rs @@ -20,6 +20,7 @@ use bytes::Bytes; use crate::error::Result; use crate::metadata::DataType; +use crate::row::Decimal; use crate::row::binary::{BinaryRowFormat, BinaryWriter, ValueWriter}; use delegate::delegate; @@ -93,7 +94,13 @@ impl BinaryWriter for CompactedKeyWriter { fn write_double(&mut self, value: f64); + fn write_decimal(&mut self, value: &Decimal, precision: u32); + fn write_time(&mut self, value: i32, precision: u32); + + fn write_timestamp_ntz(&mut self, value: &crate::row::datum::TimestampNtz, precision: u32); + + fn write_timestamp_ltz(&mut self, value: &crate::row::datum::TimestampLtz, precision: u32); } } diff --git a/crates/fluss/src/row/compacted/compacted_row.rs b/crates/fluss/src/row/compacted/compacted_row.rs index 144f8985..bc68ea10 100644 --- a/crates/fluss/src/row/compacted/compacted_row.rs +++ b/crates/fluss/src/row/compacted/compacted_row.rs @@ -133,6 +133,26 @@ impl<'a> InternalRow for CompactedRow<'a> { fn get_bytes(&self, pos: usize) -> &[u8] { self.decoded_row().get_bytes(pos) } + + fn get_decimal(&self, pos: usize, precision: usize, scale: usize) -> crate::row::Decimal { + self.decoded_row().get_decimal(pos, precision, scale) + } + + fn get_date(&self, pos: usize) -> crate::row::datum::Date { + self.decoded_row().get_date(pos) + } + + fn get_time(&self, pos: usize) -> crate::row::datum::Time { + self.decoded_row().get_time(pos) + } + + fn get_timestamp_ntz(&self, pos: usize, precision: u32) -> crate::row::datum::TimestampNtz { + self.decoded_row().get_timestamp_ntz(pos, precision) + } + + fn get_timestamp_ltz(&self, pos: usize, precision: u32) -> crate::row::datum::TimestampLtz { + self.decoded_row().get_timestamp_ltz(pos, precision) + } } #[cfg(test)] @@ -174,7 +194,7 @@ mod tests { writer.write_bytes(&[1, 2, 3, 4, 5]); let bytes = writer.to_bytes(); - let mut row = CompactedRow::from_bytes(&row_type, bytes.as_ref()); + let row = CompactedRow::from_bytes(&row_type, bytes.as_ref()); assert_eq!(row.get_field_count(), 9); assert!(row.get_boolean(0)); @@ -187,70 +207,107 @@ mod tests { assert_eq!(row.get_string(7), "Hello World"); assert_eq!(row.get_bytes(8), &[1, 2, 3, 4, 5]); - // Test with nulls - let row_type = RowType::with_data_types( - [ - DataType::Int(IntType::new()), - DataType::String(StringType::new()), - DataType::Double(DoubleType::new()), - ] - .to_vec(), - ); + // Test with nulls and negative values + let row_type = RowType::with_data_types(vec![ + DataType::Int(IntType::new()), + DataType::String(StringType::new()), + DataType::Double(DoubleType::new()), + ]); let mut writer = CompactedRowWriter::new(row_type.fields().len()); - - writer.write_int(100); + writer.write_int(-42); writer.set_null_at(1); writer.write_double(2.71); let bytes = writer.to_bytes(); - row = CompactedRow::from_bytes(&row_type, bytes.as_ref()); + let row = CompactedRow::from_bytes(&row_type, bytes.as_ref()); assert!(!row.is_null_at(0)); assert!(row.is_null_at(1)); assert!(!row.is_null_at(2)); - assert_eq!(row.get_int(0), 100); + assert_eq!(row.get_int(0), -42); assert_eq!(row.get_double(2), 2.71); + // Verify caching works on repeated reads + assert_eq!(row.get_int(0), -42); + } - // Test multiple reads (caching) - assert_eq!(row.get_int(0), 100); - assert_eq!(row.get_int(0), 100); + #[test] + fn test_compacted_row_temporal_and_decimal_types() { + // Comprehensive test covering DATE, TIME, TIMESTAMP (compact/non-compact), and DECIMAL (compact/non-compact) + use crate::metadata::{DataTypes, DecimalType, TimestampLTzType, TimestampType}; + use crate::row::Decimal; + use crate::row::datum::{TimestampLtz, TimestampNtz}; + use bigdecimal::{BigDecimal, num_bigint::BigInt}; - // Test from_bytes let row_type = RowType::with_data_types(vec![ - DataType::Int(IntType::new()), - DataType::String(StringType::new()), + DataTypes::date(), + DataTypes::time(), + DataType::Timestamp(TimestampType::with_nullable(true, 3).unwrap()), // Compact (precision <= 3) + DataType::TimestampLTz(TimestampLTzType::with_nullable(true, 3).unwrap()), // Compact + DataType::Timestamp(TimestampType::with_nullable(true, 6).unwrap()), // Non-compact (precision > 3) + DataType::TimestampLTz(TimestampLTzType::with_nullable(true, 9).unwrap()), // Non-compact + DataType::Decimal(DecimalType::new(10, 2).unwrap()), // Compact (precision <= 18) + DataType::Decimal(DecimalType::new(28, 10).unwrap()), // Non-compact (precision > 18) ]); let mut writer = CompactedRowWriter::new(row_type.fields().len()); - writer.write_int(-1); - writer.write_string("test"); - - let bytes = writer.to_bytes(); - let mut row = CompactedRow::from_bytes(&row_type, bytes.as_ref()); - - assert_eq!(row.get_int(0), -1); - assert_eq!(row.get_string(1), "test"); - // Test large row - let num_fields = 100; - let row_type = RowType::with_data_types( - (0..num_fields) - .map(|_| DataType::Int(IntType::new())) - .collect(), - ); - - let mut writer = CompactedRowWriter::new(num_fields); + // Write values + writer.write_int(19651); // Date: 2023-10-25 + writer.write_time(34200000, 0); // Time: 09:30:00.0 + writer.write_timestamp_ntz(&TimestampNtz::new(1698235273182), 3); // Compact timestamp + writer.write_timestamp_ltz(&TimestampLtz::new(1698235273182), 3); // Compact timestamp ltz + let ts_ntz_high = TimestampNtz::from_millis_nanos(1698235273182, 123456).unwrap(); + let ts_ltz_high = TimestampLtz::from_millis_nanos(1698235273182, 987654).unwrap(); + writer.write_timestamp_ntz(&ts_ntz_high, 6); // Non-compact timestamp with nanos + writer.write_timestamp_ltz(&ts_ltz_high, 9); // Non-compact timestamp ltz with nanos + + // Create Decimal values for testing + let small_decimal = + Decimal::from_big_decimal(BigDecimal::new(BigInt::from(12345), 2), 10, 2).unwrap(); // Compact decimal: 123.45 + let large_decimal = Decimal::from_big_decimal( + BigDecimal::new(BigInt::from(999999999999999999i128), 10), + 28, + 10, + ) + .unwrap(); // Non-compact decimal - for i in 0..num_fields { - writer.write_int((i * 10) as i32); - } + writer.write_decimal(&small_decimal, 10); + writer.write_decimal(&large_decimal, 28); let bytes = writer.to_bytes(); - row = CompactedRow::from_bytes(&row_type, bytes.as_ref()); - - for i in 0..num_fields { - assert_eq!(row.get_int(i), (i * 10) as i32); - } + let row = CompactedRow::from_bytes(&row_type, bytes.as_ref()); + + // Verify all values + assert_eq!(row.get_date(0).get_inner(), 19651); + assert_eq!(row.get_time(1).get_inner(), 34200000); + assert_eq!(row.get_timestamp_ntz(2, 3).get_millisecond(), 1698235273182); + assert_eq!( + row.get_timestamp_ltz(3, 3).get_epoch_millisecond(), + 1698235273182 + ); + let read_ts_ntz = row.get_timestamp_ntz(4, 6); + assert_eq!(read_ts_ntz.get_millisecond(), 1698235273182); + assert_eq!(read_ts_ntz.get_nano_of_millisecond(), 123456); + let read_ts_ltz = row.get_timestamp_ltz(5, 9); + assert_eq!(read_ts_ltz.get_epoch_millisecond(), 1698235273182); + assert_eq!(read_ts_ltz.get_nano_of_millisecond(), 987654); + // Assert on Decimal equality + assert_eq!(row.get_decimal(6, 10, 2), small_decimal); + assert_eq!(row.get_decimal(7, 28, 10), large_decimal); + + // Assert on Decimal components to catch any regressions + let read_small_decimal = row.get_decimal(6, 10, 2); + assert_eq!(read_small_decimal.precision(), 10); + assert_eq!(read_small_decimal.scale(), 2); + assert_eq!(read_small_decimal.to_unscaled_long().unwrap(), 12345); + + let read_large_decimal = row.get_decimal(7, 28, 10); + assert_eq!(read_large_decimal.precision(), 28); + assert_eq!(read_large_decimal.scale(), 10); + assert_eq!( + read_large_decimal.to_unscaled_long().unwrap(), + 999999999999999999i64 + ); } } diff --git a/crates/fluss/src/row/compacted/compacted_row_reader.rs b/crates/fluss/src/row/compacted/compacted_row_reader.rs index 408706cc..40470db1 100644 --- a/crates/fluss/src/row/compacted/compacted_row_reader.rs +++ b/crates/fluss/src/row/compacted/compacted_row_reader.rs @@ -19,7 +19,7 @@ use crate::metadata::RowType; use crate::row::compacted::compacted_row::calculate_bit_set_width_in_bytes; use crate::{ metadata::DataType, - row::{Datum, GenericRow, compacted::compacted_row_writer::CompactedRowWriter}, + row::{Datum, Decimal, GenericRow, compacted::compacted_row_writer::CompactedRowWriter}, util::varint::{read_unsigned_varint_at, read_unsigned_varint_u64_at}, }; use std::borrow::Cow; @@ -97,7 +97,75 @@ impl<'a> CompactedRowDeserializer<'a> { let (val, next) = reader.read_bytes(cursor); (Datum::Blob(val.into()), next) } - _ => panic!("unsupported DataType in CompactedRowDeserializer"), + DataType::Decimal(decimal_type) => { + let precision = decimal_type.precision(); + let scale = decimal_type.scale(); + if Decimal::is_compact_precision(precision) { + // Compact: stored as i64 + let (val, next) = reader.read_long(cursor); + let decimal = Decimal::from_unscaled_long(val, precision, scale) + .expect("Failed to create decimal from unscaled long"); + (Datum::Decimal(decimal), next) + } else { + // Non-compact: stored as minimal big-endian bytes + let (bytes, next) = reader.read_bytes(cursor); + let decimal = Decimal::from_unscaled_bytes(bytes, precision, scale) + .expect("Failed to create decimal from unscaled bytes"); + (Datum::Decimal(decimal), next) + } + } + DataType::Date(_) => { + let (val, next) = reader.read_int(cursor); + (Datum::Date(crate::row::datum::Date::new(val)), next) + } + DataType::Time(_) => { + let (val, next) = reader.read_int(cursor); + (Datum::Time(crate::row::datum::Time::new(val)), next) + } + DataType::Timestamp(timestamp_type) => { + let precision = timestamp_type.precision(); + if crate::row::datum::TimestampNtz::is_compact(precision) { + // Compact: only milliseconds + let (millis, next) = reader.read_long(cursor); + ( + Datum::TimestampNtz(crate::row::datum::TimestampNtz::new(millis)), + next, + ) + } else { + // Non-compact: milliseconds + nanos + let (millis, mid) = reader.read_long(cursor); + let (nanos, next) = reader.read_int(mid); + let timestamp = + crate::row::datum::TimestampNtz::from_millis_nanos(millis, nanos) + .expect("Invalid nano_of_millisecond value in compacted row"); + (Datum::TimestampNtz(timestamp), next) + } + } + DataType::TimestampLTz(timestamp_ltz_type) => { + let precision = timestamp_ltz_type.precision(); + if crate::row::datum::TimestampLtz::is_compact(precision) { + // Compact: only epoch milliseconds + let (epoch_millis, next) = reader.read_long(cursor); + ( + Datum::TimestampLtz(crate::row::datum::TimestampLtz::new(epoch_millis)), + next, + ) + } else { + // Non-compact: epoch milliseconds + nanos + let (epoch_millis, mid) = reader.read_long(cursor); + let (nanos, next) = reader.read_int(mid); + let timestamp_ltz = + crate::row::datum::TimestampLtz::from_millis_nanos(epoch_millis, nanos) + .expect("Invalid nano_of_millisecond value in compacted row"); + (Datum::TimestampLtz(timestamp_ltz), next) + } + } + _ => { + panic!( + "Unsupported DataType in CompactedRowDeserializer: {:?}", + dtype + ); + } }; cursor = next_cursor; row.set_field(col_pos, datum); diff --git a/crates/fluss/src/row/compacted/compacted_row_writer.rs b/crates/fluss/src/row/compacted/compacted_row_writer.rs index c130e94c..d1ad047a 100644 --- a/crates/fluss/src/row/compacted/compacted_row_writer.rs +++ b/crates/fluss/src/row/compacted/compacted_row_writer.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::row::Decimal; use crate::row::binary::BinaryWriter; use crate::row::compacted::compacted_row::calculate_bit_set_width_in_bytes; use crate::util::varint::{write_unsigned_varint_to_slice, write_unsigned_varint_u64_to_slice}; @@ -76,6 +77,7 @@ impl CompactedRowWriter { self.position = end; } } + impl BinaryWriter for CompactedRowWriter { fn reset(&mut self) { self.position = self.header_size_in_bytes; @@ -91,32 +93,34 @@ impl BinaryWriter for CompactedRowWriter { fn write_boolean(&mut self, value: bool) { let b = if value { 1u8 } else { 0u8 }; - self.write_raw(&[b]); + self.write_raw(&[b]) } fn write_byte(&mut self, value: u8) { - self.write_raw(&[value]); + self.write_raw(&[value]) } fn write_bytes(&mut self, value: &[u8]) { - let len_i32 = - i32::try_from(value.len()).expect("byte slice too large to encode length as i32"); + let len_i32 = i32::try_from(value.len()) + .expect("Byte slice too large to encode length as i32: exceeds i32::MAX"); self.write_int(len_i32); - self.write_raw(value); + self.write_raw(value) } fn write_char(&mut self, value: &str, _length: usize) { // TODO: currently, we encoding CHAR(length) as the same with STRING, the length info can be // omitted and the bytes length should be enforced in the future. - self.write_string(value); + self.write_string(value) } fn write_string(&mut self, value: &str) { - self.write_bytes(value.as_ref()); + self.write_bytes(value.as_ref()) } fn write_short(&mut self, value: i16) { - self.write_raw(&value.to_ne_bytes()); + // Use native endianness to match Java's UnsafeUtils.putShort behavior + // Java uses sun.misc.Unsafe which writes in native byte order (typically LE on x86/ARM) + self.write_raw(&value.to_ne_bytes()) } fn write_int(&mut self, value: i32) { @@ -132,21 +136,120 @@ impl BinaryWriter for CompactedRowWriter { write_unsigned_varint_u64_to_slice(value as u64, &mut self.buffer[self.position..]); self.position += bytes_written; } + fn write_float(&mut self, value: f32) { - self.write_raw(&value.to_ne_bytes()); + // Use native endianness to match Java's UnsafeUtils.putFloat behavior + self.write_raw(&value.to_ne_bytes()) } fn write_double(&mut self, value: f64) { - self.write_raw(&value.to_ne_bytes()); + // Use native endianness to match Java's UnsafeUtils.putDouble behavior + self.write_raw(&value.to_ne_bytes()) } fn write_binary(&mut self, bytes: &[u8], length: usize) { // TODO: currently, we encoding BINARY(length) as the same with BYTES, the length info can // be omitted and the bytes length should be enforced in the future. - self.write_bytes(&bytes[..length.min(bytes.len())]); + self.write_bytes(&bytes[..length.min(bytes.len())]) } fn complete(&mut self) { // do nothing } + + fn write_decimal(&mut self, value: &Decimal, precision: u32) { + // Decimal is already validated and rescaled during construction. + // Just serialize the precomputed unscaled representation. + if Decimal::is_compact_precision(precision) { + self.write_long( + value + .to_unscaled_long() + .expect("Decimal should fit in i64 for compact precision"), + ) + } else { + self.write_bytes(&value.to_unscaled_bytes()) + } + } + + fn write_time(&mut self, value: i32, _precision: u32) { + // TIME is always encoded as i32 (milliseconds since midnight) regardless of precision + self.write_int(value) + } + + fn write_timestamp_ntz(&mut self, value: &crate::row::datum::TimestampNtz, precision: u32) { + if crate::row::datum::TimestampNtz::is_compact(precision) { + // Compact: write only milliseconds + self.write_long(value.get_millisecond()); + } else { + // Non-compact: write milliseconds + nanoOfMillisecond + self.write_long(value.get_millisecond()); + self.write_int(value.get_nano_of_millisecond()); + } + } + + fn write_timestamp_ltz(&mut self, value: &crate::row::datum::TimestampLtz, precision: u32) { + if crate::row::datum::TimestampLtz::is_compact(precision) { + // Compact: write only epoch milliseconds + self.write_long(value.get_epoch_millisecond()); + } else { + // Non-compact: write epoch milliseconds + nanoOfMillisecond + self.write_long(value.get_epoch_millisecond()); + self.write_int(value.get_nano_of_millisecond()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bigdecimal::{BigDecimal, num_bigint::BigInt}; + + #[test] + fn test_write_decimal_compact() { + // Compact decimal (precision <= 18) + let bd = BigDecimal::new(BigInt::from(12345), 2); // 123.45 + let decimal = Decimal::from_big_decimal(bd, 10, 2).unwrap(); + + let mut w = CompactedRowWriter::new(1); + w.write_decimal(&decimal, 10); + + let (val, _) = crate::util::varint::read_unsigned_varint_u64_at( + w.buffer(), + w.header_size_in_bytes, + CompactedRowWriter::MAX_LONG_SIZE, + ) + .unwrap(); + assert_eq!(val as i64, 12345); + } + + #[test] + fn test_write_decimal_rounding() { + // Test HALF_UP rounding: 12.345 → 12.35 + let bd = BigDecimal::new(BigInt::from(12345), 3); + let decimal = Decimal::from_big_decimal(bd, 10, 2).unwrap(); + + let mut w = CompactedRowWriter::new(1); + w.write_decimal(&decimal, 10); + + let (val, _) = crate::util::varint::read_unsigned_varint_u64_at( + w.buffer(), + w.header_size_in_bytes, + CompactedRowWriter::MAX_LONG_SIZE, + ) + .unwrap(); + assert_eq!(val as i64, 1235); // 12.35 with scale 2 + } + + #[test] + fn test_write_decimal_non_compact() { + // Non-compact (precision > 18): uses byte array + let bd = BigDecimal::new(BigInt::from(12345), 0); + let decimal = Decimal::from_big_decimal(bd, 28, 0).unwrap(); + + let mut w = CompactedRowWriter::new(1); + w.write_decimal(&decimal, 28); + + // Verify something was written (at least length varint + some bytes) + assert!(w.position() > w.header_size_in_bytes); + } } diff --git a/crates/fluss/src/row/datum.rs b/crates/fluss/src/row/datum.rs index ad7948dc..5b21b389 100644 --- a/crates/fluss/src/row/datum.rs +++ b/crates/fluss/src/row/datum.rs @@ -17,6 +17,7 @@ use crate::error::Error::RowConvertError; use crate::error::Result; +use crate::row::Decimal; use arrow::array::{ ArrayBuilder, BinaryBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, StringBuilder, @@ -24,7 +25,6 @@ use arrow::array::{ use jiff::ToSpan; use ordered_float::OrderedFloat; use parse_display::Display; -use rust_decimal::Decimal; use serde::Serialize; use std::borrow::Cow; @@ -58,9 +58,11 @@ pub enum Datum<'a> { #[display("{0}")] Date(Date), #[display("{0}")] - Timestamp(Timestamp), + Time(Time), #[display("{0}")] - TimestampTz(TimestampLtz), + TimestampNtz(TimestampNtz), + #[display("{0}")] + TimestampLtz(TimestampLtz), } impl Datum<'_> { @@ -296,7 +298,11 @@ impl Datum<'_> { Datum::Float64(v) => append_value_to_arrow!(Float64Builder, v.into_inner()), Datum::String(v) => append_value_to_arrow!(StringBuilder, v.as_ref()), Datum::Blob(v) => append_value_to_arrow!(BinaryBuilder, v.as_ref()), - Datum::Decimal(_) | Datum::Date(_) | Datum::Timestamp(_) | Datum::TimestampTz(_) => { + Datum::Decimal(_) + | Datum::Date(_) + | Datum::Time(_) + | Datum::TimestampNtz(_) + | Datum::TimestampLtz(_) => { return Err(RowConvertError { message: format!( "Type {:?} is not yet supported for Arrow conversion", @@ -350,10 +356,122 @@ pub type F64 = OrderedFloat; pub struct Date(i32); #[derive(PartialOrd, Ord, Display, PartialEq, Eq, Debug, Copy, Clone, Default, Hash, Serialize)] -pub struct Timestamp(i64); +pub struct Time(i32); + +impl Time { + pub const fn new(inner: i32) -> Self { + Time(inner) + } + + /// Get the inner value of time type (milliseconds since midnight) + pub fn get_inner(&self) -> i32 { + self.0 + } +} + +/// Maximum timestamp precision that can be stored compactly (milliseconds only). +/// Values with precision > MAX_COMPACT_TIMESTAMP_PRECISION require additional nanosecond storage. +pub const MAX_COMPACT_TIMESTAMP_PRECISION: u32 = 3; + +/// Maximum valid value for nanoseconds within a millisecond (0 to 999,999 inclusive). +/// A millisecond contains 1,000,000 nanoseconds, so the fractional part ranges from 0 to 999,999. +pub const MAX_NANO_OF_MILLISECOND: i32 = 999_999; #[derive(PartialOrd, Ord, Display, PartialEq, Eq, Debug, Copy, Clone, Default, Hash, Serialize)] -pub struct TimestampLtz(i64); +#[display("{millisecond}")] +pub struct TimestampNtz { + millisecond: i64, + nano_of_millisecond: i32, +} + +impl TimestampNtz { + pub const fn new(millisecond: i64) -> Self { + TimestampNtz { + millisecond, + nano_of_millisecond: 0, + } + } + + pub fn from_millis_nanos( + millisecond: i64, + nano_of_millisecond: i32, + ) -> crate::error::Result { + if !(0..=MAX_NANO_OF_MILLISECOND).contains(&nano_of_millisecond) { + return Err(crate::error::Error::IllegalArgument { + message: format!( + "nanoOfMillisecond must be in range [0, {}], got: {}", + MAX_NANO_OF_MILLISECOND, nano_of_millisecond + ), + }); + } + Ok(TimestampNtz { + millisecond, + nano_of_millisecond, + }) + } + + pub fn get_millisecond(&self) -> i64 { + self.millisecond + } + + pub fn get_nano_of_millisecond(&self) -> i32 { + self.nano_of_millisecond + } + + /// Check if the timestamp is compact based on precision. + /// Precision <= MAX_COMPACT_TIMESTAMP_PRECISION means millisecond precision, no need for nanos. + pub fn is_compact(precision: u32) -> bool { + precision <= MAX_COMPACT_TIMESTAMP_PRECISION + } +} + +#[derive(PartialOrd, Ord, Display, PartialEq, Eq, Debug, Copy, Clone, Default, Hash, Serialize)] +#[display("{epoch_millisecond}")] +pub struct TimestampLtz { + epoch_millisecond: i64, + nano_of_millisecond: i32, +} + +impl TimestampLtz { + pub const fn new(epoch_millisecond: i64) -> Self { + TimestampLtz { + epoch_millisecond, + nano_of_millisecond: 0, + } + } + + pub fn from_millis_nanos( + epoch_millisecond: i64, + nano_of_millisecond: i32, + ) -> crate::error::Result { + if !(0..=MAX_NANO_OF_MILLISECOND).contains(&nano_of_millisecond) { + return Err(crate::error::Error::IllegalArgument { + message: format!( + "nanoOfMillisecond must be in range [0, {}], got: {}", + MAX_NANO_OF_MILLISECOND, nano_of_millisecond + ), + }); + } + Ok(TimestampLtz { + epoch_millisecond, + nano_of_millisecond, + }) + } + + pub fn get_epoch_millisecond(&self) -> i64 { + self.epoch_millisecond + } + + pub fn get_nano_of_millisecond(&self) -> i32 { + self.nano_of_millisecond + } + + /// Check if the timestamp is compact based on precision. + /// Precision <= MAX_COMPACT_TIMESTAMP_PRECISION means millisecond precision, no need for nanos. + pub fn is_compact(precision: u32) -> bool { + precision <= MAX_COMPACT_TIMESTAMP_PRECISION + } +} pub type Blob<'a> = Cow<'a, [u8]>; @@ -461,3 +579,54 @@ mod tests { assert_eq!(date.day(), 1); } } + +#[cfg(test)] +mod timestamp_tests { + use super::*; + + #[test] + fn test_timestamp_valid_nanos() { + // Valid range: 0 to MAX_NANO_OF_MILLISECOND for both TimestampNtz and TimestampLtz + let ntz1 = TimestampNtz::from_millis_nanos(1000, 0).unwrap(); + assert_eq!(ntz1.get_nano_of_millisecond(), 0); + + let ntz2 = TimestampNtz::from_millis_nanos(1000, MAX_NANO_OF_MILLISECOND).unwrap(); + assert_eq!(ntz2.get_nano_of_millisecond(), MAX_NANO_OF_MILLISECOND); + + let ntz3 = TimestampNtz::from_millis_nanos(1000, 500_000).unwrap(); + assert_eq!(ntz3.get_nano_of_millisecond(), 500_000); + + let ltz1 = TimestampLtz::from_millis_nanos(1000, 0).unwrap(); + assert_eq!(ltz1.get_nano_of_millisecond(), 0); + + let ltz2 = TimestampLtz::from_millis_nanos(1000, MAX_NANO_OF_MILLISECOND).unwrap(); + assert_eq!(ltz2.get_nano_of_millisecond(), MAX_NANO_OF_MILLISECOND); + } + + #[test] + fn test_timestamp_nanos_out_of_range() { + // Test that both TimestampNtz and TimestampLtz reject invalid nanos + let expected_msg = format!( + "nanoOfMillisecond must be in range [0, {}]", + MAX_NANO_OF_MILLISECOND + ); + + // Too large (1,000,000 is just beyond the valid range) + let result_ntz = TimestampNtz::from_millis_nanos(1000, MAX_NANO_OF_MILLISECOND + 1); + assert!(result_ntz.is_err()); + assert!(result_ntz.unwrap_err().to_string().contains(&expected_msg)); + + let result_ltz = TimestampLtz::from_millis_nanos(1000, MAX_NANO_OF_MILLISECOND + 1); + assert!(result_ltz.is_err()); + assert!(result_ltz.unwrap_err().to_string().contains(&expected_msg)); + + // Negative + let result_ntz = TimestampNtz::from_millis_nanos(1000, -1); + assert!(result_ntz.is_err()); + assert!(result_ntz.unwrap_err().to_string().contains(&expected_msg)); + + let result_ltz = TimestampLtz::from_millis_nanos(1000, -1); + assert!(result_ltz.is_err()); + assert!(result_ltz.unwrap_err().to_string().contains(&expected_msg)); + } +} diff --git a/crates/fluss/src/row/decimal.rs b/crates/fluss/src/row/decimal.rs new file mode 100644 index 00000000..b14bde50 --- /dev/null +++ b/crates/fluss/src/row/decimal.rs @@ -0,0 +1,477 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{Error, Result}; +use bigdecimal::num_bigint::BigInt; +use bigdecimal::num_traits::Zero; +use bigdecimal::{BigDecimal, RoundingMode}; +use std::fmt; + +#[cfg(test)] +use std::str::FromStr; + +/// Maximum decimal precision that can be stored compactly as a single i64. +/// Values with precision > MAX_COMPACT_PRECISION require byte array storage. +pub const MAX_COMPACT_PRECISION: u32 = 18; + +/// An internal data structure representing a decimal value with fixed precision and scale. +/// +/// This data structure is immutable and stores decimal values in a compact representation +/// (as a long value) if values are small enough (precision ≤ 18). +/// +/// Matches Java's org.apache.fluss.row.Decimal class. +#[derive(Debug, Clone, serde::Serialize)] +pub struct Decimal { + precision: u32, + scale: u32, + // If precision <= MAX_COMPACT_PRECISION, this holds the unscaled value + long_val: Option, + // BigDecimal representation (may be cached) + decimal_val: Option, +} + +impl Decimal { + /// Returns the precision of this Decimal. + /// + /// The precision is the number of digits in the unscaled value. + pub fn precision(&self) -> u32 { + self.precision + } + + /// Returns the scale of this Decimal. + pub fn scale(&self) -> u32 { + self.scale + } + + /// Returns whether the decimal value is small enough to be stored in a long. + pub fn is_compact(&self) -> bool { + self.precision <= MAX_COMPACT_PRECISION + } + + /// Returns whether a given precision can be stored compactly. + pub fn is_compact_precision(precision: u32) -> bool { + precision <= MAX_COMPACT_PRECISION + } + + /// Converts this Decimal into a BigDecimal. + pub fn to_big_decimal(&self) -> BigDecimal { + if let Some(bd) = &self.decimal_val { + bd.clone() + } else if let Some(long_val) = self.long_val { + BigDecimal::new(BigInt::from(long_val), self.scale as i64) + } else { + // Should never happen - we always have one representation + BigDecimal::new(BigInt::from(0), self.scale as i64) + } + } + + /// Returns a long describing the unscaled value of this Decimal. + pub fn to_unscaled_long(&self) -> Result { + if let Some(long_val) = self.long_val { + Ok(long_val) + } else { + // Extract unscaled value from BigDecimal + let bd = self.to_big_decimal(); + let (unscaled, _) = bd.as_bigint_and_exponent(); + unscaled.try_into().map_err(|_| Error::IllegalArgument { + message: format!( + "Decimal unscaled value does not fit in i64: precision={}", + self.precision + ), + }) + } + } + + /// Returns a byte array describing the unscaled value of this Decimal. + pub fn to_unscaled_bytes(&self) -> Vec { + let bd = self.to_big_decimal(); + let (unscaled, _) = bd.as_bigint_and_exponent(); + unscaled.to_signed_bytes_be() + } + + /// Creates a Decimal from Arrow's Decimal128 representation. + // TODO: For compact decimals with matching scale we may call from_unscaled_long + pub fn from_arrow_decimal128( + i128_val: i128, + arrow_scale: i64, + precision: u32, + scale: u32, + ) -> Result { + let bd = BigDecimal::new(BigInt::from(i128_val), arrow_scale); + Self::from_big_decimal(bd, precision, scale) + } + + /// Creates an instance of Decimal from a BigDecimal with the given precision and scale. + /// + /// The returned decimal value may be rounded to have the desired scale. The precision + /// will be checked. If the precision overflows, an error is returned. + pub fn from_big_decimal(bd: BigDecimal, precision: u32, scale: u32) -> Result { + // Rescale to the target scale with HALF_UP rounding (matches Java) + let scaled = bd.with_scale_round(scale as i64, RoundingMode::HalfUp); + + // Extract unscaled value + let (unscaled, exp) = scaled.as_bigint_and_exponent(); + + // Sanity check that scale matches + debug_assert_eq!( + exp, scale as i64, + "Scaled decimal exponent ({}) != expected scale ({})", + exp, scale + ); + + let actual_precision = Self::compute_precision(&unscaled); + if actual_precision > precision as usize { + return Err(Error::IllegalArgument { + message: format!( + "Decimal precision overflow: value has {} digits but precision is {} (value: {})", + actual_precision, precision, scaled + ), + }); + } + + // Compute compact representation if possible + let long_val = if precision <= MAX_COMPACT_PRECISION { + Some(i64::try_from(&unscaled).map_err(|_| Error::IllegalArgument { + message: format!( + "Decimal mantissa exceeds i64 range for compact precision {}: unscaled={} (value={})", + precision, unscaled, scaled + ), + })?) + } else { + None + }; + + Ok(Decimal { + precision, + scale, + long_val, + decimal_val: Some(scaled), + }) + } + + /// Creates an instance of Decimal from an unscaled long value with the given precision and scale. + pub fn from_unscaled_long(unscaled_long: i64, precision: u32, scale: u32) -> Result { + if precision > MAX_COMPACT_PRECISION { + return Err(Error::IllegalArgument { + message: format!( + "Precision {} exceeds MAX_COMPACT_PRECISION ({})", + precision, MAX_COMPACT_PRECISION + ), + }); + } + + let actual_precision = Self::compute_precision(&BigInt::from(unscaled_long)); + if actual_precision > precision as usize { + return Err(Error::IllegalArgument { + message: format!( + "Decimal precision overflow: unscaled value has {} digits but precision is {}", + actual_precision, precision + ), + }); + } + + Ok(Decimal { + precision, + scale, + long_val: Some(unscaled_long), + decimal_val: None, + }) + } + + /// Creates an instance of Decimal from an unscaled byte array with the given precision and scale. + pub fn from_unscaled_bytes(unscaled_bytes: &[u8], precision: u32, scale: u32) -> Result { + let unscaled = BigInt::from_signed_bytes_be(unscaled_bytes); + let bd = BigDecimal::new(unscaled, scale as i64); + Self::from_big_decimal(bd, precision, scale) + } + + /// Computes the precision of a decimal's unscaled value, matching Java's BigDecimal.precision(). + pub fn compute_precision(unscaled: &BigInt) -> usize { + if unscaled.is_zero() { + return 1; + } + + // Count ALL digits in the unscaled value (matches Java's BigDecimal.precision()) + // For bounded precision (≤ 38 digits), string conversion is cheap and simple. + unscaled.magnitude().to_str_radix(10).len() + } +} + +impl fmt::Display for Decimal { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_big_decimal()) + } +} + +// Manual implementations of comparison traits to ignore cached fields +impl PartialEq for Decimal { + fn eq(&self, other: &Self) -> bool { + // Use numeric equality like Java's Decimal.equals() which delegates to compareTo. + // This means 1.0 (scale=1) equals 1.00 (scale=2). + self.cmp(other) == std::cmp::Ordering::Equal + } +} + +impl Eq for Decimal {} + +impl PartialOrd for Decimal { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Decimal { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // If both are compact and have the same scale, compare directly + if self.is_compact() && other.is_compact() && self.scale == other.scale { + self.long_val.cmp(&other.long_val) + } else { + // Otherwise, compare as BigDecimal + self.to_big_decimal().cmp(&other.to_big_decimal()) + } + } +} + +impl std::hash::Hash for Decimal { + fn hash(&self, state: &mut H) { + // Hash the BigDecimal representation. + // + // IMPORTANT: Unlike Java's BigDecimal, Rust's bigdecimal crate normalizes + // before hashing, so hash(1.0) == hash(1.00). Combined with our numeric + // equality (1.0 == 1.00), this CORRECTLY satisfies the hash/equals contract. + // + // This is BETTER than Java's implementation which has a hash/equals violation: + // - Java: equals(1.0, 1.00) = true, but hashCode(1.0) != hashCode(1.00) + // - Rust: equals(1.0, 1.00) = true, and hash(1.0) == hash(1.00) ✓ + // + // Result: HashMap/HashSet will work correctly even if you create Decimals + // with different scales for the same numeric value (though this is rare in + // practice since decimals are schema-driven with fixed precision/scale). + self.to_big_decimal().hash(state); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_precision_calculation() { + // Zero is special case + assert_eq!(Decimal::compute_precision(&BigInt::from(0)), 1); + + // Must count ALL digits including trailing zeros (matches Java BigDecimal.precision()) + assert_eq!(Decimal::compute_precision(&BigInt::from(10)), 2); + assert_eq!(Decimal::compute_precision(&BigInt::from(100)), 3); + assert_eq!(Decimal::compute_precision(&BigInt::from(12300)), 5); + assert_eq!( + Decimal::compute_precision(&BigInt::from(10000000000i64)), + 11 + ); + + // Test the case: value=1, scale=10 → unscaled=10000000000 (11 digits) + let bd = BigDecimal::new(BigInt::from(1), 0); + assert!( + Decimal::from_big_decimal(bd.clone(), 1, 10).is_err(), + "Should reject: unscaled 10000000000 has 11 digits, precision=1 is too small" + ); + assert!( + Decimal::from_big_decimal(bd, 11, 10).is_ok(), + "Should accept with correct precision=11" + ); + } + + /// Test precision validation boundaries + #[test] + fn test_precision_validation() { + let test_cases = vec![ + (10i64, 1, 2), // 1.0 → unscaled: 10 (2 digits) + (100i64, 2, 3), // 1.00 → unscaled: 100 (3 digits) + (10000000000i64, 10, 11), // 1.0000000000 → unscaled: 10000000000 (11 digits) + ]; + + for (unscaled, scale, min_precision) in test_cases { + let bd = BigDecimal::new(BigInt::from(unscaled), scale as i64); + + // Reject if precision too small + assert!(Decimal::from_big_decimal(bd.clone(), min_precision - 1, scale).is_err()); + // Accept with correct precision + assert!(Decimal::from_big_decimal(bd, min_precision, scale).is_ok()); + } + + // i64::MAX has 19 digits, should reject with precision=5 + let bd = BigDecimal::new(BigInt::from(i64::MAX), 0); + assert!(Decimal::from_big_decimal(bd, 5, 0).is_err()); + } + + /// Test creation and basic operations for both compact and non-compact decimals + #[test] + fn test_creation_and_representation() { + // Compact (precision ≤ 18): from unscaled long + let compact = Decimal::from_unscaled_long(12345, 10, 2).unwrap(); + assert_eq!(compact.precision(), 10); + assert_eq!(compact.scale(), 2); + assert!(compact.is_compact()); + assert_eq!(compact.to_unscaled_long().unwrap(), 12345); + assert_eq!(compact.to_big_decimal().to_string(), "123.45"); + + // Non-compact (precision > 18): from BigDecimal + let bd = BigDecimal::new(BigInt::from(12345), 0); + let non_compact = Decimal::from_big_decimal(bd, 28, 0).unwrap(); + assert_eq!(non_compact.precision(), 28); + assert!(!non_compact.is_compact()); + assert_eq!( + non_compact.to_unscaled_bytes(), + BigInt::from(12345).to_signed_bytes_be() + ); + + // Test compact boundary + assert!(Decimal::is_compact_precision(18)); + assert!(!Decimal::is_compact_precision(19)); + + // Test rounding during creation + let bd = BigDecimal::new(BigInt::from(12345), 3); // 12.345 + let rounded = Decimal::from_big_decimal(bd, 10, 2).unwrap(); + assert_eq!(rounded.to_unscaled_long().unwrap(), 1235); // 12.35 + } + + /// Test serialization round-trip (unscaled bytes) + #[test] + fn test_serialization_roundtrip() { + // Compact decimal + let bd1 = BigDecimal::new(BigInt::from(1314567890123i64), 5); // 13145678.90123 + let decimal1 = Decimal::from_big_decimal(bd1.clone(), 15, 5).unwrap(); + let (unscaled1, _) = bd1.as_bigint_and_exponent(); + let from_bytes1 = + Decimal::from_unscaled_bytes(&unscaled1.to_signed_bytes_be(), 15, 5).unwrap(); + assert_eq!(from_bytes1, decimal1); + assert_eq!( + from_bytes1.to_unscaled_bytes(), + unscaled1.to_signed_bytes_be() + ); + + // Non-compact decimal + let bd2 = BigDecimal::new(BigInt::from(12345678900987654321i128), 10); + let decimal2 = Decimal::from_big_decimal(bd2.clone(), 23, 10).unwrap(); + let (unscaled2, _) = bd2.as_bigint_and_exponent(); + let from_bytes2 = + Decimal::from_unscaled_bytes(&unscaled2.to_signed_bytes_be(), 23, 10).unwrap(); + assert_eq!(from_bytes2, decimal2); + assert_eq!( + from_bytes2.to_unscaled_bytes(), + unscaled2.to_signed_bytes_be() + ); + } + + /// Test numeric equality and ordering (matches Java semantics) + #[test] + fn test_equality_and_ordering() { + // Same value, different precision/scale → should be equal (numeric equality) + let d1 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(10), 1), 2, 1).unwrap(); // 1.0 + let d2 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(100), 2), 3, 2).unwrap(); // 1.00 + assert_eq!(d1, d2, "Numeric equality: 1.0 == 1.00"); + assert_eq!(d1.cmp(&d2), std::cmp::Ordering::Equal); + + // Test ordering with positive values + let small = Decimal::from_unscaled_long(10, 5, 0).unwrap(); + let large = Decimal::from_unscaled_long(15, 5, 0).unwrap(); + assert!(small < large); + assert_eq!(small.cmp(&large), std::cmp::Ordering::Less); + + // Test ordering with negative values + let negative_large = Decimal::from_unscaled_long(-10, 5, 0).unwrap(); // -10 + let negative_small = Decimal::from_unscaled_long(-15, 5, 0).unwrap(); // -15 + assert!(negative_small < negative_large); // -15 < -10 + assert_eq!( + negative_small.cmp(&negative_large), + std::cmp::Ordering::Less + ); + + // Test ordering with mixed positive and negative + let positive = Decimal::from_unscaled_long(5, 5, 0).unwrap(); + let negative = Decimal::from_unscaled_long(-5, 5, 0).unwrap(); + assert!(negative < positive); + assert_eq!(negative.cmp(&positive), std::cmp::Ordering::Less); + + // Test clone and round-trip equality + let original = Decimal::from_unscaled_long(10, 5, 0).unwrap(); + assert_eq!(original.clone(), original); + assert_eq!( + Decimal::from_unscaled_long(original.to_unscaled_long().unwrap(), 5, 0).unwrap(), + original + ); + } + + /// Test hash/equals contract (Rust implementation is correct, unlike Java) + #[test] + fn test_hash_equals_contract() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let d1 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(10), 1), 2, 1).unwrap(); // 1.0 + let d2 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(100), 2), 3, 2).unwrap(); // 1.00 + + // Numeric equality + assert_eq!(d1, d2); + + // Hash contract: if a == b, then hash(a) == hash(b) + let mut hasher1 = DefaultHasher::new(); + d1.hash(&mut hasher1); + let hash1 = hasher1.finish(); + + let mut hasher2 = DefaultHasher::new(); + d2.hash(&mut hasher2); + let hash2 = hasher2.finish(); + + assert_eq!(hash1, hash2, "Equal decimals must have equal hashes"); + + // Verify HashMap works correctly (this would fail in Java due to their hash/equals bug) + let mut map = std::collections::HashMap::new(); + map.insert(d1.clone(), "value"); + assert_eq!(map.get(&d2), Some(&"value")); + } + + /// Test edge cases: zeros, large numbers, rescaling + #[test] + fn test_edge_cases() { + // Zero handling (compact and non-compact) + let zero_compact = Decimal::from_unscaled_long(0, 5, 2).unwrap(); + assert_eq!( + zero_compact.to_big_decimal(), + BigDecimal::new(BigInt::from(0), 2) + ); + + let zero_non_compact = + Decimal::from_big_decimal(BigDecimal::new(BigInt::from(0), 2), 20, 2).unwrap(); + assert_eq!( + zero_non_compact.to_big_decimal(), + BigDecimal::new(BigInt::from(0), 2) + ); + + // Large number (39 digits) + let large_bd = BigDecimal::from_str("123456789012345678901234567890123456789").unwrap(); + let large = Decimal::from_big_decimal(large_bd, 39, 0).unwrap(); + let double_val = large.to_big_decimal().to_string().parse::().unwrap(); + assert!((double_val - 1.2345678901234568E38).abs() < 0.01); + + // Rescaling: 5.0 (scale=1) → 5.00 (scale=2) + let d1 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(50), 1), 10, 1).unwrap(); + let d2 = Decimal::from_big_decimal(d1.to_big_decimal(), 10, 2).unwrap(); + assert_eq!(d2.to_big_decimal().to_string(), "5.00"); + assert_eq!(d2.scale(), 2); + } +} diff --git a/crates/fluss/src/row/encode/compacted_key_encoder.rs b/crates/fluss/src/row/encode/compacted_key_encoder.rs index ebe3da2a..563c1c96 100644 --- a/crates/fluss/src/row/encode/compacted_key_encoder.rs +++ b/crates/fluss/src/row/encode/compacted_key_encoder.rs @@ -238,86 +238,121 @@ mod tests { } #[test] - fn test_all_data_types() { + fn test_all_data_types_java_compatible() { + // Test encoding compatibility with Java using reference from: + // https://github.com/apache/fluss/blob/main/fluss-common/src/test/resources/encoding/encoded_key.hex + use crate::metadata::{DataType, TimestampLTzType, TimestampType}; + let row_type = RowType::with_data_types(vec![ - DataTypes::boolean(), - DataTypes::tinyint(), - DataTypes::smallint(), - DataTypes::int(), - DataTypes::bigint(), - DataTypes::float(), - DataTypes::double(), - // TODO Date - // TODO Time - DataTypes::binary(20), - DataTypes::bytes(), - DataTypes::char(2), - DataTypes::string(), - // TODO Decimal - // TODO Timestamp - // TODO Timestamp LTZ - // TODO Array of Int - // TODO Array of Float - // TODO Array of String - // TODO: Add Map and Row fields in Issue #1973 + DataTypes::boolean(), // BOOLEAN + DataTypes::tinyint(), // TINYINT + DataTypes::smallint(), // SMALLINT + DataTypes::int(), // INT + DataTypes::bigint(), // BIGINT + DataTypes::float(), // FLOAT + DataTypes::double(), // DOUBLE + DataTypes::date(), // DATE + DataTypes::time(), // TIME + DataTypes::binary(20), // BINARY(20) + DataTypes::bytes(), // BYTES + DataTypes::char(2), // CHAR(2) + DataTypes::string(), // STRING + DataTypes::decimal(5, 2), // DECIMAL(5,2) + DataTypes::decimal(20, 0), // DECIMAL(20,0) + DataType::Timestamp(TimestampType::with_nullable(false, 1).unwrap()), // TIMESTAMP(1) + DataType::Timestamp(TimestampType::with_nullable(false, 5).unwrap()), // TIMESTAMP(5) + DataType::TimestampLTz(TimestampLTzType::with_nullable(false, 1).unwrap()), // TIMESTAMP_LTZ(1) + DataType::TimestampLTz(TimestampLTzType::with_nullable(false, 5).unwrap()), // TIMESTAMP_LTZ(5) + // TODO: Add support for ARRAY type + // TODO: Add support for MAP type + // TODO: Add support for ROW type ]); + // Exact values from Java's IndexedRowTest.genRecordForAllTypes() let row = GenericRow::from_data(vec![ - Datum::from(true), - Datum::from(2i8), - Datum::from(10i16), - Datum::from(100i32), - Datum::from(-6101065172474983726i64), // from Java test case: new BigInteger("12345678901234567890").longValue() - Datum::from(13.2f32), - Datum::from(15.21f64), - // TODO Date - // TODO Time - Datum::from("1234567890".as_bytes()), - Datum::from("20".as_bytes()), - Datum::from("1"), - Datum::from("hello"), - // TODO Decimal - // TODO Timestamp - // TODO Timestamp LTZ - // TODO Array of Int - // TODO Array of Float - // TODO Array of String - // TODO: Add Map and Row fields in Issue #1973 + Datum::from(true), // BOOLEAN: true + Datum::from(2i8), // TINYINT: 2 + Datum::from(10i16), // SMALLINT: 10 + Datum::from(100i32), // INT: 100 + Datum::from(-6101065172474983726i64), // BIGINT + Datum::from(13.2f32), // FLOAT: 13.2 + Datum::from(15.21f64), // DOUBLE: 15.21 + Datum::Date(crate::row::datum::Date::new(19655)), // DATE: 2023-10-25 (19655 days since epoch) + Datum::Time(crate::row::datum::Time::new(34200000)), // TIME: 09:30:00.0 + Datum::from("1234567890".as_bytes()), // BINARY(20) + Datum::from("20".as_bytes()), // BYTES + Datum::from("1"), // CHAR(2): "1" + Datum::from("hello"), // STRING: "hello" + Datum::Decimal(crate::row::Decimal::from_unscaled_long(9, 5, 2).unwrap()), // DECIMAL(5,2) + Datum::Decimal( + crate::row::Decimal::from_big_decimal( + bigdecimal::BigDecimal::new(bigdecimal::num_bigint::BigInt::from(10), 0), + 20, + 0, + ) + .unwrap(), + ), // DECIMAL(20,0) + Datum::TimestampNtz(crate::row::datum::TimestampNtz::new(1698235273182)), // TIMESTAMP(1) + Datum::TimestampNtz(crate::row::datum::TimestampNtz::new(1698235273182)), // TIMESTAMP(5) + Datum::TimestampLtz(crate::row::datum::TimestampLtz::new(1698235273182)), // TIMESTAMP_LTZ(1) + Datum::TimestampLtz(crate::row::datum::TimestampLtz::new(1698235273182)), // TIMESTAMP_LTZ(5) ]); - let mut encoder = for_test_row_type(&row_type); - - let mut expected: Vec = Vec::new(); - // BOOLEAN: true - expected.extend(vec![0x01]); - // TINYINT: 2 - expected.extend(vec![0x02]); - // SMALLINT: 10 - expected.extend(vec![0x0A]); - // INT: 100 - expected.extend(vec![0x00, 0x64]); - // BIGINT: -6101065172474983726 - expected.extend(vec![ + // Expected bytes from Java's encoded_key.hex reference file + #[rustfmt::skip] + let expected: Vec = vec![ + // BOOLEAN: true + 0x01, + // TINYINT: 2 + 0x02, + // SMALLINT: 10 (varint encoded) + 0x0A, + // INT: 100 (varint encoded) + 0x00, 0x64, + // BIGINT: -6101065172474983726 0xD2, 0x95, 0xFC, 0xD8, 0xCE, 0xB1, 0xAA, 0xAA, 0xAB, 0x01, - ]); - // FLOAT: 13.2 - expected.extend(vec![0x33, 0x33, 0x53, 0x41]); - // DOUBLE: 15.21 - expected.extend(vec![0xEC, 0x51, 0xB8, 0x1E, 0x85, 0x6B, 0x2E, 0x40]); - // BINARY(20): "1234567890".getBytes() - expected.extend(vec![ + // FLOAT: 13.2 + 0x33, 0x33, 0x53, 0x41, + // DOUBLE: 15.21 + 0xEC, 0x51, 0xB8, 0x1E, 0x85, 0x6B, 0x2E, 0x40, + // DATE: 2023-10-25 + 0xC7, 0x99, 0x01, + // TIME: 09:30:00.0 + 0xC0, 0xB3, 0xA7, 0x10, + // BINARY(20): "1234567890" 0x0A, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x30, - ]); + // BYTES: "20" + 0x02, 0x32, 0x30, + // CHAR(2): "1" + 0x01, 0x31, + // STRING: "hello" + 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, + // DECIMAL(5,2): 9 + 0x09, + // DECIMAL(20,0): 10 + 0x01, 0x0A, + // TIMESTAMP(1): 1698235273182 + 0xDE, 0x9F, 0xD7, 0xB5, 0xB6, 0x31, + // TIMESTAMP(5): 1698235273182 + 0xDE, 0x9F, 0xD7, 0xB5, 0xB6, 0x31, 0x00, + // TIMESTAMP_LTZ(1): 1698235273182 + 0xDE, 0x9F, 0xD7, 0xB5, 0xB6, 0x31, + // TIMESTAMP_LTZ(5): 1698235273182 + 0xDE, 0x9F, 0xD7, 0xB5, 0xB6, 0x31, 0x00, + ]; - // BYTES: "20".getBytes() - expected.extend(vec![0x02, 0x32, 0x30]); - // CHAR(2): "1" - expected.extend(vec![0x01, 0x31]); - // STRING: String: "hello" - expected.extend(vec![0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F]); + let mut encoder = for_test_row_type(&row_type); + let encoded = encoder.encode_key(&row).unwrap(); + + // Assert byte-for-byte compatibility with Java's encoded_key.hex assert_eq!( - encoder.encode_key(&row).unwrap().iter().as_slice(), - expected.as_slice() + encoded.iter().as_slice(), + expected.as_slice(), + "\n\nRust encoding does not match Java reference from encoded_key.hex\n\ + Expected: {:02X?}\n\ + Actual: {:02X?}\n", + expected, + encoded.iter().as_slice() ); } } diff --git a/crates/fluss/src/row/field_getter.rs b/crates/fluss/src/row/field_getter.rs index 97f9e395..cbffa4d0 100644 --- a/crates/fluss/src/row/field_getter.rs +++ b/crates/fluss/src/row/field_getter.rs @@ -66,6 +66,21 @@ impl FieldGetter { DataType::BigInt(_) => InnerFieldGetter::BigInt { pos }, DataType::Float(_) => InnerFieldGetter::Float { pos }, DataType::Double(_) => InnerFieldGetter::Double { pos }, + DataType::Decimal(decimal_type) => InnerFieldGetter::Decimal { + pos, + precision: decimal_type.precision() as usize, + scale: decimal_type.scale() as usize, + }, + DataType::Date(_) => InnerFieldGetter::Date { pos }, + DataType::Time(_) => InnerFieldGetter::Time { pos }, + DataType::Timestamp(t) => InnerFieldGetter::Timestamp { + pos, + precision: t.precision(), + }, + DataType::TimestampLTz(t) => InnerFieldGetter::TimestampLtz { + pos, + precision: t.precision(), + }, _ => unimplemented!("DataType {:?} is currently unimplemented", data_type), }; @@ -79,17 +94,60 @@ impl FieldGetter { #[derive(Clone)] pub enum InnerFieldGetter { - Char { pos: usize, len: usize }, - String { pos: usize }, - Bool { pos: usize }, - Binary { pos: usize, len: usize }, - Bytes { pos: usize }, - TinyInt { pos: usize }, - SmallInt { pos: usize }, - Int { pos: usize }, - BigInt { pos: usize }, - Float { pos: usize }, - Double { pos: usize }, + Char { + pos: usize, + len: usize, + }, + String { + pos: usize, + }, + Bool { + pos: usize, + }, + Binary { + pos: usize, + len: usize, + }, + Bytes { + pos: usize, + }, + TinyInt { + pos: usize, + }, + SmallInt { + pos: usize, + }, + Int { + pos: usize, + }, + BigInt { + pos: usize, + }, + Float { + pos: usize, + }, + Double { + pos: usize, + }, + Decimal { + pos: usize, + precision: usize, + scale: usize, + }, + Date { + pos: usize, + }, + Time { + pos: usize, + }, + Timestamp { + pos: usize, + precision: u32, + }, + TimestampLtz { + pos: usize, + precision: u32, + }, } impl InnerFieldGetter { @@ -106,7 +164,19 @@ impl InnerFieldGetter { InnerFieldGetter::BigInt { pos } => Datum::from(row.get_long(*pos)), InnerFieldGetter::Float { pos } => Datum::from(row.get_float(*pos)), InnerFieldGetter::Double { pos } => Datum::from(row.get_double(*pos)), - //TODO Decimal, Date, Time, Timestamp, TimestampLTZ, Array, Map, Row + InnerFieldGetter::Decimal { + pos, + precision, + scale, + } => Datum::Decimal(row.get_decimal(*pos, *precision, *scale)), + InnerFieldGetter::Date { pos } => Datum::Date(row.get_date(*pos)), + InnerFieldGetter::Time { pos } => Datum::Time(row.get_time(*pos)), + InnerFieldGetter::Timestamp { pos, precision } => { + Datum::TimestampNtz(row.get_timestamp_ntz(*pos, *precision)) + } + InnerFieldGetter::TimestampLtz { pos, precision } => { + Datum::TimestampLtz(row.get_timestamp_ltz(*pos, *precision)) + } //TODO Array, Map, Row } } @@ -122,7 +192,12 @@ impl InnerFieldGetter { | Self::Int { pos } | Self::BigInt { pos } | Self::Float { pos, .. } - | Self::Double { pos } => *pos, + | Self::Double { pos } + | Self::Decimal { pos, .. } + | Self::Date { pos } + | Self::Time { pos } + | Self::Timestamp { pos, .. } + | Self::TimestampLtz { pos, .. } => *pos, } } } diff --git a/crates/fluss/src/row/mod.rs b/crates/fluss/src/row/mod.rs index 536409ef..d2f640e4 100644 --- a/crates/fluss/src/row/mod.rs +++ b/crates/fluss/src/row/mod.rs @@ -18,6 +18,7 @@ mod column; mod datum; +mod decimal; pub mod binary; pub mod compacted; @@ -28,6 +29,7 @@ mod row_decoder; pub use column::*; pub use compacted::CompactedRow; pub use datum::*; +pub use decimal::{Decimal, MAX_COMPACT_PRECISION}; pub use encode::KeyEncoder; pub use row_decoder::{CompactedRowDecoder, RowDecoder, RowDecoderFactory}; @@ -71,14 +73,26 @@ pub trait InternalRow { /// Returns the string value at the given position fn get_string(&self, pos: usize) -> &str; - // /// Returns the decimal value at the given position - // fn get_decimal(&self, pos: usize, precision: usize, scale: usize) -> Decimal; + /// Returns the decimal value at the given position + fn get_decimal(&self, pos: usize, precision: usize, scale: usize) -> Decimal; - // /// Returns the timestamp value at the given position - // fn get_timestamp_ntz(&self, pos: usize, precision: usize) -> TimestampNtz; + /// Returns the date value at the given position (date as days since epoch) + fn get_date(&self, pos: usize) -> datum::Date; - // /// Returns the timestamp value at the given position - // fn get_timestamp_ltz(&self, pos: usize, precision: usize) -> TimestampLtz; + /// Returns the time value at the given position (time as milliseconds since midnight) + fn get_time(&self, pos: usize) -> datum::Time; + + /// Returns the timestamp value at the given position (timestamp without timezone) + /// + /// The precision is required to determine whether the timestamp value was stored + /// in a compact representation (precision <= 3) or with nanosecond precision. + fn get_timestamp_ntz(&self, pos: usize, precision: u32) -> datum::TimestampNtz; + + /// Returns the timestamp value at the given position (timestamp with local timezone) + /// + /// The precision is required to determine whether the timestamp value was stored + /// in a compact representation (precision <= 3) or with nanosecond precision. + fn get_timestamp_ltz(&self, pos: usize, precision: u32) -> datum::TimestampLtz; /// Returns the binary value at the given position with fixed length fn get_binary(&self, pos: usize, length: usize) -> &[u8]; @@ -123,6 +137,43 @@ impl<'a> InternalRow for GenericRow<'a> { self.values.get(_pos).unwrap().try_into().unwrap() } + fn get_decimal(&self, pos: usize, _precision: usize, _scale: usize) -> Decimal { + match self.values.get(pos).unwrap() { + Datum::Decimal(d) => d.clone(), + other => panic!("Expected Decimal at pos {pos:?}, got {other:?}"), + } + } + + fn get_date(&self, pos: usize) -> datum::Date { + match self.values.get(pos).unwrap() { + Datum::Date(d) => *d, + Datum::Int32(i) => datum::Date::new(*i), + other => panic!("Expected Date or Int32 at pos {pos:?}, got {other:?}"), + } + } + + fn get_time(&self, pos: usize) -> datum::Time { + match self.values.get(pos).unwrap() { + Datum::Time(t) => *t, + Datum::Int32(i) => datum::Time::new(*i), + other => panic!("Expected Time or Int32 at pos {pos:?}, got {other:?}"), + } + } + + fn get_timestamp_ntz(&self, pos: usize, _precision: u32) -> datum::TimestampNtz { + match self.values.get(pos).unwrap() { + Datum::TimestampNtz(t) => *t, + other => panic!("Expected TimestampNtz at pos {pos:?}, got {other:?}"), + } + } + + fn get_timestamp_ltz(&self, pos: usize, _precision: u32) -> datum::TimestampLtz { + match self.values.get(pos).unwrap() { + Datum::TimestampLtz(t) => *t, + other => panic!("Expected TimestampLtz at pos {pos:?}, got {other:?}"), + } + } + fn get_float(&self, pos: usize) -> f32 { self.values.get(pos).unwrap().try_into().unwrap() }