speed up sum

This commit is contained in:
George Hotz 2021-06-17 16:38:34 -07:00
parent e8eb7d1b7e
commit 2affd226b3
2 changed files with 34 additions and 12 deletions

View File

@ -242,6 +242,8 @@ def cherry_reduceop(inp, op, axis):
osize = [1] osize = [1]
dimlist = [np.prod(inp.shape), 1] dimlist = [np.prod(inp.shape), 1]
redlist = [True, False] redlist = [True, False]
#dimlist = [1, np.prod(inp.shape)]
#redlist = [False, True]
else: else:
osize = np.array(inp.shape) osize = np.array(inp.shape)
osize[list(axis)] = 1 osize[list(axis)] = 1
@ -268,24 +270,42 @@ def cherry_reduceop(inp, op, axis):
if i not in axis: if i not in axis:
nosize.append(osize[i]) nosize.append(osize[i])
osize = nosize osize = nosize
if redlist[-1] != False:
dimlist = dimlist+[1]
redlist = redlist+[False]
if len(dimlist)%2 == 1:
dimlist = [1]+dimlist
redlist = [not redlist[0]]+redlist
osize = tuple(osize) osize = tuple(osize)
print(op, inp.shape, axis, "->", osize, dimlist, redlist) print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist)
inslot, outslot = 0, 2 inslot, outslot = 0, 2
cherry_dmar(SLOT(inslot), inp) cherry_dmar(SLOT(inslot), inp)
# dimlist is always the inshape # dimlist is always the inshape
# redlist is always [False, True, False, True, ...., True, False] # redlist is always [False, True, False, True, ...., True, False]
# special case if redlist ends with True
if redlist[-1] == True:
print("special case redlist[-1] == True")
outside = int(np.prod(dimlist[:-1]))
for l in range(0, outside, SZ):
reduce_size = min(SZ, outside-l)
j = 0
while j < dimlist[-1]:
riski_load(Reg.MATMUL_INPUT,
SLOT(inslot) + l*dimlist[-1] + j,
stride_y=1, stride_x=dimlist[-1],
len_y=min(SZ if j == 0 else SZ-1, dimlist[-1]-j),
len_x=reduce_size,
zero=j==0, skip_first=j!=0)
reduceops[op]()
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
j += SZ if j == 0 else SZ-1
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l, len_y=1, len_x=reduce_size)
# remove last dimension
redlist = redlist[:-1]
dimlist = dimlist[:-1]
inslot, outslot = outslot, inslot
# do the reduce merges one at a time # do the reduce merges one at a time
while True: while len(dimlist) >= 2:
print(dimlist) print("proc", dimlist, redlist)
for l in range(0, int(np.prod(dimlist[:-2]))): for l in range(0, int(np.prod(dimlist[:-2]))):
for k in range(0, dimlist[-1], SZ): for k in range(0, dimlist[-1], SZ):
reduce_size = min(SZ, dimlist[-1]-k) reduce_size = min(SZ, dimlist[-1]-k)
@ -302,13 +322,13 @@ def cherry_reduceop(inp, op, axis):
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
j += SZ if j == 0 else SZ-1 j += SZ if j == 0 else SZ-1
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size) riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size)
if len(dimlist) == 2: inslot, outslot = outslot, inslot
if len(dimlist) <= 2:
break break
# merge [False, True] into [False] # merge [False, True] into [False]
dimlist = dimlist[:-3] + [dimlist[-3]*dimlist[-1]] dimlist = dimlist[:-3] + [dimlist[-3]*dimlist[-1]]
inslot, outslot = outslot, inslot
return cherry_dmaw(SLOT(outslot), osize) return cherry_dmaw(SLOT(inslot), osize)
def cherry_unop(x, op): def cherry_unop(x, op):
cherry_dmar(SLOT(0), x) cherry_dmar(SLOT(0), x)

View File

@ -83,6 +83,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
def test_sum(self): def test_sum(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum) helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1)) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))