make contiguous ops yellow

This commit is contained in:
George Hotz 2022-07-02 17:54:04 -07:00
parent 207b9e1df3
commit f9a8412b68
2 changed files with 45 additions and 16 deletions

View File

@ -4,28 +4,51 @@ from tinygrad import optim
from extra.utils import get_parameters # TODO: move to optim
import unittest
class TestBatchnorm(unittest.TestCase):
def test_conv_bn(self):
Tensor.training = True
def model_step(lm):
Tensor.training = True
x = Tensor.ones(8,12,128,256, requires_grad=False)
loss = lm.forward(x).sum()
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#out = loss.detach().numpy()
for p in optimizer.params:
p.realize()
#x.grad.realize()
Tensor.training = False
x = Tensor.ones(1,12,128,256, requires_grad=False)
class TestBatchnorm(unittest.TestCase):
def test_conv(self):
class LilModel:
def __init__(self):
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
def forward(self, x):
return self.c(x).relu()
lm = LilModel()
model_step(lm)
def test_two_conv(self):
class LilModel:
def __init__(self):
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
self.c2 = Conv2d(32, 32, 3, padding=1, bias=False)
def forward(self, x):
return self.c2(self.c(x)).relu()
lm = LilModel()
model_step(lm)
def test_conv_bn(self):
class LilModel:
def __init__(self):
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
self.bn = BatchNorm2D(32, track_running_stats=False)
def forward(self, x):
return self.bn(self.c(x)).relu()
lm = LilModel()
loss = lm.forward(x).sum()
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#out = loss.detach().numpy()
for p in optimizer.params:
p.realize()
Tensor.training = False
model_step(lm)
if __name__ == '__main__':
unittest.main()

View File

@ -53,8 +53,14 @@ def log_op(optype, op, ret, inp):
if nm(ret) not in G.nodes: G.add_node(nm(ret))
st = getattr(ret, "st", None)
non_contiguous = st is not None and not st.contiguous
G.nodes[nm(ret)]['label'] = str(ret.shape)
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff"
if non_contiguous:
G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in st.views[-1].shape_strides))
else:
G.nodes[nm(ret)]['label'] = str(ret.shape)
if 'contiguous' in str(op).lower():
G.nodes[nm(ret)]['fillcolor'] = '#FFFF80'
else:
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff"
G.nodes[nm(ret)]['style'] = 'filled, dashed' if non_contiguous else 'filled'
class Ops: