From 63c46e02871712f6c99e3747eeccf430957ca79f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 4 Sep 2023 16:09:01 -0700 Subject: [PATCH] Parens and gls (#1768) * more paren stripping * remove global and local size from renderers * complex strip parens * extra helpers + minor webgpu fix * fix test uops * one more parens test --- ...{test_helpers.py => test_extra_helpers.py} | 0 test/test_uops.py | 5 ++-- test/{ => unit}/test_helpers.py | 7 +++++- tinygrad/codegen/kernel.py | 3 +++ tinygrad/codegen/linearizer.py | 1 + tinygrad/helpers.py | 2 +- tinygrad/ops.py | 7 +++--- tinygrad/renderer/cstyle.py | 24 +++++++++---------- tinygrad/renderer/llvmir.py | 6 ++--- tinygrad/renderer/wgsl.py | 6 ++--- 10 files changed, 33 insertions(+), 28 deletions(-) rename test/extra/{test_helpers.py => test_extra_helpers.py} (100%) rename test/{ => unit}/test_helpers.py (93%) diff --git a/test/extra/test_helpers.py b/test/extra/test_extra_helpers.py similarity index 100% rename from test/extra/test_helpers.py rename to test/extra/test_extra_helpers.py diff --git a/test/test_uops.py b/test/test_uops.py index 143fa9c3..ab7c3077 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -7,9 +7,8 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp def _uops_to_prg(uops): - ret = Device[Device.DEFAULT].renderer("test", uops) - src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,) - return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime) + src = Device[Device.DEFAULT].renderer("test", uops) + return ASTRunner("test", src, [1], [1], runtime_args={"binary": False}).build(Device[Device.DEFAULT].runtime) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops))) diff --git a/test/test_helpers.py b/test/unit/test_helpers.py similarity index 93% rename from test/test_helpers.py rename to test/unit/test_helpers.py index 5ce824b5..15445bb7 100644 --- a/test/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts +from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens VARIABLE = ContextVar("VARIABLE", 0) @@ -125,5 +125,10 @@ class TestDtypes(unittest.TestCase): self.assertTrue(all(isinstance(value, DType) for value in fields.values())) self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None)) +class TestStripParens(unittest.TestCase): + def test_simple(self): self.assertEqual("1+2", strip_parens("(1+2)")) + def test_nested(self): self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))")) + def test_casted_no_strip(self): self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)")) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 194c9a37..21663027 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -69,6 +69,9 @@ class Kernel: self.exclude_local_upcast: int = 0 self.reverse_upcast_dir: bool = False + self.global_size: Optional[List[int]] = None + self.local_size: Optional[List[int]] = None + def has_variable_shape(self) -> bool: for b in self.bufs: if any(not isinstance(x, int) for x in b.st.shape): return True diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index f1b9757b..afff43bd 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -234,6 +234,7 @@ class Linearizer(OptimizedKernel): if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, [loop_uop]) if self.opts.has_local: + self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1] self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) else: diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 77bbb8c0..17c7dc54 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -17,7 +17,7 @@ def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s)) def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x def flatten(l:Iterator): return [item for sublist in l for item in sublist] def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) -def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' else fst +def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst def merge_dicts(ds:Iterable[Dict]) -> Dict: kvs = set([(k,v) for d in ds for k,v in d.items()]) assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a7373ff9..29fcf434 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -172,11 +172,10 @@ class Compiled: def to_program(self, k): k.linearize() - ret = self.renderer(k.function_name, k.uops) - src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,) - return ASTRunner(k.function_name, src, global_size, local_size, + src = self.renderer(k.function_name, k.uops) + return ASTRunner(k.function_name, src, k.global_size, k.local_size, op_estimate=k.info.flops, mem_estimate=k.mem_estimate, - display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime) + display_name=k.display_name, runtime_args={"binary": False}).build(self.runtime) def exec_ast(self, ast:LazyOp, output, **kwargs): # all movementops do nothing in a Compiled buffer! diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index f4e01068..d0c35ac6 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -74,7 +74,7 @@ class CStyleLanguage(NamedTuple): def render_conditional(self, cond: str, x:str, y:str) -> str: return f"({cond})?({x}):{y}" - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]: + def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else self.arg_int_prefix if dtype == dtypes._arg_int32 else @@ -83,22 +83,20 @@ class CStyleLanguage(NamedTuple): [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg]) - - return prg, global_size[::-1], local_size[::-1] + return prg # returns a str statement that does the store - def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str: + def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str: if isinstance(buf_dtype, ImageDType): assert var_dtype == dtypes._float4, "images must be float4" return f"write_imagef({buf_name}, {idx}, {var_name});" if self.uses_vload and buf_dtype == dtypes.float16: - return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{strip_parens(idx)});" + return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.sz > 1: - return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{strip_parens(idx)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" - return f"*({buf_name}+{strip_parens(idx)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" + return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" + return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" -def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]: - global_size: List[int] = [] +def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str: local_size: List[int] = [] kernel,prekernel = [],[] #pend_close = None @@ -168,13 +166,13 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T elif uop == UOps.SPECIAL: xid = lang.gid if args[1].startswith("g") else lang.lid kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]};") - (global_size if args[1].startswith("g") else local_size).append(args[2]) + if args[1].startswith("l"): local_size.append(args[2]) r[u] = args[1] elif uop == UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" elif uop == UOps.LOAD: assert dtype is not None - val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, r[vin[1]], vin[0].uop == UOps.DEFINE_LOCAL) + val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL) if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) r[u] = ssa('val') kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};") @@ -183,7 +181,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T kk(f"{r[vin[0]]} = {r[vin[1]]};") elif len(vin) == 3: assert vin[0].dtype is not None and vin[2].dtype is not None - kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, r[vin[1]], vin[0].uop == UOps.DEFINE_LOCAL)) + kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)) elif uop == UOps.CAST and dtype is not None and dtype.sz > 1: val = lang.render_cast([r[x] for x in vin], dtype) if child_count[u] <= 1: @@ -205,4 +203,4 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T else: raise RuntimeError(f"failed to render {uop}") - return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel) + return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 098efa97..71fda62d 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,4 +1,4 @@ -from typing import Final, Dict, Callable, Any, List, Optional, Tuple +from typing import Final, Dict, Callable, Any, List, Optional from llvmlite import ir # type: ignore from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.helpers import dtypes @@ -55,7 +55,7 @@ def cast(bb, val, input_type, output_type): raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") -def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]: +def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str: # all llvm stuff goes into a module module = ir.Module(name=__file__) @@ -137,4 +137,4 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) bb[-1].ret_void() - return str(module), None, None + return str(module) diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index f1820f90..897aa709 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -29,16 +29,16 @@ class WGSLLanguage(CStyleLanguage): def render_const(self, x:Union[float,int], var_dtype) -> str: if math.isnan(x): val = "nan()" elif math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f" - else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f") + else: val = f"({x}" + ("" if dtypes.is_int(var_dtype) else "f") + ")" return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]: + def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: local_size = local_size[::-1] if local_size else [1] bind_it = iter(range(len(bufs))) prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var {name}: array<{type_map[dtype]}>;" for name,dtype in bufs]) prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" - return prg, global_size[::-1] if global_size else [1], local_size + return prg def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str: return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"