Scatter partition-contiguous keys to kill the per-round key gathers

count_pairs, low_group, and the emit group-walk all read each entry's leading
key via `keys[order[..]]` — a random gather over the whole ~128 MB keys array,
three times per round. partition_top now also produces `keys_part` (the leading
keys in partition order, keys_part[p] == keys[order[p]]), written by the same
parallel, disjoint phase-3 scatter at 4 bytes/entry. count_pairs and low_group
then stream their partition's keys sequentially, and low_group emits a
`keys_sorted` array so the emit group walk streams a dense local copy instead of
gathering keys[sorted[i]]. The only remaining DRAM-random access in the rounds is
the unavoidable slot gather.

Measured (16 threads, clamp 16/32): count ~160 -> ~10 ms/round, emit ~770 -> ~550
ms/round, partition +~80 ms (the added 128 MB scatter); full solve ~8.4 -> ~7.04 s
(~16%). Cumulative across the three CPU-solver changes: ~13.4 -> ~7.04 s (-47%),
0.07 -> 0.14 solve/s. Identical solution yield; cross-clamp validity and
full_solve_baseline pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
jackpotincorporated
2026-06-06 11:23:10 -04:00
parent 966ce3e262
commit 1b4a2a4dd9
+90 -60
View File
@@ -274,24 +274,30 @@ unsafe impl Send for OrderPtr {}
unsafe impl Sync for OrderPtr {}
/// Partition the `n` entries into `TOP_BUCKETS` runs by the high `TOP_BITS` of
/// their (dense) leading block. Returns `(starts, order)`, where partition `v`
/// owns the input indices `order[starts[v]..starts[v + 1]]`. The histogram
/// passes stream over `keys[]` (4 bytes/entry) instead of striding the slots.
/// their (dense) leading block. Returns `(starts, order, keys_part)`: partition
/// `v` owns the input indices `order[starts[v]..starts[v + 1]]`, and `keys_part`
/// is the entries' leading words in the same partition-contiguous order
/// (`keys_part[p] == keys[order[p]]`). Carrying `keys_part` lets the per-partition
/// `count_pairs`/`low_group` sweeps read keys sequentially instead of gathering
/// `keys[order[..]]` over the whole array — at the cost of one extra 4-byte/entry
/// scatter here, folded into the (already parallel) phase-3 pass.
///
/// Parallel counting sort: the input is split into one contiguous chunk per
/// rayon worker. Each chunk histograms its slice (phase 1), a small serial pass
/// turns those into per-chunk base offsets within each bucket's output region
/// (phase 2), and each chunk scatters its entries into `order` (phase 3). Chunk
/// `c`'s bucket-`b` writes land in `[off[c][b], off[c+1][b])`, disjoint from
/// every other chunk and bucket, so the concurrent writes never alias. Entries
/// within a bucket end up chunk-major rather than index-major; that reordering
/// is immaterial — `count_pairs`/`low_group` depend only on the multiset of low
/// keys, and final solutions are canonicalised, de-duplicated, and verified.
fn partition_top(keys: &[u32], n: usize) -> (Vec<u32>, Vec<u32>) {
/// (phase 2), and each chunk scatters its entries into `order`/`keys_part`
/// (phase 3). Chunk `c`'s bucket-`b` writes land in `[off[c][b], off[c+1][b])`,
/// disjoint from every other chunk and bucket, so the concurrent writes never
/// alias. Entries within a bucket end up chunk-major rather than index-major;
/// that reordering is immaterial — `count_pairs`/`low_group` depend only on the
/// multiset of low keys, and final solutions are canonicalised, de-duplicated,
/// and verified.
fn partition_top(keys: &[u32], n: usize) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
let mut starts = vec![0u32; TOP_BUCKETS + 1];
let mut order = vec![0u32; n];
let mut keys_part = vec![0u32; n];
if n == 0 {
return (starts, order);
return (starts, order, keys_part);
}
let nthreads = rayon::current_num_threads().max(1);
@@ -334,33 +340,39 @@ fn partition_top(keys: &[u32], n: usize) -> (Vec<u32>, Vec<u32>) {
// Phase 3 (parallel): each chunk scatters into its disjoint sub-ranges.
let optr = OrderPtr(order.as_mut_ptr());
let kptr = OrderPtr(keys_part.as_mut_ptr());
offsets.into_par_iter().enumerate().for_each(|(c, mut cur)| {
let optr = optr; // capture the whole (Sync) wrapper, not the inner ptr
let (optr, kptr) = (optr, kptr); // capture whole (Sync) wrappers
let lo = c * chunk;
let hi = ((c + 1) * chunk).min(n);
let base = optr.0;
let (obase, kbase) = (optr.0, kptr.0);
for k in lo..hi {
let b = (keys[k] >> LOW_BITS) as usize;
// SAFETY: `cur[b]` ranges over `[off[c][b], off[c+1][b])`, a range
// owned exclusively by chunk `c` and within `order`'s bounds.
unsafe { *base.add(cur[b] as usize) = k as u32 };
let key = keys[k];
let b = (key >> LOW_BITS) as usize;
let pos = cur[b] as usize;
// SAFETY: `pos` ranges over `[off[c][b], off[c+1][b])`, a range owned
// exclusively by chunk `c` and within `order`/`keys_part` bounds.
unsafe {
*obase.add(pos) = k as u32;
*kbase.add(pos) = key;
}
cur[b] += 1;
}
});
(starts, order)
(starts, order, keys_part)
}
/// Count the colliding pairs a partition will emit, from the low-bit histogram
/// alone (no reordering). `clamp` caps each exact-collision group, matching the
/// emit pass so the output offsets line up. `hist` is reusable `LOW_BUCKETS`
/// scratch.
fn count_pairs(keys: &[u32], run: &[u32], hist: &mut [u32], clamp: usize) -> usize {
/// of its (partition-contiguous) leading keys `keys_run`. `clamp` caps each
/// exact-collision group, matching the emit pass so the output offsets line up.
/// `hist` is reusable `LOW_BUCKETS` scratch.
fn count_pairs(keys_run: &[u32], hist: &mut [u32], clamp: usize) -> usize {
for h in hist.iter_mut() {
*h = 0;
}
for &k in run {
hist[(keys[k as usize] & LOW_MASK) as usize] += 1;
for &key in keys_run {
hist[(key & LOW_MASK) as usize] += 1;
}
let mut pairs = 0usize;
for i in 0..LOW_BUCKETS {
@@ -370,31 +382,44 @@ fn count_pairs(keys: &[u32], run: &[u32], hist: &mut [u32], clamp: usize) -> usi
pairs
}
/// Within one partition, group `run`'s entries by the low bits of their leading
/// block, writing the grouped indices into `sorted`. `hist` is reusable
/// `LOW_BUCKETS + 1` scratch. After this call `sorted` lists the run's indices
/// with equal low keys contiguous, so callers recover each exact-collision
/// group by walking adjacent equal keys.
fn low_group(keys: &[u32], run: &[u32], hist: &mut [u32], sorted: &mut Vec<u32>) {
let m = run.len();
/// Within one partition, group its entries by the low bits of their leading
/// block. Inputs are the partition-contiguous slabs `keys_run` (leading keys) and
/// `order_run` (matching global indices). Outputs, in low-key-sorted order:
/// `sorted` (the global indices, for the emit slot gather + back-refs) and
/// `keys_sorted` (the leading keys, so the emit group walk streams a dense local
/// array instead of gathering `keys[sorted[i]]`). `hist` is reusable
/// `LOW_BUCKETS + 1` scratch. Both reads are sequential over the slabs.
fn low_group(
keys_run: &[u32],
order_run: &[u32],
hist: &mut [u32],
sorted: &mut Vec<u32>,
keys_sorted: &mut Vec<u32>,
) {
let m = order_run.len();
sorted.clear();
keys_sorted.clear();
if m == 0 {
return;
}
for h in hist.iter_mut() {
*h = 0;
}
for &k in run {
hist[(keys[k as usize] & LOW_MASK) as usize + 1] += 1;
for &key in keys_run {
hist[(key & LOW_MASK) as usize + 1] += 1;
}
for i in 0..LOW_BUCKETS {
hist[i + 1] += hist[i];
}
sorted.resize(m, 0);
keys_sorted.resize(m, 0);
// hist[low] now holds the run-start offset; reuse it as the live cursor.
for &k in run {
let low = (keys[k as usize] & LOW_MASK) as usize;
sorted[hist[low] as usize] = k;
for i in 0..m {
let key = keys_run[i];
let low = (key & LOW_MASK) as usize;
let pos = hist[low] as usize;
sorted[pos] = order_run[i];
keys_sorted[pos] = key;
hist[low] += 1;
}
}
@@ -405,17 +430,18 @@ fn low_group(keys: &[u32], run: &[u32], hist: &mut [u32], sorted: &mut Vec<u32>)
/// Monomorphised over the XOR producer so the AVX2 intrinsics inline cleanly
/// inside a `target_feature` wrapper while sharing one source of truth.
macro_rules! emit_bucket_body {
($keys:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $xor:path) => {{
($keys_sorted:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $xor:path) => {{
let s = $sorted;
let ks = $keys_sorted; // leading keys in `s`-order; group walk streams it
let m = s.len();
let w_in = $w_in;
let w_out = $w_out;
let mut w = 0usize;
let mut i = 0;
while i < m {
let key = $keys[s[i] as usize] & LOW_MASK;
let key = ks[i] & LOW_MASK;
let mut j = i + 1;
while j < m && ($keys[s[j] as usize] & LOW_MASK) == key {
while j < m && (ks[j] & LOW_MASK) == key {
j += 1;
}
let hi = j.min(i.saturating_add($clamp));
@@ -441,7 +467,7 @@ macro_rules! emit_bucket_body {
}
unsafe fn emit_bucket_scalar(
keys: &[u32],
keys_sorted: &[u32],
slots: &[u32],
sorted: &[u32],
kout: &mut [u32],
@@ -451,13 +477,13 @@ unsafe fn emit_bucket_scalar(
w_in: usize,
w_out: usize,
) -> usize {
emit_bucket_body!(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_scalar)
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_scalar)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn emit_bucket_avx2(
keys: &[u32],
keys_sorted: &[u32],
slots: &[u32],
sorted: &[u32],
kout: &mut [u32],
@@ -467,13 +493,13 @@ unsafe fn emit_bucket_avx2(
w_in: usize,
w_out: usize,
) -> usize {
emit_bucket_body!(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_avx2)
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_avx2)
}
/// Emit a partition's children, dispatching to the AVX2 producer when available.
/// `w_in`/`w_out` are the input/output slot pitches (`w_out == w_in - 1`).
unsafe fn emit_bucket(
keys: &[u32],
keys_sorted: &[u32],
slots: &[u32],
sorted: &[u32],
kout: &mut [u32],
@@ -486,10 +512,10 @@ unsafe fn emit_bucket(
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return emit_bucket_avx2(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out);
return emit_bucket_avx2(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out);
}
}
emit_bucket_scalar(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out)
emit_bucket_scalar(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out)
}
/// Group `n` entries by their leading block, then emit one child per colliding
@@ -514,17 +540,17 @@ fn collide(
let prof = std::env::var_os("EQ_PROFILE").is_some();
let t0 = std::time::Instant::now();
let (starts, order) = partition_top(keys, n);
let (starts, order, keys_part) = partition_top(keys, n);
let t_part = std::time::Instant::now();
// Pass 1: per-partition child counts (histogram-derived, no reordering).
// Pass 1: per-partition child counts from each partition's contiguous keys.
let counts: Vec<usize> = (0..TOP_BUCKETS)
.into_par_iter()
.map_init(
|| vec![0u32; LOW_BUCKETS],
|hist, v| {
let run = &order[starts[v] as usize..starts[v + 1] as usize];
count_pairs(keys, run, hist, clamp)
let keys_run = &keys_part[starts[v] as usize..starts[v + 1] as usize];
count_pairs(keys_run, hist, clamp)
},
)
.collect();
@@ -569,11 +595,14 @@ fn collide(
.zip(pparts)
.enumerate()
.for_each(|(v, ((kout, sout), pout))| {
let run = &order[starts[v] as usize..starts[v + 1] as usize];
let lo = starts[v] as usize;
let hi = starts[v + 1] as usize;
let mut hist = vec![0u32; LOW_BUCKETS + 1];
let mut sorted = Vec::new();
low_group(keys, run, &mut hist, &mut sorted);
let w = unsafe { emit_bucket(keys, slots, &sorted, kout, sout, pout, clamp, w_in, w_out) };
let mut keys_sorted = Vec::new();
low_group(&keys_part[lo..hi], &order[lo..hi], &mut hist, &mut sorted, &mut keys_sorted);
let w =
unsafe { emit_bucket(&keys_sorted, slots, &sorted, kout, sout, pout, clamp, w_in, w_out) };
debug_assert_eq!(w, kout.len());
});
@@ -595,22 +624,23 @@ fn collide(
/// leading block `w0`, a pair whose `w1` also matches XORs the last two blocks to
/// zero — a candidate. Returns the `(l, mr)` parents of each candidate.
fn collide_final(keys: &[u32], slots: &[u32], n: usize, clamp: usize, w_in: usize) -> Vec<(u32, u32)> {
let (starts, order) = partition_top(keys, n);
let (starts, order, keys_part) = partition_top(keys, n);
(0..TOP_BUCKETS)
.into_par_iter()
.map_init(
|| (vec![0u32; LOW_BUCKETS + 1], Vec::<u32>::new()),
|(hist, sorted), v| {
let run = &order[starts[v] as usize..starts[v + 1] as usize];
low_group(keys, run, hist, sorted);
|| (vec![0u32; LOW_BUCKETS + 1], Vec::<u32>::new(), Vec::<u32>::new()),
|(hist, sorted, keys_sorted), v| {
let lo = starts[v] as usize;
let hi = starts[v + 1] as usize;
low_group(&keys_part[lo..hi], &order[lo..hi], hist, sorted, keys_sorted);
let m = sorted.len();
let mut local = Vec::new();
let mut i = 0;
while i < m {
let key = keys[sorted[i] as usize] & LOW_MASK;
let key = keys_sorted[i] & LOW_MASK;
let mut j = i + 1;
while j < m && (keys[sorted[j] as usize] & LOW_MASK) == key {
while j < m && (keys_sorted[j] & LOW_MASK) == key {
j += 1;
}
let hi = j.min(i.saturating_add(clamp));