diff --git a/test/test_linearizer.py b/test/test_linearizer.py index cb7bd5c3..38f94fad 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -472,12 +472,12 @@ class TestLinearizerOpts(unittest.TestCase): [Opt(OptOps.UPCAST, 1, 4)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts [Opt(OptOps.UNROLL, 0, 2)], # check last unroll - [Opt(OptOps.LASTLOCAL, 0, 4)], # check last local + [Opt(OptOps.LOCAL, 0, 4)], # check last local [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)], - # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], + # [Opt(OptOps.GROUPTOP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) ], apply_tc=True) diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index 75c02c87..4ee30e71 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape from enum import Enum, auto class OptOps(Enum): - UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702 + UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); GROUPTOP = auto() # noqa: E702 def __lt__(self, x:OptOps): return self.value < x.value @dataclass(frozen=True, order=True) @@ -197,7 +197,7 @@ class OptimizedKernel(Kernel): self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2])) self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads]))) for (tc_dim, tc_amt) in tc.threads: - fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1) + fix(self.apply_opt(Opt(OptOps.LOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1) # assert tensor core and prevent extra_opts from altering the key shape structure if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA @@ -216,7 +216,7 @@ class OptimizedKernel(Kernel): if self.tensor_core and s0_exists: for upc in [4,2]: if self.full_shape[s0] % upc == 0: - self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc)) + self.apply_opt(Opt(OptOps.LOCAL, s0, upc)) break # alias buffer @@ -228,26 +228,16 @@ class OptimizedKernel(Kernel): def apply_opt(self, opt:Opt): self.applied_opts.append(opt) - axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) + axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUPTOP else 0)) amt = opt.amt if opt.amt != 0 else self.full_shape[axis] assert self.full_shape[axis] % amt == 0, "no longer valid shift" assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless" if opt.op == OptOps.LOCAL: # cyan - assert axis < self.first_reduce, "can't local a reduce" - assert not(self.tensor_core), "can't local with tensor cores" - self.shift_to(axis, amt, insert_before=self.first_reduce) - self.local_dims += 1 - elif opt.op == OptOps.LASTLOCAL: # cyan - assert axis < self.first_reduce, "can't local a reduce" + assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "local is for non-reduce that aren't TC dims" self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims) self.local_dims += 1 - elif opt.op == OptOps.GROUP: # green - assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" - assert not(self.tensor_core), "can't group with tensor cores" - self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) - elif opt.op == OptOps.GROUPTOP: # green - assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" + elif opt.op == OptOps.GROUPTOP: # green + assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "group is for reduce dims" assert not(self.tensor_core), "can't group with tensor cores" self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce)) self.group_for_reduce.append(amt) @@ -257,7 +247,7 @@ class OptimizedKernel(Kernel): self.shift_to(axis, amt, insert_before=None) self.upcast() elif opt.op == OptOps.UPCAST: # yellow - assert axis < self.first_reduce, "upcast is for non-reduce" + assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "upcast is for non-reduce that aren't TC dims" assert amt <= 8, "don't upcast more than 8" self.shift_to(axis, amt, insert_before=None) self.upcast() @@ -302,7 +292,7 @@ class OptimizedKernel(Kernel): if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") if MV_THREADS_PER_ROW > 1: - self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) + self.apply_opt(Opt(OptOps.GROUPTOP, 0, MV_THREADS_PER_ROW)) if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) if MV_ROWS_PER_THREAD > 1: diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 4acaa0e5..18cbaddf 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -13,7 +13,7 @@ actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4, actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)]) actions += [ Opt(op=OptOps.LOCAL, axis=0, amt=32), - Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8), + Opt(op=OptOps.GROUPTOP, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=8), Opt(op=OptOps.GROUPTOP, axis=1, amt=8), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), ]