mirror of https://github.com/commaai/tinygrad.git
improve render_dtype [pr] (#7117)
* improve render_dtype [pr] * don't deref in index
This commit is contained in:
parent
ca0dca35f7
commit
0b2621f63f
|
@ -10,8 +10,8 @@ from tinygrad.renderer import Renderer, TensorCore
|
|||
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
|
||||
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
|
||||
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
|
||||
return f"*({r[buf]}+{sidx})"
|
||||
return f"(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
|
||||
return f"({r[buf]}+{sidx})"
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
|
||||
|
@ -42,11 +42,11 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
|
||||
# load/store
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"),
|
||||
lambda r,buf,idx,load,var,gate: f"({r[gate]}?{_render_index(r, buf, idx, load.dtype)}:{r[var]})"),
|
||||
lambda r,buf,idx,load,var,gate: f"({r[gate]}?*{_render_index(r, buf, idx, load.dtype)}:{r[var]})"),
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"),
|
||||
lambda r,buf,idx,load: _render_index(r, buf, idx, load.dtype)),
|
||||
lambda r,buf,idx,load: f"*{_render_index(r, buf, idx, load.dtype)}"),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True),
|
||||
lambda r,buf,idx,var: f"{_render_index(r, buf, idx, var.dtype)} = {r[var]};"),
|
||||
lambda r,buf,idx,var: f"*{_render_index(r, buf, idx, var.dtype)} = {r[var]};"),
|
||||
# alu/gep
|
||||
(UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg](
|
||||
*([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)),
|
||||
|
@ -103,15 +103,17 @@ class CStyleLanguage(Renderer):
|
|||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
||||
("" if mutable else "const ")+self.render_dtype(dtype)+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||
prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
||||
|
||||
def render_dtype(self, var_dtype:DType) -> str:
|
||||
return self.type_map.get(scalar:=var_dtype.scalar(), scalar.name) + (str(var_dtype.count) if (var_dtype.count) > 1 else "")
|
||||
def render_dtype(self, dt:DType) -> str:
|
||||
if isinstance(dt, PtrDType):
|
||||
return (self.smem_prefix if dt.local else self.buffer_prefix) + self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
|
||||
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
|
||||
|
||||
def __getitem__(self, key): return self.r[key] # hacky helper
|
||||
def render(self, name:str, uops:List[UOp]) -> str:
|
||||
|
@ -239,9 +241,6 @@ class IntelRenderer(OpenCLRenderer):
|
|||
(UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda r,x: f"intel_convert_as_bfloat16_float({r[x[0]]})"),
|
||||
]) + OpenCLRenderer.string_rewrite
|
||||
|
||||
def render_dtype(self, var_dtype:DType) -> str:
|
||||
return f"ushort{var_dtype.count}" if "bfloat16" in var_dtype.name else super().render_dtype(var_dtype)
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = []
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
||||
|
|
Loading…
Reference in New Issue