Skip to content

Commit 1b3d1fa

Browse files
committed
Inline both branches and spare_capacity_mut
Signed-off-by: Adam Gutglick <adam@spiraldb.com>
1 parent 089c9a9 commit 1b3d1fa

File tree

1 file changed

+45
-46
lines changed

1 file changed

+45
-46
lines changed

arrow-select/src/take.rs

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -490,52 +490,12 @@ fn take_boolean<IndexType: ArrowPrimitiveType>(
490490
BooleanArray::new(val_buf, null_buf)
491491
}
492492

493-
/// Copies byte ranges from `src` into a new contiguous buffer.
494-
///
495-
/// # Safety
496-
/// Each `(start, end)` in `ranges` must be in-bounds of `src`, and
497-
/// `capacity` must equal the total bytes across all ranges.
498-
unsafe fn copy_byte_ranges(
499-
src: &[u8],
500-
ranges: &[(usize, usize)],
501-
capacity: usize,
502-
values: &mut Vec<u8>,
503-
) {
504-
values.reserve(capacity);
505-
debug_assert_eq!(
506-
ranges.iter().map(|(s, e)| e - s).sum::<usize>(),
507-
capacity,
508-
"capacity must equal total bytes across all ranges"
509-
);
510-
let src_len = src.len();
511-
let src = src.as_ptr();
512-
let mut dst = values.as_mut_ptr();
513-
for &(start, end) in ranges {
514-
debug_assert!(start <= end, "invalid range: start ({start}) > end ({end})");
515-
debug_assert!(
516-
end <= src_len,
517-
"range end ({end}) out of bounds (src len {src_len})"
518-
);
519-
let len = end - start;
520-
// SAFETY: caller guarantees each (start, end) is in-bounds of `src`.
521-
// `dst` advances within the `capacity` bytes we allocated.
522-
// The regions don't overlap (src is input, dst is a fresh allocation).
523-
unsafe {
524-
std::ptr::copy_nonoverlapping(src.add(start), dst, len);
525-
dst = dst.add(len);
526-
}
527-
}
528-
// SAFETY: caller guarantees `capacity` == total bytes across all ranges,
529-
// so the loop above wrote exactly `capacity` bytes.
530-
unsafe { values.set_len(capacity) };
531-
}
532-
533493
/// `take` implementation for string arrays
534494
fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
535495
array: &GenericByteArray<T>,
536496
indices: &PrimitiveArray<IndexType>,
537497
) -> Result<GenericByteArray<T>, ArrowError> {
538-
let mut values = Vec::new();
498+
let mut values: Vec<u8> = Vec::new();
539499
let mut offsets = Vec::with_capacity(indices.len() + 1);
540500
offsets.push(T::Offset::default());
541501

@@ -560,15 +520,21 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
560520

561521
values.reserve(capacity);
562522

563-
let mut dst = values.as_mut_ptr();
523+
let dst = values.spare_capacity_mut();
524+
debug_assert!(dst.len() >= capacity);
525+
let mut offset = 0;
564526

565527
for index in indices.values() {
566528
// SAFETY: in-bounds proven by the first loop's bounds-checked offset access.
567-
// dst stays within reserved capacity computed from the same indices.
529+
// dst asserted above to include the required capacity.
568530
unsafe {
569531
let data: &[u8] = array.value_unchecked(index.as_usize()).as_ref();
570-
std::ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len());
571-
dst = dst.add(data.len());
532+
std::ptr::copy_nonoverlapping(
533+
data.as_ptr(),
534+
dst[offset..].as_mut_ptr().cast::<u8>(),
535+
data.len(),
536+
);
537+
offset += data.len();
572538
}
573539
}
574540

@@ -601,6 +567,9 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
601567
capacity += end - start;
602568
offsets[i + 1] = T::Offset::from_usize(capacity)
603569
.ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
570+
571+
debug_assert!(end >= start, "invalid range: start ({start}) > end ({end})");
572+
604573
ranges.push((start, end));
605574
last_filled = i + 1;
606575
}
@@ -610,7 +579,37 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
610579
.ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
611580
offsets[last_filled + 1..].fill(final_offset);
612581
// Pass 2: copy byte data for all collected ranges.
613-
unsafe { copy_byte_ranges(array.value_data(), &ranges, capacity, &mut values) };
582+
values.reserve(capacity);
583+
debug_assert_eq!(
584+
ranges.iter().map(|(s, e)| e - s).sum::<usize>(),
585+
capacity,
586+
"capacity must equal total bytes across all ranges"
587+
);
588+
589+
let src = array.value_data();
590+
let src = src.as_ptr();
591+
let dst = values.spare_capacity_mut();
592+
debug_assert!(dst.len() >= capacity);
593+
594+
let mut offset = 0;
595+
596+
for (start, end) in ranges.into_iter() {
597+
let value_len = end - start;
598+
// SAFETY: caller guarantees each (start, end) is in-bounds of `src`.
599+
// `dst` asserted above to include the required capacity.
600+
// The regions don't overlap (src is input, dst is a fresh allocation).
601+
unsafe {
602+
std::ptr::copy_nonoverlapping(
603+
src.add(start),
604+
dst[offset..].as_mut_ptr().cast::<u8>(),
605+
value_len,
606+
);
607+
offset += value_len;
608+
}
609+
}
610+
// SAFETY: caller guarantees `capacity` == total bytes across all ranges,
611+
// so the loop above wrote exactly `capacity` bytes.
612+
unsafe { values.set_len(capacity) };
614613
}
615614
};
616615

0 commit comments

Comments
 (0)