Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/shims/x86/avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::callconv::FnAbi;

use super::{permute, pmaddbw, psadbw, pshufb};
use super::{permute, pmaddbw, psadbw, pshufb, vpdpbusd};
use crate::*;

impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
Expand Down Expand Up @@ -109,6 +109,18 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {

pshufb(this, left, right, dest)?;
}

// Used to implement the _mm512_dpbusd_epi32 function.
"vpdpbusd.512" | "vpdpbusd.256" | "vpdpbusd.128" => {
this.expect_target_feature_for_intrinsic(link_name, "avx512vnni")?;
if matches!(unprefixed_name, "vpdpbusd.128" | "vpdpbusd.256") {
this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
}

let [src, a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

vpdpbusd(this, src, a, b, dest)?;
}
_ => return interp_ok(EmulateItemResult::NotSupported),
}
interp_ok(EmulateItemResult::NeedsReturn)
Expand Down
47 changes: 47 additions & 0 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,53 @@ fn pshufb<'tcx>(
interp_ok(())
}

/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding signed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding signed
/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in `a` with corresponding signed

same for all the other references to variable names in the doc comment

/// 8-bit integers in b, producing 4 intermediate signed 16-bit results. Sum these 4 results with
/// the corresponding 32-bit integer in src, and store the packed 32-bit results in dst.
///
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_dpbusd_epi32>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_dpbusd_epi32>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_dpbusd_epi32>
fn vpdpbusd<'tcx>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is only used in avx512, please move the function to that file.

ecx: &mut crate::MiriInterpCx<'tcx>,
src: &OpTy<'tcx>,
a: &OpTy<'tcx>,
b: &OpTy<'tcx>,
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx, ()> {
let (src, src_len) = ecx.project_to_simd(src)?;
let (a, a_len) = ecx.project_to_simd(a)?;
let (b, b_len) = ecx.project_to_simd(b)?;
let (dest, dest_len) = ecx.project_to_simd(dest)?;

// fn vpdpbusd(src: i32x16, a: i32x16, b: i32x16) -> i32x16;
// fn vpdpbusd256(src: i32x8, a: i32x8, b: i32x8) -> i32x8;
// fn vpdpbusd128(src: i32x4, a: i32x4, b: i32x4) -> i32x4;
assert_eq!(dest_len, src_len);
assert_eq!(dest_len, a_len);
assert_eq!(dest_len, b_len);

for i in 0..dest_len {
let src = ecx.read_scalar(&ecx.project_index(&src, i)?)?.to_i32()?;
let a = ecx.read_scalar(&ecx.project_index(&a, i)?)?.to_u32()?;
let b = ecx.read_scalar(&ecx.project_index(&b, i)?)?.to_u32()?;
let dest = ecx.project_index(&dest, i)?;

let [a1, a2, a3, a4] = a.to_le_bytes();
let [b1, b2, b3, b4] = b.to_le_bytes();

let intermediate = i32::from(i16::from(a1).wrapping_mul(i16::from(b1.cast_signed())))
.wrapping_add(i32::from(i16::from(a2).wrapping_mul(i16::from(b2.cast_signed()))))
.wrapping_add(i32::from(i16::from(a3).wrapping_mul(i16::from(b3.cast_signed()))))
.wrapping_add(i32::from(i16::from(a4).wrapping_mul(i16::from(b4.cast_signed()))));

Comment on lines +1238 to +1242
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should find a way to make this more readable...
As a start, why are you mixing i16 and i32? And why wrapping_mul? multiplying two i8 as an i16 cannot overflow, right? Same for add. If things can never overflow, please use the strict operations.

Also, I think it would make sense to let-bind the 4 multiplications. Maybe that could even be written in a loop, e.g. via from_fn?

let res = Scalar::from_i32(intermediate.wrapping_add(src));
ecx.write_scalar(res, &dest)?;
}

interp_ok(())
}

