Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/null/test_tensor_uop_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def test_stack_dim1(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=1
def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0))
def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1))

class TestTensorUOpEinsum(unittest.TestCase):
def test_einsum_dot(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij,ij->", x, x))
def test_einsum_transpose(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij->ji", x))

class TestTensorUOpSoftmax(unittest.TestCase):
def test_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.softmax())
def test_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.softmax(axis=0))
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/late/devectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges])
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=ctx.acc_num)
acc = UOp.placeholder((1,), red.dtype, ctx.acc_num, AddrSpace.REG)
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity)
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element
ctx.acc_num += 1
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
for u in included:
for s in u.src:
if s in included or s in replaces or s.op in {Ops.CONST, Ops.VCONST, Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: continue
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")
if not no_range(sink): return None
Expand Down
6 changes: 6 additions & 0 deletions tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in
def fully_flatten(l):
if not (hasattr(l, "__len__") and hasattr(l, "__getitem__")) or isinstance(l, str): return [l]
return [l[()]] if hasattr(l, "shape") and l.shape == () else [x for li in l for x in fully_flatten(li)]
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
def flat_to_grouped(padding:Sequence[T]) -> tuple[tuple[T, T], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
def resolve_pool_pads(padding:int|Sequence[int], dims:int) -> Sequence[int]:
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} with {dims=}.")
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0
def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst
Expand Down
46 changes: 44 additions & 2 deletions tinygrad/mixin/reduce.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Self, Sequence
import string
from typing import Self, Sequence, cast
from tinygrad.uop import Ops
from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype
from tinygrad.helpers import make_tuple
from tinygrad.helpers import argfix, argsort, make_tuple, merge_dicts
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin

Expand Down Expand Up @@ -135,3 +136,44 @@ def all(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Self:
```
"""
return self.bool().prod(axis, keepdim)

@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.

See: https://pytorch.org/docs/stable/generated/torch.einsum.html

```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
2 changes: 1 addition & 1 deletion tinygrad/renderer/nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class NIRRenderer(Renderer):

extra_matcher = PatternMatcher([
# handle negative unsigned CONST
(UPat.cvar("x", dtypes.uints), lambda x: UOp(Ops.CONST, dtype=x.dtype, arg=x.dtype.max+x.arg+1) if x.arg < 0 else None),
(UPat.cvar("x", dtypes.uints), lambda x: UOp.const(x.dtype, x.dtype.max+x.arg+1) if x.arg < 0 else None),
# from ptx
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
# load/store bool -> uint8
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/schedule/rangeify.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):

if allow_locals:
# handle locals
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=next(ctx))
buf = UOp.placeholder((size,), x.dtype, next(ctx), AddrSpace.LOCAL)
do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
return buf.after(do_store.barrier())

Expand Down
83 changes: 17 additions & 66 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
import time, math, itertools, functools, struct, sys, inspect, pathlib, hashlib, weakref
from contextlib import ContextDecorator
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_float, least_upper_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid, InvalidType
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin, ReductionStr
Expand Down Expand Up @@ -46,8 +46,8 @@ def _fromnp(x: 'numpy.ndarray') -> UOp:
return ret.reshape(x.shape)

def get_shape(x) -> tuple[int, ...]:
# NOTE: str is special because __getitem__ on a str is still a str
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return ()
# NOTE: str is special because iterating it still yields strs
if not hasattr(x, "__len__") or isinstance(x, str) or getattr(x, "shape", None) == (): return ()
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
return (len(subs),) + (subs[0] if subs else ())

Expand Down Expand Up @@ -87,9 +87,6 @@ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, .
# select from values for each True element in mask else select from target
return mask.where(values, target)

# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))

class Tensor(OpMixin):
"""
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
Expand Down Expand Up @@ -1105,7 +1102,7 @@ def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str=
# normalize to grouped format
if all(isinstance(p, (int,UOp)) for p in padding):
if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
pX = ((0,0),)*(self.ndim - len(padding)//2) + flat_to_grouped(cast(Sequence[sint], padding))
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding))
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
# dispatch
Expand Down Expand Up @@ -1518,66 +1515,20 @@ def argmin(self, axis=None, keepdim=False) -> Tensor:
"""
return self._inverse().argmax(axis=axis, keepdim=keepdim)

@staticmethod
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], dtype:DTypeLike|None=None) -> Tensor:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.

See: https://pytorch.org/docs/stable/generated/torch.einsum.html

