diff --git a/src/rle.rs b/src/rle.rs index 4c6788a..e3e6cc5 100644 --- a/src/rle.rs +++ b/src/rle.rs @@ -253,13 +253,113 @@ pub fn dequantize_block_q4k_rle(block: &BlockQ4KRle, out: &mut [f32; QK_K]) { // Matrix multiplication C = A × B // --------------------------------------------------------------------------- +/// Accumulate the contribution of one RLE-encoded block into `c_row`. +/// +/// For each `(value, count)` pair the dequantised weight is constant within +/// every 32-byte sub-block group, so the per-output-column dot-product +/// contribution reduces from `2 * run_len` multiplies to just `2`: +/// +/// ```text +/// original: Σ_{l} ( dq_lo * B[ki_lo+l, j] + dq_hi * B[ki_hi+l, j] ) +/// +/// optimised: dq_lo * Σ_{l} B[ki_lo+l, j] + dq_hi * Σ_{l} B[ki_hi+l, j] +/// ``` +/// +/// A run that crosses a 32-byte group boundary (and thus a scale/min change) +/// is split at the boundary; each resulting segment is handled independently. +/// +/// `sum_lo` and `sum_hi` are caller-provided scratch slices (length `≥ n`) +/// reused across calls to avoid repeated allocation. +fn accumulate_rle_block( + block: &BlockQ4KRle, + b: &[u16], + ki_base: usize, // first B-row index for this block (= b_idx * QK_K) + n: usize, + c_row: &mut [f32], + sum_lo: &mut [f32], + sum_hi: &mut [f32], +) { + let d = fp16_to_f32(block.d); + let dmin = fp16_to_f32(block.dmin); + + let mut byte_pos = 0usize; // running cursor into the 128-byte qs payload + + for p in 0..block.rle_len() { + let val = block.qs[2 * p]; + let run = block.qs[2 * p + 1] as usize; + let lo = (val & 0x0F) as f32; + let hi = (val >> 4) as f32; + + let mut remaining = run; + let mut pos = byte_pos; + + while remaining > 0 { + // Clip the current run to the boundary of the 32-byte group so + // that the sub-block scale/min stays constant over the segment. + let group = pos / 32; // 0..4 + let in_group = pos % 32; // byte offset within this group + let seg_len = remaining.min((group + 1) * 32 - pos); + + // Constant dequantised values for both nibble levels in this group. + let (sc_lo, mn_lo) = get_scale_min(group * 2, &block.scales); + let (sc_hi, mn_hi) = get_scale_min(group * 2 + 1, &block.scales); + let dq_lo = d * sc_lo as f32 * lo - dmin * mn_lo as f32; + let dq_hi = d * sc_hi as f32 * hi - dmin * mn_hi as f32; + + // Map byte positions to dequantised-output indices (0..QK_K): + // lo nibbles → group*64 + in_group .. + seg_len + // hi nibbles → group*64 + 32 + in_group .. + seg_len + let out_lo = group * 64 + in_group; + let out_hi = group * 64 + 32 + in_group; + + // Sum B rows for every j across the segment (B accessed stride-1 + // within each row — cache-friendly). + sum_lo[..n].fill(0.0); + sum_hi[..n].fill(0.0); + for l in 0..seg_len { + let base_lo = (ki_base + out_lo + l) * n; + let base_hi = (ki_base + out_hi + l) * n; + for j in 0..n { + sum_lo[j] += fp16_to_f32(b[base_lo + j]); + sum_hi[j] += fp16_to_f32(b[base_hi + j]); + } + } + + // One multiply per output column instead of one per weight element. + for j in 0..n { + c_row[j] += dq_lo * sum_lo[j] + dq_hi * sum_hi[j]; + } + + pos += seg_len; + remaining -= seg_len; + } + + byte_pos += run; + } +} + /// Multiply a Q4_K_RLE matrix **A** by an FP16 matrix **B**, producing an f32 /// matrix **C**. /// -/// Identical semantics to [`crate::matmul_q4k_fp16`] but accepts -/// [`BlockQ4KRle`] blocks. Each block is dequantised on the fly via -/// [`dequantize_block_q4k_rle`], transparently handling mixed raw/RLE blocks -/// within the same matrix. +/// For blocks in **RLE mode** (`IS_RLE = 1`) the intermediate decompressed row +/// is eliminated entirely. [`accumulate_rle_block`] works directly over the +/// `(value, count)` pairs: within each run the dequantised weight is constant +/// across all elements in the run, so each output column `j` requires only +/// **2 multiplies per group-segment** rather than 2 per weight element: +/// +/// ```text +/// c[i, j] += dq_lo * Σ B[ki_lo, j] + dq_hi * Σ B[ki_hi, j] +/// ─────────────────────────────────────────── +/// summed over seg_len consecutive positions +/// ``` +/// +/// For a single-run block (all bytes identical) this reduces the multiply +/// count from `2 * QK_K = 512` to `2 * 4 = 8` per output column (4 groups, +/// 2 nibble levels each), while B is still read exactly once. +/// +/// For blocks in **raw mode** (`IS_RLE = 0`) the block is dequantised into a +/// scratch buffer and its contribution is accumulated via a saxpy loop +/// (weight-outer, column-inner), which accesses B in row-major order. /// /// # Arguments /// @@ -308,26 +408,39 @@ pub fn matmul_q4k_rle_fp16( b.len() ); - let mut c = vec![0.0f32; m * n]; - let mut a_row = vec![0.0f32; k]; + let mut c = vec![0.0f32; m * n]; + + // Scratch for raw-mode block dequantisation. let mut block_buf = [0.0f32; QK_K]; + // Scratch for RLE-mode B-column sums; allocated once and reused per segment. + let mut sum_lo = vec![0.0f32; n]; + let mut sum_hi = vec![0.0f32; n]; for i in 0..m { - // Dequantise row i of A into a_row (f32). - for b_idx in 0..blocks_per_row { - let block = &a[i * blocks_per_row + b_idx]; - dequantize_block_q4k_rle(block, &mut block_buf); - let start = b_idx * QK_K; - a_row[start..start + QK_K].copy_from_slice(&block_buf); - } + let c_row = &mut c[i * n..(i + 1) * n]; - // Dot-product with each column of B. - for j in 0..n { - let mut sum = 0.0f32; - for ki in 0..k { - sum += a_row[ki] * fp16_to_f32(b[ki * n + j]); + for b_idx in 0..blocks_per_row { + let block = &a[i * blocks_per_row + b_idx]; + let ki_base = b_idx * QK_K; + + if block.is_rle() { + // RLE path: accumulate directly from runs, no decompression. + accumulate_rle_block( + block, b, ki_base, n, c_row, + &mut sum_lo, &mut sum_hi, + ); + } else { + // Raw path: dequantise once, then saxpy into c_row. + // Outer loop over weights (l) keeps B access stride-1 per row. + dequantize_block_q4k_rle(block, &mut block_buf); + for l in 0..QK_K { + let w = block_buf[l]; + let b_off = (ki_base + l) * n; + for j in 0..n { + c_row[j] += w * fp16_to_f32(b[b_off + j]); + } + } } - c[i * n + j] = sum; } }