UOps.IF* to graph spec (#4894)

This commit is contained in:
qazal 2024-06-09 19:00:12 +08:00 committed by GitHub
parent b9afb0d577
commit 1dde829e34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 41 additions and 1 deletions

View File

@ -2,7 +2,7 @@ from typing import Optional, Tuple, Any, List
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI, getenv
from tinygrad.helpers import CI, DEBUG, getenv
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
@ -222,6 +222,46 @@ class TestConstantFolding(unittest.TestCase):
ji = lower_schedule_item(si[-1])
assert any(uop.uop is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.uop for uop in ji.prg.p.uops]} does not contain bitcast"
class TestGatedStoreRewrite(unittest.TestCase):
@unittest.skip("not yet implemented")
def test_wrap_store_parents(self):
# wraps all store parents in the valid branch
uops = UOpGraph()
gmem = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gidx0 = uops.add(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
value = uops.add(UOps.CONST, dtypes.float, (), 42.0)
gate = uops.add(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
uops.add(UOps.STORE, None, (gmem, idx, value, gate))
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.uop is UOps.IF)
endif = next(u for u in uops if u.uop is UOps.ENDIF)
assert endif.vin[0] is if_uop
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
assert nested_uops == (gmem, gidx0, idx, value)
@unittest.skip("not yet implemented")
def test_wrap_some_parents(self):
# some parents are used outside the branch
uops = UOpGraph()
gmem0 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gmem1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
gidx0 = uops.add(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
value0 = uops.add(UOps.CONST, dtypes.float, (), 42.0)
value1 = uops.add(UOps.CONST, dtypes.float, (), 43.0)
gate = uops.add(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
uops.add(UOps.STORE, None, (gmem0, idx, value0, gate))
uops.add(UOps.STORE, None, (gmem1, idx, value1))
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.uop is UOps.IF)
endif = next(u for u in uops if u.uop is UOps.ENDIF)
assert endif.vin[0] is if_uop
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
assert nested_uops == (gmem0, value0)
class TestLocalAccess(unittest.TestCase):
# NOTE: this is failing on METAL CI, no idea why. Works locally.
@unittest.skipIf(Device.DEFAULT == "METAL" and CI, "failing only in CI")