mirror of https://github.com/commaai/tinygrad.git
optimizer: simplify GROUP and LOCAL to have one of each (#2162)
* optimizer: simplify GROUP and LOCAL to have one of each Now that tensor cores only use LASTLOCAL, we can simplify to use only that op everywhere. The only use of GROUP is in matvec hand-coded opts and it doesn't make a performance difference so switching to use only the top behavior. Also adds additional asserts to prevent tensor core dims from being altered which causes bad kernels to be generated. * search: remove duplicated actions
This commit is contained in:
parent
e0201922e3
commit
8cf0bb9351
|
@ -472,12 +472,12 @@ class TestLinearizerOpts(unittest.TestCase):
|
|||
[Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
|
||||
[Opt(OptOps.UNROLL, 0, 2)], # check last unroll
|
||||
[Opt(OptOps.LASTLOCAL, 0, 4)], # check last local
|
||||
[Opt(OptOps.LOCAL, 0, 4)], # check last local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)],
|
||||
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
|
||||
# [Opt(OptOps.GROUPTOP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||
], apply_tc=True)
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
|
|||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); GROUPTOP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
|
@ -197,7 +197,7 @@ class OptimizedKernel(Kernel):
|
|||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
|
||||
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||
fix(self.apply_opt(Opt(OptOps.LOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||
|
||||
# assert tensor core and prevent extra_opts from altering the key shape structure
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
|
@ -216,7 +216,7 @@ class OptimizedKernel(Kernel):
|
|||
if self.tensor_core and s0_exists:
|
||||
for upc in [4,2]:
|
||||
if self.full_shape[s0] % upc == 0:
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
|
||||
self.apply_opt(Opt(OptOps.LOCAL, s0, upc))
|
||||
break
|
||||
|
||||
# alias buffer
|
||||
|
@ -228,26 +228,16 @@ class OptimizedKernel(Kernel):
|
|||
|
||||
def apply_opt(self, opt:Opt):
|
||||
self.applied_opts.append(opt)
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUPTOP else 0))
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
assert not(self.tensor_core), "can't local with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.LASTLOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "local is for non-reduce that aren't TC dims"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.GROUP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "group is for reduce dims"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
|
@ -257,7 +247,7 @@ class OptimizedKernel(Kernel):
|
|||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
||||
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "upcast is for non-reduce that aren't TC dims"
|
||||
assert amt <= 8, "don't upcast more than 8"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
|
@ -302,7 +292,7 @@ class OptimizedKernel(Kernel):
|
|||
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
||||
if MV_THREADS_PER_ROW > 1:
|
||||
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
self.apply_opt(Opt(OptOps.GROUPTOP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1:
|
||||
|
|
|
@ -13,7 +13,7 @@ actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,
|
|||
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
||||
actions += [
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
Opt(op=OptOps.GROUPTOP, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=8), Opt(op=OptOps.GROUPTOP, axis=1, amt=8),
|
||||
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue