This commit is contained in:
George Hotz 2021-06-17 16:19:35 -07:00
parent d6517a8a7c
commit c1d469d440
4 changed files with 96 additions and 16 deletions

View File

@ -159,7 +159,7 @@ python3 -m pytest
PYTHONPATH="." DEBUG=1 CHERRY=1 python3 examples/efficientnet.py https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg
```
* Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`
* ~~Add reduce ops to CHERRY, and fully support forward pass. See `extra/ops_risk.py` and `extra/risk.py`~~
* Switch convolution backward pass to CHERRY instead of the numpy placeholder
* Confirm EfficientNet backward pass fully uses CHERRY instructions
* Benchmark that and transformers

View File

@ -38,7 +38,7 @@ from collections import defaultdict
# <empty> <output> <input> <weight>
# <weight> <input> <empty> <output>
SZ = 32
SZ = 4
SLOTSIZE = 1024*1024*2 # 5MB, for 20MB total. 8M elements
sram = np.zeros((SLOTSIZE*4), dtype=np.float32)
regfile = {}
@ -148,12 +148,12 @@ def riski_pow():
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] ** regfile[Reg.MATMUL_WEIGHTS]
@count
def riski_reduce_sum():
regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT].sum(axis=0)
def riski_reduce_sum(out=0, cnt=SZ):
regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].sum(axis=0)
@count
def riski_reduce_max():
regfile[Reg.MATMUL_OUTPUT][0] = regfile[Reg.MATMUL_INPUT].max(axis=0)
def riski_reduce_max(out=0, cnt=SZ):
regfile[Reg.MATMUL_OUTPUT][out] = regfile[Reg.MATMUL_INPUT][0:cnt].max(axis=0)
# TODO: make accumulate a bit in the instruction available to all
binops = {BinaryOps.ADD: riski_add,
@ -175,21 +175,28 @@ def riski_matmul():
regfile[Reg.MATMUL_WEIGHTS]
@count
def riski_mov(tout, tin):
regfile[tout][:] = regfile[tin]
def riski_mov(tout, tin, transpose=False):
ret = regfile[tin]
if transpose:
ret = ret.T
regfile[tout] = np.copy(ret)
load_log = open("/tmp/risk_load_log", "w") if os.getenv("LOAD_LOG") else None
@count
def riski_load(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ):
def riski_load(target, address, stride_y=SZ, stride_x=1, len_y=SZ, len_x=SZ, zero=True, skip_first=False):
global util_n, util_d
if load_log is not None:
load_log.write("%d %d %d %d %d\n" % (address, stride_y, stride_x, len_y, len_x))
utils[(len_y, len_x)] += 1
stride_y, stride_x = int(stride_y), int(stride_x)
d = regfile[target]
d[:] = 0
d[:len_y, :len_x] = np.lib.stride_tricks.as_strided(sram[address:], (len_y, len_x), (stride_y*4, stride_x*4))
if zero:
d[:] = 0
if skip_first:
d[1:(len_y+1), :len_x] = np.lib.stride_tricks.as_strided(sram[address:], (len_y, len_x), (stride_y*4, stride_x*4))
else:
d[:len_y, :len_x] = np.lib.stride_tricks.as_strided(sram[address:], (len_y, len_x), (stride_y*4, stride_x*4))
"""
for y in range(0, len_y):
for x in range(0, len_x):
@ -225,11 +232,83 @@ def cherry_dmaw(address, shp):
# *** CHERRY code to be compiled ***
def cherry_reduceop(x, op, axis):
print(op, x.shape, axis)
cherry_dmar(SLOT(0), x)
def cherry_reduceop(inp, op, axis):
dimlist, redlist = [], []
if type(axis) == int:
axis = [axis]
if axis is None:
# full reduce
#osize = [1]*len(inp.shape)
osize = [1]
dimlist = [np.prod(inp.shape), 1]
redlist = [True, False]
else:
osize = np.array(inp.shape)
osize[list(axis)] = 1
return cherry_dmaw(SLOT(2), x.shape)
# skip any early 1s
sd = 0
while sd < len(osize) and osize[sd] == 1 and inp.shape[sd] == 1:
sd += 1
# this would be good in the GPU
for i in range(sd, len(inp.shape)):
is_reduce_axis = osize[i] == 1 and inp.shape[i] != -1
if len(dimlist) == 0:
dimlist.append(inp.shape[i])
redlist.append(is_reduce_axis)
else:
if redlist[-1] == is_reduce_axis:
dimlist[-1] *= inp.shape[i]
else:
dimlist.append(inp.shape[i])
redlist.append(is_reduce_axis)
nosize = []
for i in range(osize.shape[0]):
if i not in axis:
nosize.append(osize[i])
osize = nosize
if redlist[-1] != False:
dimlist = dimlist+[1]
redlist = redlist+[False]
if len(dimlist)%2 == 1:
dimlist = [1]+dimlist
redlist = [not redlist[0]]+redlist
osize = tuple(osize)
print(op, inp.shape, axis, "->", osize, dimlist, redlist)
inslot, outslot = 0, 2
cherry_dmar(SLOT(inslot), inp)
# dimlist is always the inshape
# redlist is always [False, True, False, True, ...., True, False]
# do the reduce merges one at a time
while True:
print(dimlist)
for l in range(0, int(np.prod(dimlist[:-2]))):
for k in range(0, dimlist[-1], SZ):
reduce_size = min(SZ, dimlist[-1]-k)
j = 0
while j < dimlist[-2]:
riski_load(Reg.MATMUL_INPUT,
SLOT(inslot) + l*dimlist[-2]*dimlist[-1] + j*dimlist[-1] + k,
stride_y=dimlist[-1], stride_x=1,
len_y=min(SZ if j == 0 else SZ-1, dimlist[-2]-j),
len_x=reduce_size,
zero=j==0, skip_first=j!=0)
#cherry_regdump()
reduceops[op]()
riski_mov(Reg.MATMUL_INPUT, Reg.MATMUL_OUTPUT) # move the first row
j += SZ if j == 0 else SZ-1
riski_store(Reg.MATMUL_OUTPUT, SLOT(outslot) + l*dimlist[-1] + k, len_y=1, len_x=reduce_size)
if len(dimlist) == 2:
break
# merge [False, True] into [False]
dimlist = dimlist[:-3] + [dimlist[-3]*dimlist[-1]]
inslot, outslot = outslot, inslot
return cherry_dmaw(SLOT(outslot), osize)
def cherry_unop(x, op):
cherry_dmar(SLOT(0), x)

View File

@ -34,7 +34,6 @@ class Exp(Function):
# ************* reduce ops *************
"""
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input, axis)
@ -46,6 +45,7 @@ class Sum(Function):
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
return grad_output.reshape(shape) + np.zeros_like(input)
"""
class Max(Function):
def forward(ctx, inp, axis=None):
if isinstance(axis, int): axis = [axis]

View File

@ -83,6 +83,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
def test_sum(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
def test_max(self):