From ee42f7962a11444dfe2bf2983f2234eb9ed0a55c Mon Sep 17 00:00:00 2001 From: Iha Shin Date: Mon, 15 Jun 2026 17:09:00 +0900 Subject: [PATCH] feat(arrow_csv): add header validation option --- arrow-csv/src/reader/mod.rs | 111 ++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index e26072fea917..647392bdb8c0 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -273,6 +273,7 @@ impl InferredDataType { #[derive(Debug, Clone, Default)] pub struct Format { header: bool, + header_validation: bool, delimiter: Option, escape: Option, quote: Option, @@ -291,6 +292,16 @@ impl Format { self } + /// Specify whether to validate the CSV header against the schema, defaults to `false` + /// + /// When `true`, the first row gets validated against the schema before any data is read + /// + /// Only applies when [`Self::with_header`] is set to `true` + pub fn with_header_validation(mut self, validate_header: bool) -> Self { + self.header_validation = validate_header; + self + } + /// Specify a custom delimiter character, defaults to comma `','` pub fn with_delimiter(mut self, delimiter: u8) -> Self { self.delimiter = Some(delimiter); @@ -610,6 +621,9 @@ pub struct Decoder { /// Rows to skip to_skip: usize, + /// Whether to validate the first skipped row against the schema + header_validation: bool, + /// Current line number line_number: usize, @@ -635,6 +649,20 @@ impl Decoder { /// network sources such as object storage pub fn decode(&mut self, buf: &[u8]) -> Result { if self.to_skip != 0 { + if self.header_validation { + let (skipped, bytes) = self.record_decoder.decode(buf, 1)?; + + if skipped == 0 { + return Ok(bytes); + } + + let rows = self.record_decoder.flush()?; + validate_header(&rows, self.schema.fields())?; + self.header_validation = false; + self.to_skip -= 1; + return Ok(bytes); + } + // Skip in units of `to_read` to avoid over-allocating buffers let to_skip = self.to_skip.min(self.batch_size); let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?; @@ -678,6 +706,24 @@ impl Decoder { } } +fn validate_header(rows: &StringRecords<'_>, fields: &Fields) -> Result<(), ArrowError> { + let header = rows.iter().next().ok_or_else(|| { + ArrowError::CsvError("CSV header validation failed: no header row found".to_string()) + })?; + + for (idx, field) in fields.iter().enumerate() { + let actual = header.get(idx); + let expected = field.name(); + if actual != expected { + return Err(ArrowError::CsvError(format!( + "CSV header does not match schema at column {idx}: expected {expected:?} but found {actual:?}" + ))); + } + } + + Ok(()) +} + /// Parses a slice of [`StringRecords`] into a [RecordBatch] fn parse( rows: &StringRecords<'_>, @@ -1154,6 +1200,14 @@ impl ReaderBuilder { self } + /// Set whether to validate the CSV header against the schema + /// + /// This option only applies when [`Self::with_header`] is set to `true`, and defaults to `false` + pub fn with_header_validation(mut self, validate_header: bool) -> Self { + self.format.header_validation = validate_header; + self + } + /// Overrides the [Format] of this [ReaderBuilder] pub fn with_format(mut self, format: Format) -> Self { self.format = format; @@ -1261,6 +1315,7 @@ impl ReaderBuilder { Decoder { schema: self.schema, to_skip: start, + header_validation: self.format.header && self.format.header_validation, record_decoder, line_number: start, end, @@ -2351,6 +2406,62 @@ mod tests { } } + #[test] + fn test_header_validation() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let csv = "a,c\n1,2\n"; + let err = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_header_validation(true) + .build_buffered(Cursor::new(csv.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap_err() + .to_string(); + assert_eq!( + err, + "Csv error: CSV header does not match schema at column 1: expected \"b\" but found \"c\"" + ); + + let batch = ReaderBuilder::new(schema) + .with_header(true) + .with_header_validation(false) + .build_buffered(Cursor::new(csv.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap(); + assert_eq!(batch.num_rows(), 1); + } + + #[test] + fn test_header_validation_with_buffered_reader() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let csv = "a,b\n1,2\n"; + let buffered = std::io::BufReader::with_capacity(1, Cursor::new(csv.as_bytes())); + let batch = ReaderBuilder::new(schema) + .with_header(true) + .with_header_validation(true) + .build_buffered(buffered) + .unwrap() + .next() + .unwrap() + .unwrap(); + + assert_eq!(batch.num_rows(), 1); + let a = batch.column(0).as_primitive::(); + assert_eq!(a.value(0), 1); + } + #[test] fn test_null_boolean() { let csv = "true,false\nFalse,True\n,True\nFalse,";