mirror of https://github.com/commaai/tinygrad.git
add Tensor.from_blob (#6765)
* draft tensor from pointer init * some docs and types * comment * cleaner * test * malloc * qcom cl interop * jit example * cleaner * dealoc * wording * docs
This commit is contained in:
parent
14ad47b515
commit
3c56aeee70
|
@ -9,6 +9,7 @@
|
|||
::: tinygrad.Tensor.full_like
|
||||
::: tinygrad.Tensor.zeros_like
|
||||
::: tinygrad.Tensor.ones_like
|
||||
::: tinygrad.Tensor.from_blob
|
||||
|
||||
## Creation (random)
|
||||
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
import ctypes, array
|
||||
from hexdump import hexdump
|
||||
from tinygrad.runtime.ops_gpu import GPUDevice
|
||||
from tinygrad.helpers import getenv, to_mv, mv_address
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad import Tensor, TinyJit
|
||||
from tinygrad.runtime.autogen import opencl as cl
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
# create raw opencl buffer.
|
||||
gdev = GPUDevice()
|
||||
cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 0x100, None, status := ctypes.c_int32())
|
||||
assert status.value == 0
|
||||
|
||||
# fill it with something for fun
|
||||
data = memoryview(array.array('I', [i for i in range(64)]))
|
||||
cl.clEnqueueWriteBuffer(gdev.queue, cl_buf, False, 0, 0x100, mv_address(data), 0, None, None)
|
||||
cl.clFinish(gdev.queue) # wait writes to complete
|
||||
|
||||
# get raw gpu pointer from opencl buffer.
|
||||
|
||||
## get buf desc
|
||||
hexdump(to_mv(ctypes.addressof(cl_buf), 0x40))
|
||||
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf), 8).cast('Q')[0]
|
||||
|
||||
## get buf device ptr
|
||||
hexdump(to_mv(cl_buf_desc_ptr, 0x100))
|
||||
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer.
|
||||
|
||||
# create QCOM tensor with the externally managed buffer
|
||||
x = Tensor.from_blob(rawbuf_ptr, (8, 8), dtype=dtypes.int, device='QCOM')
|
||||
y = (x + 1).numpy()
|
||||
print(y)
|
||||
|
||||
# all calculations are done, save to free the object
|
||||
cl.clReleaseMemObject(cl_buf)
|
||||
|
||||
# all together with jit
|
||||
@TinyJit
|
||||
def calc(x): return x + 2
|
||||
|
||||
for i in range(4):
|
||||
cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 2*2*4, None, status := ctypes.c_int32())
|
||||
assert status.value == 0
|
||||
data = memoryview(array.array('I', [x+i for x in range(2*2)]))
|
||||
cl.clEnqueueWriteBuffer(gdev.queue, cl_buf, False, 0, 2*2*4, mv_address(data), 0, None, None)
|
||||
cl.clFinish(gdev.queue) # wait writes to complete
|
||||
|
||||
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf), 8).cast('Q')[0]
|
||||
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20]
|
||||
|
||||
y = calc(x = Tensor.from_blob(rawbuf_ptr, (2, 2), dtype=dtypes.int, device='QCOM')).numpy()
|
||||
print(f'jit {i}\n', y)
|
||||
|
||||
# all calculations are done, save to free the object
|
||||
cl.clReleaseMemObject(cl_buf)
|
||||
|
||||
# now images!
|
||||
|
||||
h, w = 128, 128
|
||||
cl_img = cl.clCreateImage2D(gdev.context, cl.CL_MEM_READ_WRITE, cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT), w, h, 0, None, status := ctypes.c_int32())
|
||||
assert status.value == 0
|
||||
|
||||
# fill it with something for fun
|
||||
data = memoryview(array.array('f', [i for i in range(h*w*4)]))
|
||||
cl.clEnqueueWriteImage(gdev.queue, cl_img, False, (ctypes.c_size_t * 3)(0,0,0), (ctypes.c_size_t * 3)(w,h,1), 0, 0, mv_address(data), 0, None, None)
|
||||
cl.clFinish(gdev.queue) # wait writes to complete
|
||||
|
||||
# get raw gpu pointer from opencl buffer.
|
||||
|
||||
## get buf desc
|
||||
hexdump(to_mv(ctypes.addressof(cl_img), 0x40))
|
||||
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_img), 8).cast('Q')[0]
|
||||
|
||||
## get buf device ptr
|
||||
hexdump(to_mv(cl_buf_desc_ptr, 0x100))
|
||||
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer.
|
||||
|
||||
# create QCOM tensor with the externally managed buffer
|
||||
# dtypes.imageh = cl.cl_image_format(cl.CL_RGBA, cl.CL_HALF_FLOAT)
|
||||
# dtypes.imagef = cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT)
|
||||
x = Tensor.from_blob(rawbuf_ptr, (h*w*4,), dtype=dtypes.imagef((h,w)), device='QCOM')
|
||||
y = (x + 1).numpy()
|
||||
print(y)
|
||||
|
||||
# all calculations are done, save to free the object
|
||||
cl.clReleaseMemObject(cl_img)
|
|
@ -1,10 +1,10 @@
|
|||
import subprocess
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest, copy, mmap, random, math
|
||||
import unittest, copy, mmap, random, math, array
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import getenv, temp, CI, _METADATA
|
||||
from tinygrad.helpers import getenv, temp, CI, _METADATA, mv_address
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import is_dtype_supported
|
||||
|
@ -330,6 +330,17 @@ class TestTinygrad(unittest.TestCase):
|
|||
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype
|
||||
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else
|
||||
|
||||
def test_tensor_from_blob(self):
|
||||
x = memoryview(bytearray(16)).cast('I')
|
||||
|
||||
t = Tensor.from_blob(mv_address(x), (4,), dtype=dtypes.int, device="CLANG")
|
||||
z = (t+1)
|
||||
np.testing.assert_equal(z.numpy(), [1, 1, 1, 1])
|
||||
|
||||
x[:] = array.array('I', [0, 1, 2, 3])
|
||||
z = (t+1)
|
||||
np.testing.assert_equal(z.numpy(), [1, 2, 3, 4])
|
||||
|
||||
def test_tensor_list_dtype(self):
|
||||
for arr in ([1], [[[1]]], [[1,1],[1,1]], [[[1,1],[1,1]],[[1,1],[1,1]]]):
|
||||
assert Tensor(arr).dtype == dtypes.default_int
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import multiprocessing, decimal, statistics, random
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
|
||||
|
@ -48,6 +48,7 @@ class BufferOptions:
|
|||
cpu_access: bool = False
|
||||
host: bool = False
|
||||
nolru: bool = False
|
||||
external_ptr: Optional[int] = None
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
||||
|
@ -75,9 +76,11 @@ class Buffer:
|
|||
def ref(self, cnt): self.base._lb_refcount += cnt
|
||||
def is_allocated(self) -> bool: return hasattr(self, '_buf')
|
||||
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
||||
def allocate(self, opaque=None) -> Buffer:
|
||||
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
|
||||
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
||||
self.allocator = Device[self.device].allocator
|
||||
if external_ptr is not None:
|
||||
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
|
||||
if self._base is not None:
|
||||
self._base.ensure_allocated()
|
||||
assert hasattr(self.allocator, "offset"), "offset function required for view"
|
||||
|
@ -99,7 +102,7 @@ class Buffer:
|
|||
def nbytes(self): return self.size*self.dtype.itemsize
|
||||
def __del__(self):
|
||||
if not hasattr(self, '_buf'): return
|
||||
if self._base is None:
|
||||
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
def __repr__(self):
|
||||
|
@ -162,7 +165,8 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
|||
else: super().free(opaque, size, options)
|
||||
|
||||
class _MallocAllocator(LRUAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
|
||||
def _alloc(self, size:int, options:BufferOptions):
|
||||
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
|
||||
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||
|
|
|
@ -294,7 +294,8 @@ class QCOMAllocator(HCQAllocator):
|
|||
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
||||
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
|
||||
|
||||
texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
|
||||
if options.external_ptr: texture = QCOMBuffer(options.external_ptr, size)
|
||||
else: texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
|
||||
|
||||
# Extend HCQBuffer with texture-related info.
|
||||
texture.pitch, texture.real_stride, texture.desc, texture.ibo = pitch, real_stride, [0] * 16, [0] * 16
|
||||
|
@ -308,7 +309,7 @@ class QCOMAllocator(HCQAllocator):
|
|||
|
||||
return texture
|
||||
|
||||
return self.device._gpu_alloc(size, map_to_cpu=True)
|
||||
return QCOMBuffer(options.external_ptr, size) if options.external_ptr else self.device._gpu_alloc(size, map_to_cpu=True)
|
||||
|
||||
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0):
|
||||
while src_off < src_size:
|
||||
|
|
|
@ -398,6 +398,21 @@ class Tensor:
|
|||
"""
|
||||
return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
|
||||
"""
|
||||
Exposes the pointer as a Tensor without taking ownership of the original data.
|
||||
The pointer must remain valid for the entire lifetime of the created Tensor.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
"""
|
||||
|
||||
r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs)
|
||||
r.lazydata.buffer.allocate(external_ptr=ptr)
|
||||
del r.lazydata.srcs # fake realize
|
||||
return r
|
||||
|
||||
_seed: int = int(time.time())
|
||||
_device_seeds: Dict[str, int] = {}
|
||||
_device_rng_counters: Dict[str, Tensor] = {}
|
||||
|
|
Loading…
Reference in New Issue