-
Notifications
You must be signed in to change notification settings - Fork 0
3670: perf: use aligned pointer reads for SparkUnsafeRow field accessors #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0ee482d
6963a92
c59b81b
8bc5761
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,13 +71,32 @@ const NESTED_TYPE_BUILDER_CAPACITY: usize = 100; | |
| /// safe to call as long as: | ||
| /// - The index is within bounds (caller's responsibility) | ||
| /// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data | ||
| /// | ||
| /// # Alignment | ||
| /// | ||
| /// Primitive accessor methods are implemented separately for each type because they have | ||
| /// different alignment guarantees: | ||
| /// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is a multiple of 8, | ||
| /// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`. | ||
| /// - `SparkUnsafeArray`: The array base address may be unaligned when nested within a row's | ||
| /// variable-length region, so accessors use `ptr::read_unaligned()`. | ||
| pub trait SparkUnsafeObject { | ||
| /// Returns the address of the row. | ||
| fn get_row_addr(&self) -> i64; | ||
|
|
||
| /// Returns the offset of the element at the given index. | ||
| fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8; | ||
|
|
||
| fn get_boolean(&self, index: usize) -> bool; | ||
| fn get_byte(&self, index: usize) -> i8; | ||
| fn get_short(&self, index: usize) -> i16; | ||
| fn get_int(&self, index: usize) -> i32; | ||
| fn get_long(&self, index: usize) -> i64; | ||
| fn get_float(&self, index: usize) -> f32; | ||
| fn get_double(&self, index: usize) -> f64; | ||
| fn get_date(&self, index: usize) -> i32; | ||
| fn get_timestamp(&self, index: usize) -> i64; | ||
|
|
||
| /// Returns the offset and length of the element at the given index. | ||
| #[inline] | ||
| fn get_offset_and_len(&self, index: usize) -> (i32, i32) { | ||
|
|
@@ -87,79 +106,6 @@ pub trait SparkUnsafeObject { | |
| (offset, len) | ||
| } | ||
|
|
||
| /// Returns boolean value at the given index of the object. | ||
| #[inline] | ||
| fn get_boolean(&self, index: usize) -> bool { | ||
| let addr = self.get_element_offset(index, 1); | ||
| // SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region. | ||
| // The caller ensures index is within bounds. | ||
| debug_assert!( | ||
| !addr.is_null(), | ||
| "get_boolean: null pointer at index {index}" | ||
| ); | ||
| unsafe { *addr != 0 } | ||
| } | ||
|
|
||
| /// Returns byte value at the given index of the object. | ||
| #[inline] | ||
| fn get_byte(&self, index: usize) -> i8 { | ||
| let addr = self.get_element_offset(index, 1); | ||
| // SAFETY: addr points to valid element data (1 byte) within the row/array region. | ||
| debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}"); | ||
| let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) }; | ||
| i8::from_le_bytes(slice.try_into().unwrap()) | ||
| } | ||
|
|
||
| /// Returns short value at the given index of the object. | ||
| #[inline] | ||
| fn get_short(&self, index: usize) -> i16 { | ||
| let addr = self.get_element_offset(index, 2); | ||
| // SAFETY: addr points to valid element data (2 bytes) within the row/array region. | ||
| debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}"); | ||
| let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) }; | ||
| i16::from_le_bytes(slice.try_into().unwrap()) | ||
| } | ||
|
|
||
| /// Returns integer value at the given index of the object. | ||
| #[inline] | ||
| fn get_int(&self, index: usize) -> i32 { | ||
| let addr = self.get_element_offset(index, 4); | ||
| // SAFETY: addr points to valid element data (4 bytes) within the row/array region. | ||
| debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}"); | ||
| let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; | ||
| i32::from_le_bytes(slice.try_into().unwrap()) | ||
| } | ||
|
|
||
| /// Returns long value at the given index of the object. | ||
| #[inline] | ||
| fn get_long(&self, index: usize) -> i64 { | ||
| let addr = self.get_element_offset(index, 8); | ||
| // SAFETY: addr points to valid element data (8 bytes) within the row/array region. | ||
| debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}"); | ||
| let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; | ||
| i64::from_le_bytes(slice.try_into().unwrap()) | ||
| } | ||
|
|
||
| /// Returns float value at the given index of the object. | ||
| #[inline] | ||
| fn get_float(&self, index: usize) -> f32 { | ||
| let addr = self.get_element_offset(index, 4); | ||
| // SAFETY: addr points to valid element data (4 bytes) within the row/array region. | ||
| debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}"); | ||
| let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; | ||
| f32::from_le_bytes(slice.try_into().unwrap()) | ||
| } | ||
|
|
||
| /// Returns double value at the given index of the object. | ||
| #[inline] | ||
| fn get_double(&self, index: usize) -> f64 { | ||
| let addr = self.get_element_offset(index, 8); | ||
| // SAFETY: addr points to valid element data (8 bytes) within the row/array region. | ||
| debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}"); | ||
| let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; | ||
| f64::from_le_bytes(slice.try_into().unwrap()) | ||
| } | ||
|
|
||
| /// Returns string value at the given index of the object. | ||
| fn get_string(&self, index: usize) -> &str { | ||
| let (offset, len) = self.get_offset_and_len(index); | ||
|
|
@@ -190,29 +136,6 @@ pub trait SparkUnsafeObject { | |
| unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) } | ||
| } | ||
|
|
||
| /// Returns date value at the given index of the object. | ||
| #[inline] | ||
| fn get_date(&self, index: usize) -> i32 { | ||
| let addr = self.get_element_offset(index, 4); | ||
| // SAFETY: addr points to valid element data (4 bytes) within the row/array region. | ||
| debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}"); | ||
| let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; | ||
| i32::from_le_bytes(slice.try_into().unwrap()) | ||
| } | ||
|
|
||
| /// Returns timestamp value at the given index of the object. | ||
| #[inline] | ||
| fn get_timestamp(&self, index: usize) -> i64 { | ||
| let addr = self.get_element_offset(index, 8); | ||
| // SAFETY: addr points to valid element data (8 bytes) within the row/array region. | ||
| debug_assert!( | ||
| !addr.is_null(), | ||
| "get_timestamp: null pointer at index {index}" | ||
| ); | ||
| let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; | ||
| i64::from_le_bytes(slice.try_into().unwrap()) | ||
| } | ||
|
|
||
| /// Returns decimal value at the given index of the object. | ||
| fn get_decimal(&self, index: usize, precision: u8) -> i128 { | ||
| if precision <= MAX_LONG_DIGITS { | ||
|
|
@@ -244,6 +167,94 @@ pub trait SparkUnsafeObject { | |
| } | ||
| } | ||
|
|
||
| /// Generates primitive accessor implementations for `SparkUnsafeObject`. | ||
| /// | ||
| /// Uses `$read_method` to read typed values from raw pointers: | ||
| /// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte aligned) | ||
| /// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray) | ||
| macro_rules! impl_primitive_accessors { | ||
| ($read_method:ident) => { | ||
| #[inline] | ||
| fn get_boolean(&self, index: usize) -> bool { | ||
| let addr = self.get_element_offset(index, 1); | ||
| debug_assert!( | ||
| !addr.is_null(), | ||
| "get_boolean: null pointer at index {index}" | ||
| ); | ||
| // SAFETY: addr points to valid element data within the row/array region. | ||
| unsafe { *addr != 0 } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn get_byte(&self, index: usize) -> i8 { | ||
| let addr = self.get_element_offset(index, 1); | ||
| debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}"); | ||
| // SAFETY: addr points to valid element data (1 byte) within the row/array region. | ||
| unsafe { *(addr as *const i8) } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn get_short(&self, index: usize) -> i16 { | ||
| let addr = self.get_element_offset(index, 2) as *const i16; | ||
| debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}"); | ||
| // SAFETY: addr points to valid element data (2 bytes) within the row/array region. | ||
| unsafe { addr.$read_method() } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn get_int(&self, index: usize) -> i32 { | ||
| let addr = self.get_element_offset(index, 4) as *const i32; | ||
| debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}"); | ||
| // SAFETY: addr points to valid element data (4 bytes) within the row/array region. | ||
| unsafe { addr.$read_method() } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Severity: medium 🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage. |
||
| } | ||
|
|
||
| #[inline] | ||
| fn get_long(&self, index: usize) -> i64 { | ||
| let addr = self.get_element_offset(index, 8) as *const i64; | ||
| debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}"); | ||
| // SAFETY: addr points to valid element data (8 bytes) within the row/array region. | ||
| unsafe { addr.$read_method() } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn get_float(&self, index: usize) -> f32 { | ||
| let addr = self.get_element_offset(index, 4) as *const f32; | ||
| debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}"); | ||
| // SAFETY: addr points to valid element data (4 bytes) within the row/array region. | ||
| unsafe { addr.$read_method() } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn get_double(&self, index: usize) -> f64 { | ||
| let addr = self.get_element_offset(index, 8) as *const f64; | ||
| debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}"); | ||
| // SAFETY: addr points to valid element data (8 bytes) within the row/array region. | ||
| unsafe { addr.$read_method() } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn get_date(&self, index: usize) -> i32 { | ||
| let addr = self.get_element_offset(index, 4) as *const i32; | ||
| debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}"); | ||
| // SAFETY: addr points to valid element data (4 bytes) within the row/array region. | ||
| unsafe { addr.$read_method() } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn get_timestamp(&self, index: usize) -> i64 { | ||
| let addr = self.get_element_offset(index, 8) as *const i64; | ||
| debug_assert!( | ||
| !addr.is_null(), | ||
| "get_timestamp: null pointer at index {index}" | ||
| ); | ||
| // SAFETY: addr points to valid element data (8 bytes) within the row/array region. | ||
| unsafe { addr.$read_method() } | ||
| } | ||
| }; | ||
| } | ||
|
Comment on lines
+175
to
+255
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of primitive accessors has a correctness issue on big-endian systems. To ensure portability and correctness, you should explicitly handle the little-endian format. For integer types, you can use macro_rules! impl_primitive_accessors {
($read_method:ident) => {
#[inline]
fn get_boolean(&self, index: usize) -> bool {
let addr = self.get_element_offset(index, 1);
debug_assert!(
!addr.is_null(),
"get_boolean: null pointer at index {index}"
);
// SAFETY: addr points to valid element data within the row/array region.
unsafe { *addr != 0 }
}
#[inline]
fn get_byte(&self, index: usize) -> i8 {
let addr = self.get_element_offset(index, 1);
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
unsafe { *(addr as *const i8) }
}
#[inline]
fn get_short(&self, index: usize) -> i16 {
let addr = self.get_element_offset(index, 2) as *const i16;
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
// Spark's Unsafe format is little-endian, so we must convert from LE to native.
unsafe { i16::from_le(addr.$read_method()) }
}
#[inline]
fn get_int(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4) as *const i32;
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
// Spark's Unsafe format is little-endian, so we must convert from LE to native.
unsafe { i32::from_le(addr.$read_method()) }
}
#[inline]
fn get_long(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8) as *const i64;
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
// Spark's Unsafe format is little-endian, so we must convert from LE to native.
unsafe { i64::from_le(addr.$read_method()) }
}
#[inline]
fn get_float(&self, index: usize) -> f32 {
let addr = self.get_element_offset(index, 4) as *const u32;
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
// Spark's Unsafe format is little-endian. Read as u32, convert, then transmute.
unsafe { f32::from_bits(u32::from_le(addr.$read_method())) }
}
#[inline]
fn get_double(&self, index: usize) -> f64 {
let addr = self.get_element_offset(index, 8) as *const u64;
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
// Spark's Unsafe format is little-endian. Read as u64, convert, then transmute.
unsafe { f64::from_bits(u64::from_le(addr.$read_method())) }
}
#[inline]
fn get_date(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4) as *const i32;
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
// Spark's Unsafe format is little-endian, so we must convert from LE to native.
unsafe { i32::from_le(addr.$read_method()) }
}
#[inline]
fn get_timestamp(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8) as *const i64;
debug_assert!(
!addr.is_null(),
"get_timestamp: null pointer at index {index}"
);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
// Spark's Unsafe format is little-endian, so we must convert from LE to native.
unsafe { i64::from_le(addr.$read_method()) }
}
};
} |
||
| pub(crate) use impl_primitive_accessors; | ||
|
|
||
| pub struct SparkUnsafeRow { | ||
| row_addr: i64, | ||
| row_size: i32, | ||
|
|
@@ -265,6 +276,11 @@ impl SparkUnsafeObject for SparkUnsafeRow { | |
| ); | ||
| (self.row_addr + offset) as *const u8 | ||
| } | ||
|
|
||
| // SparkUnsafeRow field offsets are always 8-byte aligned: the base address is 8-byte | ||
| // aligned (JVM guarantee), bitset_width is a multiple of 8, and each field slot is | ||
| // 8 bytes. This means we can safely use aligned ptr::read() for all typed accesses. | ||
| impl_primitive_accessors!(read); | ||
|
Comment on lines
+280
to
+283
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
rg -n -C3 '\bpoint_to_slice\s*\(' --type rustRepository: martin-augment/datafusion-comet Length of output: 1955 🏁 Script executed: #!/bin/bash
# Check the Aligned struct definition and its repr attribute
sed -n '1690,1700p' native/core/src/execution/shuffle/spark_unsafe/row.rs
# Check the benchmark context
sed -n '225,250p' native/core/benches/row_columnar.rsRepository: martin-augment/datafusion-comet Length of output: 1426 🏁 Script executed: #!/bin/bash
# Find the benchmark function signature and where it's called from
rg -B5 'rows: &\[Vec<u8>\]' native/core/benches/row_columnar.rs | head -40Repository: martin-augment/datafusion-comet Length of output: 254 🏁 Script executed: #!/bin/bash
# Find calls to run_benchmark to see how rows are created
rg -A10 'run_benchmark\(' native/core/benches/row_columnar.rs | head -60Repository: martin-augment/datafusion-comet Length of output: 1397 🏁 Script executed: #!/bin/bash
# Find where rows Vec<Vec<u8>> is populated in the benchmark
rg -B15 'run_benchmark.*rows' native/core/benches/row_columnar.rs | grep -A15 'let mut rows\|rows\.push\|rows = ' | head -50Repository: martin-augment/datafusion-comet Length of output: 57 🏁 Script executed: #!/bin/bash
# Get broader context - look at the main benchmark setup
sed -n '1,150p' native/core/benches/row_columnar.rs | tail -80Repository: martin-augment/datafusion-comet Length of output: 2852 🏁 Script executed: #!/bin/bash
# Check the impl_primitive_accessors macro and how it's used with aligned reads
sed -n '280,290p' native/core/src/execution/shuffle/spark_unsafe/row.rs
# Also check the aligned read/write at lines 347-370
sed -n '347,375p' native/core/src/execution/shuffle/spark_unsafe/row.rsRepository: martin-augment/datafusion-comet Length of output: 2007 Enforce the 8-byte row alignment invariant at the safe entry point.
Add an alignment assertion at the safe boundary to catch this at runtime: Suggested guard pub fn point_to_slice(&mut self, slice: &[u8]) {
+ assert!(
+ slice.is_empty()
+ || (slice.as_ptr() as usize) % std::mem::align_of::<i64>() == 0,
+ "SparkUnsafeRow::point_to_slice requires an 8-byte aligned buffer"
+ );
self.row_addr = slice.as_ptr() as i64;
self.row_size = slice.len() as i32;
}The test at line 1695 is safe because it uses 🤖 Prompt for AI Agents |
||
| } | ||
|
|
||
| impl Default for SparkUnsafeRow { | ||
|
|
@@ -328,11 +344,13 @@ impl SparkUnsafeRow { | |
| // SAFETY: row_addr points to valid Spark UnsafeRow data with at least | ||
| // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. | ||
| // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. | ||
| // The bitset starts at row_addr (8-byte aligned) and each word is at offset 8*k, | ||
| // so word_offset is always 8-byte aligned — we can use aligned ptr::read(). | ||
| debug_assert!(self.row_addr != -1, "is_null_at: row not initialized"); | ||
| unsafe { | ||
| let mask: i64 = 1i64 << (index & 0x3f); | ||
| let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64; | ||
| let word: i64 = word_offset.read_unaligned(); | ||
| let word: i64 = word_offset.read(); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switching to Severity: high Other Locations
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage. |
||
| (word & mask) != 0 | ||
| } | ||
| } | ||
|
|
@@ -343,12 +361,13 @@ impl SparkUnsafeRow { | |
| // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. | ||
| // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. | ||
| // Writing is safe because we have mutable access and the memory is owned by the JVM. | ||
| // The bitset is always 8-byte aligned — we can use aligned ptr::read()/write(). | ||
| debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized"); | ||
| unsafe { | ||
| let mask: i64 = 1i64 << (index & 0x3f); | ||
| let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64; | ||
| let word: i64 = word_offset.read_unaligned(); | ||
| word_offset.write_unaligned(word & !mask); | ||
| let word: i64 = word_offset.read(); | ||
| word_offset.write(word & !mask); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -1668,9 +1687,12 @@ mod test { | |
| let mut row = SparkUnsafeRow::new_with_num_fields(1); | ||
| // 8 bytes null bitset + 8 bytes field value = 16 bytes | ||
| // Set bit 0 in the null bitset to mark field 0 as null | ||
| let mut data = [0u8; 16]; | ||
| data[0] = 1; | ||
| row.point_to_slice(&data); | ||
| // Use aligned buffer to match real Spark UnsafeRow layout (8-byte aligned) | ||
| #[repr(align(8))] | ||
| struct Aligned([u8; 16]); | ||
| let mut data = Aligned([0u8; 16]); | ||
| data.0[0] = 1; | ||
| row.point_to_slice(&data.0); | ||
| append_field(&data_type, &mut struct_builder, &row, 0).expect("append field"); | ||
| struct_builder.append_null(); | ||
| let struct_array = struct_builder.finish(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aligned read on potentially unaligned nested struct from array
High Severity
The trait method
get_structreturns aSparkUnsafeRowwhen called on aSparkUnsafeArray. SinceSparkUnsafeRowuses alignedptr::read()for all typed accesses, but the PR's own documentation states thatSparkUnsafeArraybase addresses "may be unaligned when nested within a row's variable-length region," the resultingSparkUnsafeRowfromarray.get_struct(idx, ...)would inherit that misalignment. Callingis_null_at,get_int,get_long, etc. on such a row invokes alignedptr::read()on potentially unaligned memory, which is undefined behavior. This path is exercised inlist.rswhen arrays contain struct elements.Additional Locations (2)
native/core/src/execution/shuffle/spark_unsafe/row.rs#L282-L283native/core/src/execution/shuffle/spark_unsafe/list.rs#L425-L439