mirror of https://github.com/commaai/tinygrad.git
binops
This commit is contained in:
parent
4535d39baa
commit
10b1306525
|
@ -32,6 +32,50 @@ class Exp(Function):
|
|||
ret, = ctx.saved_tensors
|
||||
return risk_binop(grad_output, ret, BinaryOps.MUL)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
def unbroadcast(out, in_sh):
|
||||
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
|
||||
sum_axis = tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1]) if in_sh != (1,) else None
|
||||
return out.sum(axis=sum_axis).reshape(in_sh)
|
||||
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return risk_binop(x, y, BinaryOps.ADD)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(grad_output, shape_x), unbroadcast(grad_output, shape_y)
|
||||
|
||||
class Sub(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return risk_binop(x, y, BinaryOps.SUB)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(grad_output, shape_x), unbroadcast(-grad_output, shape_y)
|
||||
|
||||
class Mul(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return risk_binop(x, y, BinaryOps.MUL)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return unbroadcast(y*grad_output, x.shape), unbroadcast(x*grad_output, y.shape)
|
||||
|
||||
class Pow(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return risk_binop(x, y, BinaryOps.POW)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
|
||||
unbroadcast((x**y) * np.log(x) * grad_output, y.shape)
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
class Matmul(Function):
|
||||
|
|
|
@ -139,7 +139,7 @@ def riski_mulacc():
|
|||
|
||||
@count
|
||||
def riski_pow():
|
||||
regfile[Reg.MATMUL_OUTPUT] = np.pow(regfile[Reg.MATMUL_INPUT], regfile[Reg.MATMUL_WEIGHTS])
|
||||
regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] ** regfile[Reg.MATMUL_WEIGHTS]
|
||||
|
||||
# TODO: make accumulate a bit in the instruction available to all
|
||||
binops = {BinaryOps.ADD: riski_add,
|
||||
|
@ -202,15 +202,82 @@ def risk_unop(x, op):
|
|||
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
|
||||
return riski_dmaw(SLOT(2), x.shape)
|
||||
|
||||
def risk_binop(x, w, op):
|
||||
def risk_binop(x, y, op):
|
||||
n_dims = max(len(x.shape), len(y.shape))
|
||||
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)
|
||||
shape_x[:len(x.shape)] = np.array(x.shape, dtype=np.int32)
|
||||
shape_y[:len(y.shape)] = np.array(y.shape, dtype=np.int32)
|
||||
if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)):
|
||||
raise Exception(f"binary op unbroadcastable shape mismatch: {x.shape} vs {y.shape}")
|
||||
shape_ret = np.maximum(shape_x, shape_y)
|
||||
print(shape_x, shape_y, shape_ret)
|
||||
|
||||
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
|
||||
def push(dim, comp):
|
||||
if len(complist) > 0 and complist[-1] == comp:
|
||||
dimlist[-1] *= dim
|
||||
elif comp != (False, False):
|
||||
dimlist.append(dim); complist.append(comp)
|
||||
for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting
|
||||
push(max(shape_x[i], shape_y[i]), (shape_x[i] > 1, shape_y[i] > 1))
|
||||
|
||||
print(dimlist, complist)
|
||||
|
||||
riski_dmar(SLOT(0), x)
|
||||
riski_dmar(SLOT(1), w)
|
||||
for i in range(0, np.prod(x.shape), SZ*SZ):
|
||||
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
|
||||
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+i)
|
||||
binops[op]()
|
||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
|
||||
return riski_dmaw(SLOT(2), x.shape)
|
||||
riski_dmar(SLOT(1), y)
|
||||
if len(dimlist) <= 1:
|
||||
if len(complist) == 0:
|
||||
complist = [(True, True)]
|
||||
for i in range(0, np.prod(shape_ret), SZ*SZ):
|
||||
if complist[0][0]:
|
||||
riski_load(Reg.MATMUL_INPUT, SLOT(0)+i)
|
||||
else:
|
||||
riski_load(Reg.MATMUL_INPUT, SLOT(0), stride_y=0, stride_x=0)
|
||||
if complist[0][1]:
|
||||
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+i)
|
||||
else:
|
||||
riski_load(Reg.MATMUL_WEIGHTS, SLOT(1), stride_y=0, stride_x=0)
|
||||
binops[op]()
|
||||
riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i)
|
||||
else:
|
||||
# broadcasting on the inner 2 "real" dimensions sped up
|
||||
# NOTE: this can be made faster by supporting any dimensions
|
||||
def gd(idx, dims, comps):
|
||||
ret = 0
|
||||
mult = 1
|
||||
in_idx = idx
|
||||
for c,d in zip(comps[::-1], dims[::-1]):
|
||||
tt = idx % d
|
||||
idx = idx // d
|
||||
if c == False:
|
||||
continue
|
||||
ret += tt*mult
|
||||
mult *= d
|
||||
#print(ret, in_idx, dims, comps)
|
||||
return ret
|
||||
for i in range(0, int(np.prod(dimlist[:-2]))):
|
||||
off_0 = SLOT(0) + gd(i, dimlist[:-2], [x[0] for x in complist[:-2]])*\
|
||||
(dimlist[-2] if complist[-2][0] else 1)*(dimlist[-1] if complist[-1][0] else 1)
|
||||
off_1 = SLOT(1) + gd(i, dimlist[:-2], [x[1] for x in complist[:-2]])*\
|
||||
(dimlist[-2] if complist[-2][1] else 1)*(dimlist[-1] if complist[-1][1] else 1)
|
||||
off_2 = SLOT(2) + gd(i, dimlist[:-2], [True]*len(dimlist[:-2]))*dimlist[-2]*dimlist[-1]
|
||||
for j in range(0, dimlist[-2], SZ):
|
||||
for k in range(0, dimlist[-1], SZ):
|
||||
sy = complist[-2][0]*(dimlist[-1] if complist[-1][0] else 1)
|
||||
riski_load(Reg.MATMUL_INPUT,
|
||||
off_0 + j*sy + k*complist[-1][0],
|
||||
stride_y=sy, stride_x=complist[-1][0])
|
||||
sy = complist[-2][1]*(dimlist[-1] if complist[-1][1] else 1)
|
||||
riski_load(Reg.MATMUL_WEIGHTS,
|
||||
off_1 + j*sy + k*complist[-1][1],
|
||||
stride_y=sy, stride_x=complist[-1][1])
|
||||
binops[op]()
|
||||
# output is always "True"
|
||||
riski_store(Reg.MATMUL_OUTPUT, off_2 + j*dimlist[-1] + k,
|
||||
stride_y=dimlist[-1], stride_x=1,
|
||||
len_y=min(SZ, dimlist[-2]-j), len_x=min(SZ, dimlist[-1]-k))
|
||||
|
||||
return riski_dmaw(SLOT(2), shape_ret)
|
||||
|
||||
def risk_matmul(x, w, transpose_x=False, transpose_w=False):
|
||||
# copy matrices into SRAM
|
||||
|
|
Loading…
Reference in New Issue