mirror of https://github.com/commaai/tinygrad.git
move uops add logic to linearize (#4952)
* move logic to linearize * idk how this should work * empty
This commit is contained in:
parent
7e32b8c930
commit
63a8add2c2
|
@ -31,7 +31,7 @@ class UOpsFuzzerRunner(CompiledRunner):
|
|||
|
||||
for i, path in enumerate(self.p.uops.fuzz_paths):
|
||||
# setup prg
|
||||
uops = UOpGraph()
|
||||
uops = UOpGraph([])
|
||||
uops._uops = list(path)
|
||||
if DEBUG >= 6: uops.print()
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
|
||||
|
|
|
@ -7,7 +7,7 @@ class TestDeviceSpeed(unittest.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.dev = Device[Device.DEFAULT]
|
||||
cls.empty = Device[Device.DEFAULT].renderer.render("test", UOpGraph())
|
||||
cls.empty = Device[Device.DEFAULT].renderer.render("test", UOpGraph([]))
|
||||
|
||||
def test_empty_compile(self):
|
||||
with Timing("compiler "):
|
||||
|
|
|
@ -9,7 +9,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
|
@ -21,7 +21,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
|
@ -32,7 +32,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
|
@ -41,7 +41,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
def test_const_cast(self):
|
||||
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
out = UOp(UOps.CAST, dtypes.int, (bf,))
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
|
@ -55,7 +55,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0)
|
||||
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu))
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len([x for x in g.uops if x.uop is UOps.CAST]), 0)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
|
@ -64,7 +64,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 3)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.ALU)
|
||||
|
|
|
@ -301,7 +301,7 @@ class TestAssembly(unittest.TestCase):
|
|||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
|
||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
||||
uops = UOpGraph([UOp(UOps.SINK, None, (a1,a2))])
|
||||
uops = UOpGraph([a1,a2])
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops.uops[-1].arg, BinaryOps.MUL)
|
||||
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHL)
|
||||
|
@ -313,7 +313,7 @@ class TestAssembly(unittest.TestCase):
|
|||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||
uops = UOpGraph([UOp(UOps.SINK, None, (a1,a2))])
|
||||
uops = UOpGraph([a1,a2])
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops.uops[-1].arg, BinaryOps.IDIV)
|
||||
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHR)
|
||||
|
|
|
@ -63,7 +63,7 @@ class TestUOpsStats(unittest.TestCase):
|
|||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
||||
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||
uops = UOpGraph([UOp(UOps.SINK, None, (u5,))])
|
||||
uops = UOpGraph([u5])
|
||||
|
||||
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
|
@ -72,7 +72,7 @@ class TestUOpsStats(unittest.TestCase):
|
|||
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
||||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
||||
uops_fma = UOpGraph([UOp(UOps.SINK, None, (u4,))])
|
||||
uops_fma = UOpGraph([u4])
|
||||
|
||||
self.assertEqual(uops.flops_mem(), uops_fma.flops_mem())
|
||||
|
||||
|
|
|
@ -251,39 +251,14 @@ constant_folder = PatternMatcher([
|
|||
# *** uop graph ***
|
||||
|
||||
class UOpGraph:
|
||||
def __init__(self, add_nodes:Optional[List[UOp]]=None):
|
||||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
def __init__(self, sinks:List[UOp]):
|
||||
self.sinks: List[UOp] = sinks
|
||||
# used by linearizer
|
||||
self._uops: Optional[List[UOp]] = None
|
||||
if add_nodes is not None: self._multiadd(add_nodes)
|
||||
|
||||
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
||||
def __getitem__(self, index) -> UOp: return self.uops[index]
|
||||
|
||||
def _multiadd(self, unprocessed_nodes:List[UOp]):
|
||||
# add nodes to graph in reverse BFS order
|
||||
# TODO: i feel like this is written in a few places, possible to library it?
|
||||
in_degree: DefaultDict[UOp, int] = defaultdict(int)
|
||||
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
all_nodes: Dict[UOp, None] = dict()
|
||||
while len(unprocessed_nodes):
|
||||
n = unprocessed_nodes.pop(0)
|
||||
if n in all_nodes: continue
|
||||
all_nodes[n] = None
|
||||
for x in n.vin:
|
||||
in_degree[n] += 1
|
||||
children[x].append(n)
|
||||
unprocessed_nodes += list(n.vin)
|
||||
queue = [x for x in all_nodes if in_degree[x] == 0]
|
||||
replace_nodes: Dict[UOp, UOp] = {}
|
||||
while len(queue):
|
||||
n = queue.pop(0)
|
||||
if n in replace_nodes: continue
|
||||
replace_nodes[n] = self._add(n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
|
||||
for x in children[n]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0:
|
||||
queue.append(x)
|
||||
|
||||
def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
|
||||
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
|
||||
|
||||
|
@ -335,15 +310,39 @@ class UOpGraph:
|
|||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
||||
# NOTE: relinearizering should be okay
|
||||
#assert self._uops is None, "already linearized"
|
||||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
|
||||
# get sink
|
||||
_sinks: List[UOp] = []
|
||||
for u in self.nodes.values():
|
||||
if u.uop is UOps.STORE: _sinks.append(u)
|
||||
if u.uop is UOps.SINK: _sinks.extend(u.vin)
|
||||
sink = UOp(UOps.SINK, None, tuple(_sinks))
|
||||
del _sinks
|
||||
# add nodes to graph in reverse BFS order
|
||||
# dedup all nodes
|
||||
# TODO: i feel like this BFS is written in a few places, possible to library it?
|
||||
sink = UOp(UOps.SINK, None, tuple(self.sinks))
|
||||
unprocessed_nodes = [sink]
|
||||
early_in_degree: DefaultDict[UOp, int] = defaultdict(int)
|
||||
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
all_nodes: Dict[UOp, None] = dict()
|
||||
while len(unprocessed_nodes):
|
||||
n = unprocessed_nodes.pop(0)
|
||||
if n in all_nodes: continue
|
||||
all_nodes[n] = None
|
||||
for x in n.vin:
|
||||
early_in_degree[n] += 1
|
||||
children[x].append(n)
|
||||
unprocessed_nodes += list(n.vin)
|
||||
early_queue = [x for x in all_nodes if early_in_degree[x] == 0]
|
||||
replace_nodes: Dict[UOp, UOp] = {}
|
||||
while len(early_queue):
|
||||
n = early_queue.pop(0)
|
||||
if n in replace_nodes: continue
|
||||
key = (n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
|
||||
if found:=self.nodes.get(key): replace_nodes[n] = found
|
||||
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
|
||||
for x in children[n]:
|
||||
early_in_degree[x] -= 1
|
||||
if early_in_degree[x] == 0:
|
||||
early_queue.append(x)
|
||||
sink = replace_nodes.get(sink, sink)
|
||||
|
||||
# do graph rewrite
|
||||
sink = self.graph_rewrite(sink, constant_folder)
|
||||
if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
|
||||
|
||||
|
@ -411,11 +410,6 @@ class UOpGraph:
|
|||
|
||||
if type_verify: self.type_verify()
|
||||
|
||||
def _add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
|
||||
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
|
||||
self.nodes[key] = ret = UOp(*key)
|
||||
return ret
|
||||
|
||||
# *** checker functions ***
|
||||
|
||||
def flops_mem(self) -> Tuple[sint, sint]:
|
||||
|
|
Loading…
Reference in New Issue