add copyin copyout for image on GPU [run_process_replay] (#6580)

* add copyin copyout for image on GPU [run_process_replay]

* add timing

* enqueue vs total run

* it's failing but that's fine
This commit is contained in:
George Hotz 2024-09-18 16:06:20 +08:00 committed by GitHub
parent 162ead02a9
commit d02bb270b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 24 deletions

View File

@ -1,12 +1,12 @@
import os, sys, pickle
import os, sys, pickle, time
import numpy as np
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
from tinygrad import fetch, Tensor, TinyJit, Device, Context, GlobalCounters
from tinygrad.helpers import OSX, DEBUG
from tinygrad.helpers import OSX, DEBUG, Timing
from tinygrad.tensor import _from_np_dtype
Device.DEFAULT = "GPU" # should be QCOM on comma device
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
@ -15,7 +15,6 @@ from extra.onnx import get_run_onnx # TODO: port to main tinygrad
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
OUTPUT = "/tmp/openpilot.pkl"
def compile():
# hack to fix GPU on OSX: max doesn't work on half, see test/external/external_gpu_fail_osx.py
if OSX:
@ -31,10 +30,10 @@ def compile():
run_onnx = get_run_onnx(onnx_model)
print("loaded model")
Tensor.manual_seed(100)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input}
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in input_shapes.items()}
Tensor.manual_seed(100)
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
print("created tensors")
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
@ -43,7 +42,10 @@ def compile():
print(f"run {i}")
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
ret = next(iter(run_onnx_jit(**new_inputs).values())).cast('float32').numpy()
if i == 0: test_val = np.copy(ret)
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
np.testing.assert_equal(test_val, ret)
print("jit run validated")
with open(OUTPUT, "wb") as f:
pickle.dump(run_onnx_jit, f)
@ -52,17 +54,26 @@ def compile():
print(f"mdl size is {mdl_sz/1e6:.2f}M")
print(f"pkl size is {pkl_sz/1e6:.2f}M")
print("**** compile done ****")
return test_val
def test():
def test(test_val):
with open(OUTPUT, "rb") as f:
run = pickle.load(f)
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype) for nm, (st, _, dtype, _) in
zip(run.captured.expected_names, run.captured.expected_st_vars_dtype_device)}
out = run(**new_inputs)
val = out['outputs'].numpy()
Tensor.manual_seed(100)
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
sorted(zip(run.captured.expected_names, run.captured.expected_st_vars_dtype_device))}
for _ in range(20):
st = time.perf_counter()
out = run(**new_inputs)
mt = time.perf_counter()
val = out['outputs'].numpy()
et = time.perf_counter()
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")
print(out, val.shape, val.dtype)
np.testing.assert_equal(test_val, val)
print("**** test done ****")
if __name__ == "__main__":
compile()
test()
test_val = compile()
test(test_val)

View File

@ -3,6 +3,33 @@ import numpy as np
from tinygrad import Device, dtypes, Tensor, Context
from tinygrad.dtype import ImageDType
from tinygrad.engine.realize import lower_schedule
from tinygrad.helpers import prod
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
class TestImageCopy(unittest.TestCase):
def test_image_copyout_1x1(self):
it = Tensor.arange(4).cast(dtypes.imagef((1,1,4))).realize()
buf = it.lazydata.buffer
out = buf.as_buffer()
np.testing.assert_equal(out.cast('f').tolist(), np.arange(4))
def test_image_copyout_2x3(self):
it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize()
buf = it.lazydata.buffer
out = buf.as_buffer()
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4))
def test_image_roundtrip(self):
sz = (4,2,4)
it = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize()
buf = it.lazydata.buffer
out = buf.as_buffer()
it2 = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize()
buf2 = it2.lazydata.buffer
buf2.copyin(out)
assert (it == it2).sum().item() == prod(sz)
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
class TestImageDType(unittest.TestCase):

View File

@ -44,8 +44,8 @@ class CLProgram:
if hasattr(self, 'kernel'): check(cl.clReleaseKernel(self.kernel))
if hasattr(self, 'program'): check(cl.clReleaseProgram(self.program))
def __call__(self, *bufs:ctypes._CData, global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
for i,b in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
def __call__(self, *bufs:Tuple[ctypes._CData, BufferOptions], global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
for i,(b,_) in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
for i,v in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))
if local_size is not None: global_size = cast(Tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
event = cl.cl_event() if wait else None
@ -62,18 +62,26 @@ class CLAllocator(LRUAllocator):
def __init__(self, device:CLDevice):
self.device = device
super().__init__()
def _alloc(self, size:int, options:BufferOptions) -> ctypes._CData:
def _alloc(self, size:int, options:BufferOptions) -> Tuple[ctypes._CData, BufferOptions]:
if options.image is not None:
return checked(cl.clCreateImage2D(self.device.context, cl.CL_MEM_READ_WRITE,
return (checked(cl.clCreateImage2D(self.device.context, cl.CL_MEM_READ_WRITE,
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status)
return checked(cl.clCreateBuffer(self.device.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status)
def _free(self, opaque:ctypes._CData, options:BufferOptions): check(cl.clReleaseMemObject(opaque))
def copyin(self, dest:ctypes._CData, src:memoryview):
check(cl.clEnqueueWriteBuffer(self.device.queue, dest, False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
return (checked(cl.clCreateBuffer(self.device.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
def _free(self, opaque:Tuple[ctypes._CData, BufferOptions], options:BufferOptions): check(cl.clReleaseMemObject(opaque[0]))
def copyin(self, dest:Tuple[ctypes._CData, BufferOptions], src:memoryview):
if dest[1].image is not None:
check(cl.clEnqueueWriteImage(self.device.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0),
(ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None))
else:
check(cl.clEnqueueWriteBuffer(self.device.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
self.device.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command
def copyout(self, dest:memoryview, src:ctypes._CData):
check(cl.clEnqueueReadBuffer(self.device.queue, src, False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
def copyout(self, dest:memoryview, src:Tuple[ctypes._CData, BufferOptions]):
if src[1].image is not None:
check(cl.clEnqueueReadImage(self.device.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0),
(ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None))
else:
check(cl.clEnqueueReadBuffer(self.device.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
self.device.synchronize()
class CLDevice(Compiled):