Fix constant folding (#713)

* fix

* codegen

* contiguous is real

* no bufs_to_delete

* don't assign rawconst

* remove neg and not

* need exec to fix custom function jit
This commit is contained in:
George Hotz 2023-03-18 17:52:46 -07:00 committed by GitHub
parent 73bd0b217b
commit 902906f909
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 45 additions and 37 deletions

View File

@ -159,7 +159,7 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations
```
Buffer # class of memory on this device
unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A
unary_op (NOOP, EXP, LOG, CAST) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size)

View File

@ -67,7 +67,7 @@ class TestOpt(unittest.TestCase):
# TODO: this should be 4, but the sum output child stays around
# with pushing_permutes it can be 3
# TODO: broken with optim fixes
assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
assert len(GlobalCounters.cache) in [4,5,6,7], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
Tensor.training = False
def test_fold_conv_batchnorm_sgd(self):
@ -83,7 +83,7 @@ class TestOpt(unittest.TestCase):
img_bn.backward()
opt.step()
# TODO: broken with optim fixes
assert len(GlobalCounters.cache) in [9,10,13], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}"
assert len(GlobalCounters.cache) in [9,10,13,14], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}"
Tensor.training = False
def test_fold_conv_batchnorm_notrain(self):

View File

@ -43,7 +43,7 @@ class ATan2(Function):
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b))
return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \
grad_output.binary_op(BinaryOps.MUL, self.a.unary_op(UnaryOps.NEG).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None
grad_output.binary_op(BinaryOps.MUL, self.a.const_like(0).binary_op(BinaryOps.SUB, self.a).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None
# *** third, we use our lovely new mlop in some tests ***

View File

@ -103,6 +103,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y)
def test_sub(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
def test_neg(self):
helper_test_op([(45,65)], lambda x: -x)
def test_mul(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)
def test_div(self):

View File

@ -50,13 +50,14 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F
class GPUCodegen(ASTKernel):
lang: ClassVar[GPULanguage] = GPULanguage()
supports_constant_folding: bool = True
# for renaming
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
kernel_name_cache: Final[Dict[str, str]] = {}
code_for_op: Final[Dict[Op, str]] = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)", UnaryOps.CAST: "(A)",
UnaryOps.NOOP: "(A)", UnaryOps.CAST: "(A)",
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
@ -98,8 +99,6 @@ class GPUCodegen(ASTKernel):
# constant folding
const = None
if self.bufs[buf_index] is not None and isinstance(self.bufs[buf_index].realized, RawConst):
# bufs_to_delete can be removed, just ignore RawConst at runtime
if buf_index != 0: self.bufs_to_delete.add(buf_index)
val = self.bufs[buf_index].realized._buf
assert not math.isnan(val)
const = Token(f"({val}f)", Types.FLOAT)
@ -276,7 +275,6 @@ class GPUCodegen(ASTKernel):
print("output shape", self.output_shape)
self.printbufs("new:", DEBUG>=5)
self.bufs_to_delete: Set[int] = set()
self.loaded_keys: Dict[Tuple[int,int], Token] = {}
self.prekernel: Set[str] = set()
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(buf.dtype.name.startswith("image") for buf in self.bufs if buf is not None) else []
@ -337,9 +335,9 @@ class GPUCodegen(ASTKernel):
self.kernel.append("\n}")
# concat kernel into prg
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix for i,x in enumerate(self.bufs) if x is not None]
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix) for i,x in enumerate(self.bufs) if x is not None and not isinstance(x.realized, RawConst)]
prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] +
[', '.join([f'{t} data{i}' for i,t in buftypes] + self.lang.extra_args)] +
[") {\n"] + self.kernel)
# kernel function definition
@ -352,7 +350,7 @@ class GPUCodegen(ASTKernel):
if GPUCodegen.kernel_cnt[function_name] > 1: function_name = f"{function_name}{'n'+str(GPUCodegen.kernel_cnt[function_name]-1)}"
GPUCodegen.kernel_name_cache[prg] = function_name
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1],
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
op_estimate=self.info.flops,

View File

@ -22,10 +22,8 @@ render_llvm = {
class LLVMCodegen(ASTKernel):
op_lookup: ClassVar = {
UnaryOps.NOOP: lambda builder,x: x,
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.NOT: lambda builder,x: builder.fsub(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),

View File

@ -101,14 +101,17 @@ class LazyBuffer:
for x in get_buffers(op): x.children.add(self)
if not LAZY: self.realize()
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else 'realized'} st:{self.st}>"
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else self.realized} st:{self.st}>"
def realize(self:LazyBuffer, required_device=None) -> LazyBuffer:
assert required_device is None or required_device == self.device
if self.realized is None:
# get real ops first
if self.op.op == LoadOps.FROMCPU:
self.realized = Device[self.device].buffer.fromCPU(self.op.arg())
if prod(self.op.arg.shape) == 1 and hasattr(Device[self.device].codegen, 'supports_constant_folding'):
self.realized = RawConst(1, dtypes.from_np(self.op.arg.dtype), self.op.arg().flatten()[0])
else:
self.realized = Device[self.device].buffer.fromCPU(self.op.arg())
elif self.op.op == LoadOps.CONTIGUOUS:
realized = self.op.src[0].realize(self.device).realized
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
@ -155,6 +158,11 @@ class LazyBuffer:
def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer:
return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x.dtype))
# create a constant with the shape and dtype of self
def const_like(self, val) -> LazyBuffer:
return LazyBuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \
.movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape)
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
def toCPU(self):
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized

View File

@ -73,7 +73,7 @@ class Maximum(Function):
def backward(self, grad_output):
mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret)
# TODO: if they are equal, do they split the gradient?
return grad_output.binary_op(BinaryOps.MUL, mask.unary_op(UnaryOps.NOT)) if self.needs_input_grad[0] else None, \
return grad_output.binary_op(BinaryOps.MUL, mask.const_like(1).binary_op(BinaryOps.SUB, mask)) if self.needs_input_grad[0] else None, \
grad_output.binary_op(BinaryOps.MUL, mask) if self.needs_input_grad[1] else None
class Add(Function):
@ -85,28 +85,28 @@ class Add(Function):
grad_output if self.needs_input_grad[1] else None
class Sub(Function):
def forward(self, x, y):
def forward(self, x:LazyBuffer, y:LazyBuffer):
return x.binary_op(BinaryOps.SUB, y)
def backward(self, grad_output) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output.unary_op(UnaryOps.NEG) if self.needs_input_grad[1] else None
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
class Mul(Function):
def forward(self, x, y):
def forward(self, x:LazyBuffer, y:LazyBuffer):
self.x, self.y = x, y
return x.binary_op(BinaryOps.MUL, y)
def backward(self, grad_output) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Pow(Function):
def forward(self, x, y):
def forward(self, x:LazyBuffer, y:LazyBuffer):
self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y)
return self.ret
def backward(self, grad_output):
def backward(self, grad_output:LazyBuffer):
return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
@ -117,7 +117,7 @@ class Div(Function):
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.binary_op(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
# ************* movement ops *************

View File

@ -1,14 +1,14 @@
from __future__ import annotations
import functools, itertools, operator, random
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Set, Callable
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType
from tinygrad.shape.shapetracker import MovementOps
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.lib import RawBuffer, RawConst
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); NEG = auto(); NOT = auto(); CAST = auto() # noqa: E702
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702
@ -39,6 +39,7 @@ class Interpreted:
self.fxn_for_op = fxn_for_op
self.from_lazybuffer = from_lazybuffer
self.to_underlying = to_underlying
self.codegen = None
def exec_ast(self, ast:LazyOp, output=None, context=None):
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
@ -48,7 +49,7 @@ class Interpreted:
if not created_context and ast in context: return context[ast]
srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "")
if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "")
if not created_context: context[ast] = ret
if output is not None and output.output_buffer is not None:
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
@ -58,7 +59,7 @@ class Interpreted:
return ret
class FlopCounter:
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops = tup
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops, self._buf = *tup, self
def consume_flops(self):
self.flops, ret = 0, self.flops
return ret
@ -75,22 +76,21 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
# **************** for Compiled Buffers ****************
class ASTRunner:
def __init__(self, name, prg, bufs_to_delete:Optional[Set[int]]=None, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0):
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.bufs_to_delete, self.op_estimate, self.mem_estimate = name, prg, global_size, local_size, bufs_to_delete if bufs_to_delete else set(), op_estimate, mem_estimate
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate = name, prg, global_size, local_size, op_estimate, mem_estimate
def build(self, runtime):
self.clprg = runtime(self.name, self.prg)
return self
def exec(self, bufs) -> Optional[float]:
rawbufs = [x.realized for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete]
assert all(x is not None for x in rawbufs), "some rawbufs are None, you probably didn't realize them"
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs)
rawbufs = [x.realized for x in bufs if x is not None and not isinstance(x.realized, RawConst)]
if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs))
return self(rawbufs)
def __call__(self, rawbufs:List[RawBuffer]) -> Optional[float]:
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs)
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"*** {GlobalCounters.kernel_count:4d} {self.name:20s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
@ -143,6 +143,7 @@ class Compiled:
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
output.realized = output.output_buffer
if output.realized is not None:
if isinstance(output.realized, RawConst): output.realized = None # can't assign to RawConst
for a in get_buffers(ast):
if a.realized == output.realized and not a.st.contiguous:
output.realized = None

View File

@ -12,6 +12,7 @@ class RawBuffer: # pylint: disable=abstract-method
self._memsz: int = size*dtype.itemsize
GlobalCounters.mem_used += self._memsz
def __del__(self): GlobalCounters.mem_used -= self._memsz
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
# NOTE: this interface allows for 0 copy
@classmethod
@ -45,4 +46,5 @@ class RawBufferCopyInOut(RawBufferCopyIn):
self._copyout(x)
return x
class RawConst(RawBuffer): pass # pylint: disable=abstract-method
class RawConst(RawBuffer): # pylint: disable=abstract-method
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"

View File

@ -10,7 +10,6 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
base_fxn_for_op: Dict[Op, Callable] = {
UnaryOps.NEG: lambda x: -x, UnaryOps.NOT: lambda x: (1.0 - x),
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],

View File

@ -189,7 +189,7 @@ class Tensor:
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx