From 16d1f37ae53a368254cc636121ec7e6ea1dfd1cc Mon Sep 17 00:00:00 2001 From: charles Date: Sun, 12 Apr 2026 14:45:20 -0700 Subject: [PATCH] Initial commit --- .gitignore | 1 + Cargo.lock | 7 + Cargo.toml | 6 + src/lib.rs | 738 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 82 ++++++ 5 files changed, 834 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 src/lib.rs create mode 100644 src/main.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..2b7cb4d --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "matrix-testing" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..e48d2ed --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "matrix-testing" +version = "0.1.0" +edition = "2024" + +[dependencies] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0a1cb56 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,738 @@ +//! Naive Q4_K_M × FP16 matrix multiplication. +//! +//! Q4_K_M (called `block_q4_K` in GGML) is a 4-bit K-quant format with 256 +//! elements per super-block. Each super-block carries: +//! +//! - `d` (fp16) – super-block scale for the sub-block scales +//! - `dmin` (fp16) – super-block scale for the sub-block mins +//! - `scales` [12 u8] – 8 pairs of (scale, min), each 6-bit, packed together +//! - `qs` [128 u8] – 256 values at 4 bits each (two values per byte) +//! +//! The dequantised value for nibble `q` in sub-block `s` is: +//! +//! ```text +//! d * scales[s] * q - dmin * mins[s] +//! ``` +//! +//! This library deliberately uses the simplest possible O(M·N·K) algorithm: +//! dequantise each row of A into f32, convert each element of B from fp16 to +//! f32, accumulate dot-products. No SIMD, no tiling, no tricks. + +// --------------------------------------------------------------------------- +// Constants matching GGML's ggml-common.h +// --------------------------------------------------------------------------- + +/// Number of elements in one Q4_K super-block. +pub const QK_K: usize = 256; + +/// Number of bytes used to store the 8 (scale, min) pairs. +pub const K_SCALE_SIZE: usize = 12; + +// --------------------------------------------------------------------------- +// Block definition +// --------------------------------------------------------------------------- + +/// One Q4_K super-block, binary-compatible with GGML's `block_q4_K`. +/// +/// Memory layout (in order): +/// +/// | Offset | Field | Size | +/// |--------|----------|-------| +/// | 0 | `d` | 2 B | +/// | 2 | `dmin` | 2 B | +/// | 4 | `scales` | 12 B | +/// | 16 | `qs` | 128 B | +/// +/// Total: 144 bytes. +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct BlockQ4K { + /// Super-block scale for the quantised sub-block scales (fp16 bits). + pub d: u16, + /// Super-block scale for the quantised sub-block mins (fp16 bits). + pub dmin: u16, + /// Packed 6-bit sub-block scales and mins. + /// 8 scales + 8 mins (6 bits each) encoded into 12 bytes. + pub scales: [u8; K_SCALE_SIZE], + /// 4-bit quantised weights: two weights per byte, 128 bytes for 256 values. + pub qs: [u8; QK_K / 2], +} + +// --------------------------------------------------------------------------- +// FP16 → f32 conversion (no external dependencies) +// --------------------------------------------------------------------------- + +/// Convert an IEEE 754 half-precision float stored as raw `u16` bits to `f32`. +/// +/// Handles all IEEE 754 cases: ±zero, normal numbers, infinity, NaN. +/// Subnormal fp16 values (exponent field = 0, non-zero mantissa) are treated +/// as zero — they are vanishingly small and irrelevant for LLM weights. +#[inline] +pub fn fp16_to_f32(bits: u16) -> f32 { + let sign = (bits as u32 & 0x8000) << 16; // sign bit → f32 position + let exp_mant = bits as u32 & 0x7FFF; + + let f32_bits = if (bits & 0x7C00) == 0 { + // ±zero or subnormal (exponent field = 0) → treat as signed zero. + sign + } else if (bits & 0x7C00) == 0x7C00 { + // Infinity or NaN: all exponent bits set. + sign | 0x7F80_0000 | ((bits as u32 & 0x03FF) << 13) + } else { + // Normal number: rebias exponent from fp16 (bias 15) to f32 (bias 127). + // Δbias = 127 − 15 = 112 = 112 × 1024 in the 13-shifted representation. + sign | ((exp_mant + (112 << 10)) << 13) + }; + + f32::from_bits(f32_bits) +} + +// --------------------------------------------------------------------------- +// Scale extraction – mirrors GGML's get_scale_min_k4 +// --------------------------------------------------------------------------- + +/// Extract the 6-bit scale and 6-bit min for sub-block index `j` (0..8). +/// +/// GGML packs 8 pairs of 6-bit values into 12 bytes using a two-part scheme: +/// +/// **j = 0..3** +/// ```text +/// scale = scales[j] & 0x3F +/// min = scales[j + 4] & 0x3F +/// ``` +/// +/// **j = 4..7** +/// ```text +/// scale = (scales[j+4] & 0x0F) | ((scales[j-4] >> 6) << 4) +/// min = (scales[j+4] >> 4) | ((scales[j] >> 6) << 4) +/// ``` +#[inline] +pub(crate) fn get_scale_min(j: usize, scales: &[u8; K_SCALE_SIZE]) -> (u8, u8) { + debug_assert!(j < 8); + if j < 4 { + let sc = scales[j] & 63; + let mn = scales[j + 4] & 63; + (sc, mn) + } else { + let sc = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4); + let mn = (scales[j + 4] >> 4) | ((scales[j ] >> 6) << 4); + (sc, mn) + } +} + +// --------------------------------------------------------------------------- +// Dequantisation +// --------------------------------------------------------------------------- + +/// Dequantise one Q4_K super-block into 256 `f32` values. +/// +/// The loop mirrors GGML's `dequantize_row_q4_K`: +/// +/// ```text +/// for each group of 64 elements (4 groups total): +/// d1, m1 = scale × d, min × dmin for sub-block (is + 0) +/// d2, m2 = scale × d, min × dmin for sub-block (is + 1) +/// out[0..32] = d1 × lower_nibble(qs[0..32]) − m1 +/// out[32..64] = d2 × upper_nibble(qs[0..32]) − m2 +/// advance qs by 32 bytes, is by 2 +/// ``` +pub fn dequantize_block_q4k(block: &BlockQ4K, out: &mut [f32; QK_K]) { + let d = fp16_to_f32(block.d); + let dmin = fp16_to_f32(block.dmin); + + let mut q_off = 0usize; // byte cursor into block.qs + let mut out_off = 0usize; // element cursor into out + let mut is = 0usize; // sub-block pair index (0, 2, 4, 6) + + while out_off < QK_K { + let (sc1, mn1) = get_scale_min(is, &block.scales); + let (sc2, mn2) = get_scale_min(is + 1, &block.scales); + + let d1 = d * sc1 as f32; + let m1 = dmin * mn1 as f32; + let d2 = d * sc2 as f32; + let m2 = dmin * mn2 as f32; + + for l in 0..32 { + out[out_off + l] = d1 * (block.qs[q_off + l] & 0x0F) as f32 - m1; + } + for l in 0..32 { + out[out_off + 32 + l] = d2 * (block.qs[q_off + l] >> 4) as f32 - m2; + } + + q_off += 32; + out_off += 64; + is += 2; + } +} + +// --------------------------------------------------------------------------- +// Matrix multiplication C = A × B +// --------------------------------------------------------------------------- + +/// Multiply a Q4_K_M matrix **A** by an FP16 matrix **B**, producing an f32 +/// matrix **C**. +/// +/// # Arguments +/// +/// * `a` – Row-major slice of [`BlockQ4K`]. Row `i` occupies blocks +/// `a[i * blocks_per_row .. (i+1) * blocks_per_row]`. +/// * `b` – Row-major fp16 matrix stored as raw `u16` bits, shape \[K, N\]. +/// Element `(ki, j)` is at index `ki * n + j`. +/// * `m` – Number of rows in A (and C). +/// * `k` – Number of columns in A = number of rows in B. +/// **Must** be a multiple of [`QK_K`] (256). +/// * `n` – Number of columns in B (and C). +/// +/// # Returns +/// +/// A flat row-major `Vec` of shape \[M, N\]. +/// +/// # Panics +/// +/// Panics if `k` is not a multiple of `QK_K`, or if the lengths of `a` or `b` +/// do not match the declared dimensions. +pub fn matmul_q4k_fp16( + a: &[BlockQ4K], + b: &[u16], + m: usize, + k: usize, + n: usize, +) -> Vec { + assert_eq!(k % QK_K, 0, + "k ({k}) must be a multiple of QK_K ({QK_K})"); + assert_eq!(a.len(), m * (k / QK_K), + "A block count mismatch: expected {} blocks, got {}", m * (k / QK_K), a.len()); + assert_eq!(b.len(), k * n, + "B element count mismatch: expected {}, got {}", k * n, b.len()); + + let blocks_per_row = k / QK_K; + let mut c = vec![0.0f32; m * n]; + + // Scratch buffers allocated once and reused across rows. + let mut a_row = vec![0.0f32; k]; + let mut block_buf = [0.0f32; QK_K]; + + for i in 0..m { + // Step 1: dequantise row i of A into a_row (f32). + for b_idx in 0..blocks_per_row { + let block = &a[i * blocks_per_row + b_idx]; + dequantize_block_q4k(block, &mut block_buf); + let start = b_idx * QK_K; + a_row[start..start + QK_K].copy_from_slice(&block_buf); + } + + // Step 2: for each output column j, compute dot(a_row, B[:, j]). + for j in 0..n { + let mut sum = 0.0f32; + for ki in 0..k { + // B is row-major [K, N]: element (ki, j) → index ki*n+j. + sum += a_row[ki] * fp16_to_f32(b[ki * n + j]); + } + c[i * n + j] = sum; + } + } + + c +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // ------------------------------------------------------------------------- + // Test helpers + // ------------------------------------------------------------------------- + + /// Convert a normal finite f32 (including ±0.0, ±inf) to its fp16 bit + /// pattern. Panics if the value is a non-representable normal number + /// (magnitude too large or too small for fp16's normal range). + fn f32_to_fp16_bits(f: f32) -> u16 { + if f == 0.0 { return 0x0000; } + if f == f32::INFINITY { return 0x7C00; } + if f == f32::NEG_INFINITY { return 0xFC00; } + if f.is_nan() { return 0x7E00; } + + let bits = f.to_bits(); + let sign = ((bits >> 31) as u16) << 15; + let exp = (bits >> 23) & 0xFF; + let mant = bits & 0x007F_FFFF; + let fp16_exp = exp as i32 - 127 + 15; + assert!( + fp16_exp > 0 && fp16_exp < 31, + "f32 value {f} is outside the representable fp16 normal range" + ); + sign | ((fp16_exp as u16) << 10) | ((mant >> 13) as u16) + } + + /// Build a [`BlockQ4K`] where all 8 sub-blocks share the same `scale` and + /// `min` (both **must** be < 16 to keep the encoding simple), and every + /// byte in `qs` is `qs_byte`. + /// + /// With scale, min < 16, the scales array encoding simplifies to: + /// ```text + /// scales[0..4] = scale (used for sub-blocks 0..3) + /// scales[4..8] = min (used for sub-blocks 0..3) + /// scales[8..12] = (scale & 0xF) | ((min & 0xF) << 4) (sub-blocks 4..7) + /// ``` + fn make_block(d: f32, dmin: f32, scale: u8, min: u8, qs_byte: u8) -> BlockQ4K { + assert!( + scale < 16 && min < 16, + "make_block: scale ({scale}) and min ({min}) must both be < 16" + ); + let mut scales = [0u8; K_SCALE_SIZE]; + for j in 0..4 { + scales[j] = scale; + scales[j + 4] = min; + } + for j in 8..12 { + scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4); + } + BlockQ4K { + d: f32_to_fp16_bits(d), + dmin: f32_to_fp16_bits(dmin), + scales, + qs: [qs_byte; QK_K / 2], + } + } + + /// Assert two f32 values are within `tol` of each other. + fn assert_close(got: f32, expected: f32, tol: f32) { + assert!( + (got - expected).abs() <= tol, + "got {got}, expected {expected} (tol {tol})" + ); + } + + /// Assert every element of `got` equals `expected_scalar` within `tol`. + fn assert_all_close(got: &[f32], expected_scalar: f32, tol: f32) { + for (i, &g) in got.iter().enumerate() { + assert!( + (g - expected_scalar).abs() <= tol, + "element {i}: got {g}, expected {expected_scalar} (tol {tol})" + ); + } + } + + /// Assert two f32 slices are element-wise equal within `tol`. + + + /// Build a flat fp16 matrix (shape K×N) where every element is `value`. + fn fp16_uniform(k: usize, n: usize, value: f32) -> Vec { + vec![f32_to_fp16_bits(value); k * n] + } + + // ========================================================================= + // fp16_to_f32 + // ========================================================================= + + #[test] + fn fp16_positive_zero() { + let v = fp16_to_f32(0x0000); + assert_eq!(v, 0.0f32); + assert!(v.is_sign_positive()); + } + + #[test] + fn fp16_negative_zero() { + let v = fp16_to_f32(0x8000); + // IEEE 754: -0.0 == +0.0 + assert_eq!(v, 0.0f32); + assert!(v.is_sign_negative()); + } + + #[test] + fn fp16_one_and_negative_one() { + // 0x3C00: exp=15, mant=0 → (1+0) × 2^0 = 1.0 + assert_eq!(fp16_to_f32(0x3C00), 1.0f32); + // 0xBC00: sign=1, same magnitude + assert_eq!(fp16_to_f32(0xBC00), -1.0f32); + } + + #[test] + fn fp16_powers_of_two() { + // 0.5 = 2^-1: exp=14, mant=0 → 0x3800 + assert_eq!(fp16_to_f32(0x3800), 0.5f32); + // 2.0 = 2^1: exp=16, mant=0 → 0x4000 + assert_eq!(fp16_to_f32(0x4000), 2.0f32); + // 4.0 = 2^2: exp=17, mant=0 → 0x4400 + assert_eq!(fp16_to_f32(0x4400), 4.0f32); + // 8.0 = 2^3: exp=18, mant=0 → 0x4800 + assert_eq!(fp16_to_f32(0x4800), 8.0f32); + // 64.0 = 2^6: exp=21, mant=0 → 0x5400 + assert_eq!(fp16_to_f32(0x5400), 64.0f32); + } + + #[test] + fn fp16_non_power_of_two() { + // 3.0 = 1.5 × 2^1: exp=16, mant=512 → 0x4200 + // 0x4200 = bit14=1, bit9=1 → exp=16, mant=512 → (1+0.5)×2=3.0 + assert_eq!(fp16_to_f32(0x4200), 3.0f32); + // 10.0 = 1.25 × 2^3: exp=18, mant=256 → 0x4900 + // 0x4900 = bits14-10=10010=18, bits9-0=01_0000_0000=256 + // → (1+256/1024)×8 = 1.25×8 = 10.0 + assert_eq!(fp16_to_f32(0x4900), 10.0f32); + } + + #[test] + fn fp16_positive_infinity() { + let v = fp16_to_f32(0x7C00); + assert!(v.is_infinite()); + assert!(v.is_sign_positive()); + } + + #[test] + fn fp16_negative_infinity() { + let v = fp16_to_f32(0xFC00); + assert!(v.is_infinite()); + assert!(v.is_sign_negative()); + } + + #[test] + fn fp16_nan() { + // 0x7E00: exp all-ones, non-zero mantissa → NaN + assert!(fp16_to_f32(0x7E00).is_nan()); + // 0xFE00: signed NaN + assert!(fp16_to_f32(0xFE00).is_nan()); + } + + #[test] + fn fp16_subnormals_become_zero() { + // Subnormals (exp=0, mant≠0) are too small for LLM weights; we + // return signed zero rather than performing the full decode. + assert_eq!(fp16_to_f32(0x0001), 0.0f32); // smallest positive subnormal + assert_eq!(fp16_to_f32(0x03FF), 0.0f32); // largest positive subnormal + assert_eq!(fp16_to_f32(0x8200), 0.0f32); // a negative subnormal + } + + #[test] + fn fp16_roundtrip_via_helper() { + // Verify that f32_to_fp16_bits + fp16_to_f32 recovers the original + // value exactly for numbers that are precisely representable in fp16. + let values: &[f32] = &[0.5, 1.0, 2.0, 3.0, 4.0, 8.0, 10.0, 64.0, -1.0, -3.0]; + for &v in values { + let bits = f32_to_fp16_bits(v); + let recovered = fp16_to_f32(bits); + assert_eq!(recovered, v, "round-trip failed for {v}"); + } + } + + // ========================================================================= + // get_scale_min + // ========================================================================= + + #[test] + fn scale_min_j_lt_4_basic() { + // For j < 4: scale = scales[j] & 63, min = scales[j+4] & 63. + let mut scales = [0u8; K_SCALE_SIZE]; + scales[0] = 42; + scales[4] = 17; + let (sc, mn) = get_scale_min(0, &scales); + assert_eq!(sc, 42); + assert_eq!(mn, 17); + } + + #[test] + fn scale_min_j_lt_4_all_four_indices() { + let mut scales = [0u8; K_SCALE_SIZE]; + scales[0] = 10; scales[4] = 20; + scales[1] = 11; scales[5] = 21; + scales[2] = 12; scales[6] = 22; + scales[3] = 13; scales[7] = 23; + + for (j, (exp_sc, exp_mn)) in (0..4).zip([(10u8, 20u8), (11, 21), (12, 22), (13, 23)]) { + let (sc, mn) = get_scale_min(j, &scales); + assert_eq!(sc, exp_sc, "scale mismatch at j={j}"); + assert_eq!(mn, exp_mn, "min mismatch at j={j}"); + } + } + + #[test] + fn scale_min_j_lt_4_masks_high_bits() { + // Bits 6–7 of scales[j] / scales[j+4] contribute to the j+4 sub-block + // encoding, not to the j sub-block. The & 63 must strip them. + let mut scales = [0u8; K_SCALE_SIZE]; + scales[2] = 0b11_101010; // raw=0xEA=234, lower 6 bits = 0b101010 = 42 + scales[6] = 0b10_010101; // raw=0x95=149, lower 6 bits = 0b010101 = 21 + let (sc, mn) = get_scale_min(2, &scales); + assert_eq!(sc, 42); + assert_eq!(mn, 21); + } + + #[test] + fn scale_min_j_gte_4_small_values() { + // j = 4, scale = 5 (<16), min = 3 (<16) – upper bits are zero. + // scales[8] = (5 & 0xF) | ((3 & 0xF) << 4) = 0x05 | 0x30 = 0x35 + let mut scales = [0u8; K_SCALE_SIZE]; + scales[8] = 0x35; + let (sc, mn) = get_scale_min(4, &scales); + assert_eq!(sc, 5); + assert_eq!(mn, 3); + } + + #[test] + fn scale_min_j_gte_4_needs_upper_bits() { + // j = 4, scale = 31 (0b011111), min = 25 (0b011001). + // Both need 5 bits, so the top bits must come from scales[0] and scales[4]. + // + // Encoding: + // scales[8] = (31 & 0xF) | ((25 & 0xF) << 4) = 0x0F | 0x90 = 0x9F + // scales[0] |= ((31 >> 4) & 3) << 6 → bits 6-7 = 1 → 0x40 + // scales[4] |= ((25 >> 4) & 3) << 6 → bits 6-7 = 1 → 0x40 + // + // Decoding at j = 4: + // scale = (0x9F & 0x0F) | ((0x40 >> 6) << 4) = 15 | 16 = 31 + // min = (0x9F >> 4) | ((0x40 >> 6) << 4) = 9 | 16 = 25 + let mut scales = [0u8; K_SCALE_SIZE]; + scales[0] = 0x40; + scales[4] = 0x40; + scales[8] = 0x9F; + let (sc, mn) = get_scale_min(4, &scales); + assert_eq!(sc, 31, "scale mismatch"); + assert_eq!(mn, 25, "min mismatch"); + } + + #[test] + fn scale_min_j_7_last_index() { + // j = 7: scale = (scales[11] & 0x0F) | ((scales[3] >> 6) << 4) + // min = (scales[11] >> 4) | ((scales[7] >> 6) << 4) + // Choose scale = 20 (0b010100), min = 10 (0b001010). + // scales[11] = (20 & 0xF) | ((10 & 0xF) << 4) = 0x04 | 0xA0 = 0xA4 + // scales[3] |= ((20 >> 4) & 3) << 6 → bit 6 = 1 → 0x40 + // scales[7]: (10 >> 4) & 3 = 0, no change needed + let mut scales = [0u8; K_SCALE_SIZE]; + scales[3] = 0x40; + scales[11] = 0xA4; + let (sc, mn) = get_scale_min(7, &scales); + assert_eq!(sc, 20); + assert_eq!(mn, 10); + } + + // ========================================================================= + // dequantize_block_q4k + // ========================================================================= + + #[test] + fn dequant_zero_d_all_outputs_zero() { + // d = 0.0 → every product is 0; dmin = 0 so the subtracted min is also 0. + let block = make_block(0.0, 0.0, 5, 3, 0xFF); + let mut out = [f32::NAN; QK_K]; + dequantize_block_q4k(&block, &mut out); + assert_all_close(&out, 0.0, 0.0); + } + + #[test] + fn dequant_uniform_nibble_one_scale_one() { + // d=1.0, dmin=0.0, scale=1, min=0, all nibbles=1 + // formula: 1.0 × 1 × 1 − 0.0 × 0 = 1.0 + let block = make_block(1.0, 0.0, 1, 0, 0x11); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k(&block, &mut out); + assert_all_close(&out, 1.0, 0.0); + } + + #[test] + fn dequant_max_nibble_with_larger_scale() { + // d=2.0, dmin=0.0, scale=3, min=0, all nibbles=15 + // formula: 2.0 × 3 × 15 − 0 = 90.0 + let block = make_block(2.0, 0.0, 3, 0, 0xFF); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k(&block, &mut out); + assert_all_close(&out, 90.0, 1e-4); + } + + #[test] + fn dequant_non_zero_min_subtracts() { + // d=1.0, dmin=1.0, scale=4, min=3, nibble=5 + // formula: 1.0 × 4 × 5 − 1.0 × 3 = 20 − 3 = 17.0 + let block = make_block(1.0, 1.0, 4, 3, 0x55); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k(&block, &mut out); + assert_all_close(&out, 17.0, 1e-4); + } + + #[test] + fn dequant_zero_nibble_with_nonzero_min() { + // nibble = 0, but the min offset is still subtracted. + // d=1.0, dmin=1.0, scale=5, min=3, nibble=0 + // formula: 1.0 × 5 × 0 − 1.0 × 3 = −3.0 + let block = make_block(1.0, 1.0, 5, 3, 0x00); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k(&block, &mut out); + assert_all_close(&out, -3.0, 1e-4); + } + + #[test] + fn dequant_mixed_nibbles_correct_element_layout() { + // qs_byte = 0x21: lower nibble = 1, upper nibble = 2. + // d=1.0, dmin=0.0, scale=1, min=0. + // + // Per 64-element group the layout is: + // elements [0..32] ← lower nibbles (nibble=1) → 1.0 + // elements [32..64] ← upper nibbles (nibble=2) → 2.0 + // + // This must hold for all four groups (elements 0–255). + let block = make_block(1.0, 0.0, 1, 0, 0x21); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k(&block, &mut out); + + for group in 0..4_usize { + let base = group * 64; + for l in 0..32 { + assert_eq!( + out[base + l], 1.0, + "group {group} lower element {l} (out[{}])", base + l + ); + assert_eq!( + out[base + 32 + l], 2.0, + "group {group} upper element {l} (out[{}])", base + 32 + l + ); + } + } + } + + #[test] + fn dequant_output_count_is_qk_k() { + // Sanity-check: exactly QK_K = 256 values are written. + let block = make_block(1.0, 0.0, 1, 0, 0x33); + let mut out = [0.0f32; QK_K]; + dequantize_block_q4k(&block, &mut out); + // All elements should be non-NaN (were actually written). + assert!(out.iter().all(|v| !v.is_nan())); + } + + #[test] + fn block_q4k_size_is_144_bytes() { + // 2 (d) + 2 (dmin) + 12 (scales) + 128 (qs) = 144 bytes + assert_eq!(core::mem::size_of::(), 144); + } + + // ========================================================================= + // matmul_q4k_fp16 + // ========================================================================= + + #[test] + fn matmul_1x256_times_256x1_all_ones() { + // A: 1×256, all weights = 1.0 (one block) + // B: 256×1, all values = 1.0 + // C: 1×1, expected = 256.0 + let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)]; + let b = fp16_uniform(256, 1, 1.0); + let c = matmul_q4k_fp16(&a, &b, 1, 256, 1); + assert_eq!(c.len(), 1); + assert_close(c[0], 256.0, 0.1); + } + + #[test] + fn matmul_2x256_times_256x3_all_ones() { + // A: 2×256, all weights = 1.0 + // B: 256×3, all values = 1.0 + // C: 2×3, all = 256.0 + let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 2]; + let b = fp16_uniform(256, 3, 1.0); + let c = matmul_q4k_fp16(&a, &b, 2, 256, 3); + assert_eq!(c.len(), 6); + assert_all_close(&c, 256.0, 0.1); + } + + #[test] + fn matmul_zero_a_gives_zero_c() { + // d = 0.0 → every weight = 0.0 → every output = 0.0 + let a = vec![make_block(0.0, 0.0, 1, 0, 0xFF); 3]; + let b = fp16_uniform(256, 4, 7.0); + let c = matmul_q4k_fp16(&a, &b, 3, 256, 4); + assert_all_close(&c, 0.0, 0.0); + } + + #[test] + fn matmul_zero_b_gives_zero_c() { + let a = vec![make_block(1.0, 0.0, 1, 0, 0x55)]; + let b = fp16_uniform(256, 3, 0.0); + let c = matmul_q4k_fp16(&a, &b, 1, 256, 3); + assert_all_close(&c, 0.0, 0.0); + } + + #[test] + fn matmul_two_blocks_per_row() { + // K = 512 = 2 × QK_K; each row has two all-ones blocks. + // B = all ones → C = 512.0 + let block = make_block(1.0, 0.0, 1, 0, 0x11); + let a = vec![block; 2]; // 1 row × 2 blocks + let b = fp16_uniform(512, 1, 1.0); + let c = matmul_q4k_fp16(&a, &b, 1, 512, 1); + assert_close(c[0], 512.0, 0.1); + } + + #[test] + fn matmul_multiple_rows_multiple_blocks_per_row() { + // A: 3 rows × 2 blocks each (K = 512), all weights = 1.0 + // B: 512×2, all values = 1.0 + // C: 3×2, all = 512.0 + let block = make_block(1.0, 0.0, 1, 0, 0x11); + let a = vec![block; 3 * 2]; + let b = fp16_uniform(512, 2, 1.0); + let c = matmul_q4k_fp16(&a, &b, 3, 512, 2); + assert_eq!(c.len(), 6); + assert_all_close(&c, 512.0, 0.1); + } + + #[test] + fn matmul_alternating_weights_known_dot_product() { + // qs_byte = 0x21: lower nibble = 1, upper nibble = 2. + // d=1.0, dmin=0.0, scale=1, min=0. + // The dequantised row has the pattern: 32×1.0, 32×2.0 (×4 groups) + // → 128 ones and 128 twos. + // With B = all ones: C = 128×1.0 + 128×2.0 = 384.0 + let a = vec![make_block(1.0, 0.0, 1, 0, 0x21)]; + let b = fp16_uniform(256, 1, 1.0); + let c = matmul_q4k_fp16(&a, &b, 1, 256, 1); + assert_close(c[0], 384.0, 0.1); + } + + #[test] + fn matmul_b_scalar_scales_output() { + // All A weights = 1.0, all B values = 2.0 → C = 256 × 2.0 = 512.0 + let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)]; + let b = fp16_uniform(256, 1, 2.0); + let c = matmul_q4k_fp16(&a, &b, 1, 256, 1); + assert_close(c[0], 512.0, 0.2); + } + + #[test] + fn matmul_output_has_correct_shape() { + let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 5]; + let b = fp16_uniform(256, 7, 1.0); + let c = matmul_q4k_fp16(&a, &b, 5, 256, 7); + assert_eq!(c.len(), 5 * 7); + } + + #[test] + #[should_panic(expected = "must be a multiple of QK_K")] + fn matmul_panics_when_k_not_multiple_of_qkk() { + let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)]; + let b = fp16_uniform(128, 1, 1.0); + let _ = matmul_q4k_fp16(&a, &b, 1, 128, 1); + } + + #[test] + #[should_panic(expected = "A block count mismatch")] + fn matmul_panics_on_wrong_a_length() { + // 1×256 needs 1 block, but we supply 2. + let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 2]; + let b = fp16_uniform(256, 1, 1.0); + let _ = matmul_q4k_fp16(&a, &b, 1, 256, 1); + } + + #[test] + #[should_panic(expected = "B element count mismatch")] + fn matmul_panics_on_wrong_b_length() { + // 256×1 needs 256 elements, but we supply 128. + let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)]; + let b = vec![0x3C00u16; 128]; // 128 fp16 ones + let _ = matmul_q4k_fp16(&a, &b, 1, 256, 1); + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..f9cca13 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,82 @@ +//! Demo binary for the `matrix-testing` library. +//! +//! Constructs a small Q4_K_M × FP16 matrix multiply and prints the result. + +use matrix_testing::{matmul_q4k_fp16, BlockQ4K, K_SCALE_SIZE, QK_K}; + +fn main() { + // ----------------------------------------------------------------------- + // Build a tiny test case: A is (2 × 256) Q4_K_M, B is (256 × 3) fp16. + // + // We construct A so that every dequantised weight is exactly 1.0 and B so + // that every fp16 value is also 1.0. Then every output element should + // equal K = 256. + // ----------------------------------------------------------------------- + + const M: usize = 2; // rows of A / rows of C + const K: usize = 256; // cols of A / rows of B (one Q4_K block wide) + const N: usize = 3; // cols of B / cols of C + + let fp16_one = 0x3C00u16; // 1.0 in fp16 + let fp16_zero = 0x0000u16; // 0.0 in fp16 + + // ---- Build A ----------------------------------------------------------- + // Goal: dequant(q) == 1.0 for every element. + // + // formula: d * scale * q - dmin * min = 1.0 + // + // Choosing d=1.0, dmin=0.0, scale=1, min=0, nibble=1: + // 1.0 * 1 * 1 - 0.0 * 0 = 1.0 ✓ + // + // Scale encoding for values < 16 (upper bits are zero): + // scales[0..4] = scale = 1 + // scales[4..8] = min = 0 + // scales[8..12] = (scale & 0xF) | ((min & 0xF) << 4) = 0x01 + + let mut scales = [0u8; K_SCALE_SIZE]; + for j in 0..4 { + scales[j] = 1; // scale = 1 for sub-blocks 0..3 + scales[j + 4] = 0; // min = 0 for sub-blocks 0..3 + } + for s in scales.iter_mut().skip(8) { + *s = 0x01; // scale=1, min=0 for sub-blocks 4..7 + } + + let block_template = BlockQ4K { + d: fp16_one, + dmin: fp16_zero, + scales, + qs: [0x11u8; QK_K / 2], // nibble=1 in both halves of every byte + }; + + let a_blocks: Vec = vec![block_template; M * (K / QK_K)]; + + // ---- Build B ----------------------------------------------------------- + let b_fp16: Vec = vec![fp16_one; K * N]; + + // ---- Run the multiply -------------------------------------------------- + let c = matmul_q4k_fp16(&a_blocks, &b_fp16, M, K, N); + + // ---- Print results ----------------------------------------------------- + println!("Output matrix C ({M} x {N}):"); + for i in 0..M { + print!(" row {i}: "); + for j in 0..N { + print!("{:.1} ", c[i * N + j]); + } + println!(); + } + + let expected = K as f32; + let all_ok = c.iter().all(|&v| (v - expected).abs() < 0.1); + if all_ok { + println!("All outputs == {expected:.1} ✓"); + } else { + eprintln!("FAIL: unexpected output values"); + std::process::exit(1); + } + + println!(); + println!("Note: this is the naive O(M·N·K) implementation."); + println!("It is intentionally simple – no SIMD, no tiling, no tricks."); +}