mirror of https://github.com/commaai/tinygrad.git
optimizer: simplify GROUP and LOCAL to have one of each (#2162)
* optimizer: simplify GROUP and LOCAL to have one of each Now that tensor cores only use LASTLOCAL, we can simplify to use only that op everywhere. The only use of GROUP is in matvec hand-coded opts and it doesn't make a performance difference so switching to use only the top behavior. Also adds additional asserts to prevent tensor core dims from being altered which causes bad kernels to be generated. * search: remove duplicated actions
This commit is contained in:
parent
e0201922e3
commit
8cf0bb9351
|
@ -472,12 +472,12 @@ class TestLinearizerOpts(unittest.TestCase):
|
||||||
[Opt(OptOps.UPCAST, 1, 4)],
|
[Opt(OptOps.UPCAST, 1, 4)],
|
||||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
|
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
|
||||||
[Opt(OptOps.UNROLL, 0, 2)], # check last unroll
|
[Opt(OptOps.UNROLL, 0, 2)], # check last unroll
|
||||||
[Opt(OptOps.LASTLOCAL, 0, 4)], # check last local
|
[Opt(OptOps.LOCAL, 0, 4)], # check last local
|
||||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local
|
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local
|
||||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
|
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
|
||||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
|
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)],
|
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
|
||||||
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
# [Opt(OptOps.GROUPTOP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||||
], apply_tc=True)
|
], apply_tc=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
class OptOps(Enum):
|
class OptOps(Enum):
|
||||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); GROUPTOP = auto() # noqa: E702
|
||||||
def __lt__(self, x:OptOps): return self.value < x.value
|
def __lt__(self, x:OptOps): return self.value < x.value
|
||||||
|
|
||||||
@dataclass(frozen=True, order=True)
|
@dataclass(frozen=True, order=True)
|
||||||
|
@ -197,7 +197,7 @@ class OptimizedKernel(Kernel):
|
||||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
|
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
|
||||||
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
|
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
|
||||||
for (tc_dim, tc_amt) in tc.threads:
|
for (tc_dim, tc_amt) in tc.threads:
|
||||||
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
fix(self.apply_opt(Opt(OptOps.LOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||||
|
|
||||||
# assert tensor core and prevent extra_opts from altering the key shape structure
|
# assert tensor core and prevent extra_opts from altering the key shape structure
|
||||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||||
|
@ -216,7 +216,7 @@ class OptimizedKernel(Kernel):
|
||||||
if self.tensor_core and s0_exists:
|
if self.tensor_core and s0_exists:
|
||||||
for upc in [4,2]:
|
for upc in [4,2]:
|
||||||
if self.full_shape[s0] % upc == 0:
|
if self.full_shape[s0] % upc == 0:
|
||||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
|
self.apply_opt(Opt(OptOps.LOCAL, s0, upc))
|
||||||
break
|
break
|
||||||
|
|
||||||
# alias buffer
|
# alias buffer
|
||||||
|
@ -228,26 +228,16 @@ class OptimizedKernel(Kernel):
|
||||||
|
|
||||||
def apply_opt(self, opt:Opt):
|
def apply_opt(self, opt:Opt):
|
||||||
self.applied_opts.append(opt)
|
self.applied_opts.append(opt)
|
||||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
|
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUPTOP else 0))
|
||||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||||
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
||||||
if opt.op == OptOps.LOCAL: # cyan
|
if opt.op == OptOps.LOCAL: # cyan
|
||||||
assert axis < self.first_reduce, "can't local a reduce"
|
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "local is for non-reduce that aren't TC dims"
|
||||||
assert not(self.tensor_core), "can't local with tensor cores"
|
|
||||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
|
||||||
self.local_dims += 1
|
|
||||||
elif opt.op == OptOps.LASTLOCAL: # cyan
|
|
||||||
assert axis < self.first_reduce, "can't local a reduce"
|
|
||||||
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
||||||
self.local_dims += 1
|
self.local_dims += 1
|
||||||
elif opt.op == OptOps.GROUP: # green
|
elif opt.op == OptOps.GROUPTOP: # green
|
||||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "group is for reduce dims"
|
||||||
assert not(self.tensor_core), "can't group with tensor cores"
|
|
||||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
|
||||||
self.group_for_reduce.append(amt)
|
|
||||||
elif opt.op == OptOps.GROUPTOP: # green
|
|
||||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
|
||||||
assert not(self.tensor_core), "can't group with tensor cores"
|
assert not(self.tensor_core), "can't group with tensor cores"
|
||||||
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||||
self.group_for_reduce.append(amt)
|
self.group_for_reduce.append(amt)
|
||||||
|
@ -257,7 +247,7 @@ class OptimizedKernel(Kernel):
|
||||||
self.shift_to(axis, amt, insert_before=None)
|
self.shift_to(axis, amt, insert_before=None)
|
||||||
self.upcast()
|
self.upcast()
|
||||||
elif opt.op == OptOps.UPCAST: # yellow
|
elif opt.op == OptOps.UPCAST: # yellow
|
||||||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "upcast is for non-reduce that aren't TC dims"
|
||||||
assert amt <= 8, "don't upcast more than 8"
|
assert amt <= 8, "don't upcast more than 8"
|
||||||
self.shift_to(axis, amt, insert_before=None)
|
self.shift_to(axis, amt, insert_before=None)
|
||||||
self.upcast()
|
self.upcast()
|
||||||
|
@ -302,7 +292,7 @@ class OptimizedKernel(Kernel):
|
||||||
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
||||||
if MV_THREADS_PER_ROW > 1:
|
if MV_THREADS_PER_ROW > 1:
|
||||||
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
self.apply_opt(Opt(OptOps.GROUPTOP, 0, MV_THREADS_PER_ROW))
|
||||||
if MV_BLOCKSIZE > 1:
|
if MV_BLOCKSIZE > 1:
|
||||||
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||||
if MV_ROWS_PER_THREAD > 1:
|
if MV_ROWS_PER_THREAD > 1:
|
||||||
|
|
|
@ -13,7 +13,7 @@ actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,
|
||||||
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
||||||
actions += [
|
actions += [
|
||||||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
Opt(op=OptOps.GROUPTOP, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=8), Opt(op=OptOps.GROUPTOP, axis=1, amt=8),
|
||||||
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue