improve render_dtype [pr] (#7117)

* improve render_dtype [pr]

* don't deref in index
This commit is contained in:
George Hotz 2024-10-17 14:50:40 +08:00 committed by GitHub
parent ca0dca35f7
commit 0b2621f63f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 11 deletions

View File

@ -10,8 +10,8 @@ from tinygrad.renderer import Renderer, TensorCore
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
return f"*({r[buf]}+{sidx})"
return f"(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
return f"({r[buf]}+{sidx})"
base_rewrite = PatternMatcher([
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
@ -42,11 +42,11 @@ base_rewrite = PatternMatcher([
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
# load/store
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"),
lambda r,buf,idx,load,var,gate: f"({r[gate]}?{_render_index(r, buf, idx, load.dtype)}:{r[var]})"),
lambda r,buf,idx,load,var,gate: f"({r[gate]}?*{_render_index(r, buf, idx, load.dtype)}:{r[var]})"),
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"),
lambda r,buf,idx,load: _render_index(r, buf, idx, load.dtype)),
lambda r,buf,idx,load: f"*{_render_index(r, buf, idx, load.dtype)}"),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True),
lambda r,buf,idx,var: f"{_render_index(r, buf, idx, var.dtype)} = {r[var]};"),
lambda r,buf,idx,var: f"*{_render_index(r, buf, idx, var.dtype)} = {r[var]};"),
# alu/gep
(UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg](
*([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)),
@ -103,15 +103,17 @@ class CStyleLanguage(Renderer):
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
("" if mutable else "const ")+self.render_dtype(dtype)+self.buffer_suffix if isinstance(dtype, PtrDType) else
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
def render_dtype(self, var_dtype:DType) -> str:
return self.type_map.get(scalar:=var_dtype.scalar(), scalar.name) + (str(var_dtype.count) if (var_dtype.count) > 1 else "")
def render_dtype(self, dt:DType) -> str:
if isinstance(dt, PtrDType):
return (self.smem_prefix if dt.local else self.buffer_prefix) + self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
def __getitem__(self, key): return self.r[key] # hacky helper
def render(self, name:str, uops:List[UOp]) -> str:
@ -239,9 +241,6 @@ class IntelRenderer(OpenCLRenderer):
(UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda r,x: f"intel_convert_as_bfloat16_float({r[x[0]]})"),
]) + OpenCLRenderer.string_rewrite
def render_dtype(self, var_dtype:DType) -> str:
return f"ushort{var_dtype.count}" if "bfloat16" in var_dtype.name else super().render_dtype(var_dtype)
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = []
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):