mirror of https://github.com/commaai/tinygrad.git
shapetracker check for noop
This commit is contained in:
parent
52505faaf4
commit
a11deb5150
|
@ -191,7 +191,7 @@ tinygrad will always be below 1000 lines. If it isn't, we will revert commits un
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
GRAPH=1 python3 test/test_mnist.py TestMNIST.test_sgd_onestep
|
GRAPH=1 python3 test/test_mnist.py TestMNIST.test_sgd_onestep
|
||||||
dot -Tsvg /tmp/net.dot -o /tmp/net.svg && open /tmp/net.svg
|
# requires dot, outputs /tmp/net.svg
|
||||||
```
|
```
|
||||||
|
|
||||||
### Running tests
|
### Running tests
|
||||||
|
|
|
@ -45,5 +45,14 @@ class TestConv(unittest.TestCase):
|
||||||
w = Tensor.ones(32,1,3,3)
|
w = Tensor.ones(32,1,3,3)
|
||||||
x = x.conv2d(w, padding=(1,1), groups=32)
|
x = x.conv2d(w, padding=(1,1), groups=32)
|
||||||
|
|
||||||
|
def test_bias(self):
|
||||||
|
from tinygrad.nn import Conv2d
|
||||||
|
x = Tensor.ones(1,12,128,256)
|
||||||
|
c = Conv2d(12, 32, 3)
|
||||||
|
x = c(x)
|
||||||
|
x = x.relu()
|
||||||
|
w = Tensor.uniform(32, 1, 3, 3)
|
||||||
|
x = x.conv2d(w, groups=32)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
|
@ -18,6 +18,7 @@ if GRAPH:
|
||||||
def save_graph_exit():
|
def save_graph_exit():
|
||||||
print("saving", G)
|
print("saving", G)
|
||||||
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
|
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
|
||||||
|
os.system('dot -Tsvg /tmp/net.dot -o /tmp/net.svg')
|
||||||
atexit.register(save_graph_exit)
|
atexit.register(save_graph_exit)
|
||||||
|
|
||||||
global_num_max = 0
|
global_num_max = 0
|
||||||
|
|
|
@ -73,40 +73,48 @@ class ShapeTracker:
|
||||||
def reshape(self, *new_shape):
|
def reshape(self, *new_shape):
|
||||||
assert all([isinstance(x, int) for x in new_shape])
|
assert all([isinstance(x, int) for x in new_shape])
|
||||||
assert prod(self.shape) == prod(new_shape)
|
assert prod(self.shape) == prod(new_shape)
|
||||||
|
if self.shape == new_shape: return
|
||||||
self.views.append(View(new_shape, strides_for_shape(new_shape)))
|
self.views.append(View(new_shape, strides_for_shape(new_shape)))
|
||||||
|
|
||||||
def permute(self, *axis):
|
def permute(self, *axis):
|
||||||
self.contiguous = False
|
|
||||||
assert all([isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis])
|
assert all([isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis])
|
||||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape)
|
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape)
|
||||||
|
if tuple(range(len(axis))) == axis: return
|
||||||
|
self.contiguous = False
|
||||||
strides = strides_for_shape(self.shape)
|
strides = strides_for_shape(self.shape)
|
||||||
self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis]))
|
self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis]))
|
||||||
|
|
||||||
|
# TODO: this is a special case of slice with strides, remove it
|
||||||
|
# though it's nice that it can't change size
|
||||||
|
def flip(self, *axis):
|
||||||
|
self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
|
||||||
|
|
||||||
|
# *** under this line are not invertible ***
|
||||||
|
|
||||||
def slice(self, *arg):
|
def slice(self, *arg):
|
||||||
self.contiguous = False
|
|
||||||
assert len(arg) == len(self.shape)
|
assert len(arg) == len(self.shape)
|
||||||
|
if all([(x,y) == (0,s) for s,(x,y) in zip(self.shape, arg)]): return
|
||||||
|
self.contiguous = False
|
||||||
strides = strides_for_shape(self.shape)
|
strides = strides_for_shape(self.shape)
|
||||||
offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)])
|
offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||||
self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)]
|
self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)]
|
||||||
|
|
||||||
def expand(self, *new_shape):
|
def expand(self, *new_shape):
|
||||||
self.contiguous = False
|
|
||||||
assert all([isinstance(x, int) for x in new_shape])
|
assert all([isinstance(x, int) for x in new_shape])
|
||||||
assert all([x == y or x == 1 for x,y in zip(self.shape, new_shape)])
|
assert all([x == y or x == 1 for x,y in zip(self.shape, new_shape)])
|
||||||
|
if self.shape == new_shape: return
|
||||||
|
self.contiguous = False
|
||||||
strides = [s if x == y else 0 for s,(x,y) in zip(strides_for_shape(self.shape), zip(self.shape, new_shape))]
|
strides = [s if x == y else 0 for s,(x,y) in zip(strides_for_shape(self.shape), zip(self.shape, new_shape))]
|
||||||
self.views.append(View(new_shape, strides))
|
self.views.append(View(new_shape, strides))
|
||||||
|
|
||||||
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
|
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
|
||||||
def stride(self, *mul):
|
def stride(self, *mul):
|
||||||
self.contiguous = False
|
|
||||||
assert all([isinstance(x, int) for x in mul])
|
assert all([isinstance(x, int) for x in mul])
|
||||||
|
if all([x==1 for x in mul]): return
|
||||||
|
self.contiguous = False
|
||||||
old_strides = strides_for_shape(self.shape)
|
old_strides = strides_for_shape(self.shape)
|
||||||
strides = [z*m for z,m in zip(old_strides, mul)]
|
strides = [z*m for z,m in zip(old_strides, mul)]
|
||||||
new_shape = [(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]
|
new_shape = [(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]
|
||||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape,old_strides,mul) if m < 0])
|
offset = sum([(s-1)*z for s,z,m in zip(self.shape,old_strides,mul) if m < 0])
|
||||||
self.views.append(View(new_shape, strides, offset))
|
self.views.append(View(new_shape, strides, offset))
|
||||||
|
|
||||||
# TODO: this is a special case of slice with strides, remove it
|
|
||||||
# though it's nice that it can't change size
|
|
||||||
def flip(self, *axis):
|
|
||||||
self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
|
|
||||||
|
|
Loading…
Reference in New Issue