Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 95 additions & 53 deletions picolm/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,54 +42,69 @@ typedef struct {
size_t size;
} reader_t;

static void gguf_ensure(reader_t *r, size_t n) {
if (r->pos + n > r->size) {
fprintf(stderr, "GGUF read out of bounds\n");
exit(1);
}
}

static uint8_t read_u8(reader_t *r) {
gguf_ensure(r, 1);
uint8_t v = r->data[r->pos];
r->pos += 1;
return v;
}

static uint16_t read_u16(reader_t *r) {
gguf_ensure(r, 2);
uint16_t v;
memcpy(&v, r->data + r->pos, 2);
r->pos += 2;
return v;
}

static uint32_t read_u32(reader_t *r) {
gguf_ensure(r, 4);
uint32_t v;
memcpy(&v, r->data + r->pos, 4);
r->pos += 4;
return v;
}

static int32_t read_i32(reader_t *r) {
gguf_ensure(r, 4);
int32_t v;
memcpy(&v, r->data + r->pos, 4);
r->pos += 4;
return v;
}

static uint64_t read_u64(reader_t *r) {
gguf_ensure(r, 8);
uint64_t v;
memcpy(&v, r->data + r->pos, 8);
r->pos += 8;
return v;
}

static float read_f32(reader_t *r) {
gguf_ensure(r, 4);
float v;
memcpy(&v, r->data + r->pos, 4);
r->pos += 4;
return v;
}

/* Important: GGUF strings are views into the mmap buffer (no malloc, no leaks). */
typedef struct { const char *str; uint64_t len; } gguf_str_t;

static gguf_str_t read_gguf_string(reader_t *r) {
gguf_str_t s;
s.len = read_u64(r);
gguf_ensure(r, (size_t)s.len);
s.str = (const char *)(r->data + r->pos);
r->pos += s.len;
r->pos += (size_t)s.len;
return s;
}

