optimizer: simplify GROUP and LOCAL to have one of each (#2162)

* optimizer: simplify GROUP and LOCAL to have one of each

Now that tensor cores only use LASTLOCAL, we can simplify to use
only that op everywhere.

The only use of GROUP is in matvec hand-coded opts and it doesn't
make a performance difference so switching to use only the top
behavior.

Also adds additional asserts to prevent tensor core dims from
being altered which causes bad kernels to be generated.

* search: remove duplicated actions
This commit is contained in:
Francis Lam 2023-10-27 14:37:44 -07:00 committed by GitHub
parent e0201922e3
commit 8cf0bb9351
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 23 deletions

View File

@ -472,12 +472,12 @@ class TestLinearizerOpts(unittest.TestCase):
[Opt(OptOps.UPCAST, 1, 4)], [Opt(OptOps.UPCAST, 1, 4)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
[Opt(OptOps.UNROLL, 0, 2)], # check last unroll [Opt(OptOps.UNROLL, 0, 2)], # check last unroll
[Opt(OptOps.LASTLOCAL, 0, 4)], # check last local [Opt(OptOps.LOCAL, 0, 4)], # check last local
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) # [Opt(OptOps.GROUPTOP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
], apply_tc=True) ], apply_tc=True)

View File

@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
from enum import Enum, auto from enum import Enum, auto
class OptOps(Enum): class OptOps(Enum):
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702 UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); GROUPTOP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True) @dataclass(frozen=True, order=True)
@ -197,7 +197,7 @@ class OptimizedKernel(Kernel):
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2])) self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads]))) self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
for (tc_dim, tc_amt) in tc.threads: for (tc_dim, tc_amt) in tc.threads:
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1) fix(self.apply_opt(Opt(OptOps.LOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
# assert tensor core and prevent extra_opts from altering the key shape structure # assert tensor core and prevent extra_opts from altering the key shape structure
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
@ -216,7 +216,7 @@ class OptimizedKernel(Kernel):
if self.tensor_core and s0_exists: if self.tensor_core and s0_exists:
for upc in [4,2]: for upc in [4,2]:
if self.full_shape[s0] % upc == 0: if self.full_shape[s0] % upc == 0:
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc)) self.apply_opt(Opt(OptOps.LOCAL, s0, upc))
break break
# alias buffer # alias buffer
@ -228,26 +228,16 @@ class OptimizedKernel(Kernel):
def apply_opt(self, opt:Opt): def apply_opt(self, opt:Opt):
self.applied_opts.append(opt) self.applied_opts.append(opt)
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUPTOP else 0))
amt = opt.amt if opt.amt != 0 else self.full_shape[axis] amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
assert self.full_shape[axis] % amt == 0, "no longer valid shift" assert self.full_shape[axis] % amt == 0, "no longer valid shift"
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless" assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
if opt.op == OptOps.LOCAL: # cyan if opt.op == OptOps.LOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce" assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "local is for non-reduce that aren't TC dims"
assert not(self.tensor_core), "can't local with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce)
self.local_dims += 1
elif opt.op == OptOps.LASTLOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce"
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims) self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
self.local_dims += 1 self.local_dims += 1
elif opt.op == OptOps.GROUP: # green elif opt.op == OptOps.GROUPTOP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "group is for reduce dims"
assert not(self.tensor_core), "can't group with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt)
elif opt.op == OptOps.GROUPTOP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
assert not(self.tensor_core), "can't group with tensor cores" assert not(self.tensor_core), "can't group with tensor cores"
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce)) self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt) self.group_for_reduce.append(amt)
@ -257,7 +247,7 @@ class OptimizedKernel(Kernel):
self.shift_to(axis, amt, insert_before=None) self.shift_to(axis, amt, insert_before=None)
self.upcast() self.upcast()
elif opt.op == OptOps.UPCAST: # yellow elif opt.op == OptOps.UPCAST: # yellow
assert axis < self.first_reduce, "upcast is for non-reduce" assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "upcast is for non-reduce that aren't TC dims"
assert amt <= 8, "don't upcast more than 8" assert amt <= 8, "don't upcast more than 8"
self.shift_to(axis, amt, insert_before=None) self.shift_to(axis, amt, insert_before=None)
self.upcast() self.upcast()
@ -302,7 +292,7 @@ class OptimizedKernel(Kernel):
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
if MV_THREADS_PER_ROW > 1: if MV_THREADS_PER_ROW > 1:
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) self.apply_opt(Opt(OptOps.GROUPTOP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1: if MV_BLOCKSIZE > 1:
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1: if MV_ROWS_PER_THREAD > 1:

View File

@ -13,7 +13,7 @@ actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)]) actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
actions += [ actions += [
Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=32),
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8), Opt(op=OptOps.GROUPTOP, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=8), Opt(op=OptOps.GROUPTOP, axis=1, amt=8),
Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
] ]