mirror of https://github.com/commaai/tinygrad.git
tinygrad does forward pass convs on GPU
This commit is contained in:
parent
23c39d9f52
commit
ec03eb44bd
|
@ -14,7 +14,6 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7, gpu
|
|||
out = torch_fxn(*ts)
|
||||
ret = tinygrad_fxn(*tst)
|
||||
|
||||
# TODO: why so inaccurate?
|
||||
np.testing.assert_allclose(ret.cpu().data, out.detach().numpy(), atol=atol)
|
||||
|
||||
if not forward_only:
|
||||
|
@ -66,11 +65,11 @@ class TestOps(unittest.TestCase):
|
|||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
for groups in [1,3] if cin == 3 else [1]:
|
||||
for H in [2,5]:
|
||||
for W in [2,3,5]:
|
||||
for H in [1,2,5]:
|
||||
for W in [1,2,3,5]:
|
||||
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu, forward_only=self.gpu)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
|
|
|
@ -304,4 +304,64 @@ class LogSoftmax(Function):
|
|||
return grad_input
|
||||
register('logsoftmax', LogSoftmax, gpu=True)
|
||||
|
||||
# ************* conv ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w, stride=1, groups=1):
|
||||
if type(ctx.stride) == int:
|
||||
ctx.stride = (ctx.stride, ctx.stride)
|
||||
cout,cin,H,W = w.shape
|
||||
ys,xs = ctx.stride
|
||||
bs,cin_,iy,ix = x.shape
|
||||
oy,ox = (iy-(H-ys))//ys, (ix-(W-xs))//xs
|
||||
assert cin*ctx.groups == cin_
|
||||
assert cout % ctx.groups == 0
|
||||
rcout = cout//ctx.groups
|
||||
|
||||
# output buffer
|
||||
ret = buffer_new(ctx, (bs, cout, oy, ox))
|
||||
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix) {
|
||||
|
||||
int B = get_global_id(0); // range 0-bs
|
||||
int Y = get_global_id(1); // range 0-oy
|
||||
int X = get_global_id(2); // range 0-ox
|
||||
|
||||
// input = (bs, groups, cin, iy, ix)
|
||||
// weight = (groups, rcout, cin, H, W)
|
||||
// output = (bs, groups, rcout, oy, ox)
|
||||
for (int g = 0; g < groups; g++) {
|
||||
for (int c = 0; c < rcout; c++) {
|
||||
float acc = 0.0;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = Y; y < Y+H; y++) {
|
||||
for (int x = X; x < X+W; x++) {
|
||||
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-Y)*W + (x-X)];
|
||||
}
|
||||
}
|
||||
}
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
prg.conv(ctx.cl_queue, [bs, oy, ox], None,
|
||||
x, w, ret,
|
||||
np.int32(H), np.int32(W),
|
||||
np.int32(groups), np.int32(rcout), np.int32(cin),
|
||||
np.int32(oy), np.int32(ox),
|
||||
np.int32(iy), np.int32(ix)
|
||||
)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise Exception("not implemented")
|
||||
|
||||
register('conv2d', Conv2D, gpu=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue