mirror of https://github.com/commaai/tinygrad.git
fix conv args. fix spacing
This commit is contained in:
parent
365e62a609
commit
b49bfb6e02
|
@ -192,36 +192,36 @@ def inner_slice(ctx, x, arg):
|
|||
|
||||
# c = a@b
|
||||
def matmul(a, b, c, transpose_a=False, transpose_b=False):
|
||||
cnt = np.prod(a.shape[0:-2]) if len(a.shape) > 2 else 1
|
||||
isize, msize, osize = i32(a.shape[-2]), i32(a.shape[-1]), i32(c.shape[-1])
|
||||
if transpose_a: isize,msize = msize,isize
|
||||
assert isize == c.shape[-2]
|
||||
assert (msize == b.shape[-1]) if transpose_b else (msize == b.shape[-2])
|
||||
assert (osize == b.shape[-2]) if transpose_b else (osize == b.shape[-1])
|
||||
|
||||
matmul_prg = clbuild("matmul", """
|
||||
__kernel void matmul(
|
||||
__global const float *input, __global const float *weight, __global float *res,
|
||||
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
|
||||
) {
|
||||
int stride = get_global_id(2);
|
||||
cnt = np.prod(a.shape[0:-2]) if len(a.shape) > 2 else 1
|
||||
isize, msize, osize = i32(a.shape[-2]), i32(a.shape[-1]), i32(c.shape[-1])
|
||||
if transpose_a: isize,msize = msize,isize
|
||||
assert isize == c.shape[-2]
|
||||
assert (msize == b.shape[-1]) if transpose_b else (msize == b.shape[-2])
|
||||
assert (osize == b.shape[-2]) if transpose_b else (osize == b.shape[-1])
|
||||
|
||||
matmul_prg = clbuild("matmul", """
|
||||
__kernel void matmul(
|
||||
__global const float *input, __global const float *weight, __global float *res,
|
||||
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
|
||||
) {
|
||||
int stride = get_global_id(2);
|
||||
|
||||
int X = get_global_id(0); // isize
|
||||
int Y = get_global_id(1); // osize
|
||||
int X = get_global_id(0); // isize
|
||||
int Y = get_global_id(1); // osize
|
||||
|
||||
float ret = 0.0;
|
||||
for (int x = 0; x < msize; x++) {
|
||||
ret += input[X * is0 + x * is1 + isize*msize*stride] *
|
||||
weight[Y * ws0 + x * ws1 + msize*osize*stride];
|
||||
}
|
||||
float ret = 0.0;
|
||||
for (int x = 0; x < msize; x++) {
|
||||
ret += input[X * is0 + x * is1 + isize*msize*stride] *
|
||||
weight[Y * ws0 + x * ws1 + msize*osize*stride];
|
||||
}
|
||||
|
||||
res[X * osize + Y + isize*osize*stride] = ret;
|
||||
}""")
|
||||
res[X * osize + Y + isize*osize*stride] = ret;
|
||||
}""")
|
||||
|
||||
matmul_prg([isize, osize, cnt], None,
|
||||
a.cl, b.cl, c.cl,
|
||||
isize,
|
||||
msize if not transpose_a else i32(1), i32(1) if not transpose_a else isize,
|
||||
msize,
|
||||
i32(1) if not transpose_b else msize, osize if not transpose_b else i32(1),
|
||||
osize)
|
||||
matmul_prg([isize, osize, cnt], None,
|
||||
a.cl, b.cl, c.cl,
|
||||
isize,
|
||||
msize if not transpose_a else i32(1), i32(1) if not transpose_a else isize,
|
||||
msize,
|
||||
i32(1) if not transpose_b else msize, osize if not transpose_b else i32(1),
|
||||
osize)
|
||||
|
|
|
@ -199,11 +199,8 @@ class Conv2D(Function):
|
|||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}""")
|
||||
|
||||
conv([bs*groups*rcout, oy, ox], None,
|
||||
x.cl, w.cl, ret.cl,
|
||||
i32(H), i32(W), i32(groups), i32(rcout), i32(cin),
|
||||
i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs)
|
||||
)
|
||||
conv_args = H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs
|
||||
conv([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args])
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
|
@ -270,7 +267,7 @@ class Conv2D(Function):
|
|||
}
|
||||
""")
|
||||
|
||||
conv_args = i32(H), i32(W), i32(ctx.groups), i32(rcout), i32(cin), i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs), i32(bs)
|
||||
convw([ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
|
||||
convx([bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
|
||||
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
|
||||
convw([ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args])
|
||||
convx([bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args])
|
||||
return dx, dw
|
||||
|
|
Loading…
Reference in New Issue