From 1e7b7b2c3cf79e3947e52ef0bb08ff39eb710f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Mon, 20 May 2024 18:06:00 +0200 Subject: [PATCH] Fix flop coutning for mulacc (#4640) * Fix flop coutning for mulacc * add test_simple_mulacc * Update test_uops_stats.py * Update test_uops_stats.py * revert test_mulacc * Test for MULACC vs MUL+ADD --- test/test_uops_stats.py | 29 +++++++++++++++++++++++++++++ tinygrad/codegen/uops.py | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index e65d133a..ba228f47 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -2,6 +2,9 @@ import unittest from tinygrad import Tensor from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item +from tinygrad.codegen.uops import UOpGraph, UOps +from tinygrad.ops import BinaryOps, TernaryOps +from tinygrad.dtype import dtypes # TODO: can copy this in here when we remove it #from tinygrad.ops import get_lazyop_info @@ -50,5 +53,31 @@ class TestUOpsStats(unittest.TestCase): # NOTE: it's hard to assert on the memory here, all depends on caching assert required_mem <= mem + #MULACC should have the same stats as MUL + ADD + def test_mulacc(self): + uops = UOpGraph() + globl = uops.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple()) + o1 = uops.add(UOps.CONST, dtypes.int, tuple(), 1) + o2 = uops.add(UOps.CONST, dtypes.int, tuple(), 2) + u1 = uops.add(UOps.LOAD, dtypes.int, (globl, o1)) + u2 = uops.add(UOps.LOAD, dtypes.int, (globl, o2)) + u3 = uops.add(UOps.CONST, dtypes.int, tuple(), 3) + u4 = uops.add(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL) + u5 = uops.add(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) + uops.add(UOps.SINK, None, (u5,)) + + uops_fma = UOpGraph() + globl = uops_fma.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple()) + o1 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 1) + o2 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 2) + u1 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o1)) + u2 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o2)) + u3 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 3) + u4 = uops_fma.add(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC) + uops_fma.add(UOps.SINK, None, (u4,)) + + self.assertEqual(uops.flops_mem(), uops_fma.flops_mem()) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 110627ae..d3e55c7f 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -382,7 +382,7 @@ class UOpGraph: elif u.uop is UOps.ENDRANGE: mults = mult_stack.pop(-1) elif u.uop is UOps.ALU: - flops += mults + flops += mults * (2 if u.arg == TernaryOps.MULACC else 1) elif u.uop is UOps.LOAD: assert u.dtype is not None mem += u.dtype.itemsize * mults