diff --git a/test/test_mnist.py b/test/test_mnist.py index e706a374..39d7eecd 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -25,9 +25,9 @@ class TinyBobNet: class TinyConvNet: def __init__(self): conv = 7 - self.chans = 4 - self.c1 = Tensor(layer_init_uniform(self.chans,1,conv,conv)) - self.l1 = Tensor(layer_init_uniform(((28-conv+1)**2)*self.chans, 128)) + chans = 4 + self.c1 = Tensor(layer_init_uniform(chans,1,conv,conv)) + self.l1 = Tensor(layer_init_uniform(((28-conv+1)**2)*chans, 128)) self.l2 = Tensor(layer_init_uniform(128, 10)) def forward(self, x):