mirror of https://github.com/commaai/tinygrad.git
For cuda get current free space from device, and retry alloc failures (#2197)
* For cuda get current free space from device, and rery alloc failures * type ignore for mypy * add init to get free mem in cuda * Move retry logic in common lib. Fix typo in override _get_cur_free_space * linter error fix in test file * Not catch all, as it will catch KeyboardInterrupt * fix unintened line changes
This commit is contained in:
parent
2465d5d267
commit
6051f0ce82
|
@ -1,7 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import pytest
|
||||
import numpy as np
|
||||
from weakref import ref
|
||||
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
|
@ -23,7 +25,9 @@ class FakeDeviceBuffer:
|
|||
assert self.id == 0, "Should called _do_free() before"
|
||||
|
||||
class FakeAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
|
||||
def _do_alloc(self, size, dtype, device, **kwargs):
|
||||
if size*dtype.itemsize > self._get_cur_free_space(device): raise Exception("OOM")
|
||||
return FakeDeviceBuffer(size, dtype, device)
|
||||
def _do_free(self, buf):
|
||||
buf.id -= 1
|
||||
assert buf.id == 0, f"Free should be called once, but {buf.id}"
|
||||
|
@ -108,6 +112,44 @@ class TestAllocators(unittest.TestCase):
|
|||
test()
|
||||
check_gc()
|
||||
|
||||
def test_lru_allocator_failing_alloc_cleans_cache(self):
|
||||
def test():
|
||||
lru_allocator = FakeAllocator(128)
|
||||
for size in range(1, 4):
|
||||
alloc_free_trace(lru_allocator, size, dtypes.float32, device='0')
|
||||
assert len(lru_allocator.aging_order['0']) == 3, "All buffers should be cached"
|
||||
assert lru_allocator.free_space['0'] == 128 - 24, "24 bytes to be used by current cached buffers"
|
||||
|
||||
def always_raise_exception(*args, **kwargs):
|
||||
raise Exception("OOM")
|
||||
lru_allocator._do_alloc = always_raise_exception
|
||||
|
||||
with pytest.raises(Exception):
|
||||
buff = alloc(lru_allocator, 5, dtypes.float32, device='0')
|
||||
assert len(lru_allocator.aging_order['0']) == 0, "All buffers should be freed from cache due to failing alloc"
|
||||
test()
|
||||
check_gc()
|
||||
|
||||
def test_lru_allocator_fail_first_alloc_pass_after_clear_cahce(self):
|
||||
def test():
|
||||
lru_allocator = FakeAllocator(128)
|
||||
for size in range(1, 4):
|
||||
alloc_free_trace(lru_allocator, size, dtypes.float32, device='0')
|
||||
cache_length = 3
|
||||
assert len(lru_allocator.aging_order['0']) == cache_length, "All buffers should be cached"
|
||||
assert lru_allocator.free_space['0'] == 128 - 24, "24 bytes to be used by current cached buffers"
|
||||
|
||||
original_do_alloc = lru_allocator._do_alloc # save the original method
|
||||
def single_fail_then_pass(*args, **kwargs):
|
||||
lru_allocator._do_alloc = original_do_alloc # restore the original method
|
||||
raise Exception("OOM")
|
||||
lru_allocator._do_alloc = single_fail_then_pass
|
||||
|
||||
buff = alloc(lru_allocator, 5, dtypes.float32, device='0')
|
||||
assert len(lru_allocator.aging_order['0']) < cache_length, "Some buffers should be cleaned as first alloc failed"
|
||||
test()
|
||||
check_gc()
|
||||
|
||||
@unittest.skip("failing in CI")
|
||||
def test_gpu_copyout(self):
|
||||
def test():
|
||||
|
|
|
@ -76,15 +76,21 @@ class LRUAllocator:
|
|||
GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
|
||||
return rawbufs.popleft()[0]
|
||||
|
||||
def ensure_has_free_space(self, size, dtype, device):
|
||||
while len(self.aging_order[device]) and (self.free_space[device]-size*dtype.itemsize) < 0: # When OOM removing lru buffers.
|
||||
def ensure_has_free_space(self, space_to_free, device):
|
||||
while len(self.aging_order[device]) and self._get_cur_free_space(device) < space_to_free: # When OOM removing lru buffers.
|
||||
bucket, epoch = self.aging_order[device].popleft()
|
||||
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
|
||||
|
||||
def _alloc_buffer(self, size, dtype, device, **kwargs):
|
||||
self.ensure_has_free_space(size, dtype, device)
|
||||
self.ensure_has_free_space(size*dtype.itemsize, device)
|
||||
while True:
|
||||
try:
|
||||
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
|
||||
break
|
||||
except Exception:
|
||||
if len(self.aging_order[device]) == 0: raise
|
||||
self.ensure_has_free_space(1.1*self._get_cur_free_space(device), device) # increase free space by 10% and try again.
|
||||
self.free_space[device] -= size*dtype.itemsize
|
||||
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
|
||||
self.buffer_info[newbuf] = (size, dtype, device)
|
||||
return newbuf
|
||||
|
||||
|
@ -109,3 +115,4 @@ class LRUAllocator:
|
|||
def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
|
||||
def _do_free(self, buf): pass
|
||||
def _get_cur_free_space(self, device): return self.free_space[device]
|
||||
|
|
|
@ -49,9 +49,11 @@ else:
|
|||
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
class CUDAAllocator(LRUAllocator):
|
||||
def __init__(self): super().__init__(self._get_cur_free_space(None))
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return cuda.mem_alloc(size * dtype.itemsize) # type: ignore
|
||||
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
||||
CUDAAlloc = CUDAAllocator(pycuda.driver.Context.get_device().total_memory())
|
||||
def _get_cur_free_space(self, device): return cuda.mem_get_info()[0] # type: ignore
|
||||
CUDAAlloc = CUDAAllocator() # type: ignore
|
||||
class RawCUDABuffer(RawBufferCopyInOut): # type: ignore
|
||||
def __init__(self, size, dtype): super().__init__(size, dtype, allocator=CUDAAlloc)
|
||||
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore
|
||||
|
|
|
@ -54,7 +54,7 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
|||
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements=['C', 'A']), is_blocking=False)
|
||||
def _copyout(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
|
||||
CL.cl_allocator.ensure_has_free_space(self.size, self.dtype, self._device)
|
||||
CL.cl_allocator.ensure_has_free_space(self.size*self.dtype.itemsize, self._device)
|
||||
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
|
||||
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
|
||||
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
|
||||
|
|
Loading…
Reference in New Issue