CPU solver: AVX-512 round-0 hashing + prefetch the emit gather
Branch the CPU solver onto modern x86 extensions, runtime-dispatched with the existing AVX2/scalar fallbacks: - BatchHasher::hash8 — an 8-lane AVX-512 BLAKE2b final-block compression (native _mm512_ror_epi64 rotates), falling back to two AVX2 hash4s or scalar. Round 0 now hashes eight g-values per chunk. round0-hash drops ~1.45x on AVX-512 CPUs (≈225→155 ms here, AMD Zen4). - emit_bucket software-prefetches each collision group's randomly-gathered member slots (the ~1 GB slot arena is the round's cache-miss bottleneck), shaving a few percent off the dominant emit phase. Controlled A/B on this Zen4 box (same thermal state): ~4-5% faster overall. The collision rounds are memory-bandwidth bound, so SIMD width is not the limiter — the modern-ISA win is modest by nature. EQ_NO_AVX512 / EQ_NO_PREFETCH opt out per-CPU (e.g. parts where AVX-512 downclocks) and back the A/B harness. hash8 is validated against the scalar reference (batch_matches_reference) and full solves still find valid solutions in every dispatch configuration. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -215,6 +215,92 @@ impl BatchHasher {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Hash eight consecutive indices `g0..g0+8`, writing each 48-byte digest
|
||||||
|
/// into `out[0..8]`. Uses AVX-512 (one 8-wide BLAKE2b compression) when
|
||||||
|
/// available, else two AVX2 `hash4`s. Modern AMD (Zen4+) and Intel CPUs run
|
||||||
|
/// AVX-512 without the clock penalty older Intel parts had.
|
||||||
|
#[inline]
|
||||||
|
pub fn hash8(&self, g0: u32, out: &mut [[u8; HASH_OUTPUT]; 8]) {
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
{
|
||||||
|
if is_x86_feature_detected!("avx512f") {
|
||||||
|
unsafe { self.hash8_avx512(g0, out) };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: two 4-lane batches into the two halves of `out`.
|
||||||
|
let (lo, hi) = out.split_at_mut(4);
|
||||||
|
self.hash4(g0, (&mut lo[..4]).try_into().unwrap());
|
||||||
|
self.hash4(g0 + 4, (&mut hi[..4]).try_into().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Eight-lane BLAKE2b final-block compression (AVX-512). Same structure as
|
||||||
|
/// [`Self::hash4_avx2`] but 512-bit lanes and native 64-bit rotates.
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
#[target_feature(enable = "avx512f")]
|
||||||
|
unsafe fn hash8_avx512(&self, g0: u32, out: &mut [[u8; HASH_OUTPUT]; 8]) {
|
||||||
|
use core::arch::x86_64::*;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn g8(
|
||||||
|
v: &mut [__m512i; 16],
|
||||||
|
a: usize, b: usize, c: usize, d: usize,
|
||||||
|
x: __m512i, y: __m512i,
|
||||||
|
) {
|
||||||
|
v[a] = _mm512_add_epi64(_mm512_add_epi64(v[a], v[b]), x);
|
||||||
|
v[d] = _mm512_ror_epi64::<32>(_mm512_xor_si512(v[d], v[a]));
|
||||||
|
v[c] = _mm512_add_epi64(v[c], v[d]);
|
||||||
|
v[b] = _mm512_ror_epi64::<24>(_mm512_xor_si512(v[b], v[c]));
|
||||||
|
v[a] = _mm512_add_epi64(_mm512_add_epi64(v[a], v[b]), y);
|
||||||
|
v[d] = _mm512_ror_epi64::<16>(_mm512_xor_si512(v[d], v[a]));
|
||||||
|
v[c] = _mm512_add_epi64(v[c], v[d]);
|
||||||
|
v[b] = _mm512_ror_epi64::<63>(_mm512_xor_si512(v[b], v[c]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only m0 and m1 are nonzero; m1's high 32 bits hold the per-lane `g`.
|
||||||
|
let tail0 = u64::from_le_bytes(self.tail[0..8].try_into().unwrap());
|
||||||
|
let tail_hi = u32::from_le_bytes(self.tail[8..12].try_into().unwrap()) as u64;
|
||||||
|
let m1 = |g: u32| (tail_hi | ((g as u64) << 32)) as i64;
|
||||||
|
let mut m = [_mm512_setzero_si512(); 16];
|
||||||
|
m[0] = _mm512_set1_epi64(tail0 as i64);
|
||||||
|
m[1] = _mm512_set_epi64(
|
||||||
|
m1(g0 + 7), m1(g0 + 6), m1(g0 + 5), m1(g0 + 4),
|
||||||
|
m1(g0 + 3), m1(g0 + 2), m1(g0 + 1), m1(g0),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut v = [_mm512_setzero_si512(); 16];
|
||||||
|
for i in 0..8 {
|
||||||
|
v[i] = _mm512_set1_epi64(self.mid[i] as i64);
|
||||||
|
v[i + 8] = _mm512_set1_epi64(IV[i] as i64);
|
||||||
|
}
|
||||||
|
v[12] = _mm512_xor_si512(v[12], _mm512_set1_epi64(FINAL_COUNT as i64));
|
||||||
|
v[14] = _mm512_xor_si512(v[14], _mm512_set1_epi64(-1)); // last-block flag
|
||||||
|
|
||||||
|
for s in &SIGMA {
|
||||||
|
g8(&mut v, 0, 4, 8, 12, m[s[0]], m[s[1]]);
|
||||||
|
g8(&mut v, 1, 5, 9, 13, m[s[2]], m[s[3]]);
|
||||||
|
g8(&mut v, 2, 6, 10, 14, m[s[4]], m[s[5]]);
|
||||||
|
g8(&mut v, 3, 7, 11, 15, m[s[6]], m[s[7]]);
|
||||||
|
g8(&mut v, 0, 5, 10, 15, m[s[8]], m[s[9]]);
|
||||||
|
g8(&mut v, 1, 6, 11, 12, m[s[10]], m[s[11]]);
|
||||||
|
g8(&mut v, 2, 7, 8, 13, m[s[12]], m[s[13]]);
|
||||||
|
g8(&mut v, 3, 4, 9, 14, m[s[14]], m[s[15]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// h[i] = mid[i] ^ v[i] ^ v[i+8]; first HASH_OUTPUT/8 words per lane, LE.
|
||||||
|
let mut tmp = [0u64; 8];
|
||||||
|
for i in 0..HASH_OUTPUT / 8 {
|
||||||
|
let o = _mm512_xor_si512(
|
||||||
|
_mm512_xor_si512(_mm512_set1_epi64(self.mid[i] as i64), v[i]),
|
||||||
|
v[i + 8],
|
||||||
|
);
|
||||||
|
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut _, o);
|
||||||
|
for l in 0..8 {
|
||||||
|
out[l][i * 8..i * 8 + 8].copy_from_slice(&tmp[l].to_le_bytes());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
#[target_feature(enable = "avx2")]
|
#[target_feature(enable = "avx2")]
|
||||||
unsafe fn hash4_avx2(&self, g0: u32, out: &mut [[u8; HASH_OUTPUT]; 4]) {
|
unsafe fn hash4_avx2(&self, g0: u32, out: &mut [[u8; HASH_OUTPUT]; 4]) {
|
||||||
@@ -316,5 +402,14 @@ mod tests {
|
|||||||
assert_eq!(out[l], generate_hash(&base, base_g + l as u32), "hash4 mismatch at g={}", base_g + l as u32);
|
assert_eq!(out[l], generate_hash(&base, base_g + l as u32), "hash4 mismatch at g={}", base_g + l as u32);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hash8 (AVX-512 or hash4 fallback) must match for several batches.
|
||||||
|
for base_g in [0u32, 8, 64, 1_000_000] {
|
||||||
|
let mut out = [[0u8; HASH_OUTPUT]; 8];
|
||||||
|
hasher.hash8(base_g, &mut out);
|
||||||
|
for l in 0..8 {
|
||||||
|
assert_eq!(out[l], generate_hash(&base, base_g + l as u32), "hash8 mismatch at g={}", base_g + l as u32);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+44
-13
@@ -438,7 +438,7 @@ fn low_group(
|
|||||||
/// Monomorphised over the XOR producer so the AVX2 intrinsics inline cleanly
|
/// Monomorphised over the XOR producer so the AVX2 intrinsics inline cleanly
|
||||||
/// inside a `target_feature` wrapper while sharing one source of truth.
|
/// inside a `target_feature` wrapper while sharing one source of truth.
|
||||||
macro_rules! emit_bucket_body {
|
macro_rules! emit_bucket_body {
|
||||||
($keys_sorted:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $xor:path) => {{
|
($keys_sorted:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $prefetch:expr, $xor:path) => {{
|
||||||
let s = $sorted;
|
let s = $sorted;
|
||||||
let ks = $keys_sorted; // leading keys in `s`-order; group walk streams it
|
let ks = $keys_sorted; // leading keys in `s`-order; group walk streams it
|
||||||
let m = s.len();
|
let m = s.len();
|
||||||
@@ -453,6 +453,18 @@ macro_rules! emit_bucket_body {
|
|||||||
j += 1;
|
j += 1;
|
||||||
}
|
}
|
||||||
let hi = j.min(i.saturating_add($clamp));
|
let hi = j.min(i.saturating_add($clamp));
|
||||||
|
// The pair loops below read `slots[s[a]]` / `slots[s[b]]` at random
|
||||||
|
// global positions in the ~1 GB slot arena — each a likely cache
|
||||||
|
// miss. Prefetch this group's member slots up front so the misses
|
||||||
|
// overlap with the (L1-resident) pair XORs that follow.
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
if $prefetch {
|
||||||
|
for a in i..hi {
|
||||||
|
core::arch::x86_64::_mm_prefetch::<{ core::arch::x86_64::_MM_HINT_T0 }>(
|
||||||
|
$slots.as_ptr().add(s[a] as usize * w_in) as *const i8,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
for a in i..hi {
|
for a in i..hi {
|
||||||
let l = s[a] as usize;
|
let l = s[a] as usize;
|
||||||
for b in (a + 1)..hi {
|
for b in (a + 1)..hi {
|
||||||
@@ -484,8 +496,9 @@ unsafe fn emit_bucket_scalar(
|
|||||||
clamp: usize,
|
clamp: usize,
|
||||||
w_in: usize,
|
w_in: usize,
|
||||||
w_out: usize,
|
w_out: usize,
|
||||||
|
prefetch: bool,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_scalar)
|
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, prefetch, xor_child_scalar)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
@@ -500,12 +513,14 @@ unsafe fn emit_bucket_avx2(
|
|||||||
clamp: usize,
|
clamp: usize,
|
||||||
w_in: usize,
|
w_in: usize,
|
||||||
w_out: usize,
|
w_out: usize,
|
||||||
|
prefetch: bool,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_avx2)
|
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, prefetch, xor_child_avx2)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a partition's children, dispatching to the AVX2 producer when available.
|
/// Emit a partition's children, dispatching to the AVX2 producer when available.
|
||||||
/// `w_in`/`w_out` are the input/output slot pitches (`w_out == w_in - 1`).
|
/// `w_in`/`w_out` are the input/output slot pitches (`w_out == w_in - 1`).
|
||||||
|
/// `prefetch` software-prefetches each group's randomly-gathered member slots.
|
||||||
unsafe fn emit_bucket(
|
unsafe fn emit_bucket(
|
||||||
keys_sorted: &[u32],
|
keys_sorted: &[u32],
|
||||||
slots: &[u32],
|
slots: &[u32],
|
||||||
@@ -516,14 +531,15 @@ unsafe fn emit_bucket(
|
|||||||
clamp: usize,
|
clamp: usize,
|
||||||
w_in: usize,
|
w_in: usize,
|
||||||
w_out: usize,
|
w_out: usize,
|
||||||
|
prefetch: bool,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
{
|
{
|
||||||
if is_x86_feature_detected!("avx2") {
|
if is_x86_feature_detected!("avx2") {
|
||||||
return emit_bucket_avx2(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out);
|
return emit_bucket_avx2(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, prefetch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
emit_bucket_scalar(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out)
|
emit_bucket_scalar(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, prefetch)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Group `n` entries by their leading block, then emit one child per colliding
|
/// Group `n` entries by their leading block, then emit one child per colliding
|
||||||
@@ -546,6 +562,9 @@ fn collide(
|
|||||||
// Sub-phase timing, gated on `EQ_PROFILE`. Prints partition / count / alloc /
|
// Sub-phase timing, gated on `EQ_PROFILE`. Prints partition / count / alloc /
|
||||||
// emit splits so we can see which part of the round dominates.
|
// emit splits so we can see which part of the round dominates.
|
||||||
let prof = std::env::var_os("EQ_PROFILE").is_some();
|
let prof = std::env::var_os("EQ_PROFILE").is_some();
|
||||||
|
// Software-prefetch the emit gather (on by default; `EQ_NO_PREFETCH` disables
|
||||||
|
// it for A/B benchmarking). Read once, not per group.
|
||||||
|
let do_prefetch = std::env::var_os("EQ_NO_PREFETCH").is_none();
|
||||||
let t0 = std::time::Instant::now();
|
let t0 = std::time::Instant::now();
|
||||||
|
|
||||||
let (starts, order, keys_part) = partition_top(keys, n);
|
let (starts, order, keys_part) = partition_top(keys, n);
|
||||||
@@ -609,8 +628,9 @@ fn collide(
|
|||||||
let mut sorted = Vec::new();
|
let mut sorted = Vec::new();
|
||||||
let mut keys_sorted = Vec::new();
|
let mut keys_sorted = Vec::new();
|
||||||
low_group(&keys_part[lo..hi], &order[lo..hi], &mut hist, &mut sorted, &mut keys_sorted);
|
low_group(&keys_part[lo..hi], &order[lo..hi], &mut hist, &mut sorted, &mut keys_sorted);
|
||||||
let w =
|
let w = unsafe {
|
||||||
unsafe { emit_bucket(&keys_sorted, slots, &sorted, kout, sout, pout, clamp, w_in, w_out) };
|
emit_bucket(&keys_sorted, slots, &sorted, kout, sout, pout, clamp, w_in, w_out, do_prefetch)
|
||||||
|
};
|
||||||
debug_assert_eq!(w, kout.len());
|
debug_assert_eq!(w, kout.len());
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -722,18 +742,29 @@ pub fn solve_with(header: &[u8], clamp: Option<usize>) -> Vec<Vec<u32>> {
|
|||||||
let hasher = BatchHasher::new(header);
|
let hasher = BatchHasher::new(header);
|
||||||
let mut keys: Vec<u32> = vec![0u32; n0];
|
let mut keys: Vec<u32> = vec![0u32; n0];
|
||||||
let mut slots: Vec<u32> = vec![0u32; n0 * SLOT];
|
let mut slots: Vec<u32> = vec![0u32; n0 * SLOT];
|
||||||
let kgroup = 4 * INDICES_PER_HASH_OUTPUT; // eight entries
|
// Hash eight `g` values per chunk via `hash8`, which uses one AVX-512
|
||||||
|
// compression where available and falls back to two AVX2 `hash4`s (or scalar)
|
||||||
|
// otherwise — so this single path covers every CPU.
|
||||||
|
let kgroup = 8 * INDICES_PER_HASH_OUTPUT; // sixteen entries
|
||||||
let sgroup = kgroup * SLOT;
|
let sgroup = kgroup * SLOT;
|
||||||
debug_assert_eq!(n0 % kgroup, 0, "round-0 buffer must split into whole 4-g groups");
|
debug_assert_eq!(n0 % kgroup, 0, "round-0 buffer must split into whole 8-g groups");
|
||||||
|
// `EQ_NO_AVX512` forces the AVX2 fallback (two hash4) for A/B benchmarking.
|
||||||
|
let use_avx512 = std::env::var_os("EQ_NO_AVX512").is_none();
|
||||||
slots
|
slots
|
||||||
.par_chunks_mut(sgroup)
|
.par_chunks_mut(sgroup)
|
||||||
.zip(keys.par_chunks_mut(kgroup))
|
.zip(keys.par_chunks_mut(kgroup))
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.for_each(|(c, (schunk, kchunk))| {
|
.for_each(|(c, (schunk, kchunk))| {
|
||||||
let g0 = (c * 4) as u32;
|
let g0 = (c * 8) as u32;
|
||||||
let mut hs = [[0u8; HASH_OUTPUT]; 4];
|
let mut hs = [[0u8; HASH_OUTPUT]; 8];
|
||||||
hasher.hash4(g0, &mut hs);
|
if use_avx512 {
|
||||||
for j in 0..4 {
|
hasher.hash8(g0, &mut hs);
|
||||||
|
} else {
|
||||||
|
let (lo, hi) = hs.split_at_mut(4);
|
||||||
|
hasher.hash4(g0, (&mut lo[..4]).try_into().unwrap());
|
||||||
|
hasher.hash4(g0 + 4, (&mut hi[..4]).try_into().unwrap());
|
||||||
|
}
|
||||||
|
for j in 0..8 {
|
||||||
for i in 0..INDICES_PER_HASH_OUTPUT {
|
for i in 0..INDICES_PER_HASH_OUTPUT {
|
||||||
let e = j * INDICES_PER_HASH_OUTPUT + i;
|
let e = j * INDICES_PER_HASH_OUTPUT + i;
|
||||||
let src = &hs[j][i * HASH_BYTES..i * HASH_BYTES + HASH_BYTES];
|
let src = &hs[j][i * HASH_BYTES..i * HASH_BYTES + HASH_BYTES];
|
||||||
|
|||||||
Reference in New Issue
Block a user