mirror of https://github.com/commaai/tinygrad.git
225 lines
8.1 KiB
Python
225 lines
8.1 KiB
Python
import unittest
|
|
from tinygrad import Tensor
|
|
from tinygrad.helpers import getenv, GlobalCounters
|
|
from tinygrad.engine.schedule import create_schedule
|
|
from tinygrad.engine.realize import lower_schedule_item
|
|
from tinygrad.codegen.linearize import linearize_uop
|
|
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
|
|
|
# **************** new FlopCounter ****************
|
|
|
|
def get_stats(x:Tensor):
|
|
si = create_schedule([x.lazydata])[-1]
|
|
ei = lower_schedule_item(si)
|
|
return ei.prg.op_estimate, ei.prg.mem_estimate
|
|
|
|
class TestMemoryCount(unittest.TestCase):
|
|
def test_add(self):
|
|
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
|
|
b = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
|
|
_, mem = get_stats(a+b)
|
|
self.assertEqual(mem, 1024*1024*3) # 2 reads + 1 write
|
|
|
|
def test_add_const(self):
|
|
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
|
|
_, mem = get_stats(a+3)
|
|
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
|
|
|
def test_add_slice(self):
|
|
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)[:512]
|
|
_, mem = get_stats(a+3)
|
|
self.assertEqual(mem, 512*1024*2) # 1 read + 1 write
|
|
|
|
def test_expanded(self):
|
|
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
|
|
b = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
|
|
_, mem = get_stats(a+b)
|
|
self.assertEqual(mem, 1024*1024*2 + 1024) # 1 full read + 1 lil read + 1 write
|
|
|
|
def test_both_expanded(self):
|
|
# TODO: this probably should be a full write
|
|
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
|
|
b = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
|
|
_, mem = get_stats(a+b)
|
|
self.assertEqual(mem, 1024*1024 + 2*1024) # 2 lil reads + 1 write
|
|
|
|
def test_self_add(self):
|
|
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
|
|
_, mem = get_stats(a+a)
|
|
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
|
|
|
def test_self_add_transposed(self):
|
|
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
|
|
_, mem = get_stats(a+a.T)
|
|
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
|
|
|
def test_self_add_assign(self):
|
|
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8).realize()
|
|
_, mem = get_stats(a.assign(a+a))
|
|
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
|
|
|
# NOTE: this still isn't testing unroll using the acc
|
|
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
|
|
class TestUOpsStatsMatmulHalf(unittest.TestCase):
|
|
def test_simple_matmul_half(self, N=16):
|
|
GlobalCounters.reset()
|
|
a, b = Tensor.empty(N, N, dtype=dtypes.half), Tensor.empty(N, N, dtype=dtypes.half)
|
|
c = a.matmul(b)
|
|
c.realize()
|
|
expected_ops = N ** 3 * 2
|
|
self.assertEqual(expected_ops, GlobalCounters.global_ops)
|
|
|
|
def test_bigger_matmul_half(self): self.test_simple_matmul_half(64)
|
|
|
|
def test_batched_matmul_half(self, N=16):
|
|
GlobalCounters.reset()
|
|
a, b = Tensor.empty(4, N, N, dtype=dtypes.half), Tensor.empty(1, N, N, dtype=dtypes.half)
|
|
c = a.matmul(b)
|
|
c.realize()
|
|
expected_ops = 4 * N ** 3 * 2
|
|
self.assertEqual(expected_ops, GlobalCounters.global_ops)
|
|
|
|
class TestUOpsStats(unittest.TestCase):
|
|
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
|
|
def test_simple_add(self):
|
|
a = Tensor.empty(100,100)
|
|
b = Tensor.empty(100,100)
|
|
c = a+b
|
|
ops, mem = get_stats(c)
|
|
expected_ops = c.numel()
|
|
expected_mem = a.nbytes() + b.nbytes() + c.nbytes()
|
|
self.assertEqual(mem, expected_mem)
|
|
# NOTE; ops also include indexing ops
|
|
assert expected_ops <= ops and ops <= expected_ops * 2
|
|
|
|
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
|
|
def test_simple_add_sq(self):
|
|
a = Tensor.empty(100,100)
|
|
b = Tensor.empty(100,100)
|
|
c = (a+b)*(a+b)
|
|
ops, mem = get_stats(c)
|
|
expected_ops = c.numel()*2
|
|
expected_mem = a.nbytes() + b.nbytes() + c.nbytes()
|
|
self.assertEqual(mem, expected_mem)
|
|
# NOTE; ops also include indexing ops
|
|
assert expected_ops <= ops and ops <= expected_ops * 2
|
|
|
|
def test_simple_matmul(self):
|
|
a = Tensor.empty(1024,1024)
|
|
b = Tensor.empty(1024,1024)
|
|
c = a@b
|
|
ops, mem = get_stats(c)
|
|
expected_ops = c.numel() * 1024 * 2
|
|
required_mem = a.nbytes() + b.nbytes() + c.nbytes()
|
|
assert expected_ops <= ops and ops <= expected_ops * 1.2
|
|
# NOTE: it's hard to assert on the memory here, all depends on caching
|
|
assert required_mem <= mem
|
|
|
|
#MULACC should have the same stats as MUL + ADD
|
|
def test_mulacc(self):
|
|
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
|
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
|
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
|
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
|
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
|
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
|
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
|
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
|
uops = linearize_uop(u5.sink())
|
|
|
|
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
|
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
|
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
|
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
|
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
|
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
|
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
|
uops_fma = linearize_uop(u4.sink())
|
|
|
|
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
|
|
|
|
N = 100
|
|
@unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe?
|
|
class TestStatsOptimized(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
|
|
cls.ast_reduce = (Tensor.empty(N*N).sum()).schedule()[-1].ast
|
|
|
|
def check_gemm(self, p, extra_flops=0):
|
|
#p.uops.print()
|
|
#print(p.src)
|
|
print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate)
|
|
self.assertEqual(p.op_estimate, 2*N*N*N + extra_flops) # N**3 mulaccs
|
|
self.assertEqual(p.mem_estimate, 3*N*N*4) # 3 NxN mats with floats
|
|
|
|
def test_gemm(self):
|
|
p = Kernel(self.ast_gemm).to_program()
|
|
self.check_gemm(p)
|
|
self.assertEqual(p.lds_estimate, 2*N*N*N*4 + 4*N*N)
|
|
|
|
# this is a good lesson about why UPCASTing is a good idea
|
|
|
|
def test_gemm_one_upcasted(self):
|
|
k = Kernel(self.ast_gemm)
|
|
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
|
p = k.to_program()
|
|
self.check_gemm(p)
|
|
self.assertEqual(p.lds_estimate, N*N*N*4 + N*N*N*4//4 + 4*N*N)
|
|
|
|
def test_gemm_upcasted(self):
|
|
k = Kernel(self.ast_gemm)
|
|
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
|
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
|
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
|
p = k.to_program()
|
|
self.check_gemm(p)
|
|
self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N)
|
|
|
|
def test_gemm_upcasted_locals(self):
|
|
k = Kernel(self.ast_gemm)
|
|
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
|
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
|
try:
|
|
k.apply_opt(Opt(OptOps.LOCAL, 0, 5))
|
|
k.apply_opt(Opt(OptOps.LOCAL, 1, 5))
|
|
except KernelOptError:
|
|
raise unittest.SkipTest("no locals")
|
|
p = k.to_program()
|
|
self.check_gemm(p)
|
|
self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N)
|
|
|
|
def test_gemm_group(self):
|
|
k = Kernel(self.ast_gemm)
|
|
try:
|
|
k.apply_opt(Opt(OptOps.GROUP, 0, 4))
|
|
except KernelOptError:
|
|
raise unittest.SkipTest("no locals")
|
|
SZ = N*N*4
|
|
p = k.to_program()
|
|
# NOTE: these are sort of wrong. they aren't honoring the IF statement
|
|
self.check_gemm(p, extra_flops=SZ*4)
|
|
self.assertEqual(p.lds_estimate, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
|
|
|
|
def test_reduce(self):
|
|
k = Kernel(self.ast_reduce)
|
|
p = k.to_program()
|
|
print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate)
|
|
self.assertEqual(p.op_estimate, N*N)
|
|
self.assertEqual(p.mem_estimate, N*N*4 + 4)
|
|
|
|
def test_reduce_group(self):
|
|
k = Kernel(self.ast_reduce)
|
|
try:
|
|
k.apply_opt(Opt(OptOps.GROUP, 0, 50))
|
|
except KernelOptError:
|
|
raise unittest.SkipTest("no locals")
|
|
p = k.to_program()
|
|
# NOTE: these are wrong, they don't respect the if statement
|
|
print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2)
|