mirror of https://github.com/commaai/tinygrad.git
allow local + grouped reduce in hand_coded (#1996)
* allow local + grouped reduce in hand_coded * allowed loop size based on global_dims * fix const * fix const one more time * better divisor * a bit fix * can take 2, why not * fix linter * better comments * start with 2 * not always pick group reduce * fix images * better images * better
This commit is contained in:
parent
fa9945dac0
commit
219a1f7063
|
@ -330,18 +330,23 @@ class OptimizedKernel(Kernel):
|
|||
self.upcast()
|
||||
return
|
||||
|
||||
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
||||
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(sz)
|
||||
break
|
||||
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]) and not self.float4_axis(0):
|
||||
# are we grouping? (requires local shape support).
|
||||
# Determining the number of elements to group:
|
||||
# - The grouping is influenced by the number of elements we aim to retain in the reduction loop.
|
||||
# - As global dimensions increase, larger reduction loops are acceptable because the GPU remains occupied and can run threads longer.
|
||||
# - For small global dimensions (<=512), maximize the elements pulled into local memory to ensure GPU utilization.
|
||||
has_images = any(self.bufs[buf_index].dtype.__class__ is ImageDType for buf_index,_ in enumerate(self.bufs))
|
||||
if self.first_reduce + 1 <= self.shape_len and isinstance(self.full_shape[self.first_reduce], int):
|
||||
divisors = [d for d in range(1, min(257, self.full_shape[self.first_reduce])) if self.full_shape[self.first_reduce] % d == 0] # type: ignore
|
||||
divisors = [d for d in divisors if d % 4 == 0 and (self.full_shape[self.first_reduce] // d) % 4 == 0] if has_images else divisors # images need a unit stride axis (see required_optimizations()).
|
||||
suitable_divisors = [d for d in divisors if self.full_shape[self.first_reduce] // d <= prod(self.full_shape[:self.first_reduce]) // 8]
|
||||
if divisors and (sz := (suitable_divisors[0] if suitable_divisors and prod(self.full_shape[:self.first_reduce]) > 512 else divisors[-1])) and sz > 1:
|
||||
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce+len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(sz)
|
||||
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
||||
if self.bufs[0].dtype.name.startswith('image') and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
|
@ -362,9 +367,6 @@ class OptimizedKernel(Kernel):
|
|||
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
|
||||
self.simplify_ones()
|
||||
|
||||
# no more opt if we are grouping
|
||||
if self.group_for_reduce: return
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
|
@ -401,11 +403,11 @@ class OptimizedKernel(Kernel):
|
|||
break
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
|
||||
if self.first_reduce+len(self.group_for_reduce) < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
|
||||
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
self.upcast()
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and self.full_unupcasted_shape[-1] <= 3: self.upcast()
|
||||
if self.first_reduce+len(self.group_for_reduce) < (self.shape_len-self.upcasted) and s <= 3 and self.full_unupcasted_shape[-1] <= 3: self.upcast()
|
||||
else:
|
||||
for splits in [4]:
|
||||
if self.full_unupcasted_shape[-1]%splits == 0:
|
||||
|
@ -427,7 +429,7 @@ class OptimizedKernel(Kernel):
|
|||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
|
||||
to_local: List[Tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
local_size = prod(self.group_for_reduce) * prod(sz for _, sz in to_local)
|
||||
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||||
if local_sz is not None: to_local.append((axis, local_sz))
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
|
|
Loading…
Reference in New Issue