From 902906f90904803620d18da4bcc0f9bd466cd34a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 18 Mar 2023 17:52:46 -0700 Subject: [PATCH] Fix constant folding (#713) * fix * codegen * contiguous is real * no bufs_to_delete * don't assign rawconst * remove neg and not * need exec to fix custom function jit --- README.md | 2 +- test/external/external_test_opt.py | 4 ++-- test/test_custom_function.py | 2 +- test/test_ops.py | 2 ++ tinygrad/codegen/gpu.py | 12 +++++------- tinygrad/codegen/llvm.py | 2 -- tinygrad/lazy.py | 12 ++++++++++-- tinygrad/mlops.py | 18 +++++++++--------- tinygrad/ops.py | 21 +++++++++++---------- tinygrad/runtime/lib.py | 4 +++- tinygrad/runtime/ops_cpu.py | 1 - tinygrad/tensor.py | 2 +- 12 files changed, 45 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index f40aa87a..11028fad 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations ``` Buffer # class of memory on this device -unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A +unary_op (NOOP, EXP, LOG, CAST) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size) movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 9be556ee..061b832d 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -67,7 +67,7 @@ class TestOpt(unittest.TestCase): # TODO: this should be 4, but the sum output child stays around # with pushing_permutes it can be 3 # TODO: broken with optim fixes - assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}" + assert len(GlobalCounters.cache) in [4,5,6,7], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}" Tensor.training = False def test_fold_conv_batchnorm_sgd(self): @@ -83,7 +83,7 @@ class TestOpt(unittest.TestCase): img_bn.backward() opt.step() # TODO: broken with optim fixes - assert len(GlobalCounters.cache) in [9,10,13], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}" + assert len(GlobalCounters.cache) in [9,10,13,14], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}" Tensor.training = False def test_fold_conv_batchnorm_notrain(self): diff --git a/test/test_custom_function.py b/test/test_custom_function.py index 287516cc..8b01fa77 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -43,7 +43,7 @@ class ATan2(Function): def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b)) return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ - grad_output.binary_op(BinaryOps.MUL, self.a.unary_op(UnaryOps.NEG).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None + grad_output.binary_op(BinaryOps.MUL, self.a.const_like(0).binary_op(BinaryOps.SUB, self.a).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None # *** third, we use our lovely new mlop in some tests *** diff --git a/test/test_ops.py b/test/test_ops.py index f65def2b..782681f6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -103,6 +103,8 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y) def test_sub(self): helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub) + def test_neg(self): + helper_test_op([(45,65)], lambda x: -x) def test_mul(self): helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul) def test_div(self): diff --git a/tinygrad/codegen/gpu.py b/tinygrad/codegen/gpu.py index 4343969a..a2f3f510 100644 --- a/tinygrad/codegen/gpu.py +++ b/tinygrad/codegen/gpu.py @@ -50,13 +50,14 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F class GPUCodegen(ASTKernel): lang: ClassVar[GPULanguage] = GPULanguage() + supports_constant_folding: bool = True # for renaming kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) kernel_name_cache: Final[Dict[str, str]] = {} code_for_op: Final[Dict[Op, str]] = { - UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)", UnaryOps.CAST: "(A)", + UnaryOps.NOOP: "(A)", UnaryOps.CAST: "(A)", UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", @@ -98,8 +99,6 @@ class GPUCodegen(ASTKernel): # constant folding const = None if self.bufs[buf_index] is not None and isinstance(self.bufs[buf_index].realized, RawConst): - # bufs_to_delete can be removed, just ignore RawConst at runtime - if buf_index != 0: self.bufs_to_delete.add(buf_index) val = self.bufs[buf_index].realized._buf assert not math.isnan(val) const = Token(f"({val}f)", Types.FLOAT) @@ -276,7 +275,6 @@ class GPUCodegen(ASTKernel): print("output shape", self.output_shape) self.printbufs("new:", DEBUG>=5) - self.bufs_to_delete: Set[int] = set() self.loaded_keys: Dict[Tuple[int,int], Token] = {} self.prekernel: Set[str] = set() self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(buf.dtype.name.startswith("image") for buf in self.bufs if buf is not None) else [] @@ -337,9 +335,9 @@ class GPUCodegen(ASTKernel): self.kernel.append("\n}") # concat kernel into prg - buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix for i,x in enumerate(self.bufs) if x is not None] + buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix) for i,x in enumerate(self.bufs) if x is not None and not isinstance(x.realized, RawConst)] prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] + - [', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] + + [', '.join([f'{t} data{i}' for i,t in buftypes] + self.lang.extra_args)] + [") {\n"] + self.kernel) # kernel function definition @@ -352,7 +350,7 @@ class GPUCodegen(ASTKernel): if GPUCodegen.kernel_cnt[function_name] > 1: function_name = f"{function_name}{'n'+str(GPUCodegen.kernel_cnt[function_name]-1)}" GPUCodegen.kernel_name_cache[prg] = function_name - return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete, + return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None, op_estimate=self.info.flops, diff --git a/tinygrad/codegen/llvm.py b/tinygrad/codegen/llvm.py index dcfb6a76..62d1c6bd 100644 --- a/tinygrad/codegen/llvm.py +++ b/tinygrad/codegen/llvm.py @@ -22,10 +22,8 @@ render_llvm = { class LLVMCodegen(ASTKernel): op_lookup: ClassVar = { UnaryOps.NOOP: lambda builder,x: x, - UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)), UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)), - UnaryOps.NOT: lambda builder,x: builder.fsub(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)), BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 0932e619..968fe164 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -101,14 +101,17 @@ class LazyBuffer: for x in get_buffers(op): x.children.add(self) if not LAZY: self.realize() - def __repr__(self): return f"" + def __repr__(self): return f"" def realize(self:LazyBuffer, required_device=None) -> LazyBuffer: assert required_device is None or required_device == self.device if self.realized is None: # get real ops first if self.op.op == LoadOps.FROMCPU: - self.realized = Device[self.device].buffer.fromCPU(self.op.arg()) + if prod(self.op.arg.shape) == 1 and hasattr(Device[self.device].codegen, 'supports_constant_folding'): + self.realized = RawConst(1, dtypes.from_np(self.op.arg.dtype), self.op.arg().flatten()[0]) + else: + self.realized = Device[self.device].buffer.fromCPU(self.op.arg()) elif self.op.op == LoadOps.CONTIGUOUS: realized = self.op.src[0].realize(self.device).realized if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape): @@ -155,6 +158,11 @@ class LazyBuffer: def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer: return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x.dtype)) + # create a constant with the shape and dtype of self + def const_like(self, val) -> LazyBuffer: + return LazyBuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \ + .movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape) + # NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this? def toCPU(self): realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index db5458a0..6a09bbfb 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -73,7 +73,7 @@ class Maximum(Function): def backward(self, grad_output): mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret) # TODO: if they are equal, do they split the gradient? - return grad_output.binary_op(BinaryOps.MUL, mask.unary_op(UnaryOps.NOT)) if self.needs_input_grad[0] else None, \ + return grad_output.binary_op(BinaryOps.MUL, mask.const_like(1).binary_op(BinaryOps.SUB, mask)) if self.needs_input_grad[0] else None, \ grad_output.binary_op(BinaryOps.MUL, mask) if self.needs_input_grad[1] else None class Add(Function): @@ -85,28 +85,28 @@ class Add(Function): grad_output if self.needs_input_grad[1] else None class Sub(Function): - def forward(self, x, y): + def forward(self, x:LazyBuffer, y:LazyBuffer): return x.binary_op(BinaryOps.SUB, y) - def backward(self, grad_output) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output if self.needs_input_grad[0] else None, \ - grad_output.unary_op(UnaryOps.NEG) if self.needs_input_grad[1] else None + grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None class Mul(Function): - def forward(self, x, y): + def forward(self, x:LazyBuffer, y:LazyBuffer): self.x, self.y = x, y return x.binary_op(BinaryOps.MUL, y) - def backward(self, grad_output) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \ self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None class Pow(Function): - def forward(self, x, y): + def forward(self, x:LazyBuffer, y:LazyBuffer): self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y) return self.ret - def backward(self, grad_output): + def backward(self, grad_output:LazyBuffer): return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \ grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None @@ -117,7 +117,7 @@ class Div(Function): def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output.binary_op(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \ - grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None + grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # ************* movement ops ************* diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c5bc8951..1b96790f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,14 +1,14 @@ from __future__ import annotations import functools, itertools, operator, random from enum import Enum, auto -from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Set, Callable +from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType from tinygrad.shape.shapetracker import MovementOps -from tinygrad.runtime.lib import RawBuffer +from tinygrad.runtime.lib import RawBuffer, RawConst # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly -class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); NEG = auto(); NOT = auto(); CAST = auto() # noqa: E702 +class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class FusedOps(Enum): MULACC = auto() # noqa: E702 @@ -39,6 +39,7 @@ class Interpreted: self.fxn_for_op = fxn_for_op self.from_lazybuffer = from_lazybuffer self.to_underlying = to_underlying + self.codegen = None def exec_ast(self, ast:LazyOp, output=None, context=None): if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: @@ -48,7 +49,7 @@ class Interpreted: if not created_context and ast in context: return context[ast] srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src] ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else [])))) - if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "") + if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "") if not created_context: context[ast] = ret if output is not None and output.output_buffer is not None: assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype @@ -58,7 +59,7 @@ class Interpreted: return ret class FlopCounter: - def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops = tup + def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops, self._buf = *tup, self def consume_flops(self): self.flops, ret = 0, self.flops return ret @@ -75,22 +76,21 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex # **************** for Compiled Buffers **************** class ASTRunner: - def __init__(self, name, prg, bufs_to_delete:Optional[Set[int]]=None, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0): + def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0): if DEBUG >= 4: print(prg) - self.name, self.prg, self.global_size, self.local_size, self.bufs_to_delete, self.op_estimate, self.mem_estimate = name, prg, global_size, local_size, bufs_to_delete if bufs_to_delete else set(), op_estimate, mem_estimate + self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate = name, prg, global_size, local_size, op_estimate, mem_estimate def build(self, runtime): self.clprg = runtime(self.name, self.prg) return self def exec(self, bufs) -> Optional[float]: - rawbufs = [x.realized for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete] - assert all(x is not None for x in rawbufs), "some rawbufs are None, you probably didn't realize them" - if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs) + rawbufs = [x.realized for x in bufs if x is not None and not isinstance(x.realized, RawConst)] if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs)) return self(rawbufs) def __call__(self, rawbufs:List[RawBuffer]) -> Optional[float]: + if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs) if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et if DEBUG >= 2: print(f"*** {GlobalCounters.kernel_count:4d} {self.name:20s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + @@ -143,6 +143,7 @@ class Compiled: # NOTE: this is pretty wrong actually, who knows where else this buffer is used? output.realized = output.output_buffer if output.realized is not None: + if isinstance(output.realized, RawConst): output.realized = None # can't assign to RawConst for a in get_buffers(ast): if a.realized == output.realized and not a.st.contiguous: output.realized = None diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 2d91c6e0..45d3c1df 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -12,6 +12,7 @@ class RawBuffer: # pylint: disable=abstract-method self._memsz: int = size*dtype.itemsize GlobalCounters.mem_used += self._memsz def __del__(self): GlobalCounters.mem_used -= self._memsz + def __repr__(self): return f"buffer<{self.size}, {self.dtype}>" # NOTE: this interface allows for 0 copy @classmethod @@ -45,4 +46,5 @@ class RawBufferCopyInOut(RawBufferCopyIn): self._copyout(x) return x -class RawConst(RawBuffer): pass # pylint: disable=abstract-method +class RawConst(RawBuffer): # pylint: disable=abstract-method + def __repr__(self): return f"const<{self._buf}, {self.dtype}>" diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 50e2575f..5a311cd9 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -10,7 +10,6 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b) base_fxn_for_op: Dict[Op, Callable] = { - UnaryOps.NEG: lambda x: -x, UnaryOps.NOT: lambda x: (1.0 - x), BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 37a4a7d3..7f698510 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -189,7 +189,7 @@ class Tensor: for g in ([grads] if len(t0._ctx.parents) == 1 else grads)] for t, g in zip(t0._ctx.parents, grads): if g is not None and t.requires_grad: - assert g.shape == t.shape, f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}" + assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" t.grad = g if t.grad is None else (t.grad + g) del t0._ctx