mirror of https://github.com/commaai/tinygrad.git
Fix flop coutning for mulacc (#4640)
* Fix flop coutning for mulacc * add test_simple_mulacc * Update test_uops_stats.py * Update test_uops_stats.py * revert test_mulacc * Test for MULACC vs MUL+ADD
This commit is contained in:
parent
b144d4b460
commit
1e7b7b2c3c
|
@ -2,6 +2,9 @@ import unittest
|
|||
from tinygrad import Tensor
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
# TODO: can copy this in here when we remove it
|
||||
#from tinygrad.ops import get_lazyop_info
|
||||
|
@ -50,5 +53,31 @@ class TestUOpsStats(unittest.TestCase):
|
|||
# NOTE: it's hard to assert on the memory here, all depends on caching
|
||||
assert required_mem <= mem
|
||||
|
||||
#MULACC should have the same stats as MUL + ADD
|
||||
def test_mulacc(self):
|
||||
uops = UOpGraph()
|
||||
globl = uops.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
o1 = uops.add(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = uops.add(UOps.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = uops.add(UOps.LOAD, dtypes.int, (globl, o1))
|
||||
u2 = uops.add(UOps.LOAD, dtypes.int, (globl, o2))
|
||||
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = uops.add(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
||||
u5 = uops.add(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||
uops.add(UOps.SINK, None, (u5,))
|
||||
|
||||
uops_fma = UOpGraph()
|
||||
globl = uops_fma.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
o1 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o1))
|
||||
u2 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o2))
|
||||
u3 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = uops_fma.add(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
||||
uops_fma.add(UOps.SINK, None, (u4,))
|
||||
|
||||
self.assertEqual(uops.flops_mem(), uops_fma.flops_mem())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -382,7 +382,7 @@ class UOpGraph:
|
|||
elif u.uop is UOps.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.uop is UOps.ALU:
|
||||
flops += mults
|
||||
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
|
||||
elif u.uop is UOps.LOAD:
|
||||
assert u.dtype is not None
|
||||
mem += u.dtype.itemsize * mults
|
||||
|
|
Loading…
Reference in New Issue