add: rl2
This commit is contained in:
@@ -18,6 +18,8 @@
|
|||||||
//! dequantise each row of A into f32, convert each element of B from fp16 to
|
//! dequantise each row of A into f32, convert each element of B from fp16 to
|
||||||
//! f32, accumulate dot-products. No SIMD, no tiling, no tricks.
|
//! f32, accumulate dot-products. No SIMD, no tiling, no tricks.
|
||||||
|
|
||||||
|
pub mod rle;
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Constants matching GGML's ggml-common.h
|
// Constants matching GGML's ggml-common.h
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|||||||
992
src/rle.rs
Normal file
992
src/rle.rs
Normal file
@@ -0,0 +1,992 @@
|
|||||||
|
//! RLE-optional Q4_K super-block encoding.
|
||||||
|
//!
|
||||||
|
//! This module provides [`BlockQ4KRle`], a variant of [`crate::BlockQ4K`] that
|
||||||
|
//! optionally compresses the 128-byte weight payload using **byte-level
|
||||||
|
//! run-length encoding** (RLE). A flag bit in the [`BlockQ4KRle::flags`]
|
||||||
|
//! field indicates which mode is active:
|
||||||
|
//!
|
||||||
|
//! | `IS_RLE` bit | `qs` interpretation |
|
||||||
|
//! |--------------|------------------------------------------------------------|
|
||||||
|
//! | 0 | Raw packed nibbles, identical to [`crate::BlockQ4K::qs`] |
|
||||||
|
//! | 1 | RLE stream of `(value, count)` byte-pairs |
|
||||||
|
//!
|
||||||
|
//! ## RLE format (when `IS_RLE` = 1)
|
||||||
|
//!
|
||||||
|
//! - `flags >> 1` gives the number of `(value, count)` pairs stored in `qs`.
|
||||||
|
//! - For each pair `i`:
|
||||||
|
//! - `qs[2*i]` — the byte value (two packed 4-bit weights, same packing
|
||||||
|
//! as the raw format).
|
||||||
|
//! - `qs[2*i + 1]` — the run length in bytes (1..=255).
|
||||||
|
//! - The run lengths must sum to exactly 128 (the uncompressed `qs` size).
|
||||||
|
//!
|
||||||
|
//! RLE encoding is chosen only when the compressed representation is
|
||||||
|
//! **strictly shorter** than the 128-byte raw payload, i.e. when
|
||||||
|
//! `pairs * 2 < 128`. That caps the useful range at ≤ 63 pairs. The 7-bit
|
||||||
|
//! `flags >> 1` sub-field can hold up to 127, so this ceiling is never a
|
||||||
|
//! concern in practice.
|
||||||
|
//!
|
||||||
|
//! ## Constructing blocks
|
||||||
|
//!
|
||||||
|
//! Use [`encode`] to convert an existing [`crate::BlockQ4K`] into a
|
||||||
|
//! [`BlockQ4KRle`]. The function automatically selects the better mode.
|
||||||
|
//!
|
||||||
|
//! ## Adding this module to your crate
|
||||||
|
//!
|
||||||
|
//! Add `pub mod rle;` to `lib.rs`.
|
||||||
|
|
||||||
|
use crate::{fp16_to_f32, get_scale_min, BlockQ4K, K_SCALE_SIZE, QK_K};
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Flag constants
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Flag bit in [`BlockQ4KRle::flags`]: if set, `qs` contains an RLE stream.
|
||||||
|
pub const IS_RLE: u8 = 0x01;
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Block definition
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// A Q4_K super-block with optional byte-level RLE compression on the weights.
|
||||||
|
///
|
||||||
|
/// Identical to [`crate::BlockQ4K`] except for the additional [`flags`](Self::flags)
|
||||||
|
/// byte inserted between `scales` and `qs`.
|
||||||
|
///
|
||||||
|
/// Memory layout (repr C):
|
||||||
|
///
|
||||||
|
/// | Offset | Field | Size | Notes |
|
||||||
|
/// |--------|------------|-------|--------------------------------|
|
||||||
|
/// | 0 | `d` | 2 B | fp16 super-block scale |
|
||||||
|
/// | 2 | `dmin` | fp16 super-block min scale | 2 B |
|
||||||
|
/// | 4 | `scales` | 12 B | packed 6-bit sub-block params |
|
||||||
|
/// | 16 | `flags` | 1 B | encoding flags (see below) |
|
||||||
|
/// | 17 | `qs` | 128 B | raw nibbles or RLE stream |
|
||||||
|
/// | 145 | (padding) | 1 B | implicit trailing alignment pad|
|
||||||
|
///
|
||||||
|
/// **sizeof = 146 bytes** (padded to 2-byte alignment imposed by `u16` fields).
|
||||||
|
///
|
||||||
|
/// ## `flags` bit layout
|
||||||
|
///
|
||||||
|
/// | Bits | Meaning |
|
||||||
|
/// |------|---------------------------------------------------------------|
|
||||||
|
/// | 0 | [`IS_RLE`] — 1 = `qs` is RLE-encoded, 0 = raw packed nibbles |
|
||||||
|
/// | 1–7 | When `IS_RLE`=1: number of `(value, count)` pairs in `qs` |
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct BlockQ4KRle {
|
||||||
|
/// Super-block scale for quantised sub-block scales (fp16 bits).
|
||||||
|
pub d: u16,
|
||||||
|
/// Super-block scale for quantised sub-block mins (fp16 bits).
|
||||||
|
pub dmin: u16,
|
||||||
|
/// Packed 6-bit sub-block scales and mins (same layout as [`crate::BlockQ4K`]).
|
||||||
|
pub scales: [u8; K_SCALE_SIZE],
|
||||||
|
/// Encoding flags. Bit 0 = [`IS_RLE`]. Bits 1-7 = RLE pair count when
|
||||||
|
/// `IS_RLE` is set.
|
||||||
|
pub flags: u8,
|
||||||
|
/// Raw packed-nibble weights (`IS_RLE` = 0) or RLE byte stream (`IS_RLE` = 1).
|
||||||
|
pub qs: [u8; QK_K / 2],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BlockQ4KRle {
|
||||||
|
/// Returns `true` when `qs` holds an RLE-encoded stream.
|
||||||
|
#[inline]
|
||||||
|
pub fn is_rle(&self) -> bool {
|
||||||
|
self.flags & IS_RLE != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of `(value, count)` byte-pairs stored at the start of `qs`.
|
||||||
|
///
|
||||||
|
/// Only meaningful when [`is_rle`](Self::is_rle) returns `true`.
|
||||||
|
#[inline]
|
||||||
|
pub fn rle_len(&self) -> usize {
|
||||||
|
(self.flags >> 1) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Encoding
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Encode a [`BlockQ4K`] block into a [`BlockQ4KRle`] block.
|
||||||
|
///
|
||||||
|
/// The 128-byte `qs` payload is scanned for runs of identical bytes. If the
|
||||||
|
/// RLE representation fits in the same 128-byte field **and is strictly
|
||||||
|
/// shorter** than the raw payload, it is stored with `IS_RLE` set. Otherwise
|
||||||
|
/// the raw bytes are copied unchanged and `IS_RLE` is cleared.
|
||||||
|
///
|
||||||
|
/// The `d`, `dmin`, and `scales` fields are always copied verbatim.
|
||||||
|
pub fn encode(block: &BlockQ4K) -> BlockQ4KRle {
|
||||||
|
let raw = &block.qs;
|
||||||
|
|
||||||
|
// Scan the 128-byte raw payload for runs of equal bytes.
|
||||||
|
let mut pairs: Vec<(u8, u8)> = Vec::with_capacity(64);
|
||||||
|
let mut i = 0usize;
|
||||||
|
while i < raw.len() {
|
||||||
|
let val = raw[i];
|
||||||
|
// Count consecutive equal bytes; saturate at u8::MAX to stay in-range.
|
||||||
|
let mut run = 1u8;
|
||||||
|
while i + (run as usize) < raw.len()
|
||||||
|
&& raw[i + (run as usize)] == val
|
||||||
|
&& run < u8::MAX
|
||||||
|
{
|
||||||
|
run += 1;
|
||||||
|
}
|
||||||
|
pairs.push((val, run));
|
||||||
|
i += run as usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only switch to RLE when the encoded form is strictly smaller than the
|
||||||
|
// raw payload. Because each pair costs 2 bytes and the raw payload is
|
||||||
|
// 128 bytes, the condition pairs.len() * 2 < 128 also guarantees that
|
||||||
|
// pairs.len() ≤ 63, which fits in bits 1-7 of the flags byte.
|
||||||
|
if pairs.len() * 2 < raw.len() {
|
||||||
|
let n = pairs.len();
|
||||||
|
debug_assert!(n <= 63, "RLE pair count {n} unexpectedly exceeds 63");
|
||||||
|
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
for (k, &(val, count)) in pairs.iter().enumerate() {
|
||||||
|
qs[2 * k] = val;
|
||||||
|
qs[2 * k + 1] = count;
|
||||||
|
}
|
||||||
|
|
||||||
|
BlockQ4KRle {
|
||||||
|
d: block.d,
|
||||||
|
dmin: block.dmin,
|
||||||
|
scales: block.scales,
|
||||||
|
flags: IS_RLE | ((n as u8) << 1),
|
||||||
|
qs,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No space savings — copy raw bytes and leave IS_RLE clear.
|
||||||
|
BlockQ4KRle {
|
||||||
|
d: block.d,
|
||||||
|
dmin: block.dmin,
|
||||||
|
scales: block.scales,
|
||||||
|
flags: 0,
|
||||||
|
qs: block.qs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Decoding helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Expand the `qs` field of a [`BlockQ4KRle`] block into the 128-byte raw
|
||||||
|
/// packed-nibble array, handling both raw and RLE modes transparently.
|
||||||
|
///
|
||||||
|
/// # Panics (debug builds only)
|
||||||
|
///
|
||||||
|
/// Panics if the decoded RLE stream does not sum to exactly 128 bytes.
|
||||||
|
fn decode_qs(block: &BlockQ4KRle) -> [u8; QK_K / 2] {
|
||||||
|
if !block.is_rle() {
|
||||||
|
return block.qs;
|
||||||
|
}
|
||||||
|
|
||||||
|
let n = block.rle_len();
|
||||||
|
let mut raw = [0u8; QK_K / 2];
|
||||||
|
let mut pos = 0usize;
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
let val = block.qs[2 * i];
|
||||||
|
let count = block.qs[2 * i + 1] as usize;
|
||||||
|
raw[pos..pos + count].fill(val);
|
||||||
|
pos += count;
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_assert_eq!(
|
||||||
|
pos,
|
||||||
|
QK_K / 2,
|
||||||
|
"RLE run lengths sum to {pos}, expected {}",
|
||||||
|
QK_K / 2
|
||||||
|
);
|
||||||
|
raw
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Dequantisation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Dequantise one [`BlockQ4KRle`] super-block into [`QK_K`] (256) `f32` values.
|
||||||
|
///
|
||||||
|
/// When `IS_RLE` is set the RLE stream is first expanded into a 128-byte raw
|
||||||
|
/// buffer; thereafter the dequantisation is identical to
|
||||||
|
/// [`crate::dequantize_block_q4k`]:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// out[i] = d * scale[s] * nibble[i] - dmin * min[s]
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// where `s` is the sub-block index (0..8) that the element belongs to.
|
||||||
|
pub fn dequantize_block_q4k_rle(block: &BlockQ4KRle, out: &mut [f32; QK_K]) {
|
||||||
|
let d = fp16_to_f32(block.d);
|
||||||
|
let dmin = fp16_to_f32(block.dmin);
|
||||||
|
let qs = decode_qs(block);
|
||||||
|
|
||||||
|
let mut q_off = 0usize; // byte cursor into the raw qs array
|
||||||
|
let mut out_off = 0usize; // element cursor into `out`
|
||||||
|
let mut is = 0usize; // sub-block pair index (0, 2, 4, 6)
|
||||||
|
|
||||||
|
while out_off < QK_K {
|
||||||
|
let (sc1, mn1) = get_scale_min(is, &block.scales);
|
||||||
|
let (sc2, mn2) = get_scale_min(is + 1, &block.scales);
|
||||||
|
|
||||||
|
let d1 = d * sc1 as f32;
|
||||||
|
let m1 = dmin * mn1 as f32;
|
||||||
|
let d2 = d * sc2 as f32;
|
||||||
|
let m2 = dmin * mn2 as f32;
|
||||||
|
|
||||||
|
for l in 0..32 {
|
||||||
|
out[out_off + l] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1;
|
||||||
|
}
|
||||||
|
for l in 0..32 {
|
||||||
|
out[out_off + 32 + l] = d2 * (qs[q_off + l] >> 4) as f32 - m2;
|
||||||
|
}
|
||||||
|
|
||||||
|
q_off += 32;
|
||||||
|
out_off += 64;
|
||||||
|
is += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Matrix multiplication C = A × B
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Multiply a Q4_K_RLE matrix **A** by an FP16 matrix **B**, producing an f32
|
||||||
|
/// matrix **C**.
|
||||||
|
///
|
||||||
|
/// Identical semantics to [`crate::matmul_q4k_fp16`] but accepts
|
||||||
|
/// [`BlockQ4KRle`] blocks. Each block is dequantised on the fly via
|
||||||
|
/// [`dequantize_block_q4k_rle`], transparently handling mixed raw/RLE blocks
|
||||||
|
/// within the same matrix.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `a` – Row-major slice of [`BlockQ4KRle`]. Row `i` occupies blocks
|
||||||
|
/// `a[i * blocks_per_row .. (i+1) * blocks_per_row]`.
|
||||||
|
/// * `b` – Row-major fp16 matrix stored as raw `u16` bits, shape \[K, N\].
|
||||||
|
/// Element `(ki, j)` is at index `ki * n + j`.
|
||||||
|
/// * `m` – Number of rows in A (and C).
|
||||||
|
/// * `k` – Number of columns in A = number of rows in B.
|
||||||
|
/// **Must** be a multiple of [`QK_K`] (256).
|
||||||
|
/// * `n` – Number of columns in B (and C).
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A flat row-major `Vec<f32>` of shape \[M, N\].
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `k` is not a multiple of `QK_K`, or if the lengths of `a` or `b`
|
||||||
|
/// do not match the declared dimensions.
|
||||||
|
pub fn matmul_q4k_rle_fp16(
|
||||||
|
a: &[BlockQ4KRle],
|
||||||
|
b: &[u16],
|
||||||
|
m: usize,
|
||||||
|
k: usize,
|
||||||
|
n: usize,
|
||||||
|
) -> Vec<f32> {
|
||||||
|
assert_eq!(
|
||||||
|
k % QK_K,
|
||||||
|
0,
|
||||||
|
"k ({k}) must be a multiple of QK_K ({QK_K})"
|
||||||
|
);
|
||||||
|
let blocks_per_row = k / QK_K;
|
||||||
|
assert_eq!(
|
||||||
|
a.len(),
|
||||||
|
m * blocks_per_row,
|
||||||
|
"A block count mismatch: expected {} blocks, got {}",
|
||||||
|
m * blocks_per_row,
|
||||||
|
a.len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
b.len(),
|
||||||
|
k * n,
|
||||||
|
"B element count mismatch: expected {}, got {}",
|
||||||
|
k * n,
|
||||||
|
b.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut c = vec![0.0f32; m * n];
|
||||||
|
let mut a_row = vec![0.0f32; k];
|
||||||
|
let mut block_buf = [0.0f32; QK_K];
|
||||||
|
|
||||||
|
for i in 0..m {
|
||||||
|
// Dequantise row i of A into a_row (f32).
|
||||||
|
for b_idx in 0..blocks_per_row {
|
||||||
|
let block = &a[i * blocks_per_row + b_idx];
|
||||||
|
dequantize_block_q4k_rle(block, &mut block_buf);
|
||||||
|
let start = b_idx * QK_K;
|
||||||
|
a_row[start..start + QK_K].copy_from_slice(&block_buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dot-product with each column of B.
|
||||||
|
for j in 0..n {
|
||||||
|
let mut sum = 0.0f32;
|
||||||
|
for ki in 0..k {
|
||||||
|
sum += a_row[ki] * fp16_to_f32(b[ki * n + j]);
|
||||||
|
}
|
||||||
|
c[i * n + j] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{dequantize_block_q4k, matmul_q4k_fp16, BlockQ4K};
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------
|
||||||
|
// Test helpers
|
||||||
|
// -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Convert a normal finite f32 to its IEEE 754 fp16 bit pattern.
|
||||||
|
///
|
||||||
|
/// Panics if the value falls outside the representable fp16 normal range.
|
||||||
|
fn f32_to_fp16_bits(f: f32) -> u16 {
|
||||||
|
if f == 0.0 { return 0x0000; }
|
||||||
|
if f == f32::INFINITY { return 0x7C00; }
|
||||||
|
if f == f32::NEG_INFINITY { return 0xFC00; }
|
||||||
|
if f.is_nan() { return 0x7E00; }
|
||||||
|
|
||||||
|
let bits = f.to_bits();
|
||||||
|
let sign = ((bits >> 31) as u16) << 15;
|
||||||
|
let exp = (bits >> 23) & 0xFF;
|
||||||
|
let mant = bits & 0x007F_FFFF;
|
||||||
|
let fp16_exp = exp as i32 - 127 + 15;
|
||||||
|
assert!(
|
||||||
|
fp16_exp > 0 && fp16_exp < 31,
|
||||||
|
"f32 value {f} is outside the representable fp16 normal range"
|
||||||
|
);
|
||||||
|
sign | ((fp16_exp as u16) << 10) | ((mant >> 13) as u16)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a [`BlockQ4K`] where all 8 sub-blocks share the same `scale` and
|
||||||
|
/// `min` (both < 16), and every byte in `qs` is `qs_byte`.
|
||||||
|
fn make_block(d: f32, dmin: f32, scale: u8, min: u8, qs_byte: u8) -> BlockQ4K {
|
||||||
|
assert!(
|
||||||
|
scale < 16 && min < 16,
|
||||||
|
"make_block: scale ({scale}) and min ({min}) must both be < 16"
|
||||||
|
);
|
||||||
|
let mut scales = [0u8; K_SCALE_SIZE];
|
||||||
|
for j in 0..4 {
|
||||||
|
scales[j] = scale;
|
||||||
|
scales[j + 4] = min;
|
||||||
|
}
|
||||||
|
for j in 8..12 {
|
||||||
|
scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4);
|
||||||
|
}
|
||||||
|
BlockQ4K {
|
||||||
|
d: f32_to_fp16_bits(d),
|
||||||
|
dmin: f32_to_fp16_bits(dmin),
|
||||||
|
scales,
|
||||||
|
qs: [qs_byte; QK_K / 2],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a [`BlockQ4K`] with a custom `qs` array.
|
||||||
|
fn make_block_with_qs(
|
||||||
|
d: f32,
|
||||||
|
dmin: f32,
|
||||||
|
scale: u8,
|
||||||
|
min: u8,
|
||||||
|
qs: [u8; QK_K / 2],
|
||||||
|
) -> BlockQ4K {
|
||||||
|
assert!(
|
||||||
|
scale < 16 && min < 16,
|
||||||
|
"make_block_with_qs: scale ({scale}) and min ({min}) must both be < 16"
|
||||||
|
);
|
||||||
|
let mut scales = [0u8; K_SCALE_SIZE];
|
||||||
|
for j in 0..4 {
|
||||||
|
scales[j] = scale;
|
||||||
|
scales[j + 4] = min;
|
||||||
|
}
|
||||||
|
for j in 8..12 {
|
||||||
|
scales[j] = (scale & 0x0F) | ((min & 0x0F) << 4);
|
||||||
|
}
|
||||||
|
BlockQ4K {
|
||||||
|
d: f32_to_fp16_bits(d),
|
||||||
|
dmin: f32_to_fp16_bits(dmin),
|
||||||
|
scales,
|
||||||
|
qs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_close(got: f32, expected: f32, tol: f32) {
|
||||||
|
assert!(
|
||||||
|
(got - expected).abs() <= tol,
|
||||||
|
"got {got}, expected {expected} (tol {tol})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_all_close(got: &[f32], expected_scalar: f32, tol: f32) {
|
||||||
|
for (i, &g) in got.iter().enumerate() {
|
||||||
|
assert!(
|
||||||
|
(g - expected_scalar).abs() <= tol,
|
||||||
|
"element {i}: got {g}, expected {expected_scalar} (tol {tol})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_slices_close(got: &[f32], expected: &[f32], tol: f32) {
|
||||||
|
assert_eq!(got.len(), expected.len(), "slice length mismatch");
|
||||||
|
for (i, (&g, &e)) in got.iter().zip(expected.iter()).enumerate() {
|
||||||
|
assert!(
|
||||||
|
(g - e).abs() <= tol,
|
||||||
|
"element {i}: got {g}, expected {e} (tol {tol})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fp16_uniform(k: usize, n: usize, value: f32) -> Vec<u16> {
|
||||||
|
vec![f32_to_fp16_bits(value); k * n]
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Struct layout
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn block_q4k_rle_size_is_146_bytes() {
|
||||||
|
// d(2) + dmin(2) + scales(12) + flags(1) + qs(128) = 145 raw bytes,
|
||||||
|
// rounded up to 146 by the 2-byte alignment imposed by the u16 fields.
|
||||||
|
assert_eq!(core::mem::size_of::<BlockQ4KRle>(), 146);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn block_q4k_rle_is_two_bytes_larger_than_block_q4k() {
|
||||||
|
assert_eq!(
|
||||||
|
core::mem::size_of::<BlockQ4KRle>(),
|
||||||
|
core::mem::size_of::<BlockQ4K>() + 2,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// is_rle / rle_len
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_rle_false_when_flag_clear() {
|
||||||
|
let b = BlockQ4KRle {
|
||||||
|
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2],
|
||||||
|
};
|
||||||
|
assert!(!b.is_rle());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rle_len_zero_when_flag_clear() {
|
||||||
|
let b = BlockQ4KRle {
|
||||||
|
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs: [0; QK_K / 2],
|
||||||
|
};
|
||||||
|
assert_eq!(b.rle_len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_rle_true_when_flag_set() {
|
||||||
|
let b = BlockQ4KRle {
|
||||||
|
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
|
||||||
|
flags: IS_RLE | (5u8 << 1),
|
||||||
|
qs: [0; QK_K / 2],
|
||||||
|
};
|
||||||
|
assert!(b.is_rle());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rle_len_reports_pair_count_from_flags() {
|
||||||
|
for n in [0usize, 1, 7, 31, 63] {
|
||||||
|
let b = BlockQ4KRle {
|
||||||
|
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
|
||||||
|
flags: IS_RLE | ((n as u8) << 1),
|
||||||
|
qs: [0; QK_K / 2],
|
||||||
|
};
|
||||||
|
assert_eq!(b.rle_len(), n, "expected rle_len {n}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// encode: mode selection
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_uniform_qs_uses_rle() {
|
||||||
|
// 128 identical bytes → 1 pair → 2 bytes < 128 raw.
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x77);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(rle.is_rle(), "uniform qs should trigger RLE mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_uniform_qs_rle_len_is_one() {
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x55);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert_eq!(rle.rle_len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_uniform_qs_rle_entry_is_correct() {
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0xAB);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert_eq!(rle.qs[0], 0xAB, "RLE value byte should equal the repeated byte");
|
||||||
|
assert_eq!(rle.qs[1], 128, "RLE run length should be 128 bytes");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_alternating_bytes_stays_raw() {
|
||||||
|
// 128 single-byte runs → 128 pairs → 256 bytes ≥ 128 raw → raw mode.
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
for (i, b) in qs.iter_mut().enumerate() {
|
||||||
|
*b = if i % 2 == 0 { 0xAA } else { 0x55 };
|
||||||
|
}
|
||||||
|
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(!rle.is_rle(), "alternating bytes cannot be compressed → raw mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_raw_mode_copies_qs_verbatim() {
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
for (i, b) in qs.iter_mut().enumerate() {
|
||||||
|
// Three-byte cycle of distinct values → 128 runs of 1 byte each.
|
||||||
|
*b = match i % 3 { 0 => 0x11, 1 => 0x22, _ => 0x33 };
|
||||||
|
}
|
||||||
|
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(!rle.is_rle());
|
||||||
|
assert_eq!(rle.qs, qs, "raw mode must preserve qs bytes unchanged");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_two_runs_uses_rle_and_stores_correct_pairs() {
|
||||||
|
// Two distinct runs: 64 bytes of 0x11 followed by 64 bytes of 0x22.
|
||||||
|
// → 2 pairs = 4 bytes < 128 bytes raw.
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
qs[..64].fill(0x11);
|
||||||
|
qs[64..].fill(0x22);
|
||||||
|
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(rle.is_rle());
|
||||||
|
assert_eq!(rle.rle_len(), 2);
|
||||||
|
assert_eq!(rle.qs[0], 0x11, "first pair: value");
|
||||||
|
assert_eq!(rle.qs[1], 64, "first pair: run length");
|
||||||
|
assert_eq!(rle.qs[2], 0x22, "second pair: value");
|
||||||
|
assert_eq!(rle.qs[3], 64, "second pair: run length");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_63_pairs_uses_rle() {
|
||||||
|
// Build 62 runs of 2 bytes each (124 bytes) + 1 run of 4 bytes = 128 bytes.
|
||||||
|
// 63 pairs × 2 = 126 bytes < 128 → RLE should be chosen.
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
let mut pos = 0usize;
|
||||||
|
for run in 0..62usize {
|
||||||
|
// Use a stride-3 sequence so consecutive values are always distinct.
|
||||||
|
let v = (run as u8).wrapping_mul(3).wrapping_add(1);
|
||||||
|
qs[pos] = v;
|
||||||
|
qs[pos + 1] = v;
|
||||||
|
pos += 2;
|
||||||
|
}
|
||||||
|
// Final run: 4 bytes, value chosen to differ from the previous one.
|
||||||
|
qs[pos..].fill(0xFE);
|
||||||
|
|
||||||
|
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(rle.is_rle(), "63 pairs should use RLE");
|
||||||
|
assert_eq!(rle.rle_len(), 63);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_64_pairs_stays_raw() {
|
||||||
|
// 64 runs of 2 bytes each = 128 bytes total.
|
||||||
|
// 64 pairs × 2 = 128 bytes, which is NOT strictly less than 128 → raw.
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
let mut pos = 0usize;
|
||||||
|
for run in 0..64usize {
|
||||||
|
let v = (run as u8).wrapping_mul(3).wrapping_add(1);
|
||||||
|
qs[pos] = v;
|
||||||
|
qs[pos + 1] = v;
|
||||||
|
pos += 2;
|
||||||
|
}
|
||||||
|
let src = make_block_with_qs(1.0, 0.0, 1, 0, qs);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(!rle.is_rle(), "64 pairs offers no saving → raw mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_preserves_d_dmin_scales() {
|
||||||
|
let src = make_block(2.0, 0.5, 3, 2, 0x00);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert_eq!(rle.d, src.d);
|
||||||
|
assert_eq!(rle.dmin, src.dmin);
|
||||||
|
assert_eq!(rle.scales, src.scales);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// decode_qs (tested indirectly through dequantise, but also directly)
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decode_qs_raw_mode_returns_qs_unchanged() {
|
||||||
|
// Build a raw BlockQ4KRle (flags = 0) with a non-trivial qs pattern.
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
for (i, b) in qs.iter_mut().enumerate() { *b = i as u8; }
|
||||||
|
let rle = BlockQ4KRle {
|
||||||
|
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE], flags: 0, qs,
|
||||||
|
};
|
||||||
|
assert_eq!(decode_qs(&rle), qs);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decode_qs_rle_expands_two_pair_stream() {
|
||||||
|
// Hand-craft an RLE block: [0xAA × 64, 0xBB × 64].
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
qs[0] = 0xAA; qs[1] = 64;
|
||||||
|
qs[2] = 0xBB; qs[3] = 64;
|
||||||
|
let rle = BlockQ4KRle {
|
||||||
|
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
|
||||||
|
flags: IS_RLE | (2u8 << 1),
|
||||||
|
qs,
|
||||||
|
};
|
||||||
|
let expanded = decode_qs(&rle);
|
||||||
|
assert!(expanded[..64].iter().all(|&b| b == 0xAA), "first 64 bytes must be 0xAA");
|
||||||
|
assert!(expanded[64..].iter().all(|&b| b == 0xBB), "last 64 bytes must be 0xBB");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decode_qs_rle_single_run_covers_all() {
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
qs[0] = 0xCD; qs[1] = 128; // one run of 128 bytes
|
||||||
|
let rle = BlockQ4KRle {
|
||||||
|
d: 0, dmin: 0, scales: [0; K_SCALE_SIZE],
|
||||||
|
flags: IS_RLE | (1u8 << 1),
|
||||||
|
qs,
|
||||||
|
};
|
||||||
|
let expanded = decode_qs(&rle);
|
||||||
|
assert!(expanded.iter().all(|&b| b == 0xCD));
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// dequantize_block_q4k_rle
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dequant_rle_zero_d_all_outputs_zero() {
|
||||||
|
let src = make_block(0.0, 0.0, 1, 0, 0x77);
|
||||||
|
let rle = encode(&src);
|
||||||
|
let mut out = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut out);
|
||||||
|
assert_all_close(&out, 0.0, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dequant_rle_uniform_nibble_one_scale_one() {
|
||||||
|
// qs_byte = 0x11 → both nibbles = 1; scale = 1, d = 1.0, min = 0.
|
||||||
|
// expected: 1.0 * 1 * 1 - 0.0 = 1.0
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||||
|
let rle = encode(&src);
|
||||||
|
let mut out = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut out);
|
||||||
|
assert_all_close(&out, 1.0, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dequant_rle_non_zero_min_subtracts() {
|
||||||
|
// nibble = 0, scale = 1, d = 1.0, min = 2, dmin = 1.0
|
||||||
|
// expected: 1.0 * 1 * 0 - 1.0 * 2 = -2.0
|
||||||
|
let src = make_block(1.0, 1.0, 1, 2, 0x00);
|
||||||
|
let rle = encode(&src);
|
||||||
|
let mut out = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut out);
|
||||||
|
assert_all_close(&out, -2.0, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dequant_rle_max_nibble_15() {
|
||||||
|
// qs_byte = 0xFF → both nibbles = 15; scale = 1, d = 1.0, min = 0.
|
||||||
|
// expected: 1.0 * 1 * 15 - 0.0 = 15.0
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0xFF);
|
||||||
|
let rle = encode(&src);
|
||||||
|
let mut out = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut out);
|
||||||
|
assert_all_close(&out, 15.0, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dequant_rle_output_count_is_qk_k() {
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||||
|
let rle = encode(&src);
|
||||||
|
let mut out = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut out);
|
||||||
|
assert_eq!(out.len(), QK_K);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dequant_rle_larger_scale_multiplies() {
|
||||||
|
// nibble = 3, scale = 4, d = 2.0, min = 0
|
||||||
|
// expected: 2.0 * 4 * 3 - 0.0 = 24.0
|
||||||
|
// qs_byte = 0x33 → both nibbles = 3
|
||||||
|
let src = make_block(2.0, 0.0, 4, 0, 0x33);
|
||||||
|
let rle = encode(&src);
|
||||||
|
let mut out = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut out);
|
||||||
|
assert_all_close(&out, 24.0, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Roundtrip: encode → dequantize must match original dequantize
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn roundtrip_rle_mode_matches_original() {
|
||||||
|
// Uniform qs → RLE mode selected.
|
||||||
|
let src = make_block(2.0, 0.5, 3, 1, 0x37);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(rle.is_rle());
|
||||||
|
|
||||||
|
let mut got = [0.0f32; QK_K];
|
||||||
|
let mut expected = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut got);
|
||||||
|
dequantize_block_q4k(&src, &mut expected);
|
||||||
|
assert_slices_close(&got, &expected, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn roundtrip_raw_mode_matches_original() {
|
||||||
|
// Alternating bytes → raw mode selected; output must still be correct.
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
for (i, b) in qs.iter_mut().enumerate() {
|
||||||
|
*b = if i % 2 == 0 { 0x13 } else { 0x24 };
|
||||||
|
}
|
||||||
|
let src = make_block_with_qs(1.5, 0.25, 2, 1, qs);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(!rle.is_rle());
|
||||||
|
|
||||||
|
let mut got = [0.0f32; QK_K];
|
||||||
|
let mut expected = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut got);
|
||||||
|
dequantize_block_q4k(&src, &mut expected);
|
||||||
|
assert_slices_close(&got, &expected, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn roundtrip_two_run_block_matches_original() {
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
qs[..64].fill(0x59);
|
||||||
|
qs[64..].fill(0x8C);
|
||||||
|
let src = make_block_with_qs(3.0, 1.0, 5, 2, qs);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(rle.is_rle());
|
||||||
|
|
||||||
|
let mut got = [0.0f32; QK_K];
|
||||||
|
let mut expected = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut got);
|
||||||
|
dequantize_block_q4k(&src, &mut expected);
|
||||||
|
assert_slices_close(&got, &expected, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn roundtrip_many_short_runs_matches_original() {
|
||||||
|
// Four distinct runs of varying lengths → still compresses.
|
||||||
|
let mut qs = [0u8; QK_K / 2];
|
||||||
|
qs[..10].fill(0x11);
|
||||||
|
qs[10..30].fill(0x22);
|
||||||
|
qs[30..31].fill(0x33);
|
||||||
|
qs[31..].fill(0x44);
|
||||||
|
let src = make_block_with_qs(1.0, 0.5, 7, 3, qs);
|
||||||
|
let rle = encode(&src);
|
||||||
|
assert!(rle.is_rle(), "4-run block should compress");
|
||||||
|
assert_eq!(rle.rle_len(), 4);
|
||||||
|
|
||||||
|
let mut got = [0.0f32; QK_K];
|
||||||
|
let mut expected = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut got);
|
||||||
|
dequantize_block_q4k(&src, &mut expected);
|
||||||
|
assert_slices_close(&got, &expected, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn roundtrip_zero_qs_matches_original() {
|
||||||
|
let src = make_block(1.0, 0.5, 2, 1, 0x00);
|
||||||
|
let rle = encode(&src);
|
||||||
|
let mut got = [0.0f32; QK_K];
|
||||||
|
let mut expected = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut got);
|
||||||
|
dequantize_block_q4k(&src, &mut expected);
|
||||||
|
assert_slices_close(&got, &expected, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn roundtrip_nibble_split_low_high_correct() {
|
||||||
|
// qs_byte = 0x37: low nibble = 7 (sub-block 0 path), high nibble = 3
|
||||||
|
// (sub-block 1 path). Verify both halves are dequantised correctly.
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x37);
|
||||||
|
let rle = encode(&src);
|
||||||
|
let mut got = [0.0f32; QK_K];
|
||||||
|
let mut expected = [0.0f32; QK_K];
|
||||||
|
dequantize_block_q4k_rle(&rle, &mut got);
|
||||||
|
dequantize_block_q4k(&src, &mut expected);
|
||||||
|
assert_slices_close(&got, &expected, 1e-5);
|
||||||
|
// First 32 elements of each 64-element group → low nibble = 7.
|
||||||
|
assert_close(got[0], 7.0, 1e-5);
|
||||||
|
// Next 32 elements → high nibble = 3.
|
||||||
|
assert_close(got[32], 3.0, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// matmul_q4k_rle_fp16
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_1x256_times_256x1_all_ones() {
|
||||||
|
// A: 1×256, all weights = nibble 1, scale = 1, d = 1.0
|
||||||
|
// B: 256×1, all fp16 1.0
|
||||||
|
// C = dot([1.0; 256], [1.0; 256]) = 256.0
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||||
|
let a = vec![encode(&src)];
|
||||||
|
let b = fp16_uniform(QK_K, 1, 1.0);
|
||||||
|
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 1);
|
||||||
|
assert_eq!(c.len(), 1);
|
||||||
|
assert_close(c[0], 256.0, 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_2x256_times_256x3_all_ones() {
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||||
|
let a = vec![encode(&src), encode(&src)];
|
||||||
|
let b = fp16_uniform(QK_K, 3, 1.0);
|
||||||
|
let c = matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 3);
|
||||||
|
assert_eq!(c.len(), 6);
|
||||||
|
assert_all_close(&c, 256.0, 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_zero_a_gives_zero_c() {
|
||||||
|
let src = make_block(0.0, 0.0, 1, 0, 0xFF);
|
||||||
|
let a = vec![encode(&src)];
|
||||||
|
let b = fp16_uniform(QK_K, 4, 1.0);
|
||||||
|
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 4);
|
||||||
|
assert_all_close(&c, 0.0, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_zero_b_gives_zero_c() {
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||||
|
let a = vec![encode(&src)];
|
||||||
|
let b = fp16_uniform(QK_K, 2, 0.0);
|
||||||
|
let c = matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 2);
|
||||||
|
assert_all_close(&c, 0.0, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_two_blocks_per_row() {
|
||||||
|
// A: 1×512, two blocks, all nibble-1 weights; B: 512×1, all 1.0.
|
||||||
|
// Expected: 512.0
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||||
|
let a = vec![encode(&src), encode(&src)];
|
||||||
|
let b = fp16_uniform(2 * QK_K, 1, 1.0);
|
||||||
|
let c = matmul_q4k_rle_fp16(&a, &b, 1, 2 * QK_K, 1);
|
||||||
|
assert_eq!(c.len(), 1);
|
||||||
|
assert_close(c[0], 512.0, 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_output_shape_m_times_n() {
|
||||||
|
// A: 3×512 (6 blocks), B: 512×4 → C: 3×4 = 12 elements.
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||||
|
let a: Vec<BlockQ4KRle> = (0..6).map(|_| encode(&src)).collect();
|
||||||
|
let b = fp16_uniform(2 * QK_K, 4, 0.0);
|
||||||
|
let c = matmul_q4k_rle_fp16(&a, &b, 3, 2 * QK_K, 4);
|
||||||
|
assert_eq!(c.len(), 12);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_scalar_b_scales_output() {
|
||||||
|
// Multiplying B by a scalar should scale C by the same factor.
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x22); // nibble 2 → weight 2.0
|
||||||
|
let a = vec![encode(&src)];
|
||||||
|
let b1 = fp16_uniform(QK_K, 1, 1.0);
|
||||||
|
let b2 = fp16_uniform(QK_K, 1, 3.0);
|
||||||
|
let c1 = matmul_q4k_rle_fp16(&a, &b1, 1, QK_K, 1);
|
||||||
|
let c2 = matmul_q4k_rle_fp16(&a, &b2, 1, QK_K, 1);
|
||||||
|
assert_close(c2[0], c1[0] * 3.0, 1e-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_matches_original_matmul_mixed_blocks() {
|
||||||
|
// Mix: first block uniform (RLE), second block alternating (raw).
|
||||||
|
// Both matmul implementations should produce identical results.
|
||||||
|
let src_rle = make_block(2.0, 0.5, 3, 1, 0x37);
|
||||||
|
|
||||||
|
let mut qs_raw = [0u8; QK_K / 2];
|
||||||
|
for (i, b) in qs_raw.iter_mut().enumerate() {
|
||||||
|
*b = if i % 2 == 0 { 0x13 } else { 0x24 };
|
||||||
|
}
|
||||||
|
let src_raw = make_block_with_qs(1.5, 0.25, 2, 1, qs_raw);
|
||||||
|
|
||||||
|
let a_orig: Vec<BlockQ4K> = vec![src_rle, src_raw];
|
||||||
|
let a_rle: Vec<BlockQ4KRle> = a_orig.iter().map(encode).collect();
|
||||||
|
|
||||||
|
// A: 1×512, B: 512×2
|
||||||
|
let b = fp16_uniform(2 * QK_K, 2, 1.0);
|
||||||
|
let c_orig = matmul_q4k_fp16(&a_orig, &b, 1, 2 * QK_K, 2);
|
||||||
|
let c_rle = matmul_q4k_rle_fp16(&a_rle, &b, 1, 2 * QK_K, 2);
|
||||||
|
assert_slices_close(&c_rle, &c_orig, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_multiple_rows_multiple_blocks_per_row() {
|
||||||
|
// A: 2×512 (4 blocks), B: 512×3, all weights 1 in A, all 1.0 in B.
|
||||||
|
// Each row dot product = 512.0; C should be all 512.0.
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x11);
|
||||||
|
let a: Vec<BlockQ4KRle> = (0..4).map(|_| encode(&src)).collect();
|
||||||
|
let b = fp16_uniform(2 * QK_K, 3, 1.0);
|
||||||
|
let c = matmul_q4k_rle_fp16(&a, &b, 2, 2 * QK_K, 3);
|
||||||
|
assert_eq!(c.len(), 6);
|
||||||
|
assert_all_close(&c, 512.0, 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Panic / contract checks
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_panics_when_k_not_multiple_of_qkk() {
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||||
|
let a = vec![encode(&src)];
|
||||||
|
let b = vec![0u16; 512];
|
||||||
|
let result = std::panic::catch_unwind(move || {
|
||||||
|
matmul_q4k_rle_fp16(&a, &b, 1, 512, 2);
|
||||||
|
});
|
||||||
|
assert!(result.is_err(), "should panic when k is not a multiple of QK_K");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_panics_on_wrong_a_length() {
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||||
|
// m=2, k=QK_K requires 2 blocks; only 1 is provided.
|
||||||
|
let a = vec![encode(&src)];
|
||||||
|
let b = fp16_uniform(QK_K, 1, 1.0);
|
||||||
|
let result = std::panic::catch_unwind(move || {
|
||||||
|
matmul_q4k_rle_fp16(&a, &b, 2, QK_K, 1);
|
||||||
|
});
|
||||||
|
assert!(result.is_err(), "should panic on wrong A block count");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_rle_panics_on_wrong_b_length() {
|
||||||
|
let src = make_block(1.0, 0.0, 1, 0, 0x00);
|
||||||
|
let a = vec![encode(&src)];
|
||||||
|
// B is too short for k=QK_K, n=3.
|
||||||
|
let b = vec![0u16; 10];
|
||||||
|
let result = std::panic::catch_unwind(move || {
|
||||||
|
matmul_q4k_rle_fp16(&a, &b, 1, QK_K, 3);
|
||||||
|
});
|
||||||
|
assert!(result.is_err(), "should panic on wrong B element count");
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user