Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/null/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def test_simplex_lt(self):

def test_where_removal(self):
cond = Variable("a", 0, 3) < 2
u1, u0 = cond.ufix(1), cond.ufix(0)
u1, u0 = cond.const_like(True), cond.const_like(False)
self.helper_test_variable(cond, 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
Expand Down
39 changes: 38 additions & 1 deletion test/unit/test_gguf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, struct, unittest, sys
import os, struct, unittest, tempfile, pathlib, sys
from tinygrad import dtypes, Tensor, fetch, Device
from tinygrad.helpers import disable_gc
from tinygrad.llm.gguf import _ggml_iq_grid, ggml_data_to_tensor, gguf_load
Expand Down Expand Up @@ -115,6 +115,43 @@ def test_expected_failure_unknown_type(self):
with self.assertRaises(ValueError):
ggml_data_to_tensor(Tensor.empty(512, dtype=dtypes.uint8), 256, 1337)

def test_multi_part_load(self):
def build(n_total, part_no, tensors):
# [header] [kv_data] [tensor_infos] [padding] [tensor_data_blob]
buf = bytearray()
# Header: magic "GGUF" + version=3 + n_tensors + n_kv=2
buf += struct.pack("<4siqq", b"GGUF", 3, len(tensors), 2)
# KV entries: [key_len: uint64][key bytes][type: int32][value]
for k, v in [("split.count", n_total), ("split.no", part_no)]:
kb = k.encode()
buf += struct.pack("<Q", len(kb)) + kb + struct.pack("<i", 4) + struct.pack("<I", v)
data_off = 0
# Tensor infos: [name_len][name][ndims][dims reversed][qtype][offset_into_data_blob]
for name, dims, qtype, data in tensors:
nb = name.encode()
buf += struct.pack("<Q", len(nb)) + nb + struct.pack("<I", len(dims))
for d in reversed(dims): buf += struct.pack("<Q", d)
buf += struct.pack("<i", qtype) + struct.pack("<Q", data_off)
data_off += len(data)
buf += b"\x00" * ((32 - len(buf) % 32) % 32)
for _, _, _, data in tensors: buf += data
return bytes(buf)

with tempfile.TemporaryDirectory() as d:
d = pathlib.Path(d)
a, b = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), np.array([5.0, 6.0], dtype=np.float32)
(d / "test-00001-of-00002.gguf").write_bytes(build(2, 0, [("a", (4,), 0, a.tobytes())]))
(d / "test-00002-of-00002.gguf").write_bytes(build(2, 1, [("b", (2,), 0, b.tobytes())]))
kv, ts = gguf_load(d / "test-00001-of-00002.gguf")
self.assertEqual(kv["split.count"], 2)
np.testing.assert_equal(ts["a"].numpy(), a)
np.testing.assert_equal(ts["b"].numpy(), b)

# missing part 2
(d / "test-00002-of-00002.gguf").unlink()
with self.assertRaises(FileNotFoundError):
gguf_load(d / "test-00001-of-00002.gguf")

def _test_dequantization(self, qtype: GGMLQuantizationType):
block_size, type_size = GGML_QUANT_SIZES[qtype]
n_el, n_bytes = ggml_test_block_count * block_size, ggml_test_block_count * type_size
Expand Down
12 changes: 6 additions & 6 deletions tinygrad/llm/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations
import sys, argparse, codecs, typing, re, unicodedata, json, uuid, time, pathlib
from tinygrad import Tensor, nn
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context
from tinygrad import nn
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context, fetch
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
from tinygrad.llm.model import Transformer

Expand Down Expand Up @@ -190,11 +191,10 @@ def main():
args = parser.parse_args()

# load the model
raw_model = Tensor.from_url(models.get(args.model, args.model))
model, kv = Transformer.from_gguf(raw_model, args.max_context)
model, kv = Transformer.from_gguf(fetch(models.get(args.model, args.model)), args.max_context)
model_name = kv.get('general.name') or kv.get('general.basename') or args.model
print(f"using model \"{model_name}\" with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
del raw_model
file_sizes = [y.nbytes() for y in UOp.sink(*[x.uop for x in nn.state.get_parameters(model)]).toposort() if y.op is Ops.BUFFER]
print(f"using model \"{model_name}\" with {sum(file_sizes):,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")

# get tokenizer
tok = SimpleTokenizer.from_gguf_kv(kv)
Expand Down
50 changes: 32 additions & 18 deletions tinygrad/llm/gguf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import functools, io, struct
import functools, io, pathlib, re, struct
from typing import Any, Callable

from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, round_up
from tinygrad.nn.state import TensorIO, accept_filename
from tinygrad.nn.state import TensorIO

# ggml packs each iq grid entry as N bytes (N=4 for uint32 grids, N=8 for uint64 grids) in a single word. See ggml-common.h.
@functools.lru_cache(None)
Expand Down Expand Up @@ -131,22 +131,7 @@ def read_arr(r:io.BufferedIOBase):
[ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]

@accept_filename
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
"""
Loads a .gguf file, returning the `kv_data` and `state_dict`.

```python
import pathlib
from tinygrad import Device, Tensor
from tinygrad.llm.gguf import gguf_load

gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
kv_data, state_dict = gguf_load(gguf_tensor)
```

NOTE: The provided tensor must be on a device that supports execution.
"""
def _gguf_parse(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
r = io.BufferedReader(TensorIO(tensor), 1_000_000)
magic, version, n_tensors, n_kv = r.read(4), read_int32(r), read_int64(r), read_int64(r)
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
Expand All @@ -162,3 +147,32 @@ def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:

state_dict = {name: ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims)) for name, dims, typ, off in t_infos}
return kv_data, state_dict

