mirror of https://github.com/commaai/tinygrad.git
io_uring for copies from disk (#5035)
* exp uring * fixes and old version * nv * cleaner * cmp vs aio * fix * no lib * fix nv * linter * disk_speed_test now runs default * fixes * uring -> io_uring * linter happy * get_temp_buf comment added * tiny nits * put wait back * test runs everywhere * remove consts * remove mmap consts * do not require iouring to run test, they are generic
This commit is contained in:
parent
b69afc67d8
commit
fb1bf48cfe
|
@ -467,7 +467,7 @@ jobs:
|
|||
EOF
|
||||
echo -e 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' | sudo tee /etc/apt/preferences.d/rocm-pin-600
|
||||
sudo apt update || true
|
||||
sudo apt install --no-install-recommends --allow-unauthenticated -y hsa-rocr comgr hsa-rocr-dev
|
||||
sudo apt install --no-install-recommends --allow-unauthenticated -y hsa-rocr comgr hsa-rocr-dev liburing-dev
|
||||
curl -s https://api.github.com/repos/Qazalin/remu/releases/latest | \
|
||||
jq -r '.assets[] | select(.name == "libremu.so").browser_download_url' | \
|
||||
sudo xargs curl -L -o /usr/local/lib/libremu.so
|
||||
|
@ -509,6 +509,12 @@ jobs:
|
|||
diff /tmp/hsa.py.bak tinygrad/runtime/autogen/hsa.py
|
||||
diff /tmp/comgr.py.bak tinygrad/runtime/autogen/comgr.py
|
||||
diff /tmp/amd_gpu.py.bak tinygrad/runtime/autogen/amd_gpu.py
|
||||
- name: Verify Linux autogen
|
||||
if: matrix.backend == 'amd'
|
||||
run: |
|
||||
cp tinygrad/runtime/autogen/io_uring.py /tmp/io_uring.py.bak
|
||||
./autogen_stubs.sh io_uring
|
||||
diff /tmp/io_uring.py.bak tinygrad/runtime/autogen/io_uring.py
|
||||
- name: Run pytest (not cuda or amd)
|
||||
if: matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'amd' && matrix.backend != 'nv'
|
||||
run: python -m pytest -n=auto test/ --durations=20
|
||||
|
|
|
@ -185,6 +185,19 @@ generate_hsa() {
|
|||
python3 -c "import tinygrad.runtime.autogen.hsa"
|
||||
}
|
||||
|
||||
generate_io_uring() {
|
||||
clang2py \
|
||||
/usr/include/liburing.h \
|
||||
/usr/include/linux/io_uring.h \
|
||||
-o $BASE/io_uring.py
|
||||
|
||||
# clang2py can't parse defines
|
||||
sed -r '/^#define __NR_io_uring/ s/^#define __(NR_io_uring[^ ]+) (.*)$/\1 = \2/; t; d' /usr/include/asm-generic/unistd.h >> $BASE/io_uring.py # io_uring syscalls numbers
|
||||
sed -r '/^#define\s+([^ \t]+)\s+([^ \t]+)/ s/^#define\s+([^ \t]+)\s*([^/]*).*$/\1 = \2/; s/1U/1/g; s/0ULL/0/g; t; d' /usr/include/linux/io_uring.h >> $BASE/io_uring.py # #define name (val) -> name = val
|
||||
|
||||
fixup $BASE/io_uring.py
|
||||
}
|
||||
|
||||
if [ "$1" == "opencl" ]; then generate_opencl
|
||||
elif [ "$1" == "hip" ]; then generate_hip
|
||||
elif [ "$1" == "comgr" ]; then generate_comgr
|
||||
|
@ -193,6 +206,7 @@ elif [ "$1" == "hsa" ]; then generate_hsa
|
|||
elif [ "$1" == "kfd" ]; then generate_kfd
|
||||
elif [ "$1" == "nv" ]; then generate_nv
|
||||
elif [ "$1" == "amd" ]; then generate_amd
|
||||
elif [ "$1" == "io_uring" ]; then generate_io_uring
|
||||
elif [ "$1" == "all" ]; then generate_opencl; generate_hip; generate_comgr; generate_cuda; generate_hsa; generate_kfd; generate_nv; generate_amd
|
||||
else echo "usage: $0 <type>"
|
||||
fi
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
import os, ctypes, ctypes.util, io, mmap
|
||||
import os, ctypes, ctypes.util, io, mmap, pathlib
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import Timing, from_mv
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
|
@ -75,18 +75,17 @@ def read_to_gpu_pingpong(fd, sz, gpubuf):
|
|||
MAP_LOCKED = 0x2000
|
||||
MAP_HUGETLB = 0x40000
|
||||
|
||||
from tinygrad.runtime.ops_hip import HIPDevice
|
||||
|
||||
if __name__ == "__main__":
|
||||
dev = Device["HIP"]
|
||||
dev = Device[Device.DEFAULT]
|
||||
|
||||
warm = (Tensor.ones(1024, device="HIP").contiguous() + Tensor.ones(1024, device="HIP").contiguous()).realize()
|
||||
warm = (Tensor.ones(1024, device=Device.DEFAULT).contiguous() + Tensor.ones(1024, device=Device.DEFAULT).contiguous()).realize()
|
||||
#fn = "/home/tiny/tinygrad/weights/rng"
|
||||
fn = "/home/tiny/tinygrad/weights/LLaMA/7B/consolidated.00.pth"
|
||||
fn = pathlib.Path(__file__).parents[1] / "weights/LLaMA-2/70B/consolidated.00.pth"
|
||||
sz = os.stat(fn).st_size
|
||||
t = Tensor.empty(sz, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
with Timing("copy: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
on_hip = t.to("HIP").realize()
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
|
||||
exit(0)
|
||||
|
||||
# 4GB of random numbers
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import Timing, fetch, temp
|
||||
from tinygrad.helpers import Timing, fetch, temp, CI
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def compare_weights_both(url):
|
||||
|
@ -302,5 +302,36 @@ class TestDiskTensor(unittest.TestCase):
|
|||
ct = t.llvm_bf16_cast(dtypes.float)
|
||||
assert ct.numpy().tolist() == [9984., -1, -1000, -9984, 20]
|
||||
|
||||
def test_copy_from_disk(self):
|
||||
fn = pathlib.Path(temp("shco1"))
|
||||
fn.unlink(missing_ok=True)
|
||||
fn.write_bytes(bytes(range(256))*1024)
|
||||
|
||||
t = Tensor.empty(256*1024, device=f"disk:{temp('shco1')}", dtype=dtypes.uint8)
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
np.testing.assert_equal(on_dev.numpy(), t.numpy())
|
||||
|
||||
def test_copy_from_disk_offset(self):
|
||||
fn = pathlib.Path(temp("shco2"))
|
||||
fn.unlink(missing_ok=True)
|
||||
fn.write_bytes(bytes(range(256))*1024)
|
||||
|
||||
for off in [314, 991, 2048, 4096]:
|
||||
t = Tensor.empty(256*1024, device=f"disk:{temp('shco2')}", dtype=dtypes.uint8)[off:]
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
np.testing.assert_equal(on_dev.numpy(), t.numpy())
|
||||
|
||||
def test_copy_from_disk_huge(self):
|
||||
if CI and not hasattr(Device["DISK"], 'io_uring'): self.skipTest("slow on ci without iouring")
|
||||
|
||||
fn = pathlib.Path(temp("shco3"))
|
||||
fn.unlink(missing_ok=True)
|
||||
fn.write_bytes(bytes(range(256))*1024*256)
|
||||
|
||||
for off in [0, 551]:
|
||||
t = Tensor.empty(256*1024*256, device=f"disk:{temp('shco3')}", dtype=dtypes.uint8)[off:]
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
np.testing.assert_equal(on_dev.numpy(), t.numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -101,8 +101,9 @@ class BufferCopy(Runner):
|
|||
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
||||
super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
|
||||
def copy(self, dest, src):
|
||||
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and hasattr(src.allocator.device, 'fd'):
|
||||
dest.allocator.copy_from_fd(dest._buf, src.allocator.device.fd, src._buf.offset, src.nbytes)
|
||||
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and hasattr(src.allocator.device, 'fd')
|
||||
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
||||
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
||||
elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -395,27 +395,6 @@ class AMDAllocator(LRUAllocator):
|
|||
else: raise
|
||||
|
||||
def _free(self, opaque, options:BufferOptions): self.device._gpu_free(opaque)
|
||||
#def as_buffer(self, src:Any) -> memoryview:
|
||||
# self.device.synchronize()
|
||||
# return to_mv(src.va_addr, src.size)
|
||||
|
||||
#def copy_from_fd(self, dest, fd, offset, size):
|
||||
# fo = io.FileIO(fd, "a+b", closefd=False)
|
||||
# fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
|
||||
# copied_in, total_copy_size = 0, round_up(size+minor_offset, PAGE_SIZE)
|
||||
# for i in range(0, size+minor_offset, self.b[0].size):
|
||||
# local_size = min(self.b[0].size, total_copy_size-i)
|
||||
# copy_size = min(local_size-minor_offset, size-copied_in)
|
||||
# if copy_size == 0: break
|
||||
|
||||
# fo.readinto(to_mv(self.b[1].va_addr, local_size))
|
||||
# if i != 0: self.device._wait_signal(self.device.signal_sdma)
|
||||
# self.b = self.b[::-1]
|
||||
# self.device._submit_sdma(dest.va_addr+copied_in, self.b[0].va_addr+minor_offset, copy_size, completion_signal=self.device.signal_sdma)
|
||||
|
||||
# copied_in += copy_size
|
||||
# minor_offset = 0 # only on the first
|
||||
# self.device._wait_signal(self.device.signal_sdma)
|
||||
|
||||
def copyin(self, dest, src: memoryview):
|
||||
for i in range(0, src.nbytes, self.b[0].size):
|
||||
|
@ -428,6 +407,21 @@ class AMDAllocator(LRUAllocator):
|
|||
self.b_timeline[self.b_next] = self.device.timeline_value
|
||||
self.device.timeline_value += 1
|
||||
|
||||
def copy_from_disk(self, dest, src, size):
|
||||
def _get_temp_buf():
|
||||
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
||||
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
|
||||
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
||||
return (self.b[self.b_next].va_addr, self.b_next)
|
||||
return None
|
||||
|
||||
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=SDMA_MAX_COPY_SIZE):
|
||||
HWCopyQueue().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
||||
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
||||
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
||||
self.device.timeline_value += 1
|
||||
|
||||
def copyout(self, dest:memoryview, src):
|
||||
self.device.synchronize()
|
||||
for i in range(0, dest.nbytes, self.b[0].size):
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
from __future__ import annotations
|
||||
import os, mmap, _posixshmem, io
|
||||
from typing import Optional
|
||||
from tinygrad.helpers import OSX
|
||||
import os, mmap, _posixshmem, io, ctypes, ctypes.util, platform
|
||||
from typing import Optional, Generator, Tuple, Callable, List
|
||||
from tinygrad.helpers import OSX, round_up
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
import tinygrad.runtime.autogen.io_uring as io_uring
|
||||
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
|
||||
libc.mmap.restype = ctypes.c_void_p
|
||||
|
||||
class DiskBuffer:
|
||||
def __init__(self, device:DiskDevice, size:int, offset=0):
|
||||
|
@ -29,10 +34,47 @@ class DiskAllocator(Allocator):
|
|||
fo.readinto(dest)
|
||||
else:
|
||||
dest[:] = src._buf()
|
||||
|
||||
def _copyout_sharded(self, src:DiskBuffer, size:int, _get_free_buf:Callable, seg_len:int) -> Generator[Tuple[int, int, int, int], None, None]:
|
||||
assert hasattr(DiskDevice, 'io_uring'), "function requires io uring support"
|
||||
|
||||
fd_offset = src.offset - (minor_offset := src.offset % mmap.PAGESIZE)
|
||||
processed_reqs_cnt, copied_in, next_read_offset, total_copy_size = 0, 0, 0, round_up(size + minor_offset, mmap.PAGESIZE)
|
||||
reqs: List[Tuple[int, int, int, int]] = []
|
||||
|
||||
while next_read_offset < total_copy_size or len(reqs) != processed_reqs_cnt:
|
||||
if next_read_offset < total_copy_size and (copy_batch := _get_free_buf()) is not None:
|
||||
# Prepare sqe
|
||||
sqe_index = (tail:=DiskDevice.io_uring.sq.ktail[0]) & DiskDevice.io_uring.sq.kring_mask[0]
|
||||
sqe = DiskDevice.io_uring.sq.sqes[sqe_index]
|
||||
sqe.opcode, sqe.fd, sqe.off = io_uring.IORING_OP_READ, self.device.fd, fd_offset + next_read_offset
|
||||
sqe.addr, sqe.len, sqe.user_data = copy_batch[0], min(seg_len, total_copy_size - next_read_offset), len(reqs)
|
||||
|
||||
# Send sqe
|
||||
DiskDevice.io_uring.sq.array[sqe_index] = sqe_index
|
||||
DiskDevice.io_uring.sq.ktail[0] = tail + 1
|
||||
libc.syscall(io_uring.NR_io_uring_enter, DiskDevice.io_uring.ring_fd, 1, 1, io_uring.IORING_ENTER_GETEVENTS)
|
||||
|
||||
reqs.append((copy_batch, copied_in, minor_offset, real_copy_size:=min(sqe.len - minor_offset, size - copied_in)))
|
||||
next_read_offset += sqe.len
|
||||
copied_in += real_copy_size
|
||||
minor_offset = 0
|
||||
|
||||
if (head:=DiskDevice.io_uring.cq.khead[0]) != DiskDevice.io_uring.cq.ktail[0]:
|
||||
cqe = DiskDevice.io_uring.cq.cqes[head & DiskDevice.io_uring.cq.kring_mask[0]]
|
||||
assert cqe.res >= 0, f"read from disk failed, err: {cqe.res}"
|
||||
yield reqs[cqe.user_data]
|
||||
DiskDevice.io_uring.cq.khead[0] = head + 1 # advance
|
||||
processed_reqs_cnt += 1
|
||||
|
||||
def offset(self, buf:DiskBuffer, size:int, offset:int): return DiskBuffer(buf.device, size, offset)
|
||||
|
||||
class DiskDevice(Compiled):
|
||||
_tried_io_uring_init = False
|
||||
|
||||
def __init__(self, device:str):
|
||||
if not DiskDevice._tried_io_uring_init: self._iouring_setup()
|
||||
|
||||
self.size: Optional[int] = None
|
||||
self.count = 0
|
||||
super().__init__(device, DiskAllocator(self), None, None, None)
|
||||
|
@ -58,3 +100,25 @@ class DiskDevice(Compiled):
|
|||
if self.count == 0:
|
||||
if hasattr(self, 'fd'): os.close(self.fd)
|
||||
self.size = None
|
||||
def _iouring_setup(self):
|
||||
DiskDevice._tried_io_uring_init = True
|
||||
|
||||
if platform.system() != 'Linux': return
|
||||
|
||||
fd = libc.syscall(io_uring.NR_io_uring_setup, 4096, ctypes.byref(p:=io_uring.struct_io_uring_params()))
|
||||
if fd < 0: return
|
||||
|
||||
sq_ptr = libc.mmap(0, p.sq_off.array + p.sq_entries * 4, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, 0)
|
||||
cq_ptr = libc.mmap(0, p.cq_off.cqes + p.cq_entries * ctypes.sizeof(io_uring.struct_io_uring_cqe),
|
||||
mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_CQ_RING)
|
||||
sqes = libc.mmap(0, p.sq_entries * ctypes.sizeof(io_uring.struct_io_uring_sqe),
|
||||
mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_SQES)
|
||||
|
||||
def u32ptr(val): return ctypes.cast(val, ctypes.POINTER(ctypes.c_uint32))
|
||||
sqdesc = io_uring.struct_io_uring_sq(khead=u32ptr(sq_ptr+p.sq_off.head), ktail=u32ptr(sq_ptr+p.sq_off.tail), array=u32ptr(sq_ptr+p.sq_off.array),
|
||||
kring_mask=u32ptr(sq_ptr+p.sq_off.ring_mask), sqes=ctypes.cast(sqes, ctypes.POINTER(io_uring.struct_io_uring_sqe)))
|
||||
|
||||
cqdesc = io_uring.struct_io_uring_cq(khead=u32ptr(cq_ptr+p.cq_off.head), ktail=u32ptr(cq_ptr+p.cq_off.tail),
|
||||
kring_mask=u32ptr(sq_ptr+p.cq_off.ring_mask), cqes=ctypes.cast(cq_ptr+p.cq_off.cqes, ctypes.POINTER(io_uring.struct_io_uring_cqe)))
|
||||
|
||||
DiskDevice.io_uring = io_uring.struct_io_uring(ring_fd=fd, sq=sqdesc, cq=cqdesc) # type: ignore
|
||||
|
|
|
@ -333,7 +333,7 @@ class NVProgram:
|
|||
class NVAllocator(LRUAllocator):
|
||||
def __init__(self, device:NVDevice):
|
||||
self.device = device
|
||||
self.b = [self.device._gpu_host_alloc(2 << 20) for _ in range(16)]
|
||||
self.b = [self.device._gpu_host_alloc(2 << 20) for _ in range(32)]
|
||||
self.b_timeline = [0] * len(self.b)
|
||||
self.b_next = 0
|
||||
super().__init__()
|
||||
|
@ -347,6 +347,21 @@ class NVAllocator(LRUAllocator):
|
|||
if options.host: self.device._gpu_host_free(opaque)
|
||||
else: self.device._gpu_free(opaque)
|
||||
|
||||
def copy_from_disk(self, dest, src, size):
|
||||
def _get_temp_buf():
|
||||
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
||||
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal[0]:
|
||||
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
||||
return (self.b[self.b_next].va_addr, self.b_next)
|
||||
return None
|
||||
|
||||
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=(2 << 20)):
|
||||
HWCopyQueue().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
||||
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
||||
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
||||
self.device.timeline_value += 1
|
||||
|
||||
def copyin(self, dest, src: memoryview):
|
||||
for i in range(0, src.nbytes, self.b[0].length):
|
||||
self.b_next = (self.b_next + 1) % len(self.b)
|
||||
|
|
Loading…
Reference in New Issue