mirror of https://github.com/commaai/tinygrad.git
sum op
This commit is contained in:
parent
d6517a8a7c
commit
c1d469d440
|
@ -159,7 +159,7 @@ python3 -m pytest
|
|||
PYTHONPATH="." DEBUG=1 CHERRY=1 python3 examples/efficientnet.py https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg
|
||||
```
|
||||
|
||||
* Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`
|
||||
* ~~Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`~~
|
||||
* Switch convolution backward pass to CHERRY instead of the numpy placeholder
|
||||
* Confirm EfficientNet backward pass fully uses CHERRY instructions
|
||||
* Benchmark that and transformers
|
||||
|
|
107
extra/cherry.py
107
extra/cherry.py
|
@ -38,7 +38,7 @@ from collections import defaultdict
|
|||
# <empty> <output> <input> <weight>
|
||||
# <weight> <input> <empty> <output>
|
||||
|
||||
SZ = 32
|
||||
SZ = 4
|
||||
SLOTSIZE = 1024*1024*2 # 5MB, for 20MB total. 8M elements
|
||||
sram = np.zeros((SLOTSIZE*4), dtype=np.float32)
|
||||
regfile = {}
|
||||
|
@ -148,12 +148,12 @@ def riski_pow():
|
|||
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] ** regfile[Reg.MATMUL_WEIGHTS]
|
||||
|
||||
@count
|
||||
def riski_reduce_sum():
|
||||
regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT].sum(axis=0)
|
||||
def riski_reduce_sum(out=0, cnt=SZ):
|
||||
regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].sum(axis=0)
|
||||
|
||||
@count
|
||||
def riski_reduce_max():
|
||||
regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT].max(axis=0)
|
||||
def riski_reduce_max(out=0, cnt=SZ):
|
||||
regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].max(axis=0)
|
||||
|
||||
# TODO: make accumulate a bit in the instruction available to all
|
||||
binops = {BinaryOps.ADD: riski_add,
|
||||
|
@ -175,21 +175,28 @@ def riski_matmul():
|
|||
regfile[Reg.MATMUL_WEIGHTS]
|
||||
|
||||
@count
|
||||
def riski_mov(tout, tin):
|
||||
regfile[tout][:] = regfile[tin]
|
||||
def riski_mov(tout, tin, transpose=False):
|
||||
ret = regfile[tin]
|
||||
if transpose:
|
||||
ret = ret.T
|
||||
regfile[tout] = np.copy(ret)
|
||||
|
||||
load_log = open("/tmp/risk_load_log", "w") if os.getenv("LOAD_LOG") else None
|
||||
|
||||
@count
|
||||
def riski_load(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ):
|
||||
def riski_load(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ, zero=True, skip_first=False):
|
||||
global util_n, util_d
|
||||
if load_log is not None:
|
||||
load_log.write("%d %d %d %d %d\n" % (address, stride_y, stride_x, len_y, len_x))
|
||||
utils[(len_y, len_x)] += 1
|
||||
stride_y, stride_x = int(stride_y), int(stride_x)
|
||||
d = regfile[target]
|
||||
d[:] = 0
|
||||
d[:len_y, :len_x] = np.lib.stride_tricks.as_strided(sram[address:], (len_y, len_x), (stride_y*4, stride_x*4))
|
||||
if zero:
|
||||
d[:] = 0
|
||||
if skip_first:
|
||||
d[1:(len_y+1), :len_x] = np.lib.stride_tricks.as_strided(sram[address:], (len_y, len_x), (stride_y*4, stride_x*4))
|
||||
else:
|
||||
d[:len_y, :len_x] = np.lib.stride_tricks.as_strided(sram[address:], (len_y, len_x), (stride_y*4, stride_x*4))
|
||||
"""
|
||||
for y in range(0, len_y):
|
||||
for x in range(0, len_x):
|
||||
|
@ -225,11 +232,83 @@ def cherry_dmaw(address, shp):
|
|||
|
||||
# *** CHERRY code to be compiled ***
|
||||
|
||||
def cherry_reduceop(x, op, axis):
|
||||
print(op, x.shape, axis)
|
||||
cherry_dmar(SLOT(0), x)
|
||||
def cherry_reduceop(inp, op, axis):
|
||||
dimlist, redlist = [], []
|
||||
if type(axis) == int:
|
||||
axis = [axis]
|
||||
if axis is None:
|
||||
# full reduce
|
||||
#osize = [1]*len(inp.shape)
|
||||
osize = [1]
|
||||
dimlist = [np.prod(inp.shape), 1]
|
||||
redlist = [True, False]
|
||||
else:
|
||||
osize = np.array(inp.shape)
|
||||
osize[list(axis)] = 1
|
||||
|
||||
return cherry_dmaw(SLOT(2), x.shape)
|
||||
# skip any early 1s
|
||||
sd = 0
|
||||
while sd < len(osize) and osize[sd] == 1 and inp.shape[sd] == 1:
|
||||
sd += 1
|
||||
|
||||
# this would be good in the GPU
|
||||
for i in range(sd, len(inp.shape)):
|
||||
is_reduce_axis = osize[i] == 1 and inp.shape[i] != -1
|
||||
if len(dimlist) == 0:
|
||||
dimlist.append(inp.shape[i])
|
||||
redlist.append(is_reduce_axis)
|
||||
else:
|
||||
if redlist[-1] == is_reduce_axis:
|
||||
dimlist[-1] *= inp.shape[i]
|
||||
else:
|
||||
dimlist.append(inp.shape[i])
|
||||
redlist.append(is_reduce_axis)
|
||||
nosize = []
|
||||
for i in range(osize.shape[0]):
|
||||
if i not in axis:
|
||||
nosize.append(osize[i])
|
||||
osize = nosize
|
||||
if redlist[-1] != False:
|
||||
dimlist = dimlist+[1]
|
||||
redlist = redlist+[False]
|
||||
if len(dimlist)%2 == 1:
|
||||
dimlist = [1]+dimlist
|
||||
redlist = [not redlist[0]]+redlist
|
||||
|
||||
osize = tuple(osize)
|
||||
print(op, inp.shape, axis, "->", osize, dimlist, redlist)
|
||||
inslot, outslot = 0, 2
|
||||
cherry_dmar(SLOT(inslot), inp)
|
||||
|
||||
# dimlist is always the inshape
|
||||
# redlist is always [False, True, False, True, ...., True, False]
|
||||
|
||||
# do the reduce merges one at a time
|
||||
while True:
|
||||
print(dimlist)
|
||||
for l in range(0, int(np.prod(dimlist[:-2]))):
|
||||
for k in range(0, dimlist[-1], SZ):
|
||||
reduce_size = min(SZ, dimlist[-1]-k)
|
||||
j = 0
|
||||
while j < dimlist[-2]:
|
||||
riski_load(Reg.MATMUL_INPUT,
|
||||
SLOT(inslot) + l*dimlist[-2]*dimlist[-1] + j*dimlist[-1] + k,
|
||||
stride_y=dimlist[-1], stride_x=1,
|
||||
len_y=min(SZ if j == 0 else SZ-1, dimlist[-2]-j),
|
||||
len_x=reduce_size,
|
||||
zero=j==0, skip_first=j!=0)
|
||||
#cherry_regdump()
|
||||
reduceops[op]()
|
||||
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
|
||||
j += SZ if j == 0 else SZ-1
|
||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size)
|
||||
if len(dimlist) == 2:
|
||||
break
|
||||
# merge [False, True] into [False]
|
||||
dimlist = dimlist[:-3] + [dimlist[-3]*dimlist[-1]]
|
||||
inslot, outslot = outslot, inslot
|
||||
|
||||
return cherry_dmaw(SLOT(outslot), osize)
|
||||
|
||||
def cherry_unop(x, op):
|
||||
cherry_dmar(SLOT(0), x)
|
||||
|
|
|
@ -34,7 +34,6 @@ class Exp(Function):
|
|||
|
||||
# ************* reduce ops *************
|
||||
|
||||
"""
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input, axis)
|
||||
|
@ -46,6 +45,7 @@ class Sum(Function):
|
|||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape) + np.zeros_like(input)
|
||||
|
||||
"""
|
||||
class Max(Function):
|
||||
def forward(ctx, inp, axis=None):
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
|
|
|
@ -83,6 +83,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
|
||||
def test_max(self):
|
||||
|
|
Loading…
Reference in New Issue