Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ impl InferredDataType {
#[derive(Debug, Clone, Default)]
pub struct Format {
header: bool,
header_validation: bool,
delimiter: Option<u8>,
escape: Option<u8>,
quote: Option<u8>,
Expand All @@ -291,6 +292,16 @@ impl Format {
self
}

/// Specify whether to validate the CSV header against the schema, defaults to `false`
///
/// When `true`, the first row gets validated against the schema before any data is read
///
/// Only applies when [`Self::with_header`] is set to `true`
pub fn with_header_validation(mut self, validate_header: bool) -> Self {
self.header_validation = validate_header;
self
}

/// Specify a custom delimiter character, defaults to comma `','`
pub fn with_delimiter(mut self, delimiter: u8) -> Self {
self.delimiter = Some(delimiter);
Expand Down Expand Up @@ -610,6 +621,9 @@ pub struct Decoder {
/// Rows to skip
to_skip: usize,

/// Whether to validate the first skipped row against the schema
header_validation: bool,

/// Current line number
line_number: usize,

Expand All @@ -635,6 +649,20 @@ impl Decoder {
/// network sources such as object storage
pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
if self.to_skip != 0 {
if self.header_validation {
let (skipped, bytes) = self.record_decoder.decode(buf, 1)?;

if skipped == 0 {
return Ok(bytes);
}

let rows = self.record_decoder.flush()?;
validate_header(&rows, self.schema.fields())?;
self.header_validation = false;
self.to_skip -= 1;
return Ok(bytes);
}

// Skip in units of `to_read` to avoid over-allocating buffers
let to_skip = self.to_skip.min(self.batch_size);
let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
Expand Down Expand Up @@ -678,6 +706,24 @@ impl Decoder {
}
}

fn validate_header(rows: &StringRecords<'_>, fields: &Fields) -> Result<(), ArrowError> {
let header = rows.iter().next().ok_or_else(|| {
ArrowError::CsvError("CSV header validation failed: no header row found".to_string())
})?;

for (idx, field) in fields.iter().enumerate() {
let actual = header.get(idx);
let expected = field.name();
if actual != expected {
return Err(ArrowError::CsvError(format!(
"CSV header does not match schema at column {idx}: expected {expected:?} but found {actual:?}"
)));
}
}

Ok(())
}

/// Parses a slice of [`StringRecords`] into a [RecordBatch]
fn parse(
rows: &StringRecords<'_>,
Expand Down Expand Up @@ -1154,6 +1200,14 @@ impl ReaderBuilder {
self
}

/// Set whether to validate the CSV header against the schema
///
/// This option only applies when [`Self::with_header`] is set to `true`, and defaults to `false`
pub fn with_header_validation(mut self, validate_header: bool) -> Self {
self.format.header_validation = validate_header;
self
}

/// Overrides the [Format] of this [ReaderBuilder]
pub fn with_format(mut self, format: Format) -> Self {
self.format = format;
Expand Down Expand Up @@ -1261,6 +1315,7 @@ impl ReaderBuilder {
Decoder {
schema: self.schema,
to_skip: start,
header_validation: self.format.header && self.format.header_validation,
record_decoder,
line_number: start,
end,
Expand Down Expand Up @@ -2351,6 +2406,62 @@ mod tests {
}
}

#[test]
fn test_header_validation() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));

let csv = "a,c\n1,2\n";
let err = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_header_validation(true)
.build_buffered(Cursor::new(csv.as_bytes()))
.unwrap()
.next()
.unwrap()
.unwrap_err()
.to_string();
assert_eq!(
err,
"Csv error: CSV header does not match schema at column 1: expected \"b\" but found \"c\""
);

let batch = ReaderBuilder::new(schema)
.with_header(true)
.with_header_validation(false)
.build_buffered(Cursor::new(csv.as_bytes()))
.unwrap()
.next()
.unwrap()
.unwrap();
assert_eq!(batch.num_rows(), 1);
}

#[test]
fn test_header_validation_with_buffered_reader() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));

let csv = "a,b\n1,2\n";
let buffered = std::io::BufReader::with_capacity(1, Cursor::new(csv.as_bytes()));
let batch = ReaderBuilder::new(schema)
.with_header(true)
.with_header_validation(true)
.build_buffered(buffered)
.unwrap()
.next()
.unwrap()
.unwrap();

assert_eq!(batch.num_rows(), 1);
let a = batch.column(0).as_primitive::<Int32Type>();
assert_eq!(a.value(0), 1);
}

#[test]
fn test_null_boolean() {
let csv = "true,false\nFalse,True\n,True\nFalse,";
Expand Down
Loading