mirror of https://github.com/commaai/tinygrad.git
rewrite 0 size loadop into a CONST (#2556)
* rewrite 0 size loadop into a CONST * check alloc size * EMPTY is better * Revert "EMPTY is better" This reverts commit 574fe0f9ed28f1b97da5a81afdfd2cd5d9a94ff9. * no ast is created * fix test
This commit is contained in:
parent
4447188051
commit
67f4e03724
|
@ -6,6 +6,7 @@ from tinygrad.nn.state import get_parameters
|
|||
from tinygrad.jit import TinyJit
|
||||
from tinygrad import Device, GlobalCounters
|
||||
from tinygrad.helpers import CI, dtypes
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from test.helpers import derandomize_model
|
||||
|
||||
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
|
||||
|
@ -67,8 +68,8 @@ class TestRealWorld(unittest.TestCase):
|
|||
model = GPT2Transformer(**(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"]))
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 140 if CI else 396, all_jitted=True)
|
||||
def test(t, v): return model(t, v).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.21 if CI else 0.9, 180 if CI else 516, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CLANG", "CPU"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
|
||||
def test_train_cifar(self):
|
||||
|
|
|
@ -332,5 +332,10 @@ class TestSchedule(unittest.TestCase):
|
|||
out = x ** Tensor(2)
|
||||
check_schedule(out, 1)
|
||||
|
||||
def test_zero_size(self):
|
||||
x = Tensor.rand(2, 3, 0)
|
||||
out = x + 1
|
||||
check_schedule(out, 0, filter_loadops=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -83,7 +83,9 @@ class Buffer:
|
|||
|
||||
# TODO: size, dest, src are the same type. can we enforce this?
|
||||
class Allocator:
|
||||
def alloc(self, size:int, dtype:DType): return self._alloc(size, dtype)
|
||||
def alloc(self, size:int, dtype:DType):
|
||||
assert size > 0, f"alloc size must be positve, getting {size}"
|
||||
return self._alloc(size, dtype)
|
||||
def _alloc(self, size:int, dtype:DType): raise NotImplementedError("need alloc")
|
||||
def free(self, opaque, size:int, dtype:DType): self._free(opaque) # if you are returning a Python object, you don't need a free
|
||||
def _free(self, opaque): pass
|
||||
|
|
|
@ -82,6 +82,9 @@ def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg
|
|||
|
||||
lazycache: WeakValueDictionary = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
|
||||
# rewrite 0 size into a CONST
|
||||
if 0 in st.shape: return LazyBuffer(device, ShapeTracker.from_shape(st.shape), LoadOps, LazyOp(LoadOps.CONST, tuple(), 0.0), dtype)
|
||||
|
||||
# fromcpu aren't cached
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
|
||||
|
||||
|
@ -183,7 +186,7 @@ class LazyBuffer:
|
|||
|
||||
@staticmethod
|
||||
def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Optional[LazyBuffer]=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(shape), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const(self, val:Union[float, int]) -> LazyBuffer:
|
||||
|
|
|
@ -27,16 +27,14 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
|||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
|
||||
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
|
||||
# TODO: size 0 should be removed from the schedule
|
||||
if si.out.realized.size != 0:
|
||||
if si.ast.op in LoadOps:
|
||||
if DEBUG >= 2: print(f"*** {si.ast.op:>15s} {f'{si.out.device} <- {si.inputs[0].device}' if si.ast.op is LoadOps.FROM else si.out.device:25s} sz {si.out.realized.size:5d} shape {si.out.shape} dtype {si.out.dtype} arg {si.ast.arg}")
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {}
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out.realized, *[x.realized for x in si.inputs], **kwargs)
|
||||
else:
|
||||
Device[si.out.device].get_runner(si.ast).exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals)
|
||||
if si.ast.op in LoadOps:
|
||||
if DEBUG >= 2: print(f"*** {si.ast.op:>15s} {f'{si.out.device} <- {si.inputs[0].device}' if si.ast.op is LoadOps.FROM else si.out.device:25s} sz {si.out.realized.size:5d} shape {si.out.shape} dtype {si.out.dtype} arg {si.ast.arg}")
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {}
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out.realized, *[x.realized for x in si.inputs], **kwargs)
|
||||
else:
|
||||
Device[si.out.device].get_runner(si.ast).exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals)
|
||||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
#assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
||||
|
|
|
@ -46,9 +46,7 @@ class CUDAProgram:
|
|||
return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait)
|
||||
|
||||
class CUDAAllocator(LRUAllocator):
|
||||
def _alloc(self, size, dtype):
|
||||
if size == 0: return None
|
||||
return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size * dtype.itemsize)))
|
||||
def _alloc(self, size, dtype): return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size * dtype.itemsize)))
|
||||
def _free(self, opaque): check(cuda.cuMemFree_v2(opaque))
|
||||
def copyin(self, dest, src:memoryview): check(cuda.cuMemcpyHtoD_v2(dest, from_mv(src), len(src), None))
|
||||
def copyout(self, dest:memoryview, src): check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest)))
|
||||
|
|
|
@ -59,7 +59,6 @@ class CLAllocator(LRUAllocator):
|
|||
self.device = device
|
||||
super().__init__()
|
||||
def _alloc(self, size:int, dtype:DType):
|
||||
if size == 0: return None
|
||||
if isinstance(dtype, ImageDType):
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
|
||||
|
|
|
@ -49,7 +49,6 @@ class HIPAllocator(LRUAllocator):
|
|||
self.device = device
|
||||
super().__init__()
|
||||
def _alloc(self, size: int, dtype: DType):
|
||||
if size == 0: return None
|
||||
check(hip.hipSetDevice(self.device))
|
||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size * dtype.itemsize)))
|
||||
def _free(self, opaque:T): check(hip.hipFree(opaque))
|
||||
|
|
|
@ -52,7 +52,6 @@ class MetalAllocator(LRUAllocator):
|
|||
self.device:MetalDevice = device
|
||||
super().__init__()
|
||||
def _alloc(self, size:int, dtype:DType) -> Any:
|
||||
if size == 0: return None
|
||||
ret = self.device.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
||||
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=} {dtype=}")
|
||||
return ret
|
||||
|
|
|
@ -123,6 +123,7 @@ class Tensor:
|
|||
def numpy(self) -> np.ndarray:
|
||||
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
|
||||
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
|
||||
if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np)
|
||||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape)
|
||||
def item(self) -> Union[float, int]: return self.numpy().item()
|
||||
|
||||
|
|
Loading…
Reference in New Issue