delete extra/ops.py [pr] (#7072)

This commit is contained in:
qazal 2024-10-15 22:14:21 +03:00 committed by GitHub
parent 09de958855
commit 8ff6514ba3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 1 additions and 137 deletions

View File

@ -1,135 +0,0 @@
from __future__ import annotations
from typing import Dict, Union, Tuple, Any, List, cast
import functools, hashlib
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import dedup, prod
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps, pretty_print
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
from tinygrad.ops import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
# these ops are deleted after AST is UOp
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
class MetaOps(Enum): KERNEL = auto();
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps]
@dataclass(frozen=True)
class MemBuffer:
idx: int
dtype: DType
st: ShapeTracker
@dataclass(frozen=True)
class ConstBuffer:
val: ConstType | Variable
dtype: DType
st: ShapeTracker
@dataclass(frozen=True, eq=False)
class LazyOp:
op: Op
src: Tuple[LazyOp, ...] = ()
arg: Any = None
def cached_compare(self, x, context):
if id(self) == id(x): return True
if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
if (key := (id(self), id(x))) in context: return context[key]
ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
return ret
def __eq__(self, x): return self.cached_compare(x, context={})
def __repr__(self:LazyOp): return pretty_print(self, lambda x: f'LazyOp({x.op}, arg={x.arg}, src=(%s))')
@functools.cached_property
def dtype(self) -> DType:
if self.op in BufferOps: return self.arg.dtype
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src]))
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
@functools.cached_property
def hash(self): return hash((self.op, self.src, self.arg))
def __hash__(self): return self.hash
@functools.cached_property
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
def vars(self) -> List[Variable]:
extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x))
def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x))
def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x))
def ne(self, x:LazyOp): return LazyOp(BinaryOps.CMPNE, (self, x))
def eq(self, x:LazyOp): return -self.ne(x)
def __neg__(self): return LazyOp(UnaryOps.NEG, (self,))
@staticmethod
def const(val, dtype:DType, shape:Tuple[sint, ...]):
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)))
# the living definition of LazyOps
def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
assert ast.op is MetaOps.KERNEL, "must be SINK"
sts: Dict[LazyOp, ShapeTracker] = {}
def assert_valid(op:LazyOp, st:ShapeTracker):
if op in sts: return
# restore globals from the two stage reduce
if op.op is BufferOps.LOAD and op.arg.idx < 0:
assert_valid(local_reduce:=op.src[0].src[0], op.arg.st)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op in ReduceOps:
axis = op.arg
assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
st = ShapeTracker.from_shape(sts[op.src[0]].reduce(axis))
else:
# movementops are pushed to the edges with LOAD
# elementwise inherits shape
st = op.arg.st if op.op in BufferOps else sts[op.src[0]]
for x in op.src:
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for i, out in enumerate(ast.src):
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
assert_valid(out, out.arg.st)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts
def to_uop(*a) -> UOp:
assert isinstance(a[0], LazyOp), f"{a} must be a LazyOp ast"
if a[0].op is BufferOps.STORE: ast = LazyOp(MetaOps.KERNEL, a)
else:
assert a[0].op is MetaOps.KERNEL
ast = a[0]
verify_lazyop(ast)
@functools.lru_cache(None)
def create_uop(lop:LazyOp) -> UOp:
if lop.op in BufferOps:
st_uop = lop.arg.st.to_uop()
membuf_dtype: DType = lop.arg.dtype
dtype = membuf_dtype.base if isinstance(membuf_dtype, ImageDType) else membuf_dtype
if lop.op is BufferOps.CONST:
return UOp(UOps.CONST, dtype, (st_uop,), lop.arg.val)
buf = UOp(UOps.DEFINE_GLOBAL, membuf_dtype if isinstance(membuf_dtype, ImageDType) else PtrDType(membuf_dtype), (), lop.arg.idx)
if lop.op is BufferOps.LOAD: return UOp(UOps.LOAD, dtype, (buf, st_uop))
return UOp(UOps.STORE, dtypes.void, (buf, st_uop, create_uop(lop.src[0])))
src = tuple(create_uop(x) for x in lop.src)
if lop.op is MetaOps.KERNEL: return UOp(UOps.SINK, dtypes.void, src)
if lop.op in ReduceOps:
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, lop.op)]
return UOp(UOps.REDUCE_AXIS, src[0].dtype, src, (alu_op, lop.arg))
if lop.op is UnaryOps.CAST: return UOp(UOps.CAST, lop.arg.scalar(), src)
if lop.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, lop.arg.scalar(), src)
return src[0].alu(lop.op, *src[1:])
ret = create_uop(ast)
#with open("/tmp/ast", "w") as f: f.write(str(ret))
return ret

View File

@ -1,9 +1,8 @@
# stuff needed to unpack a kernel
from typing import Tuple
from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
from tinygrad import Variable
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.ops import UOp, UOps, KernelInfo
from tinygrad.ops import UOp, UOps, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View