From dc8fa7999c641c5ddf515a3e4e646a859768514c Mon Sep 17 00:00:00 2001 From: Marcel Bischoff <65973015+marcelbischoff@users.noreply.github.com> Date: Tue, 29 Dec 2020 10:40:11 -0500 Subject: [PATCH] Transpose on GPU (#221) * 2serious * load/save * fixing GPU * added DEBUG * needs BatchNorm or doesn't learn anything * old file not needed * added conv biases * added extra/training.py and checkpoint * assert in test only * save * padding * num_classes * checkpoint * checkpoints for padding * training was broken * merge * rotation augmentation * more aug * needs testing * streamline augment, augment is fast thus bicubic * tidying up * transformer eval * axis=-1 * transpose * test for permutation using torch.movedims * another test * line --- extra/training.py | 2 +- test/test_ops.py | 3 ++- tinygrad/ops_gpu.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/extra/training.py b/extra/training.py index a3e153cc..d3462fa4 100644 --- a/extra/training.py +++ b/extra/training.py @@ -48,7 +48,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128) Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes]) for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None): Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS], device=device)).cpu().data - Y_test_preds = np.argmax(Y_test_preds_out, axis=len(Y_test.shape)) + Y_test_preds = np.argmax(Y_test_preds_out, axis=-1) return (Y_test == Y_test_preds).mean() if num_classes is None: num_classes = Y_test.max().astype(int)+1 diff --git a/test/test_ops.py b/test/test_ops.py index 97ae5669..457f4873 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -127,9 +127,10 @@ class TestOps(unittest.TestCase): def test_pad2d(self): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device) - @cpu_only def test_transpose(self): helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device) + helper_test_op([(21,22,23,24)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)), device=self.device) + helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)), device=self.device) def test_reshape(self): helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device) diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 954d69d3..d366dbe2 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -161,12 +161,47 @@ def reduce_op(ctx, code, code2, inp, axis=None): buffer_np(np.array(osize, dtype=np.int32))) return ret +def perm_axis(ctx, inp, order): + osize = np.array(inp.shape)[list(order)] + ret = buffer_new(ctx, osize) + perm = clbuild(ctx.cl_ctx, "perm", """ + __kernel void perm(__global const float *a_g, __global float *res_g, int n_axis, + __global const int *shape, __global const int *order) { + int gid = get_global_id(0); + int gi = gid; + int idx = 0; + for(int i = n_axis-1; i>-1; i--) { + int stride = 1; + for(int j=order[i]+1; j1] if in_sh != (1,) else None return reduce_op(ctx, "out += a", "out", out, sum_axis) # ***** now for the ops themselves ***** +class Transpose(Function): + @staticmethod + def forward(ctx, x, order=(1,0)): + ctx.save_for_backward(order) + return perm_axis(ctx, x, order) + + @staticmethod + def backward(ctx, grad_output): + return perm_axis(ctx, grad_output, np.argsort(ctx.order)) +register('transpose', Transpose, device=Device.GPU) + class Add(Function): @staticmethod def forward(ctx, x, y):