mirror of https://github.com/commaai/tinygrad.git
clean up print spam
This commit is contained in:
parent
3a91d5434f
commit
b48d4bad2e
|
@ -179,11 +179,10 @@ binops = {BinaryOps.ADD: riski_add,
|
|||
reduceops = {ReduceOps.SUM: riski_reduce_sum,
|
||||
ReduceOps.MAX: riski_reduce_max}
|
||||
|
||||
#@count
|
||||
# TODO: add masks to matmul instruction?
|
||||
def riski_matmul():
|
||||
@count
|
||||
def riski_matmul(slow=False):
|
||||
#print("LLL:\n",regfile[Reg.MATMUL_INPUT],"\n",regfile[Reg.MATMUL_WEIGHTS])
|
||||
if False:
|
||||
if not slow:
|
||||
regfile[Reg.MATMUL_ACC] += \
|
||||
regfile[Reg.MATMUL_INPUT] @ \
|
||||
regfile[Reg.MATMUL_WEIGHTS]
|
||||
|
@ -242,12 +241,12 @@ def cherry_dmar(address, arr):
|
|||
arr = arr.reshape(-1)
|
||||
assert(arr.shape[0] <= SLOTSIZE)
|
||||
maxdma = max(maxdma, arr.shape[0])
|
||||
print("DMAR %d elements" % arr.shape[0])
|
||||
#print("DMAR %d elements" % arr.shape[0])
|
||||
sram[address:address+arr.shape[0]] = arr
|
||||
|
||||
@count
|
||||
def cherry_dmaw(address, shp):
|
||||
print("DMAW %d elements" % np.prod(shp))
|
||||
#print("DMAW %d elements" % np.prod(shp))
|
||||
return np.copy(sram[address:address+np.prod(shp)].reshape(shp))
|
||||
|
||||
# *** CHERRY code to be compiled ***
|
||||
|
@ -294,7 +293,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False):
|
|||
osize = nosize
|
||||
|
||||
osize = tuple(osize)
|
||||
print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist)
|
||||
#print("reduce", op, inp.shape, axis, "->", osize, dimlist, redlist)
|
||||
inslot, outslot = 0, 2
|
||||
cherry_dmar(SLOT(inslot), inp)
|
||||
|
||||
|
@ -303,7 +302,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False):
|
|||
|
||||
# special case if redlist ends with True
|
||||
if len(redlist) > 0 and redlist[-1] == True:
|
||||
print("special case redlist[-1] == True")
|
||||
#print("special case redlist[-1] == True")
|
||||
outside = int(np.prod(dimlist[:-1]))
|
||||
for l in range(0, outside, SZ):
|
||||
reduce_size = min(SZ, outside-l)
|
||||
|
@ -327,7 +326,7 @@ def cherry_reduceop(inp, op, axis, keepdims=False):
|
|||
|
||||
# do the reduce merges one at a time
|
||||
while len(dimlist) >= 2:
|
||||
print("proc", dimlist, redlist)
|
||||
#print("proc", dimlist, redlist)
|
||||
|
||||
for l in range(0, int(np.prod(dimlist[:-2]))):
|
||||
for k in range(0, dimlist[-1], SZ):
|
||||
|
@ -371,7 +370,7 @@ def cherry_binop(x, y, op):
|
|||
if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)):
|
||||
raise Exception(f"binary op unbroadcastable shape mismatch: {x.shape} vs {y.shape}")
|
||||
shape_ret = np.maximum(shape_x, shape_y)
|
||||
print(shape_x, shape_y, shape_ret)
|
||||
#print(shape_x, shape_y, shape_ret)
|
||||
|
||||
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
|
||||
def push(dim, comp):
|
||||
|
@ -382,7 +381,7 @@ def cherry_binop(x, y, op):
|
|||
for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting
|
||||
push(max(shape_x[i], shape_y[i]), (shape_x[i] > 1, shape_y[i] > 1))
|
||||
|
||||
print(dimlist, complist)
|
||||
#print(dimlist, complist)
|
||||
|
||||
cherry_dmar(SLOT(0), x)
|
||||
cherry_dmar(SLOT(1), y)
|
||||
|
@ -518,19 +517,12 @@ class TestCherry(unittest.TestCase):
|
|||
def test_mulacc_matmul(self):
|
||||
regfile[Reg.MATMUL_INPUT] = np.arange(1, SZ*SZ+1).reshape((SZ, SZ))
|
||||
regfile[Reg.MATMUL_WEIGHTS] = np.arange(1, SZ*SZ+1).reshape((SZ, SZ))*-1
|
||||
print(regfile[Reg.MATMUL_INPUT])
|
||||
print(regfile[Reg.MATMUL_WEIGHTS])
|
||||
riski_zero(Reg.MATMUL_ACC)
|
||||
riski_matmul()
|
||||
tst1 = np.copy(regfile[Reg.MATMUL_ACC])
|
||||
|
||||
for l in range(0, SZ):
|
||||
riski_mul(l)
|
||||
print("TST",l,regfile[Reg.MATMUL_OUTPUT])
|
||||
riski_reduce_sum(tgt=l)
|
||||
riski_zero(Reg.MATMUL_ACC)
|
||||
riski_matmul(True)
|
||||
tst2 = np.copy(regfile[Reg.MATMUL_ACC])
|
||||
print(tst1)
|
||||
print(tst2)
|
||||
np.testing.assert_allclose(tst1, tst2)
|
||||
|
||||
|
||||
|
|
|
@ -176,7 +176,7 @@ class Conv2D(Function):
|
|||
# bs x groups x yx -- groups x 1 --> bs x groups x yx
|
||||
# it looks like a broadcasted multiply
|
||||
|
||||
print("opt1")
|
||||
#print("opt1")
|
||||
|
||||
# x: bs x groups x iy x ix
|
||||
# w: groups x H x W
|
||||
|
@ -203,7 +203,7 @@ class Conv2D(Function):
|
|||
1, oy*ox, min(SZ, ox-X), min(SZ, groups-g))
|
||||
|
||||
elif H == 1 and W == 1 and xs == 1 and ys == 1:
|
||||
print("opt2")
|
||||
#print("opt2")
|
||||
# oxy x cin x rcout -- unstrided 1x1
|
||||
# this is a simple matmul
|
||||
for g in range(0, groups):
|
||||
|
@ -225,7 +225,7 @@ class Conv2D(Function):
|
|||
SLOT(2) + B*groups*rcout*yx + g*rcout*yx + c*yx + YX,
|
||||
1, yx, min(SZ, yx-YX), min(SZ, rcout-c))
|
||||
else:
|
||||
print("unoptimized")
|
||||
#print("unoptimized")
|
||||
# ox x cin x rcout -- unoptimized
|
||||
for g in range(0, groups):
|
||||
for c in range(0, rcout, SZ):
|
||||
|
|
|
@ -51,10 +51,10 @@ Scalar Unit:
|
|||
|
||||
128x Vector Unit:
|
||||
32Ki x 32b Vector (Lane) Memory (8 ports)
|
||||
8x:
|
||||
8x: Sublane
|
||||
32 x 32b Vector (Lane) Reg File
|
||||
2x ALUs
|
||||
Connections to matrix unix
|
||||
2x ALUs on a float each, 2048 total
|
||||
Connections to matrix unit
|
||||
2048 Vector ALUs
|
||||
|
||||
2x Matrix Multiply Unit
|
||||
|
|
Loading…
Reference in New Issue