decouple buffer mutability from cstyle (#3617)

* buffer mutability as an arg

* update test_uops
This commit is contained in:
qazal 2024-03-05 16:20:59 +02:00 committed by GitHub
parent 3275260c98
commit eb83e2d3a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 12 deletions

View File

@ -22,8 +22,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar
def _test_single_value(vals, op, dts):
uops = []
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0'))
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}')) for i,dtype in enumerate(dts)]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}',False)) for i,dtype in enumerate(dts)]
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
@ -38,7 +38,7 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts):
uops = []
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0'))
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))

View File

@ -195,7 +195,7 @@ class Linearizer(Kernel):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
(buf.idx, f"data{buf.idx}"))
(buf.idx, f"data{buf.idx}", i == 0))
# add var vals
for i,var in enumerate(self.ast.vars()):
assert var.expr is not None

View File

@ -62,11 +62,12 @@ class CStyleLanguage(NamedTuple):
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], uops:List[UOp], prefix=None) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
("const " if i > 0 else "")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
self.arg_int_prefix if dtype == dtypes.int else None) for i,(name,dtype) in enumerate(bufs)]
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]],
local_size:List[int], uops:List[UOp], prefix=None) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
@ -90,7 +91,7 @@ class CStyleLanguage(NamedTuple):
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
local_size: List[int] = []
kernel = []
bufs: List[Tuple[str, DType]] = []
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
#pend_close = None
depth = 1
def kk(s): kernel.append(" "*depth+s)
@ -162,11 +163,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
kk(lang.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
bufs.append((args.expr, (dtype,False)))
r[u] = args.expr
elif uop is UOps.DEFINE_GLOBAL:
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
bufs.append((args[1], dtype))
bufs.append((args[1], (dtype,args[2])))
r[u] = args[1]
elif uop is UOps.WMMA: kk(f"{dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")