mirror of https://github.com/commaai/tinygrad.git
linearizer: fix up edge case bugs in UNROLL opt (#3362)
Fully UNROLLing the first_reduce should not change the number of local_dims. Fully UNROLLing a GROUP dim should reduce the number of group_for_reduces by one. Also changed group_for_reduces to be a count as the axis number isn't used anywhere (they are always the first reduce dims).
This commit is contained in:
parent
dc82ef6660
commit
ddb22a60c8
|
@ -383,7 +383,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
|||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
|
||||
assert len(k.group_for_reduce) == 1
|
||||
assert k.group_for_reduces == 1
|
||||
assert k.local_dims == 1
|
||||
assert k.upcasted == 1
|
||||
|
||||
|
@ -612,9 +612,14 @@ class TestLinearizerOpts(unittest.TestCase):
|
|||
opts_shapes = [
|
||||
([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]),
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]),
|
||||
# TODO: fix these broken transformations
|
||||
# ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
# ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
|
||||
# check to ensure local_dims are stable for full UNROLL of first_reduce
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
# check behavior for full UNROLL on an existing GROUP
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]),
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
|
||||
]
|
||||
helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes])
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ class Kernel:
|
|||
|
||||
# parameters for optimization
|
||||
self.applied_opts: List[Opt] = []
|
||||
self.group_for_reduce: List[int] = []
|
||||
self.group_for_reduces: int = 0
|
||||
self.upcasted: int = 0
|
||||
self.local_dims: int = 0
|
||||
self.local_alias: Dict[int, LocalBuffer] = {}
|
||||
|
@ -123,8 +123,8 @@ class Kernel:
|
|||
self.info, self.reduceop, self.bufs[:], self.earlybufs, self.full_buf_index, self.sts[:]
|
||||
|
||||
# parameters for optimizations
|
||||
ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \
|
||||
self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals
|
||||
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \
|
||||
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals
|
||||
|
||||
# uncached since linearize didn't run
|
||||
ret.applied_opts_cache = None
|
||||
|
@ -172,7 +172,7 @@ class Kernel:
|
|||
|
||||
@property
|
||||
def upcast_in_mid_reduce_axes(self) -> List[int]:
|
||||
return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
|
||||
return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]]
|
||||
|
||||
@property
|
||||
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
||||
|
@ -192,10 +192,10 @@ class Kernel:
|
|||
colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims
|
||||
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
||||
colors += ["cyan"] * self.local_dims
|
||||
# between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green)
|
||||
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] # noqa: E501
|
||||
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
|
||||
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
|
||||
# between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green)
|
||||
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501
|
||||
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
||||
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces))
|
||||
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
||||
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
|
||||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
|
@ -399,7 +399,7 @@ class Kernel:
|
|||
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501
|
||||
self.applied_opts.append(opt)
|
||||
if opt.axis is not None:
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
|
||||
else:
|
||||
axis = -1
|
||||
if opt.amt is not None:
|
||||
|
@ -419,16 +419,18 @@ class Kernel:
|
|||
self.local_dims += 1
|
||||
elif opt.op in [OptOps.GROUP, OptOps.GROUPTOP]: # green
|
||||
assert self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem"
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not self.tensor_core, "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, top=(opt.op==OptOps.GROUPTOP), insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
self.shift_to(axis, amt, top=(opt.op==OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
||||
self.group_for_reduces += 1
|
||||
elif opt.op == OptOps.UNROLL: # purple
|
||||
assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
|
||||
assert amt <= 32, "don't unroll more than 32"
|
||||
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
|
||||
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
|
||||
#self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
|
||||
if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones
|
||||
if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
|
@ -437,16 +439,16 @@ class Kernel:
|
|||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCASTMID: # white
|
||||
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" # noqa: E501
|
||||
assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" # noqa: E501
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
assert axes[0] == axis, "wrong axis"
|
||||
assert amt == 4, "don't upcast mid anything but 4"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
|
||||
self.group_for_reduces += 1
|
||||
elif opt.op == OptOps.NOLOCALS:
|
||||
assert self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals"
|
||||
assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
|
||||
assert self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals"
|
||||
self.dont_use_locals = True
|
||||
elif opt.op == OptOps.PADTO:
|
||||
assert not self.ast.vars(), "does not work with symbolic shape"
|
||||
|
@ -499,7 +501,7 @@ class Kernel:
|
|||
break
|
||||
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
|
@ -517,7 +519,7 @@ class Kernel:
|
|||
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
||||
|
||||
# no more opt if we are grouping
|
||||
if self.group_for_reduce: return
|
||||
if self.group_for_reduces: return
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
|
@ -574,7 +576,7 @@ class Kernel:
|
|||
# **** local groups ****
|
||||
|
||||
if self.opts.has_local:
|
||||
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
|
||||
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
|
||||
self.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
|
|
|
@ -165,7 +165,7 @@ class Linearizer(Kernel):
|
|||
if self.applied_opts == self.applied_opts_cache: return self
|
||||
|
||||
# save backups
|
||||
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted
|
||||
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted
|
||||
|
||||
# global uop cache
|
||||
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
||||
|
@ -190,9 +190,9 @@ class Linearizer(Kernel):
|
|||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
|
||||
# add a local buffer for multistage reduce. # TODO: use local alias
|
||||
if self.group_for_reduce:
|
||||
if self.group_for_reduces:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
||||
temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype)
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
|
||||
|
@ -207,7 +207,7 @@ class Linearizer(Kernel):
|
|||
|
||||
# define indexes
|
||||
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0) # noqa: E501
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
|
||||
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
||||
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
||||
|
||||
|
@ -241,7 +241,7 @@ class Linearizer(Kernel):
|
|||
fake_reduce_idxs: List[Variable] = []
|
||||
if self.reduceop is not None:
|
||||
# define indexes
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] # noqa: E501
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
|
@ -321,7 +321,7 @@ class Linearizer(Kernel):
|
|||
self.load_cache.clear()
|
||||
|
||||
# end the local loop, do the local reduce
|
||||
if self.group_for_reduce:
|
||||
if self.group_for_reduces:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
|
||||
|
@ -332,14 +332,14 @@ class Linearizer(Kernel):
|
|||
barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] # noqa: E501
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
||||
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
||||
|
||||
# if any group_for_reduce items aren't reduces, upcast them here
|
||||
for j in self.upcast_in_mid_reduce_axes:
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
||||
self.upcast()
|
||||
self.group_for_reduce.pop()
|
||||
self.group_for_reduces -= 1
|
||||
local_idxs = local_idxs[:-1]
|
||||
end_local_idxs = end_local_idxs[:-1]
|
||||
# regenerate upcast_idxs
|
||||
|
@ -364,7 +364,7 @@ class Linearizer(Kernel):
|
|||
|
||||
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
||||
# been rewritten with fake end_local_idxs.
|
||||
local_idxs = local_idxs[:self.local_dims] + [NumNode(0) for i in range(len(self.group_for_reduce))]
|
||||
local_idxs = local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)]
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # noqa: E501
|
||||
|
@ -387,7 +387,7 @@ class Linearizer(Kernel):
|
|||
graph_uops(self.uops)
|
||||
|
||||
# restore backups
|
||||
self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup
|
||||
self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup
|
||||
|
||||
# set cache and return
|
||||
self.applied_opts_cache = self.applied_opts[:]
|
||||
|
|
Loading…
Reference in New Issue