def _gguf_split_paths(path: pathlib.Path, kv: dict) -> list[pathlib.Path]:
if (total := kv.get('split.count', 1)) <= 1: return [path]
if kv.get('split.no', 0) != 0: raise ValueError(f"multi-part GGUF must be loaded from the first split, got split.no={kv['split.no']}")
if not (m := re.match(r"^(.*)-00001-of-\d{5}\.gguf$", str(path))): raise ValueError(f"first split path must end with -00001-of-NNNNN.gguf: {path}")
return [pathlib.Path(f"{m.group(1)}-{i:05d}-of-{total:05d}.gguf") for i in range(1, total+1)]

def gguf_load(fn: Tensor|str|pathlib.Path) -> tuple[dict, dict[str, Tensor]]:
"""
Loads a .gguf file, returning the `kv_data` and `state_dict`. Multi-part splits are auto-merged when loaded by path.

```python
import pathlib
from tinygrad import Device, Tensor
from tinygrad.llm.gguf import gguf_load

gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
kv_data, state_dict = gguf_load(gguf_tensor)
```

NOTE: The provided tensor must be on a device that supports execution.
"""
# TODO: remove the need for copy to default device
def load(p): return _gguf_parse(p if isinstance(p, Tensor) else Tensor(p).to(None).realize())
kv, sd = load(fn)
if kv.get('split.count', 1) <= 1: return kv, sd
if isinstance(fn, Tensor): raise ValueError("multi-part GGUF requires a path argument (got Tensor)")
for pp in _gguf_split_paths(pathlib.Path(fn), kv)[1:]: sd.update(load(pp)[1])
return kv, sd
24 changes: 15 additions & 9 deletions tinygrad/llm/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
import functools, itertools
import functools, itertools, pathlib
from dataclasses import dataclass, replace
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
from tinygrad.llm.gguf import gguf_load
Expand Down Expand Up @@ -69,6 +69,8 @@ class TransformerConfig:
leading_dense_blocks: int = 0
dense_hidden_dim: int = 0
routed_scaling_factor: float = 1.0
qkv_bias: bool = False
expert_bias: bool = False

class FFNBlock:
def __init__(self, config:TransformerConfig):
Expand All @@ -81,7 +83,7 @@ def __init__(self, config:TransformerConfig):
# --- feed-forward (MoE or dense) -------------------------------------
if config.num_experts > 0:
self.ffn_gate_inp = nn.Linear(config.dim, config.num_experts, bias=False) # router
if config.kv_lora_rank > 0: self.exp_probs_b = {"bias": Tensor.zeros(config.num_experts)}
if config.expert_bias: self.exp_probs_b = {"bias": Tensor.zeros(config.num_experts)}
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
self.ffn_down_exps = ExpertWeights(config.num_experts, config.hidden_dim, config.dim)
Expand Down Expand Up @@ -142,9 +144,9 @@ def __init__(self, config:TransformerConfig):
# --- attention projections (all linear, bias-free) ------------------
q_proj_out = config.head_dim * config.n_heads * (2 if config.attn_output_gate else 1)
kv_proj_out = config.head_dim * config.n_kv_heads
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=config.qkv_bias)
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=config.qkv_bias)
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=config.qkv_bias)
self.attn_output = nn.Linear(config.head_dim * config.n_heads, config.dim, bias=False)
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)

Expand Down Expand Up @@ -319,9 +321,10 @@ def __call__(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tens
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens.contiguous(), start_pos, temperature)

@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
def from_gguf(gguf:Tensor|str|pathlib.Path, max_context:int|None=None,
realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
# TODO: remove the need for copy to default device
kv, state_dict = gguf_load(gguf.to(None).realize())
kv, state_dict = gguf_load(gguf.to(None).realize() if isinstance(gguf, Tensor) else gguf)

# all state items should be float16, not float32
state_dict = {k:v.cast('float16') if getenv("HALF", 1) else v for k,v in state_dict.items()}
Expand All @@ -336,6 +339,7 @@ def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALI
ssm = None
if arch in ('qwen35', 'qwen35moe'):
ssm = SSMConfig(**{k: kv[f'{arch}.ssm.{k}'] for k in ('conv_kernel','state_size','group_count','time_step_rank','inner_size')})
if arch in ('qwen35', 'qwen35moe', 'glm4moe'):
state_dict = {k.replace('post_attention_norm', 'ffn_norm'):v for k,v in state_dict.items()}

kv_lora_rank = kv.get(f'{arch}.attention.kv_lora_rank', 0)
Expand All @@ -354,7 +358,7 @@ def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALI
elif kv_lora_rank and 'attn_kv_a_mqa.weight' in name:
state_dict[name] = state_dict[name][:kv_lora_rank].cat(state_dict[name][kv_lora_rank:].rearrange("(h two) d -> (two h) d", two=2), dim=0)
config = TransformerConfig(
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
num_blocks=kv[f'{arch}.block_count'] - kv.get(f'{arch}.nextn_predict_layers', 0), dim=kv[f'{arch}.embedding_length'],
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0)),
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']),
Expand All @@ -374,7 +378,9 @@ def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALI
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0),
qkv_bias='blk.0.attn_q.bias' in state_dict,
expert_bias=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.exp_probs_b.bias" in state_dict)
model = Transformer(config)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
Expand Down
Loading