From 7ecf4dff68a247035d36dfb63e1e3500d41d7655 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 3 May 2023 12:15:28 -0700 Subject: [PATCH] multi cl_queue (#762) * multi cl_queue * only platforms 1 * gpus first, then cpus * put device on underlying buffer * cl_queue array --- extra/thneed.py | 10 +++++----- openpilot/compile.py | 10 +++++----- test/external/external_osx_profiling.py | 10 +++++----- tinygrad/runtime/ops_gpu.py | 23 +++++++++++++---------- 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/extra/thneed.py b/extra/thneed.py index 32b1b270..968a0ee4 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -193,7 +193,7 @@ class Thneed: }) if needs_load: data = np.empty(a.size//4, dtype=np.float32) - cl.enqueue_copy(CL.cl_queue, data, a, is_blocking=True) + cl.enqueue_copy(CL.cl_queue[0], data, a, is_blocking=True) weights.append(data.tobytes()) elif isinstance(a, cl.Image): assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type" @@ -204,7 +204,7 @@ class Thneed: buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1)) # zero out the buffer - cl.enqueue_copy(CL.cl_queue, buf, b'\x00'*buf.size, is_blocking=True) + cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True) CLProgram("from_image_strided", """ __kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) { @@ -224,7 +224,7 @@ class Thneed: if needs_load: data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32) - cl.enqueue_copy(CL.cl_queue, data, buf, is_blocking=True) + cl.enqueue_copy(CL.cl_queue[0], data, buf, is_blocking=True) if FLOAT16: data = data.astype(np.float16) weights.append(data.tobytes()) else: @@ -271,9 +271,9 @@ class Thneed: events = [] st = time.monotonic() for prg, args in self.cl_cache: - events.append(prg.clprg(CL.cl_queue, *args)) + events.append(prg.clprg(CL.cl_queue[0], *args)) mt = time.monotonic() - CL.cl_queue.finish() + CL.synchronize() et = time.monotonic() - st print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms") diff --git a/openpilot/compile.py b/openpilot/compile.py index 31c0c66f..9daa6a6c 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -88,7 +88,7 @@ def compile(dat, output_fn): # confirm thneed found the right output thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape) - cl.enqueue_copy(CL.cl_queue, thneed_out, t.outputs[0], is_blocking=True) + cl.enqueue_copy(CL.cl_queue[0], thneed_out, t.outputs[0], is_blocking=True) np.testing.assert_allclose(thneed_out, tinygrad_out.numpy()) # testing is float32 only (fix this) @@ -106,11 +106,11 @@ def compile(dat, output_fn): # try old thneed with a different input for k,v in t.inputs.items(): - cl.enqueue_copy(CL.cl_queue, v, new_np_inputs[k], is_blocking=True) + cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True) t.run() old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape) - cl.enqueue_copy(CL.cl_queue, old_thneed_out, t.outputs[0], is_blocking=True) + cl.enqueue_copy(CL.cl_queue[0], old_thneed_out, t.outputs[0], is_blocking=True) # compare thneed (rerun) with torch np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2) @@ -123,11 +123,11 @@ def compile(dat, output_fn): # inputs for k,v in nt.inputs.items(): - cl.enqueue_copy(CL.cl_queue, v, new_np_inputs[k], is_blocking=True) + cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True) nt.run() new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape) - cl.enqueue_copy(CL.cl_queue, new_thneed_out, nt.outputs[0], is_blocking=True) + cl.enqueue_copy(CL.cl_queue[0], new_thneed_out, nt.outputs[0], is_blocking=True) # compare torch to thneed np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2) diff --git a/test/external/external_osx_profiling.py b/test/external/external_osx_profiling.py index 83cb2677..7380ed9a 100644 --- a/test/external/external_osx_profiling.py +++ b/test/external/external_osx_profiling.py @@ -10,16 +10,16 @@ prg = CLProgram("test", """__kernel void test(__global float *a, __global float int idx = get_global_id(0); a[idx] = b[idx] + c[idx]; }""") -prg.clprg(CL.cl_queue, [N,], None, a._cl, b._cl, c._cl) +prg.clprg(CL.cl_queue[0], [N,], None, a._cl, b._cl, c._cl) t1 = time.monotonic_ns() -e1 = prg.clprg(CL.cl_queue, [N,], None, a._cl, b._cl, c._cl) -CL.cl_queue.finish() # type: ignore +e1 = prg.clprg(CL.cl_queue[0], [N,], None, a._cl, b._cl, c._cl) +CL.synchronize() t2 = time.monotonic_ns() time.sleep(3) t3 = time.monotonic_ns() -e2 = prg.clprg(CL.cl_queue, [N,], None, a._cl, b._cl, c._cl) -CL.cl_queue.finish() # type: ignore +e2 = prg.clprg(CL.cl_queue[0], [N,], None, a._cl, b._cl, c._cl) +CL.synchronize() t4 = time.monotonic_ns() print(e1.profile.queued) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 6681bc0b..4dfda730 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -14,16 +14,17 @@ FLOAT16 = getenv("FLOAT16", 0) class _CL: def __init__(self): - devices: List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], []) - if len(devices) == 0: devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], []) # settle for CPU - if len(devices) > 1 or DEBUG >= 1: print(f"using {devices[getenv('CL_DEVICE', 0)]}") - self.cl_ctx: cl.Context = cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]]) - self.cl_queue: cl.CommandQueue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue + platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)] + if DEBUG >= 1: print(f"using {platforms[getenv('CL_PLATFORM', 0)]}") + self.cl_ctx: cl.Context = cl.Context(devices=platforms[getenv('CL_PLATFORM', 0)]) + self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(self.cl_ctx, device=device, properties=cl.command_queue_properties.PROFILING_ENABLE) for device in self.cl_ctx.devices] + def synchronize(self): + for q in self.cl_queue: q.finish() CL = _CL() # TODO: merge CLImage in here class CLBuffer(RawBufferCopyInOut): - def __init__(self, size, dtype): + def __init__(self, size, dtype, device=0): if isinstance(dtype, ImageDType): fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize]) buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0])) @@ -31,13 +32,14 @@ class CLBuffer(RawBufferCopyInOut): # NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize else: buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize) + setattr(buf, 'device', device) # device is tracked on the underlying buffer super().__init__(size, dtype, buf) def _copyin(self, x:np.ndarray): assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}" - cl.enqueue_copy(CL.cl_queue, self._buf, x, is_blocking=False) + cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, x, is_blocking=False) def _copyout(self, x:np.ndarray): assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}" - cl.enqueue_copy(CL.cl_queue, x, self._buf, is_blocking=True) + cl.enqueue_copy(CL.cl_queue[self._buf.device], x, self._buf, is_blocking=True) class CLProgram: def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None): @@ -63,7 +65,8 @@ class CLProgram: def max_work_group_size(): return CL.cl_ctx.devices[0].max_work_group_size def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]: - e = self.clprg(CL.cl_queue, global_size, local_size, *[x._buf if isinstance(x, CLBuffer) else x for x in bufs]) + cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs] + e = self.clprg(CL.cl_queue[cl_bufs[0].device], global_size, local_size, *cl_bufs) if wait: e.wait() return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9 @@ -76,4 +79,4 @@ class CLCodegen(CStyleCodegen): barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True) -GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.cl_queue.finish) +GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.synchronize)