Buffer isn't always allocated (#3974)

* buffer alloc

* allocate

* missing allocates

* last one
This commit is contained in:
George Hotz 2024-03-28 13:33:47 -07:00 committed by GitHub
parent 9c03fe3e5d
commit 42b9d999ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 89 additions and 74 deletions

View File

@ -43,9 +43,9 @@ from tinygrad.ops import LazyOp, BufferOps, MemBuffer, BinaryOps
from tinygrad.shape.shapetracker import ShapeTracker
# allocate some buffers + load in values
out = Buffer(DEVICE, 1, dtypes.int32)
a = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I", 2))))
b = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I", 3))))
out = Buffer(DEVICE, 1, dtypes.int32).allocate()
a = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
# describe the computation
@ -79,8 +79,8 @@ from tinygrad.engine.schedule import create_schedule
# allocate some values + load in values
a = LazyBuffer.loadop(LoadOps.EMPTY, (1,), dtypes.int32, DEVICE)
b = LazyBuffer.loadop(LoadOps.EMPTY, (1,), dtypes.int32, DEVICE)
a.realized = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I", 2))))
b.realized = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I", 3))))
a.realized = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
b.realized = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
# describe the computation
out = a.e(BinaryOps.ADD, b)

View File

@ -39,7 +39,7 @@ def get_fuzz_rawbufs(lin):
return rawbufs
def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
rawbuf = type(rawbuf)(Device.DEFAULT, rawbuf.size if size is None else size, rawbuf.dtype)
rawbuf = type(rawbuf)(Device.DEFAULT, rawbuf.size if size is None else size, rawbuf.dtype).allocate()
if zero:
with Context(DEBUG=0):
mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))

View File

@ -278,7 +278,7 @@ def helper_realized_ast(r:Tensor):
run_schedule(s[:-1]) # run all kernels except the last one
# now all input LazyBuffers buffers in s[-1] should be realized
# allocate an output buffer
output_buffer = Buffer((out:=s[-1].outputs[0]).device, prod((s if isinstance(s, int) else s.max for s in out.shape)), out.dtype)
output_buffer = Buffer((out:=s[-1].outputs[0]).device, prod((s if isinstance(s, int) else s.max for s in out.shape)), out.dtype).allocate()
return s[-1].ast[0], [output_buffer] + [l.realized for l in s[-1].inputs]
@unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "need backends that support float4")

View File

@ -10,8 +10,8 @@ from tinygrad.tensor import Tensor
class TestTimeLinearizer(unittest.TestCase):
def test_reasonable_time(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
out = Buffer(Device.DEFAULT, si.outputs[0].st.real_size(), si.outputs[0].dtype)
rawbufs = [out] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
out = Buffer(Device.DEFAULT, si.outputs[0].st.real_size(), si.outputs[0].dtype).allocate()
rawbufs = [out] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype).allocate() for x in si.inputs]
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')

View File

@ -27,8 +27,8 @@ def _test_single_value(vals, op, dts):
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype)
buf2 = [Buffer(Device.DEFAULT, 1, dtype).copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg(UOpGraph(uops))
prg.exec([buf]+buf2)
ret = np.empty(1, output_dtype.np)
@ -42,7 +42,7 @@ def _test_single_value_const(vals, op, dts):
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype)
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg(UOpGraph(uops))
prg.exec([buf])
ret = np.empty(1, output_dtype.np)
@ -54,7 +54,7 @@ def _test_uops_result(output_dtype, uops, res):
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
# res = output_fn(uops)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
buf = Buffer(Device.DEFAULT, 1, output_dtype)
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg(UOpGraph(uops))
prg.exec([buf])
ret = np.empty(1, output_dtype.np)

58
tinygrad/buffer.py Normal file
View File

@ -0,0 +1,58 @@
from __future__ import annotations
from typing import Any, Optional
from dataclasses import dataclass
from tinygrad.helpers import flat_mv
from tinygrad.dtype import DType, ImageDType
from tinygrad.ops import GlobalCounters
@dataclass(frozen=True, eq=True)
class BufferOptions:
image: Optional[ImageDType] = None
uncached: bool = False
host: bool = False
nolru: bool = False
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, initial_value:Optional[bytes]=None):
assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
self.device, self.size, self.dtype, self.options = device, size, dtype, options
if opaque is not None: self.allocate(opaque)
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
if initial_value is not None:
self.allocate()
self.copyin(memoryview(initial_value))
def allocate(self, opaque=None) -> Buffer:
assert not hasattr(self, '_buf'), "can't alloc"
from tinygrad.device import Device
self.allocator = Device[self.device].allocator
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
return self
def __reduce__(self):
buf = None
if hasattr(self, '_buf'):
buf = bytearray(self.nbytes)
self.copyout(memoryview(buf))
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf)
@property
def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self):
if not hasattr(self, '_buf'): return
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)
def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}" + (">" if self.options is None else f"{self.options=}>")
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
assert not force_zero_copy, "force zero copy was passed, but copy is required"
return self.copyout(memoryview(bytearray(self.nbytes)))
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyin(self._buf, mv)
return self
def copyout(self, mv:memoryview) -> memoryview:
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyout(mv, self._buf)
return mv

View File

