clean up print spam

This commit is contained in:
George Hotz 2021-06-19 10:30:51 -07:00
parent 3a91d5434f
commit b48d4bad2e
3 changed files with 18 additions and 26 deletions

View File

@ -179,11 +179,10 @@ binops = {BinaryOps.ADD: riski_add,
reduceops = {ReduceOps.SUM: riski_reduce_sum,
ReduceOps.MAX: riski_reduce_max}
#@count
# TODO: add masks to matmul instruction?
def riski_matmul():
@count
def riski_matmul(slow=False):
#print("LLL:\n",regfile[Reg.MATMUL_INPUT],"\n",regfile[Reg.MATMUL_WEIGHTS])
if False:
if not slow:
regfile[Reg.MATMUL_ACC] += \
regfile[Reg.MATMUL_INPUT] @ \
regfile[Reg.MATMUL_WEIGHTS]
@ -242,12 +241,12 @@ def cherry_dmar(address, arr):
arr = arr.reshape(-1)
assert(arr.shape[0] <= SLOTSIZE)
maxdma = max(maxdma, arr.shape[0])
print("DMAR %d elements" % arr.shape[0])
#print("DMAR %d elements" % arr.shape[0])
sram[address:address+arr.shape[0]] = arr
@count
def cherry_dmaw(address, shp):
print("DMAW %d elements" % np.prod(shp))
#print("DMAW %d elements" % np.prod(shp))
return np.copy(sram[address:address+np.prod(shp)].reshape(shp))
# *** CHERRY code to be compiled ***
@ -294,7 +293,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False):
osize = nosize
osize = tuple(osize)
print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist)
#print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist)
inslot, outslot = 0, 2
cherry_dmar(SLOT(inslot), inp)
@ -303,7 +302,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False):
# special case if redlist ends with True
if len(redlist) > 0 and redlist[-1] == True:
print("special case redlist[-1] == True")
#print("special case redlist[-1] == True")
outside = int(np.prod(dimlist[:-1]))
for l in range(0, outside, SZ):
reduce_size = min(SZ, outside-l)
@ -327,7 +326,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False):
# do the reduce merges one at a time
while len(dimlist) >= 2:
print("proc", dimlist, redlist)
#print("proc", dimlist, redlist)
for l in range(0, int(np.prod(dimlist[:-2]))):
for k in range(0, dimlist[-1], SZ):
@ -371,7 +370,7 @@ def cherry_binop(x, y, op):
if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)):
raise Exception(f"binary op unbroadcastable shape mismatch: {x.shape} vs {y.shape}")
shape_ret = np.maximum(shape_x, shape_y)
print(shape_x, shape_y, shape_ret)
#print(shape_x, shape_y, shape_ret)
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
def push(dim, comp):
@ -382,7 +381,7 @@ def cherry_binop(x, y, op):
for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting
push(max(shape_x[i], shape_y[i]), (shape_x[i] > 1, shape_y[i] > 1))
print(dimlist, complist)
#print(dimlist, complist)
cherry_dmar(SLOT(0), x)
cherry_dmar(SLOT(1), y)
@ -518,19 +517,12 @@ class TestCherry(unittest.TestCase):
def test_mulacc_matmul(self):
regfile[Reg.MATMUL_INPUT] = np.arange(1, SZ*SZ+1).reshape((SZ, SZ))
regfile[Reg.MATMUL_WEIGHTS] = np.arange(1, SZ*SZ+1).reshape((SZ, SZ))*-1
print(regfile[Reg.MATMUL_INPUT])
print(regfile[Reg.MATMUL_WEIGHTS])
riski_zero(Reg.MATMUL_ACC)
riski_matmul()
tst1 = np.copy(regfile[Reg.MATMUL_ACC])
for l in range(0, SZ):
riski_mul(l)
print("TST",l,regfile[Reg.MATMUL_OUTPUT])
riski_reduce_sum(tgt=l)
riski_zero(Reg.MATMUL_ACC)
riski_matmul(True)
tst2 = np.copy(regfile[Reg.MATMUL_ACC])
print(tst1)
print(tst2)
np.testing.assert_allclose(tst1, tst2)

View File

@ -176,7 +176,7 @@ class Conv2D(Function):
# bs x groups x yx -- groups x 1 --> bs x groups x yx
# it looks like a broadcasted multiply
print("opt1")
#print("opt1")
# x: bs x groups x iy x ix
# w: groups x H x W
@ -203,7 +203,7 @@ class Conv2D(Function):
1, oy*ox, min(SZ, ox-X), min(SZ, groups-g))
elif H == 1 and W == 1 and xs == 1 and ys == 1:
print("opt2")
#print("opt2")
# oxy x cin x rcout -- unstrided 1x1
# this is a simple matmul
for g in range(0, groups):
@ -225,7 +225,7 @@ class Conv2D(Function):
SLOT(2) + B*groups*rcout*yx + g*rcout*yx + c*yx + YX,
1, yx, min(SZ, yx-YX), min(SZ, rcout-c))
else:
print("unoptimized")
#print("unoptimized")
# ox x cin x rcout -- unoptimized
for g in range(0, groups):
for c in range(0, rcout, SZ):

View File

@ -51,10 +51,10 @@ Scalar Unit:
128x Vector Unit:
32Ki x 32b Vector (Lane) Memory (8 ports)
8x:
8x: Sublane
32 x 32b Vector (Lane) Reg File
2x ALUs
Connections to matrix unix
2x ALUs on a float each, 2048 total
Connections to matrix unit
2048 Vector ALUs
2x Matrix Multiply Unit