This commit is contained in:
George Hotz 2021-06-10 16:52:37 -07:00
parent 4535d39baa
commit 10b1306525
2 changed files with 120 additions and 9 deletions

View File

@ -32,6 +32,50 @@ class Exp(Function):
ret, = ctx.saved_tensors
return risk_binop(grad_output, ret, BinaryOps.MUL)
# ************* binary ops *************
def unbroadcast(out, in_sh):
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
sum_axis = tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1]) if in_sh != (1,) else None
return out.sum(axis=sum_axis).reshape(in_sh)
class Add(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return risk_binop(x, y, BinaryOps.ADD)
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(grad_output, shape_x), unbroadcast(grad_output, shape_y)
class Sub(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return risk_binop(x, y, BinaryOps.SUB)
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(grad_output, shape_x), unbroadcast(-grad_output, shape_y)
class Mul(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return risk_binop(x, y, BinaryOps.MUL)
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return unbroadcast(y*grad_output, x.shape), unbroadcast(x*grad_output, y.shape)
class Pow(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return risk_binop(x, y, BinaryOps.POW)
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
unbroadcast((x**y) * np.log(x) * grad_output, y.shape)
# ************* processing ops *************
class Matmul(Function):

View File

@ -139,7 +139,7 @@ def riski_mulacc():
@count
def riski_pow():
regfile[Reg.MATMUL_OUTPUT] = np.pow(regfile[Reg.MATMUL_INPUT], regfile[Reg.MATMUL_WEIGHTS])
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] ** regfile[Reg.MATMUL_WEIGHTS]
# TODO: make accumulate a bit in the instruction available to all
binops = {BinaryOps.ADD: riski_add,
@ -202,15 +202,82 @@ def risk_unop(x, op):
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
return riski_dmaw(SLOT(2), x.shape)
def risk_binop(x, w, op):
def risk_binop(x, y, op):
n_dims = max(len(x.shape), len(y.shape))
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)
shape_x[:len(x.shape)] = np.array(x.shape, dtype=np.int32)
shape_y[:len(y.shape)] = np.array(y.shape, dtype=np.int32)
if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)):
raise Exception(f"binary op unbroadcastable shape mismatch: {x.shape} vs {y.shape}")
shape_ret = np.maximum(shape_x, shape_y)
print(shape_x, shape_y, shape_ret)
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
def push(dim, comp):
if len(complist) > 0 and complist[-1] == comp:
dimlist[-1] *= dim
elif comp != (False, False):
dimlist.append(dim); complist.append(comp)
for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting
push(max(shape_x[i], shape_y[i]), (shape_x[i] > 1, shape_y[i] > 1))
print(dimlist, complist)
riski_dmar(SLOT(0), x)
riski_dmar(SLOT(1), w)
for i in range(0, np.prod(x.shape), SZ*SZ):
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+i)
binops[op]()
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
return riski_dmaw(SLOT(2), x.shape)
riski_dmar(SLOT(1), y)
if len(dimlist) <= 1:
if len(complist) == 0:
complist = [(True, True)]
for i in range(0, np.prod(shape_ret), SZ*SZ):
if complist[0][0]:
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
else:
riski_load(Reg.MATMUL_INPUT, SLOT(0), stride_y=0, stride_x=0)
if complist[0][1]:
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+i)
else:
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1), stride_y=0, stride_x=0)
binops[op]()
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
else:
# broadcasting on the inner 2 "real" dimensions sped up
# NOTE: this can be made faster by supporting any dimensions
def gd(idx, dims, comps):
ret = 0
mult = 1
in_idx = idx
for c,d in zip(comps[::-1], dims[::-1]):
tt = idx % d
idx = idx // d
if c == False:
continue
ret += tt*mult
mult *= d
#print(ret, in_idx, dims, comps)
return ret
for i in range(0, int(np.prod(dimlist[:-2]))):
off_0 = SLOT(0) + gd(i, dimlist[:-2], [x[0] for x in complist[:-2]])*\
(dimlist[-2] if complist[-2][0] else 1)*(dimlist[-1] if complist[-1][0] else 1)
off_1 = SLOT(1) + gd(i, dimlist[:-2], [x[1] for x in complist[:-2]])*\
(dimlist[-2] if complist[-2][1] else 1)*(dimlist[-1] if complist[-1][1] else 1)
off_2 = SLOT(2) + gd(i, dimlist[:-2], [True]*len(dimlist[:-2]))*dimlist[-2]*dimlist[-1]
for j in range(0, dimlist[-2], SZ):
for k in range(0, dimlist[-1], SZ):
sy = complist[-2][0]*(dimlist[-1] if complist[-1][0] else 1)
riski_load(Reg.MATMUL_INPUT,
off_0 + j*sy + k*complist[-1][0],
stride_y=sy, stride_x=complist[-1][0])
sy = complist[-2][1]*(dimlist[-1] if complist[-1][1] else 1)
riski_load(Reg.MATMUL_WEIGHTS,
off_1 + j*sy + k*complist[-1][1],
stride_y=sy, stride_x=complist[-1][1])
binops[op]()
# output is always "True"
riski_store(Reg.MATMUL_OUTPUT, off_2 + j*dimlist[-1] + k,
stride_y=dimlist[-1], stride_x=1,
len_y=min(SZ, dimlist[-2]-j), len_x=min(SZ, dimlist[-1]-k))
return riski_dmaw(SLOT(2), shape_ret)
def risk_matmul(x, w, transpose_x=False, transpose_w=False):
# copy matrices into SRAM