Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/mockgpu/amd/emu.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def emit(wave_id: int, inst, branch_taken: bool|None):
"""Emit an SQTT packet for one executed instruction."""
w = wave_id & 0x1F
if wave_id not in started:
_emit_nibbles(nibbles, WAVESTART, delta=1, simd=0, cu_lo=0, wave=w, id7=wave_id)
_emit_nibbles(nibbles, WAVESTART, delta=1, simd=0, wgp=0, wave=w, id7=wave_id)
started.add(wave_id)
inst_type, inst_op, op_name = type(inst), inst.op.value if hasattr(inst, 'op') else 0, inst.op.name if hasattr(inst, 'op') else ""
if issubclass(inst_type, _SOPP):
Expand All @@ -180,7 +180,7 @@ def emit(wave_id: int, inst, branch_taken: bool|None):

def finish(wave_id: int):
"""Emit WAVEEND for a completed wave."""
if wave_id in started: _emit_nibbles(nibbles, WAVEEND, delta=1, simd=0, cu_lo=0, wave=wave_id & 0x1F)
if wave_id in started: _emit_nibbles(nibbles, WAVEEND, delta=1, simd=0, wgp=0, wave=wave_id & 0x1F)

def finalize() -> bytes:
"""Pad and return the encoded SQTT blob."""
Expand Down
32 changes: 32 additions & 0 deletions test/null/test_tensor_uop_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ def test_mul_bool_int(self):
self.assertIs(_strip_unique((t.eq(1) * Tensor.arange(3)).uop), _strip_unique(t.uop.eq(1) * UOp.arange(3)))
# Tensor's ufix picks float dtype when scalar is float and self is int; UOp should match.
def test_add_scalar_float_on_int(self): _check(self, _t(3), lambda x: x + 1.5)
# div: Tensor.div (default case) delegates to ElementwiseMixin.div; trees must match for Tensor and UOp.
def test_div_tensor_by_tensor(self):
a, b = _t(4).float(), _t(4).float() + 1
self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop))
def test_div_int_by_int(self): _check(self, _t(4), lambda x: x / 3)
def test_div_sum_by_sum(self): _check(self, _t(4).float(), lambda x: x.sum() / (x + 1).sum())
def test_div_broadcast_tensor_by_tensor(self):
a, b = _t(3, 4).float(), _t(4).float() + 1
self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop))
# isclose used `self == other` which is Python identity on UOp (not elementwise); now uses .eq().
def test_isclose(self):
t = _t(4).float()
self.assertIs(_strip_unique(t.isclose(t).uop), _strip_unique(t.uop.isclose(t.uop)))

class TestTensorUOpGetitem(unittest.TestCase):
# ---- pure slice patterns ----
Expand Down Expand Up @@ -118,6 +131,25 @@ class TestTensorUOpLoss(unittest.TestCase):
def test_cross_entropy(self):
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.cross_entropy(Y).uop), _strip_unique(t.uop.cross_entropy(Y.uop)))
def test_sparse_categorical_crossentropy(self):
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.sparse_categorical_crossentropy(Y).uop), _strip_unique(t.uop.sparse_categorical_crossentropy(Y.uop)))
def test_sparse_categorical_crossentropy_ignore_index(self):
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.sparse_categorical_crossentropy(Y, ignore_index=0).uop),
_strip_unique(t.uop.sparse_categorical_crossentropy(Y.uop, ignore_index=0)))

class TestTensorUOpScatterReduce(unittest.TestCase):
def _check(self, x, idx, src, **kw):
self.assertIs(_strip_unique(x.scatter_reduce(0, idx, src, **kw).uop),
_strip_unique(x.uop.scatter_reduce(0, idx.uop, src.uop, **kw)))
def test_sum(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="sum")
def test_prod(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="prod")
def test_mean(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean")
def test_amax(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="amax")
def test_amin(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="amin")
def test_mean_exclude_self(self):
self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False)

class TestTensorUOpCat(unittest.TestCase):
def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0))
Expand Down
85 changes: 83 additions & 2 deletions tinygrad/mixin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DTypeLike, Invalid, InvalidType, PtrDType, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.dtype import ConstType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up

if TYPE_CHECKING:
Expand Down Expand Up @@ -419,7 +419,7 @@ def normalize(self, p:float=2.0, dim:int=1, eps:float=1e-12) -> Self:
print(t.normalize(p=1, dim=0).numpy())
```
"""
if p == 0: return self / (self != 0).sum(dim, keepdim=True).maximum(eps) # type: ignore[comparison-overlap]
if p == 0: return self / self.ne(0).sum(dim, keepdim=True).maximum(eps)
return self / self.abs().pow(p).sum(dim, keepdim=True).pow(1/p).maximum(eps)

def logsumexp(self, axis=None, keepdim=False) -> Self:
Expand Down Expand Up @@ -772,6 +772,63 @@ def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:boo
x = x.gather(i, index)
return x.cast(self.dtype)

def _pre_scatter(self, dim:int, index:Self, src:Self) -> tuple[Self, Self]:
if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
if src.device != self.device: raise RuntimeError(f"expected src and self on the same device, {src.device=}, {self.device=}")
dim = self._resolve_dim(dim)
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.ndim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}")
# shrink src to index shape to shrink away the unused values
src = src.shrink_to(index.shape)
# prepare src and mask for reduce with respect to dim
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
return src.pad_to(*self.shape, None), mask.pad_to(*self.shape, None)

def scatter_reduce(self, dim:int, index:Self, src:Self, reduce:Literal["sum", "prod", "mean", "amax", "amin"],
include_self:bool=True) -> Self:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`.

