uop resolve [run_process_replay] (#6826)

* uop bool and int and stuff [run_process_replay]

* add ne support

* can't even be None anymore

* BinaryOps.AND support

* less compare
This commit is contained in:
George Hotz 2024-10-01 13:11:42 +08:00 committed by GitHub
parent a42b177533
commit d726eb6f48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 122 additions and 16 deletions

View File

@ -185,7 +185,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
def _is_simple(lin: Kernel) -> bool:
if len(lin.ast.src) > 1: return False
ast:UOp = lin.ast.src[0]
if ast.src[0] and ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is UOps.LOAD: return True
if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is UOps.LOAD: return True
return False
if __name__ == "__main__":

View File

@ -618,7 +618,7 @@ class TestMultiTensor(unittest.TestCase):
ast = si.ast.src[0]
assert ast.op is UOps.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0]
assert ast.src[2].src[0].op is UOps.LOAD
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():

View File

@ -0,0 +1,84 @@
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp
class TestUOpResolve(unittest.TestCase):
def test_simple_int(self):
u = UOp.const(dtypes.int, 4)
self.assertEqual(int(u), 4)
def test_int_add(self):
u = UOp.const(dtypes.int, 4) + 7
self.assertEqual(int(u), 11)
def test_lt(self):
u = UOp.const(dtypes.int, 4) < 7
self.assertTrue(u)
def test_leq(self):
u = UOp.const(dtypes.int, 4) <= 4
self.assertTrue(u)
def test_ne(self):
u = UOp.const(dtypes.int, 4).ne(7)
self.assertTrue(u)
def test_ne_f(self):
u = UOp.const(dtypes.int, 4).ne(4)
self.assertFalse(u)
def test_ngt(self):
u = UOp.const(dtypes.int, 4) > 7
self.assertFalse(u)
def test_float_direct(self):
u = UOp.const(dtypes.float, 4.5) + 7
self.assertEqual(float(u), 11.5)
def test_var_cmp_t(self):
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 20
self.assertTrue(u)
def test_var_cmp_t2(self):
u = UOp.define_var("i", dtypes.pyint, 1, 10)//2 < 20
self.assertTrue(u)
def test_var_cmp_f(self):
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 1
self.assertFalse(u)
def test_var_cmp_f2(self):
u = UOp.define_var("i", dtypes.pyint, 1, 10) > 11
self.assertFalse(u)
def test_or_true(self):
u = UOp.define_var("b", dtypes.bool, False, True) | True
self.assertTrue(u)
def test_or_false(self):
with self.assertRaises(ValueError):
u = UOp.define_var("b", dtypes.bool, False, True) | False
self.assertTrue(u)
def test_and_false(self):
u = UOp.define_var("b", dtypes.bool, False, True) & False
self.assertFalse(u)
def test_and_true(self):
with self.assertRaises(ValueError):
u = UOp.define_var("b", dtypes.bool, False, True) & True
self.assertFalse(u)
@unittest.skip("too fancy to be supported right now")
def test_var_cmp_range(self):
v = UOp.define_var("i", dtypes.pyint, 1, 10)
u = v > 4 or v < 6
self.assertTrue(u)
def test_var_cmp_assert(self):
with self.assertRaises(ValueError):
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 5
self.assertFalse(u)
if __name__ == '__main__':
unittest.main()

View File

@ -385,7 +385,8 @@ class Kernel:
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
else: amt = -1
if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
acc_sz = self.reduceop.dtype.itemsize
upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
@ -598,7 +599,7 @@ class Kernel:
@functools.cached_property
def name(self) -> str:
# kernel name (before late upcast)
name = ("r" if self.reduceop else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])

View File

@ -39,7 +39,7 @@ def fold_expanded(ex, buf):
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
load_1 = new_srcs[offsets[o]]
new_src = list(load_1.src)
if not new_src[1].divides(fold_length): continue
if new_src[1].divides(fold_length) is None: continue
# for images, we rewrite the index. it must evenly divide 4 from the above check
if is_image:
new_src[1] = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((new_src[1] // 4) % buf.dtype.shape[1], (new_src[1] // (4 * buf.dtype.shape[1]))))
@ -264,7 +264,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
if not drop_stmt and idx.key == start_idx.key: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx)))
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))
# ***** optional patterns *****

