mirror of https://github.com/commaai/tinygrad.git
UOps.IF* to graph spec (#4894)
This commit is contained in:
parent
b9afb0d577
commit
1dde829e34
|
@ -2,7 +2,7 @@ from typing import Optional, Tuple, Any, List
|
|||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.helpers import CI, DEBUG, getenv
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
||||
|
@ -222,6 +222,46 @@ class TestConstantFolding(unittest.TestCase):
|
|||
ji = lower_schedule_item(si[-1])
|
||||
assert any(uop.uop is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.uop for uop in ji.prg.p.uops]} does not contain bitcast"
|
||||
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
@unittest.skip("not yet implemented")
|
||||
def test_wrap_store_parents(self):
|
||||
# wraps all store parents in the valid branch
|
||||
uops = UOpGraph()
|
||||
gmem = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gidx0 = uops.add(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
value = uops.add(UOps.CONST, dtypes.float, (), 42.0)
|
||||
|
||||
gate = uops.add(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
|
||||
uops.add(UOps.STORE, None, (gmem, idx, value, gate))
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.uop is UOps.IF)
|
||||
endif = next(u for u in uops if u.uop is UOps.ENDIF)
|
||||
assert endif.vin[0] is if_uop
|
||||
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
|
||||
assert nested_uops == (gmem, gidx0, idx, value)
|
||||
|
||||
@unittest.skip("not yet implemented")
|
||||
def test_wrap_some_parents(self):
|
||||
# some parents are used outside the branch
|
||||
uops = UOpGraph()
|
||||
gmem0 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gmem1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
|
||||
gidx0 = uops.add(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
value0 = uops.add(UOps.CONST, dtypes.float, (), 42.0)
|
||||
value1 = uops.add(UOps.CONST, dtypes.float, (), 43.0)
|
||||
|
||||
gate = uops.add(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
|
||||
uops.add(UOps.STORE, None, (gmem0, idx, value0, gate))
|
||||
uops.add(UOps.STORE, None, (gmem1, idx, value1))
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.uop is UOps.IF)
|
||||
endif = next(u for u in uops if u.uop is UOps.ENDIF)
|
||||
assert endif.vin[0] is if_uop
|
||||
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
|
||||
assert nested_uops == (gmem0, value0)
|
||||
|
||||
class TestLocalAccess(unittest.TestCase):
|
||||
# NOTE: this is failing on METAL CI, no idea why. Works locally.
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL" and CI, "failing only in CI")
|
||||
|
|
Loading…
Reference in New Issue