make loops match

This commit is contained in:
2026-04-12 19:18:14 -07:00
parent e80cd09415
commit c7909bc112

View File

@@ -211,27 +211,28 @@ pub fn matmul_q4k_fp16(
let blocks_per_row = k / QK_K;
let mut c = vec![0.0f32; m * n];
// Scratch buffers allocated once and reused across rows.
let mut a_row = vec![0.0f32; k];
// Scratch buffer for one dequantised block, reused across iterations.
let mut block_buf = [0.0f32; QK_K];
for i in 0..m {
// Step 1: dequantise row i of A into a_row (f32).
for b_idx in 0..blocks_per_row {
let block = &a[i * blocks_per_row + b_idx];
dequantize_block_q4k(block, &mut block_buf);
let start = b_idx * QK_K;
a_row[start..start + QK_K].copy_from_slice(&block_buf);
}
let c_row = &mut c[i * n..(i + 1) * n];
// Step 2: for each output column j, compute dot(a_row, B[:, j]).
for j in 0..n {
let mut sum = 0.0f32;
for ki in 0..k {
// B is row-major [K, N]: element (ki, j) → index ki*n+j.
sum += a_row[ki] * fp16_to_f32(b[ki * n + j]);
// Dequantise one block at a time and saxpy its weights directly into
// c_row. The inner loop order (weight-outer, column-inner) keeps each
// B row in a contiguous stride-1 access, which is more cache-friendly
// than the alternative (column-outer, weight-inner) that jumps by N
// between consecutive B reads.
for b_idx in 0..blocks_per_row {
let block = &a[i * blocks_per_row + b_idx];
let ki_base = b_idx * QK_K;
dequantize_block_q4k(block, &mut block_buf);
for l in 0..QK_K {
let w = block_buf[l];
let b_off = (ki_base + l) * n;
for j in 0..n {
c_row[j] += w * fp16_to_f32(b[b_off + j]);
}
}
c[i * n + j] = sum;
}
}