mirror of https://github.com/commaai/tinygrad.git
Transpose on GPU (#221)
* 2serious * load/save * fixing GPU * added DEBUG * needs BatchNorm or doesn't learn anything * old file not needed * added conv biases * added extra/training.py and checkpoint * assert in test only * save * padding * num_classes * checkpoint * checkpoints for padding * training was broken * merge * rotation augmentation * more aug * needs testing * streamline augment, augment is fast thus bicubic * tidying up * transformer eval * axis=-1 * transpose * test for permutation using torch.movedims * another test * line
This commit is contained in:
parent
36579f66bf
commit
dc8fa7999c
|
@ -48,7 +48,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128)
|
|||
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS], device=device)).cpu().data
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=len(Y_test.shape))
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
if num_classes is None: num_classes = Y_test.max().astype(int)+1
|
||||
|
|
|
@ -127,9 +127,10 @@ class TestOps(unittest.TestCase):
|
|||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
|
||||
|
||||
@cpu_only
|
||||
def test_transpose(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device)
|
||||
helper_test_op([(21,22,23,24)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)), device=self.device)
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device)
|
||||
|
|
|
@ -161,12 +161,47 @@ def reduce_op(ctx, code, code2, inp, axis=None):
|
|||
buffer_np(np.array(osize, dtype=np.int32)))
|
||||
return ret
|
||||
|
||||
def perm_axis(ctx, inp, order):
|
||||
osize = np.array(inp.shape)[list(order)]
|
||||
ret = buffer_new(ctx, osize)
|
||||
perm = clbuild(ctx.cl_ctx, "perm", """
|
||||
__kernel void perm(__global const float *a_g, __global float *res_g, int n_axis,
|
||||
__global const int *shape, __global const int *order) {
|
||||
int gid = get_global_id(0);
|
||||
int gi = gid;
|
||||
int idx = 0;
|
||||
for(int i = n_axis-1; i>-1; i--) {
|
||||
int stride = 1;
|
||||
for(int j=order[i]+1; j<n_axis; j++) stride *= shape[j];
|
||||
idx += (gi % shape[order[i]])*stride;
|
||||
gi /= shape[order[i]];
|
||||
}
|
||||
res_g[gid] = a_g[idx];
|
||||
}""")
|
||||
buffer_np = lambda x: cl.Buffer(ctx.cl_ctx,
|
||||
cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
|
||||
perm(ctx.cl_queue, [np.prod(osize)], None, inp.cl, ret.cl, i32(len(osize)),
|
||||
buffer_np(np.array(inp.shape, dtype=np.int32)),
|
||||
buffer_np(np.array(order, dtype=np.int32)))
|
||||
return ret
|
||||
|
||||
def unbroadcast(ctx, out, in_sh):
|
||||
sum_axis = [i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1] if in_sh != (1,) else None
|
||||
return reduce_op(ctx, "out += a", "out", out, sum_axis)
|
||||
|
||||
# ***** now for the ops themselves *****
|
||||
|
||||
class Transpose(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, order=(1,0)):
|
||||
ctx.save_for_backward(order)
|
||||
return perm_axis(ctx, x, order)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return perm_axis(ctx, grad_output, np.argsort(ctx.order))
|
||||
register('transpose', Transpose, device=Device.GPU)
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
|
|
Loading…
Reference in New Issue