2024-02-20 16:36:30 +08:00
|
|
|
import unittest
|
2024-05-11 11:09:22 +08:00
|
|
|
from tinygrad import Tensor
|
2024-07-19 18:05:33 +08:00
|
|
|
from tinygrad.helpers import getenv
|
2024-03-27 12:02:46 +08:00
|
|
|
from tinygrad.engine.schedule import create_schedule
|
2024-05-11 11:09:22 +08:00
|
|
|
from tinygrad.engine.realize import lower_schedule_item
|
2024-07-11 08:34:50 +08:00
|
|
|
from tinygrad.codegen.uops import flops_mem, UOps, UOp
|
|
|
|
from tinygrad.codegen.uopgraph import UOpGraph
|
2024-05-21 00:06:00 +08:00
|
|
|
from tinygrad.ops import BinaryOps, TernaryOps
|
|
|
|
from tinygrad.dtype import dtypes
|
2024-02-20 16:36:30 +08:00
|
|
|
|
|
|
|
# **************** new FlopCounter ****************
|
|
|
|
|
|
|
|
def get_stats(x:Tensor):
|
|
|
|
si = create_schedule([x.lazydata])[-1]
|
2024-05-11 11:09:22 +08:00
|
|
|
ei = lower_schedule_item(si)
|
2024-05-11 13:43:09 +08:00
|
|
|
return ei.prg.op_estimate, ei.prg.mem_estimate
|
2024-02-20 16:36:30 +08:00
|
|
|
|
|
|
|
class TestUOpsStats(unittest.TestCase):
|
2024-07-19 18:05:33 +08:00
|
|
|
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
|
2024-02-20 16:36:30 +08:00
|
|
|
def test_simple_add(self):
|
|
|
|
a = Tensor.empty(100,100)
|
|
|
|
b = Tensor.empty(100,100)
|
|
|
|
c = a+b
|
|
|
|
ops, mem = get_stats(c)
|
|
|
|
expected_ops = c.numel()
|
|
|
|
expected_mem = a.nbytes() + b.nbytes() + c.nbytes()
|
|
|
|
self.assertEqual(mem, expected_mem)
|
|
|
|
# NOTE; ops also include indexing ops
|
|
|
|
assert expected_ops <= ops and ops <= expected_ops * 2
|
|
|
|
|
|
|
|
def test_simple_add_sq(self):
|
|
|
|
a = Tensor.empty(100,100)
|
|
|
|
b = Tensor.empty(100,100)
|
|
|
|
c = (a+b)*(a+b)
|
|
|
|
ops, mem = get_stats(c)
|
|
|
|
expected_ops = c.numel()*2
|
|
|
|
expected_mem = a.nbytes() + b.nbytes() + c.nbytes()
|
|
|
|
self.assertEqual(mem, expected_mem)
|
|
|
|
# NOTE; ops also include indexing ops
|
|
|
|
assert expected_ops <= ops and ops <= expected_ops * 2
|
|
|
|
|
|
|
|
def test_simple_matmul(self):
|
|
|
|
a = Tensor.empty(1024,1024)
|
|
|
|
b = Tensor.empty(1024,1024)
|
|
|
|
c = a@b
|
|
|
|
ops, mem = get_stats(c)
|
|
|
|
expected_ops = c.numel() * 1024 * 2
|
|
|
|
required_mem = a.nbytes() + b.nbytes() + c.nbytes()
|
|
|
|
assert expected_ops <= ops and ops <= expected_ops * 1.2
|
|
|
|
# NOTE: it's hard to assert on the memory here, all depends on caching
|
|
|
|
assert required_mem <= mem
|
|
|
|
|
2024-05-21 00:06:00 +08:00
|
|
|
#MULACC should have the same stats as MUL + ADD
|
|
|
|
def test_mulacc(self):
|
2024-06-14 18:23:25 +08:00
|
|
|
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
|
|
|
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
|
|
|
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
|
|
|
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
|
|
|
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
|
|
|
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
|
|
|
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
|
|
|
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
2024-06-14 18:52:37 +08:00
|
|
|
uops = UOpGraph([u5])
|
2024-05-21 00:06:00 +08:00
|
|
|
|
2024-06-14 18:23:25 +08:00
|
|
|
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
|
|
|
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
|
|
|
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
|
|
|
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
|
|
|
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
|
|
|
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
|
|
|
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
2024-06-14 18:52:37 +08:00
|
|
|
uops_fma = UOpGraph([u4])
|
2024-05-21 00:06:00 +08:00
|
|
|
|
2024-07-11 08:34:50 +08:00
|
|
|
self.assertEqual(flops_mem(uops.uops), flops_mem(uops_fma.uops))
|
2024-05-21 00:06:00 +08:00
|
|
|
|
|
|
|
|
2024-02-20 16:36:30 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main(verbosity=2)
|