mirror of https://github.com/commaai/tinygrad.git
speed up sum
This commit is contained in:
parent
e8eb7d1b7e
commit
2affd226b3
|
@ -242,6 +242,8 @@ def cherry_reduceop(inp, op, axis):
|
|||
osize = [1]
|
||||
dimlist = [np.prod(inp.shape), 1]
|
||||
redlist = [True, False]
|
||||
#dimlist = [1, np.prod(inp.shape)]
|
||||
#redlist = [False, True]
|
||||
else:
|
||||
osize = np.array(inp.shape)
|
||||
osize[list(axis)] = 1
|
||||
|
@ -268,24 +270,42 @@ def cherry_reduceop(inp, op, axis):
|
|||
if i not in axis:
|
||||
nosize.append(osize[i])
|
||||
osize = nosize
|
||||
if redlist[-1] != False:
|
||||
dimlist = dimlist+[1]
|
||||
redlist = redlist+[False]
|
||||
if len(dimlist)%2 == 1:
|
||||
dimlist = [1]+dimlist
|
||||
redlist = [not redlist[0]]+redlist
|
||||
|
||||
osize = tuple(osize)
|
||||
print(op, inp.shape, axis, "->", osize, dimlist, redlist)
|
||||
print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist)
|
||||
inslot, outslot = 0, 2
|
||||
cherry_dmar(SLOT(inslot), inp)
|
||||
|
||||
# dimlist is always the inshape
|
||||
# redlist is always [False, True, False, True, ...., True, False]
|
||||
|
||||
# special case if redlist ends with True
|
||||
if redlist[-1] == True:
|
||||
print("special case redlist[-1] == True")
|
||||
outside = int(np.prod(dimlist[:-1]))
|
||||
for l in range(0, outside, SZ):
|
||||
reduce_size = min(SZ, outside-l)
|
||||
j = 0
|
||||
while j < dimlist[-1]:
|
||||
riski_load(Reg.MATMUL_INPUT,
|
||||
SLOT(inslot) + l*dimlist[-1] + j,
|
||||
stride_y=1, stride_x=dimlist[-1],
|
||||
len_y=min(SZ if j == 0 else SZ-1, dimlist[-1]-j),
|
||||
len_x=reduce_size,
|
||||
zero=j==0, skip_first=j!=0)
|
||||
reduceops[op]()
|
||||
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
|
||||
j += SZ if j == 0 else SZ-1
|
||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l, len_y=1, len_x=reduce_size)
|
||||
# remove last dimension
|
||||
redlist = redlist[:-1]
|
||||
dimlist = dimlist[:-1]
|
||||
inslot, outslot = outslot, inslot
|
||||
|
||||
# do the reduce merges one at a time
|
||||
while True:
|
||||
print(dimlist)
|
||||
while len(dimlist) >= 2:
|
||||
print("proc", dimlist, redlist)
|
||||
|
||||
for l in range(0, int(np.prod(dimlist[:-2]))):
|
||||
for k in range(0, dimlist[-1], SZ):
|
||||
reduce_size = min(SZ, dimlist[-1]-k)
|
||||
|
@ -302,13 +322,13 @@ def cherry_reduceop(inp, op, axis):
|
|||
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
|
||||
j += SZ if j == 0 else SZ-1
|
||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size)
|
||||
if len(dimlist) == 2:
|
||||
inslot, outslot = outslot, inslot
|
||||
if len(dimlist) <= 2:
|
||||
break
|
||||
# merge [False, True] into [False]
|
||||
dimlist = dimlist[:-3] + [dimlist[-3]*dimlist[-1]]
|
||||
inslot, outslot = outslot, inslot
|
||||
|
||||
return cherry_dmaw(SLOT(outslot), osize)
|
||||
return cherry_dmaw(SLOT(inslot), osize)
|
||||
|
||||
def cherry_unop(x, op):
|
||||
cherry_dmar(SLOT(0), x)
|
||||
|
|
|
@ -83,6 +83,8 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
|
||||
|
|
Loading…
Reference in New Issue