mirror of https://github.com/commaai/tinygrad.git
nv mockgpu (#4600)
* mockgpu nv * works * comment that out * fix merge * setup gpuocelot * install packages * not run all of them * passes * fix ci * almost * should pass * linter * linter 2 * try this? * ugn, not supported * ci * remove ticket from description * better descs
This commit is contained in:
parent
3c11ca452e
commit
eb9689336e
|
@ -332,7 +332,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, clang, gpu, cuda, ptx, amd] #, triton]
|
||||
backend: [llvm, clang, gpu, cuda, ptx, amd, nv] #, triton]
|
||||
|
||||
name: Tests on (${{ matrix.backend }})
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -356,7 +356,7 @@ jobs:
|
|||
path: ~/.cache/tinygrad/downloads/
|
||||
key: downloads-cache-${{ matrix.backend }}-${{ env.DOWNLOAD_CACHE_VERSION }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' || matrix.backend == 'nv' && 'NV=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
|
||||
- name: Install OpenCL
|
||||
if: matrix.backend == 'gpu'
|
||||
run: |
|
||||
|
@ -368,14 +368,14 @@ jobs:
|
|||
intel-oneapi-runtime-dpcpp-sycl-opencl-cpu=2023.2.1-16 intel-oneapi-runtime-tbb-common=2021.10.0-49541 \
|
||||
intel-oneapi-runtime-tbb=2021.10.0-49541 intel-oneapi-runtime-opencl=2023.2.1-16
|
||||
- name: Install packages (cuda)
|
||||
if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton'
|
||||
if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv'
|
||||
run: |
|
||||
echo 'Acquire::http::Pipeline-Depth "5";' | sudo tee -a /etc/apt/apt.conf.d/99parallel
|
||||
sudo apt update -y || true
|
||||
sudo apt install -y --no-install-recommends git g++ cmake ninja-build llvm-15-dev zlib1g-dev libglew-dev \
|
||||
flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc libzstd-dev
|
||||
- name: Cache gpuocelot
|
||||
if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton'
|
||||
if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv'
|
||||
id: cache-build
|
||||
uses: actions/cache@v4
|
||||
env:
|
||||
|
@ -384,7 +384,7 @@ jobs:
|
|||
path: ${{ github.workspace }}/gpuocelot/ocelot
|
||||
key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild-5
|
||||
- name: Clone/compile gpuocelot
|
||||
if: (matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton') && steps.cache-build.outputs.cache-hit != 'true'
|
||||
if: (matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv') && steps.cache-build.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
git clone --recurse-submodules https://github.com/gpuocelot/gpuocelot.git ${{ github.workspace }}/gpuocelot
|
||||
cd ${{ github.workspace }}/gpuocelot/ocelot
|
||||
|
@ -394,7 +394,7 @@ jobs:
|
|||
cmake .. -Wno-dev -G Ninja -DOCELOT_BUILD_TOOLS=OFF -DCMAKE_BUILD_ALWAYS=0 -DBUILD_TESTS_CUDA=OFF
|
||||
ninja
|
||||
- name: Install gpuocelot
|
||||
if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton'
|
||||
if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv'
|
||||
run: |
|
||||
cd ${{ github.workspace }}/gpuocelot/ocelot/build
|
||||
sudo ninja install -d explain
|
||||
|
@ -416,7 +416,7 @@ jobs:
|
|||
run: pip install -e '.[testing${{matrix.backend=='llvm'&&',llvm'||matrix.backend=='cuda'&&',cuda'||matrix.backend=='ptx'&&',cuda'||matrix.backend=='triton'&&',triton'||''}}]' --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
PYTHONPATH=${{ github.workspace }} python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU','AMD'], Device.DEFAULT"
|
||||
PYTHONPATH=${{ github.workspace }} python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU','AMD','NV'], Device.DEFAULT"
|
||||
DEBUG=5 PYTHONPATH=${{ github.workspace }} FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
|
||||
- name: Verify OpenCL autogen
|
||||
if: matrix.backend == 'gpu'
|
||||
|
@ -443,13 +443,13 @@ jobs:
|
|||
diff /tmp/hsa.py.bak tinygrad/runtime/autogen/hsa.py
|
||||
diff /tmp/comgr.py.bak tinygrad/runtime/autogen/comgr.py
|
||||
- name: Run pytest (not cuda or amd)
|
||||
if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'amd'
|
||||
if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'amd' && matrix.backend != 'nv'
|
||||
run: python -m pytest -n=auto test/ --durations=20
|
||||
- name: Run ONNX (only LLVM)
|
||||
if: matrix.backend == 'llvm'
|
||||
run: python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Run pytest (cuda)
|
||||
if: matrix.backend=='cuda'||matrix.backend=='ptx'||matrix.backend=='triton'
|
||||
if: matrix.backend=='cuda'||matrix.backend=='ptx'||matrix.backend=='triton'||matrix.backend=='nv'
|
||||
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors)' --ignore=test/external --ignore=test/models --durations=20
|
||||
- name: Run pytest (amd)
|
||||
if: matrix.backend=='amd'
|
||||
|
|
|
@ -19,9 +19,11 @@ WAIT_REG_MEM_FUNCTION_ALWAYS = 0
|
|||
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
|
||||
WAIT_REG_MEM_FUNCTION_GEQ = 5 # >=
|
||||
|
||||
remu = ctypes.CDLL("/usr/local/lib/libremu.so")
|
||||
remu.run_asm.restype = ctypes.c_uint32
|
||||
remu.run_asm.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_void_p]
|
||||
try:
|
||||
remu = ctypes.CDLL("/usr/local/lib/libremu.so")
|
||||
remu.run_asm.restype = ctypes.c_uint32
|
||||
remu.run_asm.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_void_p]
|
||||
except Exception: pass
|
||||
|
||||
def create_sdma_packets():
|
||||
# TODO: clean up this, if we want to keep it
|
||||
|
|
|
@ -93,4 +93,4 @@ class VirtDriver:
|
|||
self.tracked_files = []
|
||||
self.tracked_addresses = []
|
||||
def track_address(self, staddr, enaddr, rcb, wcb): self.tracked_addresses.append((staddr, enaddr, rcb, wcb))
|
||||
def open(self, name, flags, mode): raise NotImplementedError()
|
||||
def open(self, name, flags, mode, fdcls): raise NotImplementedError()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import ctypes, ctypes.util, struct, platform, pathlib, re, time, os, builtins, atexit
|
||||
from extra.mockgpu.nv.nvdriver import NVDriver
|
||||
from extra.mockgpu.amd.amddriver import AMDDriver
|
||||
from tinygrad.helpers import from_mv, to_mv
|
||||
start = time.perf_counter()
|
||||
|
@ -51,7 +52,7 @@ def install_hook(c_function, python_function):
|
|||
def __restore(): libc.memcpy(ioctl_address.contents, original_bc, len(tramp))
|
||||
atexit.register(__restore)
|
||||
|
||||
drivers = [AMDDriver()]
|
||||
drivers = [AMDDriver(), NVDriver()]
|
||||
tracked_fds = {}
|
||||
|
||||
@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_ulong)
|
||||
|
|
|
@ -0,0 +1,217 @@
|
|||
import pathlib, re, ctypes, mmap, collections, struct, functools, os, copy
|
||||
import tinygrad.runtime.autogen.nv_gpu as nv_gpu
|
||||
from typing import Optional, Any
|
||||
from tinygrad.helpers import from_mv
|
||||
from extra.mockgpu.driver import VirtDriver, VirtFileDesc, TextFileDesc, DirFileDesc, VirtFile
|
||||
from extra.mockgpu.nv.nvgpu import NVGPU
|
||||
|
||||
MAP_FIXED = 0x10
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
|
||||
libc.mmap.restype = ctypes.c_void_p
|
||||
libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
|
||||
libc.munmap.restype = ctypes.c_int
|
||||
|
||||
NVSubDevice = collections.namedtuple('NVSubDevice', ['device'])
|
||||
NVUserMode = collections.namedtuple('NVUserMode', ['subdevice'])
|
||||
NVVASpace = collections.namedtuple('NVVASpace', ['device'])
|
||||
NVAllocation = collections.namedtuple('NVAllocation', ['device', 'size'])
|
||||
NVChannelGroup = collections.namedtuple('NVChannelGroup', ['device'])
|
||||
NVContextShare = collections.namedtuple('NVContextShare', ['channel_group'])
|
||||
NVGPFIFO = collections.namedtuple('NVGPFIFO', ['device', 'token'])
|
||||
|
||||
class NVCtlFileDesc(VirtFileDesc):
|
||||
def __init__(self, fd, driver):
|
||||
super().__init__(fd)
|
||||
self.driver = driver
|
||||
|
||||
def ioctl(self, fd, request, argp): return self.driver.ctl_ioctl(request, argp)
|
||||
def mmap(self, start, sz, prot, flags, fd, offset): return libc.mmap(start, sz, prot, flags|mmap.MAP_ANONYMOUS, -1, 0)
|
||||
|
||||
class NVUVMFileDesc(VirtFileDesc):
|
||||
def __init__(self, fd, driver):
|
||||
super().__init__(fd)
|
||||
self.driver = driver
|
||||
|
||||
def ioctl(self, fd, request, argp): return self.driver.uvm_ioctl(request, argp)
|
||||
def mmap(self, start, sz, prot, flags, fd, offset): return libc.mmap(start, sz, prot, flags|mmap.MAP_ANONYMOUS, -1, 0)
|
||||
|
||||
class NVDevFileDesc(VirtFileDesc):
|
||||
def __init__(self, fd, driver, gpu):
|
||||
super().__init__(fd)
|
||||
self.driver, self.gpu = driver, gpu
|
||||
self._mapping_userland = False
|
||||
|
||||
def ioctl(self, fd, request, argp): return self.driver.dev_ioctl(self.gpu, request, argp)
|
||||
def mmap(self, start, sz, prot, flags, fd, offset):
|
||||
start = libc.mmap(start, sz, prot, flags|mmap.MAP_ANONYMOUS, -1, 0)
|
||||
if self._mapping_userland: self.driver.track_address(start, start+sz, lambda mv,off: None, lambda mv, off: self.driver._gpu_mmio_write(mv, off, self.gpu))
|
||||
return start
|
||||
|
||||
class NVDriver(VirtDriver):
|
||||
def __init__(self, gpus=6):
|
||||
super().__init__()
|
||||
|
||||
self.tracked_files += [VirtFile('/dev/nvidiactl', functools.partial(NVCtlFileDesc, driver=self)),
|
||||
VirtFile('/dev/nvidia-uvm', functools.partial(NVUVMFileDesc, driver=self))]
|
||||
|
||||
self.root_handle = None
|
||||
|
||||
self.gpus = {}
|
||||
self.next_fd = (1 << 30)
|
||||
self.next_handle = 1
|
||||
|
||||
self.object_by_handle = {}
|
||||
self.opened_fds = {}
|
||||
self.next_doorbell = collections.defaultdict(int)
|
||||
|
||||
for i in range(gpus): self._prepare_gpu(i)
|
||||
|
||||
def _alloc_fd(self):
|
||||
my_fd = self.next_fd
|
||||
self.next_fd = self.next_fd + 1
|
||||
return my_fd
|
||||
|
||||
def _alloc_handle(self):
|
||||
handle = self.next_handle
|
||||
self.next_handle += 1
|
||||
return handle
|
||||
|
||||
def _prepare_gpu(self, gpu_id):
|
||||
self.gpus[gpu_id] = NVGPU(gpu_id)
|
||||
self.tracked_files += [VirtFile(f'/dev/nvidia{gpu_id}', functools.partial(NVDevFileDesc, driver=self, gpu=self.gpus[gpu_id]))]
|
||||
|
||||
def open(self, name, flags, mode, virtfile):
|
||||
cl = virtfile.fdcls(self._alloc_fd())
|
||||
self.opened_fds[cl.fd] = cl
|
||||
return cl
|
||||
|
||||
def rm_alloc(self, argp):
|
||||
struct = nv_gpu.NVOS21_PARAMETERS.from_address(argp)
|
||||
params_ptr = struct.pAllocParms if struct.pAllocParms else None
|
||||
if struct.hClass == nv_gpu.NV01_ROOT_CLIENT: self.root_handle = struct.hObjectNew = self._alloc_handle()
|
||||
elif struct.hClass == nv_gpu.NV01_DEVICE_0:
|
||||
params:Any = nv_gpu.NV0080_ALLOC_PARAMETERS.from_address(params_ptr)
|
||||
assert params.hClientShare == self.root_handle
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = self.gpus[params.deviceId]
|
||||
elif struct.hClass == nv_gpu.NV20_SUBDEVICE_0:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVGPU)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = NVSubDevice(self.object_by_handle[struct.hObjectParent])
|
||||
elif struct.hClass == nv_gpu.TURING_USERMODE_A:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVSubDevice)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = NVUserMode(self.object_by_handle[struct.hObjectParent])
|
||||
elif struct.hClass == nv_gpu.FERMI_VASPACE_A:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVGPU)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = NVVASpace(self.object_by_handle[struct.hObjectParent])
|
||||
elif struct.hClass == nv_gpu.NV1_MEMORY_SYSTEM or struct.hClass == nv_gpu.NV1_MEMORY_USER:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVGPU)
|
||||
params = nv_gpu.NV_MEMORY_ALLOCATION_PARAMS.from_address(params_ptr)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = NVAllocation(self.object_by_handle[struct.hObjectParent], params.size)
|
||||
elif struct.hClass == nv_gpu.KEPLER_CHANNEL_GROUP_A:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVGPU)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = NVChannelGroup(self.object_by_handle[struct.hObjectParent])
|
||||
elif struct.hClass == nv_gpu.FERMI_CONTEXT_SHARE_A:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVChannelGroup)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = NVContextShare(self.object_by_handle[struct.hObjectParent])
|
||||
elif struct.hClass == nv_gpu.AMPERE_CHANNEL_GPFIFO_A:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVChannelGroup)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS.from_address(params_ptr)
|
||||
gpu = self.object_by_handle[struct.hObjectParent].device
|
||||
gpfifo_token = gpu.add_gpfifo(params.gpFifoOffset, params.gpFifoEntries)
|
||||
self.object_by_handle[struct.hObjectNew] = NVGPFIFO(gpu, gpfifo_token)
|
||||
elif struct.hClass == nv_gpu.AMPERE_DMA_COPY_B or struct.hClass == nv_gpu.ADA_COMPUTE_A:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVGPFIFO)
|
||||
else: raise RuntimeError(f"Unknown {struct.hClass} to rm_alloc")
|
||||
return 0
|
||||
|
||||
def rm_control(self, argp):
|
||||
struct = nv_gpu.NVOS54_PARAMETERS.from_address(argp)
|
||||
params_ptr = struct.params if struct.params else None
|
||||
if struct.cmd == nv_gpu.NV0000_CTRL_CMD_GPU_GET_ID_INFO_V2:
|
||||
params:Any = nv_gpu.NV0000_CTRL_GPU_GET_ID_INFO_V2_PARAMS.from_address(params_ptr)
|
||||
params.deviceInstance = params.gpuId # emulate them to be the same
|
||||
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_GPU_GET_GID_INFO:
|
||||
assert struct.hObject in self.object_by_handle and isinstance(self.object_by_handle[struct.hObject], NVSubDevice)
|
||||
gpu = self.object_by_handle[struct.hObject].device
|
||||
params = nv_gpu.NV2080_CTRL_GPU_GET_GID_INFO_PARAMS.from_address(params_ptr)
|
||||
if params.flags != nv_gpu.NV2080_GPU_CMD_GPU_GET_GID_FLAGS_FORMAT_BINARY: raise RuntimeError(f"Unknown format")
|
||||
bts = gpu.gpu_uuid(sz=params.length)
|
||||
for i in range(params.length): params.data[i] = bts[i]
|
||||
elif struct.cmd == nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN:
|
||||
assert struct.hObject in self.object_by_handle and isinstance(self.object_by_handle[struct.hObject], NVGPFIFO)
|
||||
params = nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN_PARAMS.from_address(params_ptr)
|
||||
gpu_fifo = self.object_by_handle[struct.hObject]
|
||||
params.workSubmitToken = gpu_fifo.token
|
||||
elif struct.cmd == nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE: pass
|
||||
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_PERF_BOOST: pass
|
||||
else: raise RuntimeError(f"Unknown {struct.cmd} to rm_control")
|
||||
return 0
|
||||
|
||||
def ctl_ioctl(self, req, argp):
|
||||
nr = req & 0xff
|
||||
if nr == nv_gpu.NV_ESC_RM_ALLOC: return self.rm_alloc(argp)
|
||||
elif nr == nv_gpu.NV_ESC_RM_ALLOC_MEMORY: pass
|
||||
elif nr == nv_gpu.NV_ESC_RM_CONTROL: return self.rm_control(argp)
|
||||
elif nr == nv_gpu.NV_ESC_RM_MAP_MEMORY:
|
||||
st:Any = nv_gpu.nv_ioctl_nvos33_parameters_with_fd.from_address(argp)
|
||||
obj = self.object_by_handle[st.params.hMemory]
|
||||
if isinstance(obj, NVUserMode):
|
||||
file = self.opened_fds[st.fd]
|
||||
assert isinstance(file, NVDevFileDesc)
|
||||
file._mapping_userland = True
|
||||
elif nr == nv_gpu.NV_ESC_RM_FREE:
|
||||
st = nv_gpu.NVOS00_PARAMETERS.from_address(argp)
|
||||
self.object_by_handle.pop(st.hObjectOld)
|
||||
elif nr == nv_gpu.NV_ESC_CARD_INFO:
|
||||
for i,gpu in enumerate(self.gpus.values()):
|
||||
st = nv_gpu.nv_ioctl_card_info_t.from_address(argp + i * ctypes.sizeof(nv_gpu.nv_ioctl_card_info_t))
|
||||
st.gpu_id = gpu.gpuid
|
||||
st.pci_info.device_id = 0x2684
|
||||
st.valid = True
|
||||
else: raise RuntimeError(f"Unknown {nr} to nvidiactl")
|
||||
return 0
|
||||
def uvm_ioctl(self, nr, argp):
|
||||
if nr == nv_gpu.UVM_INITIALIZE: pass
|
||||
elif nr == nv_gpu.UVM_MM_INITIALIZE: pass
|
||||
elif nr == nv_gpu.UVM_REGISTER_GPU:
|
||||
st:Any = nv_gpu.UVM_REGISTER_GPU_PARAMS.from_address(argp)
|
||||
assert any(all(st.gpu_uuid.uuid[i] == gpu.gpu_uuid()[i] for i in range(16)) for gpu in self.gpus.values())
|
||||
elif nr == nv_gpu.UVM_REGISTER_GPU_VASPACE: pass
|
||||
elif nr == nv_gpu.UVM_ENABLE_PEER_ACCESS: pass # uvm and shared spaced are setup already, no emulation for now
|
||||
elif nr == nv_gpu.UVM_CREATE_EXTERNAL_RANGE:
|
||||
st = nv_gpu.UVM_CREATE_EXTERNAL_RANGE_PARAMS.from_address(argp)
|
||||
libc.mmap(st.base, st.length, mmap.PROT_READ|mmap.PROT_WRITE, MAP_FIXED|mmap.MAP_SHARED|mmap.MAP_ANONYMOUS, -1, 0)
|
||||
elif nr == nv_gpu.UVM_MAP_EXTERNAL_ALLOCATION:
|
||||
st = nv_gpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS.from_address(argp)
|
||||
for gpu_attr_id in range(st.gpuAttributesCount):
|
||||
gpu = None
|
||||
for _gpu in self.gpus.values():
|
||||
if all(st.perGpuAttributes[gpu_attr_id].gpuUuid.uuid[i] == _gpu.gpu_uuid()[i] for i in range(16)):
|
||||
gpu = _gpu
|
||||
break
|
||||
if gpu is None: return -1
|
||||
gpu.map_range(st.base, st.length)
|
||||
elif nr == nv_gpu.UVM_REGISTER_CHANNEL: pass
|
||||
elif nr == nv_gpu.UVM_FREE:
|
||||
st = nv_gpu.UVM_FREE_PARAMS.from_address(argp)
|
||||
libc.munmap(st.base, st.length)
|
||||
else: raise RuntimeError(f"Unknown {nr} to nvidia-uvm")
|
||||
return 0
|
||||
|
||||
def dev_ioctl(self, dev, req, argp): return 0
|
||||
def _gpu_mmio_write(self, mv, off, gpu):
|
||||
any_progress = True
|
||||
while any_progress:
|
||||
any_progress = False
|
||||
for gpu in self.gpus.values():
|
||||
for q in gpu.queues:
|
||||
if (prev_rptr:=q.ctrl.GPGet) != q.ctrl.GPPut:
|
||||
any_progress |= q.execute()
|
|
@ -0,0 +1,165 @@
|
|||
import ctypes, ctypes.util, time
|
||||
import tinygrad.runtime.autogen.nv_gpu as nv_gpu
|
||||
from enum import Enum, auto
|
||||
from extra.mockgpu.gpu import VirtGPU
|
||||
from tinygrad.helpers import to_mv, init_c_struct_t
|
||||
|
||||
def make_qmd_struct_type():
|
||||
fields = []
|
||||
bits = [(name,dt) for name,dt in nv_gpu.__dict__.items() if name.startswith("NVC6C0_QMDV03_00") and isinstance(dt, tuple)]
|
||||
bits += [(name+f"_{i}",dt(i)) for name,dt in nv_gpu.__dict__.items() for i in range(8) if name.startswith("NVC6C0_QMDV03_00") and callable(dt)]
|
||||
bits = sorted(bits, key=lambda x: x[1][1])
|
||||
for i,(name, data) in enumerate(bits):
|
||||
if i > 0 and (gap:=(data[1] - bits[i-1][1][0] - 1)) != 0: fields.append((f"_reserved{i}", ctypes.c_uint32, gap))
|
||||
fields.append((name.replace("NVC6C0_QMDV03_00_", "").lower(), ctypes.c_uint32, data[0]-data[1]+1))
|
||||
return init_c_struct_t(tuple(fields))
|
||||
qmd_struct_t = make_qmd_struct_type()
|
||||
assert ctypes.sizeof(qmd_struct_t) == 0x40 * 4
|
||||
|
||||
try:
|
||||
gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
|
||||
gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501
|
||||
except Exception: pass
|
||||
|
||||
class SchedResult(Enum): CONT = auto(); YIELD = auto() # noqa: E702
|
||||
|
||||
class GPFIFO:
|
||||
def __init__(self, token, base, entries_cnt):
|
||||
self.token, self.base, self.entries_cnt = token, base, entries_cnt
|
||||
self.gpfifo = to_mv(self.base, self.entries_cnt * 8).cast("Q")
|
||||
self.ctrl = nv_gpu.AmpereAControlGPFifo.from_address(self.base + self.entries_cnt * 8)
|
||||
self.state = {}
|
||||
|
||||
# Buf exec state
|
||||
self.buf = None
|
||||
self.buf_sz = 0
|
||||
self.buf_ptr = 0
|
||||
|
||||
def _next_dword(self):
|
||||
assert self.buf is not None
|
||||
x = self.buf[self.buf_ptr]
|
||||
self.buf_ptr += 1
|
||||
return x
|
||||
|
||||
def _next_header(self):
|
||||
header = self._next_dword()
|
||||
typ = (header >> 28) & 0b111
|
||||
size = (header >> 16) & 0xFFF
|
||||
subc = (header >> 13) & 0x7
|
||||
mthd = (header & 0x1FFF) << 2
|
||||
return typ, size, subc, mthd
|
||||
|
||||
def _state(self, reg): return self.state[reg]
|
||||
def _state64(self, reg): return (self.state[reg] << 32) + self.state[reg + 4]
|
||||
def _state64_le(self, reg): return (self.state[reg + 4] << 32) + self.state[reg]
|
||||
|
||||
def _reset_buf_state(self): self.buf, self.buf_ptr = None, 0
|
||||
def _set_buf_state(self, gpfifo_entry):
|
||||
ptr = ((gpfifo_entry >> 2) & 0xfffffffff) << 2
|
||||
sz = ((gpfifo_entry >> 42) & 0x1fffff) << 2
|
||||
self.buf = to_mv(ptr, sz).cast("I")
|
||||
self.buf_sz = sz // 4
|
||||
|
||||
def execute(self) -> bool:
|
||||
initial_off = self.buf_ptr
|
||||
while self.ctrl.GPGet != self.ctrl.GPPut:
|
||||
self._set_buf_state(self.gpfifo[self.ctrl.GPGet])
|
||||
|
||||
if not self.execute_buf():
|
||||
# Buffer isn't executed fully, check if any progress and report.
|
||||
# Do not move GPGet in this case, will continue from the same state next time.
|
||||
return self.buf_ptr != initial_off
|
||||
|
||||
self.ctrl.GPGet = (self.ctrl.GPGet + 1) % self.entries_cnt
|
||||
self._reset_buf_state()
|
||||
return True
|
||||
|
||||
def execute_buf(self) -> bool:
|
||||
while self.buf_ptr < self.buf_sz:
|
||||
init_off = self.buf_ptr
|
||||
typ, size, subc, mthd = self._next_header()
|
||||
cmd_end_off = self.buf_ptr + size
|
||||
|
||||
while self.buf_ptr < cmd_end_off:
|
||||
res = self.execute_cmd(mthd)
|
||||
if res == SchedResult.YIELD:
|
||||
self.buf_ptr = init_off # just revert to the header
|
||||
return False
|
||||
mthd += 4
|
||||
return True
|
||||
|
||||
def execute_qmd(self, qmd_addr):
|
||||
qmd = qmd_struct_t.from_address(qmd_addr)
|
||||
prg_addr = qmd.program_address_lower + (qmd.program_address_upper << 32)
|
||||
const0 = to_mv(qmd.constant_buffer_addr_lower_0 + (qmd.constant_buffer_addr_upper_0 << 32), 0x160).cast('I')
|
||||
args_cnt, vals_cnt = const0[0], const0[1]
|
||||
args_addr = qmd.constant_buffer_addr_lower_0 + (qmd.constant_buffer_addr_upper_0 << 32) + 0x160
|
||||
args = to_mv(args_addr, args_cnt*8).cast('Q')
|
||||
vals = to_mv(args_addr + args_cnt*8, vals_cnt*4).cast('I')
|
||||
cargs = [ctypes.cast(args[i], ctypes.c_void_p) for i in range(args_cnt)] + [ctypes.cast(vals[i], ctypes.c_void_p) for i in range(vals_cnt)]
|
||||
gx, gy, gz = qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth
|
||||
lx, ly, lz = qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2
|
||||
gpuocelot_lib.ptx_run(ctypes.cast(prg_addr, ctypes.c_char_p), args_cnt+vals_cnt, (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0)
|
||||
|
||||
def execute_cmd(self, cmd) -> SchedResult:
|
||||
if cmd == nv_gpu.NVC56F_SEM_EXECUTE: return self._exec_signal()
|
||||
elif cmd == nv_gpu.NVC6C0_LAUNCH_DMA: return self._exec_nvc6c0_dma()
|
||||
elif cmd == nv_gpu.NVC6B5_LAUNCH_DMA: return self._exec_nvc6b5_dma()
|
||||
elif cmd == 0x0320: return self._exec_load_inline_qmd() # NVC6C0_LOAD_INLINE_QMD_DATA
|
||||
else: self.state[cmd] = self._next_dword() # just state update
|
||||
return SchedResult.CONT
|
||||
|
||||
def _exec_signal(self) -> SchedResult:
|
||||
signal = self._state64_le(nv_gpu.NVC56F_SEM_ADDR_LO)
|
||||
val = self._state64_le(nv_gpu.NVC56F_SEM_PAYLOAD_LO)
|
||||
flags = self._next_dword()
|
||||
typ = (flags >> 0) & 0b111
|
||||
if typ == 1: to_mv(signal, 8).cast('Q')[0] = val
|
||||
elif typ == 3:
|
||||
mval = to_mv(signal, 8).cast('Q')[0]
|
||||
return SchedResult.CONT if mval >= val else SchedResult.YIELD
|
||||
else: raise RuntimeError(f"Unsupported type={typ} in exec wait/signal")
|
||||
return SchedResult.CONT
|
||||
|
||||
def _exec_load_inline_qmd(self):
|
||||
qmd_addr = self._state64(nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A) << 8
|
||||
assert qmd_addr != 0x0, f"invalid qmd address {qmd_addr}"
|
||||
qmd_data = [self._next_dword() for _ in range(0x40)]
|
||||
cdata = (ctypes.c_uint32 * len(qmd_data))(*qmd_data)
|
||||
ctypes.memmove(qmd_addr, cdata, 0x40 * 4)
|
||||
self.execute_qmd(qmd_addr)
|
||||
|
||||
def _exec_nvc6c0_dma(self):
|
||||
addr = self._state64(nv_gpu.NVC6C0_OFFSET_OUT_UPPER)
|
||||
sz = self._state(nv_gpu.NVC6C0_LINE_LENGTH_IN)
|
||||
lanes = self._state(nv_gpu.NVC6C0_LINE_COUNT)
|
||||
assert lanes == 1, f"unsupported lanes > 1 in _exec_nvc6c0_dma: {lanes}"
|
||||
flags = self._next_dword()
|
||||
assert flags == 0x41, f"unsupported flags in _exec_nvc6c0_dma: {flags}"
|
||||
typ, dsize, subc, mthd = self._next_header()
|
||||
assert typ == 6 and mthd == nv_gpu.NVC6C0_LOAD_INLINE_DATA, f"Expected inline data not found after nvc6c0_dma, {typ=} {mthd=}"
|
||||
copy_data = [self._next_dword() for _ in range(dsize)]
|
||||
assert len(copy_data) * 4 == sz, f"different copy sizes in _exec_nvc6c0_dma: {len(copy_data) * 4} != {sz}"
|
||||
cdata = (ctypes.c_uint32 * len(copy_data))(*copy_data)
|
||||
ctypes.memmove(addr, cdata, sz)
|
||||
|
||||
def _exec_nvc6b5_dma(self):
|
||||
src = self._state64(nv_gpu.NVC6B5_OFFSET_IN_UPPER)
|
||||
dst = self._state64(nv_gpu.NVC6B5_OFFSET_OUT_UPPER)
|
||||
sz = self._state(nv_gpu.NVC6B5_LINE_LENGTH_IN)
|
||||
flags = self._next_dword()
|
||||
assert flags == 0x182, f"unsupported flags in _exec_nvc6b5_dma: {flags}"
|
||||
ctypes.memmove(dst, src, sz)
|
||||
|
||||
class NVGPU(VirtGPU):
|
||||
def __init__(self, gpuid):
|
||||
super().__init__(gpuid)
|
||||
self.mapped_ranges = set()
|
||||
self.queues = []
|
||||
|
||||
def map_range(self, vaddr, size): self.mapped_ranges.add((vaddr, size))
|
||||
def unmap_range(self, vaddr, size): self.mapped_ranges.remove((vaddr, size))
|
||||
def add_gpfifo(self, base, entries_count):
|
||||
self.queues.append(GPFIFO(token:=len(self.queues), base, entries_count))
|
||||
return token
|
||||
def gpu_uuid(self, sz=16): return self.gpuid.to_bytes(sz, byteorder='big', signed=False)
|
|
@ -26,7 +26,7 @@ def assert_jit_cache_len(fxn, expected_len):
|
|||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"HSA", "AMD"} or (device == "CUDA" and not CI and not getenv("PTX"))
|
||||
return device in {"HSA", "AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
||||
if device == "CUDA" and getenv("PTX") and dtype in (dtypes.int8, dtypes.uint8): return False
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
|
@ -35,7 +35,7 @@ def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
|||
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
|
||||
if dtype == dtypes.half:
|
||||
if device == "GPU": return not CI and not OSX
|
||||
if device in ["LLVM", "CUDA"]: return not CI
|
||||
if device in ["LLVM", "CUDA", "NV"]: return not CI
|
||||
if device == "PYTHON": return sys.version_info >= (3, 12)
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
||||
return True
|
||||
|
|
|
@ -200,7 +200,7 @@ class TestFloatDType(TestDType):
|
|||
|
||||
class TestDoubleDtype(TestDType):
|
||||
DTYPE = dtypes.double
|
||||
@unittest.skipIf(getenv("CUDACPU") or getenv("PTX"), "conversion not supported on CUDACPU and PTX") # TODO: why not?
|
||||
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or getenv("PTX"), "conversion not supported on CUDACPU and PTX") # TODO: why not?
|
||||
def test_float64_increased_precision(self):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
|
|
|
@ -40,7 +40,7 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (T
|
|||
|
||||
# TODO: CUDACPU segfaults on sin
|
||||
# TODO: METAL sin is flaky for float16
|
||||
if getenv("CUDACPU") or Device.DEFAULT == "METAL": unary_operations.remove((Tensor.sin, np.sin))
|
||||
if getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV") or Device.DEFAULT == "METAL": unary_operations.remove((Tensor.sin, np.sin))
|
||||
|
||||
class ht:
|
||||
float64 = strat.floats(width=64, allow_subnormal=False)
|
||||
|
@ -145,7 +145,7 @@ class TestDTypeALU(unittest.TestCase):
|
|||
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
|
||||
|
||||
# Metal and CUDACPU and HIP behave differently than numpy in CI for overflows
|
||||
skip_overflow = CI and (Device.DEFAULT in {"HSA", "AMD"} or getenv("CUDACPU"))
|
||||
skip_overflow = CI and (Device.DEFAULT in {"HSA", "AMD", "NV"} or getenv("CUDACPU"))
|
||||
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
|
||||
|
|
|
@ -287,7 +287,7 @@ class TestJit(unittest.TestCase):
|
|||
for i in range(5):
|
||||
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "HSA"}, "no GPU CI")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "HSA", "NV", "AMD"}, "no GPU CI")
|
||||
def test_jitted_transfers(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
||||
|
|
|
@ -852,7 +852,7 @@ class TestKernelOpts(unittest.TestCase):
|
|||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
def test_padto_matmul(self):
|
||||
if CI and Device.DEFAULT in ["CUDA", "AMD"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims")
|
||||
if CI and Device.DEFAULT in ["CUDA", "AMD", "NV"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims")
|
||||
N = 17 * 17
|
||||
Tensor.manual_seed(289)
|
||||
a = Tensor.rand(N, N)
|
||||
|
|
|
@ -32,7 +32,7 @@ def helper_test_lin(lin: Linearizer, opts, failed_platforms, rtol=1e-2, atol=1e-
|
|||
else:
|
||||
assert Device.DEFAULT in failed_platforms, f"failed on {Device.DEFAULT} with {compare_result[0]}"
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT=="CUDA", "failed on CUDA CI")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "failed on CUDA CI")
|
||||
class TestLinearizerFailures(unittest.TestCase):
|
||||
def setUp(self):
|
||||
random.seed(42)
|
||||
|
|
|
@ -148,6 +148,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "not supported in HCQ")
|
||||
def test_copy_jit(self):
|
||||
@TinyJit
|
||||
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
|
||||
|
@ -292,6 +293,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||
y_shard = layer_norm_sharded(x_sharded).realize()
|
||||
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
|
||||
def test_data_parallel_resnet(self):
|
||||
import sys, pathlib
|
||||
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
||||
|
@ -310,6 +312,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||
shard_output_np = shard_output.numpy()
|
||||
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
|
||||
def test_data_parallel_resnet_train_step(self):
|
||||
import sys, pathlib
|
||||
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import Profiling, CI
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
|
||||
class TestConvSpeed(unittest.TestCase):
|
||||
|
||||
def test_mnist(self):
|
||||
|
|
|
@ -10,7 +10,7 @@ from tinygrad.nn.state import load_state_dict
|
|||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
|
||||
class TestNN(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no int64 on WebGPU")
|
||||
def test_sparse_cat_cross_entropy(self):
|
||||
|
|
|
@ -74,7 +74,7 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
|
|||
class TestOps(unittest.TestCase):
|
||||
|
||||
def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, low=-1.5, high=1.5):
|
||||
if getenv("CUDACPU"): self.skipTest('helper_test_exception fails in CUDACPU')
|
||||
if getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV"): self.skipTest('helper_test_exception fails in CUDACPU')
|
||||
ts, tst = prepare_test_op(low, high, shps, vals)
|
||||
with self.assertRaises(expected) as torch_cm:
|
||||
torch_fxn(*ts)
|
||||
|
@ -1493,7 +1493,7 @@ class TestOps(unittest.TestCase):
|
|||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CUDA", "CUDA fails on this")
|
||||
@unittest.skipIf( Device.DEFAULT in {"CUDA", "NV"}, "CUDA fails on this")
|
||||
def test_maxpool2d_unit_stride(self):
|
||||
helper_test_op([(8, 2, 17, 14)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1),
|
||||
|
|
|
@ -41,7 +41,7 @@ def step(tensor, optim, steps=1, teeny=False, **kwargs):
|
|||
optim.step()
|
||||
return net.x.detach().numpy(), net.W.detach().numpy()
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
|
||||
class TestOptim(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_training = Tensor.training
|
||||
|
|
|
@ -11,6 +11,7 @@ from tinygrad.helpers import Context
|
|||
from tinygrad.engine.realize import capturing
|
||||
|
||||
class TestTimeLinearizer(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
|
||||
def test_reasonable_time(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
|
||||
|
@ -19,6 +20,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
|||
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
|
||||
def test_bufs_from_lin(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
|
||||
rawbufs = bufs_from_lin(lin:=Linearizer(*si.ast))
|
||||
|
@ -28,6 +30,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
|||
assert all(r.size > 0 for r in rawbufs)
|
||||
|
||||
class TestBEAM(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
|
||||
def test_dynamic_beam(self):
|
||||
# TODO: make this infra globally usable
|
||||
class Capture:
|
||||
|
|
|
@ -4,7 +4,7 @@ from tinygrad import Tensor, Device, dtypes
|
|||
from test.helpers import is_dtype_supported
|
||||
# similar to test/external/external_test_gpu_ast.py, but universal
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CUDA" and CI, "slow on CUDA CI")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CUDA", "NV"} and CI, "slow on CUDA CI")
|
||||
class TestSpecific(unittest.TestCase):
|
||||
# from openpilot
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_
|
|||
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:{kernel_size}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
|
||||
|
||||
@unittest.skipIf(getenv("BIG") == 0, "no big tests")
|
||||
@unittest.skipIf(getenv("CUDACPU"), "no CUDACPU")
|
||||
@unittest.skipIf(getenv("CUDACPU") or getenv("MOCKGPU"), "no CUDACPU or MOCKGPUs")
|
||||
class TestBigSpeed(unittest.TestCase):
|
||||
def test_add(self):
|
||||
def f(a, b): return a+b
|
||||
|
@ -138,7 +138,7 @@ class TestBigSpeed(unittest.TestCase):
|
|||
def test_matvec_16384_4096(self): helper_test_matvec('matvec_16384_4096', 16384, 4096)
|
||||
|
||||
@unittest.skipIf(getenv("BIG") == 1, "only big tests")
|
||||
@unittest.skipIf(getenv("CUDACPU"), "no CUDACPU")
|
||||
@unittest.skipIf(getenv("CUDACPU") or getenv("MOCKGPU"), "no CUDACPU or MOCKGPUs")
|
||||
class TestSpeed(unittest.TestCase):
|
||||
def test_sub(self):
|
||||
def f(a, b): return a-b
|
||||
|
|
|
@ -360,7 +360,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
c = (a + b).mean().backward()
|
||||
print(c)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "NV", "AMD"}, "no GPU CI")
|
||||
class TestMoveTensor(unittest.TestCase):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
@given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]),
|
||||
|
|
|
@ -4,18 +4,22 @@ from typing import Tuple, List, Any, cast
|
|||
from tinygrad.device import Compiled, Compiler, LRUAllocator, BufferOptions
|
||||
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes
|
||||
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes, CUDACompiler
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
import tinygrad.runtime.autogen.nv_gpu as nv_gpu
|
||||
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401
|
||||
|
||||
libc = ctypes.CDLL("libc.so.6")
|
||||
libc.memset.argtypes = [ctypes.c_void_p, ctypes.c_char, ctypes.c_int]
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
|
||||
libc.mmap.restype = ctypes.c_void_p
|
||||
libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
|
||||
libc.munmap.restype = ctypes.c_int
|
||||
|
||||
if MOCKGPU:=getenv("MOCKGPU"):
|
||||
import extra.mockgpu.mockgpu # noqa: F401
|
||||
libc.mmap = extra.mockgpu.mockgpu._mmap # type: ignore
|
||||
libc.munmap = extra.mockgpu.mockgpu._munmap # type: ignore
|
||||
|
||||
def nv_iowr(fd, nr, args):
|
||||
ret = fcntl.ioctl(fd, (3 << 30) | (ctypes.sizeof(args) & 0x1FFF) << 16 | (ord('F') & 0xFF) << 8 | (nr & 0xFF), args)
|
||||
if ret != 0: raise RuntimeError(f"ioctl returned {ret}")
|
||||
|
@ -161,25 +165,29 @@ class NVProgram:
|
|||
print(subprocess.check_output(["nvdisasm", fn+".cubin"]).decode('utf-8'))
|
||||
except Exception as e: print("failed to disasm cubin", str(e))
|
||||
|
||||
_phoff, _shoff, _flags, _ehsize, _phentsize, _phnum, _shentsize, _shnum, _shstrndx = struct.unpack_from("<QQIHHHHHH", self.lib, 0x20)
|
||||
sections = [struct.unpack_from("<IIQQQQIIQ", self.lib, _shoff + i * _shentsize) for i in range(_shnum)]
|
||||
shstrtab = memoryview(bytearray(self.lib[sections[_shstrndx][4]:sections[_shstrndx][4]+sections[_shstrndx][5]]))
|
||||
|
||||
self.shmem_usage = 0
|
||||
constant_buffers_data = {}
|
||||
for sh_name, sh_type, sh_flags, _, sh_offset, sh_size, _, sh_info, _ in sections:
|
||||
section_name = shstrtab[sh_name:].tobytes().split(b'\0', 1)[0].decode('utf-8')
|
||||
if sh_type == SHT_NOBITS and sh_flags & SHF_ALLOC: self.shmem_usage = sh_size
|
||||
elif sh_type == SHT_PROGBITS and sh_flags & SHF_ALLOC and sh_flags & SHF_EXECINSTR:
|
||||
self.program = memoryview(bytearray(self.lib[sh_offset:sh_offset+sh_size])).cast("I")
|
||||
self.registers_usage = sh_info >> 24
|
||||
if match := re.match(r'\.nv\.constant(\d+)', section_name):
|
||||
constant_buffers_data[int(match.group(1))] = memoryview(bytearray(self.lib[sh_offset:sh_offset+sh_size])).cast("I")
|
||||
if section_name == ".nv.info":
|
||||
section_data = memoryview(bytearray(self.lib[sh_offset:sh_offset+sh_size])).cast("I")
|
||||
for i in range(sh_size // 12):
|
||||
if section_data[i * 3 + 0] & 0xffff == 0x1204 and section_data[i * 3 + 2] + 0x240 > self.device.slm_per_thread:
|
||||
raise RuntimeError("too high local memory")
|
||||
|
||||
if MOCKGPU:
|
||||
self.program, self.registers_usage = memoryview(bytearray(lib) + b'\x00' * (4 - len(lib)%4)).cast("I"), 0x10
|
||||
constant_buffers_data[0] = memoryview(bytearray(0x190))
|
||||
else:
|
||||
_phoff, _shoff, _flags, _ehsize, _phentsize, _phnum, _shentsize, _shnum, _shstrndx = struct.unpack_from("<QQIHHHHHH", self.lib, 0x20)
|
||||
sections = [struct.unpack_from("<IIQQQQIIQ", self.lib, _shoff + i * _shentsize) for i in range(_shnum)]
|
||||
shstrtab = memoryview(bytearray(self.lib[sections[_shstrndx][4]:sections[_shstrndx][4]+sections[_shstrndx][5]]))
|
||||
for sh_name, sh_type, sh_flags, _, sh_offset, sh_size, _, sh_info, _ in sections:
|
||||
section_name = shstrtab[sh_name:].tobytes().split(b'\0', 1)[0].decode('utf-8')
|
||||
if sh_type == SHT_NOBITS and sh_flags & SHF_ALLOC: self.shmem_usage = sh_size
|
||||
elif sh_type == SHT_PROGBITS and sh_flags & SHF_ALLOC and sh_flags & SHF_EXECINSTR:
|
||||
self.program = memoryview(bytearray(self.lib[sh_offset:sh_offset+sh_size])).cast("I")
|
||||
self.registers_usage = sh_info >> 24
|
||||
if match := re.match(r'\.nv\.constant(\d+)', section_name):
|
||||
constant_buffers_data[int(match.group(1))] = memoryview(bytearray(self.lib[sh_offset:sh_offset+sh_size])).cast("I")
|
||||
if section_name == ".nv.info":
|
||||
section_data = memoryview(bytearray(self.lib[sh_offset:sh_offset+sh_size])).cast("I")
|
||||
for i in range(sh_size // 12):
|
||||
if section_data[i * 3 + 0] & 0xffff == 0x1204 and section_data[i * 3 + 2] + 0x240 > self.device.slm_per_thread:
|
||||
raise RuntimeError("too high local memory")
|
||||
|
||||
# Registers allocation granularity per warp is 256, warp allocaiton granularity is 4. Register file size is 65536.
|
||||
self.max_threads = ((65536 // round_up(self.registers_usage * 32, 256)) // 4) * 4 * 32
|
||||
|
@ -187,8 +195,8 @@ class NVProgram:
|
|||
# Load program and constant buffers (if any)
|
||||
self.lib_sz = round_up(round_up(self.program.nbytes, 128) + sum([round_up(x.nbytes, 128) for i,x in constant_buffers_data.items()]), 0x1000)
|
||||
self.lib_gpu = self.device.allocator.alloc(self.lib_sz)
|
||||
for st in range(0, len(self.program), 4096):
|
||||
HWComputeQueue().copy_from_cpu(self.lib_gpu.base+st*4, self.program[st:st+4096]).submit(self.device)
|
||||
for st in range(0, len(self.program), 4095):
|
||||
HWComputeQueue().copy_from_cpu(self.lib_gpu.base+st*4, self.program[st:st+4095]).submit(self.device)
|
||||
|
||||
self.constbuffer_0 = [0] * 88
|
||||
self.constbuffer_0[6:12] = [*nvdata64_le(self.device.shared_mem_window), *nvdata64_le(self.device.local_mem_window), *nvdata64_le(0xfffdc0)]
|
||||
|
@ -238,6 +246,8 @@ class NVProgram:
|
|||
if self.device.kernargs_ptr >= (self.device.kernargs_page.base + self.device.kernargs_page.length - self.kernargs_segment_size):
|
||||
self.device.kernargs_ptr = self.device.kernargs_page.base
|
||||
|
||||
# HACK: Save counts of args and vars to "unused" constbuffer for later extraction in mockgpu to pass into gpuocelot.
|
||||
if MOCKGPU: self.constbuffer_0[0:2] = [len(args), len(vals)]
|
||||
kernargs = [arg_half for arg in args for arg_half in nvdata64_le(arg.base)] + [val for val in vals]
|
||||
|
||||
queue = HWComputeQueue()
|
||||
|
@ -488,11 +498,11 @@ class NVDevice(Compiled):
|
|||
self.kernargs_page: nv_gpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS = self._gpu_alloc(0x4000000, map_to_cpu=True)
|
||||
self.kernargs_ptr: int = self.kernargs_page.base
|
||||
|
||||
self.arch: str = 'sm_89' # TODO: fix
|
||||
self.arch: str = "sm_89" if not MOCKGPU else "sm_35" # TODO: fix
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, NVAllocator(self), CUDARenderer(self.arch), NVCompiler(self.arch), functools.partial(NVProgram, self),
|
||||
functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
|
||||
super().__init__(device, NVAllocator(self), CUDARenderer(self.arch), CUDACompiler(self.arch) if MOCKGPU else NVCompiler(self.arch),
|
||||
functools.partial(NVProgram, self), functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
|
||||
|
||||
self._cmdq_setup_compute_gpfifo()
|
||||
self._cmdq_setup_dma_gpfifo()
|
||||
|
|
Loading…
Reference in New Issue