mirror of https://github.com/commaai/tinygrad.git
make contiguous ops yellow
This commit is contained in:
parent
207b9e1df3
commit
f9a8412b68
|
@ -4,19 +4,9 @@ from tinygrad import optim
|
|||
from extra.utils import get_parameters # TODO: move to optim
|
||||
import unittest
|
||||
|
||||
class TestBatchnorm(unittest.TestCase):
|
||||
def test_conv_bn(self):
|
||||
def model_step(lm):
|
||||
Tensor.training = True
|
||||
|
||||
x = Tensor.ones(1,12,128,256, requires_grad=False)
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
|
||||
self.bn = BatchNorm2D(32, track_running_stats=False)
|
||||
def forward(self, x):
|
||||
return self.bn(self.c(x)).relu()
|
||||
|
||||
lm = LilModel()
|
||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
optimizer.zero_grad()
|
||||
|
@ -25,7 +15,40 @@ class TestBatchnorm(unittest.TestCase):
|
|||
#out = loss.detach().numpy()
|
||||
for p in optimizer.params:
|
||||
p.realize()
|
||||
#x.grad.realize()
|
||||
Tensor.training = False
|
||||
|
||||
|
||||
class TestBatchnorm(unittest.TestCase):
|
||||
def test_conv(self):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
|
||||
def forward(self, x):
|
||||
return self.c(x).relu()
|
||||
lm = LilModel()
|
||||
model_step(lm)
|
||||
|
||||
def test_two_conv(self):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
|
||||
self.c2 = Conv2d(32, 32, 3, padding=1, bias=False)
|
||||
def forward(self, x):
|
||||
return self.c2(self.c(x)).relu()
|
||||
lm = LilModel()
|
||||
model_step(lm)
|
||||
|
||||
def test_conv_bn(self):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
|
||||
self.bn = BatchNorm2D(32, track_running_stats=False)
|
||||
def forward(self, x):
|
||||
return self.bn(self.c(x)).relu()
|
||||
lm = LilModel()
|
||||
model_step(lm)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -53,7 +53,13 @@ def log_op(optype, op, ret, inp):
|
|||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
st = getattr(ret, "st", None)
|
||||
non_contiguous = st is not None and not st.contiguous
|
||||
if non_contiguous:
|
||||
G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in st.views[-1].shape_strides))
|
||||
else:
|
||||
G.nodes[nm(ret)]['label'] = str(ret.shape)
|
||||
if 'contiguous' in str(op).lower():
|
||||
G.nodes[nm(ret)]['fillcolor'] = '#FFFF80'
|
||||
else:
|
||||
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff"
|
||||
G.nodes[nm(ret)]['style'] = 'filled, dashed' if non_contiguous else 'filled'
|
||||
|
||||
|
|
Loading…
Reference in New Issue