```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return Tensor.uprod(*xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))

# ***** processing ops *****

def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])

def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple[int, ...], d_:int|tuple[int, ...]) -> list[int]:
(d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
pads, grouped_pads = list(pads), _flat_to_grouped(pads)
grouped_pads = list(flat_to_grouped(pads))
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
o_ = [ceildiv(i+pB+pA - (d*(k-1)+1), s) + 1 for i,d,k,s,(pB,pA) in zip(i_,d_,k_,s_,grouped_pads)]
for dim,(o,i,s,k,d,(pB,pA)) in enumerate(zip(o_,i_,s_,k_,d_,grouped_pads)):
# we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
# `s*(o-1) + (d*(k-1)+1) - (i+pB+pA)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
# we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
# `smax(s*(o-1) - (pB+i-1), 0)` -> last_sliding_window_start - (pad_before + input_size - zero_offset)
pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+pB+pA) - smax(s*(o-1) - (pB+i-1), 0)
return pads
grouped_pads[dim] = (pB, pA + s*(o-1) + (d*(k-1)+1) - (i+pB+pA) - smax(s*(o-1) - (pB+i-1), 0))
return flatten(reversed(grouped_pads))

# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
Expand Down Expand Up @@ -1618,7 +1569,7 @@ def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1,
"""
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
reg_pads = self._resolve_pool_pads(padding, len(k_))
reg_pads = resolve_pool_pads(padding, len(k_))
ceil_pads = self._apply_ceil_mode(reg_pads, k_, stride if stride is not None else k_, dilation)
if not count_include_pad:
pads = ceil_pads if ceil_mode else reg_pads
Expand Down Expand Up @@ -1660,7 +1611,7 @@ def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1,
```
"""
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
pads = self._resolve_pool_pads(padding, len(k_))
pads = resolve_pool_pads(padding, len(k_))
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
pooled = self.pad(pads, value=self.dtype.min)._pool(k_, stride if stride is not None else k_, dilation)
if not return_indices: return pooled.max(axis)
Expand Down Expand Up @@ -1694,7 +1645,7 @@ def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride
bs,c,*spatial_shape = self.shape
if output_size is None:
k_,d_,s_ = (make_tuple(x, len(spatial_shape)) for x in (kernel_size, dilation, stride if stride is not None else kernel_size))
p_ = _flat_to_grouped(self._resolve_pool_pads(padding, len(spatial_shape)))
p_ = flat_to_grouped(resolve_pool_pads(padding, len(spatial_shape)))
# https://arxiv.org/pdf/1603.07285 inverse of relationship 15 in section 5.1.
output_size = tuple((i-1)*s - (pB+pA) + (d*(k-1)+1) for i,k,d,s,(pA,pB) in zip(spatial_shape,k_,d_,s_,p_))
else: output_size = output_size[-len(spatial_shape):]
Expand Down Expand Up @@ -1730,7 +1681,7 @@ def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilat
"""
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
padding_ = self._resolve_pool_pads(padding, len(HW))
padding_ = resolve_pool_pads(padding, len(HW))
assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\
f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"

Expand All @@ -1755,8 +1706,8 @@ def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilat
# TODO: stride == dilation
# use padding to round up to 4x4 output tiles
# (bs, cin_, tyx, HWI)
pads = [[padding_[i*2], padding_[i*2+1] + (-(dim+sum(padding_[i*2:(i+1)*2])-2) % 4)] for i, dim in enumerate(reversed(self.shape[-len(HW):]))]
d = self.pad(sum(pads, []))._pool(HWI, HWO)
pads = [(pB, pA + (-(s + pB + pA - 2) % 4)) for (pB, pA), s in zip(flat_to_grouped(padding_), self.shape[-len(HW):])]
d = self.pad(flatten(reversed(pads)))._pool(HWI, HWO)
# move HW to the front: # (HWI, bs, cin_, tyx)
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
tyx = d.shape[-len(HWI):] # dim of tiling
Expand Down Expand Up @@ -1807,7 +1758,7 @@ def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, strid
"""
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
HW = weight.shape[2:]
padding = _flat_to_grouped(self._resolve_pool_pads(padding, len(HW)))
padding = flat_to_grouped(resolve_pool_pads(padding, len(HW)))
stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
if any(s>1 for s in stride):
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
Expand Down Expand Up @@ -2537,7 +2488,7 @@ def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1,
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)

padding_neg, padding_pos = [min(0, p) for p in self._resolve_pool_pads(padding, 2)], [max(0, p) for p in self._resolve_pool_pads(padding, 2)]
padding_neg, padding_pos = [min(0, p) for p in resolve_pool_pads(padding, 2)], [max(0, p) for p in resolve_pool_pads(padding, 2)]
x = x.pad(padding_neg)
iy, ix = x.shape[2:]

Expand Down
Loading
Loading