Scatter partition-contiguous keys to kill the per-round key gathers
count_pairs, low_group, and the emit group-walk all read each entry's leading key via `keys[order[..]]` — a random gather over the whole ~128 MB keys array, three times per round. partition_top now also produces `keys_part` (the leading keys in partition order, keys_part[p] == keys[order[p]]), written by the same parallel, disjoint phase-3 scatter at 4 bytes/entry. count_pairs and low_group then stream their partition's keys sequentially, and low_group emits a `keys_sorted` array so the emit group walk streams a dense local copy instead of gathering keys[sorted[i]]. The only remaining DRAM-random access in the rounds is the unavoidable slot gather. Measured (16 threads, clamp 16/32): count ~160 -> ~10 ms/round, emit ~770 -> ~550 ms/round, partition +~80 ms (the added 128 MB scatter); full solve ~8.4 -> ~7.04 s (~16%). Cumulative across the three CPU-solver changes: ~13.4 -> ~7.04 s (-47%), 0.07 -> 0.14 solve/s. Identical solution yield; cross-clamp validity and full_solve_baseline pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
+90
-60
@@ -274,24 +274,30 @@ unsafe impl Send for OrderPtr {}
|
|||||||
unsafe impl Sync for OrderPtr {}
|
unsafe impl Sync for OrderPtr {}
|
||||||
|
|
||||||
/// Partition the `n` entries into `TOP_BUCKETS` runs by the high `TOP_BITS` of
|
/// Partition the `n` entries into `TOP_BUCKETS` runs by the high `TOP_BITS` of
|
||||||
/// their (dense) leading block. Returns `(starts, order)`, where partition `v`
|
/// their (dense) leading block. Returns `(starts, order, keys_part)`: partition
|
||||||
/// owns the input indices `order[starts[v]..starts[v + 1]]`. The histogram
|
/// `v` owns the input indices `order[starts[v]..starts[v + 1]]`, and `keys_part`
|
||||||
/// passes stream over `keys[]` (4 bytes/entry) instead of striding the slots.
|
/// is the entries' leading words in the same partition-contiguous order
|
||||||
|
/// (`keys_part[p] == keys[order[p]]`). Carrying `keys_part` lets the per-partition
|
||||||
|
/// `count_pairs`/`low_group` sweeps read keys sequentially instead of gathering
|
||||||
|
/// `keys[order[..]]` over the whole array — at the cost of one extra 4-byte/entry
|
||||||
|
/// scatter here, folded into the (already parallel) phase-3 pass.
|
||||||
///
|
///
|
||||||
/// Parallel counting sort: the input is split into one contiguous chunk per
|
/// Parallel counting sort: the input is split into one contiguous chunk per
|
||||||
/// rayon worker. Each chunk histograms its slice (phase 1), a small serial pass
|
/// rayon worker. Each chunk histograms its slice (phase 1), a small serial pass
|
||||||
/// turns those into per-chunk base offsets within each bucket's output region
|
/// turns those into per-chunk base offsets within each bucket's output region
|
||||||
/// (phase 2), and each chunk scatters its entries into `order` (phase 3). Chunk
|
/// (phase 2), and each chunk scatters its entries into `order`/`keys_part`
|
||||||
/// `c`'s bucket-`b` writes land in `[off[c][b], off[c+1][b])`, disjoint from
|
/// (phase 3). Chunk `c`'s bucket-`b` writes land in `[off[c][b], off[c+1][b])`,
|
||||||
/// every other chunk and bucket, so the concurrent writes never alias. Entries
|
/// disjoint from every other chunk and bucket, so the concurrent writes never
|
||||||
/// within a bucket end up chunk-major rather than index-major; that reordering
|
/// alias. Entries within a bucket end up chunk-major rather than index-major;
|
||||||
/// is immaterial — `count_pairs`/`low_group` depend only on the multiset of low
|
/// that reordering is immaterial — `count_pairs`/`low_group` depend only on the
|
||||||
/// keys, and final solutions are canonicalised, de-duplicated, and verified.
|
/// multiset of low keys, and final solutions are canonicalised, de-duplicated,
|
||||||
fn partition_top(keys: &[u32], n: usize) -> (Vec<u32>, Vec<u32>) {
|
/// and verified.
|
||||||
|
fn partition_top(keys: &[u32], n: usize) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
|
||||||
let mut starts = vec![0u32; TOP_BUCKETS + 1];
|
let mut starts = vec![0u32; TOP_BUCKETS + 1];
|
||||||
let mut order = vec![0u32; n];
|
let mut order = vec![0u32; n];
|
||||||
|
let mut keys_part = vec![0u32; n];
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return (starts, order);
|
return (starts, order, keys_part);
|
||||||
}
|
}
|
||||||
|
|
||||||
let nthreads = rayon::current_num_threads().max(1);
|
let nthreads = rayon::current_num_threads().max(1);
|
||||||
@@ -334,33 +340,39 @@ fn partition_top(keys: &[u32], n: usize) -> (Vec<u32>, Vec<u32>) {
|
|||||||
|
|
||||||
// Phase 3 (parallel): each chunk scatters into its disjoint sub-ranges.
|
// Phase 3 (parallel): each chunk scatters into its disjoint sub-ranges.
|
||||||
let optr = OrderPtr(order.as_mut_ptr());
|
let optr = OrderPtr(order.as_mut_ptr());
|
||||||
|
let kptr = OrderPtr(keys_part.as_mut_ptr());
|
||||||
offsets.into_par_iter().enumerate().for_each(|(c, mut cur)| {
|
offsets.into_par_iter().enumerate().for_each(|(c, mut cur)| {
|
||||||
let optr = optr; // capture the whole (Sync) wrapper, not the inner ptr
|
let (optr, kptr) = (optr, kptr); // capture whole (Sync) wrappers
|
||||||
let lo = c * chunk;
|
let lo = c * chunk;
|
||||||
let hi = ((c + 1) * chunk).min(n);
|
let hi = ((c + 1) * chunk).min(n);
|
||||||
let base = optr.0;
|
let (obase, kbase) = (optr.0, kptr.0);
|
||||||
for k in lo..hi {
|
for k in lo..hi {
|
||||||
let b = (keys[k] >> LOW_BITS) as usize;
|
let key = keys[k];
|
||||||
// SAFETY: `cur[b]` ranges over `[off[c][b], off[c+1][b])`, a range
|
let b = (key >> LOW_BITS) as usize;
|
||||||
// owned exclusively by chunk `c` and within `order`'s bounds.
|
let pos = cur[b] as usize;
|
||||||
unsafe { *base.add(cur[b] as usize) = k as u32 };
|
// SAFETY: `pos` ranges over `[off[c][b], off[c+1][b])`, a range owned
|
||||||
|
// exclusively by chunk `c` and within `order`/`keys_part` bounds.
|
||||||
|
unsafe {
|
||||||
|
*obase.add(pos) = k as u32;
|
||||||
|
*kbase.add(pos) = key;
|
||||||
|
}
|
||||||
cur[b] += 1;
|
cur[b] += 1;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
(starts, order)
|
(starts, order, keys_part)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Count the colliding pairs a partition will emit, from the low-bit histogram
|
/// Count the colliding pairs a partition will emit, from the low-bit histogram
|
||||||
/// alone (no reordering). `clamp` caps each exact-collision group, matching the
|
/// of its (partition-contiguous) leading keys `keys_run`. `clamp` caps each
|
||||||
/// emit pass so the output offsets line up. `hist` is reusable `LOW_BUCKETS`
|
/// exact-collision group, matching the emit pass so the output offsets line up.
|
||||||
/// scratch.
|
/// `hist` is reusable `LOW_BUCKETS` scratch.
|
||||||
fn count_pairs(keys: &[u32], run: &[u32], hist: &mut [u32], clamp: usize) -> usize {
|
fn count_pairs(keys_run: &[u32], hist: &mut [u32], clamp: usize) -> usize {
|
||||||
for h in hist.iter_mut() {
|
for h in hist.iter_mut() {
|
||||||
*h = 0;
|
*h = 0;
|
||||||
}
|
}
|
||||||
for &k in run {
|
for &key in keys_run {
|
||||||
hist[(keys[k as usize] & LOW_MASK) as usize] += 1;
|
hist[(key & LOW_MASK) as usize] += 1;
|
||||||
}
|
}
|
||||||
let mut pairs = 0usize;
|
let mut pairs = 0usize;
|
||||||
for i in 0..LOW_BUCKETS {
|
for i in 0..LOW_BUCKETS {
|
||||||
@@ -370,31 +382,44 @@ fn count_pairs(keys: &[u32], run: &[u32], hist: &mut [u32], clamp: usize) -> usi
|
|||||||
pairs
|
pairs
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Within one partition, group `run`'s entries by the low bits of their leading
|
/// Within one partition, group its entries by the low bits of their leading
|
||||||
/// block, writing the grouped indices into `sorted`. `hist` is reusable
|
/// block. Inputs are the partition-contiguous slabs `keys_run` (leading keys) and
|
||||||
/// `LOW_BUCKETS + 1` scratch. After this call `sorted` lists the run's indices
|
/// `order_run` (matching global indices). Outputs, in low-key-sorted order:
|
||||||
/// with equal low keys contiguous, so callers recover each exact-collision
|
/// `sorted` (the global indices, for the emit slot gather + back-refs) and
|
||||||
/// group by walking adjacent equal keys.
|
/// `keys_sorted` (the leading keys, so the emit group walk streams a dense local
|
||||||
fn low_group(keys: &[u32], run: &[u32], hist: &mut [u32], sorted: &mut Vec<u32>) {
|
/// array instead of gathering `keys[sorted[i]]`). `hist` is reusable
|
||||||
let m = run.len();
|
/// `LOW_BUCKETS + 1` scratch. Both reads are sequential over the slabs.
|
||||||
|
fn low_group(
|
||||||
|
keys_run: &[u32],
|
||||||
|
order_run: &[u32],
|
||||||
|
hist: &mut [u32],
|
||||||
|
sorted: &mut Vec<u32>,
|
||||||
|
keys_sorted: &mut Vec<u32>,
|
||||||
|
) {
|
||||||
|
let m = order_run.len();
|
||||||
sorted.clear();
|
sorted.clear();
|
||||||
|
keys_sorted.clear();
|
||||||
if m == 0 {
|
if m == 0 {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for h in hist.iter_mut() {
|
for h in hist.iter_mut() {
|
||||||
*h = 0;
|
*h = 0;
|
||||||
}
|
}
|
||||||
for &k in run {
|
for &key in keys_run {
|
||||||
hist[(keys[k as usize] & LOW_MASK) as usize + 1] += 1;
|
hist[(key & LOW_MASK) as usize + 1] += 1;
|
||||||
}
|
}
|
||||||
for i in 0..LOW_BUCKETS {
|
for i in 0..LOW_BUCKETS {
|
||||||
hist[i + 1] += hist[i];
|
hist[i + 1] += hist[i];
|
||||||
}
|
}
|
||||||
sorted.resize(m, 0);
|
sorted.resize(m, 0);
|
||||||
|
keys_sorted.resize(m, 0);
|
||||||
// hist[low] now holds the run-start offset; reuse it as the live cursor.
|
// hist[low] now holds the run-start offset; reuse it as the live cursor.
|
||||||
for &k in run {
|
for i in 0..m {
|
||||||
let low = (keys[k as usize] & LOW_MASK) as usize;
|
let key = keys_run[i];
|
||||||
sorted[hist[low] as usize] = k;
|
let low = (key & LOW_MASK) as usize;
|
||||||
|
let pos = hist[low] as usize;
|
||||||
|
sorted[pos] = order_run[i];
|
||||||
|
keys_sorted[pos] = key;
|
||||||
hist[low] += 1;
|
hist[low] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -405,17 +430,18 @@ fn low_group(keys: &[u32], run: &[u32], hist: &mut [u32], sorted: &mut Vec<u32>)
|
|||||||
/// Monomorphised over the XOR producer so the AVX2 intrinsics inline cleanly
|
/// Monomorphised over the XOR producer so the AVX2 intrinsics inline cleanly
|
||||||
/// inside a `target_feature` wrapper while sharing one source of truth.
|
/// inside a `target_feature` wrapper while sharing one source of truth.
|
||||||
macro_rules! emit_bucket_body {
|
macro_rules! emit_bucket_body {
|
||||||
($keys:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $xor:path) => {{
|
($keys_sorted:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $xor:path) => {{
|
||||||
let s = $sorted;
|
let s = $sorted;
|
||||||
|
let ks = $keys_sorted; // leading keys in `s`-order; group walk streams it
|
||||||
let m = s.len();
|
let m = s.len();
|
||||||
let w_in = $w_in;
|
let w_in = $w_in;
|
||||||
let w_out = $w_out;
|
let w_out = $w_out;
|
||||||
let mut w = 0usize;
|
let mut w = 0usize;
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
while i < m {
|
while i < m {
|
||||||
let key = $keys[s[i] as usize] & LOW_MASK;
|
let key = ks[i] & LOW_MASK;
|
||||||
let mut j = i + 1;
|
let mut j = i + 1;
|
||||||
while j < m && ($keys[s[j] as usize] & LOW_MASK) == key {
|
while j < m && (ks[j] & LOW_MASK) == key {
|
||||||
j += 1;
|
j += 1;
|
||||||
}
|
}
|
||||||
let hi = j.min(i.saturating_add($clamp));
|
let hi = j.min(i.saturating_add($clamp));
|
||||||
@@ -441,7 +467,7 @@ macro_rules! emit_bucket_body {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn emit_bucket_scalar(
|
unsafe fn emit_bucket_scalar(
|
||||||
keys: &[u32],
|
keys_sorted: &[u32],
|
||||||
slots: &[u32],
|
slots: &[u32],
|
||||||
sorted: &[u32],
|
sorted: &[u32],
|
||||||
kout: &mut [u32],
|
kout: &mut [u32],
|
||||||
@@ -451,13 +477,13 @@ unsafe fn emit_bucket_scalar(
|
|||||||
w_in: usize,
|
w_in: usize,
|
||||||
w_out: usize,
|
w_out: usize,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
emit_bucket_body!(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_scalar)
|
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_scalar)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
#[target_feature(enable = "avx2")]
|
#[target_feature(enable = "avx2")]
|
||||||
unsafe fn emit_bucket_avx2(
|
unsafe fn emit_bucket_avx2(
|
||||||
keys: &[u32],
|
keys_sorted: &[u32],
|
||||||
slots: &[u32],
|
slots: &[u32],
|
||||||
sorted: &[u32],
|
sorted: &[u32],
|
||||||
kout: &mut [u32],
|
kout: &mut [u32],
|
||||||
@@ -467,13 +493,13 @@ unsafe fn emit_bucket_avx2(
|
|||||||
w_in: usize,
|
w_in: usize,
|
||||||
w_out: usize,
|
w_out: usize,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
emit_bucket_body!(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_avx2)
|
emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_avx2)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a partition's children, dispatching to the AVX2 producer when available.
|
/// Emit a partition's children, dispatching to the AVX2 producer when available.
|
||||||
/// `w_in`/`w_out` are the input/output slot pitches (`w_out == w_in - 1`).
|
/// `w_in`/`w_out` are the input/output slot pitches (`w_out == w_in - 1`).
|
||||||
unsafe fn emit_bucket(
|
unsafe fn emit_bucket(
|
||||||
keys: &[u32],
|
keys_sorted: &[u32],
|
||||||
slots: &[u32],
|
slots: &[u32],
|
||||||
sorted: &[u32],
|
sorted: &[u32],
|
||||||
kout: &mut [u32],
|
kout: &mut [u32],
|
||||||
@@ -486,10 +512,10 @@ unsafe fn emit_bucket(
|
|||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
{
|
{
|
||||||
if is_x86_feature_detected!("avx2") {
|
if is_x86_feature_detected!("avx2") {
|
||||||
return emit_bucket_avx2(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out);
|
return emit_bucket_avx2(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
emit_bucket_scalar(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out)
|
emit_bucket_scalar(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Group `n` entries by their leading block, then emit one child per colliding
|
/// Group `n` entries by their leading block, then emit one child per colliding
|
||||||
@@ -514,17 +540,17 @@ fn collide(
|
|||||||
let prof = std::env::var_os("EQ_PROFILE").is_some();
|
let prof = std::env::var_os("EQ_PROFILE").is_some();
|
||||||
let t0 = std::time::Instant::now();
|
let t0 = std::time::Instant::now();
|
||||||
|
|
||||||
let (starts, order) = partition_top(keys, n);
|
let (starts, order, keys_part) = partition_top(keys, n);
|
||||||
let t_part = std::time::Instant::now();
|
let t_part = std::time::Instant::now();
|
||||||
|
|
||||||
// Pass 1: per-partition child counts (histogram-derived, no reordering).
|
// Pass 1: per-partition child counts from each partition's contiguous keys.
|
||||||
let counts: Vec<usize> = (0..TOP_BUCKETS)
|
let counts: Vec<usize> = (0..TOP_BUCKETS)
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
.map_init(
|
.map_init(
|
||||||
|| vec![0u32; LOW_BUCKETS],
|
|| vec![0u32; LOW_BUCKETS],
|
||||||
|hist, v| {
|
|hist, v| {
|
||||||
let run = &order[starts[v] as usize..starts[v + 1] as usize];
|
let keys_run = &keys_part[starts[v] as usize..starts[v + 1] as usize];
|
||||||
count_pairs(keys, run, hist, clamp)
|
count_pairs(keys_run, hist, clamp)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.collect();
|
.collect();
|
||||||
@@ -569,11 +595,14 @@ fn collide(
|
|||||||
.zip(pparts)
|
.zip(pparts)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.for_each(|(v, ((kout, sout), pout))| {
|
.for_each(|(v, ((kout, sout), pout))| {
|
||||||
let run = &order[starts[v] as usize..starts[v + 1] as usize];
|
let lo = starts[v] as usize;
|
||||||
|
let hi = starts[v + 1] as usize;
|
||||||
let mut hist = vec![0u32; LOW_BUCKETS + 1];
|
let mut hist = vec![0u32; LOW_BUCKETS + 1];
|
||||||
let mut sorted = Vec::new();
|
let mut sorted = Vec::new();
|
||||||
low_group(keys, run, &mut hist, &mut sorted);
|
let mut keys_sorted = Vec::new();
|
||||||
let w = unsafe { emit_bucket(keys, slots, &sorted, kout, sout, pout, clamp, w_in, w_out) };
|
low_group(&keys_part[lo..hi], &order[lo..hi], &mut hist, &mut sorted, &mut keys_sorted);
|
||||||
|
let w =
|
||||||
|
unsafe { emit_bucket(&keys_sorted, slots, &sorted, kout, sout, pout, clamp, w_in, w_out) };
|
||||||
debug_assert_eq!(w, kout.len());
|
debug_assert_eq!(w, kout.len());
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -595,22 +624,23 @@ fn collide(
|
|||||||
/// leading block `w0`, a pair whose `w1` also matches XORs the last two blocks to
|
/// leading block `w0`, a pair whose `w1` also matches XORs the last two blocks to
|
||||||
/// zero — a candidate. Returns the `(l, mr)` parents of each candidate.
|
/// zero — a candidate. Returns the `(l, mr)` parents of each candidate.
|
||||||
fn collide_final(keys: &[u32], slots: &[u32], n: usize, clamp: usize, w_in: usize) -> Vec<(u32, u32)> {
|
fn collide_final(keys: &[u32], slots: &[u32], n: usize, clamp: usize, w_in: usize) -> Vec<(u32, u32)> {
|
||||||
let (starts, order) = partition_top(keys, n);
|
let (starts, order, keys_part) = partition_top(keys, n);
|
||||||
|
|
||||||
(0..TOP_BUCKETS)
|
(0..TOP_BUCKETS)
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
.map_init(
|
.map_init(
|
||||||
|| (vec![0u32; LOW_BUCKETS + 1], Vec::<u32>::new()),
|
|| (vec![0u32; LOW_BUCKETS + 1], Vec::<u32>::new(), Vec::<u32>::new()),
|
||||||
|(hist, sorted), v| {
|
|(hist, sorted, keys_sorted), v| {
|
||||||
let run = &order[starts[v] as usize..starts[v + 1] as usize];
|
let lo = starts[v] as usize;
|
||||||
low_group(keys, run, hist, sorted);
|
let hi = starts[v + 1] as usize;
|
||||||
|
low_group(&keys_part[lo..hi], &order[lo..hi], hist, sorted, keys_sorted);
|
||||||
let m = sorted.len();
|
let m = sorted.len();
|
||||||
let mut local = Vec::new();
|
let mut local = Vec::new();
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
while i < m {
|
while i < m {
|
||||||
let key = keys[sorted[i] as usize] & LOW_MASK;
|
let key = keys_sorted[i] & LOW_MASK;
|
||||||
let mut j = i + 1;
|
let mut j = i + 1;
|
||||||
while j < m && (keys[sorted[j] as usize] & LOW_MASK) == key {
|
while j < m && (keys_sorted[j] & LOW_MASK) == key {
|
||||||
j += 1;
|
j += 1;
|
||||||
}
|
}
|
||||||
let hi = j.min(i.saturating_add(clamp));
|
let hi = j.min(i.saturating_add(clamp));
|
||||||
|
|||||||
Reference in New Issue
Block a user