mirror of https://github.com/commaai/tinygrad.git
allow zerosized tensors (#1659)
* allow zerosized tensors * works with numpy
This commit is contained in:
parent
f9cb31fdc2
commit
355b02dc3f
|
@ -53,10 +53,11 @@ def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._
|
|||
|
||||
class TestAllocators(unittest.TestCase):
|
||||
def test_lru_allocator_reusage(self):
|
||||
mc, mu = GlobalCounters.mem_cached, GlobalCounters.mem_used
|
||||
def test():
|
||||
lru_allocator = FakeAllocator(2048)
|
||||
traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32)
|
||||
assert GlobalCounters.mem_cached == 16*dtypes.float32.itemsize, "Buffer should be cached"
|
||||
assert GlobalCounters.mem_cached - mc == 16*dtypes.float32.itemsize, "Buffer should be cached"
|
||||
for _ in range(32):
|
||||
def __test():
|
||||
buf = alloc(lru_allocator, 16, dtypes.float32)
|
||||
|
@ -69,19 +70,20 @@ class TestAllocators(unittest.TestCase):
|
|||
buf = alloc(lru_allocator, 16, dtypes.float32)
|
||||
assert usedbuf != buf, "Nobody should get used buffer"
|
||||
__test()
|
||||
assert GlobalCounters.mem_used == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
|
||||
assert GlobalCounters.mem_used - mu == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
|
||||
test()
|
||||
check_gc()
|
||||
|
||||
def test_lru_allocator_cache_free(self):
|
||||
mc, mu = GlobalCounters.mem_cached, GlobalCounters.mem_used
|
||||
def test():
|
||||
lru_allocator = FakeAllocator(128)
|
||||
refs = []
|
||||
for _ in range(32):
|
||||
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32))
|
||||
for sz in range(32):
|
||||
for sz in range(1, 32):
|
||||
alloc_free_trace(lru_allocator, sz, dtypes.float32)
|
||||
assert GlobalCounters.mem_used + GlobalCounters.mem_cached <= 128, "Should not allocate on device more than allowed (128)"
|
||||
assert GlobalCounters.mem_used + GlobalCounters.mem_cached - mc - mu <= 128, "Should not allocate on device more than allowed (128)"
|
||||
for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache"
|
||||
test()
|
||||
check_gc()
|
||||
|
|
|
@ -220,5 +220,9 @@ class TestTinygrad(unittest.TestCase):
|
|||
x = Tensor.randn(1, 1, 1)
|
||||
x.dot(layer).mean().backward()
|
||||
|
||||
def test_zerosized_tensors(self):
|
||||
Tensor([]).realize()
|
||||
Tensor([]).numpy()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -42,13 +42,13 @@ class RawBufferCopyIn(RawBuffer):
|
|||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray, **kwargs):
|
||||
ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
|
||||
ret._copyin(x)
|
||||
if x.size > 0: ret._copyin(x)
|
||||
return ret
|
||||
|
||||
class RawBufferMapped(RawBufferCopyIn):
|
||||
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
|
||||
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
|
||||
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
|
||||
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size) # type: ignore
|
||||
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
|
||||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
|
@ -61,7 +61,7 @@ class RawBufferCopyInOut(RawBufferCopyIn):
|
|||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
|
||||
self._copyout(x)
|
||||
if x.size > 0: self._copyout(x)
|
||||
return x
|
||||
|
||||
class RawBufferTransfer(RawBuffer):
|
||||
|
@ -91,7 +91,7 @@ class LRUAllocator:
|
|||
while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers.
|
||||
bucket, epoch = self.aging_order[device].popleft()
|
||||
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
|
||||
newbuf = self._do_alloc(size, dtype, device, **kwargs)
|
||||
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
|
||||
self.buffer_info[newbuf] = (size, dtype, device)
|
||||
return newbuf
|
||||
def _free_buffer(self, buf_to_free):
|
||||
|
|
Loading…
Reference in New Issue