mirror of https://github.com/commaai/tinygrad.git
add tests for uops stats (#5649)
* add tests for uops stats * no locals skip is fine * eh
This commit is contained in:
parent
4f83da626e
commit
7c4b177e3a
|
@ -7,6 +7,7 @@ from tinygrad.codegen.uops import flops_mem, UOps, UOp
|
|||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
|
||||
# **************** new FlopCounter ****************
|
||||
|
||||
|
@ -118,6 +119,67 @@ class TestUOpsStats(unittest.TestCase):
|
|||
|
||||
self.assertEqual(flops_mem(uops.uops), flops_mem(uops_fma.uops))
|
||||
|
||||
N = 100
|
||||
@unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe?
|
||||
class TestStatsOptimized(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
|
||||
|
||||
def check_gemm(self, p, extra_flops=0):
|
||||
#p.uops.print()
|
||||
#print(p.src)
|
||||
print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate)
|
||||
self.assertEqual(p.op_estimate, 2*N*N*N + extra_flops) # N**3 mulaccs
|
||||
self.assertEqual(p.mem_estimate, 3*N*N*4) # 3 NxN mats with floats
|
||||
|
||||
def test_gemm(self):
|
||||
p = Kernel(self.ast_gemm).to_program()
|
||||
self.check_gemm(p)
|
||||
self.assertEqual(p.lds_estimate, 2*N*N*N*4 + 4*N*N)
|
||||
|
||||
# this is a good lesson about why UPCASTing is a good idea
|
||||
|
||||
def test_gemm_one_upcasted(self):
|
||||
k = Kernel(self.ast_gemm)
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
p = k.to_program()
|
||||
self.check_gemm(p)
|
||||
self.assertEqual(p.lds_estimate, N*N*N*4 + N*N*N*4//4 + 4*N*N)
|
||||
|
||||
def test_gemm_upcasted(self):
|
||||
k = Kernel(self.ast_gemm)
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
p = k.to_program()
|
||||
self.check_gemm(p)
|
||||
self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N)
|
||||
|
||||
def test_gemm_upcasted_locals(self):
|
||||
k = Kernel(self.ast_gemm)
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
try:
|
||||
k.apply_opt(Opt(OptOps.LOCAL, 0, 5))
|
||||
k.apply_opt(Opt(OptOps.LOCAL, 1, 5))
|
||||
except KernelOptError:
|
||||
raise unittest.SkipTest("no locals")
|
||||
p = k.to_program()
|
||||
self.check_gemm(p)
|
||||
self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N)
|
||||
|
||||
def test_gemm_group(self):
|
||||
k = Kernel(self.ast_gemm)
|
||||
try:
|
||||
k.apply_opt(Opt(OptOps.GROUP, 0, 4))
|
||||
except KernelOptError:
|
||||
raise unittest.SkipTest("no locals")
|
||||
SZ = N*N*4
|
||||
p = k.to_program()
|
||||
# NOTE: these are sort of wrong. they aren't honoring the IF statement
|
||||
self.check_gemm(p, extra_flops=SZ*4)
|
||||
self.assertEqual(p.lds_estimate, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -179,8 +179,8 @@ def type_verify(uops):
|
|||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
assert dtype.count == 1, f"wide ALU is not supported on {dtype}"
|
||||
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
|
@ -190,8 +190,7 @@ def type_verify(uops):
|
|||
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
||||
# the distance to shift isn't typechecked
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
@ -216,10 +215,12 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
|||
elif u.op is UOps.STORE:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
||||
elif u.op is UOps.IF:
|
||||
dont_count = dont_count.union(u.src[0].sparents)
|
||||
for u in uops:
|
||||
if u.op is UOps.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= uop_alu_resolve(u.src[1])
|
||||
mults *= uop_alu_resolve(u.src[1]) - uop_alu_resolve(u.src[0])
|
||||
elif u.op is UOps.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.op is UOps.LOAD:
|
||||
|
|
Loading…
Reference in New Issue