Split UnaryOps.CAST into CAST and BITCAST (#4487)

* Separate cast and bitcast

* Fix lint

* No more arg[0]

* Revert "No more arg[0]"

This reverts commit dee6911335513f092fe2cbb9684e8a9d26aad964.

* CAST/BITCAST arg is the dtype only, no more tuple

* No image bitcast, regenerate dataset

* Small fixes
This commit is contained in:
Ahmed Harmouche 2024-05-15 17:43:31 +02:00 committed by GitHub
parent 53d082a2aa
commit 662bca8134
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 20 additions and 18 deletions

Binary file not shown.

View File

@ -812,7 +812,7 @@ class TestKernelOpts(unittest.TestCase):
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
k = Linearizer(ast, opts=Device[Device.DEFAULT].renderer)
with self.assertRaises(KernelOptError):
k.apply_opt(Opt(OptOps.TC, 0, 1))

File diff suppressed because one or more lines are too long

View File

@ -343,7 +343,7 @@ class Kernel:
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM and self.opts.device in tensor_cores:
for tc in tensor_cores[self.opts.device]:
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
if mul_op.op is not BinaryOps.MUL: continue
@ -352,7 +352,7 @@ class Kernel:
# TODO: apply tc even if the sources are not from LOAD
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
try:
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
except ValueError: return None
return None
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue

View File

@ -423,8 +423,8 @@ class Linearizer(Kernel):
if cache is None: cache = {}
if x in cache: return cache[x]
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \
for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST, \
self.get_base_dtype(x.arg), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
if x.op in ReduceOps and not do_reduce:
assert offs is None, "not available if we aren't doing reduce"
return acc

View File

@ -132,7 +132,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
elif buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg[0], ImageDType):
elif buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
pass # don't realize image to image casts. this is part of a larger problem
else:
realizes[buf.base] = None
@ -235,7 +235,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is UnaryOps.CAST and tr.arg[0].itemsize > tr.srcs[0].dtype.itemsize:
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
tr = tr.srcs[0].base
reduce_for_op[tr] = r
realizes[tr] = None

View File

@ -15,7 +15,7 @@ from tinygrad.engine.realize import CompiledRunner
from tinygrad.renderer import Program
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(4)]
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]

View File

@ -34,7 +34,7 @@ class LazyBuffer:
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
if (self.op is LoadOps.CONTIGUOUS or (self.op is UnaryOps.CAST and self.arg[1] is True)) and srcs[0].st.consecutive and \
if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
# some LazyBuffers can be processed with only a view, no AST required
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
@ -106,7 +106,8 @@ class LazyBuffer:
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)

View File

@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
class UnaryOps(Enum):
"""A -> A (elementwise)"""
EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
class BinaryOps(Enum):
"""A + A -> A (elementwise)"""
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
@ -60,7 +60,7 @@ class LazyOp:
@functools.cached_property
def dtype(self) -> DType:
if self.op in BufferOps: return self.arg.dtype
if self.op is UnaryOps.CAST: return self.arg[0]
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else self.src[-1].dtype
@functools.cached_property
@ -94,7 +94,8 @@ InterpretedFlopCounter: Dict[Op, Callable] = {
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST},
UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501