Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions extra/viz/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ By default, VIZ CLI automatically loads the latest trace files.

Use `extra/viz/cli.py --profile -s ALL` to inspect the complete timing data of kernels, JIT, codegen and scheduling.

- Add DEBUG=3 to see AST, DEGUG=4 to also see source code.
- Add DEBUG=3 to see AST, DEBUG=4 to also see source code.
- Make sure to add NO_COLOR=1 to disable colored output.
- Add --jsonl to see JSON output

Expand All @@ -22,12 +22,6 @@ DEBUG=3 extra/viz/cli.py --profile -s ALL > asts.txt

# Get kernel timing information in JSONL format
extra/viz/cli.py --profile -s ALL --jsonl

# View top 40 slowest kernels on the AMD device and their AST (DEBUG=4 to see source code)
DEBUG=3 extra/viz/cli.py --profile -s AMD --top 40

# List top 10 slowest operations across all devices
extra/viz/cli.py --profile --top 10 -s ALL
```

## Inspect codegen and PatternMatcher
Expand Down
44 changes: 28 additions & 16 deletions extra/viz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,19 @@ def get(data:dict, key:str):
def main(args) -> None:
viz.load_rewrites(viz_data:=viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))))

def fmt(val, to_str=str) -> str: return json.dumps(val if isinstance(val, dict) else {"value":val}) if args.jsonl else to_str(val)

rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz_data.ctxs if c.get("steps")}
def print_step(step:dict) -> None:
data = viz.get_render(viz_data, step["query"])
if isinstance(data.get("value"), Iterator):
for m in data["value"]:
if m.get("uop"): print(json.dumps({"ast":m["uop"]}) if args.jsonl else m["uop"])
if m.get("uop"): print(fmt(m["uop"]))
if m.get("diff"):
loc = pathlib.Path(m["upat"][0][0])
print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")
for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
if data.get("src") is not None: print(json.dumps({"src":data["src"]}) if args.jsonl else data["src"])
print(fmt(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}"))
for line in m["diff"]: print(fmt(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)))
if data.get("src") is not None: print(fmt(data["src"]))

# ** Graph rewrites printer
if args.rewrites:
Expand Down Expand Up @@ -116,7 +118,8 @@ def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16
phase, delay = "EXEC", int(e.st) - dispatch_st
if inst and phase: info = f"{phase:<8} {inst}"
unit = e.device.replace(" ", "-")
print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}")
row = {"clk":int(e.st)-inst_st, "unit":unit, "op":op_name, "dur":int(unwrap(e.en)-e.st), "delay":delay or "", "info":info}
print(fmt(row, lambda _: f"{row['clk']:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {row['dur']:<4} {str(row['delay']):<4} {info}"))

# ** PMC printer
elif "PMC" in args.src:
Expand All @@ -130,17 +133,19 @@ def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16
cols = r[2]["cols"] if len(r) > 2 else cols
pmc_data = [[x for x in cols], *[[str(x) for x in r] for r in rows]]
widths = [max(len(r[i]) for r in pmc_data) for i in range(len(cols))]
def fmt(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |"
print(fmt(pmc_data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in pmc_data[1:]])))
def pad(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |"
table_str = pad(pmc_data[0])+"\n"+pad(["-"*w for w in widths])+"\n"+("\n".join([pad(row) for row in pmc_data[1:]]))
print(fmt({"cols":cols, "rows":rows}, lambda _: table_str))

# ** Memory printer
elif data is not None and data["event_type"] == 1:
print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info")
print(fmt({"peak":data["peak"], "cols":["ts", "event", "key", "info"]},
lambda _: f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info"))
for e in data["events"]:
info = str(e.get("arg", {}))
info = str(arg:=e.pop("arg", {}))
if e["event"] == "free":
info = ', '.join([f"{fmt_colored(kernel)} {['read','write','write+read'][mode]}@data{num}" for _,kernel,num,mode in e["arg"]["users"]])
print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}")
info = ', '.join([f"{fmt_colored(kernel)} {['read','write','write+read'][mode]}@data{num}" for _,kernel,num,mode in arg["users"]])
print(fmt({**e, "info":info}, lambda _: f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}"))

# ** Profiler printer
else:
Expand All @@ -157,7 +162,7 @@ def produce_top_kernels() -> Iterator[dict]:
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
num_rows = len(items) if args.top < 0 else args.top
for (dev,name),(t,c,ref) in items[:num_rows]:
display = f"{dev[:7]:7s} {fmt_colored(name)}" if args.src == "ALL" else name
display = f"{dev[:7]:7s} {fmt_colored(name)}" if args.src == "ALL" else fmt_colored(name)
yield {"name":display, "dur_ms":t, "count":c, "pct":t/total*100.0, "ref":ref}
if num_rows > 0 and items[num_rows:]:
other_t = sum(t for _,(t,_,_) in items[num_rows:])
Expand All @@ -171,9 +176,16 @@ def produce_all_kernels() -> Iterator[dict]:
if dev == "MARKER":
yield {"device":dev, "name":fmt_colored(e["name"]), "et_ms":ts*1e-3, "ref":None, "ext":None}
continue
if e["fmt"].startswith("TB:"): e["fmt"] = "" # TODO: print python backtrace at a reasonable DEBUG level
ext:list[str] = []
if (fmt:=e["fmt"]).startswith("TB:"):
tb, fmt = json.loads(e["fmt"].replace("TB:", "")), ""
while tb:
file, lineno, fxn, code = tb.pop()
line = f"{file.split('/')[-1]}:{lineno} {fxn}"
if fmt: ext.append(f"{line} {code}")
elif not file.startswith("<") and not fxn.startswith("<"): fmt = line
yield {"device":dev, "name":fmt_colored(e["name"]), "dur_ms":e["dur"]*1e-3,
"et_ms":(e["st"]+e["dur"])*1e-3, "fmt":e["fmt"], "ref":e["ref"], "ext":None}
"et_ms":(e["st"]+e["dur"])*1e-3, "fmt":fmt, "ref":e["ref"], "ext":"\n".join(ext)}
def fmt_top(k:dict) -> str:
return f"{fmt_colored(k['name'])}{' ' * max(0, 36-ansilen(k['name']))} {time_to_str(k['dur_ms']*1e-3, w=9)} {k['count']:7d} {k['pct']:6.2f}%"
def fmt_all(k:dict) -> str:
Expand All @@ -184,12 +196,12 @@ def fmt_all(k:dict) -> str:
return f"{name} tm {ptm}/{k['et_ms']:9.2f}ms"+(f" ({fmt_str})" if k["fmt"] else "")
fmt_row = fmt_top if args.top else fmt_all
for k in (produce_top_kernels if args.top else produce_all_kernels)():
if args.jsonl: print(json.dumps(k))
else: print(fmt_row(k))
print(fmt(k, to_str=fmt_row))
if k["ref"] is not None:
steps = rewrites[viz_data.ctxs[k["ref"]]["name"]]
if DEBUG >= 3 and (ast_step:=steps.get("View Base AST")) is not None: print_step(ast_step)
if DEBUG >= 4 and (src_step:=steps.get("View Source")) is not None: print_step(src_step)
elif DEBUG >= 3 and k.get("ext"): print(fmt(k["ext"]))

def get_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(add_help=False)
Expand Down
26 changes: 25 additions & 1 deletion test/null/test_tensor_uop_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math, unittest
from tinygrad import Tensor
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp

def _t(*shape):
return Tensor.arange(math.prod(shape)).reshape(*shape)
Expand Down Expand Up @@ -76,5 +77,28 @@ def test_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.s
def test_log_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax())
def test_log_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax(axis=0))

# UOp.empty / UOp.empty_like are the canonical buffer allocators; Tensor.empty / Tensor.empty_like just forward.
class TestUOpEmpty(unittest.TestCase):
def test_empty_dtype_string(self):
self.assertEqual(UOp.empty((3, 4), dtype="float32").dtype, dtypes.float32)

def test_empty_like_dtype_override(self):
u = Tensor.ones(3, 4).uop.empty_like(dtype=dtypes.int8)
self.assertEqual((u.shape, u.dtype), ((3, 4), dtypes.int8))
self.assertTrue(u.has_buffer_identity())

def test_empty_like_sharded_to_single_device(self):
# regression: sharded source, override to single device must yield full logical shape with no axis
t = Tensor.ones(8, 4).shard(("NULL:0", "NULL:1"), axis=0)
for dev in ("NULL:2", ("NULL:2",)): # singleton tuple also canonicalizes to single device
u = t.uop.empty_like(device=dev, dtype=dtypes.int32)
self.assertEqual((u.shape, u.device, u.dtype, u.axis), ((8, 4), "NULL:2", dtypes.int32, None))
self.assertTrue(u.has_buffer_identity())

def test_empty_direct_singleton_tuple_device(self):
# regression: direct UOp.empty with a singleton-tuple device + axis must not trip .multi()'s tuple assert
u = UOp.empty((4,), dtype=dtypes.float32, device=("NULL:0",), axis=0)
self.assertEqual((u.shape, u.device, u.axis), ((4,), "NULL", None))

if __name__ == "__main__":
unittest.main()
11 changes: 3 additions & 8 deletions tinygrad/callify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def tag_uop(ctx:AllocCtx, x:UOp):
def disk_copy_is_buffer(ctx:AllocCtx, u:UOp):
# copies to disk are replaced with the disk buffer
to_disk = isinstance(u._device, str) and u._device.startswith(("DISK", "TINYFS"))
if to_disk: ctx.buffer_map[u] = UOp.new_buffer(u.device, u.shard_size, u.dtype).reshape(u.max_shard_shape)
if to_disk: ctx.buffer_map[u] = u.empty_like()
# all copies from disk/numpy are realized into a real buffer
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK", "PYTHON", "TINYFS"])
if from_creation: return tag_uop(ctx, u)
Expand All @@ -40,19 +40,14 @@ def apply_after(ctx:AllocCtx, u:UOp):
(UPat(GroupOp.All, name="x"), lambda ctx,x: tag_uop(ctx,x) if x in ctx.bases else None),
])

def _buffer_like(u:UOp) -> UOp:
buffer = UOp.new_buffer(u.device, u.shard_size, u.dtype).reshape(u.max_shard_shape).shrink_to(u.shard_shape)
if isinstance(u.device, tuple) and u.axis is not None: buffer = buffer.multi(u.axis)
return buffer

def replace_contig_with_store_after(u:UOp):
# can't allocate a buffer without a device (e.g., inside a CALL function body with only PARAMs)
if u._device is None: return None
# if size is 0, remove the contig
if 0 in u.shape: return u.src[0]
# no real contig for DISK/TINYFS tensors, they are left alone
if isinstance(u._device, str) and u._device.startswith(("DISK", "TINYFS")): return u.rtag(None)
buf = _buffer_like(u)
buf = u.empty_like()
return buf.after(buf.store(u.src[0])).rtag(u.tag)

def replace_store_after_with_contig(u:UOp, src:UOp):
Expand Down Expand Up @@ -102,7 +97,7 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
# add the outputs to the call
srcs = c.src[0].src
resolved = [c.gettuple(i) for i in range(len(srcs))]
outs = tuple(_buffer_like(r) for r in resolved)
outs = tuple(r.empty_like() for r in resolved)
targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))]
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])

Expand Down
13 changes: 3 additions & 10 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_shape(x) -> tuple[int, ...]:
def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp:
if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x
else:
ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape)
ret = UOp.empty(shape:=get_shape(x), dtype, "PYTHON")
assert dtype.fmt is not None, f"{dtype=} has None fmt"
truncate_function = truncate[dtype]
data = struct.pack(f"{prod(shape)}{dtype.fmt}", *[truncate_function(dtype.const(xi)) for xi in fully_flatten(x)])
Expand Down Expand Up @@ -516,21 +516,14 @@ def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=Non
print(t.shape)
```
"""
dtype, shape = to_dtype(dtype) if dtype is not None else dtypes.default_float, argfix(*shape)
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
# TODO: add test for multidevice tensor
device = canonicalize_device(device)
return Tensor(UOp.new_buffer(device, size, dtype), **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
return Tensor(UOp.empty(argfix(*shape), dtype, device), **kwargs)

def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device
if isinstance(device, tuple) and (axis := self.uop.axis) is not None:
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis))
return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs)
return Tensor(self.uop.empty_like(dtype, device), **kwargs)

@staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
Expand Down
20 changes: 15 additions & 5 deletions tinygrad/uop/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from dataclasses import dataclass
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace, ConstFloat, PyConst
from tinygrad.dtype import storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, DTypeLike, to_dtype, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace
from tinygrad.dtype import ConstFloat, PyConst, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
from tinygrad.device import Buffer, MultiBuffer, canonicalize_device
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY
from tinygrad.helpers import strip_parens, colored, ansilen, printable
Expand Down Expand Up @@ -328,8 +328,7 @@ def shape(self) -> tuple[sint, ...]:
return ret

@property
def max_shape(self) -> tuple[int, ...]:
return tuple([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape])
def max_shape(self) -> tuple[int, ...]: return to_max_shape(self.shape)

@property
def shard_shape(self) -> tuple[sint, ...]:
Expand Down Expand Up @@ -656,6 +655,16 @@ def unique(arg:int|None=None): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num) i
@staticmethod
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
@staticmethod
def empty(shape:tuple[sint, ...], dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, axis:int|None=None, num=None) -> UOp:
dtype, device = to_dtype(dtype) if dtype is not None else dtypes.default_float, canonicalize_device(device)
max_shape = to_max_shape(shape)
ret = UOp.new_buffer(device, prod(max_shape), dtype, num).reshape(max_shape).shrink_to(shape)
return ret.multi(axis) if isinstance(device, tuple) and axis is not None else ret
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp:
device = canonicalize_device(self.device if device is None else device)
axis = self.axis if isinstance(device, tuple) else None
return UOp.empty(self.shard_shape if axis is not None else self.shape, self.dtype if dtype is None else dtype, device, axis)
@property
def device(self) -> str|tuple[str, ...]: return unwrap(self._device)
@recursive_property
Expand Down Expand Up @@ -1435,6 +1444,7 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
return rewrite_ctx.walk_rewrite(sink) if walk else rewrite_ctx.unified_rewrite(sink)

def sint_to_uop(x:sint, dtype=dtypes.weakint) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
def to_max_shape(shape:tuple[sint, ...]) -> tuple[int, ...]: return tuple(int(x.vmax) if isinstance(x, UOp) else x for x in shape)

def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
pm_lower_index_dtype = PatternMatcher([
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/viz/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite) -> Generator[GraphRe
next_sink = _reconstruct(data, ctx.sink)
yield {"graph":uop_to_json(data, next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches, disable=not ctx.matches):
replaces[u0:=_reconstruct(data, u0_num)] = u1 = _reconstruct(data, u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
Expand Down
Loading