mirror of https://github.com/commaai/tinygrad.git
add constant folding for WHERE in uops (#3584)
* add constant folding for WHERE in uops * prereqs for generic constant folding * fix test * disable slow overflow logic * make that test faster
This commit is contained in:
parent
3b7e3fa2e4
commit
aa9b013d79
|
@ -27,6 +27,14 @@ class TestLinearizer(unittest.TestCase):
|
|||
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
|
||||
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_load_removed(self):
|
||||
a = Tensor.rand(1).realize()
|
||||
b = Tensor.rand(1).realize()
|
||||
ta = Tensor.where(Tensor(True), a, b).numpy()
|
||||
tb = Tensor.where(Tensor(False), a, b).numpy()
|
||||
np.testing.assert_equal(a.numpy(), ta)
|
||||
np.testing.assert_equal(b.numpy(), tb)
|
||||
|
||||
def test_load_dedup(self):
|
||||
# for different leaves in the AST, the same loads may occur.
|
||||
|
||||
|
@ -209,7 +217,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0)
|
||||
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1),
|
||||
arg=TernaryOps.WHERE).uop == UOps.ALU
|
||||
arg=TernaryOps.WHERE).uop == UOps.CONST
|
||||
|
||||
def helper_realized_ast(r:Tensor):
|
||||
s = create_schedule([r.lazydata])
|
||||
|
|
|
@ -1082,7 +1082,7 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(IMAGE>0, "no conv3d on images")
|
||||
def test_padded_conv3d(self):
|
||||
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
|
||||
helper_test_op([(1,4,5,5,5), (4,4,3,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.device import Buffer, Device
|
|||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.device import CompiledASTRunner, Compiled
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.runtime.ops_python import exec_alu
|
||||
from tinygrad.codegen.uops import exec_alu
|
||||
from test.test_dtype import is_dtype_supported
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
|
@ -113,5 +113,21 @@ class TestExecALU(TestUOps):
|
|||
def test_sqrt(self):
|
||||
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.int, (0,)), 0)
|
||||
|
||||
@unittest.skip("not enabled because it's slow")
|
||||
def test_overflow(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)
|
||||
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.uint8, (0, 1)), 255)
|
||||
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.uint8, (0, 1000)), 24)
|
||||
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (127, 0)), 127)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
|
||||
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-100, 100)), 56)
|
||||
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-1000, 0)), 24)
|
||||
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-130, 0)), 126)
|
||||
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1.0, 1.0)), 2)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-math.exp2(7), 0)), -128)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import functools
|
||||
import functools, math
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import DEBUG, flatten, all_same
|
||||
|
@ -25,6 +25,30 @@ class UOp:
|
|||
def __repr__(self):
|
||||
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
||||
|
||||
def exec_alu(arg, dtype, p):
|
||||
if arg == TernaryOps.WHERE: ret = p[1] if p[0] else p[2]
|
||||
elif arg == UnaryOps.LOG2: ret = math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
|
||||
elif arg == UnaryOps.EXP2:
|
||||
try: ret = math.exp(p[0]*math.log(2))
|
||||
except OverflowError: ret = math.inf
|
||||
elif arg == UnaryOps.SQRT: ret = math.sqrt(p[0]) if p[0] >= 0 else math.nan
|
||||
elif arg == UnaryOps.SIN: ret = math.sin(p[0])
|
||||
elif arg == UnaryOps.NEG: ret = -p[0]
|
||||
elif arg == BinaryOps.MUL: ret = p[0]*p[1]
|
||||
elif arg == BinaryOps.ADD: ret = p[0]+p[1]
|
||||
elif arg == BinaryOps.SUB: ret = p[0]-p[1]
|
||||
elif arg == BinaryOps.XOR: ret = p[0]^p[1]
|
||||
elif arg == BinaryOps.MAX: ret = max(p[0], p[1])
|
||||
elif arg == BinaryOps.CMPEQ: ret = p[0] == p[1]
|
||||
elif arg == BinaryOps.CMPLT: ret = p[0] < p[1]
|
||||
elif arg == BinaryOps.DIV: ret = p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
|
||||
elif arg == BinaryOps.MOD: ret = p[0]%p[1]
|
||||
return ret
|
||||
#else: raise NotImplementedError(f"no support for {arg}")
|
||||
#if not dtypes.is_int(dtype): return ret
|
||||
#adjusted = 0 if dtypes.is_unsigned(dtype) else 2 ** (dtype.itemsize * 8 - 1)
|
||||
#return (ret + adjusted) % 2 ** (dtype.itemsize * 8) - adjusted
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.uop == UOps.CONST: return u.arg
|
||||
elif u.uop == UOps.DEFINE_VAR: return u.arg
|
||||
|
@ -68,6 +92,7 @@ class UOpGraph:
|
|||
# constant folding
|
||||
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before)
|
||||
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
||||
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
|
||||
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
|
||||
return self.add(UOps.CONST, dtype, arg=vin[0].arg * vin[1].arg, insert_before=insert_before)
|
||||
# zero folding
|
||||
|
|
|
@ -2,36 +2,14 @@
|
|||
# works to test the tensor cores, and all the uops in general
|
||||
# this is the (living) definition of uops
|
||||
from typing import Tuple, List, Optional, Any, Dict
|
||||
import pickle, base64, itertools, time, math, struct
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Allocator, Compiler
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.uops import UOp, UOps, exec_alu
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
|
||||
def exec_alu(arg, dtype, p):
|
||||
# TODO: make this complete and correctly honor the dtypes
|
||||
# TODO: use this for constant folding
|
||||
if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2]
|
||||
if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
|
||||
if arg == UnaryOps.EXP2:
|
||||
try: return math.exp(p[0]*math.log(2))
|
||||
except OverflowError: return math.inf
|
||||
if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan
|
||||
if arg == UnaryOps.SIN: return math.sin(p[0])
|
||||
if arg == UnaryOps.NEG: return -p[0]
|
||||
if arg == BinaryOps.MUL: return p[0]*p[1]
|
||||
if arg == BinaryOps.ADD: return p[0]+p[1]
|
||||
if arg == BinaryOps.SUB: return p[0]-p[1]
|
||||
if arg == BinaryOps.XOR: return p[0]^p[1]
|
||||
if arg == BinaryOps.MAX: return max(p[0], p[1])
|
||||
if arg == BinaryOps.CMPEQ: return p[0] == p[1]
|
||||
if arg == BinaryOps.CMPLT: return p[0] < p[1]
|
||||
if arg == BinaryOps.DIV: return p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
|
||||
if arg == BinaryOps.MOD: return p[0]%p[1]
|
||||
raise NotImplementedError(f"no support for {arg}")
|
||||
|
||||
def _load(m, i):
|
||||
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
return m[i]
|
||||
|
|
Loading…
Reference in New Issue