mirror of https://github.com/commaai/tinygrad.git
uop resolve [run_process_replay] (#6826)
* uop bool and int and stuff [run_process_replay] * add ne support * can't even be None anymore * BinaryOps.AND support * less compare
This commit is contained in:
parent
a42b177533
commit
d726eb6f48
|
@ -185,7 +185,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
|
|||
def _is_simple(lin: Kernel) -> bool:
|
||||
if len(lin.ast.src) > 1: return False
|
||||
ast:UOp = lin.ast.src[0]
|
||||
if ast.src[0] and ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is UOps.LOAD: return True
|
||||
if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is UOps.LOAD: return True
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -618,7 +618,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||
ast = si.ast.src[0]
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0]
|
||||
assert ast.src[2].src[0].op is UOps.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1
|
||||
t = 2 * t
|
||||
for si in t.schedule():
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import unittest
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp
|
||||
|
||||
class TestUOpResolve(unittest.TestCase):
|
||||
def test_simple_int(self):
|
||||
u = UOp.const(dtypes.int, 4)
|
||||
self.assertEqual(int(u), 4)
|
||||
|
||||
def test_int_add(self):
|
||||
u = UOp.const(dtypes.int, 4) + 7
|
||||
self.assertEqual(int(u), 11)
|
||||
|
||||
def test_lt(self):
|
||||
u = UOp.const(dtypes.int, 4) < 7
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_leq(self):
|
||||
u = UOp.const(dtypes.int, 4) <= 4
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_ne(self):
|
||||
u = UOp.const(dtypes.int, 4).ne(7)
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_ne_f(self):
|
||||
u = UOp.const(dtypes.int, 4).ne(4)
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_ngt(self):
|
||||
u = UOp.const(dtypes.int, 4) > 7
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_float_direct(self):
|
||||
u = UOp.const(dtypes.float, 4.5) + 7
|
||||
self.assertEqual(float(u), 11.5)
|
||||
|
||||
def test_var_cmp_t(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 20
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_t2(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10)//2 < 20
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_f(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 1
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_var_cmp_f2(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) > 11
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_or_true(self):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) | True
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_or_false(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) | False
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_and_false(self):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) & False
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_and_true(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) & True
|
||||
self.assertFalse(u)
|
||||
|
||||
@unittest.skip("too fancy to be supported right now")
|
||||
def test_var_cmp_range(self):
|
||||
v = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
u = v > 4 or v < 6
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_assert(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 5
|
||||
self.assertFalse(u)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -385,7 +385,8 @@ class Kernel:
|
|||
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
|
||||
else: amt = -1
|
||||
|
||||
if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
||||
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
||||
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
||||
acc_sz = self.reduceop.dtype.itemsize
|
||||
upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
|
||||
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
|
||||
|
@ -598,7 +599,7 @@ class Kernel:
|
|||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# kernel name (before late upcast)
|
||||
name = ("r" if self.reduceop else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
|
||||
name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
|
||||
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
|
||||
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def fold_expanded(ex, buf):
|
|||
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
|
||||
load_1 = new_srcs[offsets[o]]
|
||||
new_src = list(load_1.src)
|
||||
if not new_src[1].divides(fold_length): continue
|
||||
if new_src[1].divides(fold_length) is None: continue
|
||||
# for images, we rewrite the index. it must evenly divide 4 from the above check
|
||||
if is_image:
|
||||
new_src[1] = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((new_src[1] // 4) % buf.dtype.shape[1], (new_src[1] // (4 * buf.dtype.shape[1]))))
|
||||
|
@ -264,7 +264,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
|
|||
|
||||
if not drop_stmt and idx.key == start_idx.key: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx)))
|
||||
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))
|
||||
|
||||
# ***** optional patterns *****
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ class ScheduleItemContext:
|
|||
# ** helpers for doing movementops on uops
|
||||
|
||||
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (n:=cache.get(u)): return n
|
||||
if (n:=cache.get(u)) is not None: return n
|
||||
if u.op is UOps.SHAPETRACKER:
|
||||
new_st = apply_to_st(u.arg)
|
||||
return u if u.arg == new_st else UOp(UOps.SHAPETRACKER, dtypes.void, (), new_st)
|
||||
|
@ -140,7 +140,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
|||
|
||||
# buffer ops define ShapeTracker
|
||||
# if it's realized, it's a load and we add it to the inputs
|
||||
if (ubuf:=buf_uops.get(buf.buffer)) and buf not in outputs:
|
||||
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if buf.op is MetaOps.CONST:
|
||||
|
|
|
@ -16,7 +16,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[
|
|||
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
|
||||
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
|
||||
|
||||
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
|
||||
if enable_cache: lazycache[cache_key] = ret
|
||||
|
|
|
@ -65,9 +65,13 @@ class MathTrait:
|
|||
def eq(self, x): return self.ne(x).ne(True)
|
||||
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
|
||||
def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self)
|
||||
# TODO: use this one instead
|
||||
def ge(self, x): return self.lt(x).ne(True)
|
||||
#def ge(self, x): return (-self).lt(-x+1)
|
||||
def le(self, x): return self.gt(x).ne(True)
|
||||
# NOTE: __eq__/__ne__ can't be overridden, and means the same thing as is and is not
|
||||
def __lt__(self, x): return self.lt(x)
|
||||
def __gt__(self, x): return self.gt(x)
|
||||
def __ge__(self, x): return self.ge(x)
|
||||
def __le__(self, x): return self.le(x)
|
||||
def max(self, x): return self.alu(BinaryOps.MAX, self.ufix(x))
|
||||
def min(self, x): return -(-self).max(-x)
|
||||
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
|
||||
|
@ -166,6 +170,16 @@ class UOp(MathTrait):
|
|||
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
# *** uop evaluation ***
|
||||
def _eval(self, dtype, expected_type) -> ConstType:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
vmin, vmax = self._min_max
|
||||
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax}")
|
||||
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
|
||||
return vmin
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return self._eval(dtypes.floats, float)
|
||||
# *** uop syntactic sugar
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
|
@ -284,8 +298,15 @@ class UOp(MathTrait):
|
|||
if s1.arg < 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg)
|
||||
if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
|
||||
if self.arg is BinaryOps.CMPLT: return (s0.vmax<s1.vmin, s0.vmin<s1.vmax)
|
||||
if self.arg is BinaryOps.CMPNE:
|
||||
always_ne = (s0.vmax < s1.vmin) or (s1.vmin > s0.vmax)
|
||||
sometimes_ne = not (s0.vmin == s0.vmax == s1.vmin == s1.vmax)
|
||||
return (always_ne, sometimes_ne)
|
||||
# float has NAN issue and we use explicit NAN in transcendental
|
||||
if self.arg is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax)
|
||||
if self.dtype is dtypes.bool:
|
||||
if self.arg is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax
|
||||
if self.arg is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax
|
||||
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -563,12 +584,12 @@ class RewriteContext:
|
|||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
self.replace: Dict[UOp, UOp] = {}
|
||||
def rewrite(self, n:UOp) -> UOp:
|
||||
if rn := self.replace.get(n): return rn
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg)
|
||||
if found := self.nodes.get(replace_source): self.replace[n] = found
|
||||
if (found := self.nodes.get(replace_source)) is not None: self.replace[n] = found
|
||||
else:
|
||||
x = UOp(*replace_source) if new_src != n.src else n
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) else x
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x
|
||||
return found
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
|
|
|
@ -144,7 +144,7 @@ class PTXRenderer(Renderer):
|
|||
|
||||
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
if atype == dtype or isinstance(atype, PtrDType):
|
||||
if u: r[u] = a
|
||||
if u is not None: r[u] = a
|
||||
return a
|
||||
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
|
||||
return ret
|
||||
|
|
|
@ -74,7 +74,7 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
|||
return graph
|
||||
|
||||
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
|
||||
if (found:=replaces.get(base.key)): return found
|
||||
if (found:=replaces.get(base.key)) is not None: return found
|
||||
new_srcs = tuple(replace_uop(x, replaces) for x in base.src)
|
||||
replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue