Fix GPU 2**31 virtual size limit (#392)

* in progress

* big conv test works

* that's unneeded

* fix opencl with reduce

* rewrite contiguous_view_constant_fold

* clean up mids in loop code

* subidx

* print cl kernel before run

* no reduce, no loop

* Revert "no reduce, no loop"

This reverts commit 92777e40e9fbecd9f49fc520b48a12d42d6cbd42.
This commit is contained in:
George Hotz 2022-10-04 21:55:20 -07:00 committed by GitHub
parent 392e57aea7
commit b7f748c15a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 23 deletions

View File

@ -210,8 +210,10 @@ class OpenCLBuffer(GPUBuffer):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]:
def contiguous_view_constant_fold(x, name:str, reduce:Optional[int]=None) -> Tuple[str, Optional[str], str]:
if x.is_image():
# this will only be for convs, so it shouldn't be a reduce
assert reduce is None
#print("is image")
return f"""inline float get_{name}(const sampler_t smp, read_only image2d_t x, int gid) {{
int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')};
@ -223,9 +225,9 @@ class OpenCLBuffer(GPUBuffer):
int2 l_smp = {x._image.pos_to_sample_pos('l')};
float4 dat = read_imagef(x, smp, l_smp);
return valid ? (idx4 == 0 ? dat.x : (idx4 == 1 ? dat.y : (idx4 == 2 ? dat.z : dat.w))) : 0.0;
}}""", f"read_only image2d_t {name}_g", f"get_{name}(smp, {name}_g, idx);"
}}""", f"read_only image2d_t {name}_g", f"get_{name}(smp, {name}_g, gid);"
#ewtypes.append(f"read_only image2d_t {name}_g")
return super().contiguous_view_constant_fold(name)
return super().contiguous_view_constant_fold(name, reduce)
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc"):
if C is None or earlycode != "acc":

View File

@ -207,7 +207,6 @@ class TestOps(unittest.TestCase):
arg = (4,3,2,6)
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
@unittest.skipUnless(Device.DEFAULT != "GPU", "GPU doesn't work with convs with virtual dimensions > 2**31")
def test_sd_big_conv(self):
# internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int
helper_test_op([(1,256,64,64), (512,256,3,3)],

View File

@ -64,6 +64,8 @@ class CLProgram:
CL.CACHE.append((self, args))
else:
e = self.clprg(CL().cl_queue, *args)
if DEBUG >= 4:
print(self.prg)
if DEBUG >= 2:
CL.cl_queue.finish()
if DEBUG >= 1:
@ -71,8 +73,6 @@ class CLProgram:
CL.ops_sum += op_estimate
print(f"**CL** {CL.kernel_count:6d} {self.name:20s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:6.1f}M/{CL.ops_sum/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " +
("" if DEBUG <= 1 or CL.CACHE is not None else f"tm {(e.profile.end - e.profile.start)/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({op_estimate/(e.profile.end - e.profile.start):8.2f} GFLOPS)"))
if DEBUG >= 4:
print(self.prg)
# **** end CL wrappers ****
@ -115,15 +115,13 @@ class GPUBuffer:
CL.enqueue_copy(data, self.contiguous_op().cl, is_blocking=True)
return data
def contiguous_view(x, name:str) -> str:
return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}"
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]:
if x._base_shape == (1,) and x._backing is not None:
# this function doesn't need a memory access
return f"inline float get_{name}(int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? {x._backing[0]} : 0.0;}}", None, f"get_{name}(idx);"
else:
return x.contiguous_view(name), f"__global const float *{name}_g", f"get_{name}({name}_g, idx);"
def contiguous_view_constant_fold(x, name:str, reduce:Optional[int]=None) -> Tuple[str, Optional[str], str]:
idx_getter = f"int valid = 1; {'long' if prod(x.shape) >= 2**31 else 'int'} idx = gid; {'idx *= '+str(reduce)+'; idx += subidx;' if reduce is not None else ''} {x.st.expr().replace('//', '/')};"
constant = x._backing[0] if x._base_shape == (1,) and x._backing is not None else None
args = (["__global const float *x"] if constant is None else []) + ["int gid"] + (["int subidx"] if reduce is not None else [])
return f"inline float get_{name}({','.join(args)}) {{ {idx_getter} return valid ? {constant if constant is not None else 'x[idx]'} : 0.0;}}", \
f"__global const float *{name}_g" if constant is None else None, \
f"get_{name}({name+'_g, ' if constant is None else ''}gid{', subidx' if reduce is not None else ''});"
def unary_op(x, op:UnaryOps): return type(x)(x.shape)._processing_op([("A", x)], GPUBuffer.code_for_op[op])
def binary_op(x, op:BinaryOps, y:GPUBuffer): return type(x)(x.shape)._processing_op([("A", x), ("B", y)], GPUBuffer.code_for_op[op])
@ -133,19 +131,20 @@ class GPUBuffer:
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
assert C is None
for _, b in bufs:
assert prod(b.shape) < 2**31, f"GPU buffers must be under 2**31, {b.shape} isn't"
# get the input/output shape and the reduce amount
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
assert red < 2**31, f"reduce must be under 2**31, {red} isn't"
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
if red > 1 and prod(ret.shape) != 1:
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
kernel_name = "reduce" if red > 1 else "elementwise"
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
early_views = {name:buf.contiguous_view_constant_fold(name, red) for name, buf in bufs if name in earlybufs}
late_views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs if name not in earlybufs}
views = {**early_views, **late_views}
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] # type: ignore
@ -155,22 +154,22 @@ class GPUBuffer:
if inter_red > 1:
buf_cl.append(cl.LocalMemory(inter_red*4))
reduce_loop = f"int mid = get_global_id(1); for (int subidx = {red//inter_red + 1} * mid; subidx < min({red}, {red//inter_red + 1} * (mid+1)); subidx++)" if inter_red > 1 else f"for (int subidx = 0; subidx < {red}; subidx++)"
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
float acc = {GPUBuffer.start_for_op[op]};
int gid = get_global_id(0);
{'int mid = get_global_id(1);' if inter_red > 1 else 'int mid = 0;'}
for (int idx = gid * {red} + {red//inter_red + 1} * mid; idx < gid * {red} + min({red}, {red//inter_red + 1} * (mid+1)); idx++) {{
{chr(10).join([f' float {name} = ' + views[name][2] for name, _ in bufs if name in earlybufs])}
{reduce_loop} {{
{chr(10).join([f' float {name} = ' + early_views[name][2] for name in early_views])}
acc = {earlycode};
}} int idx = gid;"""+(f"""
}}"""+(f"""
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
for (int rdx = 0; rdx < {inter_red}; rdx++) {{
acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')};
}}""" if inter_red != 1 else "{")+f"""
{chr(10).join([f' float {name} = ' + views[name][2] for name, _ in bufs if name not in earlybufs])}
{chr(10).join([f' float {name} = ' + late_views[name][2] for name in late_views])}
output[gid] = {code};
}}
}}""")