diff --git a/visionipc/visionipc.pxd b/visionipc/visionipc.pxd index 32baa9c..3151dfc 100644 --- a/visionipc/visionipc.pxd +++ b/visionipc/visionipc.pxd @@ -8,7 +8,12 @@ from libc.stdint cimport uint32_t, uint64_t from libcpp cimport bool, int cdef extern from "cereal/visionipc/visionbuf.h": + struct _cl_device_id + struct _cl_context struct _cl_mem + + ctypedef _cl_device_id * cl_device_id + ctypedef _cl_context * cl_context ctypedef _cl_mem * cl_mem cdef enum VisionStreamType: diff --git a/visionipc/visionipc_pyx.pxd b/visionipc/visionipc_pyx.pxd index 6e2d5ed..ec431ce 100644 --- a/visionipc/visionipc_pyx.pxd +++ b/visionipc/visionipc_pyx.pxd @@ -2,6 +2,11 @@ #cython: language_level=3 from .visionipc cimport VisionBuf as cppVisionBuf +from .visionipc cimport cl_device_id, cl_context + +cdef class CLContext: + cdef cl_device_id device_id + cdef cl_context context cdef class VisionBuf: cdef cppVisionBuf * buf diff --git a/visionipc/visionipc_pyx.pyx b/visionipc/visionipc_pyx.pyx index bcce534..0ff270e 100644 --- a/visionipc/visionipc_pyx.pyx +++ b/visionipc/visionipc_pyx.pyx @@ -98,8 +98,11 @@ cdef class VisionIpcClient: cdef cppVisionIpcClient * client cdef VisionIpcBufExtra extra - def __cinit__(self, string name, VisionStreamType stream, bool conflate): - self.client = new cppVisionIpcClient(name, stream, conflate, NULL, NULL) + def __cinit__(self, string name, VisionStreamType stream, bool conflate, CLContext context = None): + if context: + self.client = new cppVisionIpcClient(name, stream, conflate, context.device_id, context.context) + else: + self.client = new cppVisionIpcClient(name, stream, conflate, NULL, NULL) def __dealloc__(self): del self.client