mirror of https://github.com/commaai/tinygrad.git
Hip driver (#2992)
* start hip driver * fix hip llama * make HIP default if we can * don't change those
This commit is contained in:
parent
f290ca3924
commit
753a7ecc05
|
@ -122,11 +122,10 @@ jobs:
|
||||||
run: HIP=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
run: HIP=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||||
- name: Run Stable Diffusion
|
- name: Run Stable Diffusion
|
||||||
run: python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
run: python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||||
# TODO: rocm 6.0 broke this
|
- name: Run LLaMA (with HIP)
|
||||||
# - name: Run LLaMA
|
run: |
|
||||||
# run: |
|
HIP=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||||
# JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
HIP=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||||
# JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
|
||||||
- name: Run GPT2 (with HIP)
|
- name: Run GPT2 (with HIP)
|
||||||
run: |
|
run: |
|
||||||
HIP=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
HIP=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
import ctypes, ctypes.util, struct, platform, pathlib, re, time
|
||||||
|
|
||||||
|
# *** ioctl lib ***
|
||||||
|
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||||
|
processor = platform.processor()
|
||||||
|
IOCTL_SYSCALL = {"aarch64": 0x1d, "x86_64":16}[processor]
|
||||||
|
|
||||||
|
def get_struct(argp, stype):
|
||||||
|
return ctypes.cast(ctypes.c_void_p(argp), ctypes.POINTER(stype)).contents
|
||||||
|
|
||||||
|
def format_struct(s):
|
||||||
|
sdats = []
|
||||||
|
for field_name, field_type in s._fields_:
|
||||||
|
dat = getattr(s, field_name)
|
||||||
|
if isinstance(dat, int): sdats.append(f"{field_name}:0x{dat:X}")
|
||||||
|
else: sdats.append(f"{field_name}:{dat}")
|
||||||
|
return sdats
|
||||||
|
|
||||||
|
def install_hook(c_function, python_function):
|
||||||
|
python_function_addr = ctypes.cast(ctypes.byref(python_function), ctypes.POINTER(ctypes.c_ulong)).contents.value
|
||||||
|
# AARCH64 trampoline to ioctl
|
||||||
|
if processor == "aarch64":
|
||||||
|
# 0x0000000000000000: 70 00 00 10 adr x16, #0xc
|
||||||
|
# 0x0000000000000004: 10 02 40 F9 ldr x16, [x16]
|
||||||
|
# 0x0000000000000008: 00 02 1F D6 br x16
|
||||||
|
tramp = b"\x70\x00\x00\x10\x10\x02\x40\xf9\x00\x02\x1f\xd6"
|
||||||
|
tramp += struct.pack("Q", python_function_addr)
|
||||||
|
elif processor == "x86_64":
|
||||||
|
# 0x0000000000000000: 49 B8 aa aa aa aa aa aa aa aa movabs r8, <address>
|
||||||
|
# 0x000000000000000a: 41 FF E0 jmp r8
|
||||||
|
tramp = b"\x49\xB8" + struct.pack("Q", python_function_addr) + b"\x41\xFF\xE0"
|
||||||
|
else:
|
||||||
|
raise Exception(f"processor {processor} not supported")
|
||||||
|
|
||||||
|
# get real ioctl address
|
||||||
|
ioctl_address = ctypes.cast(ctypes.byref(c_function), ctypes.POINTER(ctypes.c_ulong))
|
||||||
|
|
||||||
|
# hook ioctl
|
||||||
|
ret = libc.mprotect(ctypes.c_ulong((ioctl_address.contents.value//0x1000)*0x1000), 0x2000, 7)
|
||||||
|
assert ret == 0
|
||||||
|
libc.memcpy(ioctl_address.contents, ctypes.create_string_buffer(tramp), len(tramp))
|
||||||
|
|
||||||
|
# *** ioctl lib end ***
|
||||||
|
|
||||||
|
# clang2py kfd_ioctl.h -o kfd_ioctl.py
|
||||||
|
from extra.hip_gpu_driver import kfd_ioctl
|
||||||
|
def ioctls_from_header():
|
||||||
|
hdr = (pathlib.Path(__file__).parent.parent.parent / "extra/hip_gpu_driver/kfd_ioctl.h").read_text().replace("\\\n", "")
|
||||||
|
pattern = r'#define\s+(AMDKFD_IOC_[A-Z0-9_]+)\s+AMDKFD_IOWR?\((0x[0-9a-fA-F]+),\s+struct\s([A-Za-z0-9_]+)\)'
|
||||||
|
matches = re.findall(pattern, hdr, re.MULTILINE)
|
||||||
|
return {int(nr, 0x10):(name, getattr(kfd_ioctl, "struct_"+sname)) for name, nr, sname in matches}
|
||||||
|
nrs = ioctls_from_header()
|
||||||
|
|
||||||
|
@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_void_p)
|
||||||
|
def ioctl(fd, request, argp):
|
||||||
|
st = time.perf_counter()
|
||||||
|
ret = libc.syscall(IOCTL_SYSCALL, ctypes.c_int(fd), ctypes.c_ulong(request), ctypes.c_void_p(argp))
|
||||||
|
et = time.perf_counter()-st
|
||||||
|
idir, size, itype, nr = (request>>30), (request>>16)&0x3FFF, (request>>8)&0xFF, request&0xFF
|
||||||
|
if nr in nrs and itype == 75:
|
||||||
|
name, stype = nrs[nr]
|
||||||
|
s = get_struct(argp, stype)
|
||||||
|
print(f"{et*1000.:7.2f} ms : {ret:2d} = {name:40s}", ' '.join(format_struct(s)))
|
||||||
|
else:
|
||||||
|
print("ioctl", idir, size, itype, nr, f"fd={fd} ret={ret}")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
install_hook(libc.ioctl, ioctl)
|
||||||
|
|
||||||
|
# AMD_LOG_LEVEL=4 HSAKMT_DEBUG_LEVEL=7
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("***** import tinygrad")
|
||||||
|
from tinygrad import Tensor, Device, TinyJit
|
||||||
|
print("***** access HIP")
|
||||||
|
dev = Device["HIP"]
|
||||||
|
print("***** create tensor a")
|
||||||
|
a = Tensor([1.,2.]*200, device="HIP").realize()
|
||||||
|
print("***** create tensor b")
|
||||||
|
b = Tensor([3.,4.]*200, device="HIP").realize()
|
||||||
|
@TinyJit
|
||||||
|
def add(a, b): return (a+b).realize()
|
||||||
|
for i in range(4):
|
||||||
|
print(f"***** add tensors {i}")
|
||||||
|
c = add(a, b)
|
||||||
|
#dev.synchronize()
|
||||||
|
c = add(b, a)
|
||||||
|
dev.synchronize()
|
||||||
|
print(f"***** delete")
|
||||||
|
del add, a, b, c, dev
|
||||||
|
print(f"***** done")
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,6 @@
|
||||||
|
from __future__ import annotations
|
||||||
import ctypes, functools, subprocess
|
import ctypes, functools, subprocess
|
||||||
from typing import Tuple, TypeVar
|
from typing import Tuple, TypeVar, List
|
||||||
import gpuctypes.hip as hip
|
import gpuctypes.hip as hip
|
||||||
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
|
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
|
||||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
|
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
|
||||||
|
@ -40,22 +41,24 @@ class HIPProgram:
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
class HIPAllocator(LRUAllocator):
|
class HIPAllocator(LRUAllocator):
|
||||||
def __init__(self, device):
|
def __init__(self, device:HIPDevice):
|
||||||
self.device = device
|
self.device = device
|
||||||
super().__init__()
|
super().__init__()
|
||||||
def _alloc(self, size:int):
|
def _alloc(self, size:int):
|
||||||
check(hip.hipSetDevice(self.device))
|
check(hip.hipSetDevice(self.device.device))
|
||||||
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
|
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
|
||||||
def _free(self, opaque:T): check(hip.hipFree(opaque))
|
def _free(self, opaque:T): check(hip.hipFree(opaque))
|
||||||
def copyin(self, dest:T, src: memoryview):
|
def copyin(self, dest:T, src: memoryview):
|
||||||
check(hip.hipSetDevice(self.device))
|
check(hip.hipSetDevice(self.device.device))
|
||||||
# TODO: have to make sure src isn't freed to make this async
|
host_mem = init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), len(src), 0)))
|
||||||
check(hip.hipMemcpy(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice))
|
self.device.pending_copyin.append(host_mem)
|
||||||
|
ctypes.memmove(host_mem, from_mv(src), len(src))
|
||||||
|
check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))
|
||||||
def copyout(self, dest:memoryview, src:T):
|
def copyout(self, dest:memoryview, src:T):
|
||||||
check(hip.hipSetDevice(self.device))
|
check(hip.hipSetDevice(self.device.device))
|
||||||
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
|
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
|
||||||
def transfer(self, dest:T, src:T, sz:int):
|
def transfer(self, dest:T, src:T, sz:int):
|
||||||
check(hip.hipSetDevice(self.device))
|
check(hip.hipSetDevice(self.device.device))
|
||||||
# TODO: hipMemcpyAsync, but you have to track the "src" buffer to not free it
|
# TODO: hipMemcpyAsync, but you have to track the "src" buffer to not free it
|
||||||
check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice))
|
check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice))
|
||||||
|
|
||||||
|
@ -63,11 +66,14 @@ class HIPDevice(Compiled):
|
||||||
default_arch_name = "gfx1100"
|
default_arch_name = "gfx1100"
|
||||||
def __init__(self, device:str=""):
|
def __init__(self, device:str=""):
|
||||||
self.device = int(device.split(":")[1]) if ":" in device else 0
|
self.device = int(device.split(":")[1]) if ":" in device else 0
|
||||||
|
self.pending_copyin: List[hip.hipDeviceptr_t] = []
|
||||||
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501
|
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501
|
||||||
|
|
||||||
from tinygrad.runtime.graph.hip import HIPGraph
|
from tinygrad.runtime.graph.hip import HIPGraph
|
||||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer,
|
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions(device="HIP"), HIPRenderer,
|
||||||
compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
|
compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
|
||||||
def synchronize(self):
|
def synchronize(self):
|
||||||
check(hip.hipSetDevice(self.device))
|
check(hip.hipSetDevice(self.device))
|
||||||
check(hip.hipDeviceSynchronize())
|
check(hip.hipDeviceSynchronize())
|
||||||
|
for opaque in self.pending_copyin: check(hip.hipFree(opaque))
|
||||||
|
self.pending_copyin.clear()
|
Loading…
Reference in New Issue