Transpose on GPU (#221)

* 2serious

* load/save

* fixing GPU

* added DEBUG

* needs BatchNorm or doesn't learn anything

* old file not needed

* added conv biases

* added extra/training.py and checkpoint

* assert in test only

* save

* padding

* num_classes

* checkpoint

* checkpoints for padding

* training was broken

* merge

* rotation augmentation

* more aug

* needs testing

* streamline augment, augment is fast thus bicubic

* tidying up

* transformer eval

* axis=-1

* transpose

* test for permutation using torch.movedims

* another test

* line
This commit is contained in:
Marcel Bischoff 2020-12-29 10:40:11 -05:00 committed by GitHub
parent 36579f66bf
commit dc8fa7999c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 2 deletions

View File

@ -48,7 +48,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128)
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS], device=device)).cpu().data
Y_test_preds = np.argmax(Y_test_preds_out, axis=len(Y_test.shape))
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
return (Y_test == Y_test_preds).mean()
if num_classes is None: num_classes = Y_test.max().astype(int)+1

View File

@ -127,9 +127,10 @@ class TestOps(unittest.TestCase):
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
@cpu_only
def test_transpose(self):
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device)
helper_test_op([(21,22,23,24)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)), device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)), device=self.device)
def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device)

View File

@ -161,12 +161,47 @@ def reduce_op(ctx, code, code2, inp, axis=None):
buffer_np(np.array(osize, dtype=np.int32)))
return ret
def perm_axis(ctx, inp, order):
osize = np.array(inp.shape)[list(order)]
ret = buffer_new(ctx, osize)
perm = clbuild(ctx.cl_ctx, "perm", """
__kernel void perm(__global const float *a_g, __global float *res_g, int n_axis,
__global const int *shape, __global const int *order) {
int gid = get_global_id(0);
int gi = gid;
int idx = 0;
for(int i = n_axis-1; i>-1; i--) {
int stride = 1;
for(int j=order[i]+1; j<n_axis; j++) stride *= shape[j];
idx += (gi % shape[order[i]])*stride;
gi /= shape[order[i]];
}
res_g[gid] = a_g[idx];
}""")
buffer_np = lambda x: cl.Buffer(ctx.cl_ctx,
cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
perm(ctx.cl_queue, [np.prod(osize)], None, inp.cl, ret.cl, i32(len(osize)),
buffer_np(np.array(inp.shape, dtype=np.int32)),
buffer_np(np.array(order, dtype=np.int32)))
return ret
def unbroadcast(ctx, out, in_sh):
sum_axis = [i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1] if in_sh != (1,) else None
return reduce_op(ctx, "out += a", "out", out, sum_axis)
# ***** now for the ops themselves *****
class Transpose(Function):
@staticmethod
def forward(ctx, x, order=(1,0)):
ctx.save_for_backward(order)
return perm_axis(ctx, x, order)
@staticmethod
def backward(ctx, grad_output):
return perm_axis(ctx, grad_output, np.argsort(ctx.order))
register('transpose', Transpose, device=Device.GPU)
class Add(Function):
@staticmethod
def forward(ctx, x, y):