move uops add logic to linearize (#4952)

* move logic to linearize

* idk how this should work

* empty
This commit is contained in:
George Hotz 2024-06-14 03:52:37 -07:00 committed by GitHub
parent 7e32b8c930
commit 63a8add2c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 46 additions and 52 deletions

View File

@ -31,7 +31,7 @@ class UOpsFuzzerRunner(CompiledRunner):
for i, path in enumerate(self.p.uops.fuzz_paths):
# setup prg
uops = UOpGraph()
uops = UOpGraph([])
uops._uops = list(path)
if DEBUG >= 6: uops.print()
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)

View File

@ -7,7 +7,7 @@ class TestDeviceSpeed(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dev = Device[Device.DEFAULT]
cls.empty = Device[Device.DEFAULT].renderer.render("test", UOpGraph())
cls.empty = Device[Device.DEFAULT].renderer.render("test", UOpGraph([]))
def test_empty_compile(self):
with Timing("compiler "):

View File

@ -9,7 +9,7 @@ class TestUOpGraph(unittest.TestCase):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
@ -21,7 +21,7 @@ class TestUOpGraph(unittest.TestCase):
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
@ -32,7 +32,7 @@ class TestUOpGraph(unittest.TestCase):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
@ -41,7 +41,7 @@ class TestUOpGraph(unittest.TestCase):
def test_const_cast(self):
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
out = UOp(UOps.CAST, dtypes.int, (bf,))
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
@ -55,7 +55,7 @@ class TestUOpGraph(unittest.TestCase):
x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0)
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu))
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
g = UOpGraph([out])
self.assertEqual(len([x for x in g.uops if x.uop is UOps.CAST]), 0)
def test_depth_2_const_fold(self):
@ -64,7 +64,7 @@ class TestUOpGraph(unittest.TestCase):
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
g = UOpGraph([out])
self.assertEqual(len(g.uops), 3)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.ALU)

View File

@ -301,7 +301,7 @@ class TestAssembly(unittest.TestCase):
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
uops = UOpGraph([UOp(UOps.SINK, None, (a1,a2))])
uops = UOpGraph([a1,a2])
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops.uops[-1].arg, BinaryOps.MUL)
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHL)
@ -313,7 +313,7 @@ class TestAssembly(unittest.TestCase):
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
uops = UOpGraph([UOp(UOps.SINK, None, (a1,a2))])
uops = UOpGraph([a1,a2])
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops.uops[-1].arg, BinaryOps.IDIV)
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHR)

View File

@ -63,7 +63,7 @@ class TestUOpsStats(unittest.TestCase):
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
uops = UOpGraph([UOp(UOps.SINK, None, (u5,))])
uops = UOpGraph([u5])
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
@ -72,7 +72,7 @@ class TestUOpsStats(unittest.TestCase):
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
uops_fma = UOpGraph([UOp(UOps.SINK, None, (u4,))])
uops_fma = UOpGraph([u4])
self.assertEqual(uops.flops_mem(), uops_fma.flops_mem())

View File

@ -251,39 +251,14 @@ constant_folder = PatternMatcher([
# *** uop graph ***
class UOpGraph:
def __init__(self, add_nodes:Optional[List[UOp]]=None):
self.nodes: Dict[Tuple, UOp] = {}
def __init__(self, sinks:List[UOp]):
self.sinks: List[UOp] = sinks
# used by linearizer
self._uops: Optional[List[UOp]] = None
if add_nodes is not None: self._multiadd(add_nodes)
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
def __getitem__(self, index) -> UOp: return self.uops[index]
def _multiadd(self, unprocessed_nodes:List[UOp]):
# add nodes to graph in reverse BFS order
# TODO: i feel like this is written in a few places, possible to library it?
in_degree: DefaultDict[UOp, int] = defaultdict(int)
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
all_nodes: Dict[UOp, None] = dict()
while len(unprocessed_nodes):
n = unprocessed_nodes.pop(0)
if n in all_nodes: continue
all_nodes[n] = None
for x in n.vin:
in_degree[n] += 1
children[x].append(n)
unprocessed_nodes += list(n.vin)
queue = [x for x in all_nodes if in_degree[x] == 0]
replace_nodes: Dict[UOp, UOp] = {}
while len(queue):
n = queue.pop(0)
if n in replace_nodes: continue
replace_nodes[n] = self._add(n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
for x in children[n]:
in_degree[x] -= 1
if in_degree[x] == 0:
queue.append(x)
def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
@ -335,15 +310,39 @@ class UOpGraph:
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
# NOTE: relinearizering should be okay
#assert self._uops is None, "already linearized"
self.nodes: Dict[Tuple, UOp] = {}
# get sink
_sinks: List[UOp] = []
for u in self.nodes.values():
if u.uop is UOps.STORE: _sinks.append(u)
if u.uop is UOps.SINK: _sinks.extend(u.vin)
sink = UOp(UOps.SINK, None, tuple(_sinks))
del _sinks
# add nodes to graph in reverse BFS order
# dedup all nodes
# TODO: i feel like this BFS is written in a few places, possible to library it?
sink = UOp(UOps.SINK, None, tuple(self.sinks))
unprocessed_nodes = [sink]
early_in_degree: DefaultDict[UOp, int] = defaultdict(int)
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
all_nodes: Dict[UOp, None] = dict()
while len(unprocessed_nodes):
n = unprocessed_nodes.pop(0)
if n in all_nodes: continue
all_nodes[n] = None
for x in n.vin:
early_in_degree[n] += 1
children[x].append(n)
unprocessed_nodes += list(n.vin)
early_queue = [x for x in all_nodes if early_in_degree[x] == 0]
replace_nodes: Dict[UOp, UOp] = {}
while len(early_queue):
n = early_queue.pop(0)
if n in replace_nodes: continue
key = (n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
if found:=self.nodes.get(key): replace_nodes[n] = found
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
for x in children[n]:
early_in_degree[x] -= 1
if early_in_degree[x] == 0:
early_queue.append(x)
sink = replace_nodes.get(sink, sink)
# do graph rewrite
sink = self.graph_rewrite(sink, constant_folder)
if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
@ -411,11 +410,6 @@ class UOpGraph:
if type_verify: self.type_verify()
def _add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
self.nodes[key] = ret = UOp(*key)
return ret
# *** checker functions ***
def flops_mem(self) -> Tuple[sint, sint]: