From 7fca0bc91242aa4d55719edbed839902603fc030 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:49:09 +0800 Subject: [PATCH] use pattern matcher for image [run_process_replay] (#6762) * use pattern matcher for image [run_process_replay] * try again * this --- tinygrad/dtype.py | 2 +- tinygrad/ops.py | 2 +- tinygrad/renderer/cstyle.py | 30 ++++++++++++++++++------------ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index f4e9d32b..c8f40c8e 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -27,7 +27,7 @@ class ImageDType(DType): shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape base: DType local: bool = False # images are never local - def scalar(self): return self.base + def scalar(self) -> DType: return self.base def vec(self, sz:int): return self.base.vec(sz) def __repr__(self): return f"dtypes.{self.name}({self.shape})" diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 583f87e5..fcb2f7ea 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -441,7 +441,7 @@ class UPat(MathTrait): def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: if (self.name is not None and store.setdefault(self.name, uop) is not uop) or \ - (self.dtype is not None and uop.dtype.scalar() not in self.dtype) or \ + (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \ (self.arg is not None and self.arg != uop.arg) or \ (self.op is not None and uop.op not in self.op) or \ (self.allowed_len != -1 and len(uop.src) != self.allowed_len): return [] diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 8b1d42d5..0ec5582f 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -9,10 +9,7 @@ from tinygrad.renderer import Renderer, TensorCore def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str: sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]] - if isinstance(buf.dtype, ImageDType): - assert load.dtype == dtypes.float.vec(4), f"images must be float4, getting {load.dtype}" - val = f"read_imagef({r[buf]}, smp, {sidx})" - elif load.dtype.count > 1 and isinstance(buf.dtype, PtrDType): + if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType): val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))" else: val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]" @@ -21,17 +18,15 @@ def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str: if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype) return val -def render_store(r:CStyleLanguage, store:UOp, buf:UOp, var:UOp) -> str: - sidx = strip_parens(r[store.src[1]]) if store.src[1].arg == BinaryOps.ADD else r[store.src[1]] - if isinstance(buf.dtype, ImageDType): - assert var.dtype == dtypes.float.vec(4), f"images must be float4, getting {var.dtype}" - val = f"write_imagef({r[buf]}, {sidx}, {r[var]});" - elif var.dtype.count > 1 and isinstance(buf.dtype, PtrDType): +def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp, gate:Optional[UOp]=None) -> str: + sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx] + if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType): prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};" else: val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};" - return f"if ({r[store.src[3]]}) {{ {val} }}" if len(store.src) > 3 and store.src[3].op is not UOps.IF else val + # TODO: this if should be in UOps, not here + return f"if ({r[gate]}) {{ {val} }}" if gate is not None else val def render_alu(r:CStyleLanguage, x:UOp): if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src] @@ -50,6 +45,16 @@ base_pm = PatternMatcher([ (UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"), (UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"), (UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"), + # load/store image + (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))), + lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"), + (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))), + lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"), + # TODO: this if should be in UOps, not here + (UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4)), + UPat.var("gate", dtype=dtypes.bool))), lambda r,buf,idx,var,gate: f"if ({r[gate]}) {{ write_imagef({r[buf]}, {r[idx]}, {r[var]}); }}"), + (UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True), + lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"), # r method accesses (UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"), (UPat(UOps.VECTORIZE, name="x"), @@ -73,7 +78,8 @@ base_pm = PatternMatcher([ (UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)), # function calls (UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load), - (UPat(UOps.STORE, src=(UPat.var("buf"), UPat(), UPat.var("var")), allow_any_len=True, name="store"), render_store), + (UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate", dtype=dtypes.bool))), render_store), + (UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store), (UPat(UOps.ALU, name="x"), render_alu), (UPat(UOps.GEP, name="x"), render_gep), ])