import unittest from tinygrad import Tensor, GlobalCounters from tinygrad.ops import UOps from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv from tinygrad.codegen.kernel import Kernel from tinygrad.engine.schedule import create_schedule class TestWinograd(unittest.TestCase): def setUp(self): self.old = WINO.value WINO.value = 1 def tearDown(self): WINO.value = self.old def test_speed(self): x = Tensor.empty(1,4,9,9) w = Tensor.empty(4,4,3,3) with Timing("running conv: "): out = Tensor.conv2d(x, w) with Timing("scheduling: "): sched = create_schedule([out.lazydata]) for i,s in enumerate(sched): if s.ast.op is not UOps.SINK: continue ops = s.ast.parents with Timing(f"linearize {i} with {len(ops):4d} ops: "): l = Kernel(s.ast) l.hand_coded_optimizations() l.linearize() assert len(l.sts) <= 256 # just the current value to prevent regression if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views") for st in l.sts: assert len(st.views) <= 2, "too many views in winograd" if DEBUG >= 3: print(f"{len(st.views):3d} views") for v in st.views: print(v) def test_profile(self): x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize() with Profiling(enabled=not CI, sort='time'): out = Tensor.conv2d(x,w).realize() out.numpy() def test_four_kernels(self): x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize() GlobalCounters.reset() out = Tensor.conv2d(x,w).realize() assert GlobalCounters.kernel_count == 4 out.numpy() @unittest.skipIf(getenv("PTX"), "winograd uses too much in PTX") def test_counters(self): IC, OC, X, Y = 4,4,9,9 #OC, IC, X, Y = 512, 256, 8, 8 x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize() GlobalCounters.reset() Tensor.conv2d(x,w).realize() ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem WINO.value = 0 GlobalCounters.reset() Tensor.conv2d(x,w).realize() ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}") print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}") self.assertLess(ops_ratio, 2.6) # TODO: there's issues with factorization now self.assertLess(mem_ratio, 10) if __name__ == '__main__': unittest.main(verbosity=2)