Create gguf_matmul.rs

This commit is contained in:
2026-04-12 19:40:58 -07:00
parent c7909bc112
commit 992097b4e0

457
src/bin/gguf_matmul.rs Normal file
View File

@@ -0,0 +1,457 @@
//! Read one Q4_K layer from a GGUF model file, multiply it against a random
//! FP16 activation matrix, and compare the baseline and RLE implementations.
//!
//! # Usage
//!
//! ```text
//! cargo run --release --bin gguf_matmul -- <model.gguf> [n]
//! ```
//!
//! `n` is the number of activation columns (token count / batch size).
//! Defaults to 1 (single-token inference).
//!
//! # GGUF layout (v2 / v3)
//!
//! ```text
//! ┌─────────────────────────────────────────────────────┐
//! │ magic u32 │ version u32 │ n_tensors u64 │ n_kv u64 │
//! ├─────────────────────────────────────────────────────┤
//! │ metadata key-value pairs (variable length) │
//! ├─────────────────────────────────────────────────────┤
//! │ tensor info records (variable length) │
//! ├─────────────────────────────────────────────────────┤
//! │ padding to `alignment` boundary (default 32 bytes) │
//! ├─────────────────────────────────────────────────────┤
//! │ tensor data (concatenated, each individually padded)│
//! └─────────────────────────────────────────────────────┘
//! ```
//!
//! Each Q4_K block is 144 bytes:
//! `d(2) + dmin(2) + scales(12) + qs(128)` — identical to our `BlockQ4K`.
use std::{
env,
error::Error,
fs::File,
io::{self, BufReader, Read, Seek, SeekFrom},
time::{Duration, Instant},
};
use matrix_testing::{
matmul_q4k_fp16, BlockQ4K, K_SCALE_SIZE, QK_K,
rle::{encode, matmul_q4k_rle_fp16, BlockQ4KRle},
};
// ---------------------------------------------------------------------------
// GGUF constants
// ---------------------------------------------------------------------------
/// File magic: bytes b"GGUF" interpreted as a little-endian u32.
const GGUF_MAGIC: u32 = 0x4655_4747;
/// Default tensor data alignment when not overridden by `general.alignment`.
const GGUF_DEFAULT_ALIGNMENT: u64 = 32;
/// GGML tensor type code for Q4_K (matches ggml.h `GGML_TYPE_Q4_K`).
const GGML_TYPE_Q4_K: u32 = 12;
/// Size in bytes of one Q4_K block: d(2) + dmin(2) + scales(12) + qs(128).
const BLOCK_BYTES: usize = 2 + 2 + K_SCALE_SIZE + QK_K / 2; // 144
// GGUF metadata value type tags (gguf-spec §3.2).
const GTYPE_U8: u32 = 0;
const GTYPE_I8: u32 = 1;
const GTYPE_U16: u32 = 2;
const GTYPE_I16: u32 = 3;
const GTYPE_U32: u32 = 4;
const GTYPE_I32: u32 = 5;
const GTYPE_F32: u32 = 6;
const GTYPE_BOOL: u32 = 7;
const GTYPE_STR: u32 = 8;
const GTYPE_ARR: u32 = 9;
const GTYPE_U64: u32 = 10;
const GTYPE_I64: u32 = 11;
const GTYPE_F64: u32 = 12;
// ---------------------------------------------------------------------------
// Primitive binary readers (little-endian, no deps)
// ---------------------------------------------------------------------------
fn read_u8(r: &mut impl Read) -> io::Result<u8> {
let mut b = [0u8; 1];
r.read_exact(&mut b)?;
Ok(b[0])
}
fn read_u16(r: &mut impl Read) -> io::Result<u16> {
let mut b = [0u8; 2];
r.read_exact(&mut b)?;
Ok(u16::from_le_bytes(b))
}
fn read_u32(r: &mut impl Read) -> io::Result<u32> {
let mut b = [0u8; 4];
r.read_exact(&mut b)?;
Ok(u32::from_le_bytes(b))
}
fn read_u64(r: &mut impl Read) -> io::Result<u64> {
let mut b = [0u8; 8];
r.read_exact(&mut b)?;
Ok(u64::from_le_bytes(b))
}
/// Read a GGUF length-prefixed UTF-8 string.
fn read_str(r: &mut impl Read) -> io::Result<String> {
let len = read_u64(r)? as usize;
let mut buf = vec![0u8; len];
r.read_exact(&mut buf)?;
Ok(String::from_utf8_lossy(&buf).into_owned())
}
/// Skip one GGUF metadata value of the given type tag without storing it.
fn skip_value(r: &mut impl Read, tag: u32) -> io::Result<()> {
match tag {
GTYPE_U8 | GTYPE_I8 | GTYPE_BOOL => { read_u8(r)?; }
GTYPE_U16 | GTYPE_I16 => { read_u16(r)?; }
GTYPE_U32 | GTYPE_I32 | GTYPE_F32 => { read_u32(r)?; }
GTYPE_U64 | GTYPE_I64 | GTYPE_F64 => { read_u64(r)?; }
GTYPE_STR => { read_str(r)?; }
GTYPE_ARR => {
let elem_tag = read_u32(r)?;
let count = read_u64(r)?;
for _ in 0..count {
skip_value(r, elem_tag)?;
}
}
t => return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unknown GGUF value type {t}"),
)),
}
Ok(())
}
// ---------------------------------------------------------------------------
// Tensor info
// ---------------------------------------------------------------------------
struct TensorInfo {
name: String,
/// Dimensions in GGML order: dims[0] is the innermost (fastest-varying).
/// For a 2-D weight matrix: dims[0] = K (in-features), dims[1] = M (out-features).
dims: Vec<u64>,
dtype: u32,
/// Byte offset of this tensor's data measured from the start of the data
/// section (i.e. add `data_start` to get the absolute file offset).
offset: u64,
}
impl TensorInfo {
fn n_elements(&self) -> u64 {
self.dims.iter().product()
}
fn data_bytes(&self) -> u64 {
debug_assert_eq!(self.dtype, GGML_TYPE_Q4_K);
(self.n_elements() / QK_K as u64) * BLOCK_BYTES as u64
}
/// Return (m, k) matrix dimensions.
/// dims[0] = K (column / inner dim), dims[1] = M (row / outer dim).
fn matrix_dims(&self) -> (usize, usize) {
assert_eq!(self.dims.len(), 2, "expected 2-D tensor");
let k = self.dims[0] as usize;
let m = self.dims[1] as usize;
(m, k)
}
}
// ---------------------------------------------------------------------------
// GGUF header parser
// ---------------------------------------------------------------------------
/// Parse the GGUF file header and return `(tensor_infos, data_start_offset)`.
///
/// `data_start_offset` is the absolute byte position where tensor data begins.
fn parse_header(path: &str) -> Result<(Vec<TensorInfo>, u64), Box<dyn Error>> {
let mut r = BufReader::new(File::open(path)?);
// Magic + version
let magic = read_u32(&mut r)?;
if magic != GGUF_MAGIC {
return Err(format!(
"not a GGUF file (expected magic {GGUF_MAGIC:#010x}, got {magic:#010x})"
).into());
}
let version = read_u32(&mut r)?;
if !(2..=3).contains(&version) {
eprintln!("warning: unexpected GGUF version {version} — proceeding anyway");
}
let n_tensors = read_u64(&mut r)? as usize;
let n_metadata = read_u64(&mut r)?;
// Scan metadata KV pairs; capture `general.alignment` if present.
let mut alignment = GGUF_DEFAULT_ALIGNMENT;
for _ in 0..n_metadata {
let key = read_str(&mut r)?;
let tag = read_u32(&mut r)?;
if key == "general.alignment" && tag == GTYPE_U32 {
alignment = read_u32(&mut r)? as u64;
} else {
skip_value(&mut r, tag)?;
}
}
// Tensor info records.
let mut tensors = Vec::with_capacity(n_tensors);
for _ in 0..n_tensors {
let name = read_str(&mut r)?;
let n_dims = read_u32(&mut r)? as usize;
let dims: Vec<u64> = (0..n_dims)
.map(|_| read_u64(&mut r))
.collect::<io::Result<_>>()?;
let dtype = read_u32(&mut r)?;
let offset = read_u64(&mut r)?;
tensors.push(TensorInfo { name, dims, dtype, offset });
}
// Data starts at the next `alignment`-byte boundary after the header.
let header_end = r.stream_position()?;
let data_start = header_end.div_ceil(alignment) * alignment;
Ok((tensors, data_start))
}
// ---------------------------------------------------------------------------
// Block loader
// ---------------------------------------------------------------------------
/// Seek to the tensor's data and read its Q4_K blocks into a Vec.
fn load_blocks(
r: &mut (impl Read + Seek),
data_start: u64,
tensor: &TensorInfo,
) -> io::Result<Vec<BlockQ4K>> {
r.seek(SeekFrom::Start(data_start + tensor.offset))?;
let n_blocks = (tensor.n_elements() / QK_K as u64) as usize;
let mut blocks = Vec::with_capacity(n_blocks);
let mut buf = [0u8; BLOCK_BYTES];
for _ in 0..n_blocks {
r.read_exact(&mut buf)?;
// Parse field by field — safe, no transmute.
// Layout: d(0..2) dmin(2..4) scales(4..16) qs(16..144)
blocks.push(BlockQ4K {
d: u16::from_le_bytes([buf[0], buf[1]]),
dmin: u16::from_le_bytes([buf[2], buf[3]]),
scales: buf[4..16].try_into().unwrap(),
qs: buf[16..BLOCK_BYTES].try_into().unwrap(),
});
}
Ok(blocks)
}
// ---------------------------------------------------------------------------
// Random FP16 activation matrix (no external rand dep)
// ---------------------------------------------------------------------------
/// Minimal 64-bit LCG (Knuth / PCG constants).
struct Lcg(u64);
impl Lcg {
fn new(seed: u64) -> Self { Self(seed) }
/// Return the next pseudo-random f32 in (0.05, +0.05).
/// This is a plausible scale for normalised transformer activations.
fn next_f32(&mut self) -> f32 {
self.0 = self.0
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
// Map high 32 bits to [0, 1) then shift to (0.05, +0.05).
(self.0 >> 32) as f32 / 4_294_967_296.0 * 0.10 - 0.05
}
}
/// Lossily convert a finite f32 in the fp16 normal range to fp16 bits.
fn f32_to_fp16(f: f32) -> u16 {
if f == 0.0 { return 0; }
let bits = f.to_bits();
let sign = ((bits >> 31) as u16) << 15;
let exp = ((bits >> 23) & 0xFF) as i32 - 127 + 15;
let mantissa = (bits & 0x007F_FFFF) >> 13;
if exp <= 0 { return sign; } // underflow → signed zero
if exp >= 31 { return sign | 0x7C00; } // overflow → signed infinity
sign | ((exp as u16) << 10) | mantissa as u16
}
/// Build a K × N FP16 matrix filled with pseudo-random activations.
fn random_fp16_matrix(k: usize, n: usize, seed: u64) -> Vec<u16> {
let mut lcg = Lcg::new(seed);
(0..k * n).map(|_| f32_to_fp16(lcg.next_f32())).collect()
}
// ---------------------------------------------------------------------------
// Timing helper
// ---------------------------------------------------------------------------
/// Run `f` `trials` times and return the result of the last run plus the
/// minimum elapsed time across all trials (best-of-N timing).
fn bench<T, F: FnMut() -> T>(trials: usize, mut f: F) -> (T, Duration) {
let mut best = Duration::MAX;
let mut last = None;
for _ in 0..trials {
let t = Instant::now();
let r = f();
let dt = t.elapsed();
if dt < best { best = dt; }
last = Some(r);
}
(last.unwrap(), best)
}
fn mflops(m: usize, k: usize, n: usize, dur: Duration) -> f64 {
2.0 * m as f64 * k as f64 * n as f64 / dur.as_secs_f64() / 1e6
}
// ---------------------------------------------------------------------------
// Main
// ---------------------------------------------------------------------------
fn main() -> Result<(), Box<dyn Error>> {
let args: Vec<String> = env::args().collect();
if args.len() < 2 {
eprintln!("usage: {} <model.gguf> [n_cols=1] [trials=3]", args[0]);
std::process::exit(1);
}
let path = &args[1];
let n = args.get(2).and_then(|s| s.parse::<usize>().ok()).unwrap_or(1);
let trials = args.get(3).and_then(|s| s.parse::<usize>().ok()).unwrap_or(3);
// ── Parse header ────────────────────────────────────────────────────────
println!("Parsing {path}");
let (tensors, data_start) = parse_header(path)?;
// List all tensor names and types for context.
let n_q4k = tensors.iter().filter(|t| t.dtype == GGML_TYPE_Q4_K).count();
println!(" {} tensors total, {} are Q4_K", tensors.len(), n_q4k);
// ── Select the first suitable 2-D Q4_K tensor ───────────────────────────
// "Suitable" means 2-D and K divisible by QK_K (required for our matmul).
let tensor = tensors
.iter()
.find(|t| {
t.dtype == GGML_TYPE_Q4_K
&& t.dims.len() == 2
&& t.dims[0] % QK_K as u64 == 0
})
.ok_or("no suitable 2-D Q4_K tensor found in this GGUF file")?;
let (m, k) = tensor.matrix_dims();
let n_blocks = m * (k / QK_K);
let data_mib = tensor.data_bytes() as f64 / (1u64 << 20) as f64;
println!();
println!("┌─ Tensor ───────────────────────────────────────");
println!("│ name : {}", tensor.name);
println!("│ shape : {m} rows × {k} cols");
println!("│ blocks : {n_blocks} ({data_mib:.1} MiB on disk)");
println!("│ activation: {k} × {n} (random FP16)");
println!("│ trials : {trials} (best-of reported)");
println!("└────────────────────────────────────────────────");
// ── Load blocks ─────────────────────────────────────────────────────────
print!("\nLoading blocks … ");
let t0 = Instant::now();
let mut file = BufReader::new(File::open(path)?);
let blocks = load_blocks(&mut file, data_start, tensor)?;
println!("{:.3} s ({} blocks × {} B)", t0.elapsed().as_secs_f64(), n_blocks, BLOCK_BYTES);
// ── Build random activation matrix [K × N] ───────────────────────────────
let b_fp16 = random_fp16_matrix(k, n, 0xDEAD_BEEF_CAFE_1234);
// ── Baseline matmul (best of `trials`) ──────────────────────────────────
let (c_base, t_base) = bench(trials, || {
matmul_q4k_fp16(&blocks, &b_fp16, m, k, n)
});
println!(
"Baseline: {:.3} s {:.0} MFLOP/s",
t_base.as_secs_f64(),
mflops(m, k, n, t_base),
);
// ── RLE encode (best of `trials`) ────────────────────────────────────────
let (rle_blocks, t_enc) = bench(trials, || -> Vec<BlockQ4KRle> {
blocks.iter().map(encode).collect()
});
for block in &rle_blocks {
println!("Got value {:?}", block);
for pair in block.qs {
println!("top {} bottom {}", (pair >> 4), (pair & 0b1111));
}
break;
}
let n_rle = rle_blocks.iter().filter(|b| b.is_rle()).count();
let n_raw = n_blocks - n_rle;
let avg_pairs = if n_rle > 0 {
rle_blocks.iter()
.filter(|b| b.is_rle())
.map(|b| b.rle_len() as f64)
.sum::<f64>() / n_rle as f64
} else { 0.0 };
println!(
"Encode : {:.3} s RLE {n_rle}/{n_blocks} blocks ({:.1}%), \
raw {n_raw}/{n_blocks} ({:.1}%), avg {avg_pairs:.1} pairs/RLE block",
t_enc.as_secs_f64(),
100.0 * n_rle as f64 / n_blocks as f64,
100.0 * n_raw as f64 / n_blocks as f64,
);
// ── RLE matmul (best of `trials`) ────────────────────────────────────────
let (c_rle, t_rle) = bench(trials, || {
matmul_q4k_rle_fp16(&rle_blocks, &b_fp16, m, k, n)
});
println!(
"RLE : {:.3} s {:.0} MFLOP/s",
t_rle.as_secs_f64(),
mflops(m, k, n, t_rle),
);
// ── Correctness check ────────────────────────────────────────────────────
let max_err = c_base
.iter()
.zip(&c_rle)
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let ok = max_err < 1e-2;
println!(
"\nMax |baseline rle| = {max_err:.2e} {}",
if ok { "" } else { "✗ (unexpectedly large — check block layout)" }
);
// ── Summary ──────────────────────────────────────────────────────────────
println!();
println!("Speedup (matmul only): {:.2}×", t_base.as_secs_f64() / t_rle.as_secs_f64());
println!("Speedup (matmul + encode once): {:.2}×",
t_base.as_secs_f64() / (t_rle + t_enc).as_secs_f64());
// Show a small slice of the output so it's clear something real happened.
let show = n.min(4);
print!("First {show} output(s) of row 0: ");
for j in 0..show {
print!("{:.4} ", c_base[j]);
}
println!();
if !ok {
std::process::exit(1);
}
Ok(())
}