Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion native/core/src/execution/shuffle/spark_unsafe/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ use crate::{
errors::CometError,
execution::shuffle::spark_unsafe::{
map::append_map_elements,
row::{append_field, downcast_builder_ref, SparkUnsafeObject, SparkUnsafeRow},
row::{
append_field, downcast_builder_ref, impl_primitive_accessors, SparkUnsafeObject,
SparkUnsafeRow,
},
},
};
use arrow::array::{
Expand Down Expand Up @@ -101,6 +104,10 @@ impl SparkUnsafeObject for SparkUnsafeArray {
fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 {
(self.element_offset + (index * element_size) as i64) as *const u8
}

// SparkUnsafeArray base address may be unaligned when nested within a row's variable-length
// region, so we must use ptr::read_unaligned() for all typed accesses.
impl_primitive_accessors!(read_unaligned);
}

impl SparkUnsafeArray {
Expand Down
226 changes: 124 additions & 102 deletions native/core/src/execution/shuffle/spark_unsafe/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,32 @@ const NESTED_TYPE_BUILDER_CAPACITY: usize = 100;
/// safe to call as long as:
/// - The index is within bounds (caller's responsibility)
/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
///
/// # Alignment
///
/// Primitive accessor methods are implemented separately for each type because they have
/// different alignment guarantees:
/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is a multiple of 8,
/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`.
/// - `SparkUnsafeArray`: The array base address may be unaligned when nested within a row's
/// variable-length region, so accessors use `ptr::read_unaligned()`.
pub trait SparkUnsafeObject {
/// Returns the address of the row.
fn get_row_addr(&self) -> i64;

/// Returns the offset of the element at the given index.
fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8;

fn get_boolean(&self, index: usize) -> bool;
fn get_byte(&self, index: usize) -> i8;
fn get_short(&self, index: usize) -> i16;
fn get_int(&self, index: usize) -> i32;
fn get_long(&self, index: usize) -> i64;
fn get_float(&self, index: usize) -> f32;
fn get_double(&self, index: usize) -> f64;
fn get_date(&self, index: usize) -> i32;
fn get_timestamp(&self, index: usize) -> i64;

/// Returns the offset and length of the element at the given index.
#[inline]
fn get_offset_and_len(&self, index: usize) -> (i32, i32) {
Expand All @@ -87,79 +106,6 @@ pub trait SparkUnsafeObject {
(offset, len)
}

/// Returns boolean value at the given index of the object.
#[inline]
fn get_boolean(&self, index: usize) -> bool {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region.
// The caller ensures index is within bounds.
debug_assert!(
!addr.is_null(),
"get_boolean: null pointer at index {index}"
);
unsafe { *addr != 0 }
}

/// Returns byte value at the given index of the object.
#[inline]
fn get_byte(&self, index: usize) -> i8 {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
i8::from_le_bytes(slice.try_into().unwrap())
}

/// Returns short value at the given index of the object.
#[inline]
fn get_short(&self, index: usize) -> i16 {
let addr = self.get_element_offset(index, 2);
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
i16::from_le_bytes(slice.try_into().unwrap())
}

/// Returns integer value at the given index of the object.
#[inline]
fn get_int(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns long value at the given index of the object.
#[inline]
fn get_long(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
}

/// Returns float value at the given index of the object.
#[inline]
fn get_float(&self, index: usize) -> f32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
f32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns double value at the given index of the object.
#[inline]
fn get_double(&self, index: usize) -> f64 {
let addr = self.get_element_offset(index, 8);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aligned read on potentially unaligned nested struct from array

High Severity

The trait method get_struct returns a SparkUnsafeRow when called on a SparkUnsafeArray. Since SparkUnsafeRow uses aligned ptr::read() for all typed accesses, but the PR's own documentation states that SparkUnsafeArray base addresses "may be unaligned when nested within a row's variable-length region," the resulting SparkUnsafeRow from array.get_struct(idx, ...) would inherit that misalignment. Calling is_null_at, get_int, get_long, etc. on such a row invokes aligned ptr::read() on potentially unaligned memory, which is undefined behavior. This path is exercised in list.rs when arrays contain struct elements.

Additional Locations (2)
Fix in Cursor Fix in Web

// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
f64::from_le_bytes(slice.try_into().unwrap())
}

/// Returns string value at the given index of the object.
fn get_string(&self, index: usize) -> &str {
let (offset, len) = self.get_offset_and_len(index);
Expand Down Expand Up @@ -190,29 +136,6 @@ pub trait SparkUnsafeObject {
unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
}

/// Returns date value at the given index of the object.
#[inline]
fn get_date(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns timestamp value at the given index of the object.
#[inline]
fn get_timestamp(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(
!addr.is_null(),
"get_timestamp: null pointer at index {index}"
);
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
}

/// Returns decimal value at the given index of the object.
fn get_decimal(&self, index: usize, precision: u8) -> i128 {
if precision <= MAX_LONG_DIGITS {
Expand Down Expand Up @@ -244,6 +167,94 @@ pub trait SparkUnsafeObject {
}
}

/// Generates primitive accessor implementations for `SparkUnsafeObject`.
///
/// Uses `$read_method` to read typed values from raw pointers:
/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte aligned)
/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray)
macro_rules! impl_primitive_accessors {
($read_method:ident) => {
#[inline]
fn get_boolean(&self, index: usize) -> bool {
let addr = self.get_element_offset(index, 1);
debug_assert!(
!addr.is_null(),
"get_boolean: null pointer at index {index}"
);
// SAFETY: addr points to valid element data within the row/array region.
unsafe { *addr != 0 }
}

#[inline]
fn get_byte(&self, index: usize) -> i8 {
let addr = self.get_element_offset(index, 1);
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
unsafe { *(addr as *const i8) }
}

#[inline]
fn get_short(&self, index: usize) -> i16 {
let addr = self.get_element_offset(index, 2) as *const i16;
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_int(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4) as *const i32;
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
unsafe { addr.$read_method() }
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptr::$read_method() reads in native endianness, but Spark UnsafeRow/UnsafeArray encodes primitives as little-endian (e.g., SparkUnsafeArray::new uses from_le_bytes). If this crate ever targets big-endian platforms, these accessors would decode incorrectly unless you explicitly guard/document little-endian-only support.

Severity: medium

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.

}

#[inline]
fn get_long(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8) as *const i64;
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_float(&self, index: usize) -> f32 {
let addr = self.get_element_offset(index, 4) as *const f32;
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_double(&self, index: usize) -> f64 {
let addr = self.get_element_offset(index, 8) as *const f64;
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_date(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4) as *const i32;
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_timestamp(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8) as *const i64;
debug_assert!(
!addr.is_null(),
"get_timestamp: null pointer at index {index}"
);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
unsafe { addr.$read_method() }
}
};
}
Comment on lines +175 to +255
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of primitive accessors has a correctness issue on big-endian systems. ptr::read() and ptr::read_unaligned() interpret bytes using the native endianness of the CPU. However, Spark's Unsafe format is little-endian. This will lead to incorrect data being read on big-endian architectures, causing silent data corruption.

To ensure portability and correctness, you should explicitly handle the little-endian format. For integer types, you can use T::from_le(). For floating-point types, you should read the data as an integer of the same size, convert its endianness, and then convert the bits to a float (e.g., f32::from_bits(u32::from_le(...))).

macro_rules! impl_primitive_accessors {
    ($read_method:ident) => {
        #[inline]
        fn get_boolean(&self, index: usize) -> bool {
            let addr = self.get_element_offset(index, 1);
            debug_assert!(
                !addr.is_null(),
                "get_boolean: null pointer at index {index}"
            );
            // SAFETY: addr points to valid element data within the row/array region.
            unsafe { *addr != 0 }
        }

        #[inline]
        fn get_byte(&self, index: usize) -> i8 {
            let addr = self.get_element_offset(index, 1);
            debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
            // SAFETY: addr points to valid element data (1 byte) within the row/array region.
            unsafe { *(addr as *const i8) }
        }

        #[inline]
        fn get_short(&self, index: usize) -> i16 {
            let addr = self.get_element_offset(index, 2) as *const i16;
            debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
            // SAFETY: addr points to valid element data (2 bytes) within the row/array region.
            // Spark's Unsafe format is little-endian, so we must convert from LE to native.
            unsafe { i16::from_le(addr.$read_method()) }
        }

        #[inline]
        fn get_int(&self, index: usize) -> i32 {
            let addr = self.get_element_offset(index, 4) as *const i32;
            debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
            // SAFETY: addr points to valid element data (4 bytes) within the row/array region.
            // Spark's Unsafe format is little-endian, so we must convert from LE to native.
            unsafe { i32::from_le(addr.$read_method()) }
        }

        #[inline]
        fn get_long(&self, index: usize) -> i64 {
            let addr = self.get_element_offset(index, 8) as *const i64;
            debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
            // SAFETY: addr points to valid element data (8 bytes) within the row/array region.
            // Spark's Unsafe format is little-endian, so we must convert from LE to native.
            unsafe { i64::from_le(addr.$read_method()) }
        }

        #[inline]
        fn get_float(&self, index: usize) -> f32 {
            let addr = self.get_element_offset(index, 4) as *const u32;
            debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
            // SAFETY: addr points to valid element data (4 bytes) within the row/array region.
            // Spark's Unsafe format is little-endian. Read as u32, convert, then transmute.
            unsafe { f32::from_bits(u32::from_le(addr.$read_method())) }
        }

        #[inline]
        fn get_double(&self, index: usize) -> f64 {
            let addr = self.get_element_offset(index, 8) as *const u64;
            debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
            // SAFETY: addr points to valid element data (8 bytes) within the row/array region.
            // Spark's Unsafe format is little-endian. Read as u64, convert, then transmute.
            unsafe { f64::from_bits(u64::from_le(addr.$read_method())) }
        }

        #[inline]
        fn get_date(&self, index: usize) -> i32 {
            let addr = self.get_element_offset(index, 4) as *const i32;
            debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
            // SAFETY: addr points to valid element data (4 bytes) within the row/array region.
            // Spark's Unsafe format is little-endian, so we must convert from LE to native.
            unsafe { i32::from_le(addr.$read_method()) }
        }

        #[inline]
        fn get_timestamp(&self, index: usize) -> i64 {
            let addr = self.get_element_offset(index, 8) as *const i64;
            debug_assert!(
                !addr.is_null(),
                "get_timestamp: null pointer at index {index}"
            );
            // SAFETY: addr points to valid element data (8 bytes) within the row/array region.
            // Spark's Unsafe format is little-endian, so we must convert from LE to native.
            unsafe { i64::from_le(addr.$read_method()) }
        }
    };
}

pub(crate) use impl_primitive_accessors;

pub struct SparkUnsafeRow {
row_addr: i64,
row_size: i32,
Expand All @@ -265,6 +276,11 @@ impl SparkUnsafeObject for SparkUnsafeRow {
);
(self.row_addr + offset) as *const u8
}

// SparkUnsafeRow field offsets are always 8-byte aligned: the base address is 8-byte
// aligned (JVM guarantee), bitset_width is a multiple of 8, and each field slot is
// 8 bytes. This means we can safely use aligned ptr::read() for all typed accesses.
impl_primitive_accessors!(read);
Comment on lines +280 to +283
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
rg -n -C3 '\bpoint_to_slice\s*\(' --type rust

Repository: martin-augment/datafusion-comet

Length of output: 1955


🏁 Script executed:

#!/bin/bash
# Check the Aligned struct definition and its repr attribute
sed -n '1690,1700p' native/core/src/execution/shuffle/spark_unsafe/row.rs

# Check the benchmark context
sed -n '225,250p' native/core/benches/row_columnar.rs

Repository: martin-augment/datafusion-comet

Length of output: 1426


🏁 Script executed:

#!/bin/bash
# Find the benchmark function signature and where it's called from
rg -B5 'rows: &\[Vec<u8>\]' native/core/benches/row_columnar.rs | head -40

Repository: martin-augment/datafusion-comet

Length of output: 254


🏁 Script executed:

#!/bin/bash
# Find calls to run_benchmark to see how rows are created
rg -A10 'run_benchmark\(' native/core/benches/row_columnar.rs | head -60

Repository: martin-augment/datafusion-comet

Length of output: 1397


🏁 Script executed:

#!/bin/bash
# Find where rows Vec<Vec<u8>> is populated in the benchmark
rg -B15 'run_benchmark.*rows' native/core/benches/row_columnar.rs | grep -A15 'let mut rows\|rows\.push\|rows = ' | head -50

Repository: martin-augment/datafusion-comet

Length of output: 57


🏁 Script executed:

#!/bin/bash
# Get broader context - look at the main benchmark setup
sed -n '1,150p' native/core/benches/row_columnar.rs | tail -80

Repository: martin-augment/datafusion-comet

Length of output: 2852


🏁 Script executed:

#!/bin/bash
# Check the impl_primitive_accessors macro and how it's used with aligned reads
sed -n '280,290p' native/core/src/execution/shuffle/spark_unsafe/row.rs

# Also check the aligned read/write at lines 347-370
sed -n '347,375p' native/core/src/execution/shuffle/spark_unsafe/row.rs

Repository: martin-augment/datafusion-comet

Length of output: 2007


Enforce the 8-byte row alignment invariant at the safe entry point.

impl_primitive_accessors!(read) and the aligned ptr::read()/ptr::write() calls in is_null_at() and set_not_null_at() require row_addr to be 8-byte aligned, but SparkUnsafeRow::point_to_slice(&[u8]) accepts unaligned slices. The benchmark passes Vec<u8> buffers (alignment = 1) to point_to_slice(), which later causes undefined behavior when aligned pointer operations dereference misaligned addresses.

Add an alignment assertion at the safe boundary to catch this at runtime:

Suggested guard
 pub fn point_to_slice(&mut self, slice: &[u8]) {
+    assert!(
+        slice.is_empty()
+            || (slice.as_ptr() as usize) % std::mem::align_of::<i64>() == 0,
+        "SparkUnsafeRow::point_to_slice requires an 8-byte aligned buffer"
+    );
     self.row_addr = slice.as_ptr() as i64;
     self.row_size = slice.len() as i32;
 }

The test at line 1695 is safe because it uses #[repr(align(8))], but the benchmark at line 236 passes unaligned Vec<u8> slices and will panic or exhibit undefined behavior without this guard.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@native/core/src/execution/shuffle/spark_unsafe/row.rs` around lines 280 -
283, The safe constructor SparkUnsafeRow::point_to_slice(&[u8]) currently
accepts unaligned slices but downstream helpers generated by
impl_primitive_accessors!(read) and methods is_null_at() / set_not_null_at() use
aligned ptr::read()/write(), so add a runtime alignment check in point_to_slice
that asserts the row base pointer is 8-byte aligned (e.g., check (row_addr as
usize) % 8 == 0 or use is_aligned_to(8)) and panic with a clear message if not;
this enforces the 8-byte alignment invariant at the safe entry point and
prevents undefined behavior when benchmarks pass unaligned Vec<u8> buffers.

}

impl Default for SparkUnsafeRow {
Expand Down Expand Up @@ -328,11 +344,13 @@ impl SparkUnsafeRow {
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
// The bitset starts at row_addr (8-byte aligned) and each word is at offset 8*k,
// so word_offset is always 8-byte aligned — we can use aligned ptr::read().
debug_assert!(self.row_addr != -1, "is_null_at: row not initialized");
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
let word: i64 = word_offset.read_unaligned();
let word: i64 = word_offset.read();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to word_offset.read() assumes row_addr is 8-byte aligned, but point_to_slice can point at arbitrary &[u8] (and Vec<u8> is not guaranteed 8-byte aligned by its type). If an unaligned slice is ever passed in, this becomes UB; it may be worth enforcing/documenting the alignment precondition at the API boundary.

Severity: high

Other Locations
  • native/core/src/execution/shuffle/spark_unsafe/row.rs:369
  • native/core/src/execution/shuffle/spark_unsafe/row.rs:370

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.

(word & mask) != 0
}
}
Expand All @@ -343,12 +361,13 @@ impl SparkUnsafeRow {
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
// Writing is safe because we have mutable access and the memory is owned by the JVM.
// The bitset is always 8-byte aligned — we can use aligned ptr::read()/write().
debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized");
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;
let word: i64 = word_offset.read_unaligned();
word_offset.write_unaligned(word & !mask);
let word: i64 = word_offset.read();
word_offset.write(word & !mask);
}
}
}
Expand Down Expand Up @@ -1668,9 +1687,12 @@ mod test {
let mut row = SparkUnsafeRow::new_with_num_fields(1);
// 8 bytes null bitset + 8 bytes field value = 16 bytes
// Set bit 0 in the null bitset to mark field 0 as null
let mut data = [0u8; 16];
data[0] = 1;
row.point_to_slice(&data);
// Use aligned buffer to match real Spark UnsafeRow layout (8-byte aligned)
#[repr(align(8))]
struct Aligned([u8; 16]);
let mut data = Aligned([0u8; 16]);
data.0[0] = 1;
row.point_to_slice(&data.0);
append_field(&data_type, &mut struct_builder, &row, 0).expect("append field");
struct_builder.append_null();
let struct_array = struct_builder.finish();
Expand Down
Loading