diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index e65d133a..ba228f47 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -2,6 +2,9 @@ import unittest from tinygrad import Tensor from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item +from tinygrad.codegen.uops import UOpGraph, UOps +from tinygrad.ops import BinaryOps, TernaryOps +from tinygrad.dtype import dtypes # TODO: can copy this in here when we remove it #from tinygrad.ops import get_lazyop_info @@ -50,5 +53,31 @@ class TestUOpsStats(unittest.TestCase): # NOTE: it's hard to assert on the memory here, all depends on caching assert required_mem <= mem + #MULACC should have the same stats as MUL + ADD + def test_mulacc(self): + uops = UOpGraph() + globl = uops.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple()) + o1 = uops.add(UOps.CONST, dtypes.int, tuple(), 1) + o2 = uops.add(UOps.CONST, dtypes.int, tuple(), 2) + u1 = uops.add(UOps.LOAD, dtypes.int, (globl, o1)) + u2 = uops.add(UOps.LOAD, dtypes.int, (globl, o2)) + u3 = uops.add(UOps.CONST, dtypes.int, tuple(), 3) + u4 = uops.add(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL) + u5 = uops.add(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) + uops.add(UOps.SINK, None, (u5,)) + + uops_fma = UOpGraph() + globl = uops_fma.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple()) + o1 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 1) + o2 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 2) + u1 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o1)) + u2 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o2)) + u3 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 3) + u4 = uops_fma.add(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC) + uops_fma.add(UOps.SINK, None, (u4,)) + + self.assertEqual(uops.flops_mem(), uops_fma.flops_mem()) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 110627ae..d3e55c7f 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -382,7 +382,7 @@ class UOpGraph: elif u.uop is UOps.ENDRANGE: mults = mult_stack.pop(-1) elif u.uop is UOps.ALU: - flops += mults + flops += mults * (2 if u.arg == TernaryOps.MULACC else 1) elif u.uop is UOps.LOAD: assert u.dtype is not None mem += u.dtype.itemsize * mults