mirror of https://github.com/commaai/tinygrad.git
Fix constant folding (#713)
* fix * codegen * contiguous is real * no bufs_to_delete * don't assign rawconst * remove neg and not * need exec to fix custom function jit
This commit is contained in:
parent
73bd0b217b
commit
902906f909
|
@ -159,7 +159,7 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations
|
|||
|
||||
```
|
||||
Buffer # class of memory on this device
|
||||
unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A
|
||||
unary_op (NOOP, EXP, LOG, CAST) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
|
||||
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size)
|
||||
|
|
|
@ -67,7 +67,7 @@ class TestOpt(unittest.TestCase):
|
|||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
|
||||
assert len(GlobalCounters.cache) in [4,5,6,7], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
|
@ -83,7 +83,7 @@ class TestOpt(unittest.TestCase):
|
|||
img_bn.backward()
|
||||
opt.step()
|
||||
# TODO: broken with optim fixes
|
||||
assert len(GlobalCounters.cache) in [9,10,13], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}"
|
||||
assert len(GlobalCounters.cache) in [9,10,13,14], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_notrain(self):
|
||||
|
|
|
@ -43,7 +43,7 @@ class ATan2(Function):
|
|||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b))
|
||||
return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, self.a.unary_op(UnaryOps.NEG).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None
|
||||
grad_output.binary_op(BinaryOps.MUL, self.a.const_like(0).binary_op(BinaryOps.SUB, self.a).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None
|
||||
|
||||
# *** third, we use our lovely new mlop in some tests ***
|
||||
|
||||
|
|
|
@ -103,6 +103,8 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y)
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
|
||||
def test_neg(self):
|
||||
helper_test_op([(45,65)], lambda x: -x)
|
||||
def test_mul(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)
|
||||
def test_div(self):
|
||||
|
|
|
@ -50,13 +50,14 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F
|
|||
|
||||
class GPUCodegen(ASTKernel):
|
||||
lang: ClassVar[GPULanguage] = GPULanguage()
|
||||
supports_constant_folding: bool = True
|
||||
|
||||
# for renaming
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
kernel_name_cache: Final[Dict[str, str]] = {}
|
||||
|
||||
code_for_op: Final[Dict[Op, str]] = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)", UnaryOps.CAST: "(A)",
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.CAST: "(A)",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
|
||||
|
@ -98,8 +99,6 @@ class GPUCodegen(ASTKernel):
|
|||
# constant folding
|
||||
const = None
|
||||
if self.bufs[buf_index] is not None and isinstance(self.bufs[buf_index].realized, RawConst):
|
||||
# bufs_to_delete can be removed, just ignore RawConst at runtime
|
||||
if buf_index != 0: self.bufs_to_delete.add(buf_index)
|
||||
val = self.bufs[buf_index].realized._buf
|
||||
assert not math.isnan(val)
|
||||
const = Token(f"({val}f)", Types.FLOAT)
|
||||
|
@ -276,7 +275,6 @@ class GPUCodegen(ASTKernel):
|
|||
print("output shape", self.output_shape)
|
||||
self.printbufs("new:", DEBUG>=5)
|
||||
|
||||
self.bufs_to_delete: Set[int] = set()
|
||||
self.loaded_keys: Dict[Tuple[int,int], Token] = {}
|
||||
self.prekernel: Set[str] = set()
|
||||
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(buf.dtype.name.startswith("image") for buf in self.bufs if buf is not None) else []
|
||||
|
@ -337,9 +335,9 @@ class GPUCodegen(ASTKernel):
|
|||
self.kernel.append("\n}")
|
||||
|
||||
# concat kernel into prg
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix for i,x in enumerate(self.bufs) if x is not None]
|
||||
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix) for i,x in enumerate(self.bufs) if x is not None and not isinstance(x.realized, RawConst)]
|
||||
prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
|
||||
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] +
|
||||
[', '.join([f'{t} data{i}' for i,t in buftypes] + self.lang.extra_args)] +
|
||||
[") {\n"] + self.kernel)
|
||||
|
||||
# kernel function definition
|
||||
|
@ -352,7 +350,7 @@ class GPUCodegen(ASTKernel):
|
|||
if GPUCodegen.kernel_cnt[function_name] > 1: function_name = f"{function_name}{'n'+str(GPUCodegen.kernel_cnt[function_name]-1)}"
|
||||
GPUCodegen.kernel_name_cache[prg] = function_name
|
||||
|
||||
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
|
||||
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
|
||||
list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1],
|
||||
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
|
||||
op_estimate=self.info.flops,
|
||||
|
|
|
@ -22,10 +22,8 @@ render_llvm = {
|
|||
class LLVMCodegen(ASTKernel):
|
||||
op_lookup: ClassVar = {
|
||||
UnaryOps.NOOP: lambda builder,x: x,
|
||||
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
|
||||
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.NOT: lambda builder,x: builder.fsub(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
|
||||
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
||||
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
||||
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),
|
||||
|
|
|
@ -101,14 +101,17 @@ class LazyBuffer:
|
|||
for x in get_buffers(op): x.children.add(self)
|
||||
if not LAZY: self.realize()
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else 'realized'} st:{self.st}>"
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else self.realized} st:{self.st}>"
|
||||
|
||||
def realize(self:LazyBuffer, required_device=None) -> LazyBuffer:
|
||||
assert required_device is None or required_device == self.device
|
||||
if self.realized is None:
|
||||
# get real ops first
|
||||
if self.op.op == LoadOps.FROMCPU:
|
||||
self.realized = Device[self.device].buffer.fromCPU(self.op.arg())
|
||||
if prod(self.op.arg.shape) == 1 and hasattr(Device[self.device].codegen, 'supports_constant_folding'):
|
||||
self.realized = RawConst(1, dtypes.from_np(self.op.arg.dtype), self.op.arg().flatten()[0])
|
||||
else:
|
||||
self.realized = Device[self.device].buffer.fromCPU(self.op.arg())
|
||||
elif self.op.op == LoadOps.CONTIGUOUS:
|
||||
realized = self.op.src[0].realize(self.device).realized
|
||||
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
|
||||
|
@ -155,6 +158,11 @@ class LazyBuffer:
|
|||
def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer:
|
||||
return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x.dtype))
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const_like(self, val) -> LazyBuffer:
|
||||
return LazyBuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \
|
||||
.movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape)
|
||||
|
||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||
def toCPU(self):
|
||||
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized
|
||||
|
|
|
@ -73,7 +73,7 @@ class Maximum(Function):
|
|||
def backward(self, grad_output):
|
||||
mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret)
|
||||
# TODO: if they are equal, do they split the gradient?
|
||||
return grad_output.binary_op(BinaryOps.MUL, mask.unary_op(UnaryOps.NOT)) if self.needs_input_grad[0] else None, \
|
||||
return grad_output.binary_op(BinaryOps.MUL, mask.const_like(1).binary_op(BinaryOps.SUB, mask)) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, mask) if self.needs_input_grad[1] else None
|
||||
|
||||
class Add(Function):
|
||||
|
@ -85,28 +85,28 @@ class Add(Function):
|
|||
grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Sub(Function):
|
||||
def forward(self, x, y):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer):
|
||||
return x.binary_op(BinaryOps.SUB, y)
|
||||
|
||||
def backward(self, grad_output) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output.unary_op(UnaryOps.NEG) if self.needs_input_grad[1] else None
|
||||
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
def forward(self, x, y):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer):
|
||||
self.x, self.y = x, y
|
||||
return x.binary_op(BinaryOps.MUL, y)
|
||||
|
||||
def backward(self, grad_output) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
||||
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Pow(Function):
|
||||
def forward(self, x, y):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer):
|
||||
self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output):
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
|
||||
|
||||
|
@ -117,7 +117,7 @@ class Div(Function):
|
|||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
|
||||
grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
|
||||
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
from __future__ import annotations
|
||||
import functools, itertools, operator, random
|
||||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Set, Callable
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable
|
||||
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType
|
||||
from tinygrad.shape.shapetracker import MovementOps
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawConst
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); NEG = auto(); NOT = auto(); CAST = auto() # noqa: E702
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
|
@ -39,6 +39,7 @@ class Interpreted:
|
|||
self.fxn_for_op = fxn_for_op
|
||||
self.from_lazybuffer = from_lazybuffer
|
||||
self.to_underlying = to_underlying
|
||||
self.codegen = None
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, context=None):
|
||||
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
|
@ -48,7 +49,7 @@ class Interpreted:
|
|||
if not created_context and ast in context: return context[ast]
|
||||
srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
||||
ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if not created_context: context[ast] = ret
|
||||
if output is not None and output.output_buffer is not None:
|
||||
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
|
||||
|
@ -58,7 +59,7 @@ class Interpreted:
|
|||
return ret
|
||||
|
||||
class FlopCounter:
|
||||
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops = tup
|
||||
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops, self._buf = *tup, self
|
||||
def consume_flops(self):
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
|
@ -75,22 +76,21 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
|
|||
# **************** for Compiled Buffers ****************
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, name, prg, bufs_to_delete:Optional[Set[int]]=None, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0):
|
||||
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0):
|
||||
if DEBUG >= 4: print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.bufs_to_delete, self.op_estimate, self.mem_estimate = name, prg, global_size, local_size, bufs_to_delete if bufs_to_delete else set(), op_estimate, mem_estimate
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate = name, prg, global_size, local_size, op_estimate, mem_estimate
|
||||
|
||||
def build(self, runtime):
|
||||
self.clprg = runtime(self.name, self.prg)
|
||||
return self
|
||||
|
||||
def exec(self, bufs) -> Optional[float]:
|
||||
rawbufs = [x.realized for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete]
|
||||
assert all(x is not None for x in rawbufs), "some rawbufs are None, you probably didn't realize them"
|
||||
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs)
|
||||
rawbufs = [x.realized for x in bufs if x is not None and not isinstance(x.realized, RawConst)]
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs))
|
||||
return self(rawbufs)
|
||||
|
||||
def __call__(self, rawbufs:List[RawBuffer]) -> Optional[float]:
|
||||
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs)
|
||||
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
print(f"*** {GlobalCounters.kernel_count:4d} {self.name:20s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
|
@ -143,6 +143,7 @@ class Compiled:
|
|||
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
output.realized = output.output_buffer
|
||||
if output.realized is not None:
|
||||
if isinstance(output.realized, RawConst): output.realized = None # can't assign to RawConst
|
||||
for a in get_buffers(ast):
|
||||
if a.realized == output.realized and not a.st.contiguous:
|
||||
output.realized = None
|
||||
|
|
|
@ -12,6 +12,7 @@ class RawBuffer: # pylint: disable=abstract-method
|
|||
self._memsz: int = size*dtype.itemsize
|
||||
GlobalCounters.mem_used += self._memsz
|
||||
def __del__(self): GlobalCounters.mem_used -= self._memsz
|
||||
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
|
||||
|
||||
# NOTE: this interface allows for 0 copy
|
||||
@classmethod
|
||||
|
@ -45,4 +46,5 @@ class RawBufferCopyInOut(RawBufferCopyIn):
|
|||
self._copyout(x)
|
||||
return x
|
||||
|
||||
class RawConst(RawBuffer): pass # pylint: disable=abstract-method
|
||||
class RawConst(RawBuffer): # pylint: disable=abstract-method
|
||||
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
|
||||
|
|
|
@ -10,7 +10,6 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
|
|||
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
|
||||
|
||||
base_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.NEG: lambda x: -x, UnaryOps.NOT: lambda x: (1.0 - x),
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
|
|
|
@ -189,7 +189,7 @@ class Tensor:
|
|||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
del t0._ctx
|
||||
|
||||
|
|
Loading…
Reference in New Issue