diff --git a/test/graph_batchnorm.py b/test/graph_batchnorm.py index 2bb894a6..772c65f1 100644 --- a/test/graph_batchnorm.py +++ b/test/graph_batchnorm.py @@ -4,28 +4,51 @@ from tinygrad import optim from extra.utils import get_parameters # TODO: move to optim import unittest -class TestBatchnorm(unittest.TestCase): - def test_conv_bn(self): - Tensor.training = True +def model_step(lm): + Tensor.training = True + x = Tensor.ones(8,12,128,256, requires_grad=False) + loss = lm.forward(x).sum() + optimizer = optim.SGD(get_parameters(lm), lr=0.001) + optimizer.zero_grad() + loss.backward() + optimizer.step() + #out = loss.detach().numpy() + for p in optimizer.params: + p.realize() + #x.grad.realize() + Tensor.training = False - x = Tensor.ones(1,12,128,256, requires_grad=False) + +class TestBatchnorm(unittest.TestCase): + def test_conv(self): + class LilModel: + def __init__(self): + self.c = Conv2d(12, 32, 3, padding=1, bias=False) + def forward(self, x): + return self.c(x).relu() + lm = LilModel() + model_step(lm) + + def test_two_conv(self): + class LilModel: + def __init__(self): + self.c = Conv2d(12, 32, 3, padding=1, bias=False) + self.c2 = Conv2d(32, 32, 3, padding=1, bias=False) + def forward(self, x): + return self.c2(self.c(x)).relu() + lm = LilModel() + model_step(lm) + + def test_conv_bn(self): class LilModel: def __init__(self): self.c = Conv2d(12, 32, 3, padding=1, bias=False) self.bn = BatchNorm2D(32, track_running_stats=False) def forward(self, x): return self.bn(self.c(x)).relu() - lm = LilModel() - loss = lm.forward(x).sum() - optimizer = optim.SGD(get_parameters(lm), lr=0.001) - optimizer.zero_grad() - loss.backward() - optimizer.step() - #out = loss.detach().numpy() - for p in optimizer.params: - p.realize() - Tensor.training = False + model_step(lm) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c1d391c4..bada0f34 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -53,8 +53,14 @@ def log_op(optype, op, ret, inp): if nm(ret) not in G.nodes: G.add_node(nm(ret)) st = getattr(ret, "st", None) non_contiguous = st is not None and not st.contiguous - G.nodes[nm(ret)]['label'] = str(ret.shape) - G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff" + if non_contiguous: + G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in st.views[-1].shape_strides)) + else: + G.nodes[nm(ret)]['label'] = str(ret.shape) + if 'contiguous' in str(op).lower(): + G.nodes[nm(ret)]['fillcolor'] = '#FFFF80' + else: + G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff" G.nodes[nm(ret)]['style'] = 'filled, dashed' if non_contiguous else 'filled' class Ops: