mirror of https://github.com/commaai/tinygrad.git
Split UnaryOps.CAST into CAST and BITCAST (#4487)
* Separate cast and bitcast * Fix lint * No more arg[0] * Revert "No more arg[0]" This reverts commit dee6911335513f092fe2cbb9684e8a9d26aad964. * CAST/BITCAST arg is the dtype only, no more tuple * No image bitcast, regenerate dataset * Small fixes
This commit is contained in:
parent
53d082a2aa
commit
662bca8134
Binary file not shown.
|
@ -812,7 +812,7 @@ class TestKernelOpts(unittest.TestCase):
|
|||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
k = Linearizer(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(Opt(OptOps.TC, 0, 1))
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -343,7 +343,7 @@ class Kernel:
|
|||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
for tc in tensor_cores[self.opts.device]:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
||||
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
|
||||
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
if mul_op.op is not BinaryOps.MUL: continue
|
||||
|
@ -352,7 +352,7 @@ class Kernel:
|
|||
# TODO: apply tc even if the sources are not from LOAD
|
||||
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
try:
|
||||
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
except ValueError: return None
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
|
||||
|
|
|
@ -423,8 +423,8 @@ class Linearizer(Kernel):
|
|||
if cache is None: cache = {}
|
||||
if x in cache: return cache[x]
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \
|
||||
for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST, \
|
||||
self.get_base_dtype(x.arg), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
|
|
|
@ -132,7 +132,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
|||
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
|
||||
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
||||
simple_pads.add(buf.base)
|
||||
elif buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg[0], ImageDType):
|
||||
elif buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
|
||||
pass # don't realize image to image casts. this is part of a larger problem
|
||||
else:
|
||||
realizes[buf.base] = None
|
||||
|
@ -235,7 +235,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
|||
if not st.contiguous or tr_next.op in ReduceOps: break
|
||||
tr = tr_next
|
||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||
if tr.op is UnaryOps.CAST and tr.arg[0].itemsize > tr.srcs[0].dtype.itemsize:
|
||||
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
|
||||
tr = tr.srcs[0].base
|
||||
reduce_for_op[tr] = r
|
||||
realizes[tr] = None
|
||||
|
|
|
@ -15,7 +15,7 @@ from tinygrad.engine.realize import CompiledRunner
|
|||
from tinygrad.renderer import Program
|
||||
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
||||
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(4)]
|
||||
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
|
||||
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
||||
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
|
||||
|
|
|
@ -34,7 +34,7 @@ class LazyBuffer:
|
|||
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
||||
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
||||
|
||||
if (self.op is LoadOps.CONTIGUOUS or (self.op is UnaryOps.CAST and self.arg[1] is True)) and srcs[0].st.consecutive and \
|
||||
if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \
|
||||
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
|
||||
# some LazyBuffers can be processed with only a view, no AST required
|
||||
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
||||
|
@ -106,7 +106,8 @@ class LazyBuffer:
|
|||
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
||||
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
||||
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
|
||||
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
||||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
|
|
|
@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
||||
class UnaryOps(Enum):
|
||||
"""A -> A (elementwise)"""
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(Enum):
|
||||
"""A + A -> A (elementwise)"""
|
||||
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
|
||||
|
@ -60,7 +60,7 @@ class LazyOp:
|
|||
@functools.cached_property
|
||||
def dtype(self) -> DType:
|
||||
if self.op in BufferOps: return self.arg.dtype
|
||||
if self.op is UnaryOps.CAST: return self.arg[0]
|
||||
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
|
||||
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else self.src[-1].dtype
|
||||
|
||||
@functools.cached_property
|
||||
|
@ -94,7 +94,8 @@ InterpretedFlopCounter: Dict[Op, Callable] = {
|
|||
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST},
|
||||
UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
|
||||
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
||||
|
|
Loading…
Reference in New Issue