optimize rle a bit
This commit is contained in:
143
src/rle.rs
143
src/rle.rs
@@ -253,13 +253,113 @@ pub fn dequantize_block_q4k_rle(block: &BlockQ4KRle, out: &mut [f32; QK_K]) {
|
||||
// Matrix multiplication C = A × B
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Accumulate the contribution of one RLE-encoded block into `c_row`.
|
||||
///
|
||||
/// For each `(value, count)` pair the dequantised weight is constant within
|
||||
/// every 32-byte sub-block group, so the per-output-column dot-product
|
||||
/// contribution reduces from `2 * run_len` multiplies to just `2`:
|
||||
///
|
||||
/// ```text
|
||||
/// original: Σ_{l} ( dq_lo * B[ki_lo+l, j] + dq_hi * B[ki_hi+l, j] )
|
||||
///
|
||||
/// optimised: dq_lo * Σ_{l} B[ki_lo+l, j] + dq_hi * Σ_{l} B[ki_hi+l, j]
|
||||
/// ```
|
||||
///
|
||||
/// A run that crosses a 32-byte group boundary (and thus a scale/min change)
|
||||
/// is split at the boundary; each resulting segment is handled independently.
|
||||
///
|
||||
/// `sum_lo` and `sum_hi` are caller-provided scratch slices (length `≥ n`)
|
||||
/// reused across calls to avoid repeated allocation.
|
||||
fn accumulate_rle_block(
|
||||
block: &BlockQ4KRle,
|
||||
b: &[u16],
|
||||
ki_base: usize, // first B-row index for this block (= b_idx * QK_K)
|
||||
n: usize,
|
||||
c_row: &mut [f32],
|
||||
sum_lo: &mut [f32],
|
||||
sum_hi: &mut [f32],
|
||||
) {
|
||||
let d = fp16_to_f32(block.d);
|
||||
let dmin = fp16_to_f32(block.dmin);
|
||||
|
||||
let mut byte_pos = 0usize; // running cursor into the 128-byte qs payload
|
||||
|
||||
for p in 0..block.rle_len() {
|
||||
let val = block.qs[2 * p];
|
||||
let run = block.qs[2 * p + 1] as usize;
|
||||
let lo = (val & 0x0F) as f32;
|
||||
let hi = (val >> 4) as f32;
|
||||
|
||||
let mut remaining = run;
|
||||
let mut pos = byte_pos;
|
||||
|
||||
while remaining > 0 {
|
||||
// Clip the current run to the boundary of the 32-byte group so
|
||||
// that the sub-block scale/min stays constant over the segment.
|
||||
let group = pos / 32; // 0..4
|
||||
let in_group = pos % 32; // byte offset within this group
|
||||
let seg_len = remaining.min((group + 1) * 32 - pos);
|
||||
|
||||
// Constant dequantised values for both nibble levels in this group.
|
||||
let (sc_lo, mn_lo) = get_scale_min(group * 2, &block.scales);
|
||||
let (sc_hi, mn_hi) = get_scale_min(group * 2 + 1, &block.scales);
|
||||
let dq_lo = d * sc_lo as f32 * lo - dmin * mn_lo as f32;
|
||||
let dq_hi = d * sc_hi as f32 * hi - dmin * mn_hi as f32;
|
||||
|
||||
// Map byte positions to dequantised-output indices (0..QK_K):
|
||||
// lo nibbles → group*64 + in_group .. + seg_len
|
||||
// hi nibbles → group*64 + 32 + in_group .. + seg_len
|
||||
let out_lo = group * 64 + in_group;
|
||||
let out_hi = group * 64 + 32 + in_group;
|
||||
|
||||
// Sum B rows for every j across the segment (B accessed stride-1
|
||||
// within each row — cache-friendly).
|
||||
sum_lo[..n].fill(0.0);
|
||||
sum_hi[..n].fill(0.0);
|
||||
for l in 0..seg_len {
|
||||
let base_lo = (ki_base + out_lo + l) * n;
|
||||
let base_hi = (ki_base + out_hi + l) * n;
|
||||
for j in 0..n {
|
||||
sum_lo[j] += fp16_to_f32(b[base_lo + j]);
|
||||
sum_hi[j] += fp16_to_f32(b[base_hi + j]);
|
||||
}
|
||||
}
|
||||
|
||||
// One multiply per output column instead of one per weight element.
|
||||
for j in 0..n {
|
||||
c_row[j] += dq_lo * sum_lo[j] + dq_hi * sum_hi[j];
|
||||
}
|
||||
|
||||
pos += seg_len;
|
||||
remaining -= seg_len;
|
||||
}
|
||||
|
||||
byte_pos += run;
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply a Q4_K_RLE matrix **A** by an FP16 matrix **B**, producing an f32
|
||||
/// matrix **C**.
|
||||
///
|
||||
/// Identical semantics to [`crate::matmul_q4k_fp16`] but accepts
|
||||
/// [`BlockQ4KRle`] blocks. Each block is dequantised on the fly via
|
||||
/// [`dequantize_block_q4k_rle`], transparently handling mixed raw/RLE blocks
|
||||
/// within the same matrix.
|
||||
/// For blocks in **RLE mode** (`IS_RLE = 1`) the intermediate decompressed row
|
||||
/// is eliminated entirely. [`accumulate_rle_block`] works directly over the
|
||||
/// `(value, count)` pairs: within each run the dequantised weight is constant
|
||||
/// across all elements in the run, so each output column `j` requires only
|
||||
/// **2 multiplies per group-segment** rather than 2 per weight element:
|
||||
///
|
||||
/// ```text
|
||||
/// c[i, j] += dq_lo * Σ B[ki_lo, j] + dq_hi * Σ B[ki_hi, j]
|
||||
/// ───────────────────────────────────────────
|
||||
/// summed over seg_len consecutive positions
|
||||
/// ```
|
||||
///
|
||||
/// For a single-run block (all bytes identical) this reduces the multiply
|
||||
/// count from `2 * QK_K = 512` to `2 * 4 = 8` per output column (4 groups,
|
||||
/// 2 nibble levels each), while B is still read exactly once.
|
||||
///
|
||||
/// For blocks in **raw mode** (`IS_RLE = 0`) the block is dequantised into a
|
||||
/// scratch buffer and its contribution is accumulated via a saxpy loop
|
||||
/// (weight-outer, column-inner), which accesses B in row-major order.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -309,25 +409,38 @@ pub fn matmul_q4k_rle_fp16(
|
||||
);
|
||||
|
||||
let mut c = vec![0.0f32; m * n];
|
||||
let mut a_row = vec![0.0f32; k];
|
||||
|
||||
// Scratch for raw-mode block dequantisation.
|
||||
let mut block_buf = [0.0f32; QK_K];
|
||||
// Scratch for RLE-mode B-column sums; allocated once and reused per segment.
|
||||
let mut sum_lo = vec![0.0f32; n];
|
||||
let mut sum_hi = vec![0.0f32; n];
|
||||
|
||||
for i in 0..m {
|
||||
// Dequantise row i of A into a_row (f32).
|
||||
let c_row = &mut c[i * n..(i + 1) * n];
|
||||
|
||||
for b_idx in 0..blocks_per_row {
|
||||
let block = &a[i * blocks_per_row + b_idx];
|
||||
dequantize_block_q4k_rle(block, &mut block_buf);
|
||||
let start = b_idx * QK_K;
|
||||
a_row[start..start + QK_K].copy_from_slice(&block_buf);
|
||||
}
|
||||
let ki_base = b_idx * QK_K;
|
||||
|
||||
// Dot-product with each column of B.
|
||||
if block.is_rle() {
|
||||
// RLE path: accumulate directly from runs, no decompression.
|
||||
accumulate_rle_block(
|
||||
block, b, ki_base, n, c_row,
|
||||
&mut sum_lo, &mut sum_hi,
|
||||
);
|
||||
} else {
|
||||
// Raw path: dequantise once, then saxpy into c_row.
|
||||
// Outer loop over weights (l) keeps B access stride-1 per row.
|
||||
dequantize_block_q4k_rle(block, &mut block_buf);
|
||||
for l in 0..QK_K {
|
||||
let w = block_buf[l];
|
||||
let b_off = (ki_base + l) * n;
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for ki in 0..k {
|
||||
sum += a_row[ki] * fp16_to_f32(b[ki * n + j]);
|
||||
c_row[j] += w * fp16_to_f32(b[b_off + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
c[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user