Expand All @@ -110,9 +125,9 @@ static uint64_t skip_meta_value(reader_t *r, uint32_t vtype, int *is_numeric) {
case GGUF_META_UINT64: return read_u64(r);
case GGUF_META_INT64: return read_u64(r);
case GGUF_META_FLOAT32: { read_f32(r); *is_numeric = 0; return 0; }
case GGUF_META_FLOAT64: { r->pos += 8; *is_numeric = 0; return 0; }
case GGUF_META_FLOAT64: { gguf_ensure(r, 8); r->pos += 8; *is_numeric = 0; return 0; }
case GGUF_META_BOOL: return read_u8(r);
case GGUF_META_STRING: { read_gguf_string(r); *is_numeric = 0; return 0; }
case GGUF_META_STRING: { (void)read_gguf_string(r); *is_numeric = 0; return 0; }
case GGUF_META_ARRAY: {
*is_numeric = 0;
uint32_t arr_type = read_u32(r);
Expand Down Expand Up @@ -141,7 +156,11 @@ static int mmap_file(model_t *m, const char *path) {
}

LARGE_INTEGER fsize;
GetFileSizeEx(fh, &fsize);
if (!GetFileSizeEx(fh, &fsize)) {
fprintf(stderr, "GetFileSizeEx failed\n");
CloseHandle(fh);
return -1;
}
m->mmap_size = (size_t)fsize.QuadPart;

HANDLE mh = CreateFileMappingA(fh, NULL, PAGE_READONLY, 0, 0, NULL);
Expand All @@ -165,21 +184,25 @@ static int mmap_file(model_t *m, const char *path) {
#else
int fd = open(path, O_RDONLY);
if (fd < 0) {
fprintf(stderr, "Cannot open file: %s\n", path);
perror("open");
return -1;
}

struct stat st;
fstat(fd, &st);
if (fstat(fd, &st) != 0) {
perror("fstat");
close(fd);
return -1;
}
m->mmap_size = (size_t)st.st_size;

void *addr = mmap(NULL, m->mmap_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (addr == MAP_FAILED) {
fprintf(stderr, "mmap failed\n");
perror("mmap");
close(fd);
return -1;
}
madvise(addr, m->mmap_size, MADV_SEQUENTIAL);
(void)madvise(addr, m->mmap_size, MADV_SEQUENTIAL);

m->mmap_addr = addr;
m->fd = fd;
Expand All @@ -202,10 +225,20 @@ static void munmap_file(model_t *m) {

/* ---- GGUF Parser ---- */

/* Align 'x' up to the next multiple of 'a'. Works for any a >= 1. */
static size_t align_up_any(size_t x, size_t a) {
if (a == 0) return x; /* defensive */
size_t rem = x % a;
return rem ? (x + (a - rem)) : x;
}

static int parse_gguf(model_t *m, int max_seq_len) {
reader_t r = { .data = (const uint8_t *)m->mmap_addr, .pos = 0, .size = m->mmap_size };
model_config_t *cfg = &m->config;

/* Start from a known state in case some GGUF metadata fields are missing. */
memset(cfg, 0, sizeof(*cfg));

uint32_t magic = read_u32(&r);
if (magic != GGUF_MAGIC) {
fprintf(stderr, "Invalid GGUF magic: 0x%08X\n", magic);
Expand Down Expand Up @@ -277,19 +310,46 @@ static int parse_gguf(model_t *m, int max_seq_len) {
} else {
uint32_t arr_type = read_u32(&r);
uint64_t arr_len = read_u64(&r);
(void)arr_type;
m->tok_scores_data = r.data + r.pos;
m->tok_n_scores = arr_len;
r.pos += arr_len * 4;

/* scores should be float32; if not, skip safely */
if (arr_type != GGUF_META_FLOAT32) {
int dummy;
for (uint64_t j = 0; j < arr_len; j++) {
skip_meta_value(&r, arr_type, &dummy);
}
} else {
gguf_ensure(&r, (size_t)arr_len * 4);
m->tok_scores_data = r.data + r.pos;
m->tok_n_scores = arr_len;
r.pos += (size_t)arr_len * 4;
}
}
} else {
int dummy; skip_meta_value(&r, vtype, &dummy);
}
}

/* ---- Validate and default required config ---- */
if (cfg->n_embd <= 0 || cfg->n_heads <= 0 || cfg->n_layers <= 0) {
fprintf(stderr, "Invalid model config in GGUF metadata (n_embd=%d n_heads=%d n_layers=%d)\n",
cfg->n_embd, cfg->n_heads, cfg->n_layers);
return -1;
}
if (cfg->n_kv_heads <= 0) {
/* Some GGUFs omit head_count_kv; default to MHA. */
cfg->n_kv_heads = cfg->n_heads;
}
if (cfg->max_seq_len <= 0) cfg->max_seq_len = 2048;
if (cfg->alignment <= 0) cfg->alignment = 32;

if (max_seq_len > 0 && max_seq_len < cfg->max_seq_len) {
cfg->max_seq_len = max_seq_len;
}
if (cfg->n_embd % cfg->n_heads != 0) {
fprintf(stderr, "Invalid head configuration: n_embd=%d not divisible by n_heads=%d\n",
cfg->n_embd, cfg->n_heads);
return -1;
}
cfg->head_dim = cfg->n_embd / cfg->n_heads;

/* Parse tensor info entries */
Expand All @@ -315,7 +375,7 @@ static int parse_gguf(model_t *m, int max_seq_len) {
}

size_t alignment = (size_t)cfg->alignment;
size_t tensor_data_base = (r.pos + alignment - 1) & ~(alignment - 1);
size_t tensor_data_base = align_up_any(r.pos, alignment);

model_weights_t *w = &m->weights;
memset(w, 0, sizeof(*w));
Expand All @@ -342,13 +402,11 @@ static int parse_gguf(model_t *m, int max_seq_len) {
layer = layer * 10 + (*p - '0');
p++;
}
if (p < end && *p == '.') {
p++;
size_t slen = (size_t)(end - p);
if (slen < sizeof(suffix)) {
memcpy(suffix, p, slen);
suffix[slen] = '\0';
}
if (p < end && *p == '.') p++;
size_t slen = (size_t)(end - p);
if (slen < sizeof(suffix)) {
memcpy(suffix, p, slen);
suffix[slen] = '\0';
}
}

Expand Down Expand Up @@ -461,7 +519,7 @@ static int allocate_run_state(model_t *m) {
sz_scratch + sz_rope + sz_norm;

/* FP16 KV cache: separate allocation */
size_t kv_elements = (size_t)c->n_layers * c->max_seq_len * kv_dim;
size_t kv_elements = (size_t)c->n_layers * (size_t)c->max_seq_len * (size_t)kv_dim;
size_t sz_kv = kv_elements * sizeof(uint16_t) * 2; /* key + val */

fprintf(stderr, "Allocating %.2f MB for runtime state (+ %.2f MB FP16 KV cache)\n",
Expand All @@ -480,6 +538,7 @@ static int allocate_run_state(model_t *m) {
if (!s->kv_block) {
fprintf(stderr, "OOM: cannot allocate %zu bytes for KV cache\n", sz_kv);
free(s->mem_block);
s->mem_block = NULL;
return -1;
}
s->kv_size = sz_kv;
Expand Down Expand Up @@ -539,8 +598,15 @@ int model_load(model_t *m, const char *path, int max_seq_len) {
memset(m, 0, sizeof(*m));

if (mmap_file(m, path) != 0) return -1;
if (parse_gguf(m, max_seq_len) != 0) return -1;
if (allocate_run_state(m) != 0) return -1;

if (parse_gguf(m, max_seq_len) != 0) {
munmap_file(m);
return -1;
}
if (allocate_run_state(m) != 0) {
munmap_file(m);
return -1;
}

return 0;
}
Expand Down Expand Up @@ -613,47 +679,25 @@ float *model_forward(model_t *m, int token, int pos) {
val_pos_fp16[d] = fp32_to_fp16(v_tmp[d]);
}

/* ---- Flash Attention (online softmax) ----
*
* Instead of materializing the full [n_heads * seq_len] score array,
* compute attention in a single pass using the online softmax trick:
*
* max_s = -inf, sum_exp = 0, acc[d] = 0
* for each cached position t:
* s = dot(Q_h, K_t) / sqrt(d)
* if s > max_s:
* correction = exp(max_s - s)
* acc *= correction, sum_exp *= correction
* sum_exp += 1, acc += V_t
* max_s = s
* else:
* w = exp(s - max_s)
* sum_exp += w, acc += w * V_t
* result = acc / sum_exp
*
* This saves memory (no att[] buffer) and is more cache-friendly.
*/
/* ---- Flash Attention (online softmax) ---- */
for (int h = 0; h < n_heads; h++) {
float *qh = s->q + h * head_dim;
int kv_h = h / kv_mul;
float *xbh = s->xb + h * head_dim;

float max_score = -1e30f;
float sum_exp = 0.0f;
/* Accumulator for weighted V values */
float acc[256]; /* head_dim is typically 64-128 */
float acc[256];
memset(acc, 0, (size_t)head_dim * sizeof(float));

for (int t = 0; t <= pos; t++) {
/* Compute score: dot(Q_h, K_t) / sqrt(head_dim) */
const uint16_t *kt = kcache_layer + (size_t)t * kv_dim + kv_h * head_dim;
float score = 0.0f;
for (int d = 0; d < head_dim; d++) {
score += qh[d] * fp16_to_fp32(kt[d]);
}
score /= sqrtf((float)head_dim);

/* Online softmax update */
const uint16_t *vt = vcache_layer + (size_t)t * kv_dim + kv_h * head_dim;

if (score > max_score) {
Expand All @@ -664,15 +708,14 @@ float *model_forward(model_t *m, int token, int pos) {
}
max_score = score;
} else {
float w = expf(score - max_score);
sum_exp += w;
float ww = expf(score - max_score);
sum_exp += ww;
for (int d = 0; d < head_dim; d++) {
acc[d] += w * fp16_to_fp32(vt[d]);
acc[d] += ww * fp16_to_fp32(vt[d]);
}
}
}

/* Normalize */
float inv_sum = 1.0f / sum_exp;
for (int d = 0; d < head_dim; d++) {
xbh[d] = acc[d] * inv_sum;
Expand Down Expand Up @@ -750,7 +793,6 @@ int kvcache_save(const model_t *m, const char *path, int n_pos) {
};
fwrite(header, sizeof(uint32_t), 4, f);

/* Write KV cache for each layer, only the first n_pos positions */
size_t row_size = (size_t)kv_dim * sizeof(uint16_t);
for (int l = 0; l < c->n_layers; l++) {
const uint16_t *kcache_l = m->state.key_cache + (size_t)l * seq_len * kv_dim;
Expand Down Expand Up @@ -830,4 +872,4 @@ int kvcache_load(model_t *m, const char *path) {
fclose(f);
fprintf(stderr, "KV cache loaded: %d positions from %s\n", n_pos, path);
return n_pos;
}
}