From ce15bf2bdbb164c2d37024c4b2442894718a0465 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 16 Jun 2022 11:41:29 -0700 Subject: [PATCH] the big memory gradient didn't even need to be computed --- test/test_train.py | 1 - tinygrad/mlops.py | 53 ++++++++++++++++++++++++---------------------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/test/test_train.py b/test/test_train.py index d3b7f818..6ca34641 100644 --- a/test/test_train.py +++ b/test/test_train.py @@ -32,7 +32,6 @@ class TestTrain(unittest.TestCase): Y = np.zeros((BS), dtype=np.int32) train_one_step(model,X,Y) - @unittest.skip("OOM in GPU test") def test_vit(self): model = ViT() X = np.zeros((BS,3,224,224), dtype=np.float32) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index f5d5925b..ba5843e0 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -174,30 +174,33 @@ class Conv2D(Function): def backward(ctx, grad_output): x, w, C = ctx.saved_tensors - #dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None - xt = grad_output - if C.xs > 1 or C.ys > 1: # unstride - xt = ctx.movement_op(MovementOps.RESHAPE, xt, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1)) - xt = ctx.movement_op(MovementOps.SLICE, xt, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.ys), (0,xt.shape[4]), (0,C.xs))) - xt = ctx.movement_op(MovementOps.RESHAPE, xt, (xt.shape[0], xt.shape[1], xt.shape[2]*C.ys, xt.shape[4]*C.xs)) - wt = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W)) - wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4)) - wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4)) - wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W)) - Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=C.groups) - # TODO: this shape can be wrong. support asymmetric padding to remove the slice - dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, (Cdx.bs, Cdx.cout, Cdx.oy, Cdx.ox), Cdx) - dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape]) + dx, dw = None, None + if ctx.needs_input_grad[0]: + #dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None + xt = grad_output + if C.xs > 1 or C.ys > 1: # unstride. note, this is really memory intensive for big strides. + xt = ctx.movement_op(MovementOps.RESHAPE, xt, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1)) + xt = ctx.movement_op(MovementOps.SLICE, xt, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.ys), (0,xt.shape[4]), (0,C.xs))) + xt = ctx.movement_op(MovementOps.RESHAPE, xt, (xt.shape[0], xt.shape[1], xt.shape[2]*C.ys, xt.shape[4]*C.xs)) + wt = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W)) + wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4)) + wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4)) + wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W)) + Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=C.groups) + # TODO: this shape can be wrong. support asymmetric padding to remove the slice + dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, (Cdx.bs, Cdx.cout, Cdx.oy, Cdx.ox), Cdx) + dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape]) - # compute derivative of weights using ProcessingOps.CONV - xdw = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix)) - xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (2,1,0,3,4)) - xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix)) - grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3)) - grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox)) - Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups) - grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (C.cin, C.cout, Cdw.oy, Cdw.ox), Cdw) - grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3)) - # TODO: remove this slice using asymmetric padding - dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3]))) + if ctx.needs_input_grad[1]: + # compute derivative of weights using ProcessingOps.CONV + xdw = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix)) + xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (2,1,0,3,4)) + xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix)) + grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3)) + grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox)) + Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups) + grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (C.cin, C.cout, Cdw.oy, Cdw.ox), Cdw) + grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3)) + # TODO: remove this slice using asymmetric padding + dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3]))) return dx, dw