diff --git a/src/lib.rs b/src/lib.rs index 8b2a21e..ded81b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -211,27 +211,28 @@ pub fn matmul_q4k_fp16( let blocks_per_row = k / QK_K; let mut c = vec![0.0f32; m * n]; - // Scratch buffers allocated once and reused across rows. - let mut a_row = vec![0.0f32; k]; + // Scratch buffer for one dequantised block, reused across iterations. let mut block_buf = [0.0f32; QK_K]; for i in 0..m { - // Step 1: dequantise row i of A into a_row (f32). - for b_idx in 0..blocks_per_row { - let block = &a[i * blocks_per_row + b_idx]; - dequantize_block_q4k(block, &mut block_buf); - let start = b_idx * QK_K; - a_row[start..start + QK_K].copy_from_slice(&block_buf); - } + let c_row = &mut c[i * n..(i + 1) * n]; - // Step 2: for each output column j, compute dot(a_row, B[:, j]). - for j in 0..n { - let mut sum = 0.0f32; - for ki in 0..k { - // B is row-major [K, N]: element (ki, j) → index ki*n+j. - sum += a_row[ki] * fp16_to_f32(b[ki * n + j]); + // Dequantise one block at a time and saxpy its weights directly into + // c_row. The inner loop order (weight-outer, column-inner) keeps each + // B row in a contiguous stride-1 access, which is more cache-friendly + // than the alternative (column-outer, weight-inner) that jumps by N + // between consecutive B reads. + for b_idx in 0..blocks_per_row { + let block = &a[i * blocks_per_row + b_idx]; + let ki_base = b_idx * QK_K; + dequantize_block_q4k(block, &mut block_buf); + for l in 0..QK_K { + let w = block_buf[l]; + let b_off = (ki_base + l) * n; + for j in 0..n { + c_row[j] += w * fp16_to_f32(b[b_off + j]); + } } - c[i * n + j] = sum; } }