use pattern matcher for image [run_process_replay] (#6762)

* use pattern matcher for image [run_process_replay]

* try again

* this
This commit is contained in:
George Hotz 2024-09-26 15:49:09 +08:00 committed by GitHub
parent 197f8fd986
commit 7fca0bc912
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 14 deletions

View File

@ -27,7 +27,7 @@ class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
local: bool = False # images are never local
def scalar(self): return self.base
def scalar(self) -> DType: return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"

View File

@ -441,7 +441,7 @@ class UPat(MathTrait):
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
if (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
(self.dtype is not None and uop.dtype.scalar() not in self.dtype) or \
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
(self.arg is not None and self.arg != uop.arg) or \
(self.op is not None and uop.op not in self.op) or \
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []

View File

@ -9,10 +9,7 @@ from tinygrad.renderer import Renderer, TensorCore
def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]]
if isinstance(buf.dtype, ImageDType):
assert load.dtype == dtypes.float.vec(4), f"images must be float4, getting {load.dtype}"
val = f"read_imagef({r[buf]}, smp, {sidx})"
elif load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))"
else:
val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
@ -21,17 +18,15 @@ def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype)
return val
def render_store(r:CStyleLanguage, store:UOp, buf:UOp, var:UOp) -> str:
sidx = strip_parens(r[store.src[1]]) if store.src[1].arg == BinaryOps.ADD else r[store.src[1]]
if isinstance(buf.dtype, ImageDType):
assert var.dtype == dtypes.float.vec(4), f"images must be float4, getting {var.dtype}"
val = f"write_imagef({r[buf]}, {sidx}, {r[var]});"
elif var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp, gate:Optional[UOp]=None) -> str:
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix
val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};"
else:
val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};"
return f"if ({r[store.src[3]]}) {{ {val} }}" if len(store.src) > 3 and store.src[3].op is not UOps.IF else val
# TODO: this if should be in UOps, not here
return f"if ({r[gate]}) {{ {val} }}" if gate is not None else val
def render_alu(r:CStyleLanguage, x:UOp):
if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src]
@ -50,6 +45,16 @@ base_pm = PatternMatcher([
(UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"),
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"),
(UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"),
# load/store image
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))),
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
# TODO: this if should be in UOps, not here
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4)),
UPat.var("gate", dtype=dtypes.bool))), lambda r,buf,idx,var,gate: f"if ({r[gate]}) {{ write_imagef({r[buf]}, {r[idx]}, {r[var]}); }}"),
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True),
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
# r method accesses
(UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"),
(UPat(UOps.VECTORIZE, name="x"),
@ -73,7 +78,8 @@ base_pm = PatternMatcher([
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
# function calls
(UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat(), UPat.var("var")), allow_any_len=True, name="store"), render_store),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate", dtype=dtypes.bool))), render_store),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store),
(UPat(UOps.ALU, name="x"), render_alu),
(UPat(UOps.GEP, name="x"), render_gep),
])