shapetracker check for noop

This commit is contained in:
George Hotz 2022-06-16 16:29:18 -07:00
parent 52505faaf4
commit a11deb5150
4 changed files with 27 additions and 9 deletions

View File

@ -191,7 +191,7 @@ tinygrad will always be below 1000 lines. If it isn't, we will revert commits un
```bash
GRAPH=1 python3 test/test_mnist.py TestMNIST.test_sgd_onestep
dot -Tsvg /tmp/net.dot -o /tmp/net.svg && open /tmp/net.svg
# requires dot, outputs /tmp/net.svg
```
### Running tests

View File

@ -45,5 +45,14 @@ class TestConv(unittest.TestCase):
w = Tensor.ones(32,1,3,3)
x = x.conv2d(w, padding=(1,1), groups=32)
def test_bias(self):
from tinygrad.nn import Conv2d
x = Tensor.ones(1,12,128,256)
c = Conv2d(12, 32, 3)
x = c(x)
x = x.relu()
w = Tensor.uniform(32, 1, 3, 3)
x = x.conv2d(w, groups=32)
if __name__ == '__main__':
unittest.main()

View File

@ -18,6 +18,7 @@ if GRAPH:
def save_graph_exit():
print("saving", G)
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
os.system('dot -Tsvg /tmp/net.dot -o /tmp/net.svg')
atexit.register(save_graph_exit)
global_num_max = 0

View File

@ -73,40 +73,48 @@ class ShapeTracker:
def reshape(self, *new_shape):
assert all([isinstance(x, int) for x in new_shape])
assert prod(self.shape) == prod(new_shape)
if self.shape == new_shape: return
self.views.append(View(new_shape, strides_for_shape(new_shape)))
def permute(self, *axis):
self.contiguous = False
assert all([isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis])
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape)
if tuple(range(len(axis))) == axis: return
self.contiguous = False
strides = strides_for_shape(self.shape)
self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis]))
# TODO: this is a special case of slice with strides, remove it
# though it's nice that it can't change size
def flip(self, *axis):
self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
# *** under this line are not invertible ***
def slice(self, *arg):
self.contiguous = False
assert len(arg) == len(self.shape)
if all([(x,y) == (0,s) for s,(x,y) in zip(self.shape, arg)]): return
self.contiguous = False
strides = strides_for_shape(self.shape)
offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)])
self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)]
def expand(self, *new_shape):
self.contiguous = False
assert all([isinstance(x, int) for x in new_shape])
assert all([x == y or x == 1 for x,y in zip(self.shape, new_shape)])
if self.shape == new_shape: return
self.contiguous = False
strides = [s if x == y else 0 for s,(x,y) in zip(strides_for_shape(self.shape), zip(self.shape, new_shape))]
self.views.append(View(new_shape, strides))
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
def stride(self, *mul):
self.contiguous = False
assert all([isinstance(x, int) for x in mul])
if all([x==1 for x in mul]): return
self.contiguous = False
old_strides = strides_for_shape(self.shape)
strides = [z*m for z,m in zip(old_strides, mul)]
new_shape = [(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]
offset = sum([(s-1)*z for s,z,m in zip(self.shape,old_strides,mul) if m < 0])
self.views.append(View(new_shape, strides, offset))
# TODO: this is a special case of slice with strides, remove it
# though it's nice that it can't change size
def flip(self, *axis):
self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])