From 10b13065251258637317d679376e802ecf0427b2 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 10 Jun 2021 16:52:37 -0700 Subject: [PATCH] binops --- extra/ops_risk.py | 44 ++++++++++++++++++++++++ extra/risk.py | 85 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 120 insertions(+), 9 deletions(-) diff --git a/extra/ops_risk.py b/extra/ops_risk.py index fee4fcba..0f68242c 100644 --- a/extra/ops_risk.py +++ b/extra/ops_risk.py @@ -32,6 +32,50 @@ class Exp(Function): ret, = ctx.saved_tensors return risk_binop(grad_output, ret, BinaryOps.MUL) +# ************* binary ops ************* + +def unbroadcast(out, in_sh): + # adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i] + sum_axis = tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1]) if in_sh != (1,) else None + return out.sum(axis=sum_axis).reshape(in_sh) + +class Add(Function): + def forward(ctx, x, y): + ctx.save_for_backward(x.shape, y.shape) + return risk_binop(x, y, BinaryOps.ADD) + + def backward(ctx, grad_output): + shape_x, shape_y = ctx.saved_tensors + return unbroadcast(grad_output, shape_x), unbroadcast(grad_output, shape_y) + +class Sub(Function): + def forward(ctx, x, y): + ctx.save_for_backward(x.shape, y.shape) + return risk_binop(x, y, BinaryOps.SUB) + + def backward(ctx, grad_output): + shape_x, shape_y = ctx.saved_tensors + return unbroadcast(grad_output, shape_x), unbroadcast(-grad_output, shape_y) + +class Mul(Function): + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return risk_binop(x, y, BinaryOps.MUL) + + def backward(ctx, grad_output): + x,y = ctx.saved_tensors + return unbroadcast(y*grad_output, x.shape), unbroadcast(x*grad_output, y.shape) + +class Pow(Function): + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return risk_binop(x, y, BinaryOps.POW) + + def backward(ctx, grad_output): + x,y = ctx.saved_tensors + return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \ + unbroadcast((x**y) * np.log(x) * grad_output, y.shape) + # ************* processing ops ************* class Matmul(Function): diff --git a/extra/risk.py b/extra/risk.py index baddf4fd..6f920831 100755 --- a/extra/risk.py +++ b/extra/risk.py @@ -139,7 +139,7 @@ def riski_mulacc(): @count def riski_pow(): - regfile[Reg.MATMUL_OUTPUT] = np.pow(regfile[Reg.MATMUL_INPUT], regfile[Reg.MATMUL_WEIGHTS]) + regfile[Reg.MATMUL_OUTPUT] = regfile[Reg.MATMUL_INPUT] ** regfile[Reg.MATMUL_WEIGHTS] # TODO: make accumulate a bit in the instruction available to all binops = {BinaryOps.ADD: riski_add, @@ -202,15 +202,82 @@ def risk_unop(x, op): riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i) return riski_dmaw(SLOT(2), x.shape) -def risk_binop(x, w, op): +def risk_binop(x, y, op): + n_dims = max(len(x.shape), len(y.shape)) + shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32) + shape_x[:len(x.shape)] = np.array(x.shape, dtype=np.int32) + shape_y[:len(y.shape)] = np.array(y.shape, dtype=np.int32) + if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)): + raise Exception(f"binary op unbroadcastable shape mismatch: {x.shape} vs {y.shape}") + shape_ret = np.maximum(shape_x, shape_y) + print(shape_x, shape_y, shape_ret) + + dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims + def push(dim, comp): + if len(complist) > 0 and complist[-1] == comp: + dimlist[-1] *= dim + elif comp != (False, False): + dimlist.append(dim); complist.append(comp) + for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting + push(max(shape_x[i], shape_y[i]), (shape_x[i] > 1, shape_y[i] > 1)) + + print(dimlist, complist) + riski_dmar(SLOT(0), x) - riski_dmar(SLOT(1), w) - for i in range(0, np.prod(x.shape), SZ*SZ): - riski_load(Reg.MATMUL_INPUT, SLOT(0)+i) - riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+i) - binops[op]() - riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i) - return riski_dmaw(SLOT(2), x.shape) + riski_dmar(SLOT(1), y) + if len(dimlist) <= 1: + if len(complist) == 0: + complist = [(True, True)] + for i in range(0, np.prod(shape_ret), SZ*SZ): + if complist[0][0]: + riski_load(Reg.MATMUL_INPUT, SLOT(0)+i) + else: + riski_load(Reg.MATMUL_INPUT, SLOT(0), stride_y=0, stride_x=0) + if complist[0][1]: + riski_load(Reg.MATMUL_WEIGHTS, SLOT(1)+i) + else: + riski_load(Reg.MATMUL_WEIGHTS, SLOT(1), stride_y=0, stride_x=0) + binops[op]() + riski_store(Reg.MATMUL_OUTPUT, SLOT(2)+i) + else: + # broadcasting on the inner 2 "real" dimensions sped up + # NOTE: this can be made faster by supporting any dimensions + def gd(idx, dims, comps): + ret = 0 + mult = 1 + in_idx = idx + for c,d in zip(comps[::-1], dims[::-1]): + tt = idx % d + idx = idx // d + if c == False: + continue + ret += tt*mult + mult *= d + #print(ret, in_idx, dims, comps) + return ret + for i in range(0, int(np.prod(dimlist[:-2]))): + off_0 = SLOT(0) + gd(i, dimlist[:-2], [x[0] for x in complist[:-2]])*\ + (dimlist[-2] if complist[-2][0] else 1)*(dimlist[-1] if complist[-1][0] else 1) + off_1 = SLOT(1) + gd(i, dimlist[:-2], [x[1] for x in complist[:-2]])*\ + (dimlist[-2] if complist[-2][1] else 1)*(dimlist[-1] if complist[-1][1] else 1) + off_2 = SLOT(2) + gd(i, dimlist[:-2], [True]*len(dimlist[:-2]))*dimlist[-2]*dimlist[-1] + for j in range(0, dimlist[-2], SZ): + for k in range(0, dimlist[-1], SZ): + sy = complist[-2][0]*(dimlist[-1] if complist[-1][0] else 1) + riski_load(Reg.MATMUL_INPUT, + off_0 + j*sy + k*complist[-1][0], + stride_y=sy, stride_x=complist[-1][0]) + sy = complist[-2][1]*(dimlist[-1] if complist[-1][1] else 1) + riski_load(Reg.MATMUL_WEIGHTS, + off_1 + j*sy + k*complist[-1][1], + stride_y=sy, stride_x=complist[-1][1]) + binops[op]() + # output is always "True" + riski_store(Reg.MATMUL_OUTPUT, off_2 + j*dimlist[-1] + k, + stride_y=dimlist[-1], stride_x=1, + len_y=min(SZ, dimlist[-2]-j), len_x=min(SZ, dimlist[-1]-k)) + + return riski_dmaw(SLOT(2), shape_ret) def risk_matmul(x, w, transpose_x=False, transpose_w=False): # copy matrices into SRAM