add tests for uops stats (#5649)

* add tests for uops stats

* no locals skip is fine

* eh
This commit is contained in:
George Hotz 2024-07-22 21:57:03 -07:00 committed by GitHub
parent 4f83da626e
commit 7c4b177e3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 5 deletions

View File

@ -7,6 +7,7 @@ from tinygrad.codegen.uops import flops_mem, UOps, UOp
from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.dtype import dtypes
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
# **************** new FlopCounter ****************
@ -118,6 +119,67 @@ class TestUOpsStats(unittest.TestCase):
self.assertEqual(flops_mem(uops.uops), flops_mem(uops_fma.uops))
N = 100
@unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe?
class TestStatsOptimized(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
def check_gemm(self, p, extra_flops=0):
#p.uops.print()
#print(p.src)
print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate)
self.assertEqual(p.op_estimate, 2*N*N*N + extra_flops) # N**3 mulaccs
self.assertEqual(p.mem_estimate, 3*N*N*4) # 3 NxN mats with floats
def test_gemm(self):
p = Kernel(self.ast_gemm).to_program()
self.check_gemm(p)
self.assertEqual(p.lds_estimate, 2*N*N*N*4 + 4*N*N)
# this is a good lesson about why UPCASTing is a good idea
def test_gemm_one_upcasted(self):
k = Kernel(self.ast_gemm)
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
p = k.to_program()
self.check_gemm(p)
self.assertEqual(p.lds_estimate, N*N*N*4 + N*N*N*4//4 + 4*N*N)
def test_gemm_upcasted(self):
k = Kernel(self.ast_gemm)
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
p = k.to_program()
self.check_gemm(p)
self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N)
def test_gemm_upcasted_locals(self):
k = Kernel(self.ast_gemm)
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
try:
k.apply_opt(Opt(OptOps.LOCAL, 0, 5))
k.apply_opt(Opt(OptOps.LOCAL, 1, 5))
except KernelOptError:
raise unittest.SkipTest("no locals")
p = k.to_program()
self.check_gemm(p)
self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N)
def test_gemm_group(self):
k = Kernel(self.ast_gemm)
try:
k.apply_opt(Opt(OptOps.GROUP, 0, 4))
except KernelOptError:
raise unittest.SkipTest("no locals")
SZ = N*N*4
p = k.to_program()
# NOTE: these are sort of wrong. they aren't honoring the IF statement
self.check_gemm(p, extra_flops=SZ*4)
self.assertEqual(p.lds_estimate, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -179,8 +179,8 @@ def type_verify(uops):
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
if uop is UOps.ALU:
if arg in UnaryOps:
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
assert dtype.count == 1, f"wide ALU is not supported on {dtype}"
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
@ -190,8 +190,7 @@ def type_verify(uops):
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
# the distance to shift isn't typechecked
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps:
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
@ -216,10 +215,12 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
elif u.op is UOps.STORE:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
elif u.op is UOps.IF:
dont_count = dont_count.union(u.src[0].sparents)
for u in uops:
if u.op is UOps.RANGE:
mult_stack.append(mults)
mults *= uop_alu_resolve(u.src[1])
mults *= uop_alu_resolve(u.src[1]) - uop_alu_resolve(u.src[0])
elif u.op is UOps.ENDRANGE:
mults = mult_stack.pop(-1)
elif u.op is UOps.LOAD: