Fix flop coutning for mulacc (#4640)

* Fix flop coutning for mulacc

* add test_simple_mulacc

* Update test_uops_stats.py

* Update test_uops_stats.py

* revert test_mulacc

* Test for MULACC vs MUL+ADD
This commit is contained in:
Szymon Ożóg 2024-05-20 18:06:00 +02:00 committed by GitHub
parent b144d4b460
commit 1e7b7b2c3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 1 deletions

View File

@ -2,6 +2,9 @@ import unittest
from tinygrad import Tensor
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.dtype import dtypes
# TODO: can copy this in here when we remove it
#from tinygrad.ops import get_lazyop_info
@ -50,5 +53,31 @@ class TestUOpsStats(unittest.TestCase):
# NOTE: it's hard to assert on the memory here, all depends on caching
assert required_mem <= mem
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):
uops = UOpGraph()
globl = uops.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
o1 = uops.add(UOps.CONST, dtypes.int, tuple(), 1)
o2 = uops.add(UOps.CONST, dtypes.int, tuple(), 2)
u1 = uops.add(UOps.LOAD, dtypes.int, (globl, o1))
u2 = uops.add(UOps.LOAD, dtypes.int, (globl, o2))
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), 3)
u4 = uops.add(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = uops.add(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
uops.add(UOps.SINK, None, (u5,))
uops_fma = UOpGraph()
globl = uops_fma.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
o1 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 1)
o2 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 2)
u1 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o1))
u2 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o2))
u3 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 3)
u4 = uops_fma.add(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
uops_fma.add(UOps.SINK, None, (u4,))
self.assertEqual(uops.flops_mem(), uops_fma.flops_mem())
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -382,7 +382,7 @@ class UOpGraph:
elif u.uop is UOps.ENDRANGE:
mults = mult_stack.pop(-1)
elif u.uop is UOps.ALU:
flops += mults
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
elif u.uop is UOps.LOAD:
assert u.dtype is not None
mem += u.dtype.itemsize * mults