From 2affd226b30e592b29d2bc027745758cd80cb827 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 17 Jun 2021 16:38:34 -0700 Subject: [PATCH] speed up sum --- extra/cherry.py | 44 ++++++++++++++++++++++++++++++++------------ test/test_ops.py | 2 ++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/extra/cherry.py b/extra/cherry.py index e6efc708..c2a6a001 100755 --- a/extra/cherry.py +++ b/extra/cherry.py @@ -242,6 +242,8 @@ def cherry_reduceop(inp, op, axis): osize = [1] dimlist = [np.prod(inp.shape), 1] redlist = [True, False] + #dimlist = [1, np.prod(inp.shape)] + #redlist = [False, True] else: osize = np.array(inp.shape) osize[list(axis)] = 1 @@ -268,24 +270,42 @@ def cherry_reduceop(inp, op, axis): if i not in axis: nosize.append(osize[i]) osize = nosize - if redlist[-1] != False: - dimlist = dimlist+[1] - redlist = redlist+[False] - if len(dimlist)%2 == 1: - dimlist = [1]+dimlist - redlist = [not redlist[0]]+redlist osize = tuple(osize) - print(op, inp.shape, axis, "->", osize, dimlist, redlist) + print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist) inslot, outslot = 0, 2 cherry_dmar(SLOT(inslot), inp) # dimlist is always the inshape # redlist is always [False, True, False, True, ...., True, False] + # special case if redlist ends with True + if redlist[-1] == True: + print("special case redlist[-1] == True") + outside = int(np.prod(dimlist[:-1])) + for l in range(0, outside, SZ): + reduce_size = min(SZ, outside-l) + j = 0 + while j < dimlist[-1]: + riski_load(Reg.MATMUL_INPUT, + SLOT(inslot) + l*dimlist[-1] + j, + stride_y=1, stride_x=dimlist[-1], + len_y=min(SZ if j == 0 else SZ-1, dimlist[-1]-j), + len_x=reduce_size, + zero=j==0, skip_first=j!=0) + reduceops[op]() + riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row + j += SZ if j == 0 else SZ-1 + riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l, len_y=1, len_x=reduce_size) + # remove last dimension + redlist = redlist[:-1] + dimlist = dimlist[:-1] + inslot, outslot = outslot, inslot + # do the reduce merges one at a time - while True: - print(dimlist) + while len(dimlist) >= 2: + print("proc", dimlist, redlist) + for l in range(0, int(np.prod(dimlist[:-2]))): for k in range(0, dimlist[-1], SZ): reduce_size = min(SZ, dimlist[-1]-k) @@ -302,13 +322,13 @@ def cherry_reduceop(inp, op, axis): riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row j += SZ if j == 0 else SZ-1 riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size) - if len(dimlist) == 2: + inslot, outslot = outslot, inslot + if len(dimlist) <= 2: break # merge [False, True] into [False] dimlist = dimlist[:-3] + [dimlist[-3]*dimlist[-1]] - inslot, outslot = outslot, inslot - return cherry_dmaw(SLOT(outslot), osize) + return cherry_dmaw(SLOT(inslot), osize) def cherry_unop(x, op): cherry_dmar(SLOT(0), x) diff --git a/test/test_ops.py b/test/test_ops.py index 235d2aa6..23bef2ce 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -83,6 +83,8 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) def test_sum(self): helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3)) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))