Initial commit
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/target
|
||||
7
Cargo.lock
generated
Normal file
7
Cargo.lock
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "matrix-testing"
|
||||
version = "0.1.0"
|
||||
6
Cargo.toml
Normal file
6
Cargo.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[package]
|
||||
name = "matrix-testing"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
738
src/lib.rs
Normal file
738
src/lib.rs
Normal file
@@ -0,0 +1,738 @@
|
||||
//! Naive Q4_K_M × FP16 matrix multiplication.
|
||||
//!
|
||||
//! Q4_K_M (called `block_q4_K` in GGML) is a 4-bit K-quant format with 256
|
||||
//! elements per super-block. Each super-block carries:
|
||||
//!
|
||||
//! - `d` (fp16) – super-block scale for the sub-block scales
|
||||
//! - `dmin` (fp16) – super-block scale for the sub-block mins
|
||||
//! - `scales` [12 u8] – 8 pairs of (scale, min), each 6-bit, packed together
|
||||
//! - `qs` [128 u8] – 256 values at 4 bits each (two values per byte)
|
||||
//!
|
||||
//! The dequantised value for nibble `q` in sub-block `s` is:
|
||||
//!
|
||||
//! ```text
|
||||
//! d * scales[s] * q - dmin * mins[s]
|
||||
//! ```
|
||||
//!
|
||||
//! This library deliberately uses the simplest possible O(M·N·K) algorithm:
|
||||
//! dequantise each row of A into f32, convert each element of B from fp16 to
|
||||
//! f32, accumulate dot-products. No SIMD, no tiling, no tricks.
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constants matching GGML's ggml-common.h
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Number of elements in one Q4_K super-block.
|
||||
pub const QK_K: usize = 256;
|
||||
|
||||
/// Number of bytes used to store the 8 (scale, min) pairs.
|
||||
pub const K_SCALE_SIZE: usize = 12;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Block definition
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// One Q4_K super-block, binary-compatible with GGML's `block_q4_K`.
|
||||
///
|
||||
/// Memory layout (in order):
|
||||
///
|
||||
/// | Offset | Field | Size |
|
||||
/// |--------|----------|-------|
|
||||
/// | 0 | `d` | 2 B |
|
||||
/// | 2 | `dmin` | 2 B |
|
||||
/// | 4 | `scales` | 12 B |
|
||||
/// | 16 | `qs` | 128 B |
|
||||
///
|
||||
/// Total: 144 bytes.
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct BlockQ4K {
|
||||
/// Super-block scale for the quantised sub-block scales (fp16 bits).
|
||||
pub d: u16,
|
||||
/// Super-block scale for the quantised sub-block mins (fp16 bits).
|
||||
pub dmin: u16,
|
||||
/// Packed 6-bit sub-block scales and mins.
|
||||
/// 8 scales + 8 mins (6 bits each) encoded into 12 bytes.
|
||||
pub scales: [u8; K_SCALE_SIZE],
|
||||
/// 4-bit quantised weights: two weights per byte, 128 bytes for 256 values.
|
||||
pub qs: [u8; QK_K / 2],
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FP16 → f32 conversion (no external dependencies)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Convert an IEEE 754 half-precision float stored as raw `u16` bits to `f32`.
|
||||
///
|
||||
/// Handles all IEEE 754 cases: ±zero, normal numbers, infinity, NaN.
|
||||
/// Subnormal fp16 values (exponent field = 0, non-zero mantissa) are treated
|
||||
/// as zero — they are vanishingly small and irrelevant for LLM weights.
|
||||
#[inline]
|
||||
pub fn fp16_to_f32(bits: u16) -> f32 {
|
||||
let sign = (bits as u32 & 0x8000) << 16; // sign bit → f32 position
|
||||
let exp_mant = bits as u32 & 0x7FFF;
|
||||
|
||||
let f32_bits = if (bits & 0x7C00) == 0 {
|
||||
// ±zero or subnormal (exponent field = 0) → treat as signed zero.
|
||||
sign
|
||||
} else if (bits & 0x7C00) == 0x7C00 {
|
||||
// Infinity or NaN: all exponent bits set.
|
||||
sign | 0x7F80_0000 | ((bits as u32 & 0x03FF) << 13)
|
||||
} else {
|
||||
// Normal number: rebias exponent from fp16 (bias 15) to f32 (bias 127).
|
||||
// Δbias = 127 − 15 = 112 = 112 × 1024 in the 13-shifted representation.
|
||||
sign | ((exp_mant + (112 << 10)) << 13)
|
||||
};
|
||||
|
||||
f32::from_bits(f32_bits)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Scale extraction – mirrors GGML's get_scale_min_k4
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Extract the 6-bit scale and 6-bit min for sub-block index `j` (0..8).
|
||||
///
|
||||
/// GGML packs 8 pairs of 6-bit values into 12 bytes using a two-part scheme:
|
||||
///
|
||||
/// **j = 0..3**
|
||||
/// ```text
|
||||
/// scale = scales[j] & 0x3F
|
||||
/// min = scales[j + 4] & 0x3F
|
||||
/// ```
|
||||
///
|
||||
/// **j = 4..7**
|
||||
/// ```text
|
||||
/// scale = (scales[j+4] & 0x0F) | ((scales[j-4] >> 6) << 4)
|
||||
/// min = (scales[j+4] >> 4) | ((scales[j] >> 6) << 4)
|
||||
/// ```
|
||||
#[inline]
|
||||
pub(crate) fn get_scale_min(j: usize, scales: &[u8; K_SCALE_SIZE]) -> (u8, u8) {
|
||||
debug_assert!(j < 8);
|
||||
if j < 4 {
|
||||
let sc = scales[j] & 63;
|
||||
let mn = scales[j + 4] & 63;
|
||||
(sc, mn)
|
||||
} else {
|
||||
let sc = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4);
|
||||
let mn = (scales[j + 4] >> 4) | ((scales[j ] >> 6) << 4);
|
||||
(sc, mn)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Dequantisation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Dequantise one Q4_K super-block into 256 `f32` values.
|
||||
///
|
||||
/// The loop mirrors GGML's `dequantize_row_q4_K`:
|
||||
///
|
||||
/// ```text
|
||||
/// for each group of 64 elements (4 groups total):
|
||||
/// d1, m1 = scale × d, min × dmin for sub-block (is + 0)
|
||||
/// d2, m2 = scale × d, min × dmin for sub-block (is + 1)
|
||||
/// out[0..32] = d1 × lower_nibble(qs[0..32]) − m1
|
||||
/// out[32..64] = d2 × upper_nibble(qs[0..32]) − m2
|
||||
/// advance qs by 32 bytes, is by 2
|
||||
/// ```
|
||||
pub fn dequantize_block_q4k(block: &BlockQ4K, out: &mut [f32; QK_K]) {
|
||||
let d = fp16_to_f32(block.d);
|
||||
let dmin = fp16_to_f32(block.dmin);
|
||||
|
||||
let mut q_off = 0usize; // byte cursor into block.qs
|
||||
let mut out_off = 0usize; // element cursor into out
|
||||
let mut is = 0usize; // sub-block pair index (0, 2, 4, 6)
|
||||
|
||||
while out_off < QK_K {
|
||||
let (sc1, mn1) = get_scale_min(is, &block.scales);
|
||||
let (sc2, mn2) = get_scale_min(is + 1, &block.scales);
|
||||
|
||||
let d1 = d * sc1 as f32;
|
||||
let m1 = dmin * mn1 as f32;
|
||||
let d2 = d * sc2 as f32;
|
||||
let m2 = dmin * mn2 as f32;
|
||||
|
||||
for l in 0..32 {
|
||||
out[out_off + l] = d1 * (block.qs[q_off + l] & 0x0F) as f32 - m1;
|
||||
}
|
||||
for l in 0..32 {
|
||||
out[out_off + 32 + l] = d2 * (block.qs[q_off + l] >> 4) as f32 - m2;
|
||||
}
|
||||
|
||||
q_off += 32;
|
||||
out_off += 64;
|
||||
is += 2;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Matrix multiplication C = A × B
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Multiply a Q4_K_M matrix **A** by an FP16 matrix **B**, producing an f32
|
||||
/// matrix **C**.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` – Row-major slice of [`BlockQ4K`]. Row `i` occupies blocks
|
||||
/// `a[i * blocks_per_row .. (i+1) * blocks_per_row]`.
|
||||
/// * `b` – Row-major fp16 matrix stored as raw `u16` bits, shape \[K, N\].
|
||||
/// Element `(ki, j)` is at index `ki * n + j`.
|
||||
/// * `m` – Number of rows in A (and C).
|
||||
/// * `k` – Number of columns in A = number of rows in B.
|
||||
/// **Must** be a multiple of [`QK_K`] (256).
|
||||
/// * `n` – Number of columns in B (and C).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A flat row-major `Vec<f32>` of shape \[M, N\].
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `k` is not a multiple of `QK_K`, or if the lengths of `a` or `b`
|
||||
/// do not match the declared dimensions.
|
||||
pub fn matmul_q4k_fp16(
|
||||
a: &[BlockQ4K],
|
||||
b: &[u16],
|
||||
m: usize,
|
||||
k: usize,
|
||||
n: usize,
|
||||
) -> Vec<f32> {
|
||||
assert_eq!(k % QK_K, 0,
|
||||
"k ({k}) must be a multiple of QK_K ({QK_K})");
|
||||
assert_eq!(a.len(), m * (k / QK_K),
|
||||
"A block count mismatch: expected {} blocks, got {}", m * (k / QK_K), a.len());
|
||||
assert_eq!(b.len(), k * n,
|
||||
"B element count mismatch: expected {}, got {}", k * n, b.len());
|
||||
|
||||
let blocks_per_row = k / QK_K;
|
||||
let mut c = vec![0.0f32; m * n];
|
||||
|
||||
// Scratch buffers allocated once and reused across rows.
|
||||
let mut a_row = vec![0.0f32; k];
|
||||
let mut block_buf = [0.0f32; QK_K];
|
||||
|
||||
for i in 0..m {
|
||||
// Step 1: dequantise row i of A into a_row (f32).
|
||||
for b_idx in 0..blocks_per_row {
|
||||
let block = &a[i * blocks_per_row + b_idx];
|
||||
dequantize_block_q4k(block, &mut block_buf);
|
||||
let start = b_idx * QK_K;
|
||||
a_row[start..start + QK_K].copy_from_slice(&block_buf);
|
||||
}
|
||||
|
||||
// Step 2: for each output column j, compute dot(a_row, B[:, j]).
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for ki in 0..k {
|
||||
// B is row-major [K, N]: element (ki, j) → index ki*n+j.
|
||||
sum += a_row[ki] * fp16_to_f32(b[ki * n + j]);
|
||||
}
|
||||
c[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Test helpers
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
/// Convert a normal finite f32 (including ±0.0, ±inf) to its fp16 bit
|
||||
/// pattern. Panics if the value is a non-representable normal number
|
||||
/// (magnitude too large or too small for fp16's normal range).
|
||||
fn f32_to_fp16_bits(f: f32) -> u16 {
|
||||
if f == 0.0 { return 0x0000; }
|
||||
if f == f32::INFINITY { return 0x7C00; }
|
||||
if f == f32::NEG_INFINITY { return 0xFC00; }
|
||||
if f.is_nan() { return 0x7E00; }
|
||||
|
||||
let bits = f.to_bits();
|
||||
let sign = ((bits >> 31) as u16) << 15;
|
||||
let exp = (bits >> 23) & 0xFF;
|
||||
let mant = bits & 0x007F_FFFF;
|
||||
let fp16_exp = exp as i32 - 127 + 15;
|
||||
assert!(
|
||||
fp16_exp > 0 && fp16_exp < 31,
|
||||
"f32 value {f} is outside the representable fp16 normal range"
|
||||
);
|
||||
sign | ((fp16_exp as u16) << 10) | ((mant >> 13) as u16)
|
||||
}
|
||||
|
||||
/// Build a [`BlockQ4K`] where all 8 sub-blocks share the same `scale` and
|
||||
/// `min` (both **must** be < 16 to keep the encoding simple), and every
|
||||
/// byte in `qs` is `qs_byte`.
|
||||
///
|
||||
/// With scale, min < 16, the scales array encoding simplifies to:
|
||||
/// ```text
|
||||
/// scales[0..4] = scale (used for sub-blocks 0..3)
|
||||
/// scales[4..8] = min (used for sub-blocks 0..3)
|
||||
/// scales[8..12] = (scale & 0xF) | ((min & 0xF) << 4) (sub-blocks 4..7)
|
||||
/// ```
|
||||
fn make_block(d: f32, dmin: f32, scale: u8, min: u8, qs_byte: u8) -> BlockQ4K {
|
||||
assert!(
|
||||
scale < 16 && min < 16,
|
||||
"make_block: scale ({scale}) and min ({min}) must both be < 16"
|
||||
);
|
||||
let mut scales = [0u8; K_SCALE_SIZE];
|
||||
for j in 0..4 {
|
||||
scales[j] = scale;
|
||||
scales[j + 4] = min;
|
||||
}
|
||||
for j in 8..12 {
|
||||
scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4);
|
||||
}
|
||||
BlockQ4K {
|
||||
d: f32_to_fp16_bits(d),
|
||||
dmin: f32_to_fp16_bits(dmin),
|
||||
scales,
|
||||
qs: [qs_byte; QK_K / 2],
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert two f32 values are within `tol` of each other.
|
||||
fn assert_close(got: f32, expected: f32, tol: f32) {
|
||||
assert!(
|
||||
(got - expected).abs() <= tol,
|
||||
"got {got}, expected {expected} (tol {tol})"
|
||||
);
|
||||
}
|
||||
|
||||
/// Assert every element of `got` equals `expected_scalar` within `tol`.
|
||||
fn assert_all_close(got: &[f32], expected_scalar: f32, tol: f32) {
|
||||
for (i, &g) in got.iter().enumerate() {
|
||||
assert!(
|
||||
(g - expected_scalar).abs() <= tol,
|
||||
"element {i}: got {g}, expected {expected_scalar} (tol {tol})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert two f32 slices are element-wise equal within `tol`.
|
||||
|
||||
|
||||
/// Build a flat fp16 matrix (shape K×N) where every element is `value`.
|
||||
fn fp16_uniform(k: usize, n: usize, value: f32) -> Vec<u16> {
|
||||
vec![f32_to_fp16_bits(value); k * n]
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// fp16_to_f32
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn fp16_positive_zero() {
|
||||
let v = fp16_to_f32(0x0000);
|
||||
assert_eq!(v, 0.0f32);
|
||||
assert!(v.is_sign_positive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_negative_zero() {
|
||||
let v = fp16_to_f32(0x8000);
|
||||
// IEEE 754: -0.0 == +0.0
|
||||
assert_eq!(v, 0.0f32);
|
||||
assert!(v.is_sign_negative());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_one_and_negative_one() {
|
||||
// 0x3C00: exp=15, mant=0 → (1+0) × 2^0 = 1.0
|
||||
assert_eq!(fp16_to_f32(0x3C00), 1.0f32);
|
||||
// 0xBC00: sign=1, same magnitude
|
||||
assert_eq!(fp16_to_f32(0xBC00), -1.0f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_powers_of_two() {
|
||||
// 0.5 = 2^-1: exp=14, mant=0 → 0x3800
|
||||
assert_eq!(fp16_to_f32(0x3800), 0.5f32);
|
||||
// 2.0 = 2^1: exp=16, mant=0 → 0x4000
|
||||
assert_eq!(fp16_to_f32(0x4000), 2.0f32);
|
||||
// 4.0 = 2^2: exp=17, mant=0 → 0x4400
|
||||
assert_eq!(fp16_to_f32(0x4400), 4.0f32);
|
||||
// 8.0 = 2^3: exp=18, mant=0 → 0x4800
|
||||
assert_eq!(fp16_to_f32(0x4800), 8.0f32);
|
||||
// 64.0 = 2^6: exp=21, mant=0 → 0x5400
|
||||
assert_eq!(fp16_to_f32(0x5400), 64.0f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_non_power_of_two() {
|
||||
// 3.0 = 1.5 × 2^1: exp=16, mant=512 → 0x4200
|
||||
// 0x4200 = bit14=1, bit9=1 → exp=16, mant=512 → (1+0.5)×2=3.0
|
||||
assert_eq!(fp16_to_f32(0x4200), 3.0f32);
|
||||
// 10.0 = 1.25 × 2^3: exp=18, mant=256 → 0x4900
|
||||
// 0x4900 = bits14-10=10010=18, bits9-0=01_0000_0000=256
|
||||
// → (1+256/1024)×8 = 1.25×8 = 10.0
|
||||
assert_eq!(fp16_to_f32(0x4900), 10.0f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_positive_infinity() {
|
||||
let v = fp16_to_f32(0x7C00);
|
||||
assert!(v.is_infinite());
|
||||
assert!(v.is_sign_positive());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_negative_infinity() {
|
||||
let v = fp16_to_f32(0xFC00);
|
||||
assert!(v.is_infinite());
|
||||
assert!(v.is_sign_negative());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_nan() {
|
||||
// 0x7E00: exp all-ones, non-zero mantissa → NaN
|
||||
assert!(fp16_to_f32(0x7E00).is_nan());
|
||||
// 0xFE00: signed NaN
|
||||
assert!(fp16_to_f32(0xFE00).is_nan());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_subnormals_become_zero() {
|
||||
// Subnormals (exp=0, mant≠0) are too small for LLM weights; we
|
||||
// return signed zero rather than performing the full decode.
|
||||
assert_eq!(fp16_to_f32(0x0001), 0.0f32); // smallest positive subnormal
|
||||
assert_eq!(fp16_to_f32(0x03FF), 0.0f32); // largest positive subnormal
|
||||
assert_eq!(fp16_to_f32(0x8200), 0.0f32); // a negative subnormal
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp16_roundtrip_via_helper() {
|
||||
// Verify that f32_to_fp16_bits + fp16_to_f32 recovers the original
|
||||
// value exactly for numbers that are precisely representable in fp16.
|
||||
let values: &[f32] = &[0.5, 1.0, 2.0, 3.0, 4.0, 8.0, 10.0, 64.0, -1.0, -3.0];
|
||||
for &v in values {
|
||||
let bits = f32_to_fp16_bits(v);
|
||||
let recovered = fp16_to_f32(bits);
|
||||
assert_eq!(recovered, v, "round-trip failed for {v}");
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// get_scale_min
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn scale_min_j_lt_4_basic() {
|
||||
// For j < 4: scale = scales[j] & 63, min = scales[j+4] & 63.
|
||||
let mut scales = [0u8; K_SCALE_SIZE];
|
||||
scales[0] = 42;
|
||||
scales[4] = 17;
|
||||
let (sc, mn) = get_scale_min(0, &scales);
|
||||
assert_eq!(sc, 42);
|
||||
assert_eq!(mn, 17);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scale_min_j_lt_4_all_four_indices() {
|
||||
let mut scales = [0u8; K_SCALE_SIZE];
|
||||
scales[0] = 10; scales[4] = 20;
|
||||
scales[1] = 11; scales[5] = 21;
|
||||
scales[2] = 12; scales[6] = 22;
|
||||
scales[3] = 13; scales[7] = 23;
|
||||
|
||||
for (j, (exp_sc, exp_mn)) in (0..4).zip([(10u8, 20u8), (11, 21), (12, 22), (13, 23)]) {
|
||||
let (sc, mn) = get_scale_min(j, &scales);
|
||||
assert_eq!(sc, exp_sc, "scale mismatch at j={j}");
|
||||
assert_eq!(mn, exp_mn, "min mismatch at j={j}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scale_min_j_lt_4_masks_high_bits() {
|
||||
// Bits 6–7 of scales[j] / scales[j+4] contribute to the j+4 sub-block
|
||||
// encoding, not to the j sub-block. The & 63 must strip them.
|
||||
let mut scales = [0u8; K_SCALE_SIZE];
|
||||
scales[2] = 0b11_101010; // raw=0xEA=234, lower 6 bits = 0b101010 = 42
|
||||
scales[6] = 0b10_010101; // raw=0x95=149, lower 6 bits = 0b010101 = 21
|
||||
let (sc, mn) = get_scale_min(2, &scales);
|
||||
assert_eq!(sc, 42);
|
||||
assert_eq!(mn, 21);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scale_min_j_gte_4_small_values() {
|
||||
// j = 4, scale = 5 (<16), min = 3 (<16) – upper bits are zero.
|
||||
// scales[8] = (5 & 0xF) | ((3 & 0xF) << 4) = 0x05 | 0x30 = 0x35
|
||||
let mut scales = [0u8; K_SCALE_SIZE];
|
||||
scales[8] = 0x35;
|
||||
let (sc, mn) = get_scale_min(4, &scales);
|
||||
assert_eq!(sc, 5);
|
||||
assert_eq!(mn, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scale_min_j_gte_4_needs_upper_bits() {
|
||||
// j = 4, scale = 31 (0b011111), min = 25 (0b011001).
|
||||
// Both need 5 bits, so the top bits must come from scales[0] and scales[4].
|
||||
//
|
||||
// Encoding:
|
||||
// scales[8] = (31 & 0xF) | ((25 & 0xF) << 4) = 0x0F | 0x90 = 0x9F
|
||||
// scales[0] |= ((31 >> 4) & 3) << 6 → bits 6-7 = 1 → 0x40
|
||||
// scales[4] |= ((25 >> 4) & 3) << 6 → bits 6-7 = 1 → 0x40
|
||||
//
|
||||
// Decoding at j = 4:
|
||||
// scale = (0x9F & 0x0F) | ((0x40 >> 6) << 4) = 15 | 16 = 31
|
||||
// min = (0x9F >> 4) | ((0x40 >> 6) << 4) = 9 | 16 = 25
|
||||
let mut scales = [0u8; K_SCALE_SIZE];
|
||||
scales[0] = 0x40;
|
||||
scales[4] = 0x40;
|
||||
scales[8] = 0x9F;
|
||||
let (sc, mn) = get_scale_min(4, &scales);
|
||||
assert_eq!(sc, 31, "scale mismatch");
|
||||
assert_eq!(mn, 25, "min mismatch");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scale_min_j_7_last_index() {
|
||||
// j = 7: scale = (scales[11] & 0x0F) | ((scales[3] >> 6) << 4)
|
||||
// min = (scales[11] >> 4) | ((scales[7] >> 6) << 4)
|
||||
// Choose scale = 20 (0b010100), min = 10 (0b001010).
|
||||
// scales[11] = (20 & 0xF) | ((10 & 0xF) << 4) = 0x04 | 0xA0 = 0xA4
|
||||
// scales[3] |= ((20 >> 4) & 3) << 6 → bit 6 = 1 → 0x40
|
||||
// scales[7]: (10 >> 4) & 3 = 0, no change needed
|
||||
let mut scales = [0u8; K_SCALE_SIZE];
|
||||
scales[3] = 0x40;
|
||||
scales[11] = 0xA4;
|
||||
let (sc, mn) = get_scale_min(7, &scales);
|
||||
assert_eq!(sc, 20);
|
||||
assert_eq!(mn, 10);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// dequantize_block_q4k
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn dequant_zero_d_all_outputs_zero() {
|
||||
// d = 0.0 → every product is 0; dmin = 0 so the subtracted min is also 0.
|
||||
let block = make_block(0.0, 0.0, 5, 3, 0xFF);
|
||||
let mut out = [f32::NAN; QK_K];
|
||||
dequantize_block_q4k(&block, &mut out);
|
||||
assert_all_close(&out, 0.0, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequant_uniform_nibble_one_scale_one() {
|
||||
// d=1.0, dmin=0.0, scale=1, min=0, all nibbles=1
|
||||
// formula: 1.0 × 1 × 1 − 0.0 × 0 = 1.0
|
||||
let block = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k(&block, &mut out);
|
||||
assert_all_close(&out, 1.0, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequant_max_nibble_with_larger_scale() {
|
||||
// d=2.0, dmin=0.0, scale=3, min=0, all nibbles=15
|
||||
// formula: 2.0 × 3 × 15 − 0 = 90.0
|
||||
let block = make_block(2.0, 0.0, 3, 0, 0xFF);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k(&block, &mut out);
|
||||
assert_all_close(&out, 90.0, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequant_non_zero_min_subtracts() {
|
||||
// d=1.0, dmin=1.0, scale=4, min=3, nibble=5
|
||||
// formula: 1.0 × 4 × 5 − 1.0 × 3 = 20 − 3 = 17.0
|
||||
let block = make_block(1.0, 1.0, 4, 3, 0x55);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k(&block, &mut out);
|
||||
assert_all_close(&out, 17.0, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequant_zero_nibble_with_nonzero_min() {
|
||||
// nibble = 0, but the min offset is still subtracted.
|
||||
// d=1.0, dmin=1.0, scale=5, min=3, nibble=0
|
||||
// formula: 1.0 × 5 × 0 − 1.0 × 3 = −3.0
|
||||
let block = make_block(1.0, 1.0, 5, 3, 0x00);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k(&block, &mut out);
|
||||
assert_all_close(&out, -3.0, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequant_mixed_nibbles_correct_element_layout() {
|
||||
// qs_byte = 0x21: lower nibble = 1, upper nibble = 2.
|
||||
// d=1.0, dmin=0.0, scale=1, min=0.
|
||||
//
|
||||
// Per 64-element group the layout is:
|
||||
// elements [0..32] ← lower nibbles (nibble=1) → 1.0
|
||||
// elements [32..64] ← upper nibbles (nibble=2) → 2.0
|
||||
//
|
||||
// This must hold for all four groups (elements 0–255).
|
||||
let block = make_block(1.0, 0.0, 1, 0, 0x21);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k(&block, &mut out);
|
||||
|
||||
for group in 0..4_usize {
|
||||
let base = group * 64;
|
||||
for l in 0..32 {
|
||||
assert_eq!(
|
||||
out[base + l], 1.0,
|
||||
"group {group} lower element {l} (out[{}])", base + l
|
||||
);
|
||||
assert_eq!(
|
||||
out[base + 32 + l], 2.0,
|
||||
"group {group} upper element {l} (out[{}])", base + 32 + l
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequant_output_count_is_qk_k() {
|
||||
// Sanity-check: exactly QK_K = 256 values are written.
|
||||
let block = make_block(1.0, 0.0, 1, 0, 0x33);
|
||||
let mut out = [0.0f32; QK_K];
|
||||
dequantize_block_q4k(&block, &mut out);
|
||||
// All elements should be non-NaN (were actually written).
|
||||
assert!(out.iter().all(|v| !v.is_nan()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_q4k_size_is_144_bytes() {
|
||||
// 2 (d) + 2 (dmin) + 12 (scales) + 128 (qs) = 144 bytes
|
||||
assert_eq!(core::mem::size_of::<BlockQ4K>(), 144);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// matmul_q4k_fp16
|
||||
// =========================================================================
|
||||
|
||||
#[test]
|
||||
fn matmul_1x256_times_256x1_all_ones() {
|
||||
// A: 1×256, all weights = 1.0 (one block)
|
||||
// B: 256×1, all values = 1.0
|
||||
// C: 1×1, expected = 256.0
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)];
|
||||
let b = fp16_uniform(256, 1, 1.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 1, 256, 1);
|
||||
assert_eq!(c.len(), 1);
|
||||
assert_close(c[0], 256.0, 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_2x256_times_256x3_all_ones() {
|
||||
// A: 2×256, all weights = 1.0
|
||||
// B: 256×3, all values = 1.0
|
||||
// C: 2×3, all = 256.0
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 2];
|
||||
let b = fp16_uniform(256, 3, 1.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 2, 256, 3);
|
||||
assert_eq!(c.len(), 6);
|
||||
assert_all_close(&c, 256.0, 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_zero_a_gives_zero_c() {
|
||||
// d = 0.0 → every weight = 0.0 → every output = 0.0
|
||||
let a = vec![make_block(0.0, 0.0, 1, 0, 0xFF); 3];
|
||||
let b = fp16_uniform(256, 4, 7.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 3, 256, 4);
|
||||
assert_all_close(&c, 0.0, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_zero_b_gives_zero_c() {
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x55)];
|
||||
let b = fp16_uniform(256, 3, 0.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 1, 256, 3);
|
||||
assert_all_close(&c, 0.0, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_two_blocks_per_row() {
|
||||
// K = 512 = 2 × QK_K; each row has two all-ones blocks.
|
||||
// B = all ones → C = 512.0
|
||||
let block = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let a = vec![block; 2]; // 1 row × 2 blocks
|
||||
let b = fp16_uniform(512, 1, 1.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 1, 512, 1);
|
||||
assert_close(c[0], 512.0, 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_multiple_rows_multiple_blocks_per_row() {
|
||||
// A: 3 rows × 2 blocks each (K = 512), all weights = 1.0
|
||||
// B: 512×2, all values = 1.0
|
||||
// C: 3×2, all = 512.0
|
||||
let block = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||
let a = vec![block; 3 * 2];
|
||||
let b = fp16_uniform(512, 2, 1.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 3, 512, 2);
|
||||
assert_eq!(c.len(), 6);
|
||||
assert_all_close(&c, 512.0, 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_alternating_weights_known_dot_product() {
|
||||
// qs_byte = 0x21: lower nibble = 1, upper nibble = 2.
|
||||
// d=1.0, dmin=0.0, scale=1, min=0.
|
||||
// The dequantised row has the pattern: 32×1.0, 32×2.0 (×4 groups)
|
||||
// → 128 ones and 128 twos.
|
||||
// With B = all ones: C = 128×1.0 + 128×2.0 = 384.0
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x21)];
|
||||
let b = fp16_uniform(256, 1, 1.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 1, 256, 1);
|
||||
assert_close(c[0], 384.0, 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_b_scalar_scales_output() {
|
||||
// All A weights = 1.0, all B values = 2.0 → C = 256 × 2.0 = 512.0
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)];
|
||||
let b = fp16_uniform(256, 1, 2.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 1, 256, 1);
|
||||
assert_close(c[0], 512.0, 0.2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_output_has_correct_shape() {
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 5];
|
||||
let b = fp16_uniform(256, 7, 1.0);
|
||||
let c = matmul_q4k_fp16(&a, &b, 5, 256, 7);
|
||||
assert_eq!(c.len(), 5 * 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "must be a multiple of QK_K")]
|
||||
fn matmul_panics_when_k_not_multiple_of_qkk() {
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)];
|
||||
let b = fp16_uniform(128, 1, 1.0);
|
||||
let _ = matmul_q4k_fp16(&a, &b, 1, 128, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "A block count mismatch")]
|
||||
fn matmul_panics_on_wrong_a_length() {
|
||||
// 1×256 needs 1 block, but we supply 2.
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11); 2];
|
||||
let b = fp16_uniform(256, 1, 1.0);
|
||||
let _ = matmul_q4k_fp16(&a, &b, 1, 256, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "B element count mismatch")]
|
||||
fn matmul_panics_on_wrong_b_length() {
|
||||
// 256×1 needs 256 elements, but we supply 128.
|
||||
let a = vec![make_block(1.0, 0.0, 1, 0, 0x11)];
|
||||
let b = vec![0x3C00u16; 128]; // 128 fp16 ones
|
||||
let _ = matmul_q4k_fp16(&a, &b, 1, 256, 1);
|
||||
}
|
||||
}
|
||||
82
src/main.rs
Normal file
82
src/main.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
//! Demo binary for the `matrix-testing` library.
|
||||
//!
|
||||
//! Constructs a small Q4_K_M × FP16 matrix multiply and prints the result.
|
||||
|
||||
use matrix_testing::{matmul_q4k_fp16, BlockQ4K, K_SCALE_SIZE, QK_K};
|
||||
|
||||
fn main() {
|
||||
// -----------------------------------------------------------------------
|
||||
// Build a tiny test case: A is (2 × 256) Q4_K_M, B is (256 × 3) fp16.
|
||||
//
|
||||
// We construct A so that every dequantised weight is exactly 1.0 and B so
|
||||
// that every fp16 value is also 1.0. Then every output element should
|
||||
// equal K = 256.
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
const M: usize = 2; // rows of A / rows of C
|
||||
const K: usize = 256; // cols of A / rows of B (one Q4_K block wide)
|
||||
const N: usize = 3; // cols of B / cols of C
|
||||
|
||||
let fp16_one = 0x3C00u16; // 1.0 in fp16
|
||||
let fp16_zero = 0x0000u16; // 0.0 in fp16
|
||||
|
||||
// ---- Build A -----------------------------------------------------------
|
||||
// Goal: dequant(q) == 1.0 for every element.
|
||||
//
|
||||
// formula: d * scale * q - dmin * min = 1.0
|
||||
//
|
||||
// Choosing d=1.0, dmin=0.0, scale=1, min=0, nibble=1:
|
||||
// 1.0 * 1 * 1 - 0.0 * 0 = 1.0 ✓
|
||||
//
|
||||
// Scale encoding for values < 16 (upper bits are zero):
|
||||
// scales[0..4] = scale = 1
|
||||
// scales[4..8] = min = 0
|
||||
// scales[8..12] = (scale & 0xF) | ((min & 0xF) << 4) = 0x01
|
||||
|
||||
let mut scales = [0u8; K_SCALE_SIZE];
|
||||
for j in 0..4 {
|
||||
scales[j] = 1; // scale = 1 for sub-blocks 0..3
|
||||
scales[j + 4] = 0; // min = 0 for sub-blocks 0..3
|
||||
}
|
||||
for s in scales.iter_mut().skip(8) {
|
||||
*s = 0x01; // scale=1, min=0 for sub-blocks 4..7
|
||||
}
|
||||
|
||||
let block_template = BlockQ4K {
|
||||
d: fp16_one,
|
||||
dmin: fp16_zero,
|
||||
scales,
|
||||
qs: [0x11u8; QK_K / 2], // nibble=1 in both halves of every byte
|
||||
};
|
||||
|
||||
let a_blocks: Vec<BlockQ4K> = vec![block_template; M * (K / QK_K)];
|
||||
|
||||
// ---- Build B -----------------------------------------------------------
|
||||
let b_fp16: Vec<u16> = vec![fp16_one; K * N];
|
||||
|
||||
// ---- Run the multiply --------------------------------------------------
|
||||
let c = matmul_q4k_fp16(&a_blocks, &b_fp16, M, K, N);
|
||||
|
||||
// ---- Print results -----------------------------------------------------
|
||||
println!("Output matrix C ({M} x {N}):");
|
||||
for i in 0..M {
|
||||
print!(" row {i}: ");
|
||||
for j in 0..N {
|
||||
print!("{:.1} ", c[i * N + j]);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
let expected = K as f32;
|
||||
let all_ok = c.iter().all(|&v| (v - expected).abs() < 0.1);
|
||||
if all_ok {
|
||||
println!("All outputs == {expected:.1} ✓");
|
||||
} else {
|
||||
eprintln!("FAIL: unexpected output values");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
println!();
|
||||
println!("Note: this is the naive O(M·N·K) implementation.");
|
||||
println!("It is intentionally simple – no SIMD, no tiling, no tricks.");
|
||||
}
|
||||
Reference in New Issue
Block a user