diff --git a/extra/cherry.py b/extra/cherry.py index 500ac7c7..5238cfeb 100755 --- a/extra/cherry.py +++ b/extra/cherry.py @@ -179,11 +179,10 @@ binops = {BinaryOps.ADD: riski_add, reduceops = {ReduceOps.SUM: riski_reduce_sum, ReduceOps.MAX: riski_reduce_max} -#@count -# TODO: add masks to matmul instruction? -def riski_matmul(): +@count +def riski_matmul(slow=False): #print("LLL:\n",regfile[Reg.MATMUL_INPUT],"\n",regfile[Reg.MATMUL_WEIGHTS]) - if False: + if not slow: regfile[Reg.MATMUL_ACC] += \ regfile[Reg.MATMUL_INPUT] @ \ regfile[Reg.MATMUL_WEIGHTS] @@ -242,12 +241,12 @@ def cherry_dmar(address, arr): arr = arr.reshape(-1) assert(arr.shape[0] <= SLOTSIZE) maxdma = max(maxdma, arr.shape[0]) - print("DMAR %d elements" % arr.shape[0]) + #print("DMAR %d elements" % arr.shape[0]) sram[address:address+arr.shape[0]] = arr @count def cherry_dmaw(address, shp): - print("DMAW %d elements" % np.prod(shp)) + #print("DMAW %d elements" % np.prod(shp)) return np.copy(sram[address:address+np.prod(shp)].reshape(shp)) # *** CHERRY code to be compiled *** @@ -294,7 +293,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False): osize = nosize osize = tuple(osize) - print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist) + #print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist) inslot, outslot = 0, 2 cherry_dmar(SLOT(inslot), inp) @@ -303,7 +302,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False): # special case if redlist ends with True if len(redlist) > 0 and redlist[-1] == True: - print("special case redlist[-1] == True") + #print("special case redlist[-1] == True") outside = int(np.prod(dimlist[:-1])) for l in range(0, outside, SZ): reduce_size = min(SZ, outside-l) @@ -327,7 +326,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False): # do the reduce merges one at a time while len(dimlist) >= 2: - print("proc", dimlist, redlist) + #print("proc", dimlist, redlist) for l in range(0, int(np.prod(dimlist[:-2]))): for k in range(0, dimlist[-1], SZ): @@ -371,7 +370,7 @@ def cherry_binop(x, y, op): if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)): raise Exception(f"binary op unbroadcastable shape mismatch: {x.shape} vs {y.shape}") shape_ret = np.maximum(shape_x, shape_y) - print(shape_x, shape_y, shape_ret) + #print(shape_x, shape_y, shape_ret) dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims def push(dim, comp): @@ -382,7 +381,7 @@ def cherry_binop(x, y, op): for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting push(max(shape_x[i], shape_y[i]), (shape_x[i] > 1, shape_y[i] > 1)) - print(dimlist, complist) + #print(dimlist, complist) cherry_dmar(SLOT(0), x) cherry_dmar(SLOT(1), y) @@ -518,19 +517,12 @@ class TestCherry(unittest.TestCase): def test_mulacc_matmul(self): regfile[Reg.MATMUL_INPUT] = np.arange(1, SZ*SZ+1).reshape((SZ, SZ)) regfile[Reg.MATMUL_WEIGHTS] = np.arange(1, SZ*SZ+1).reshape((SZ, SZ))*-1 - print(regfile[Reg.MATMUL_INPUT]) - print(regfile[Reg.MATMUL_WEIGHTS]) riski_zero(Reg.MATMUL_ACC) riski_matmul() tst1 = np.copy(regfile[Reg.MATMUL_ACC]) - - for l in range(0, SZ): - riski_mul(l) - print("TST",l,regfile[Reg.MATMUL_OUTPUT]) - riski_reduce_sum(tgt=l) + riski_zero(Reg.MATMUL_ACC) + riski_matmul(True) tst2 = np.copy(regfile[Reg.MATMUL_ACC]) - print(tst1) - print(tst2) np.testing.assert_allclose(tst1, tst2) diff --git a/extra/ops_cherry.py b/extra/ops_cherry.py index 0dc17c90..ce8c1bac 100644 --- a/extra/ops_cherry.py +++ b/extra/ops_cherry.py @@ -176,7 +176,7 @@ class Conv2D(Function): # bs x groups x yx -- groups x 1 --> bs x groups x yx # it looks like a broadcasted multiply - print("opt1") + #print("opt1") # x: bs x groups x iy x ix # w: groups x H x W @@ -203,7 +203,7 @@ class Conv2D(Function): 1, oy*ox, min(SZ, ox-X), min(SZ, groups-g)) elif H == 1 and W == 1 and xs == 1 and ys == 1: - print("opt2") + #print("opt2") # oxy x cin x rcout -- unstrided 1x1 # this is a simple matmul for g in range(0, groups): @@ -225,7 +225,7 @@ class Conv2D(Function): SLOT(2) + B*groups*rcout*yx + g*rcout*yx + c*yx + YX, 1, yx, min(SZ, yx-YX), min(SZ, rcout-c)) else: - print("unoptimized") + #print("unoptimized") # ox x cin x rcout -- unoptimized for g in range(0, groups): for c in range(0, rcout, SZ): diff --git a/fpga/TPUNOTES b/fpga/TPUNOTES index 8846f760..069ec4a6 100644 --- a/fpga/TPUNOTES +++ b/fpga/TPUNOTES @@ -51,10 +51,10 @@ Scalar Unit: 128x Vector Unit: 32Ki x 32b Vector (Lane) Memory (8 ports) - 8x: + 8x: Sublane 32 x 32b Vector (Lane) Reg File - 2x ALUs - Connections to matrix unix + 2x ALUs on a float each, 2048 total + Connections to matrix unit 2048 Vector ALUs 2x Matrix Multiply Unit