loop unrolling upcast

This commit is contained in:
George Hotz 2023-01-28 14:51:24 -08:00
parent 381f3e92da
commit 2f194aadad
2 changed files with 13 additions and 4 deletions

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python
import os
import unittest
import numpy as np
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
@ -9,7 +10,7 @@ from test.lib_test_ast import test_ast
def compile_and_test_ast(ast):
k = CLASTKernel(ast)
k.codegen()(*k.bufs)
test_ast(k)
if not int(os.getenv("NOTEST", "0")): test_ast(k)
class TestAST(unittest.TestCase):
def test_conv_zeroview_ast(self):

View File

@ -77,8 +77,11 @@ class CLASTKernel(ASTKernel):
else:
ldr = Token(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if self.buftokens[buf_index].typ == Types.FLOAT4 else 1)).cl}]", self.buftokens[buf_index].typ)
ldr = ldr if str(valid) == "1" or (VALIDHACKS and isinstance(self.bufs[buf_index]._buf, CLImage)) else Token(f"({valid.cl} ? {ldr.tok} : 0.0f)", ldr.typ)
self.kernel.append(f"{ldr.decltype()} val{buf_index}_{o} = {ldr.tok};\n")
self.loaded_keys[(buf_index,o)] = Token(f"val{buf_index}_{o}", ldr.typ)
if const is not None:
self.loaded_keys[(buf_index,o)] = ldr
else:
self.kernel.append(f"{ldr.decltype()} val{buf_index}_{o} = {ldr.tok};\n")
self.loaded_keys[(buf_index,o)] = Token(f"val{buf_index}_{o}", ldr.typ)
tokens.append(self.loaded_keys[(buf_index,o)])
return tokens
@ -206,6 +209,10 @@ class CLASTKernel(ASTKernel):
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1))
self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis)
# if last dim <= 3 and it's a reduce dim, upcast (loop unrolling)
end_dimension = max([st.shape[-1] for st in self.sts])
if self.first_reduce < self.shape_len and end_dimension > 1 and end_dimension <= 3: self.upcast()
# STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD
# group_for_reduce will have to be better first
def codegen(self):
@ -248,12 +255,13 @@ class CLASTKernel(ASTKernel):
# early ast
accumulators : List[Token] = [Token("acc%d" % i, self.buftokens[0].typ) for i in range(self.buftokens[0].size())]
if self.reduceop:
broadcast = self.buftokens[self.bufs.index(self.earlybufs[0])].size() // self.buftokens[0].size()
full_shape = [x.shape for x in self.sts if x.shape != self.sts[0].shape]
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators]
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, accumulators, do_reduce=True)] + ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce)))
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, accumulators*broadcast, do_reduce=True)] + ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce)))
# middle
if self.group_for_reduce: