cloud device [pr] (#6964)

* first try at cloud device [pr]

* real separation

* we're free

* clang works

* unhappy with timeout

* better timeouts and free

* unrelated

* use http verbs + add test

* lines + better test

* fix DELETE

* shorter cloud

* split key

* fix sending renderer

* PTXRenderer serialization

* add sessions

* http.client

* minor timeout bump

* fix keep-alive

* inc server timeout

* real fix timeout

* that one too
This commit is contained in:
George Hotz 2024-10-11 12:24:06 +08:00 committed by GitHub
parent 23c09f4b4c
commit f50d0e0ee0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 175 additions and 3 deletions

View File

@ -229,6 +229,9 @@ jobs:
- if: ${{ matrix.task == 'onnx' }}
name: Test ONNX (CLANG)
run: CLANG=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Run CLOUD=1 Test
run: CLOUDDEV=CLANG CLOUD=1 python3 test/test_ops.py TestOps.test_tiny_add
- if: ${{ matrix.task == 'onnx' }}
name: Test Action Space
run: PYTHONPATH="." GPU=1 python3 extra/optimization/get_action_space.py

View File

@ -85,6 +85,11 @@ class TestPickle(unittest.TestCase):
sched_pk = pickle.loads(pk)
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
def test_pickle_renderer(self):
from tinygrad.device import Device
pk = pickle.dumps(Device.default.renderer)
pickle.loads(pk)
class TestPickleJIT(unittest.TestCase):
@classmethod
def setUpClass(cls):

View File

@ -180,7 +180,7 @@ class CompileError(Exception): pass
class Compiler:
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default
def compile_cached(self, src:str) -> bytes:
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}"

View File

@ -89,4 +89,5 @@ class Renderer:
extra_matcher: Any = None
code_for_op: Dict[Op, Callable] = {}
def __reduce__(self): return self.__class__, ()
def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer")

View File

@ -65,7 +65,9 @@ class PTXRenderer(Renderer):
tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half]
code_for_op = asm_for_op
extra_matcher = ptx_matcher
def __init__(self, arch:str, device="CUDA"): self.device, self.tensor_cores = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
def __init__(self, arch:str, device="CUDA"):
self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
def __reduce__(self): return self.__class__, (self.arch, self.device)
# language options
kernel_prefix = """.version VERSION

View File

@ -313,7 +313,8 @@ class CUDARenderer(CStyleLanguage):
st1_pattern=(((1,1),(1,0),(0,2),(0,3),(0,4)),((1,3),(1,5),(1,2),(0,0),(0,1),(1,4))),
st2_pattern=(((1,1),(1,0),(1,4),(0,0),(0,1)),((0,4),(0,2),(1,5),(0,3),(1,3),(1,2))), reduce_axes=[(0,8),(1,2)],
upcast_axes=([(0,8)],[(2,2),(3,2)],[(3,2),(2,2)])) for di, do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
def __reduce__(self): return self.__class__, (self.arch,)
# language options
kernel_prefix = "extern \"C\" __global__ "

View File

@ -0,0 +1,160 @@
# the CLOUD=1 device is a process boundary between the frontend/runtime
# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
# with CLOUD tinygrad is frontend <-> middleware <-> CloudDevice ///HTTP/// cloud_server <-> runtime <-> hardware
# this client and server can be on the same machine, same network, or just same internet
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
from __future__ import annotations
from typing import Tuple, Optional, Dict, Any, DefaultDict
from collections import defaultdict
import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii
from dataclasses import dataclass, field
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap
from tinygrad.device import Compiled, Allocator, Compiler, Device
from http.server import HTTPServer, BaseHTTPRequestHandler
# ***** backend *****
@dataclass
class CloudSession:
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
buffers: Dict[int, Tuple[Any, int]] = field(default_factory=dict)
buffer_num = 0
class CloudHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
dname: str
sessions: DefaultDict[str, CloudSession] = defaultdict(CloudSession)
def setup(self):
super().setup()
print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
def get_data(self):
content_len = self.headers.get('Content-Length')
assert content_len is not None
return self.rfile.read(int(content_len))
def get_json(self): return json.loads(self.get_data())
def _fail(self):
self.send_response(404)
self.end_headers()
return 0
def _do(self, method):
session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
ret = b""
if self.path == "/renderer" and method == "GET":
cls, args = Device[CloudHandler.dname].renderer.__reduce__()
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
elif self.path.startswith("/alloc") and method == "POST":
size = int(self.path.split("=")[-1])
session.buffer_num += 1
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size), size)
ret = str(session.buffer_num).encode()
elif self.path.startswith("/buffer"):
key = int(self.path.split("/")[-1])
buf,sz = session.buffers[key]
if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data())))
elif method == "DELETE":
Device[CloudHandler.dname].allocator.free(buf,sz)
del session.buffers[key]
else: return self._fail()
elif self.path.startswith("/program"):
name, hsh = self.path.split("/")[-2:]
if method == "PUT":
src = self.get_data()
assert hashlib.sha256(src).hexdigest() == hsh
lib = Device[CloudHandler.dname].compiler.compile_cached(src.decode())
session.programs[(name, hsh)] = Device[CloudHandler.dname].runtime(name, lib)
elif method == "POST":
j = self.get_json()
bufs = [session.buffers[x][0] for x in j['bufs']]
del j['bufs']
r = session.programs[(name, hsh)](*bufs, **j)
if r is not None: ret = str(r).encode()
elif method == "DELETE": del session.programs[(name, hsh)]
else: return self._fail()
else: return self._fail()
self.send_response(200)
self.send_header('Content-Length', str(len(ret)))
self.end_headers()
return self.wfile.write(ret)
def do_GET(self): return self._do("GET")
def do_POST(self): return self._do("POST")
def do_PUT(self): return self._do("PUT")
def do_DELETE(self): return self._do("DELETE")
def cloud_server(port:int):
multiprocessing.current_process().name = "MainProcess"
CloudHandler.dname = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT
print(f"start cloud server on {port} with device {CloudHandler.dname}")
server = HTTPServer(('', port), CloudHandler)
server.serve_forever()
# ***** frontend *****
class CloudAllocator(Allocator):
def __init__(self, device:CloudDevice):
self.device = device
super().__init__()
def _alloc(self, size:int, options) -> int: return int(self.device.send("POST", f"alloc?size={size}"))
def _free(self, opaque, options):
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected):
self.device.send("DELETE", f"buffer/{opaque}", data=b"")
def copyin(self, dest:int, src:memoryview): self.device.send("PUT", f"buffer/{dest}", data=bytes(src))
def copyout(self, dest:memoryview, src:int):
resp = self.device.send("GET", f"buffer/{src}")
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
dest[:] = resp
class CloudProgram:
def __init__(self, device:CloudDevice, name:str, lib:bytes):
self.device = device
self.prgid = f"{name}/{hashlib.sha256(lib).hexdigest()}"
self.device.send("PUT", "program/"+self.prgid, lib)
super().__init__()
def __del__(self): self.device.send("DELETE", "program/"+self.prgid)
def __call__(self, *bufs, global_size=None, local_size=None, vals:Tuple[int, ...]=(), wait=False):
args = {"bufs": bufs, "vals": vals, "wait": wait}
if global_size is not None: args["global_size"] = global_size
if local_size is not None: args["local_size"] = local_size
ret = self.device.send("POST", "program/"+self.prgid, json.dumps(args).encode())
if wait: return float(ret)
class CloudDevice(Compiled):
def __init__(self, device:str):
if (host:=getenv("HOST", "")) != "":
self.host = host
else:
p = multiprocessing.Process(target=cloud_server, args=(6667,))
p.daemon = True
p.start()
self.host = "127.0.0.1:6667"
self.cookie = binascii.hexlify(os.urandom(0x10)).decode()
if DEBUG >= 1: print(f"cloud with host {self.host}")
while 1:
try:
self.conn = http.client.HTTPConnection(self.host, timeout=60.0)
clouddev = json.loads(self.send("GET", "renderer").decode())
break
except Exception as e:
print(e)
time.sleep(0.1)
if DEBUG >= 1: print(f"remote has device {clouddev}")
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
assert clouddev[0].startswith("tinygrad.renderer."), f"bad renderer {clouddev}"
renderer = fromimport(clouddev[0], clouddev[1])(*clouddev[2])
super().__init__(device, CloudAllocator(self), renderer, Compiler(), functools.partial(CloudProgram, self))
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
# TODO: retry logic
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.cookie}"})
response = self.conn.getresponse()
assert response.status == 200, f"failed on {method} {path}"
return response.read()
if __name__ == "__main__": cloud_server(getenv("PORT", 6667))