View File

@ -50,7 +50,7 @@ class ScheduleItemContext:
# ** helpers for doing movementops on uops
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
if (n:=cache.get(u)): return n
if (n:=cache.get(u)) is not None: return n
if u.op is UOps.SHAPETRACKER:
new_st = apply_to_st(u.arg)
return u if u.arg == new_st else UOp(UOps.SHAPETRACKER, dtypes.void, (), new_st)
@ -140,7 +140,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# buffer ops define ShapeTracker
# if it's realized, it's a load and we add it to the inputs
if (ubuf:=buf_uops.get(buf.buffer)) and buf not in outputs:
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if buf.op is MetaOps.CONST:

View File

@ -16,7 +16,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
if enable_cache: lazycache[cache_key] = ret

View File

@ -65,9 +65,13 @@ class MathTrait:
def eq(self, x): return self.ne(x).ne(True)
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self)
# TODO: use this one instead
def ge(self, x): return self.lt(x).ne(True)
#def ge(self, x): return (-self).lt(-x+1)
def le(self, x): return self.gt(x).ne(True)
# NOTE: __eq__/__ne__ can't be overridden, and means the same thing as is and is not
def __lt__(self, x): return self.lt(x)
def __gt__(self, x): return self.gt(x)
def __ge__(self, x): return self.ge(x)
def __le__(self, x): return self.le(x)
def max(self, x): return self.alu(BinaryOps.MAX, self.ufix(x))
def min(self, x): return -(-self).max(-x)
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
@ -166,6 +170,16 @@ class UOp(MathTrait):
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
# *** uop evaluation ***
def _eval(self, dtype, expected_type) -> ConstType:
assert self.dtype in dtype, f"eval with wrong dtype {self}"
vmin, vmax = self._min_max
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax}")
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
return vmin
def __bool__(self): return self._eval((dtypes.bool,), bool)
def __int__(self): return self._eval(dtypes.ints, int)
def __float__(self): return self._eval(dtypes.floats, float)
# *** uop syntactic sugar
@property
def st_arg(self) -> ShapeTracker:
@ -284,8 +298,15 @@ class UOp(MathTrait):
if s1.arg < 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg)
if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
if self.arg is BinaryOps.CMPLT: return (s0.vmax<s1.vmin, s0.vmin<s1.vmax)
if self.arg is BinaryOps.CMPNE:
always_ne = (s0.vmax < s1.vmin) or (s1.vmin > s0.vmax)
sometimes_ne = not (s0.vmin == s0.vmax == s1.vmin == s1.vmax)
return (always_ne, sometimes_ne)
# float has NAN issue and we use explicit NAN in transcendental
if self.arg is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax)
if self.dtype is dtypes.bool:
if self.arg is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax
if self.arg is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax
return dtypes.min(self.dtype), dtypes.max(self.dtype)
@dataclass(frozen=True)
@ -563,12 +584,12 @@ class RewriteContext:
self.nodes: Dict[Tuple, UOp] = {}
self.replace: Dict[UOp, UOp] = {}
def rewrite(self, n:UOp) -> UOp:
if rn := self.replace.get(n): return rn
if (rn := self.replace.get(n)) is not None: return rn
replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg)
if found := self.nodes.get(replace_source): self.replace[n] = found
if (found := self.nodes.get(replace_source)) is not None: self.replace[n] = found
else:
x = UOp(*replace_source) if new_src != n.src else n
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) else x
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x
return found
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
if TRACK_MATCH_STATS >= 2:

View File

@ -144,7 +144,7 @@ class PTXRenderer(Renderer):
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype or isinstance(atype, PtrDType):
if u: r[u] = a
if u is not None: r[u] = a
return a
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
return ret

View File

@ -74,7 +74,7 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
return graph
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
if (found:=replaces.get(base.key)): return found
if (found:=replaces.get(base.key)) is not None: return found
new_srcs = tuple(replace_uop(x, replaces) for x in base.src)
replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
return ret