add constant folding for WHERE in uops (#3584)

* add constant folding for WHERE in uops

* prereqs for generic constant folding

* fix test

* disable slow overflow logic

* make that test faster
This commit is contained in:
George Hotz 2024-03-02 10:37:14 -08:00 committed by GitHub
parent 3b7e3fa2e4
commit aa9b013d79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 29 deletions

View File

@ -27,6 +27,14 @@ class TestLinearizer(unittest.TestCase):
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
def test_load_removed(self):
a = Tensor.rand(1).realize()
b = Tensor.rand(1).realize()
ta = Tensor.where(Tensor(True), a, b).numpy()
tb = Tensor.where(Tensor(False), a, b).numpy()
np.testing.assert_equal(a.numpy(), ta)
np.testing.assert_equal(b.numpy(), tb)
def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.
@ -209,7 +217,7 @@ class TestLinearizer(unittest.TestCase):
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0)
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1),
arg=TernaryOps.WHERE).uop == UOps.ALU
arg=TernaryOps.WHERE).uop == UOps.CONST
def helper_realized_ast(r:Tensor):
s = create_schedule([r.lazydata])

View File

@ -1082,7 +1082,7 @@ class TestOps(unittest.TestCase):
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_padded_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
helper_test_op([(1,4,5,5,5), (4,4,3,3,3)],
lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5)

View File

@ -7,7 +7,7 @@ from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.runtime.ops_python import exec_alu
from tinygrad.codegen.uops import exec_alu
from test.test_dtype import is_dtype_supported
def _uops_to_prg(uops):
@ -113,5 +113,21 @@ class TestExecALU(TestUOps):
def test_sqrt(self):
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.int, (0,)), 0)
@unittest.skip("not enabled because it's slow")
def test_overflow(self):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.uint8, (0, 1)), 255)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.uint8, (0, 1000)), 24)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (127, 0)), 127)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-100, 100)), 56)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-1000, 0)), 24)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-130, 0)), 126)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1.0, 1.0)), 2)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-math.exp2(7), 0)), -128)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import functools
import functools, math
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten, all_same
@ -25,6 +25,30 @@ class UOp:
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
def exec_alu(arg, dtype, p):
if arg == TernaryOps.WHERE: ret = p[1] if p[0] else p[2]
elif arg == UnaryOps.LOG2: ret = math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
elif arg == UnaryOps.EXP2:
try: ret = math.exp(p[0]*math.log(2))
except OverflowError: ret = math.inf
elif arg == UnaryOps.SQRT: ret = math.sqrt(p[0]) if p[0] >= 0 else math.nan
elif arg == UnaryOps.SIN: ret = math.sin(p[0])
elif arg == UnaryOps.NEG: ret = -p[0]
elif arg == BinaryOps.MUL: ret = p[0]*p[1]
elif arg == BinaryOps.ADD: ret = p[0]+p[1]
elif arg == BinaryOps.SUB: ret = p[0]-p[1]
elif arg == BinaryOps.XOR: ret = p[0]^p[1]
elif arg == BinaryOps.MAX: ret = max(p[0], p[1])
elif arg == BinaryOps.CMPEQ: ret = p[0] == p[1]
elif arg == BinaryOps.CMPLT: ret = p[0] < p[1]
elif arg == BinaryOps.DIV: ret = p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
elif arg == BinaryOps.MOD: ret = p[0]%p[1]
return ret
#else: raise NotImplementedError(f"no support for {arg}")
#if not dtypes.is_int(dtype): return ret
#adjusted = 0 if dtypes.is_unsigned(dtype) else 2 ** (dtype.itemsize * 8 - 1)
#return (ret + adjusted) % 2 ** (dtype.itemsize * 8) - adjusted
def uop_alu_resolve(u:UOp) -> sint:
if u.uop == UOps.CONST: return u.arg
elif u.uop == UOps.DEFINE_VAR: return u.arg
@ -68,6 +92,7 @@ class UOpGraph:
# constant folding
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before)
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
return self.add(UOps.CONST, dtype, arg=vin[0].arg * vin[1].arg, insert_before=insert_before)
# zero folding

View File

@ -2,36 +2,14 @@
# works to test the tensor cores, and all the uops in general
# this is the (living) definition of uops
from typing import Tuple, List, Optional, Any, Dict
import pickle, base64, itertools, time, math, struct
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Allocator, Compiler
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.codegen.uops import UOp, UOps, exec_alu
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.codegen.kernel import LinearizerOptions
def exec_alu(arg, dtype, p):
# TODO: make this complete and correctly honor the dtypes
# TODO: use this for constant folding
if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2]
if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
if arg == UnaryOps.EXP2:
try: return math.exp(p[0]*math.log(2))
except OverflowError: return math.inf
if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan
if arg == UnaryOps.SIN: return math.sin(p[0])
if arg == UnaryOps.NEG: return -p[0]
if arg == BinaryOps.MUL: return p[0]*p[1]
if arg == BinaryOps.ADD: return p[0]+p[1]
if arg == BinaryOps.SUB: return p[0]-p[1]
if arg == BinaryOps.XOR: return p[0]^p[1]
if arg == BinaryOps.MAX: return max(p[0], p[1])
if arg == BinaryOps.CMPEQ: return p[0] == p[1]
if arg == BinaryOps.CMPLT: return p[0] < p[1]
if arg == BinaryOps.DIV: return p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
if arg == BinaryOps.MOD: return p[0]%p[1]
raise NotImplementedError(f"no support for {arg}")
def _load(m, i):
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
return m[i]