@ -2,13 +2,12 @@ from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar
import importlib, inspect, functools, pathlib, time, ctypes
from tinygrad.dtype import DType, ImageDType
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.helpers import prod, CACHECOLLECTING
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, get_lazyop_info, GlobalCounters
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.codegen.uops import UOpGraph
from dataclasses import dataclass
if TYPE_CHECKING:
from tinygrad.codegen.linearizer import Linearizer
@ -68,50 +67,6 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
# **************** Buffer / Allocator ****************
@dataclass(frozen=True, eq=True)
class BufferOptions:
image: Optional[ImageDType] = None
uncached: bool = False
host: bool = False
nolru: bool = False
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, initial_value:Optional[bytes]=None):
assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
self.device, self.size, self.dtype, self.d, self.options = device, size, dtype, Device[device], options
self.allocator = self.d.allocator
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, options)
# TODO: mem_used for all devices
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
if initial_value is not None: self.copyin(memoryview(initial_value))
def __reduce__(self):
buf = bytearray(self.nbytes)
self.copyout(memoryview(buf))
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf)
@property
def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self):
if not hasattr(self, '_buf'): return # happens when __init__ has raised exception
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)
def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}" + (">" if self.options is None else f"{self.options=}>")
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
assert not force_zero_copy, "force zero copy was passed, but copy is required"
return self.copyout(memoryview(bytearray(self.nbytes)))
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyin(self._buf, mv)
return self
def copyout(self, mv:memoryview) -> memoryview:
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyout(mv, self._buf)
return mv
class BufferCopy(JITRunner):
def copy(self, dest, src):
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and src._buf.ud.fd is not None:
@ -128,7 +83,7 @@ class BufferCopy(JITRunner):
self.copy(dest, src)
et = None
if wait or DEBUG >= 2:
dest.d.synchronize()
Device[dest.device].synchronize()
et = time.perf_counter() - st
total_sz = dest.size*dest.dtype.itemsize
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest.device[:7]:>7s} <- {src.device[:7]:7s}"

View File

@ -4,7 +4,7 @@ import functools, itertools, operator
from tinygrad.nn.state import get_parameters
from tinygrad.dtype import DType
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException
from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph
from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
@ -57,8 +57,8 @@ def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer],
for ji in jit_cache:
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledASTRunner): ji_graph_dev = ji.prg.device
elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].d.dname.split(":", 1)[0] in {"HSA", "CUDA"}:
ji_graph_dev = ji.rawbufs[0].d
elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
ji_graph_dev = Device[ji.rawbufs[0].device]
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or
@ -153,7 +153,7 @@ class PlaceHolder:
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
ret = self.ref()
if ret: return ret
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options)
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate()
return buffer_cache[self]
class _CacheCollector:

View File

@ -47,7 +47,8 @@ def run_schedule(schedule:List[ScheduleItem]):
# if the buffer isn't realized, it might be a const or something. this is fine
out.realized = out.srcs[1].base.realized
else:
out.realized = Buffer(out.device, out.size, out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None)
out.realized = Buffer(out.device, out.size, out.dtype)
if not getattr(prg, "skip_allocation", False): out.realized.allocate()
del out.srcs
# run the function (put it in JIT)

View File

@ -74,7 +74,7 @@ def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
for k,lx in bufsts.items():
buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype)
rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype).allocate()
assert all(r is not None for r in rawbufs)
return cast(List[Buffer], rawbufs)
@ -148,7 +148,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
return beam[0][0]
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice

View File

@ -2,7 +2,7 @@ import ctypes, collections
from typing import Any, Optional, Tuple, Dict, List, cast
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, GraphException
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
@ -41,7 +41,7 @@ class CUDAGraph(MultiDeviceJITGraph):
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
src_dev = cast(CUDADevice, src.d)
src_dev = cast(CUDADevice, Device[src.device])
new_node = cuda.CUgraphNode()
deps = self.access_resources(read=[src], write=[dest], new_dependency=new_node)

View File

@ -1,7 +1,7 @@
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Union, Tuple
from tinygrad.helpers import GraphException, init_c_var, round_up
from tinygrad.device import Compiled, Buffer, BufferOptions, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats
from tinygrad.device import Compiled, Buffer, BufferOptions, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
@ -37,7 +37,7 @@ class HSAGraph(MultiDeviceJITGraph):
for ji in self.jit_cache:
if isinstance(ji.prg, CompiledASTRunner): compiled_devices.add(ji.prg.device)
elif isinstance(ji.prg, BufferXfer):
for x in ji.rawbufs[0:2]: compiled_devices.add(cast(Buffer, x).d)
for x in ji.rawbufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
else: raise GraphException
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
@ -86,7 +86,7 @@ class HSAGraph(MultiDeviceJITGraph):
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d)
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)

View File

@ -68,7 +68,7 @@ class DiskRunner(JITRunner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False):
assert len(rawbufs) == 2
src = rawbufs[1]._buf
rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset)
rawbufs[0].allocate(DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset))
class DiskDevice(Compiled):
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)

View File

@ -12,7 +12,8 @@ from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.device import Buffer, Device, BufferOptions
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Device
from tinygrad.shape.symbolic import sint
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule