Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions generate_expert_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#!/usr/bin/env python3
"""
generate_expert_index.py — Generate expert_index.json from Qwen3.5-397B-A17B safetensors.

Scans safetensors file headers to build a mapping of (layer, component) -> (file, offset, stride).
This index is required by repack_experts.py.

Usage:
python generate_expert_index.py [--model PATH] [--output expert_index.json]
"""

import argparse
import json
import os
import re
import struct
import sys
from collections import defaultdict
from pathlib import Path


# Expected component sizes per expert (bytes)
COMPONENT_SIZES = {
"gate_proj.weight": 2097152, # [1024, 512] uint32
"gate_proj.scales": 131072, # [1024, 64] uint16
"gate_proj.biases": 131072, # [1024, 64] uint16
"up_proj.weight": 2097152, # [1024, 512] uint32
"up_proj.scales": 131072, # [1024, 64] uint16
"up_proj.biases": 131072, # [1024, 64] uint16
"down_proj.weight": 2097152, # [4096, 128] uint32
"down_proj.scales": 131072, # [4096, 16] uint16
"down_proj.biases": 131072, # [4096, 16] uint16
}

NUM_EXPERTS = 512
NUM_LAYERS = 60

# Pattern: language_model.model.layers.{L}.mlp.switch_mlp.{component}
EXPERT_PATTERN = re.compile(
r'^language_model\.model\.layers\.(\d+)\.mlp\.switch_mlp\.((?:gate|up|down)_proj\.(?:weight|scales|biases))$'
)


def parse_safetensors_header(filepath):
"""Parse a safetensors file header. Returns (header_dict, data_start_offset)."""
with open(filepath, 'rb') as f:
header_len = struct.unpack('<Q', f.read(8))[0]
header = json.loads(f.read(header_len))
data_start = 8 + header_len
return header, data_start


def main():
parser = argparse.ArgumentParser(description='Generate expert_index.json from safetensors')
parser.add_argument('--model', type=str, required=True,
help='Path to model directory (containing safetensors files)')
parser.add_argument('--output', type=str, default='expert_index.json',
help='Output path for expert_index.json')
args = parser.parse_args()

model_path = Path(args.model)

# Load weight index
index_path = model_path / 'model.safetensors.index.json'
if not index_path.exists():
print(f"ERROR: {index_path} not found", file=sys.stderr)
sys.exit(1)

with open(index_path) as f:
idx = json.load(f)

weight_map = idx['weight_map']

# Find all expert tensors and group by (layer, component) -> filename
expert_tensors = {} # (layer_idx, component) -> (tensor_name, filename)
for name, filename in weight_map.items():
m = EXPERT_PATTERN.match(name)
if m:
layer_idx = int(m.group(1))
component = m.group(2)
expert_tensors[(layer_idx, component)] = (name, filename)

print(f"Model: {model_path}")
print(f"Found {len(expert_tensors)} expert tensors")
print(f"Expected: {NUM_LAYERS * len(COMPONENT_SIZES)} = {NUM_LAYERS} layers x {len(COMPONENT_SIZES)} components")

if len(expert_tensors) != NUM_LAYERS * len(COMPONENT_SIZES):
print("WARNING: tensor count mismatch", file=sys.stderr)

# Parse safetensors headers for all needed files
needed_files = set(fn for _, fn in expert_tensors.values())
print(f"\nParsing {len(needed_files)} safetensors file headers...")

header_cache = {}
for filename in sorted(needed_files):
filepath = model_path / filename
header_cache[filename] = parse_safetensors_header(str(filepath))
print(f" {filename}: header parsed")

# Build expert_reads index
expert_reads = defaultdict(dict)

for (layer_idx, component), (tensor_name, filename) in sorted(expert_tensors.items()):
header, data_start = header_cache[filename]

if tensor_name not in header:
# Skip __metadata__ key
if tensor_name == '__metadata__':
continue
print(f"WARNING: {tensor_name} not in {filename} header", file=sys.stderr)
continue

meta = header[tensor_name]
tensor_offset = meta['data_offsets'][0]
tensor_size = meta['data_offsets'][1] - meta['data_offsets'][0]

# The tensor contains all 512 experts contiguously
# expert_size = total_tensor_size / num_experts
expert_size = tensor_size // NUM_EXPERTS

expected_size = COMPONENT_SIZES.get(component)
if expected_size and expert_size != expected_size:
print(f"WARNING: {tensor_name}: computed expert_size={expert_size}, "
f"expected={expected_size}", file=sys.stderr)

# abs_offset = data section start + tensor's offset within data section
abs_offset = data_start + tensor_offset

# expert_stride = expert_size (experts are packed contiguously)
expert_stride = expert_size

expert_reads[str(layer_idx)][component] = {
"file": filename,
"abs_offset": abs_offset,
"expert_stride": expert_stride,
"expert_size": expert_size,
}

# Verify completeness
complete = True
for layer_idx in range(NUM_LAYERS):
layer_key = str(layer_idx)
if layer_key not in expert_reads:
print(f"ERROR: layer {layer_idx} missing entirely", file=sys.stderr)
complete = False
continue
for comp in COMPONENT_SIZES:
if comp not in expert_reads[layer_key]:
print(f"ERROR: layer {layer_idx} missing {comp}", file=sys.stderr)
complete = False

if not complete:
print("\nERROR: Index is incomplete", file=sys.stderr)
sys.exit(1)

# Write output
output = {
"model_path": str(model_path),
"expert_reads": dict(expert_reads),
}

with open(args.output, 'w') as f:
json.dump(output, f, indent=2)

print(f"\nWrote {args.output}")
print(f" {len(expert_reads)} layers, {len(COMPONENT_SIZES)} components each")
print(f" Total: {len(expert_reads) * len(COMPONENT_SIZES)} entries")


if __name__ == '__main__':
main()
9 changes: 6 additions & 3 deletions metal_infer/export_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
For each entry: uint32 token_id, uint16 str_len, char[str_len]
"""
import json
import os
import struct
import sys

def main():
import glob
default_tok = glob.glob(os.path.expanduser(
'~/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit'
'/snapshots/*/tokenizer.json'))
tok_path = sys.argv[1] if len(sys.argv) > 1 else (
'/Users/danielwoods/.cache/huggingface/hub/'
'models--mlx-community--Qwen3.5-397B-A17B-4bit/'
'snapshots/39159bd8aa74f5c8446d2b2dc584f62bb51cb0d3/tokenizer.json'
default_tok[0] if default_tok else 'tokenizer.json'
)
out_path = sys.argv[2] if len(sys.argv) > 2 else 'tokenizer.bin'

Expand Down
72 changes: 72 additions & 0 deletions metal_infer/export_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
"""Export vocab.bin in the simple format expected by infer.m's load_vocab().

Format:
[num_entries: uint32] [max_id: uint32]
For each entry (0..max_id): [byte_len: uint16] [utf8_bytes: byte_len]

Usage:
python export_vocab.py <tokenizer.json> [output.bin]
"""
import json
import os
import struct
import sys

def main():
tok_path = sys.argv[1]
out_path = sys.argv[2] if len(sys.argv) > 2 else 'vocab.bin'

with open(tok_path, 'r', encoding='utf-8') as f:
t = json.load(f)

vocab = t['model']['vocab'] # str -> int
added = t.get('added_tokens', []) # list of {id, content, ...}

# Merge added tokens into vocab
for tok in added:
vocab[tok['content']] = tok['id']

max_id = max(vocab.values())
num_entries = max_id + 1

# BPE byte-level encoding uses Unicode chars for bytes:
# Ġ (U+0120) = space, Ċ (U+010A) = newline, etc.
# Build the reverse mapping to decode these back to real bytes.
bs = list(range(ord("!"), ord("~")+1)) + list(range(ord("¡"), ord("¬")+1)) + list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(256):
if b not in bs:
bs.append(b)
cs.append(256 + n)
n += 1
bpe_decode_map = {chr(c): bytes([b]) for b, c in zip(bs, cs)}

def decode_bpe_token(s):
"""Convert BPE token string to actual bytes."""
try:
return b''.join(bpe_decode_map.get(ch, ch.encode('utf-8')) for ch in s)
except Exception:
return s.encode('utf-8')

# Build id -> string mapping with BPE decoding
id_to_str = {}
for s, tid in vocab.items():
id_to_str[tid] = decode_bpe_token(s)

with open(out_path, 'wb') as f:
f.write(struct.pack('<I', num_entries))
f.write(struct.pack('<I', max_id))

for i in range(num_entries):
b = id_to_str.get(i, b'')
f.write(struct.pack('<H', len(b)))
if b:
f.write(b)

sz = os.path.getsize(out_path)
print(f"Exported vocab.bin: {num_entries} entries (max_id={max_id}), {sz / 1024 / 1024:.1f} MB")

if __name__ == '__main__':
main()
45 changes: 34 additions & 11 deletions metal_infer/infer.m
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include <unistd.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <dirent.h>
#include <sys/time.h>
#include <math.h>
#include <getopt.h>
Expand Down Expand Up @@ -123,7 +124,8 @@
#define THINK_START_TOKEN 248068 // <think>
#define THINK_END_TOKEN 248069 // </think>

#define MODEL_PATH_DEFAULT "/Users/danielwoods/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit/snapshots/39159bd8aa74f5c8446d2b2dc584f62bb51cb0d3"
// MODEL_PATH_DEFAULT is resolved at runtime via get_default_model_path() below
#define MODEL_PATH_DEFAULT NULL
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this, since this not used anywhere


// ============================================================================
// Timing helper
Expand Down Expand Up @@ -6283,9 +6285,35 @@ static void print_usage(const char *prog) {
printf(" --help This message\n");
}

// Resolve default model path at runtime using $HOME
static const char *get_default_model_path(void) {
static char path[1024];
const char *home = getenv("HOME");
if (!home) home = "/tmp";
// Find the first snapshot directory
char snapshots_dir[1024];
snprintf(snapshots_dir, sizeof(snapshots_dir),
"%s/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit/snapshots", home);
DIR *d = opendir(snapshots_dir);
if (d) {
struct dirent *entry;
while ((entry = readdir(d)) != NULL) {
if (entry->d_name[0] != '.') {
snprintf(path, sizeof(path), "%s/%s", snapshots_dir, entry->d_name);
closedir(d);
return path;
}
}
closedir(d);
}
snprintf(path, sizeof(path),
"%s/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit/snapshots/39159bd8aa74f5c8446d2b2dc584f62bb51cb0d3", home);
return path;
}

int main(int argc, char **argv) {
@autoreleasepool {
const char *model_path = MODEL_PATH_DEFAULT;
const char *model_path = get_default_model_path();
const char *weights_path = NULL;
const char *manifest_path = NULL;
const char *vocab_path = NULL;
Expand Down Expand Up @@ -6517,15 +6545,10 @@ int main(int argc, char **argv) {
fcntl(layer_fds[i], F_RDAHEAD, 1);
struct stat st;
if (fstat(layer_fds[i], &st) == 0 && st.st_size > 0) {
layer_mmaps[i] = mmap(NULL, st.st_size, PROT_READ, MAP_PRIVATE, layer_fds[i], 0);
if (layer_mmaps[i] != MAP_FAILED) {
layer_mmap_sizes[i] = st.st_size;
// No madvise: kernel default is best.
// MADV_RANDOM disables readahead (tested: hurts).
// MADV_SEQUENTIAL doesn't reduce I/O fragmentation (tested: no effect).
// The kernel fragments 3.9MB preads into ~5.7 disk ops regardless
// of hints — this is inherent to the page cache's physical page layout.
}
// Skip mmap for expert files — 120GB of mmap reservations
// can trigger OOM kills on systems with memory pressure.
// The engine falls back to pread() which works fine.
layer_mmap_sizes[i] = st.st_size;
}
}
}
Expand Down
30 changes: 27 additions & 3 deletions metal_infer/main.m
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <getopt.h>
#include <pthread.h>
#include <errno.h>
#include <dirent.h>

// ============================================================================
// Constants matching the Qwen3.5-397B packed expert layout
Expand Down Expand Up @@ -77,8 +78,8 @@

#define EXPERT_SIZE 7077888 // Total bytes per expert

// Default model path
#define MODEL_PATH "/Users/danielwoods/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit/snapshots/39159bd8aa74f5c8446d2b2dc584f62bb51cb0d3"
// Default model path — resolved at runtime, can be overridden with --model
#define MODEL_PATH NULL

// ============================================================================
// Timing helper
Expand Down Expand Up @@ -1504,7 +1505,30 @@ int main(int argc, char **argv) {
int num_active_experts = 4; // --k flag
int do_verify = 0;
int use_fast = 0;
const char *model_path = MODEL_PATH;
// Resolve default model path at runtime using $HOME
static char default_model_path[1024];
const char *home = getenv("HOME");
if (!home) home = "/tmp";
{
char snapshots_dir[1024];
snprintf(snapshots_dir, sizeof(snapshots_dir),
"%s/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit/snapshots", home);
DIR *d = opendir(snapshots_dir);
if (d) {
struct dirent *entry;
while ((entry = readdir(d)) != NULL) {
if (entry->d_name[0] != '.') {
snprintf(default_model_path, sizeof(default_model_path), "%s/%s", snapshots_dir, entry->d_name);
break;
}
}
closedir(d);
} else {
snprintf(default_model_path, sizeof(default_model_path),
"%s/.cache/huggingface/hub/models--mlx-community--Qwen3.5-397B-A17B-4bit/snapshots/39159bd8aa74f5c8446d2b2dc584f62bb51cb0d3", home);
}
}
const char *model_path = default_model_path;

static struct option long_options[] = {
{"layer", required_argument, 0, 'l'},
Expand Down
Loading