CPU solver: AVX-512 round-0 hashing + prefetch the emit gather

Branch the CPU solver onto modern x86 extensions, runtime-dispatched with the
existing AVX2/scalar fallbacks:

- BatchHasher::hash8 — an 8-lane AVX-512 BLAKE2b final-block compression (native
  _mm512_ror_epi64 rotates), falling back to two AVX2 hash4s or scalar. Round 0
  now hashes eight g-values per chunk. round0-hash drops ~1.45x on AVX-512 CPUs
  (≈225→155 ms here, AMD Zen4).
- emit_bucket software-prefetches each collision group's randomly-gathered
  member slots (the ~1 GB slot arena is the round's cache-miss bottleneck),
  shaving a few percent off the dominant emit phase.

Controlled A/B on this Zen4 box (same thermal state): ~4-5% faster overall.
The collision rounds are memory-bandwidth bound, so SIMD width is not the
limiter — the modern-ISA win is modest by nature. EQ_NO_AVX512 / EQ_NO_PREFETCH
opt out per-CPU (e.g. parts where AVX-512 downclocks) and back the A/B harness.

hash8 is validated against the scalar reference (batch_matches_reference) and
full solves still find valid solutions in every dispatch configuration.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
jackpotincorporated
2026-06-06 13:12:57 -04:00
parent 00531fb591
commit afd56bee1b
2 changed files with 139 additions and 13 deletions
+95
View File
@@ -215,6 +215,92 @@ impl BatchHasher {
} }
} }
/// Hash eight consecutive indices `g0..g0+8`, writing each 48-byte digest
/// into `out[0..8]`. Uses AVX-512 (one 8-wide BLAKE2b compression) when
/// available, else two AVX2 `hash4`s. Modern AMD (Zen4+) and Intel CPUs run
/// AVX-512 without the clock penalty older Intel parts had.
#[inline]
pub fn hash8(&self, g0: u32, out: &mut [[u8; HASH_OUTPUT]; 8]) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { self.hash8_avx512(g0, out) };
return;
}
}
// Fallback: two 4-lane batches into the two halves of `out`.
let (lo, hi) = out.split_at_mut(4);
self.hash4(g0, (&mut lo[..4]).try_into().unwrap());
self.hash4(g0 + 4, (&mut hi[..4]).try_into().unwrap());
}
/// Eight-lane BLAKE2b final-block compression (AVX-512). Same structure as
/// [`Self::hash4_avx2`] but 512-bit lanes and native 64-bit rotates.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn hash8_avx512(&self, g0: u32, out: &mut [[u8; HASH_OUTPUT]; 8]) {
use core::arch::x86_64::*;
#[inline(always)]
unsafe fn g8(
v: &mut [__m512i; 16],
a: usize, b: usize, c: usize, d: usize,
x: __m512i, y: __m512i,
) {
v[a] = _mm512_add_epi64(_mm512_add_epi64(v[a], v[b]), x);
v[d] = _mm512_ror_epi64::<32>(_mm512_xor_si512(v[d], v[a]));
v[c] = _mm512_add_epi64(v[c], v[d]);
v[b] = _mm512_ror_epi64::<24>(_mm512_xor_si512(v[b], v[c]));
v[a] = _mm512_add_epi64(_mm512_add_epi64(v[a], v[b]), y);
v[d] = _mm512_ror_epi64::<16>(_mm512_xor_si512(v[d], v[a]));
v[c] = _mm512_add_epi64(v[c], v[d]);
v[b] = _mm512_ror_epi64::<63>(_mm512_xor_si512(v[b], v[c]));
}
// Only m0 and m1 are nonzero; m1's high 32 bits hold the per-lane `g`.
let tail0 = u64::from_le_bytes(self.tail[0..8].try_into().unwrap());
let tail_hi = u32::from_le_bytes(self.tail[8..12].try_into().unwrap()) as u64;
let m1 = |g: u32| (tail_hi | ((g as u64) << 32)) as i64;
let mut m = [_mm512_setzero_si512(); 16];
m[0] = _mm512_set1_epi64(tail0 as i64);
m[1] = _mm512_set_epi64(
m1(g0 + 7), m1(g0 + 6), m1(g0 + 5), m1(g0 + 4),
m1(g0 + 3), m1(g0 + 2), m1(g0 + 1), m1(g0),
);
let mut v = [_mm512_setzero_si512(); 16];
for i in 0..8 {
v[i] = _mm512_set1_epi64(self.mid[i] as i64);
v[i + 8] = _mm512_set1_epi64(IV[i] as i64);
}
v[12] = _mm512_xor_si512(v[12], _mm512_set1_epi64(FINAL_COUNT as i64));
v[14] = _mm512_xor_si512(v[14], _mm512_set1_epi64(-1)); // last-block flag
for s in &SIGMA {
g8(&mut v, 0, 4, 8, 12, m[s[0]], m[s[1]]);
g8(&mut v, 1, 5, 9, 13, m[s[2]], m[s[3]]);
g8(&mut v, 2, 6, 10, 14, m[s[4]], m[s[5]]);
g8(&mut v, 3, 7, 11, 15, m[s[6]], m[s[7]]);
g8(&mut v, 0, 5, 10, 15, m[s[8]], m[s[9]]);
g8(&mut v, 1, 6, 11, 12, m[s[10]], m[s[11]]);
g8(&mut v, 2, 7, 8, 13, m[s[12]], m[s[13]]);
g8(&mut v, 3, 4, 9, 14, m[s[14]], m[s[15]]);
}
// h[i] = mid[i] ^ v[i] ^ v[i+8]; first HASH_OUTPUT/8 words per lane, LE.
let mut tmp = [0u64; 8];
for i in 0..HASH_OUTPUT / 8 {
let o = _mm512_xor_si512(
_mm512_xor_si512(_mm512_set1_epi64(self.mid[i] as i64), v[i]),
v[i + 8],
);
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut _, o);
for l in 0..8 {
out[l][i * 8..i * 8 + 8].copy_from_slice(&tmp[l].to_le_bytes());
}
}
}
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")] #[target_feature(enable = "avx2")]
unsafe fn hash4_avx2(&self, g0: u32, out: &mut [[u8; HASH_OUTPUT]; 4]) { unsafe fn hash4_avx2(&self, g0: u32, out: &mut [[u8; HASH_OUTPUT]; 4]) {
@@ -316,5 +402,14 @@ mod tests {
assert_eq!(out[l], generate_hash(&base, base_g + l as u32), "hash4 mismatch at g={}", base_g + l as u32); assert_eq!(out[l], generate_hash(&base, base_g + l as u32), "hash4 mismatch at g={}", base_g + l as u32);
} }
} }
// hash8 (AVX-512 or hash4 fallback) must match for several batches.
for base_g in [0u32, 8, 64, 1_000_000] {
let mut out = [[0u8; HASH_OUTPUT]; 8];
hasher.hash8(base_g, &mut out);
for l in 0..8 {
assert_eq!(out[l], generate_hash(&base, base_g + l as u32), "hash8 mismatch at g={}", base_g + l as u32);
}
}
} }
} }
+44 -13
View File
@@ -438,7 +438,7 @@ fn low_group(
/// Monomorphised over the XOR producer so the AVX2 intrinsics inline cleanly /// Monomorphised over the XOR producer so the AVX2 intrinsics inline cleanly
/// inside a `target_feature` wrapper while sharing one source of truth. /// inside a `target_feature` wrapper while sharing one source of truth.
macro_rules! emit_bucket_body { macro_rules! emit_bucket_body {
($keys_sorted:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $xor:path) => {{ ($keys_sorted:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $prefetch:expr, $xor:path) => {{
let s = $sorted; let s = $sorted;
let ks = $keys_sorted; // leading keys in `s`-order; group walk streams it let ks = $keys_sorted; // leading keys in `s`-order; group walk streams it
let m = s.len(); let m = s.len();
@@ -453,6 +453,18 @@ macro_rules! emit_bucket_body {
j += 1; j += 1;
} }
let hi = j.min(i.saturating_add($clamp)); let hi = j.min(i.saturating_add($clamp));
// The pair loops below read `slots[s[a]]` / `slots[s[b]]` at random
// global positions in the ~1 GB slot arena — each a likely cache
// miss. Prefetch this group's member slots up front so the misses
// overlap with the (L1-resident) pair XORs that follow.
#[cfg(target_arch = "x86_64")]
if $prefetch {
for a in i..hi {
core::arch::x86_64::_mm_prefetch::<{ core::arch::x86_64::_MM_HINT_T0 }>(
$slots.as_ptr().add(s[a] as usize * w_in) as *const i8,
);
}
}
for a in i..hi { for a in i..hi {
let l = s[a] as usize; let l = s[a] as usize;
for b in (a + 1)..hi { for b in (a + 1)..hi {
@@ -484,8 +496,9 @@ unsafe fn emit_bucket_scalar(
clamp: usize, clamp: usize,
w_in: usize, w_in: usize,
w_out: usize, w_out: usize,
prefetch: bool,
) -> usize { ) -> usize {
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_scalar) emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, prefetch, xor_child_scalar)
} }
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
@@ -500,12 +513,14 @@ unsafe fn emit_bucket_avx2(
clamp: usize, clamp: usize,
w_in: usize, w_in: usize,
w_out: usize, w_out: usize,
prefetch: bool,
) -> usize { ) -> usize {
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_avx2) emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, prefetch, xor_child_avx2)
} }
/// Emit a partition's children, dispatching to the AVX2 producer when available. /// Emit a partition's children, dispatching to the AVX2 producer when available.
/// `w_in`/`w_out` are the input/output slot pitches (`w_out == w_in - 1`). /// `w_in`/`w_out` are the input/output slot pitches (`w_out == w_in - 1`).
/// `prefetch` software-prefetches each group's randomly-gathered member slots.
unsafe fn emit_bucket( unsafe fn emit_bucket(
keys_sorted: &[u32], keys_sorted: &[u32],
slots: &[u32], slots: &[u32],
@@ -516,14 +531,15 @@ unsafe fn emit_bucket(
clamp: usize, clamp: usize,
w_in: usize, w_in: usize,
w_out: usize, w_out: usize,
prefetch: bool,
) -> usize { ) -> usize {
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
{ {
if is_x86_feature_detected!("avx2") { if is_x86_feature_detected!("avx2") {
return emit_bucket_avx2(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out); return emit_bucket_avx2(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, prefetch);
} }
} }
emit_bucket_scalar(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out) emit_bucket_scalar(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, prefetch)
} }
/// Group `n` entries by their leading block, then emit one child per colliding /// Group `n` entries by their leading block, then emit one child per colliding
@@ -546,6 +562,9 @@ fn collide(
// Sub-phase timing, gated on `EQ_PROFILE`. Prints partition / count / alloc / // Sub-phase timing, gated on `EQ_PROFILE`. Prints partition / count / alloc /
// emit splits so we can see which part of the round dominates. // emit splits so we can see which part of the round dominates.
let prof = std::env::var_os("EQ_PROFILE").is_some(); let prof = std::env::var_os("EQ_PROFILE").is_some();
// Software-prefetch the emit gather (on by default; `EQ_NO_PREFETCH` disables
// it for A/B benchmarking). Read once, not per group.
let do_prefetch = std::env::var_os("EQ_NO_PREFETCH").is_none();
let t0 = std::time::Instant::now(); let t0 = std::time::Instant::now();
let (starts, order, keys_part) = partition_top(keys, n); let (starts, order, keys_part) = partition_top(keys, n);
@@ -609,8 +628,9 @@ fn collide(
let mut sorted = Vec::new(); let mut sorted = Vec::new();
let mut keys_sorted = Vec::new(); let mut keys_sorted = Vec::new();
low_group(&keys_part[lo..hi], &order[lo..hi], &mut hist, &mut sorted, &mut keys_sorted); low_group(&keys_part[lo..hi], &order[lo..hi], &mut hist, &mut sorted, &mut keys_sorted);
let w = let w = unsafe {
unsafe { emit_bucket(&keys_sorted, slots, &sorted, kout, sout, pout, clamp, w_in, w_out) }; emit_bucket(&keys_sorted, slots, &sorted, kout, sout, pout, clamp, w_in, w_out, do_prefetch)
};
debug_assert_eq!(w, kout.len()); debug_assert_eq!(w, kout.len());
}); });
@@ -722,18 +742,29 @@ pub fn solve_with(header: &[u8], clamp: Option<usize>) -> Vec<Vec<u32>> {
let hasher = BatchHasher::new(header); let hasher = BatchHasher::new(header);
let mut keys: Vec<u32> = vec![0u32; n0]; let mut keys: Vec<u32> = vec![0u32; n0];
let mut slots: Vec<u32> = vec![0u32; n0 * SLOT]; let mut slots: Vec<u32> = vec![0u32; n0 * SLOT];
let kgroup = 4 * INDICES_PER_HASH_OUTPUT; // eight entries // Hash eight `g` values per chunk via `hash8`, which uses one AVX-512
// compression where available and falls back to two AVX2 `hash4`s (or scalar)
// otherwise — so this single path covers every CPU.
let kgroup = 8 * INDICES_PER_HASH_OUTPUT; // sixteen entries
let sgroup = kgroup * SLOT; let sgroup = kgroup * SLOT;
debug_assert_eq!(n0 % kgroup, 0, "round-0 buffer must split into whole 4-g groups"); debug_assert_eq!(n0 % kgroup, 0, "round-0 buffer must split into whole 8-g groups");
// `EQ_NO_AVX512` forces the AVX2 fallback (two hash4) for A/B benchmarking.
let use_avx512 = std::env::var_os("EQ_NO_AVX512").is_none();
slots slots
.par_chunks_mut(sgroup) .par_chunks_mut(sgroup)
.zip(keys.par_chunks_mut(kgroup)) .zip(keys.par_chunks_mut(kgroup))
.enumerate() .enumerate()
.for_each(|(c, (schunk, kchunk))| { .for_each(|(c, (schunk, kchunk))| {
let g0 = (c * 4) as u32; let g0 = (c * 8) as u32;
let mut hs = [[0u8; HASH_OUTPUT]; 4]; let mut hs = [[0u8; HASH_OUTPUT]; 8];
hasher.hash4(g0, &mut hs); if use_avx512 {
for j in 0..4 { hasher.hash8(g0, &mut hs);
} else {
let (lo, hi) = hs.split_at_mut(4);
hasher.hash4(g0, (&mut lo[..4]).try_into().unwrap());
hasher.hash4(g0 + 4, (&mut hi[..4]).try_into().unwrap());
}
for j in 0..8 {
for i in 0..INDICES_PER_HASH_OUTPUT { for i in 0..INDICES_PER_HASH_OUTPUT {
let e = j * INDICES_PER_HASH_OUTPUT + i; let e = j * INDICES_PER_HASH_OUTPUT + i;
let src = &hs[j][i * HASH_BYTES..i * HASH_BYTES + HASH_BYTES]; let src = &hs[j][i * HASH_BYTES..i * HASH_BYTES + HASH_BYTES];