mirror of https://github.com/commaai/tinygrad.git
noop removal can replace with reshape
This commit is contained in:
parent
d985217fa4
commit
d04b274cd2
|
@ -15,7 +15,7 @@ def sparse_categorical_crossentropy(out, Y):
|
|||
return out.mul(y).mean()
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categorical_crossentropy,
|
||||
transform=lambda x: x, target_transform=lambda x: x):
|
||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
|
||||
Tensor.training = True
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
|
@ -31,14 +31,15 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
|
|||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
cat = np.argmax(out.cpu().data, axis=-1)
|
||||
accuracy = (cat == y).mean()
|
||||
|
||||
# printing
|
||||
loss = loss.detach().cpu().data
|
||||
losses.append(loss)
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
if not noloss:
|
||||
cat = np.argmax(out.cpu().data, axis=-1)
|
||||
accuracy = (cat == y).mean()
|
||||
|
||||
loss = loss.detach().cpu().data
|
||||
losses.append(loss)
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x,
|
||||
target_transform=lambda y: y):
|
||||
|
|
|
@ -75,7 +75,8 @@ class TestMNIST(unittest.TestCase):
|
|||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, BS=69, steps=1)
|
||||
train(model, X_train, Y_train, optimizer, BS=69, steps=1, noloss=True)
|
||||
for p in model.parameters(): p.realize()
|
||||
|
||||
def test_conv(self):
|
||||
np.random.seed(1337)
|
||||
|
|
|
@ -82,10 +82,11 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
|
|||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
|
||||
if getattr(ret, "st", None) is not None and not ret.st.contiguous:
|
||||
G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
|
||||
#G.nodes[nm(ret)]['label'] = str(ret.shape)+"\n"+str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
|
||||
G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in ret.st.views[-1].shape_strides))
|
||||
dashed = True
|
||||
else:
|
||||
G.nodes[nm(ret)]['label'] = str(ret.shape)
|
||||
elif optype == ReduceOps: G.nodes[nm(ret)]['label'] = str(inp[0].shape)+"\n"+str(ret.shape)
|
||||
else: G.nodes[nm(ret)]['label'] = str(ret.shape)
|
||||
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else '')) if optype in top_colors else "#ffffff"
|
||||
G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled'
|
||||
|
||||
|
@ -250,8 +251,8 @@ class LazyBuffer:
|
|||
# NOTE: if ret is in the cache, it can already be realized
|
||||
if REMOVE_MOVEMENT_NOPS and ret.realized is None and x.realized is None and ret.st.contiguous:
|
||||
root = get_lazybuffers(ret.op)[0]
|
||||
if ret.st.shape == root.shape and root.st.contiguous:
|
||||
return root
|
||||
if root.st.contiguous and root != x:
|
||||
return root.movement_op(MovementOps.RESHAPE, ret.st.shape) if ret.st.shape != root.shape else root
|
||||
|
||||
return ret
|
||||
|
||||
|
|
Loading…
Reference in New Issue