diff --git a/src/equihash.rs b/src/equihash.rs index b4eae08..5121a60 100644 --- a/src/equihash.rs +++ b/src/equihash.rs @@ -274,24 +274,30 @@ unsafe impl Send for OrderPtr {} unsafe impl Sync for OrderPtr {} /// Partition the `n` entries into `TOP_BUCKETS` runs by the high `TOP_BITS` of -/// their (dense) leading block. Returns `(starts, order)`, where partition `v` -/// owns the input indices `order[starts[v]..starts[v + 1]]`. The histogram -/// passes stream over `keys[]` (4 bytes/entry) instead of striding the slots. +/// their (dense) leading block. Returns `(starts, order, keys_part)`: partition +/// `v` owns the input indices `order[starts[v]..starts[v + 1]]`, and `keys_part` +/// is the entries' leading words in the same partition-contiguous order +/// (`keys_part[p] == keys[order[p]]`). Carrying `keys_part` lets the per-partition +/// `count_pairs`/`low_group` sweeps read keys sequentially instead of gathering +/// `keys[order[..]]` over the whole array — at the cost of one extra 4-byte/entry +/// scatter here, folded into the (already parallel) phase-3 pass. /// /// Parallel counting sort: the input is split into one contiguous chunk per /// rayon worker. Each chunk histograms its slice (phase 1), a small serial pass /// turns those into per-chunk base offsets within each bucket's output region -/// (phase 2), and each chunk scatters its entries into `order` (phase 3). Chunk -/// `c`'s bucket-`b` writes land in `[off[c][b], off[c+1][b])`, disjoint from -/// every other chunk and bucket, so the concurrent writes never alias. Entries -/// within a bucket end up chunk-major rather than index-major; that reordering -/// is immaterial — `count_pairs`/`low_group` depend only on the multiset of low -/// keys, and final solutions are canonicalised, de-duplicated, and verified. -fn partition_top(keys: &[u32], n: usize) -> (Vec, Vec) { +/// (phase 2), and each chunk scatters its entries into `order`/`keys_part` +/// (phase 3). Chunk `c`'s bucket-`b` writes land in `[off[c][b], off[c+1][b])`, +/// disjoint from every other chunk and bucket, so the concurrent writes never +/// alias. Entries within a bucket end up chunk-major rather than index-major; +/// that reordering is immaterial — `count_pairs`/`low_group` depend only on the +/// multiset of low keys, and final solutions are canonicalised, de-duplicated, +/// and verified. +fn partition_top(keys: &[u32], n: usize) -> (Vec, Vec, Vec) { let mut starts = vec![0u32; TOP_BUCKETS + 1]; let mut order = vec![0u32; n]; + let mut keys_part = vec![0u32; n]; if n == 0 { - return (starts, order); + return (starts, order, keys_part); } let nthreads = rayon::current_num_threads().max(1); @@ -334,33 +340,39 @@ fn partition_top(keys: &[u32], n: usize) -> (Vec, Vec) { // Phase 3 (parallel): each chunk scatters into its disjoint sub-ranges. let optr = OrderPtr(order.as_mut_ptr()); + let kptr = OrderPtr(keys_part.as_mut_ptr()); offsets.into_par_iter().enumerate().for_each(|(c, mut cur)| { - let optr = optr; // capture the whole (Sync) wrapper, not the inner ptr + let (optr, kptr) = (optr, kptr); // capture whole (Sync) wrappers let lo = c * chunk; let hi = ((c + 1) * chunk).min(n); - let base = optr.0; + let (obase, kbase) = (optr.0, kptr.0); for k in lo..hi { - let b = (keys[k] >> LOW_BITS) as usize; - // SAFETY: `cur[b]` ranges over `[off[c][b], off[c+1][b])`, a range - // owned exclusively by chunk `c` and within `order`'s bounds. - unsafe { *base.add(cur[b] as usize) = k as u32 }; + let key = keys[k]; + let b = (key >> LOW_BITS) as usize; + let pos = cur[b] as usize; + // SAFETY: `pos` ranges over `[off[c][b], off[c+1][b])`, a range owned + // exclusively by chunk `c` and within `order`/`keys_part` bounds. + unsafe { + *obase.add(pos) = k as u32; + *kbase.add(pos) = key; + } cur[b] += 1; } }); - (starts, order) + (starts, order, keys_part) } /// Count the colliding pairs a partition will emit, from the low-bit histogram -/// alone (no reordering). `clamp` caps each exact-collision group, matching the -/// emit pass so the output offsets line up. `hist` is reusable `LOW_BUCKETS` -/// scratch. -fn count_pairs(keys: &[u32], run: &[u32], hist: &mut [u32], clamp: usize) -> usize { +/// of its (partition-contiguous) leading keys `keys_run`. `clamp` caps each +/// exact-collision group, matching the emit pass so the output offsets line up. +/// `hist` is reusable `LOW_BUCKETS` scratch. +fn count_pairs(keys_run: &[u32], hist: &mut [u32], clamp: usize) -> usize { for h in hist.iter_mut() { *h = 0; } - for &k in run { - hist[(keys[k as usize] & LOW_MASK) as usize] += 1; + for &key in keys_run { + hist[(key & LOW_MASK) as usize] += 1; } let mut pairs = 0usize; for i in 0..LOW_BUCKETS { @@ -370,31 +382,44 @@ fn count_pairs(keys: &[u32], run: &[u32], hist: &mut [u32], clamp: usize) -> usi pairs } -/// Within one partition, group `run`'s entries by the low bits of their leading -/// block, writing the grouped indices into `sorted`. `hist` is reusable -/// `LOW_BUCKETS + 1` scratch. After this call `sorted` lists the run's indices -/// with equal low keys contiguous, so callers recover each exact-collision -/// group by walking adjacent equal keys. -fn low_group(keys: &[u32], run: &[u32], hist: &mut [u32], sorted: &mut Vec) { - let m = run.len(); +/// Within one partition, group its entries by the low bits of their leading +/// block. Inputs are the partition-contiguous slabs `keys_run` (leading keys) and +/// `order_run` (matching global indices). Outputs, in low-key-sorted order: +/// `sorted` (the global indices, for the emit slot gather + back-refs) and +/// `keys_sorted` (the leading keys, so the emit group walk streams a dense local +/// array instead of gathering `keys[sorted[i]]`). `hist` is reusable +/// `LOW_BUCKETS + 1` scratch. Both reads are sequential over the slabs. +fn low_group( + keys_run: &[u32], + order_run: &[u32], + hist: &mut [u32], + sorted: &mut Vec, + keys_sorted: &mut Vec, +) { + let m = order_run.len(); sorted.clear(); + keys_sorted.clear(); if m == 0 { return; } for h in hist.iter_mut() { *h = 0; } - for &k in run { - hist[(keys[k as usize] & LOW_MASK) as usize + 1] += 1; + for &key in keys_run { + hist[(key & LOW_MASK) as usize + 1] += 1; } for i in 0..LOW_BUCKETS { hist[i + 1] += hist[i]; } sorted.resize(m, 0); + keys_sorted.resize(m, 0); // hist[low] now holds the run-start offset; reuse it as the live cursor. - for &k in run { - let low = (keys[k as usize] & LOW_MASK) as usize; - sorted[hist[low] as usize] = k; + for i in 0..m { + let key = keys_run[i]; + let low = (key & LOW_MASK) as usize; + let pos = hist[low] as usize; + sorted[pos] = order_run[i]; + keys_sorted[pos] = key; hist[low] += 1; } } @@ -405,17 +430,18 @@ fn low_group(keys: &[u32], run: &[u32], hist: &mut [u32], sorted: &mut Vec) /// Monomorphised over the XOR producer so the AVX2 intrinsics inline cleanly /// inside a `target_feature` wrapper while sharing one source of truth. macro_rules! emit_bucket_body { - ($keys:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $xor:path) => {{ + ($keys_sorted:expr, $slots:expr, $sorted:expr, $kout:expr, $sout:expr, $pout:expr, $clamp:expr, $w_in:expr, $w_out:expr, $xor:path) => {{ let s = $sorted; + let ks = $keys_sorted; // leading keys in `s`-order; group walk streams it let m = s.len(); let w_in = $w_in; let w_out = $w_out; let mut w = 0usize; let mut i = 0; while i < m { - let key = $keys[s[i] as usize] & LOW_MASK; + let key = ks[i] & LOW_MASK; let mut j = i + 1; - while j < m && ($keys[s[j] as usize] & LOW_MASK) == key { + while j < m && (ks[j] & LOW_MASK) == key { j += 1; } let hi = j.min(i.saturating_add($clamp)); @@ -441,7 +467,7 @@ macro_rules! emit_bucket_body { } unsafe fn emit_bucket_scalar( - keys: &[u32], + keys_sorted: &[u32], slots: &[u32], sorted: &[u32], kout: &mut [u32], @@ -451,13 +477,13 @@ unsafe fn emit_bucket_scalar( w_in: usize, w_out: usize, ) -> usize { - emit_bucket_body!(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_scalar) + emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_scalar) } #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] unsafe fn emit_bucket_avx2( - keys: &[u32], + keys_sorted: &[u32], slots: &[u32], sorted: &[u32], kout: &mut [u32], @@ -467,13 +493,13 @@ unsafe fn emit_bucket_avx2( w_in: usize, w_out: usize, ) -> usize { - emit_bucket_body!(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_avx2) + emit_bucket_body!(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out, xor_child_avx2) } /// Emit a partition's children, dispatching to the AVX2 producer when available. /// `w_in`/`w_out` are the input/output slot pitches (`w_out == w_in - 1`). unsafe fn emit_bucket( - keys: &[u32], + keys_sorted: &[u32], slots: &[u32], sorted: &[u32], kout: &mut [u32], @@ -486,10 +512,10 @@ unsafe fn emit_bucket( #[cfg(target_arch = "x86_64")] { if is_x86_feature_detected!("avx2") { - return emit_bucket_avx2(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out); + return emit_bucket_avx2(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out); } } - emit_bucket_scalar(keys, slots, sorted, kout, sout, pout, clamp, w_in, w_out) + emit_bucket_scalar(keys_sorted, slots, sorted, kout, sout, pout, clamp, w_in, w_out) } /// Group `n` entries by their leading block, then emit one child per colliding @@ -514,17 +540,17 @@ fn collide( let prof = std::env::var_os("EQ_PROFILE").is_some(); let t0 = std::time::Instant::now(); - let (starts, order) = partition_top(keys, n); + let (starts, order, keys_part) = partition_top(keys, n); let t_part = std::time::Instant::now(); - // Pass 1: per-partition child counts (histogram-derived, no reordering). + // Pass 1: per-partition child counts from each partition's contiguous keys. let counts: Vec = (0..TOP_BUCKETS) .into_par_iter() .map_init( || vec![0u32; LOW_BUCKETS], |hist, v| { - let run = &order[starts[v] as usize..starts[v + 1] as usize]; - count_pairs(keys, run, hist, clamp) + let keys_run = &keys_part[starts[v] as usize..starts[v + 1] as usize]; + count_pairs(keys_run, hist, clamp) }, ) .collect(); @@ -569,11 +595,14 @@ fn collide( .zip(pparts) .enumerate() .for_each(|(v, ((kout, sout), pout))| { - let run = &order[starts[v] as usize..starts[v + 1] as usize]; + let lo = starts[v] as usize; + let hi = starts[v + 1] as usize; let mut hist = vec![0u32; LOW_BUCKETS + 1]; let mut sorted = Vec::new(); - low_group(keys, run, &mut hist, &mut sorted); - let w = unsafe { emit_bucket(keys, slots, &sorted, kout, sout, pout, clamp, w_in, w_out) }; + let mut keys_sorted = Vec::new(); + low_group(&keys_part[lo..hi], &order[lo..hi], &mut hist, &mut sorted, &mut keys_sorted); + let w = + unsafe { emit_bucket(&keys_sorted, slots, &sorted, kout, sout, pout, clamp, w_in, w_out) }; debug_assert_eq!(w, kout.len()); }); @@ -595,22 +624,23 @@ fn collide( /// leading block `w0`, a pair whose `w1` also matches XORs the last two blocks to /// zero — a candidate. Returns the `(l, mr)` parents of each candidate. fn collide_final(keys: &[u32], slots: &[u32], n: usize, clamp: usize, w_in: usize) -> Vec<(u32, u32)> { - let (starts, order) = partition_top(keys, n); + let (starts, order, keys_part) = partition_top(keys, n); (0..TOP_BUCKETS) .into_par_iter() .map_init( - || (vec![0u32; LOW_BUCKETS + 1], Vec::::new()), - |(hist, sorted), v| { - let run = &order[starts[v] as usize..starts[v + 1] as usize]; - low_group(keys, run, hist, sorted); + || (vec![0u32; LOW_BUCKETS + 1], Vec::::new(), Vec::::new()), + |(hist, sorted, keys_sorted), v| { + let lo = starts[v] as usize; + let hi = starts[v + 1] as usize; + low_group(&keys_part[lo..hi], &order[lo..hi], hist, sorted, keys_sorted); let m = sorted.len(); let mut local = Vec::new(); let mut i = 0; while i < m { - let key = keys[sorted[i] as usize] & LOW_MASK; + let key = keys_sorted[i] & LOW_MASK; let mut j = i + 1; - while j < m && (keys[sorted[j] as usize] & LOW_MASK) == key { + while j < m && (keys_sorted[j] & LOW_MASK) == key { j += 1; } let hi = j.min(i.saturating_add(clamp));