linearizer: support local and group_for_reduce dimensions together (#1821)

also minor changes to test_speed_v_torch.py and size of UOps.SPECIAL
This commit is contained in:
Francis Lam 2023-09-08 12:39:27 -07:00 committed by GitHub
parent 9e8c1dbf34
commit 651205fa5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 8 deletions

View File

@ -42,7 +42,7 @@ def colorize_float(x):
return colored(ret, 'yellow')
save_ops, save_mem = 0, 0
CNT = 8
CNT = getenv("CNT", 8)
def helper_test_speed(f1, *args):
global save_ops, save_mem
ets = []
@ -108,7 +108,7 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args):
flops = save_ops*1e-6
mem = save_mem*1e-6
print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB")
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4, rtol=1e-3)
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3)
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
torch.manual_seed(0)

View File

@ -206,7 +206,7 @@ class Linearizer(OptimizedKernel):
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduce:
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.sts.append(ShapeTracker(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
@ -347,8 +347,9 @@ class Linearizer(OptimizedKernel):
self.uop(UOps.BARRIER, None, ())
end_loop(loop_local_idxs)
# local indexs are over, 0 them out
local_idxs = [x*0 for x in local_idxs]
# create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
# if any group_for_reduce items aren't reduces, upcast them here
for j in self.upcast_in_mid_reduce_axes:
@ -356,6 +357,7 @@ class Linearizer(OptimizedKernel):
self.upcast()
self.group_for_reduce.pop()
local_idxs = local_idxs[:-1]
end_local_idxs = end_local_idxs[:-1]
# regenerate upcast_idxs
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
@ -365,11 +367,10 @@ class Linearizer(OptimizedKernel):
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# late reduce loop
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
render_loop(end_local_idxs)
# load localbufs
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs)
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, do_reduce=True) # type: ignore

View File

@ -165,7 +165,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_const(args, dtype)};")
elif uop == UOps.SPECIAL:
xid = lang.gid if args[1].startswith("g") else lang.lid
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]};")
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
if args[1].startswith("l"): local_size.append(args[2])
r[u] = args[1]
elif uop == UOps.CONST: