allow local + grouped reduce in hand_coded (#1996)

* allow local + grouped reduce in hand_coded

* allowed loop size based on global_dims

* fix const

* fix const one more time

* better divisor

* a bit fix

* can take 2, why not

* fix linter

* better comments

* start with 2

* not always pick group reduce

* fix images

* better images

* better
This commit is contained in:
nimlgen 2023-10-06 16:11:28 +03:00 committed by GitHub
parent fa9945dac0
commit 219a1f7063
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 16 deletions

View File

@ -330,18 +330,23 @@ class OptimizedKernel(Kernel):
self.upcast()
return
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
# are we grouping? (requires local shape support)
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(sz)
break
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]) and not self.float4_axis(0):
# are we grouping? (requires local shape support).
# Determining the number of elements to group:
# - The grouping is influenced by the number of elements we aim to retain in the reduction loop.
# - As global dimensions increase, larger reduction loops are acceptable because the GPU remains occupied and can run threads longer.
# - For small global dimensions (<=512), maximize the elements pulled into local memory to ensure GPU utilization.
has_images = any(self.bufs[buf_index].dtype.__class__ is ImageDType for buf_index,_ in enumerate(self.bufs))
if self.first_reduce + 1 <= self.shape_len and isinstance(self.full_shape[self.first_reduce], int):
divisors = [d for d in range(1, min(257, self.full_shape[self.first_reduce])) if self.full_shape[self.first_reduce] % d == 0] # type: ignore
divisors = [d for d in divisors if d % 4 == 0 and (self.full_shape[self.first_reduce] // d) % 4 == 0] if has_images else divisors # images need a unit stride axis (see required_optimizations()).
suitable_divisors = [d for d in divisors if self.full_shape[self.first_reduce] // d <= prod(self.full_shape[:self.first_reduce]) // 8]
if divisors and (sz := (suitable_divisors[0] if suitable_divisors and prod(self.full_shape[:self.first_reduce]) > 512 else divisors[-1])) and sz > 1:
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce+len(self.group_for_reduce))
self.group_for_reduce.append(sz)
# are we upcasting in mid reduce? (only for images)
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
if self.bufs[0].dtype.name.startswith('image') and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
axes = self.sts[0].unit_stride_axes()
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
if self.sts[0].shape[axes[0]]%4 == 0:
@ -362,9 +367,6 @@ class OptimizedKernel(Kernel):
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
self.simplify_ones()
# no more opt if we are grouping
if self.group_for_reduce: return
# **** below this line need to be optional and benchmarked ****
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
@ -401,11 +403,11 @@ class OptimizedKernel(Kernel):
break
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
if self.first_reduce+len(self.group_for_reduce) < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
self.upcast()
# if it's small, upcast a second reduce dimension too
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and self.full_unupcasted_shape[-1] <= 3: self.upcast()
if self.first_reduce+len(self.group_for_reduce) < (self.shape_len-self.upcasted) and s <= 3 and self.full_unupcasted_shape[-1] <= 3: self.upcast()
else:
for splits in [4]:
if self.full_unupcasted_shape[-1]%splits == 0:
@ -427,7 +429,7 @@ class OptimizedKernel(Kernel):
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
to_local: List[Tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)
local_size = prod(self.group_for_reduce) * prod(sz for _, sz in to_local)
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
if local_sz is not None: to_local.append((axis, local_sz))
for axis, local_sz in sorted(to_local[:3]):