Set `include_self=False` to exclude values in the `self` Tensor from the reduction.

```python exec="true" source="above" session="tensor" result="python"
src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5)
print(src.numpy())
index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
print(index.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy())
```
"""
src, mask = self._pre_scatter(dim, index, src)
def _inv_mask(a:Self|PyConst, b:Self|PyConst) -> Self: return mask.any(-1).logical_not().where(a, b)
if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
if reduce == "amax": return mask.where(src, m := src.dtype.min).max(-1).maximum(self if include_self else _inv_mask(self, m))
if reduce == "amin": return mask.where(src, m := src.dtype.max).min(-1).minimum(self if include_self else _inv_mask(self, m))
if reduce == "mean":
count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0))
return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count)
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")

# ***** functional nn ops *****

def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:
Expand Down Expand Up @@ -936,6 +993,30 @@ def binary_crossentropy_logits(self, Y:Self, reduction:ReductionStr="mean", pos_
log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid()
return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction)

def sparse_categorical_crossentropy(self, Y:Self, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Self:
"""
Computes the sparse categorical cross-entropy loss between `self` and `Y`.

NOTE: `self` is logits and `Y` is the target labels.
NOTE: unlike PyTorch, this function expects the class axis to be -1

See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[-1, 2, -3], [1, -2, 3]])
Y = Tensor([1, 2])
print(t.sparse_categorical_crossentropy(Y).item())
```
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
if Y.device != self.device: raise RuntimeError(f"expected Y and self on the same device, {Y.device=}, {self.device=}")
log_probs = self.log_softmax()
loss_mask = Y.ne(ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y = Y.unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
return -unreduced.sum() / loss_mask.sum() if reduction == "mean" else -unreduced._do_reduction(reduction)

def cross_entropy(self, Y:Self, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Self:
"""
Computes the cross entropy loss between input logits and target.
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/mixin/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def mod(self, x: Self | ConstType, reverse: bool = False) -> Self:
return self._binop(Ops.MOD, x, reverse)

def div(self, x: Self | ConstType, reverse: bool = False) -> Self:
return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL))
lhs, rhs = self._broadcasted(x, reverse)
return lhs * rhs.reciprocal()

def __neg__(self) -> Self:
return self.neg()
Expand Down Expand Up @@ -566,7 +567,7 @@ def isclose(self, other, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) ->
```
"""
is_finite_close = self.isfinite() & other.isfinite() & ((self - other).abs() <= atol + rtol * other.abs())
is_infinite_close = (self.isinf() | other.isinf()) & (self == other)
is_infinite_close = (self.isinf() | other.isinf()) & self.eq(other)
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
return is_finite_close | is_infinite_close | is_nan_close

Expand Down
18 changes: 9 additions & 9 deletions tinygrad/renderer/amd/sqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,34 +288,34 @@ class WAVERDY(PacketType): # exclude: 1 << 3
class WAVEEND(PacketType): # exclude: 1 << 4
encoding = bits[4:0] == 0b10101
delta = bits[7:5]
flag7 = bits[8:8]
sa = bits[8:8]
simd = bits[10:9]
cu_lo = bits[13:11]
wgp = bits[13:11]
wave = bits[19:15]
@property
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
def cu(self) -> int: return self.wgp | (self.sa << 3)

class WAVESTART(PacketType): # exclude: 1 << 4
encoding = bits[4:0] == 0b01100
delta = bits[6:5]
flag7 = bits[7:7]
sa = bits[7:7]
simd = bits[9:8]
cu_lo = bits[12:10]
wgp = bits[12:10]
wave = bits[17:13]
id7 = bits[31:18]
@property
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
def cu(self) -> int: return self.wgp | (self.sa << 3)

class WAVESTART_RDNA4(PacketType): # Layout 4 has wave field at different position
encoding = bits[4:0] == 0b01100
delta = bits[6:5]
flag7 = bits[7:7]
sa = bits[7:7]
simd = bits[9:8]
cu_lo = bits[12:10]
wgp = bits[12:10]
wave = bits[19:15]
id7 = bits[31:20]
@property
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
def cu(self) -> int: return self.wgp | (self.sa << 3)

class WAVEALLOC(PacketType): # exclude: 1 << 10
encoding = bits[4:0] == 0b00101
Expand Down
3 changes: 1 addition & 2 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def create_non_native_float_pats(dts:tuple[DType, ...], casting:bool=True):
def cast_float_to_bf16(x: UOp) -> UOp:
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
x = x.bitcast(dtypes.uint)
# NOTE: != returns UOp, not bool, issue with mypy
x = ((-x & 0x7f800000) != 0).where(x + ((x >> 16) & 1) + 0x7fff, ((x & 0xffff) != 0).where((x | 0x10000), x)) # type: ignore[comparison-overlap]
x = (-x & 0x7f800000).ne(0).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).ne(0).where((x | 0x10000), x))
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)

# manual bfloat16 casting patterns (shared between LLVM, Clang, and AMD renderers to avoid compiler intrinsics)
Expand Down
Loading
Loading