mirror of https://github.com/commaai/tinygrad.git
Parens and gls (#1768)
* more paren stripping * remove global and local size from renderers * complex strip parens * extra helpers + minor webgpu fix * fix test uops * one more parens test
This commit is contained in:
parent
3473c9e88d
commit
63c46e0287
|
@ -7,9 +7,8 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
|
|||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
ret = Device[Device.DEFAULT].renderer("test", uops)
|
||||
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
|
||||
return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime)
|
||||
src = Device[Device.DEFAULT].renderer("test", uops)
|
||||
return ASTRunner("test", src, [1], [1], runtime_args={"binary": False}).build(Device[Device.DEFAULT].runtime)
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
|
@ -125,5 +125,10 @@ class TestDtypes(unittest.TestCase):
|
|||
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
|
||||
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
|
||||
|
||||
class TestStripParens(unittest.TestCase):
|
||||
def test_simple(self): self.assertEqual("1+2", strip_parens("(1+2)"))
|
||||
def test_nested(self): self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))"))
|
||||
def test_casted_no_strip(self): self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -69,6 +69,9 @@ class Kernel:
|
|||
self.exclude_local_upcast: int = 0
|
||||
self.reverse_upcast_dir: bool = False
|
||||
|
||||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
|
||||
def has_variable_shape(self) -> bool:
|
||||
for b in self.bufs:
|
||||
if any(not isinstance(x, int) for x in b.st.shape): return True
|
||||
|
|
|
@ -234,6 +234,7 @@ class Linearizer(OptimizedKernel):
|
|||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, [loop_uop])
|
||||
|
||||
if self.opts.has_local:
|
||||
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
||||
else:
|
||||
|
|
|
@ -17,7 +17,7 @@ def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
|
|||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' else fst
|
||||
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def merge_dicts(ds:Iterable[Dict]) -> Dict:
|
||||
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
||||
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
|
|
|
@ -172,11 +172,10 @@ class Compiled:
|
|||
|
||||
def to_program(self, k):
|
||||
k.linearize()
|
||||
ret = self.renderer(k.function_name, k.uops)
|
||||
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
|
||||
return ASTRunner(k.function_name, src, global_size, local_size,
|
||||
src = self.renderer(k.function_name, k.uops)
|
||||
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
|
||||
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
||||
display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)
|
||||
display_name=k.display_name, runtime_args={"binary": False}).build(self.runtime)
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output, **kwargs):
|
||||
# all movementops do nothing in a Compiled buffer!
|
||||
|
|
|
@ -74,7 +74,7 @@ class CStyleLanguage(NamedTuple):
|
|||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
return f"({cond})?({x}):{y}"
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
self.arg_int_prefix if dtype == dtypes._arg_int32 else
|
||||
|
@ -83,22 +83,20 @@ class CStyleLanguage(NamedTuple):
|
|||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
|
||||
|
||||
return prg, global_size[::-1], local_size[::-1]
|
||||
return prg
|
||||
|
||||
# returns a str statement that does the store
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes._float4, "images must be float4"
|
||||
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{strip_parens(idx)});"
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{strip_parens(idx)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*({buf_name}+{strip_parens(idx)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
|
||||
global_size: List[int] = []
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
|
||||
local_size: List[int] = []
|
||||
kernel,prekernel = [],[]
|
||||
#pend_close = None
|
||||
|
@ -168,13 +166,13 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
|||
elif uop == UOps.SPECIAL:
|
||||
xid = lang.gid if args[1].startswith("g") else lang.lid
|
||||
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]};")
|
||||
(global_size if args[1].startswith("g") else local_size).append(args[2])
|
||||
if args[1].startswith("l"): local_size.append(args[2])
|
||||
r[u] = args[1]
|
||||
elif uop == UOps.CONST:
|
||||
r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
elif uop == UOps.LOAD:
|
||||
assert dtype is not None
|
||||
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, r[vin[1]], vin[0].uop == UOps.DEFINE_LOCAL)
|
||||
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
|
||||
if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
|
||||
r[u] = ssa('val')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
|
||||
|
@ -183,7 +181,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
|||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
elif len(vin) == 3:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, r[vin[1]], vin[0].uop == UOps.DEFINE_LOCAL))
|
||||
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
|
||||
elif uop == UOps.CAST and dtype is not None and dtype.sz > 1:
|
||||
val = lang.render_cast([r[x] for x in vin], dtype)
|
||||
if child_count[u] <= 1:
|
||||
|
@ -205,4 +203,4 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
|||
else:
|
||||
raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel)
|
||||
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
|
||||
from typing import Final, Dict, Callable, Any, List, Optional
|
||||
from llvmlite import ir # type: ignore
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.helpers import dtypes
|
||||
|
@ -55,7 +55,7 @@ def cast(bb, val, input_type, output_type):
|
|||
|
||||
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
||||
|
||||
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
|
||||
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
|
@ -137,4 +137,4 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
|||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module), None, None
|
||||
return str(module)
|
||||
|
|
|
@ -29,16 +29,16 @@ class WGSLLanguage(CStyleLanguage):
|
|||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
if math.isnan(x): val = "nan()"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
|
||||
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
|
||||
else: val = f"({x}" + ("" if dtypes.is_int(var_dtype) else "f") + ")"
|
||||
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
||||
local_size = local_size[::-1] if local_size else [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
||||
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{type_map[dtype]}>;" for name,dtype in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
|
||||
return prg, global_size[::-1] if global_size else [1], local_size
|
||||
return prg
|
||||
|
||||
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
|
||||
|
|
Loading…
Reference in New Issue