multi cl_queue (#762)

* multi cl_queue

* only platforms 1

* gpus first, then cpus

* put device on underlying buffer

* cl_queue array
This commit is contained in:
George Hotz 2023-05-03 12:15:28 -07:00 committed by GitHub
parent 7757f5fed2
commit 7ecf4dff68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 25 deletions

View File

@ -193,7 +193,7 @@ class Thneed:
})
if needs_load:
data = np.empty(a.size//4, dtype=np.float32)
cl.enqueue_copy(CL.cl_queue, data, a, is_blocking=True)
cl.enqueue_copy(CL.cl_queue[0], data, a, is_blocking=True)
weights.append(data.tobytes())
elif isinstance(a, cl.Image):
assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type"
@ -204,7 +204,7 @@ class Thneed:
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
# zero out the buffer
cl.enqueue_copy(CL.cl_queue, buf, b'\x00'*buf.size, is_blocking=True)
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)
CLProgram("from_image_strided", """
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
@ -224,7 +224,7 @@ class Thneed:
if needs_load:
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
cl.enqueue_copy(CL.cl_queue, data, buf, is_blocking=True)
cl.enqueue_copy(CL.cl_queue[0], data, buf, is_blocking=True)
if FLOAT16: data = data.astype(np.float16)
weights.append(data.tobytes())
else:
@ -271,9 +271,9 @@ class Thneed:
events = []
st = time.monotonic()
for prg, args in self.cl_cache:
events.append(prg.clprg(CL.cl_queue, *args))
events.append(prg.clprg(CL.cl_queue[0], *args))
mt = time.monotonic()
CL.cl_queue.finish()
CL.synchronize()
et = time.monotonic() - st
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")

View File

@ -88,7 +88,7 @@ def compile(dat, output_fn):
# confirm thneed found the right output
thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
cl.enqueue_copy(CL.cl_queue, thneed_out, t.outputs[0], is_blocking=True)
cl.enqueue_copy(CL.cl_queue[0], thneed_out, t.outputs[0], is_blocking=True)
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
# testing is float32 only (fix this)
@ -106,11 +106,11 @@ def compile(dat, output_fn):
# try old thneed with a different input
for k,v in t.inputs.items():
cl.enqueue_copy(CL.cl_queue, v, new_np_inputs[k], is_blocking=True)
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
t.run()
old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
cl.enqueue_copy(CL.cl_queue, old_thneed_out, t.outputs[0], is_blocking=True)
cl.enqueue_copy(CL.cl_queue[0], old_thneed_out, t.outputs[0], is_blocking=True)
# compare thneed (rerun) with torch
np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2)
@ -123,11 +123,11 @@ def compile(dat, output_fn):
# inputs
for k,v in nt.inputs.items():
cl.enqueue_copy(CL.cl_queue, v, new_np_inputs[k], is_blocking=True)
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
nt.run()
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
cl.enqueue_copy(CL.cl_queue, new_thneed_out, nt.outputs[0], is_blocking=True)
cl.enqueue_copy(CL.cl_queue[0], new_thneed_out, nt.outputs[0], is_blocking=True)
# compare torch to thneed
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)

View File

@ -10,16 +10,16 @@ prg = CLProgram("test", """__kernel void test(__global float *a, __global float
int idx = get_global_id(0);
a[idx] = b[idx] + c[idx];
}""")
prg.clprg(CL.cl_queue, [N,], None, a._cl, b._cl, c._cl)
prg.clprg(CL.cl_queue[0], [N,], None, a._cl, b._cl, c._cl)
t1 = time.monotonic_ns()
e1 = prg.clprg(CL.cl_queue, [N,], None, a._cl, b._cl, c._cl)
CL.cl_queue.finish() # type: ignore
e1 = prg.clprg(CL.cl_queue[0], [N,], None, a._cl, b._cl, c._cl)
CL.synchronize()
t2 = time.monotonic_ns()
time.sleep(3)
t3 = time.monotonic_ns()
e2 = prg.clprg(CL.cl_queue, [N,], None, a._cl, b._cl, c._cl)
CL.cl_queue.finish() # type: ignore
e2 = prg.clprg(CL.cl_queue[0], [N,], None, a._cl, b._cl, c._cl)
CL.synchronize()
t4 = time.monotonic_ns()
print(e1.profile.queued)

View File

@ -14,16 +14,17 @@ FLOAT16 = getenv("FLOAT16", 0)
class _CL:
def __init__(self):
devices: List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
if len(devices) == 0: devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], []) # settle for CPU
if len(devices) > 1 or DEBUG >= 1: print(f"using {devices[getenv('CL_DEVICE', 0)]}")
self.cl_ctx: cl.Context = cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
self.cl_queue: cl.CommandQueue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)]
if DEBUG >= 1: print(f"using {platforms[getenv('CL_PLATFORM', 0)]}")
self.cl_ctx: cl.Context = cl.Context(devices=platforms[getenv('CL_PLATFORM', 0)])
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(self.cl_ctx, device=device, properties=cl.command_queue_properties.PROFILING_ENABLE) for device in self.cl_ctx.devices]
def synchronize(self):
for q in self.cl_queue: q.finish()
CL = _CL()
# TODO: merge CLImage in here
class CLBuffer(RawBufferCopyInOut):
def __init__(self, size, dtype):
def __init__(self, size, dtype, device=0):
if isinstance(dtype, ImageDType):
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
@ -31,13 +32,14 @@ class CLBuffer(RawBufferCopyInOut):
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
else:
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', device) # device is tracked on the underlying buffer
super().__init__(size, dtype, buf)
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
cl.enqueue_copy(CL.cl_queue, self._buf, x, is_blocking=False)
cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, x, is_blocking=False)
def _copyout(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
cl.enqueue_copy(CL.cl_queue, x, self._buf, is_blocking=True)
cl.enqueue_copy(CL.cl_queue[self._buf.device], x, self._buf, is_blocking=True)
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
@ -63,7 +65,8 @@ class CLProgram:
def max_work_group_size(): return CL.cl_ctx.devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
e = self.clprg(CL.cl_queue, global_size, local_size, *[x._buf if isinstance(x, CLBuffer) else x for x in bufs])
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
e = self.clprg(CL.cl_queue[cl_bufs[0].device], global_size, local_size, *cl_bufs)
if wait:
e.wait()
return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9
@ -76,4 +79,4 @@ class CLCodegen(CStyleCodegen):
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.cl_queue.finish)
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.synchronize)