Parens and gls (#1768)

* more paren stripping

* remove global and local size from renderers

* complex strip parens

* extra helpers + minor webgpu fix

* fix test uops

* one more parens test
This commit is contained in:
George Hotz 2023-09-04 16:09:01 -07:00 committed by GitHub
parent 3473c9e88d
commit 63c46e0287
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 33 additions and 28 deletions

View File

@ -7,9 +7,8 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
def _uops_to_prg(uops):
ret = Device[Device.DEFAULT].renderer("test", uops)
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime)
src = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src, [1], [1], runtime_args={"binary": False}).build(Device[Device.DEFAULT].runtime)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))

View File

@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens
VARIABLE = ContextVar("VARIABLE", 0)
@ -125,5 +125,10 @@ class TestDtypes(unittest.TestCase):
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
class TestStripParens(unittest.TestCase):
def test_simple(self): self.assertEqual("1+2", strip_parens("(1+2)"))
def test_nested(self): self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))"))
def test_casted_no_strip(self): self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)"))
if __name__ == '__main__':
unittest.main()

View File

@ -69,6 +69,9 @@ class Kernel:
self.exclude_local_upcast: int = 0
self.reverse_upcast_dir: bool = False
self.global_size: Optional[List[int]] = None
self.local_size: Optional[List[int]] = None
def has_variable_shape(self) -> bool:
for b in self.bufs:
if any(not isinstance(x, int) for x in b.st.shape): return True

View File

@ -234,6 +234,7 @@ class Linearizer(OptimizedKernel):
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, [loop_uop])
if self.opts.has_local:
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:

View File

@ -17,7 +17,7 @@ def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' else fst
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
def merge_dicts(ds:Iterable[Dict]) -> Dict:
kvs = set([(k,v) for d in ds for k,v in d.items()])
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"

View File

@ -172,11 +172,10 @@ class Compiled:
def to_program(self, k):
k.linearize()
ret = self.renderer(k.function_name, k.uops)
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
return ASTRunner(k.function_name, src, global_size, local_size,
src = self.renderer(k.function_name, k.uops)
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)
display_name=k.display_name, runtime_args={"binary": False}).build(self.runtime)
def exec_ast(self, ast:LazyOp, output, **kwargs):
# all movementops do nothing in a Compiled buffer!

View File

@ -74,7 +74,7 @@ class CStyleLanguage(NamedTuple):
def render_conditional(self, cond: str, x:str, y:str) -> str:
return f"({cond})?({x}):{y}"
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
self.arg_int_prefix if dtype == dtypes._arg_int32 else
@ -83,22 +83,20 @@ class CStyleLanguage(NamedTuple):
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
return prg, global_size[::-1], local_size[::-1]
return prg
# returns a str statement that does the store
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{strip_parens(idx)});"
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{strip_parens(idx)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*({buf_name}+{strip_parens(idx)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
global_size: List[int] = []
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
local_size: List[int] = []
kernel,prekernel = [],[]
#pend_close = None
@ -168,13 +166,13 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
elif uop == UOps.SPECIAL:
xid = lang.gid if args[1].startswith("g") else lang.lid
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]};")
(global_size if args[1].startswith("g") else local_size).append(args[2])
if args[1].startswith("l"): local_size.append(args[2])
r[u] = args[1]
elif uop == UOps.CONST:
r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop == UOps.LOAD:
assert dtype is not None
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, r[vin[1]], vin[0].uop == UOps.DEFINE_LOCAL)
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
r[u] = ssa('val')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
@ -183,7 +181,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
kk(f"{r[vin[0]]} = {r[vin[1]]};")
elif len(vin) == 3:
assert vin[0].dtype is not None and vin[2].dtype is not None
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, r[vin[1]], vin[0].uop == UOps.DEFINE_LOCAL))
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
elif uop == UOps.CAST and dtype is not None and dtype.sz > 1:
val = lang.render_cast([r[x] for x in vin], dtype)
if child_count[u] <= 1:
@ -205,4 +203,4 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
else:
raise RuntimeError(f"failed to render {uop}")
return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel)
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel)

View File

@ -1,4 +1,4 @@
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
from typing import Final, Dict, Callable, Any, List, Optional
from llvmlite import ir # type: ignore
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes
@ -55,7 +55,7 @@ def cast(bb, val, input_type, output_type):
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
@ -137,4 +137,4 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
bb[-1].ret_void()
return str(module), None, None
return str(module)

View File

@ -29,16 +29,16 @@ class WGSLLanguage(CStyleLanguage):
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): val = "nan()"
elif math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
else: val = f"({x}" + ("" if dtypes.is_int(var_dtype) else "f") + ")"
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
local_size = local_size[::-1] if local_size else [1]
bind_it = iter(range(len(bufs)))
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{type_map[dtype]}>;" for name,dtype in bufs])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
return prg, global_size[::-1] if global_size else [1], local_size
return prg
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"