mirror of https://github.com/commaai/tinygrad.git
decouple buffer mutability from cstyle (#3617)
* buffer mutability as an arg * update test_uops
This commit is contained in:
parent
3275260c98
commit
eb83e2d3a0
|
@ -22,8 +22,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar
|
|||
def _test_single_value(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0'))
|
||||
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}')) for i,dtype in enumerate(dts)]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
|
||||
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}',False)) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
|
@ -38,7 +38,7 @@ def _test_single_value(vals, op, dts):
|
|||
def _test_single_value_const(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0'))
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
|
||||
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
|
|
|
@ -195,7 +195,7 @@ class Linearizer(Kernel):
|
|||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
|
||||
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
||||
(buf.idx, f"data{buf.idx}"))
|
||||
(buf.idx, f"data{buf.idx}", i == 0))
|
||||
# add var vals
|
||||
for i,var in enumerate(self.ast.vars()):
|
||||
assert var.expr is not None
|
||||
|
|
|
@ -62,11 +62,12 @@ class CStyleLanguage(NamedTuple):
|
|||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], uops:List[UOp], prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
("const " if i > 0 else "")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for i,(name,dtype) in enumerate(bufs)]
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]],
|
||||
local_size:List[int], uops:List[UOp], prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
|
@ -90,7 +91,7 @@ class CStyleLanguage(NamedTuple):
|
|||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
|
||||
local_size: List[int] = []
|
||||
kernel = []
|
||||
bufs: List[Tuple[str, DType]] = []
|
||||
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
|
||||
#pend_close = None
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
@ -162,11 +163,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
|||
kk(lang.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
bufs.append((args.expr, (dtype,False)))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
|
||||
bufs.append((args[1], dtype))
|
||||
bufs.append((args[1], (dtype,args[2])))
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.WMMA: kk(f"{dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
||||
|
|
Loading…
Reference in New Issue