diff --git a/arrow-json/benches/json-reader.rs b/arrow-json/benches/json-reader.rs index 504839f8ffe2..d02790ad8428 100644 --- a/arrow-json/benches/json-reader.rs +++ b/arrow-json/benches/json-reader.rs @@ -179,7 +179,13 @@ fn bench_binary_hex(c: &mut Criterion) { bench_decode_binary(c, "decode_binary_view_hex_json", &binary_data, view_field); } -fn bench_decode_schema(c: &mut Criterion, name: &str, data: &[u8], schema: Arc) { +fn bench_decode_schema( + c: &mut Criterion, + name: &str, + data: &[u8], + schema: Arc, + projection: bool, +) { let mut group = c.benchmark_group(name); group.throughput(Throughput::Bytes(data.len() as u64)); group.sample_size(50); @@ -190,6 +196,7 @@ fn bench_decode_schema(c: &mut Criterion, name: &str, data: &[u8], schema: Arc Self { + Self { projection, ..self } + } + /// Set the [`StructMode`] for the reader, which determines whether structs /// can be decoded from JSON as objects or lists. For more details refer to /// the enum documentation. Default is to use `ObjectOnly`. @@ -303,6 +313,19 @@ impl ReaderBuilder { } }; + let num_fields = self.schema.flattened_fields().len(); + + // Extract projection field set from schema for projection-aware parsing + // - strict_mode: fail-fast on unknown fields during tape parsing + // - projection: skip JSON fields not present in the schema + let enable_projection = self.strict_mode || self.projection; + let projection: Option> = match &data_type { + DataType::Struct(fields) if enable_projection && !fields.is_empty() => { + Some(fields.iter().map(|f| f.name().clone()).collect()) + } + _ => None, + }; + let decoder = make_decoder( data_type, self.coerce_primitive, @@ -311,12 +334,15 @@ impl ReaderBuilder { self.struct_mode, )?; - let num_fields = self.schema.flattened_fields().len(); - Ok(Decoder { decoder, is_field: self.is_field, - tape_decoder: TapeDecoder::new(self.batch_size, num_fields), + tape_decoder: TapeDecoder::new( + self.batch_size, + num_fields, + projection, + self.strict_mode, + ), batch_size: self.batch_size, schema: self.schema, }) @@ -1783,6 +1809,39 @@ mod tests { ); } + #[test] + fn test_projection_skip_unknown_fields() { + // JSON has fields a, b, c but schema only has a, c + let buf = r#" + {"a": 1, "b": "ignored", "c": true} + {"a": 2, "b": "also ignored", "c": false} + "#; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("c", DataType::Boolean, true), + ])); + + // with_projection(true): skip unknown field "b" and succeed + let batch = ReaderBuilder::new(schema) + .with_projection(true) + .build(Cursor::new(buf.as_bytes())) + .unwrap() + .read() + .unwrap() + .unwrap(); + + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 2); + + let a = batch.column(0).as_primitive::(); + assert_eq!(a.values(), &[1, 2]); + + let c = batch.column(1).as_boolean(); + assert!(c.value(0)); + assert!(!c.value(1)); + } + fn read_file(path: &str, schema: Option) -> Reader> { let file = File::open(path).unwrap(); let mut reader = BufReader::new(file); diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index 89ee3f778765..86c8152de441 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -19,6 +19,7 @@ use crate::reader::serializer::TapeSerializer; use arrow_schema::ArrowError; use memchr::memchr2; use serde_core::Serialize; +use std::collections::HashSet; use std::fmt::Write; /// We decode JSON to a flattened tape representation, @@ -237,6 +238,18 @@ enum DecoderState { /// /// Consists of `(literal, decoded length)` Literal(Literal, u8), + /// Skipping a value (for unprojected fields), not inside a string + /// + /// Contains the nesting depth of objects/arrays being skipped + SkipValue(u32), + /// Skipping inside a string literal (for unprojected fields) + /// + /// Contains the nesting depth of objects/arrays + SkipString(u32), + /// Skipping an escape sequence inside a string (for unprojected fields) + /// + /// Contains the nesting depth of objects/arrays + SkipEscape(u32), } impl DecoderState { @@ -251,6 +264,9 @@ impl DecoderState { DecoderState::Escape => "escape", DecoderState::Unicode(_, _, _) => "unicode literal", DecoderState::Literal(d, _) => d.as_str(), + DecoderState::SkipValue(_) => "skip value", + DecoderState::SkipString(_) => "skip string", + DecoderState::SkipEscape(_) => "skip escape", } } } @@ -315,12 +331,44 @@ pub struct TapeDecoder { /// A stack of [`DecoderState`] stack: Vec, + + /// Optional projection: set of field names to include + /// If None, all fields are parsed. If Some, only fields in the set are parsed. + projection: Option>, + + /// If true, return error when encountering fields not in projection + strict_mode: bool, + + /// Cache current nesting depth to avoid O(depth) stack traversal on every field + /// Incremented when entering Object/List, decremented when exiting + current_nesting_depth: usize, } impl TapeDecoder { + /// Returns projection info if we should check field projection. + /// Only applies at top level (nesting_depth == 1) with a projection set. + fn projection_info(&self) -> Option<(usize, &HashSet)> { + if self.current_nesting_depth != 1 { + return None; + } + let projection = self.projection.as_ref()?; + let TapeElement::String(string_idx) = *self.elements.last()? else { + return None; + }; + Some((string_idx as usize, projection)) + } + /// Create a new [`TapeDecoder`] with the provided batch size /// and an estimated number of fields in each row - pub fn new(batch_size: usize, num_fields: usize) -> Self { + /// + /// If `projection` is Some, only fields in the set will be parsed and written to the tape. + /// Other fields will be skipped during parsing (or rejected if `strict_mode` is true). + pub fn new( + batch_size: usize, + num_fields: usize, + projection: Option>, + strict_mode: bool, + ) -> Self { let tokens_per_row = 2 + num_fields * 2; let mut offsets = Vec::with_capacity(batch_size * (num_fields * 2) + 1); offsets.push(0); @@ -335,6 +383,9 @@ impl TapeDecoder { cur_row: 0, bytes: Vec::with_capacity(num_fields * 2 * 8), stack: Vec::with_capacity(10), + projection, + strict_mode, + current_nesting_depth: 0, } } @@ -372,6 +423,7 @@ impl TapeDecoder { let end_idx = self.elements.len() as u32; self.elements[start_idx as usize] = TapeElement::StartObject(end_idx); self.elements.push(TapeElement::EndObject(start_idx)); + self.current_nesting_depth -= 1; self.stack.pop(); } b => return Err(err(b, "parsing object")), @@ -387,6 +439,7 @@ impl TapeDecoder { let end_idx = self.elements.len() as u32; self.elements[start_idx as usize] = TapeElement::StartList(end_idx); self.elements.push(TapeElement::EndList(start_idx)); + self.current_nesting_depth -= 1; self.stack.pop(); } Some(_) => self.stack.push(DecoderState::Value), @@ -423,11 +476,13 @@ impl TapeDecoder { b'[' => { let idx = self.elements.len() as u32; self.elements.push(TapeElement::StartList(u32::MAX)); + self.current_nesting_depth += 1; DecoderState::List(idx) } b'{' => { let idx = self.elements.len() as u32; self.elements.push(TapeElement::StartObject(u32::MAX)); + self.current_nesting_depth += 1; DecoderState::Object(idx) } b => return Err(err(b, "parsing value")), @@ -449,7 +504,41 @@ impl TapeDecoder { DecoderState::Colon => { iter.skip_whitespace(); match next!(iter) { - b':' => self.stack.pop(), + b':' => { + self.stack.pop(); + + // Check projection at top level only + if let Some((string_idx, projection)) = self.projection_info() { + let start = self.offsets[string_idx]; + let end = self.offsets[string_idx + 1]; + let field_name = std::str::from_utf8(&self.bytes[start..end]) + .map_err(|e| { + ArrowError::JsonError(format!( + "Invalid UTF-8 in field name: {e}" + )) + })?; + + match (projection.contains(field_name), self.strict_mode) { + (true, _) => {} + (false, true) => { + return Err(ArrowError::JsonError(format!( + "column '{field_name}' missing from schema" + ))); + } + (false, false) => { + // Field not in projection: skip its value + // Remove field name from tape (must have paired field_name:value) + self.elements.pop(); + self.bytes.truncate(start); + self.offsets.pop(); + + // Replace Value state with SkipValue + *self.stack.last_mut().unwrap() = + DecoderState::SkipValue(0); + } + } + } + } b => return Err(err(b, "parsing colon")), }; } @@ -519,6 +608,113 @@ impl TapeDecoder { } *idx += 1; }, + // Skip a value (not inside a string) + DecoderState::SkipValue(depth) => { + while !iter.is_empty() { + if *depth > 0 { + // Inside nested structure - fast skip to next structural character + iter.advance_until(|b| matches!(b, b'"' | b'{' | b'[' | b'}' | b']')); + if iter.is_empty() { + break; + } + match next!(iter) { + b'"' => { + *state = DecoderState::SkipString(*depth); + break; + } + b'{' | b'[' => *depth += 1, + b'}' | b']' => { + *depth -= 1; + if *depth == 0 { + self.stack.pop(); + break; + } + } + _ => {} + } + } else { + // depth == 0: Skip simple value (number/literal/start of compound) + iter.advance_until(|b| { + matches!( + b, + b',' | b'}' + | b']' + | b' ' + | b'\n' + | b'\r' + | b'\t' + | b'"' + | b'{' + | b'[' + ) + }); + if iter.is_empty() { + break; + } + match iter.peek() { + Some(b',' | b'}' | b']') => { + self.stack.pop(); + break; + } + Some(b' ' | b'\n' | b'\r' | b'\t') => { + iter.skip_whitespace(); + if iter.peek().is_some_and(|b| matches!(b, b',' | b'}' | b']')) + { + self.stack.pop(); + break; + } + if iter.is_empty() { + break; + } + } + Some(b'"') => { + next!(iter); + *state = DecoderState::SkipString(0); + break; + } + Some(b'{' | b'[') => { + next!(iter); + *depth = 1; + } + _ => {} + } + } + } + } + // Skip inside a string literal + DecoderState::SkipString(depth) => { + iter.skip_chrs(b'\\', b'"'); + if iter.is_empty() { + break; + } + match next!(iter) { + b'\\' => *state = DecoderState::SkipEscape(*depth), + b'"' => { + if *depth == 0 { + // String value ended at top level - check completion + iter.skip_whitespace(); + if iter.peek().is_some_and(|b| matches!(b, b',' | b'}' | b']')) { + self.stack.pop(); + } else if iter.is_empty() { + // Need more data, stay in a "finished string but not yet popped" state + // For simplicity, transition to SkipValue(0) and let it handle + *state = DecoderState::SkipValue(0); + } + } else { + *state = DecoderState::SkipValue(*depth); + } + } + _ => unreachable!(), + } + } + // Skip an escape sequence inside a string + DecoderState::SkipEscape(depth) => { + if iter.is_empty() { + break; + } + next!(iter); // consume escaped character + *state = DecoderState::SkipString(*depth); + } } } @@ -767,7 +963,7 @@ mod tests { {"a": ["", "foo", ["bar", "c"]], "b": {"1": []}, "c": {"2": [1, 2, 3]} } "#; - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); decoder.decode(a.as_bytes()).unwrap(); assert!(!decoder.has_partial_row()); assert_eq!(decoder.num_buffered_rows(), 7); @@ -877,21 +1073,21 @@ mod tests { #[test] fn test_invalid() { // Test invalid - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); let err = decoder.decode(b"hello").unwrap_err().to_string(); assert_eq!( err, "Json error: Encountered unexpected 'h' whilst parsing value" ); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); let err = decoder.decode(b"{\"hello\": }").unwrap_err().to_string(); assert_eq!( err, "Json error: Encountered unexpected '}' whilst parsing value" ); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); let err = decoder .decode(b"{\"hello\": [ false, tru ]}") .unwrap_err() @@ -901,7 +1097,7 @@ mod tests { "Json error: Encountered unexpected ' ' whilst parsing literal" ); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); let err = decoder .decode(b"{\"hello\": \"\\ud8\"}") .unwrap_err() @@ -912,7 +1108,7 @@ mod tests { ); // Missing surrogate pair - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); let err = decoder .decode(b"{\"hello\": \"\\ud83d\"}") .unwrap_err() @@ -923,40 +1119,40 @@ mod tests { ); // Test truncation - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); decoder.decode(b"{\"he").unwrap(); assert!(decoder.has_partial_row()); assert_eq!(decoder.num_buffered_rows(), 1); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading string"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); decoder.decode(b"{\"hello\" : ").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading value"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); decoder.decode(b"{\"hello\" : [").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading list"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); decoder.decode(b"{\"hello\" : tru").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading true"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); decoder.decode(b"{\"hello\" : nu").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Truncated record whilst reading null"); // Test invalid UTF-8 - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); decoder.decode(b"{\"hello\" : \"world\xFF\"}").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Encountered non-UTF-8 data"); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); decoder.decode(b"{\"\xe2\" : \"\x96\xa1\"}").unwrap(); let err = decoder.finish().unwrap_err().to_string(); assert_eq!(err, "Json error: Encountered truncated UTF-8 sequence"); @@ -964,11 +1160,11 @@ mod tests { #[test] fn test_invalid_surrogates() { - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); let res = decoder.decode(b"{\"test\": \"\\ud800\\ud801\"}"); assert!(res.is_err()); - let mut decoder = TapeDecoder::new(16, 2); + let mut decoder = TapeDecoder::new(16, 2, None, false); let res = decoder.decode(b"{\"test\": \"\\udc00\\udc01\"}"); assert!(res.is_err()); }