optimizer: simplify GROUP and LOCAL to have one of each (#2162)

* optimizer: simplify GROUP and LOCAL to have one of each

Now that tensor cores only use LASTLOCAL, we can simplify to use
only that op everywhere.

The only use of GROUP is in matvec hand-coded opts and it doesn't
make a performance difference so switching to use only the top
behavior.

Also adds additional asserts to prevent tensor core dims from
being altered which causes bad kernels to be generated.

* search: remove duplicated actions
This commit is contained in:
Francis Lam 2023-10-27 14:37:44 -07:00 committed by GitHub
parent e0201922e3
commit 8cf0bb9351
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 23 deletions

View File

@ -472,12 +472,12 @@ class TestLinearizerOpts(unittest.TestCase):
[Opt(OptOps.UPCAST, 1, 4)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
[Opt(OptOps.UNROLL, 0, 2)], # check last unroll
[Opt(OptOps.LASTLOCAL, 0, 4)], # check last local
[Opt(OptOps.LOCAL, 0, 4)], # check last local
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)],
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
# [Opt(OptOps.GROUPTOP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
], apply_tc=True)

View File

@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
from enum import Enum, auto
class OptOps(Enum):
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); GROUPTOP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)
@ -197,7 +197,7 @@ class OptimizedKernel(Kernel):
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
for (tc_dim, tc_amt) in tc.threads:
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
fix(self.apply_opt(Opt(OptOps.LOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
# assert tensor core and prevent extra_opts from altering the key shape structure
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
@ -216,7 +216,7 @@ class OptimizedKernel(Kernel):
if self.tensor_core and s0_exists:
for upc in [4,2]:
if self.full_shape[s0] % upc == 0:
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
self.apply_opt(Opt(OptOps.LOCAL, s0, upc))
break
# alias buffer
@ -228,26 +228,16 @@ class OptimizedKernel(Kernel):
def apply_opt(self, opt:Opt):
self.applied_opts.append(opt)
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUPTOP else 0))
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
if opt.op == OptOps.LOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce"
assert not(self.tensor_core), "can't local with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce)
self.local_dims += 1
elif opt.op == OptOps.LASTLOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce"
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "local is for non-reduce that aren't TC dims"
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
self.local_dims += 1
elif opt.op == OptOps.GROUP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
assert not(self.tensor_core), "can't group with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt)
elif opt.op == OptOps.GROUPTOP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
elif opt.op == OptOps.GROUPTOP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "group is for reduce dims"
assert not(self.tensor_core), "can't group with tensor cores"
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt)
@ -257,7 +247,7 @@ class OptimizedKernel(Kernel):
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op == OptOps.UPCAST: # yellow
assert axis < self.first_reduce, "upcast is for non-reduce"
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "upcast is for non-reduce that aren't TC dims"
assert amt <= 8, "don't upcast more than 8"
self.shift_to(axis, amt, insert_before=None)
self.upcast()
@ -302,7 +292,7 @@ class OptimizedKernel(Kernel):
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
if MV_THREADS_PER_ROW > 1:
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
self.apply_opt(Opt(OptOps.GROUPTOP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1:
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1:

View File

@ -13,7 +13,7 @@ actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
actions += [
Opt(op=OptOps.LOCAL, axis=0, amt=32),
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
Opt(op=OptOps.GROUPTOP, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=8), Opt(op=OptOps.GROUPTOP, axis=1, amt=8),
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
]