/// Packs two N-bit integer vectors to a single N/2-bit integers.
///
/// The conversion from N-bit to N/2-bit should be provided by `f`.
Expand Down
94 changes: 93 additions & 1 deletion tests/pass/shims/x86/intrinsics-x86-avx512.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// We're testing x86 target specific features
//@only-target: x86_64 i686
//@compile-flags: -C target-feature=+avx512f,+avx512vl,+avx512bitalg,+avx512vpopcntdq
//@compile-flags: -C target-feature=+avx512f,+avx512vl,+avx512bitalg,+avx512vpopcntdq,+avx512vnni

#[cfg(target_arch = "x86")]
use std::arch::x86::*;
Expand All @@ -13,12 +13,14 @@ fn main() {
assert!(is_x86_feature_detected!("avx512vl"));
assert!(is_x86_feature_detected!("avx512bitalg"));
assert!(is_x86_feature_detected!("avx512vpopcntdq"));
assert!(is_x86_feature_detected!("avx512vnni"));

unsafe {
test_avx512();
test_avx512bitalg();
test_avx512vpopcntdq();
test_avx512ternarylogic();
test_avx512vnni();
}
}

Expand Down Expand Up @@ -411,6 +413,96 @@ unsafe fn test_avx512ternarylogic() {
test_mm_ternarylogic_epi32();
}

#[target_feature(enable = "avx512vnni")]
unsafe fn test_avx512vnni() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned that this is aiming to hit a bunch of overflow and truncation cases. Please add comments pointing that out.

#[target_feature(enable = "avx512vnni")]
unsafe fn test_mm512_dpbusd_epi32() {
const SRC: [i32; 16] = [
1,
0,
0,
7,
i32::MAX - 10,
i32::MIN + 10,
12345,
-9876,
0x01020304,
-1,
42,
0,
1_000_000_000,
-1_000_000_000,
17,
-17,
];

const A: [i32; 16] = [
0x01010101,
0xFFFF_FFFFu32 as i32,
0xFFFF_FFFFu32 as i32,
0x02_80_01_FF,
0xFFFF_FFFFu32 as i32,
0xFFFF_FFFFu32 as i32,
0x00_FF_00_FF,
0x7F_80_FF_01,
0x10_20_30_40,
0xDE_AD_BE_EFu32 as i32,
0x00_00_00_FF,
0x12_34_56_78,
0xFF_00_FF_00u32 as i32,
0x01_02_03_04,
0xAA_55_AA_55u32 as i32,
0x11_22_33_44,
];

const B: [i32; 16] = [
0x01010101,
0x7F7F_7F7F,
0x8080_8080u32 as i32,
0xFF_01_80_7Fu32 as i32,
0x7F7F_7F7F,
0x8080_8080u32 as i32,
0x01_FF_01_FF,
0x80_7F_00_FFu32 as i32,
0x7F_01_FF_80u32 as i32,
0x01_02_03_04,
0xFF_FF_FF_FFu32 as i32,
0x80_00_7F_FFu32 as i32,
0x7F_80_7F_80u32 as i32,
0x40_C0_20_E0u32 as i32,
0x00_01_02_03,
0x7F_7E_80_81u32 as i32,
];

const DST: [i32; 16] = [
5,
129540,
-130560,
32390,
-2147354119,
2147353098,
11835,
-9877,
16902884,
2093,
-213,
8498,
1000064770,
-1000000096,
697,
-8738,
];

let src = _mm512_loadu_si512(SRC.as_ptr().cast::<__m512i>());
let a = _mm512_loadu_si512(A.as_ptr().cast::<__m512i>());
let b = _mm512_loadu_si512(B.as_ptr().cast::<__m512i>());
let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>());

assert_eq_m512i(_mm512_dpbusd_epi32(src, a, b), dst);
}
test_mm512_dpbusd_epi32();
}

#[track_caller]
unsafe fn assert_eq_m512i(a: __m512i, b: __m512i) {
assert_eq!(transmute::<_, [i32; 16]>(a), transmute::<_, [i32; 16]>(b))
Expand Down