From 7c4b177e3abfe00beec968f3b0789f092480e4a2 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 22 Jul 2024 21:57:03 -0700 Subject: [PATCH] add tests for uops stats (#5649) * add tests for uops stats * no locals skip is fine * eh --- test/test_uops_stats.py | 62 ++++++++++++++++++++++++++++++++++++++++ tinygrad/codegen/uops.py | 11 +++---- 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 449622d1..0d2a7007 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -7,6 +7,7 @@ from tinygrad.codegen.uops import flops_mem, UOps, UOp from tinygrad.codegen.uopgraph import UOpGraph from tinygrad.ops import BinaryOps, TernaryOps from tinygrad.dtype import dtypes +from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError # **************** new FlopCounter **************** @@ -118,6 +119,67 @@ class TestUOpsStats(unittest.TestCase): self.assertEqual(flops_mem(uops.uops), flops_mem(uops_fma.uops)) +N = 100 +@unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe? +class TestStatsOptimized(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast + + def check_gemm(self, p, extra_flops=0): + #p.uops.print() + #print(p.src) + print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate) + self.assertEqual(p.op_estimate, 2*N*N*N + extra_flops) # N**3 mulaccs + self.assertEqual(p.mem_estimate, 3*N*N*4) # 3 NxN mats with floats + + def test_gemm(self): + p = Kernel(self.ast_gemm).to_program() + self.check_gemm(p) + self.assertEqual(p.lds_estimate, 2*N*N*N*4 + 4*N*N) + + # this is a good lesson about why UPCASTing is a good idea + + def test_gemm_one_upcasted(self): + k = Kernel(self.ast_gemm) + k.apply_opt(Opt(OptOps.UPCAST, 0, 4)) + p = k.to_program() + self.check_gemm(p) + self.assertEqual(p.lds_estimate, N*N*N*4 + N*N*N*4//4 + 4*N*N) + + def test_gemm_upcasted(self): + k = Kernel(self.ast_gemm) + k.apply_opt(Opt(OptOps.UPCAST, 0, 4)) + k.apply_opt(Opt(OptOps.UPCAST, 1, 4)) + k.apply_opt(Opt(OptOps.UNROLL, 0, 4)) + p = k.to_program() + self.check_gemm(p) + self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N) + + def test_gemm_upcasted_locals(self): + k = Kernel(self.ast_gemm) + k.apply_opt(Opt(OptOps.UPCAST, 0, 4)) + k.apply_opt(Opt(OptOps.UPCAST, 1, 4)) + try: + k.apply_opt(Opt(OptOps.LOCAL, 0, 5)) + k.apply_opt(Opt(OptOps.LOCAL, 1, 5)) + except KernelOptError: + raise unittest.SkipTest("no locals") + p = k.to_program() + self.check_gemm(p) + self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N) + + def test_gemm_group(self): + k = Kernel(self.ast_gemm) + try: + k.apply_opt(Opt(OptOps.GROUP, 0, 4)) + except KernelOptError: + raise unittest.SkipTest("no locals") + SZ = N*N*4 + p = k.to_program() + # NOTE: these are sort of wrong. they aren't honoring the IF statement + self.check_gemm(p, extra_flops=SZ*4) + self.assertEqual(p.lds_estimate, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index b487540b..a047d570 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -179,8 +179,8 @@ def type_verify(uops): assert dtype is None, f"{uop} dtype must be None, got {dtype}" if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}" if uop is UOps.ALU: - if arg in UnaryOps: - assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" + assert dtype.count == 1, f"wide ALU is not supported on {dtype}" + if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}" assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" @@ -190,8 +190,7 @@ def type_verify(uops): elif arg in {BinaryOps.SHL, BinaryOps.SHR}: # the distance to shift isn't typechecked assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" - elif arg in BinaryOps: - assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" + elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" elif arg == TernaryOps.WHERE: assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}" assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}" @@ -216,10 +215,12 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: elif u.op is UOps.STORE: dont_count = dont_count.union(u.src[1].sparents) if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents) + elif u.op is UOps.IF: + dont_count = dont_count.union(u.src[0].sparents) for u in uops: if u.op is UOps.RANGE: mult_stack.append(mults) - mults *= uop_alu_resolve(u.src[1]) + mults *= uop_alu_resolve(u.src[1]) - uop_alu_resolve(u.src[0]) elif u.op is UOps.ENDRANGE: mults = mult_stack.pop(-1) elif u.op is UOps.LOAD: