tinygrad/test/test_uop_graph.py

670 lines
33 KiB
Python

from typing import List
import unittest, time
from test.helpers import TestUOps
from tinygrad import dtypes, Variable, Device
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, NOp, PatternMatcher, KernelInfo
from tinygrad.codegen.lowerer import ast_to_uop
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding
from tinygrad.shape.shapetracker import ShapeTracker, View
simple_pm = PatternMatcher([
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(NOp.cvar('x') + NOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(NOp.cvar('x') * NOp.cvar('y') * NOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)),
])
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
class TestGraphRewriteEfficiency(unittest.TestCase):
def test_create_many_uops(self):
c1 = UOp.const(dtypes.int, 1)
c2 = UOp.const(dtypes.int, 2)
st = time.perf_counter()
uops = [UOp(UOps.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)]
et = time.perf_counter() - st
print(f"created {len(uops)} uops in {et*1000:.2f} ms")
def test_expand_rewrite(self):
sink = UOp(UOps.SINK, None, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
offset=0, mask=None, contiguous=False),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
lower_sink = ast_to_uop(sink, Device[Device.DEFAULT].renderer)
cnt = [0]
old_init = UOp.__init__
def uop_hook(self, *args, **kwargs):
cnt[0] += 1
old_init(self, *args, **kwargs)
UOp.__init__ = uop_hook
st = time.perf_counter()
new_sink = full_graph_rewrite(lower_sink)
et = time.perf_counter() - st
UOp.__init__ = old_init
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.sparents)} -> {len(new_sink.sparents)}, creating {cnt[0]} uops")
#from collections import Counter
#print(Counter(x.op for x in new_sink.sparents))
#from tinygrad.engine.graph import graph_uops
#graph_uops(linearize_uop(new_sink))
class TestGraphRewrite(unittest.TestCase):
def test_dedup(self):
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
v2 = UOp(UOps.DEFINE_VAR, dtypes.float)
nout = graph_rewrite(v1+v2, PatternMatcher([]))
self.assertIs(nout.src[0], nout.src[1])
def test_simple(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(c1+c2, simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.arg, 3.0)
def test_depth_2_late(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite(c1*c2*(c3+c3), simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.arg, 12.0)
def test_double(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite(c1+c2+c3, simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.arg, 6.0)
def test_triple(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
c4 = UOp.const(dtypes.float, 4.0)
nout = graph_rewrite(c1+c2+c3+c4, simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.arg, 10.0)
def test_diamond(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite((c1+c2)+(c1+c3), simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.arg, 7.0)
def test_magic_4(self):
c1 = UOp.const(dtypes.int, 4.0)
nout = graph_rewrite(c1, simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.arg, 3.0)
def test_depth_2_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.float)
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(v+c1+c2, simple_pm)
self.assertEqual(nout.op, UOps.ALU)
self.assertEqual(nout.src[0].op, UOps.DEFINE_VAR)
self.assertEqual(nout.src[1].op, UOps.CONST)
self.assertEqual(nout.src[1].arg, 3.0)
def test_consts_go_last(self):
a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1))
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
for out in outs:
sink = graph_rewrite(out, constant_folder)
print(sink)
self.assertEqual(sink.op, UOps.ALU)
self.assertEqual(sink.src[1].op, UOps.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
class TestUOpGraph(TestUOps):
def test_add_constant_fold(self):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp(UOps.CONST, dtypes.int, (), 0), UOp(UOps.CONST, dtypes.int, (), 1)), arg=Variable('tmp', 0, 1))
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 1.0)
def test_where_const_fold(self):
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 2.0)
def test_const_cast(self):
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
out = UOp(UOps.CAST, dtypes.int, (bf,))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 0)
def test_noop_vectorize_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0)
idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
def test_gep_vec_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1)
d2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 2)
idx = UOp.const(dtypes.int, 0)
def _test_vec(geps, count=4):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
out = UOp(UOps.STORE, None, (d0, idx, vec))
uops = to_uops_list([out])
if DEBUG >= 4:
from tinygrad import Device
print(Device[Device.DEFAULT].renderer.render("test", uops))
return uops[-1].src[-1]
# possible
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in range(4))
self.assert_equiv_uops(_test_vec(xyzw), val)
# unaligned
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
wzyx = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in reversed(range(4)))
self.assertIs(_test_vec(wzyx).op, UOps.VECTORIZE)
# different_size
val = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2))
self.assertIs(_test_vec(xy+xy).op, UOps.VECTORIZE)
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2))
self.assertIs(_test_vec(xy, count=2).op, UOps.VECTORIZE)
# different vals
val1 = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
val2 = UOp(UOps.LOAD, dtypes.float.vec(2), (d2, idx))
xy1 = tuple(UOp(UOps.GEP, dtypes.float, (val1, ), i) for i in range(2))
xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), i) for i in range(2))
self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE)
def test_gep_vec_const_fold(self):
for vec_size in [2, 4, 8]:
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)])
for uop, const in zip(uops, consts):
self.assert_equiv_uops(uop, const)
def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1)
def test_wmma_vectorize_no_fold(self):
for i in [4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
def test_cast_alu_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
alu = ld.lt(1).cast(dtypes.bool)
out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
def test_double_cast_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0)
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1))
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
uops = to_uops_list([out])
self.assertEqual(len(uops), 5)
out = uops[-1]
self.assertEqual(out.op, UOps.ALU)
self.assertEqual(out.arg, BinaryOps.ADD)
self.assertEqual(out.src[1].op, UOps.CONST)
self.assertEqual(out.src[1].arg, 6)
def test_fold_gated_load(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
glbl1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1)
glbl2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 2)
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1))
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(UOps.STORE, None, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(UOps.BARRIER, None, (st, ))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
def test_fold_gated_store(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)
st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 4)
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
def test_asserts_bad_gate(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
idx = UOp.const(dtypes.int, 0)
bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
def test_switched_range_order(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
c0 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 2)
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 0, False))
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
store = UOp(UOps.STORE, None, (glbl, alu, cf))
uops = to_uops_list([store])
ranges = [x for x in uops if x.op is UOps.RANGE]
endranges = [x for x in uops if x.op is UOps.ENDRANGE]
# ranges are closed in the right order
self.assertEqual(endranges[-1].src[0], ranges[0])
def expander_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + float4_folding)
class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = expander_rewrite(e1+3)
assert sink.op is UOps.EXPAND and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [3,4,5,6])
def test_contract_simple(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
def test_contract_axis_1(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12])
self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15])
def test_contract_axis_2(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
def test_contract_axis_2_big(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4])
self.assertListEqual([x.arg for x in sink.src[6].src], [10,14])
def test_contract_multi_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3,2),(2,2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6])
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,2),(3,2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6])
def test_contract_mid(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 2
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2])
self.assertListEqual([x.arg for x in sink.src[1].src], [1,3])
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
def test_contract_no_expand(self):
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 2
assert sink.src[0] == sink.src[1]
def test_contract_half_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
assert sink.src[0] == sink.src[1]
assert sink.src[0] != sink.src[2]
assert sink.src[6] == sink.src[7]
def test_expand_same_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
sink = expander_rewrite(e1+e2)
assert sink.op is UOps.EXPAND and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,5,10,15])
def test_expand_different_axis(self, flip=False):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
assert sink.op is UOps.EXPAND and len(sink.src) == 16
assert sink.arg == ((1, 4), (2, 4))
self.assertListEqual([x.arg for x in sink.src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
@unittest.skip("no longer supported")
def test_reduce_known_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(UOps.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD)
sink = expander_rewrite(sink)
assert sink.op is UOps.CONST
self.assertEqual(sink.arg, 3*(0+1+2+3))
@unittest.skip("no longer supported")
def test_reduce_const(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(UOps.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD)
sink = expander_rewrite(sink)
assert sink.op is UOps.CONST
self.assertEqual(sink.arg, 3*4)
def test_double_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((1,2),))
sink = expander_rewrite(e)
assert sink.op is UOps.EXPAND and len(sink.src) == 8
assert sink.arg == ((1, 2), (2, 4))
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
def test_double_expand_reverse(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),))
sink = expander_rewrite(e)
assert sink.op is UOps.EXPAND and len(sink.src) == 8
assert sink.arg == ((1, 4), (2, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
def test_double_expand_middle(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),))
sink = expander_rewrite(e)
assert sink.op is UOps.EXPAND and len(sink.src) == 8
assert sink.arg == ((1, 2), (2, 2), (3, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 1, 4, 5, 2, 3, 6, 7])
# does this need to work?
@unittest.expectedFailure
@unittest.skip
def test_reduce_different_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
sink = UOp(UOps.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD)
sink = expander_rewrite(sink)
print(sink)
class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
def test_two_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,8),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
def test_simple_load_fold_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertListEqual([src.arg for src in single_load.src[2].src], [0.0, 1.0, 2.0, 3.0])
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
def test_simple_store_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
def test_simple_store_fold_gate(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
assert len(one_store.src) == 4
assert str(one_store.src[3]) == str(gate) # huh, why do i need str here?
def test_simple_store_dont_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
print(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
class TestIFUOps(TestUOps):
def test_create_ifs(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
idx = UOp.const(dtypes.int, 0)
st = UOp(UOps.STORE, None, (sbuf, idx, UOp.const(dtypes.float, 42)))
barrier = UOp(UOps.BARRIER, None, (st,))
lbuf = UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier))
store = UOp(UOps.STORE, None, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate))
sink = UOp(UOps.SINK, None, (store,))
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
def test_expand_ifs_one_gate(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid&(lidx.ne(2))
st = UOp(UOps.STORE, None, (sbuf, lidx, UOp.const(dtypes.float, 42)))
barrier = UOp(UOps.BARRIER, None, (st,))
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
stores = [UOp(UOps.STORE, None, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(stores))
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
# this will be fixed with the merge gated stores bounty
@unittest.expectedFailure
def test_expand_ifs_dumb(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
stores = [UOp(UOps.STORE, None, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(stores))
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
if __name__ == '__main__':
unittest.main(verbosity=2)