tinygrad/test/test_schedule.py

435 lines
12 KiB
Python
Raw Normal View History

# this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
import unittest
from typing import List, Optional
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
2024-01-10 12:17:13 +08:00
from tinygrad.helpers import DEBUG, GRAPH
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.engine.schedule import create_schedule
from tinygrad import nn, dtypes
2023-09-29 16:21:51 +08:00
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
seen = set()
if to_prerealize:
for pre in to_prerealize:
for s in create_schedule([pre.lazydata], seen.copy()):
for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0)
seen.add(out)
sched = create_schedule([t.lazydata], seen)
2024-01-10 12:17:13 +08:00
if GRAPH:
for i,s in enumerate(sched):
for out in s.outputs: realized_lazybuffer(out, i+1)
if filter_loadops: sched = [s for s in sched if s.ast[0].op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
print("kernel", i+1)
for op in s.ast: print_tree(op)
assert len(sched) == allowed
2023-09-29 16:21:51 +08:00
# test the (non loadops) ops linearize
for s in sched:
if s.ast[0].op in LoadOps: continue
l = Linearizer(*s.ast)
l.hand_coded_optimizations()
l.linearize()
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = a+b+c
check_schedule(d, 1)
def test_basic_binop_fusion_deep(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = Tensor.empty(10)
e = a+b+c+d
check_schedule(e, 1)
def test_mulacc_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a*b).sum()
check_schedule(c, 1)
def test_mulacc_relu_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a*b).sum().relu()
check_schedule(c, 1)
def test_binop_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(5,2)
d = (a+b).reshape(5,2)+c
check_schedule(d, 1)
def test_binop_permute_fusion(self):
a = Tensor.empty(2,5)
b = Tensor.empty(2,5)
c = Tensor.empty(5,2)
d = (a+b).permute(1,0)+c
check_schedule(d, 1)
2023-09-29 16:21:51 +08:00
def test_constants_are_embedded(self):
a = Tensor.empty(3,3) * 2
check_schedule(a, 2, filter_loadops=False)
def test_binop_elu_fusion(self):
a = Tensor.empty(10)
b = a.elu()
check_schedule(b, 1)
def test_binop_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
c = (a+b).reshape(10, 10).sum(axis=0, keepdim=True)
check_schedule(c, 1)
def test_reduce_reshape_binop_fusion(self):
a = Tensor.empty(10,10)
b = Tensor.empty(10)
c = a.sum(axis=0) + b
check_schedule(c, 1)
@unittest.skip("not pushing permutes through reduces")
def test_reduce_permute_binop_fusion(self):
a = Tensor.empty(10,10,10)
b = Tensor.empty(10,10,1)
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
check_schedule(c, 1)
def test_binop_early_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
c = Tensor.empty(10,10)
d = ((a+b).reshape(10,10) + c).sum(axis=0)
check_schedule(d, 1)
def test_diamond_folded(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = Tensor.empty(10)
ab = a+b
e = (ab+c) + (ab+d)
check_schedule(e, 1)
def test_cache_binaryop(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a+b
check_schedule(d, 0, [c])
@unittest.skip("failing in old lazy")
def test_cache_binaryop_reshaped(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
check_schedule(d, 0, [c])
@unittest.skip("failing in new lazy")
def test_cache_binaryop_transpose(self):
a = Tensor.empty(10,10)
b = Tensor.empty(10,10)
c = (a.T*b.T).T #.contiguous()
d = a*b
check_schedule(d, 0, [c])
def test_cache_two_reduceops(self):
a = Tensor.empty(10)
b = a.sum()
c = a.sum()
bc = b+c
check_schedule(bc, 1)
def test_fold_double_unary(self):
y = Tensor.empty(2)
out = y.sum(keepdim=True).sqrt().__neg__()
check_schedule(out, 1)
#@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self):
with Tensor.train():
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)
# run
img = Tensor.ones(2,3,64,64)
out = c1(img).relu()
check_schedule(out, 1, [c1.weight, c1.bias])
def test_fold_conv_elu(self):
c1 = nn.Conv2d(3,16,3)
# run
img = Tensor.rand(2,3,64,64)
out = c1(img).elu()
threefry again (#3785) * feat: initial xor * feat: initial threefly * feat: remove custom random * fix: really need to install precommit * feat: lmao forgot that this is rotate not a shift * clean: put that there * feat: numpy xor * feat: quick test for xor * feat: llvm xor * feat: slightly working xor in torch * feat: rand works in jit * clean: save a line * feat: match jax * feat: maybe test against jax * feat: requires_grad * fix: fix test_symbolic_ops * feat: lower alpha * feat: just pad * fix: maybe fix training tests? * fix: fix some llvm stuff * feat: cursed realize on the way out * feat: testing jax * fix: why is the jax install process not simple * fix: maybe passing test * fix: symbolic workarounds * clean: still need that precommit * fix: aaaa * fix: more test fixes * fix: quick fix for wgsl * feat: need to set requires_grad on the final tensor * feat: one more tensor * feat: don't take forever * feat: seeing y ci is brok * feat: can't allocate 64GiB lmao * fix: fix this * feat: hope this doesn't break smth before i go to bed * feat: don't destroy ram * feat: int * feat: remove jax * feat: properish workaround? * feat: skip slow webgpu tests * feat: no longer fails * feat: use dtypes * feat: real number * fix: torch * fix: don't test against reference for torch * feat: to device * feat: fix advanced indexing * feat: correct casting * feat: even rng_counter * feat: match master * feat: this was actually bad * fix: maybe? * feat: store * feat: remove realizes * feat: somehow this is important * feat: somehow this is also important * feat: save a line * fix: don't need that anymore * feat: restore this * fix: linter * feat: remove realizes * fix: realized is in base now * fix: add back cast * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: :( * fix: :( * fix: not being dumb * feat: try changing less tests * feat: shouldn't have to change that * feat: contiguous bumps it by one * fix: hmm * fix: numpy memory moment * fix: cl_khr_fp16 * fix: torch has different tensor count * fix: missing contiguous * hmm: hmm * fix: some fixes * fix: typing * feat: dont do that * feat: typing fixes * feat: why is this realize required? * feat: ngl kinda odd typing * feat: oh * feat: remove realizes * feat: why is this realize required? * fix: hacky patch for cudacpu * fix: without this realize pytest crashes????? * fix: shorter line * fix: cudacpu fixes * fix: cudacpu fixes * feat: real buffer * feat: don't search when searching lmao * fix: can't use contiguous things * fix: no more 100GB arrays * fix: revert * fix: skip 7 and 10 * feat: working ish beam * feat: minimize changes * feat: seed 0 stable diffusion example changed * fix: different on ci * fix: no beam * feat: make threefry optional * fix: check value * fix: unused import * feat: threefry default * fix: 5d * feat: allow non upcast div * fix: 5d better * fix: 5d better * fix: save all dtype * feat: proper error * feat: lazyop key * fix: check float * feat: try removing this realize now * feat: disable threefry for uops hip tensor cores * feat: don't need that * feat: only check upcast * fix: disable threefry for some metal tests * feat: disable for metal tensor uops as well * feat: disable for most uops * fix: disable threefry for new uops tests * feat: multitensor * fix: typing * feat: threefry default off * feat: skip threefry half rand * feat: restore old * fix: bad git * clean: ruff * feat: bfloat16 fix * fix: :| * feat: restore old --------- Co-authored-by: chenyu <chenyu@fastmail.com>
2024-03-19 04:47:07 +08:00
check_schedule(out, 1, [c1.weight, c1.bias, img])
def test_two_sum(self):
img = Tensor.empty(64,64)
x = (img.sum(0) + img.sum(1))
out = x.relu()
del x # is 3 without this
check_schedule(out, 2)
#@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape(self):
a = Tensor.empty(16,16)
b = Tensor.empty(16,16)
c = (a+b).reshape(4,4,4,4).permute(2,3,0,1).contiguous()
check_schedule(c, 1)
#@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape_alt(self):
a = Tensor.empty(4,4,4,4)
b = Tensor.empty(4,4,4,4)
c = (a+b).reshape(16,16).permute(1,0).contiguous()
check_schedule(c, 1)
def test_no_binop_rerun(self):
a = Tensor.empty(16)
b = Tensor.empty(16)
c = a+b
d = (a+b).reshape(16,1)
check_schedule(d, 0, [c])
def test_multi_permute_should_collapse(self):
a = Tensor.empty(4,4,4,4)
b = Tensor.empty(16)
c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b
check_schedule(c, 1)
@unittest.skip("failing in old lazy")
def test_fancy_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
out = c.sum() + d.sum()
check_schedule(out, 1)
# NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first
@unittest.skip("not real world")
def test_children_dont_push(self):
a = Tensor.empty(10, 10, 1)
b = Tensor.empty(10, 10, 1)
d = (a+b).expand(10, 10, 10)
e = (a+b).permute(2,1,0)
f = d+e
check_schedule(f, 2)
@unittest.skip("failing in new lazy")
def test_dont_fuse_binops_with_children(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
keep_me = a+b
e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
d = keep_me+c
check_schedule(d, 2)
check_schedule(keep_me, 0, [d])
#@unittest.skip("failing in old lazy")
def test_permute_breaks_fusion(self):
a = Tensor.empty(10, 10, 10)
b = Tensor.empty(10, 10)
c = (a.sum(axis=2) + b).permute(1,0)
d = c.permute(1,0)
check_schedule(d, 1)
def test_some_permute_fusion(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(1, 16)
d = (a.T + b.expand(8192, 16).T)
c = a + b.expand(8192, 16)
e = d.T
check_schedule(c, 1)
check_schedule(e, 1)
def test_shrink_fuse(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(8192, 16)
c = a * b
d = Tensor.empty(1, 16)
e = c[0] * d
check_schedule(e, 1)
def test_expand_nofuse(self):
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)
c = a * b
d = Tensor.empty(8192, 16)
e = c * d
check_schedule(e, 2)
# this is the failing case in openpilot...it's very simple like this
@unittest.skip("failing in old lazy")
def test_image_conv_fusion(self):
from tinygrad.features.image import image_conv2d
w1 = Tensor.empty(16, 16, 1, 1)
b1 = Tensor.empty(16)
w2 = Tensor.empty(16, 16, 1, 1)
b2 = Tensor.empty(16)
w3 = Tensor.empty(16, 16, 1, 1)
b3 = Tensor.empty(16)
x = Tensor.empty(1, 16, 32, 32)
x = base = image_conv2d(x, w1, b1)
x = image_conv2d(x, w2, b2) + base
x = image_conv2d(x, w3, b3)
# NOOP, 3 convs, contiguous
check_schedule(x, 5)
def test_image_conv_fusion_minimal(self):
b1 = Tensor.empty(16)
b2 = Tensor.empty(16)
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
x = Tensor.empty(16, 32)
x = base = p(x) + b1.reshape(16,1)
x = p(x)
x = x + b2.reshape(16,1)
x = x + base
del base
x = p(x)
check_schedule(x, 4)
def test_image_conv_fusion_more_minimal(self):
b1 = Tensor.empty(16)
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
x = Tensor.empty(16, 32)
x = base = p(x) + b1.reshape(16,1)
x = p(x)
del base
check_schedule(x, 3)
def test_resnet_block(self):
Tensor.training = False
in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes)
conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
bn2 = nn.BatchNorm2d(planes)
x = Tensor.empty(1, 64, 32, 32)
out = bn1(conv1(x)).relu()
out = bn2(conv2(out))
out = (out + x).relu()
threefry again (#3785) * feat: initial xor * feat: initial threefly * feat: remove custom random * fix: really need to install precommit * feat: lmao forgot that this is rotate not a shift * clean: put that there * feat: numpy xor * feat: quick test for xor * feat: llvm xor * feat: slightly working xor in torch * feat: rand works in jit * clean: save a line * feat: match jax * feat: maybe test against jax * feat: requires_grad * fix: fix test_symbolic_ops * feat: lower alpha * feat: just pad * fix: maybe fix training tests? * fix: fix some llvm stuff * feat: cursed realize on the way out * feat: testing jax * fix: why is the jax install process not simple * fix: maybe passing test * fix: symbolic workarounds * clean: still need that precommit * fix: aaaa * fix: more test fixes * fix: quick fix for wgsl * feat: need to set requires_grad on the final tensor * feat: one more tensor * feat: don't take forever * feat: seeing y ci is brok * feat: can't allocate 64GiB lmao * fix: fix this * feat: hope this doesn't break smth before i go to bed * feat: don't destroy ram * feat: int * feat: remove jax * feat: properish workaround? * feat: skip slow webgpu tests * feat: no longer fails * feat: use dtypes * feat: real number * fix: torch * fix: don't test against reference for torch * feat: to device * feat: fix advanced indexing * feat: correct casting * feat: even rng_counter * feat: match master * feat: this was actually bad * fix: maybe? * feat: store * feat: remove realizes * feat: somehow this is important * feat: somehow this is also important * feat: save a line * fix: don't need that anymore * feat: restore this * fix: linter * feat: remove realizes * fix: realized is in base now * fix: add back cast * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: bump deadline * fix: :( * fix: :( * fix: not being dumb * feat: try changing less tests * feat: shouldn't have to change that * feat: contiguous bumps it by one * fix: hmm * fix: numpy memory moment * fix: cl_khr_fp16 * fix: torch has different tensor count * fix: missing contiguous * hmm: hmm * fix: some fixes * fix: typing * feat: dont do that * feat: typing fixes * feat: why is this realize required? * feat: ngl kinda odd typing * feat: oh * feat: remove realizes * feat: why is this realize required? * fix: hacky patch for cudacpu * fix: without this realize pytest crashes????? * fix: shorter line * fix: cudacpu fixes * fix: cudacpu fixes * feat: real buffer * feat: don't search when searching lmao * fix: can't use contiguous things * fix: no more 100GB arrays * fix: revert * fix: skip 7 and 10 * feat: working ish beam * feat: minimize changes * feat: seed 0 stable diffusion example changed * fix: different on ci * fix: no beam * feat: make threefry optional * fix: check value * fix: unused import * feat: threefry default * fix: 5d * feat: allow non upcast div * fix: 5d better * fix: 5d better * fix: save all dtype * feat: proper error * feat: lazyop key * fix: check float * feat: try removing this realize now * feat: disable threefry for uops hip tensor cores * feat: don't need that * feat: only check upcast * fix: disable threefry for some metal tests * feat: disable for metal tensor uops as well * feat: disable for most uops * fix: disable threefry for new uops tests * feat: multitensor * fix: typing * feat: threefry default off * feat: skip threefry half rand * feat: restore old * fix: bad git * clean: ruff * feat: bfloat16 fix * fix: :| * feat: restore old --------- Co-authored-by: chenyu <chenyu@fastmail.com>
2024-03-19 04:47:07 +08:00
check_schedule(out, 2, [conv1.weight, conv2.weight])
2023-10-05 22:22:05 +08:00
def test_contiguous_while_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.contiguous()
check_schedule(out, 1, filter_loadops=False)
def test_contiguous_while_not_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.permute(0,2,3,1).contiguous()
check_schedule(out, 2, filter_loadops=False)
def test_double_from(self):
x = Tensor([1,2,3,4])
2024-04-10 04:19:30 +08:00
out = x.to('npy')
check_schedule(out, 0, filter_loadops=False)
2023-12-14 00:39:20 +08:00
def test_pow_const_tensor_simplified(self):
x = Tensor([1,2,3,4])
2023-12-14 00:39:20 +08:00
# NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5)
out = x ** Tensor(2)
check_schedule(out, 1)
2023-12-14 00:39:20 +08:00
def test_pow_const_tensor_to_zero(self):
x = Tensor([1,2,3,4])
out = x ** Tensor(0)
# NOTE: this is ConstBuffer 0 + ConstBuffer 1
check_schedule(out, 0)
2023-12-14 00:39:20 +08:00
def test_zero_size(self):
x = Tensor.empty(2, 3, 0)
out = x + 1
check_schedule(out, 0, filter_loadops=False)
def test_reduce_permute_nofuse(self):
x = Tensor.empty(32, 32, 32)
y = Tensor.empty(32, 32)
out = x.sum(axis=2).T+y
check_schedule(out, 2)
def test_two_elus_sum(self):
x = Tensor.empty(32, 32)
y = Tensor.empty(32, 32)
out = x.sum(1).relu().elu() + y.sum(1).relu().elu()
check_schedule(out, 2)
def test_multistage_reduce(self):
x = Tensor.empty(32, 32, 32)
out = x.sum(2).relu().sum(1)
check_schedule(out, 2)
def test_multistage_reduce_fork(self):
x = Tensor.empty(32, 32, 32)
x = x.sum(2)
out2 = x + 1
out = x.relu().sum(1) + out2[0]
check_schedule(out, 2)
def test_example_matmul(self):
x = Tensor.eye(64, requires_grad=True)
y = Tensor.eye(64, requires_grad=True)
z = y.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
check_schedule(out, 2)
def test_contiguous_add(self):
x = Tensor.empty(32)
y = Tensor.empty(32)
z = Tensor.empty(32)
out = (x+y).contiguous()+z
check_schedule(out, 2)
def test_double_sum_ref(self):
x = Tensor.empty(32, 32, 32)
x = x.sum(2)
out = x + x[:, 4]
check_schedule(out, 2)
def test_reduce_shrink(self):
x = Tensor.empty(32, 32)
y = Tensor.empty(16)
x = x.sum(1)
x = x[:16]
out = x + y
check_schedule(out, 2) # TODO: this should be 1
@unittest.skip("broken due to const folding and two contiguous are different kernels")
def test_const_no_recompute(self):
x = Tensor(2) + Tensor(2)
y = Tensor(2) + Tensor(2)
out = x.contiguous() + y.contiguous()
check_schedule(out, 2)
if __name__ == '__main__':
unittest.main(verbosity=2)