diff --git a/README.md b/README.md index 151f053d..c86331ec 100644 --- a/README.md +++ b/README.md @@ -191,7 +191,7 @@ tinygrad will always be below 1000 lines. If it isn't, we will revert commits un ```bash GRAPH=1 python3 test/test_mnist.py TestMNIST.test_sgd_onestep -dot -Tsvg /tmp/net.dot -o /tmp/net.svg && open /tmp/net.svg +# requires dot, outputs /tmp/net.svg ``` ### Running tests diff --git a/test/test_conv.py b/test/test_conv.py index 7c6239d9..35b02141 100644 --- a/test/test_conv.py +++ b/test/test_conv.py @@ -45,5 +45,14 @@ class TestConv(unittest.TestCase): w = Tensor.ones(32,1,3,3) x = x.conv2d(w, padding=(1,1), groups=32) + def test_bias(self): + from tinygrad.nn import Conv2d + x = Tensor.ones(1,12,128,256) + c = Conv2d(12, 32, 3) + x = c(x) + x = x.relu() + w = Tensor.uniform(32, 1, 3, 3) + x = x.conv2d(w, groups=32) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 875b6103..45301b21 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -18,6 +18,7 @@ if GRAPH: def save_graph_exit(): print("saving", G) nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot') + os.system('dot -Tsvg /tmp/net.dot -o /tmp/net.svg') atexit.register(save_graph_exit) global_num_max = 0 diff --git a/tinygrad/shapetracker.py b/tinygrad/shapetracker.py index 9adbf693..7c9f82f4 100644 --- a/tinygrad/shapetracker.py +++ b/tinygrad/shapetracker.py @@ -73,40 +73,48 @@ class ShapeTracker: def reshape(self, *new_shape): assert all([isinstance(x, int) for x in new_shape]) assert prod(self.shape) == prod(new_shape) + if self.shape == new_shape: return self.views.append(View(new_shape, strides_for_shape(new_shape))) def permute(self, *axis): - self.contiguous = False assert all([isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis]) assert len(set(axis)) == len(axis) and len(axis) == len(self.shape) + if tuple(range(len(axis))) == axis: return + self.contiguous = False strides = strides_for_shape(self.shape) self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis])) + # TODO: this is a special case of slice with strides, remove it + # though it's nice that it can't change size + def flip(self, *axis): + self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))]) + + # *** under this line are not invertible *** + def slice(self, *arg): - self.contiguous = False assert len(arg) == len(self.shape) + if all([(x,y) == (0,s) for s,(x,y) in zip(self.shape, arg)]): return + self.contiguous = False strides = strides_for_shape(self.shape) offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)]) self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)] def expand(self, *new_shape): - self.contiguous = False assert all([isinstance(x, int) for x in new_shape]) assert all([x == y or x == 1 for x,y in zip(self.shape, new_shape)]) + if self.shape == new_shape: return + self.contiguous = False strides = [s if x == y else 0 for s,(x,y) in zip(strides_for_shape(self.shape), zip(self.shape, new_shape))] self.views.append(View(new_shape, strides)) # TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either def stride(self, *mul): - self.contiguous = False assert all([isinstance(x, int) for x in mul]) + if all([x==1 for x in mul]): return + self.contiguous = False old_strides = strides_for_shape(self.shape) strides = [z*m for z,m in zip(old_strides, mul)] new_shape = [(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)] offset = sum([(s-1)*z for s,z,m in zip(self.shape,old_strides,mul) if m < 0]) self.views.append(View(new_shape, strides, offset)) - # TODO: this is a special case of slice with strides, remove it - # though it's nice that it can't change size - def flip(self, *axis): - self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])