mirror of https://github.com/commaai/tinygrad.git
speed up sum
This commit is contained in:
parent
e8eb7d1b7e
commit
2affd226b3
|
@ -242,6 +242,8 @@ def cherry_reduceop(inp, op, axis):
|
||||||
osize = [1]
|
osize = [1]
|
||||||
dimlist = [np.prod(inp.shape), 1]
|
dimlist = [np.prod(inp.shape), 1]
|
||||||
redlist = [True, False]
|
redlist = [True, False]
|
||||||
|
#dimlist = [1, np.prod(inp.shape)]
|
||||||
|
#redlist = [False, True]
|
||||||
else:
|
else:
|
||||||
osize = np.array(inp.shape)
|
osize = np.array(inp.shape)
|
||||||
osize[list(axis)] = 1
|
osize[list(axis)] = 1
|
||||||
|
@ -268,24 +270,42 @@ def cherry_reduceop(inp, op, axis):
|
||||||
if i not in axis:
|
if i not in axis:
|
||||||
nosize.append(osize[i])
|
nosize.append(osize[i])
|
||||||
osize = nosize
|
osize = nosize
|
||||||
if redlist[-1] != False:
|
|
||||||
dimlist = dimlist+[1]
|
|
||||||
redlist = redlist+[False]
|
|
||||||
if len(dimlist)%2 == 1:
|
|
||||||
dimlist = [1]+dimlist
|
|
||||||
redlist = [not redlist[0]]+redlist
|
|
||||||
|
|
||||||
osize = tuple(osize)
|
osize = tuple(osize)
|
||||||
print(op, inp.shape, axis, "->", osize, dimlist, redlist)
|
print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist)
|
||||||
inslot, outslot = 0, 2
|
inslot, outslot = 0, 2
|
||||||
cherry_dmar(SLOT(inslot), inp)
|
cherry_dmar(SLOT(inslot), inp)
|
||||||
|
|
||||||
# dimlist is always the inshape
|
# dimlist is always the inshape
|
||||||
# redlist is always [False, True, False, True, ...., True, False]
|
# redlist is always [False, True, False, True, ...., True, False]
|
||||||
|
|
||||||
|
# special case if redlist ends with True
|
||||||
|
if redlist[-1] == True:
|
||||||
|
print("special case redlist[-1] == True")
|
||||||
|
outside = int(np.prod(dimlist[:-1]))
|
||||||
|
for l in range(0, outside, SZ):
|
||||||
|
reduce_size = min(SZ, outside-l)
|
||||||
|
j = 0
|
||||||
|
while j < dimlist[-1]:
|
||||||
|
riski_load(Reg.MATMUL_INPUT,
|
||||||
|
SLOT(inslot) + l*dimlist[-1] + j,
|
||||||
|
stride_y=1, stride_x=dimlist[-1],
|
||||||
|
len_y=min(SZ if j == 0 else SZ-1, dimlist[-1]-j),
|
||||||
|
len_x=reduce_size,
|
||||||
|
zero=j==0, skip_first=j!=0)
|
||||||
|
reduceops[op]()
|
||||||
|
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
|
||||||
|
j += SZ if j == 0 else SZ-1
|
||||||
|
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l, len_y=1, len_x=reduce_size)
|
||||||
|
# remove last dimension
|
||||||
|
redlist = redlist[:-1]
|
||||||
|
dimlist = dimlist[:-1]
|
||||||
|
inslot, outslot = outslot, inslot
|
||||||
|
|
||||||
# do the reduce merges one at a time
|
# do the reduce merges one at a time
|
||||||
while True:
|
while len(dimlist) >= 2:
|
||||||
print(dimlist)
|
print("proc", dimlist, redlist)
|
||||||
|
|
||||||
for l in range(0, int(np.prod(dimlist[:-2]))):
|
for l in range(0, int(np.prod(dimlist[:-2]))):
|
||||||
for k in range(0, dimlist[-1], SZ):
|
for k in range(0, dimlist[-1], SZ):
|
||||||
reduce_size = min(SZ, dimlist[-1]-k)
|
reduce_size = min(SZ, dimlist[-1]-k)
|
||||||
|
@ -302,13 +322,13 @@ def cherry_reduceop(inp, op, axis):
|
||||||
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
|
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
|
||||||
j += SZ if j == 0 else SZ-1
|
j += SZ if j == 0 else SZ-1
|
||||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size)
|
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size)
|
||||||
if len(dimlist) == 2:
|
inslot, outslot = outslot, inslot
|
||||||
|
if len(dimlist) <= 2:
|
||||||
break
|
break
|
||||||
# merge [False, True] into [False]
|
# merge [False, True] into [False]
|
||||||
dimlist = dimlist[:-3] + [dimlist[-3]*dimlist[-1]]
|
dimlist = dimlist[:-3] + [dimlist[-3]*dimlist[-1]]
|
||||||
inslot, outslot = outslot, inslot
|
|
||||||
|
|
||||||
return cherry_dmaw(SLOT(outslot), osize)
|
return cherry_dmaw(SLOT(inslot), osize)
|
||||||
|
|
||||||
def cherry_unop(x, op):
|
def cherry_unop(x, op):
|
||||||
cherry_dmar(SLOT(0), x)
|
cherry_dmar(SLOT(0), x)
|
||||||
|
|
|
@ -83,6 +83,8 @@ class TestOps(unittest.TestCase):
|
||||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||||
def test_sum(self):
|
def test_sum(self):
|
||||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
|
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
|
||||||
|
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3))
|
||||||
|
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3)))
|
||||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
|
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
|
||||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
||||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
|
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
|
||||||
|
|
Loading…
Reference in New Issue