mirror of https://github.com/commaai/tinygrad.git
spec for in order LOAD/STORE indexing (#6073)
* test_unaligns_idxs * spec for in order LOAD/STORE indexing * test UOps.SPECIAL * check for supports_float4
This commit is contained in:
parent
5048f9a4d5
commit
83a2543c74
|
@ -5,6 +5,7 @@
|
|||
import unittest
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, TernaryOps, BufferOps, MemBuffer, ConstBuffer, MetaOps # noqa: F401 # pylint: disable=unused-import
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
|
@ -100,5 +101,71 @@ class TestLinearizerDumb(unittest.TestCase):
|
|||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
|
||||
# from process replay https://github.com/tinygrad/tinygrad/actions/runs/10389229290/job/28766762085#step:18:6490
|
||||
@unittest.expectedFailure
|
||||
def test_unaligns_idxs(self):
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))), src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(2,), src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.float, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.long, st=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.long, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
load_idxs = [x.src[1] for x in k.uops if x.op is UOps.LOAD and x.src[0].arg == 3]
|
||||
assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!"
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4")
|
||||
def test_unrolled_float4_align(self):
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(0, 1), src=(
|
||||
LazyOp(TernaryOps.WHERE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.long, st=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),))), src=()),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.long, st=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),))), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
load_idxs = [x.src[1] for x in k.uops if x.op is UOps.LOAD and x.src[0].arg == 2]
|
||||
assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!"
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4")
|
||||
@unittest.skipIf(getenv("PTX"), "this is somehow correct in PTX")
|
||||
def test_upcasted_stores_out_of_order(self):
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),))), src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(6,), src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),))), src=()),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
store_idxs = [x.src[1] for x in k.uops if x.op is UOps.STORE]
|
||||
for i in range(len(store_idxs) - 1):
|
||||
first_bounds = store_idxs[i].vmin.arg+store_idxs[i].vmax.arg
|
||||
next_bounds = store_idxs[i+1].vmin.arg+store_idxs[i+1].vmax.arg
|
||||
assert first_bounds < next_bounds, f"first stored (max) idx {first_bounds} then {next_bounds}!"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -375,5 +375,44 @@ class TestUOpStr(TestEqUOps):
|
|||
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")
|
||||
assert str(eval(str(a))) == str(a)
|
||||
|
||||
class TestIndexingOrdering(unittest.TestCase):
|
||||
# NOTE: these tests skip type_verify since they add dtype to STORE
|
||||
@unittest.expectedFailure
|
||||
def test_simple_order(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = UOpGraph([st1, st0]).linearize(skip_check=True)
|
||||
stores = [st for st in uops.uops if st.op is UOps.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_ordering_multi_output(self):
|
||||
buf0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
buf1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1)
|
||||
st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = UOpGraph([st0_0, st1_0, st0_1, st1_1]).linearize(skip_check=True)
|
||||
stores = [st for st in uops.uops if st.op is UOps.STORE]
|
||||
print("\n".join(map(str, stores)))
|
||||
# buf0 stores come first
|
||||
self.assertEqual(stores[0].src[0].arg, stores[1].src[0].arg)
|
||||
# buf1 stores come next
|
||||
self.assertEqual(stores[2].src[0].arg, stores[3].src[0].arg)
|
||||
# both stores are aligned based on idx
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
def test_simple_order_with_special(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = UOpGraph([st1, st0]).linearize(skip_check=True)
|
||||
stores = [st for st in uops.uops if st.op is UOps.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
Loading…
Reference in New Issue