diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 806d4cab..f4cde3d8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -229,6 +229,9 @@ jobs: - if: ${{ matrix.task == 'onnx' }} name: Test ONNX (CLANG) run: CLANG=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20 + - if: ${{ matrix.task == 'onnx' }} + name: Run CLOUD=1 Test + run: CLOUDDEV=CLANG CLOUD=1 python3 test/test_ops.py TestOps.test_tiny_add - if: ${{ matrix.task == 'onnx' }} name: Test Action Space run: PYTHONPATH="." GPU=1 python3 extra/optimization/get_action_space.py diff --git a/test/test_pickle.py b/test/test_pickle.py index 6366eacf..ad69d382 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -85,6 +85,11 @@ class TestPickle(unittest.TestCase): sched_pk = pickle.loads(pk) assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast) + def test_pickle_renderer(self): + from tinygrad.device import Device + pk = pickle.dumps(Device.default.renderer) + pickle.loads(pk) + class TestPickleJIT(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/tinygrad/device.py b/tinygrad/device.py index 27e9da82..89c95880 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -180,7 +180,7 @@ class CompileError(Exception): pass class Compiler: def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey - def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function") + def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default def compile_cached(self, src:str) -> bytes: if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None: assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}" diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 2a3d7572..1a8006c1 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -89,4 +89,5 @@ class Renderer: extra_matcher: Any = None code_for_op: Dict[Op, Callable] = {} + def __reduce__(self): return self.__class__, () def render(self, name:str, uops:List[UOp]) -> str: raise NotImplementedError("needs a renderer") diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 5709fb7d..e27e2fdf 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -65,7 +65,9 @@ class PTXRenderer(Renderer): tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half] code_for_op = asm_for_op extra_matcher = ptx_matcher - def __init__(self, arch:str, device="CUDA"): self.device, self.tensor_cores = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [] + def __init__(self, arch:str, device="CUDA"): + self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch + def __reduce__(self): return self.__class__, (self.arch, self.device) # language options kernel_prefix = """.version VERSION diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 69de5e90..d57deae1 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -313,7 +313,8 @@ class CUDARenderer(CStyleLanguage): st1_pattern=(((1,1),(1,0),(0,2),(0,3),(0,4)),((1,3),(1,5),(1,2),(0,0),(0,1),(1,4))), st2_pattern=(((1,1),(1,0),(1,4),(0,0),(0,1)),((0,4),(0,2),(1,5),(0,3),(1,3),(1,2))), reduce_axes=[(0,8),(1,2)], upcast_axes=([(0,8)],[(2,2),(3,2)],[(3,2),(2,2)])) for di, do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])] - def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [] + def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch + def __reduce__(self): return self.__class__, (self.arch,) # language options kernel_prefix = "extern \"C\" __global__ " diff --git a/tinygrad/runtime/ops_cloud.py b/tinygrad/runtime/ops_cloud.py new file mode 100644 index 00000000..8be9a56b --- /dev/null +++ b/tinygrad/runtime/ops_cloud.py @@ -0,0 +1,160 @@ +# the CLOUD=1 device is a process boundary between the frontend/runtime +# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware +# with CLOUD tinygrad is frontend <-> middleware <-> CloudDevice ///HTTP/// cloud_server <-> runtime <-> hardware +# this client and server can be on the same machine, same network, or just same internet +# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC + +from __future__ import annotations +from typing import Tuple, Optional, Dict, Any, DefaultDict +from collections import defaultdict +import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii +from dataclasses import dataclass, field +from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap +from tinygrad.device import Compiled, Allocator, Compiler, Device +from http.server import HTTPServer, BaseHTTPRequestHandler + +# ***** backend ***** + +@dataclass +class CloudSession: + programs: Dict[Tuple[str, str], Any] = field(default_factory=dict) + buffers: Dict[int, Tuple[Any, int]] = field(default_factory=dict) + buffer_num = 0 + +class CloudHandler(BaseHTTPRequestHandler): + protocol_version = 'HTTP/1.1' + dname: str + sessions: DefaultDict[str, CloudSession] = defaultdict(CloudSession) + + def setup(self): + super().setup() + print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}") + + def get_data(self): + content_len = self.headers.get('Content-Length') + assert content_len is not None + return self.rfile.read(int(content_len)) + def get_json(self): return json.loads(self.get_data()) + + def _fail(self): + self.send_response(404) + self.end_headers() + return 0 + + def _do(self, method): + session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]] + ret = b"" + if self.path == "/renderer" and method == "GET": + cls, args = Device[CloudHandler.dname].renderer.__reduce__() + ret = json.dumps((cls.__module__, cls.__name__, args)).encode() + elif self.path.startswith("/alloc") and method == "POST": + size = int(self.path.split("=")[-1]) + session.buffer_num += 1 + session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size), size) + ret = str(session.buffer_num).encode() + elif self.path.startswith("/buffer"): + key = int(self.path.split("/")[-1]) + buf,sz = session.buffers[key] + if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf) + elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data()))) + elif method == "DELETE": + Device[CloudHandler.dname].allocator.free(buf,sz) + del session.buffers[key] + else: return self._fail() + elif self.path.startswith("/program"): + name, hsh = self.path.split("/")[-2:] + if method == "PUT": + src = self.get_data() + assert hashlib.sha256(src).hexdigest() == hsh + lib = Device[CloudHandler.dname].compiler.compile_cached(src.decode()) + session.programs[(name, hsh)] = Device[CloudHandler.dname].runtime(name, lib) + elif method == "POST": + j = self.get_json() + bufs = [session.buffers[x][0] for x in j['bufs']] + del j['bufs'] + r = session.programs[(name, hsh)](*bufs, **j) + if r is not None: ret = str(r).encode() + elif method == "DELETE": del session.programs[(name, hsh)] + else: return self._fail() + else: return self._fail() + self.send_response(200) + self.send_header('Content-Length', str(len(ret))) + self.end_headers() + return self.wfile.write(ret) + + def do_GET(self): return self._do("GET") + def do_POST(self): return self._do("POST") + def do_PUT(self): return self._do("PUT") + def do_DELETE(self): return self._do("DELETE") + +def cloud_server(port:int): + multiprocessing.current_process().name = "MainProcess" + CloudHandler.dname = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT + print(f"start cloud server on {port} with device {CloudHandler.dname}") + server = HTTPServer(('', port), CloudHandler) + server.serve_forever() + +# ***** frontend ***** + +class CloudAllocator(Allocator): + def __init__(self, device:CloudDevice): + self.device = device + super().__init__() + def _alloc(self, size:int, options) -> int: return int(self.device.send("POST", f"alloc?size={size}")) + def _free(self, opaque, options): + with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): + self.device.send("DELETE", f"buffer/{opaque}", data=b"") + def copyin(self, dest:int, src:memoryview): self.device.send("PUT", f"buffer/{dest}", data=bytes(src)) + def copyout(self, dest:memoryview, src:int): + resp = self.device.send("GET", f"buffer/{src}") + assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}" + dest[:] = resp + +class CloudProgram: + def __init__(self, device:CloudDevice, name:str, lib:bytes): + self.device = device + self.prgid = f"{name}/{hashlib.sha256(lib).hexdigest()}" + self.device.send("PUT", "program/"+self.prgid, lib) + super().__init__() + def __del__(self): self.device.send("DELETE", "program/"+self.prgid) + + def __call__(self, *bufs, global_size=None, local_size=None, vals:Tuple[int, ...]=(), wait=False): + args = {"bufs": bufs, "vals": vals, "wait": wait} + if global_size is not None: args["global_size"] = global_size + if local_size is not None: args["local_size"] = local_size + ret = self.device.send("POST", "program/"+self.prgid, json.dumps(args).encode()) + if wait: return float(ret) + +class CloudDevice(Compiled): + def __init__(self, device:str): + if (host:=getenv("HOST", "")) != "": + self.host = host + else: + p = multiprocessing.Process(target=cloud_server, args=(6667,)) + p.daemon = True + p.start() + self.host = "127.0.0.1:6667" + self.cookie = binascii.hexlify(os.urandom(0x10)).decode() + if DEBUG >= 1: print(f"cloud with host {self.host}") + while 1: + try: + self.conn = http.client.HTTPConnection(self.host, timeout=60.0) + clouddev = json.loads(self.send("GET", "renderer").decode()) + break + except Exception as e: + print(e) + time.sleep(0.1) + if DEBUG >= 1: print(f"remote has device {clouddev}") + # TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer + assert clouddev[0].startswith("tinygrad.renderer."), f"bad renderer {clouddev}" + renderer = fromimport(clouddev[0], clouddev[1])(*clouddev[2]) + super().__init__(device, CloudAllocator(self), renderer, Compiler(), functools.partial(CloudProgram, self)) + + def send(self, method, path, data:Optional[bytes]=None) -> bytes: + # TODO: retry logic + self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.cookie}"}) + response = self.conn.getresponse() + assert response.status == 200, f"failed on {method} {path}" + return response.read() + +if __name__ == "__main__": cloud_server(getenv("PORT", 6667))