diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index bcdb5a3c..34e8158a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -10,8 +10,8 @@ from tinygrad.renderer import Renderer, TensorCore def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType): sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx] if dtype.count > 1 and isinstance(buf.dtype, PtrDType): - return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))" - return f"*({r[buf]}+{sidx})" + return f"(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))" + return f"({r[buf]}+{sidx})" base_rewrite = PatternMatcher([ (UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]), @@ -42,11 +42,11 @@ base_rewrite = PatternMatcher([ (UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)), # load/store (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"), - lambda r,buf,idx,load,var,gate: f"({r[gate]}?{_render_index(r, buf, idx, load.dtype)}:{r[var]})"), + lambda r,buf,idx,load,var,gate: f"({r[gate]}?*{_render_index(r, buf, idx, load.dtype)}:{r[var]})"), (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"), - lambda r,buf,idx,load: _render_index(r, buf, idx, load.dtype)), + lambda r,buf,idx,load: f"*{_render_index(r, buf, idx, load.dtype)}"), (UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), - lambda r,buf,idx,var: f"{_render_index(r, buf, idx, var.dtype)} = {r[var]};"), + lambda r,buf,idx,var: f"*{_render_index(r, buf, idx, var.dtype)} = {r[var]};"), # alu/gep (UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg]( *([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)), @@ -103,15 +103,17 @@ class CStyleLanguage(Renderer): def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else - ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else + ("" if mutable else "const ")+self.render_dtype(dtype)+self.buffer_suffix if isinstance(dtype, PtrDType) else self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs] prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" - def render_dtype(self, var_dtype:DType) -> str: - return self.type_map.get(scalar:=var_dtype.scalar(), scalar.name) + (str(var_dtype.count) if (var_dtype.count) > 1 else "") + def render_dtype(self, dt:DType) -> str: + if isinstance(dt, PtrDType): + return (self.smem_prefix if dt.local else self.buffer_prefix) + self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "") + return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "") def __getitem__(self, key): return self.r[key] # hacky helper def render(self, name:str, uops:List[UOp]) -> str: @@ -239,9 +241,6 @@ class IntelRenderer(OpenCLRenderer): (UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda r,x: f"intel_convert_as_bfloat16_float({r[x[0]]})"), ]) + OpenCLRenderer.string_rewrite - def render_dtype(self, var_dtype:DType) -> str: - return f"ushort{var_dtype.count}" if "bfloat16" in var_dtype.name else super().render_dtype(var_dtype) - def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: prefix = [] for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):