mirror of https://github.com/commaai/tinygrad.git
shapetracker check for noop
This commit is contained in:
parent
52505faaf4
commit
a11deb5150
|
@ -191,7 +191,7 @@ tinygrad will always be below 1000 lines. If it isn't, we will revert commits un
|
|||
|
||||
```bash
|
||||
GRAPH=1 python3 test/test_mnist.py TestMNIST.test_sgd_onestep
|
||||
dot -Tsvg /tmp/net.dot -o /tmp/net.svg && open /tmp/net.svg
|
||||
# requires dot, outputs /tmp/net.svg
|
||||
```
|
||||
|
||||
### Running tests
|
||||
|
|
|
@ -45,5 +45,14 @@ class TestConv(unittest.TestCase):
|
|||
w = Tensor.ones(32,1,3,3)
|
||||
x = x.conv2d(w, padding=(1,1), groups=32)
|
||||
|
||||
def test_bias(self):
|
||||
from tinygrad.nn import Conv2d
|
||||
x = Tensor.ones(1,12,128,256)
|
||||
c = Conv2d(12, 32, 3)
|
||||
x = c(x)
|
||||
x = x.relu()
|
||||
w = Tensor.uniform(32, 1, 3, 3)
|
||||
x = x.conv2d(w, groups=32)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -18,6 +18,7 @@ if GRAPH:
|
|||
def save_graph_exit():
|
||||
print("saving", G)
|
||||
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
|
||||
os.system('dot -Tsvg /tmp/net.dot -o /tmp/net.svg')
|
||||
atexit.register(save_graph_exit)
|
||||
|
||||
global_num_max = 0
|
||||
|
|
|
@ -73,40 +73,48 @@ class ShapeTracker:
|
|||
def reshape(self, *new_shape):
|
||||
assert all([isinstance(x, int) for x in new_shape])
|
||||
assert prod(self.shape) == prod(new_shape)
|
||||
if self.shape == new_shape: return
|
||||
self.views.append(View(new_shape, strides_for_shape(new_shape)))
|
||||
|
||||
def permute(self, *axis):
|
||||
self.contiguous = False
|
||||
assert all([isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis])
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape)
|
||||
if tuple(range(len(axis))) == axis: return
|
||||
self.contiguous = False
|
||||
strides = strides_for_shape(self.shape)
|
||||
self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis]))
|
||||
|
||||
# TODO: this is a special case of slice with strides, remove it
|
||||
# though it's nice that it can't change size
|
||||
def flip(self, *axis):
|
||||
self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
|
||||
|
||||
# *** under this line are not invertible ***
|
||||
|
||||
def slice(self, *arg):
|
||||
self.contiguous = False
|
||||
assert len(arg) == len(self.shape)
|
||||
if all([(x,y) == (0,s) for s,(x,y) in zip(self.shape, arg)]): return
|
||||
self.contiguous = False
|
||||
strides = strides_for_shape(self.shape)
|
||||
offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)]
|
||||
|
||||
def expand(self, *new_shape):
|
||||
self.contiguous = False
|
||||
assert all([isinstance(x, int) for x in new_shape])
|
||||
assert all([x == y or x == 1 for x,y in zip(self.shape, new_shape)])
|
||||
if self.shape == new_shape: return
|
||||
self.contiguous = False
|
||||
strides = [s if x == y else 0 for s,(x,y) in zip(strides_for_shape(self.shape), zip(self.shape, new_shape))]
|
||||
self.views.append(View(new_shape, strides))
|
||||
|
||||
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
|
||||
def stride(self, *mul):
|
||||
self.contiguous = False
|
||||
assert all([isinstance(x, int) for x in mul])
|
||||
if all([x==1 for x in mul]): return
|
||||
self.contiguous = False
|
||||
old_strides = strides_for_shape(self.shape)
|
||||
strides = [z*m for z,m in zip(old_strides, mul)]
|
||||
new_shape = [(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape,old_strides,mul) if m < 0])
|
||||
self.views.append(View(new_shape, strides, offset))
|
||||
|
||||
# TODO: this is a special case of slice with strides, remove it
|
||||
# though it's nice that it can't change size
|
||||
def flip(self, *axis):
|
||||
self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
|
||||
|
|
Loading…
Reference in New Issue